Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
833cc7a2
Commit
833cc7a2
authored
May 23, 2019
by
Shucai Xiao
Browse files
Merge branch 'develop' of
https://github.com/ROCmSoftwarePlatform/AMDMIGraphX
into int8_miopen_call
parents
5be03483
0d796941
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
61 additions
and
43 deletions
+61
-43
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+0
-7
src/env.cpp
src/env.cpp
+8
-0
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+7
-0
src/include/migraphx/env.hpp
src/include/migraphx/env.hpp
+9
-0
src/include/migraphx/op/reshape.hpp
src/include/migraphx/op/reshape.hpp
+1
-1
src/include/migraphx/op/squeeze.hpp
src/include/migraphx/op/squeeze.hpp
+1
-0
src/include/migraphx/op/unsqueeze.hpp
src/include/migraphx/op/unsqueeze.hpp
+1
-0
src/include/migraphx/raw_data.hpp
src/include/migraphx/raw_data.hpp
+12
-5
src/program.cpp
src/program.cpp
+9
-2
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+6
-8
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+1
-1
src/tf/tf.cpp
src/tf/tf.cpp
+3
-1
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+0
-16
test/tf/tf_test.cpp
test/tf/tf_test.cpp
+3
-2
No files found.
src/eliminate_contiguous.cpp
View file @
833cc7a2
...
@@ -67,13 +67,6 @@ void eliminate_contiguous::apply(program& p) const
...
@@ -67,13 +67,6 @@ void eliminate_contiguous::apply(program& p) const
{
{
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
p
))
{
{
// skip the reshape operator for now, since there is a bug
// for the transpose followed by a reshape
if
(
ins
->
name
()
==
"reshape"
)
{
continue
;
}
// Make a copy so we can modify it while we iterate
// Make a copy so we can modify it while we iterate
auto
args
=
ins
->
inputs
();
auto
args
=
ins
->
inputs
();
for
(
auto
arg
:
ins
->
inputs
())
for
(
auto
arg
:
ins
->
inputs
())
...
...
src/env.cpp
View file @
833cc7a2
...
@@ -21,6 +21,14 @@ bool disabled(const char* name)
...
@@ -21,6 +21,14 @@ bool disabled(const char* name)
return
contains
({
"0"
,
"disable"
,
"disabled"
,
"no"
,
"false"
},
e
.
front
());
return
contains
({
"0"
,
"disable"
,
"disabled"
,
"no"
,
"false"
},
e
.
front
());
}
}
std
::
size_t
value_of
(
const
char
*
name
)
{
auto
e
=
env
(
name
);
if
(
e
.
empty
())
return
0
;
return
std
::
stoul
(
e
.
front
());
}
std
::
vector
<
std
::
string
>
env
(
const
char
*
name
)
std
::
vector
<
std
::
string
>
env
(
const
char
*
name
)
{
{
auto
p
=
std
::
getenv
(
name
);
auto
p
=
std
::
getenv
(
name
);
...
...
src/include/migraphx/check_shapes.hpp
View file @
833cc7a2
...
@@ -103,6 +103,13 @@ struct check_shapes
...
@@ -103,6 +103,13 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
const
check_shapes
&
standard_or_scalar
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
()
or
s
.
scalar
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not a scalar or in standard layout"
);
return
*
this
;
}
const
check_shapes
&
packed
()
const
const
check_shapes
&
packed
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
();
}))
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
();
}))
...
...
src/include/migraphx/env.hpp
View file @
833cc7a2
...
@@ -19,6 +19,8 @@ bool enabled(const char* name);
...
@@ -19,6 +19,8 @@ bool enabled(const char* name);
bool
disabled
(
const
char
*
name
);
bool
disabled
(
const
char
*
name
);
std
::
vector
<
std
::
string
>
env
(
const
char
*
name
);
std
::
vector
<
std
::
string
>
env
(
const
char
*
name
);
std
::
size_t
value_of
(
const
char
*
name
);
template
<
class
T
>
template
<
class
T
>
bool
enabled
(
T
)
bool
enabled
(
T
)
{
{
...
@@ -33,6 +35,13 @@ bool disabled(T)
...
@@ -33,6 +35,13 @@ bool disabled(T)
return
result
;
return
result
;
}
}
template
<
class
T
>
std
::
size_t
value_of
(
T
)
{
static
const
std
::
size_t
result
=
value_of
(
T
::
value
());
return
result
;
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/op/reshape.hpp
View file @
833cc7a2
...
@@ -29,7 +29,7 @@ struct reshape
...
@@ -29,7 +29,7 @@ struct reshape
std
::
string
name
()
const
{
return
"reshape"
;
}
std
::
string
name
()
const
{
return
"reshape"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
}.
has
(
1
)
.
standard
()
;
auto
&&
idims
=
inputs
.
front
().
lens
();
auto
&&
idims
=
inputs
.
front
().
lens
();
std
::
vector
<
std
::
size_t
>
rdims
(
dims
.
begin
(),
dims
.
end
());
std
::
vector
<
std
::
size_t
>
rdims
(
dims
.
begin
(),
dims
.
end
());
auto
n_neg_dims
=
std
::
count
(
dims
.
begin
(),
dims
.
end
(),
-
1
);
auto
n_neg_dims
=
std
::
count
(
dims
.
begin
(),
dims
.
end
(),
-
1
);
...
...
src/include/migraphx/op/squeeze.hpp
View file @
833cc7a2
...
@@ -29,6 +29,7 @@ struct squeeze
...
@@ -29,6 +29,7 @@ struct squeeze
std
::
string
name
()
const
{
return
"squeeze"
;
}
std
::
string
name
()
const
{
return
"squeeze"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
auto
input_shape
=
inputs
[
0
];
auto
input_shape
=
inputs
[
0
];
auto
type
=
input_shape
.
type
();
auto
type
=
input_shape
.
type
();
auto
old_lens
=
input_shape
.
lens
();
auto
old_lens
=
input_shape
.
lens
();
...
...
src/include/migraphx/op/unsqueeze.hpp
View file @
833cc7a2
...
@@ -29,6 +29,7 @@ struct unsqueeze
...
@@ -29,6 +29,7 @@ struct unsqueeze
std
::
string
name
()
const
{
return
"unsqueeze"
;
}
std
::
string
name
()
const
{
return
"unsqueeze"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard_or_scalar
();
auto
input_shape
=
inputs
[
0
];
auto
input_shape
=
inputs
[
0
];
auto
type
=
input_shape
.
type
();
auto
type
=
input_shape
.
type
();
auto
old_lens
=
input_shape
.
lens
();
auto
old_lens
=
input_shape
.
lens
();
...
...
src/include/migraphx/raw_data.hpp
View file @
833cc7a2
...
@@ -27,6 +27,7 @@ struct raw_data : raw_data_base
...
@@ -27,6 +27,7 @@ struct raw_data : raw_data_base
template
<
class
Stream
>
template
<
class
Stream
>
friend
Stream
&
operator
<<
(
Stream
&
os
,
const
Derived
&
d
)
friend
Stream
&
operator
<<
(
Stream
&
os
,
const
Derived
&
d
)
{
{
if
(
not
d
.
empty
())
d
.
visit
([
&
](
auto
x
)
{
os
<<
x
;
});
d
.
visit
([
&
](
auto
x
)
{
os
<<
x
;
});
return
os
;
return
os
;
}
}
...
@@ -40,8 +41,11 @@ struct raw_data : raw_data_base
...
@@ -40,8 +41,11 @@ struct raw_data : raw_data_base
template
<
class
Visitor
>
template
<
class
Visitor
>
void
visit_at
(
Visitor
v
,
std
::
size_t
n
=
0
)
const
void
visit_at
(
Visitor
v
,
std
::
size_t
n
=
0
)
const
{
{
auto
&&
s
=
static_cast
<
const
Derived
&>
(
*
this
).
get_shape
();
auto
&&
derived
=
static_cast
<
const
Derived
&>
(
*
this
);
auto
&&
buffer
=
static_cast
<
const
Derived
&>
(
*
this
).
data
();
if
(
derived
.
empty
())
MIGRAPHX_THROW
(
"Visiting empty data!"
);
auto
&&
s
=
derived
.
get_shape
();
auto
&&
buffer
=
derived
.
data
();
s
.
visit_type
([
&
](
auto
as
)
{
v
(
*
(
as
.
from
(
buffer
)
+
s
.
index
(
n
)));
});
s
.
visit_type
([
&
](
auto
as
)
{
v
(
*
(
as
.
from
(
buffer
)
+
s
.
index
(
n
)));
});
}
}
...
@@ -55,8 +59,11 @@ struct raw_data : raw_data_base
...
@@ -55,8 +59,11 @@ struct raw_data : raw_data_base
template
<
class
Visitor
>
template
<
class
Visitor
>
void
visit
(
Visitor
v
)
const
void
visit
(
Visitor
v
)
const
{
{
auto
&&
s
=
static_cast
<
const
Derived
&>
(
*
this
).
get_shape
();
auto
&&
derived
=
static_cast
<
const
Derived
&>
(
*
this
);
auto
&&
buffer
=
static_cast
<
const
Derived
&>
(
*
this
).
data
();
if
(
derived
.
empty
())
MIGRAPHX_THROW
(
"Visiting empty data!"
);
auto
&&
s
=
derived
.
get_shape
();
auto
&&
buffer
=
derived
.
data
();
s
.
visit_type
([
&
](
auto
as
)
{
v
(
make_view
(
s
,
as
.
from
(
buffer
)));
});
s
.
visit_type
([
&
](
auto
as
)
{
v
(
make_view
(
s
,
as
.
from
(
buffer
)));
});
}
}
...
...
src/program.cpp
View file @
833cc7a2
...
@@ -437,13 +437,20 @@ argument program::eval(std::unordered_map<std::string, argument> params) const
...
@@ -437,13 +437,20 @@ argument program::eval(std::unordered_map<std::string, argument> params) const
#else
#else
auto
check_context
=
[](
auto
f
)
{
return
f
();
};
auto
check_context
=
[](
auto
f
)
{
return
f
();
};
#endif
#endif
if
(
enabled
(
MIGRAPHX_TRACE_EVAL
{}))
auto
trace_level
=
value_of
(
MIGRAPHX_TRACE_EVAL
{});
if
(
trace_level
>
0
)
{
{
return
generic_eval
(
*
this
,
ctx
,
std
::
move
(
params
),
[
&
](
auto
&
ins
,
auto
f
)
{
return
generic_eval
(
*
this
,
ctx
,
std
::
move
(
params
),
[
&
](
auto
&
ins
,
auto
f
)
{
ctx
.
finish
();
ctx
.
finish
();
std
::
cout
<<
"Run instruction: "
;
std
::
cout
<<
"Run instruction: "
;
this
->
debug_print
(
ins
);
this
->
debug_print
(
ins
);
return
check_context
(
f
);
auto
result
=
check_context
(
f
);
ctx
.
finish
();
if
(
trace_level
>
1
and
ins
->
name
().
front
()
!=
'@'
and
ins
->
name
()
!=
"load"
)
std
::
cout
<<
"Ouput: "
<<
result
<<
std
::
endl
;
return
result
;
});
});
}
}
else
else
...
...
src/simplify_reshapes.cpp
View file @
833cc7a2
...
@@ -14,7 +14,9 @@ bool is_reshaper(instruction_ref ins)
...
@@ -14,7 +14,9 @@ bool is_reshaper(instruction_ref ins)
// clang-format off
// clang-format off
static
const
std
::
unordered_set
<
std
::
string
>
names
=
{
static
const
std
::
unordered_set
<
std
::
string
>
names
=
{
"reshape"
,
"reshape"
,
"contiguous"
"contiguous"
,
"squeeze"
,
"unsqueeze"
};
};
// clang-format on
// clang-format on
return
contains
(
names
,
ins
->
name
());
return
contains
(
names
,
ins
->
name
());
...
@@ -45,6 +47,9 @@ void simplify_reshapes::apply(program& p) const
...
@@ -45,6 +47,9 @@ void simplify_reshapes::apply(program& p) const
auto
end
=
std
::
prev
(
p
.
end
());
auto
end
=
std
::
prev
(
p
.
end
());
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
p
))
{
{
if
(
ins
==
end
and
ins
->
name
()
==
"contiguous"
)
continue
;
// Skip possible dead instructions
if
(
ins
->
outputs
().
empty
()
and
ins
!=
end
)
if
(
ins
->
outputs
().
empty
()
and
ins
!=
end
)
continue
;
continue
;
if
(
is_reshaper
(
ins
))
if
(
is_reshaper
(
ins
))
...
@@ -94,13 +99,6 @@ void simplify_reshapes::apply(program& p) const
...
@@ -94,13 +99,6 @@ void simplify_reshapes::apply(program& p) const
p
.
replace_instruction
(
ins
,
t
->
inputs
().
front
());
p
.
replace_instruction
(
ins
,
t
->
inputs
().
front
());
}
}
}
}
// Replace all reshapes with as_shape
for
(
auto
ins
:
iterator_for
(
p
))
{
if
(
ins
->
name
()
!=
"reshape"
)
continue
;
p
.
replace_instruction
(
ins
,
op
::
as_shape
{
ins
->
get_shape
()},
ins
->
inputs
());
}
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/target.cpp
View file @
833cc7a2
...
@@ -51,7 +51,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
...
@@ -51,7 +51,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
propagate_constant
{},
propagate_constant
{},
dead_code_elimination
{},
dead_code_elimination
{},
auto_contiguous
{},
auto_contiguous
{},
//
simplify_reshapes{},
simplify_reshapes
{},
dead_code_elimination
{},
dead_code_elimination
{},
lowering
{
ctx
},
lowering
{
ctx
},
eliminate_concat
{
concat_gpu_optimization
{}},
eliminate_concat
{
concat_gpu_optimization
{}},
...
...
src/tf/tf.cpp
View file @
833cc7a2
...
@@ -393,7 +393,9 @@ struct tf_parser
...
@@ -393,7 +393,9 @@ struct tf_parser
int64_t
out_channels
=
num_channels
*
multiplier
;
int64_t
out_channels
=
num_channels
*
multiplier
;
new_weights_shape
[
0
]
=
out_channels
;
new_weights_shape
[
0
]
=
out_channels
;
new_weights_shape
[
1
]
=
1
;
new_weights_shape
[
1
]
=
1
;
auto
new_weights
=
prog
.
add_instruction
(
op
::
reshape
{
new_weights_shape
},
weights
);
// Make sure weights are contiguous before doing reshape
auto
cweights
=
prog
.
add_instruction
(
op
::
contiguous
{},
weights
);
auto
new_weights
=
prog
.
add_instruction
(
op
::
reshape
{
new_weights_shape
},
cweights
);
return
prog
.
add_instruction
(
op
,
{
args
[
0
],
new_weights
});
return
prog
.
add_instruction
(
op
,
{
args
[
0
],
new_weights
});
}
}
...
...
test/gpu/miopen.cpp
View file @
833cc7a2
...
@@ -1364,22 +1364,6 @@ struct test_contiguous : verify_program<test_contiguous>
...
@@ -1364,22 +1364,6 @@ struct test_contiguous : verify_program<test_contiguous>
}
}
};
};
struct
test_eliminate_contiguous
:
verify_program
<
test_eliminate_contiguous
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
auto
seq
=
p
.
add_parameter
(
"seq"
,
s
);
std
::
vector
<
int64_t
>
perm
{
0
,
2
,
1
,
3
};
auto
tran_seq
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{
perm
},
seq
);
std
::
vector
<
int64_t
>
out_shape
{
0
,
0
,
-
1
};
p
.
add_instruction
(
migraphx
::
op
::
reshape
{
out_shape
},
tran_seq
);
return
p
;
}
};
struct
test_transpose
:
verify_program
<
test_transpose
>
struct
test_transpose
:
verify_program
<
test_transpose
>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
...
...
test/tf/tf_test.cpp
View file @
833cc7a2
...
@@ -136,8 +136,9 @@ TEST_CASE(depthwiseconv_test)
...
@@ -136,8 +136,9 @@ TEST_CASE(depthwiseconv_test)
op
.
group
=
3
;
op
.
group
=
3
;
auto
l2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
3
,
1
,
2
}},
l1
);
auto
l2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
3
,
1
,
2
}},
l1
);
auto
l3
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
3
,
0
,
2
}},
l2
);
auto
l3
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
3
,
0
,
2
}},
l2
);
auto
l4
=
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
3
,
1
,
3
,
3
}},
l3
);
auto
l4
=
p
.
add_instruction
(
migraphx
::
op
::
contiguous
{},
l3
);
p
.
add_instruction
(
op
,
l0
,
l4
);
auto
l5
=
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
3
,
1
,
3
,
3
}},
l4
);
p
.
add_instruction
(
op
,
l0
,
l5
);
auto
prog
=
migraphx
::
parse_tf
(
"depthwise_conv_test.pb"
,
true
);
auto
prog
=
migraphx
::
parse_tf
(
"depthwise_conv_test.pb"
,
true
);
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment