Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2c60e428
Unverified
Commit
2c60e428
authored
May 21, 2019
by
mvermeulen
Committed by
GitHub
May 21, 2019
Browse files
Merge pull request #263 from ROCmSoftwarePlatform/reshape
Re-enable simplify_reshapes
parents
cc8605e4
973b496b
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
23 additions
and
36 deletions
+23
-36
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+0
-7
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+7
-0
src/include/migraphx/op/reshape.hpp
src/include/migraphx/op/reshape.hpp
+1
-1
src/include/migraphx/op/squeeze.hpp
src/include/migraphx/op/squeeze.hpp
+1
-0
src/include/migraphx/op/unsqueeze.hpp
src/include/migraphx/op/unsqueeze.hpp
+1
-0
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+6
-8
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+1
-1
src/tf/tf.cpp
src/tf/tf.cpp
+3
-1
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+0
-16
test/tf/tf_test.cpp
test/tf/tf_test.cpp
+3
-2
No files found.
src/eliminate_contiguous.cpp
View file @
2c60e428
...
...
@@ -67,13 +67,6 @@ void eliminate_contiguous::apply(program& p) const
{
for
(
auto
ins
:
iterator_for
(
p
))
{
// skip the reshape operator for now, since there is a bug
// for the transpose followed by a reshape
if
(
ins
->
name
()
==
"reshape"
)
{
continue
;
}
// Make a copy so we can modify it while we iterate
auto
args
=
ins
->
inputs
();
for
(
auto
arg
:
ins
->
inputs
())
...
...
src/include/migraphx/check_shapes.hpp
View file @
2c60e428
...
...
@@ -103,6 +103,13 @@ struct check_shapes
return
*
this
;
}
const
check_shapes
&
standard_or_scalar
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
()
or
s
.
scalar
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not a scalar or in standard layout"
);
return
*
this
;
}
const
check_shapes
&
packed
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
();
}))
...
...
src/include/migraphx/op/reshape.hpp
View file @
2c60e428
...
...
@@ -29,7 +29,7 @@ struct reshape
std
::
string
name
()
const
{
return
"reshape"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
}.
has
(
1
)
.
standard
()
;
auto
&&
idims
=
inputs
.
front
().
lens
();
std
::
vector
<
std
::
size_t
>
rdims
(
dims
.
begin
(),
dims
.
end
());
auto
n_neg_dims
=
std
::
count
(
dims
.
begin
(),
dims
.
end
(),
-
1
);
...
...
src/include/migraphx/op/squeeze.hpp
View file @
2c60e428
...
...
@@ -29,6 +29,7 @@ struct squeeze
std
::
string
name
()
const
{
return
"squeeze"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
auto
input_shape
=
inputs
[
0
];
auto
type
=
input_shape
.
type
();
auto
old_lens
=
input_shape
.
lens
();
...
...
src/include/migraphx/op/unsqueeze.hpp
View file @
2c60e428
...
...
@@ -29,6 +29,7 @@ struct unsqueeze
std
::
string
name
()
const
{
return
"unsqueeze"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard_or_scalar
();
auto
input_shape
=
inputs
[
0
];
auto
type
=
input_shape
.
type
();
auto
old_lens
=
input_shape
.
lens
();
...
...
src/simplify_reshapes.cpp
View file @
2c60e428
...
...
@@ -14,7 +14,9 @@ bool is_reshaper(instruction_ref ins)
// clang-format off
static
const
std
::
unordered_set
<
std
::
string
>
names
=
{
"reshape"
,
"contiguous"
"contiguous"
,
"squeeze"
,
"unsqueeze"
};
// clang-format on
return
contains
(
names
,
ins
->
name
());
...
...
@@ -45,6 +47,9 @@ void simplify_reshapes::apply(program& p) const
auto
end
=
std
::
prev
(
p
.
end
());
for
(
auto
ins
:
iterator_for
(
p
))
{
if
(
ins
==
end
and
ins
->
name
()
==
"contiguous"
)
continue
;
// Skip possible dead instructions
if
(
ins
->
outputs
().
empty
()
and
ins
!=
end
)
continue
;
if
(
is_reshaper
(
ins
))
...
...
@@ -94,13 +99,6 @@ void simplify_reshapes::apply(program& p) const
p
.
replace_instruction
(
ins
,
t
->
inputs
().
front
());
}
}
// Replace all reshapes with as_shape
for
(
auto
ins
:
iterator_for
(
p
))
{
if
(
ins
->
name
()
!=
"reshape"
)
continue
;
p
.
replace_instruction
(
ins
,
op
::
as_shape
{
ins
->
get_shape
()},
ins
->
inputs
());
}
}
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/target.cpp
View file @
2c60e428
...
...
@@ -51,7 +51,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
propagate_constant
{},
dead_code_elimination
{},
auto_contiguous
{},
//
simplify_reshapes{},
simplify_reshapes
{},
dead_code_elimination
{},
lowering
{
ctx
},
eliminate_concat
{
concat_gpu_optimization
{}},
...
...
src/tf/tf.cpp
View file @
2c60e428
...
...
@@ -393,7 +393,9 @@ struct tf_parser
int64_t
out_channels
=
num_channels
*
multiplier
;
new_weights_shape
[
0
]
=
out_channels
;
new_weights_shape
[
1
]
=
1
;
auto
new_weights
=
prog
.
add_instruction
(
op
::
reshape
{
new_weights_shape
},
weights
);
// Make sure weights are contiguous before doing reshape
auto
cweights
=
prog
.
add_instruction
(
op
::
contiguous
{},
weights
);
auto
new_weights
=
prog
.
add_instruction
(
op
::
reshape
{
new_weights_shape
},
cweights
);
return
prog
.
add_instruction
(
op
,
{
args
[
0
],
new_weights
});
}
...
...
test/gpu/miopen.cpp
View file @
2c60e428
...
...
@@ -1251,22 +1251,6 @@ struct test_contiguous : verify_program<test_contiguous>
}
};
struct
test_eliminate_contiguous
:
verify_program
<
test_eliminate_contiguous
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
auto
seq
=
p
.
add_parameter
(
"seq"
,
s
);
std
::
vector
<
int64_t
>
perm
{
0
,
2
,
1
,
3
};
auto
tran_seq
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{
perm
},
seq
);
std
::
vector
<
int64_t
>
out_shape
{
0
,
0
,
-
1
};
p
.
add_instruction
(
migraphx
::
op
::
reshape
{
out_shape
},
tran_seq
);
return
p
;
}
};
struct
test_transpose
:
verify_program
<
test_transpose
>
{
migraphx
::
program
create_program
()
const
...
...
test/tf/tf_test.cpp
View file @
2c60e428
...
...
@@ -136,8 +136,9 @@ TEST_CASE(depthwiseconv_test)
op
.
group
=
3
;
auto
l2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
3
,
1
,
2
}},
l1
);
auto
l3
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
3
,
0
,
2
}},
l2
);
auto
l4
=
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
3
,
1
,
3
,
3
}},
l3
);
p
.
add_instruction
(
op
,
l0
,
l4
);
auto
l4
=
p
.
add_instruction
(
migraphx
::
op
::
contiguous
{},
l3
);
auto
l5
=
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
3
,
1
,
3
,
3
}},
l4
);
p
.
add_instruction
(
op
,
l0
,
l5
);
auto
prog
=
migraphx
::
parse_tf
(
"depthwise_conv_test.pb"
,
true
);
EXPECT
(
p
==
prog
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment