Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2c60e428
Unverified
Commit
2c60e428
authored
May 21, 2019
by
mvermeulen
Committed by
GitHub
May 21, 2019
Browse files
Merge pull request #263 from ROCmSoftwarePlatform/reshape
Re-enable simplify_reshapes
parents
cc8605e4
973b496b
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
23 additions
and
36 deletions
+23
-36
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+0
-7
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+7
-0
src/include/migraphx/op/reshape.hpp
src/include/migraphx/op/reshape.hpp
+1
-1
src/include/migraphx/op/squeeze.hpp
src/include/migraphx/op/squeeze.hpp
+1
-0
src/include/migraphx/op/unsqueeze.hpp
src/include/migraphx/op/unsqueeze.hpp
+1
-0
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+6
-8
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+1
-1
src/tf/tf.cpp
src/tf/tf.cpp
+3
-1
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+0
-16
test/tf/tf_test.cpp
test/tf/tf_test.cpp
+3
-2
No files found.
src/eliminate_contiguous.cpp
View file @
2c60e428
...
@@ -67,13 +67,6 @@ void eliminate_contiguous::apply(program& p) const
...
@@ -67,13 +67,6 @@ void eliminate_contiguous::apply(program& p) const
{
{
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
p
))
{
{
// skip the reshape operator for now, since there is a bug
// for the transpose followed by a reshape
if
(
ins
->
name
()
==
"reshape"
)
{
continue
;
}
// Make a copy so we can modify it while we iterate
// Make a copy so we can modify it while we iterate
auto
args
=
ins
->
inputs
();
auto
args
=
ins
->
inputs
();
for
(
auto
arg
:
ins
->
inputs
())
for
(
auto
arg
:
ins
->
inputs
())
...
...
src/include/migraphx/check_shapes.hpp
View file @
2c60e428
...
@@ -103,6 +103,13 @@ struct check_shapes
...
@@ -103,6 +103,13 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
const
check_shapes
&
standard_or_scalar
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
()
or
s
.
scalar
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not a scalar or in standard layout"
);
return
*
this
;
}
const
check_shapes
&
packed
()
const
const
check_shapes
&
packed
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
();
}))
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
();
}))
...
...
src/include/migraphx/op/reshape.hpp
View file @
2c60e428
...
@@ -29,7 +29,7 @@ struct reshape
...
@@ -29,7 +29,7 @@ struct reshape
std
::
string
name
()
const
{
return
"reshape"
;
}
std
::
string
name
()
const
{
return
"reshape"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
}.
has
(
1
)
.
standard
()
;
auto
&&
idims
=
inputs
.
front
().
lens
();
auto
&&
idims
=
inputs
.
front
().
lens
();
std
::
vector
<
std
::
size_t
>
rdims
(
dims
.
begin
(),
dims
.
end
());
std
::
vector
<
std
::
size_t
>
rdims
(
dims
.
begin
(),
dims
.
end
());
auto
n_neg_dims
=
std
::
count
(
dims
.
begin
(),
dims
.
end
(),
-
1
);
auto
n_neg_dims
=
std
::
count
(
dims
.
begin
(),
dims
.
end
(),
-
1
);
...
...
src/include/migraphx/op/squeeze.hpp
View file @
2c60e428
...
@@ -29,6 +29,7 @@ struct squeeze
...
@@ -29,6 +29,7 @@ struct squeeze
std
::
string
name
()
const
{
return
"squeeze"
;
}
std
::
string
name
()
const
{
return
"squeeze"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
auto
input_shape
=
inputs
[
0
];
auto
input_shape
=
inputs
[
0
];
auto
type
=
input_shape
.
type
();
auto
type
=
input_shape
.
type
();
auto
old_lens
=
input_shape
.
lens
();
auto
old_lens
=
input_shape
.
lens
();
...
...
src/include/migraphx/op/unsqueeze.hpp
View file @
2c60e428
...
@@ -29,6 +29,7 @@ struct unsqueeze
...
@@ -29,6 +29,7 @@ struct unsqueeze
std
::
string
name
()
const
{
return
"unsqueeze"
;
}
std
::
string
name
()
const
{
return
"unsqueeze"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard_or_scalar
();
auto
input_shape
=
inputs
[
0
];
auto
input_shape
=
inputs
[
0
];
auto
type
=
input_shape
.
type
();
auto
type
=
input_shape
.
type
();
auto
old_lens
=
input_shape
.
lens
();
auto
old_lens
=
input_shape
.
lens
();
...
...
src/simplify_reshapes.cpp
View file @
2c60e428
...
@@ -14,7 +14,9 @@ bool is_reshaper(instruction_ref ins)
...
@@ -14,7 +14,9 @@ bool is_reshaper(instruction_ref ins)
// clang-format off
// clang-format off
static
const
std
::
unordered_set
<
std
::
string
>
names
=
{
static
const
std
::
unordered_set
<
std
::
string
>
names
=
{
"reshape"
,
"reshape"
,
"contiguous"
"contiguous"
,
"squeeze"
,
"unsqueeze"
};
};
// clang-format on
// clang-format on
return
contains
(
names
,
ins
->
name
());
return
contains
(
names
,
ins
->
name
());
...
@@ -45,6 +47,9 @@ void simplify_reshapes::apply(program& p) const
...
@@ -45,6 +47,9 @@ void simplify_reshapes::apply(program& p) const
auto
end
=
std
::
prev
(
p
.
end
());
auto
end
=
std
::
prev
(
p
.
end
());
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
p
))
{
{
if
(
ins
==
end
and
ins
->
name
()
==
"contiguous"
)
continue
;
// Skip possible dead instructions
if
(
ins
->
outputs
().
empty
()
and
ins
!=
end
)
if
(
ins
->
outputs
().
empty
()
and
ins
!=
end
)
continue
;
continue
;
if
(
is_reshaper
(
ins
))
if
(
is_reshaper
(
ins
))
...
@@ -94,13 +99,6 @@ void simplify_reshapes::apply(program& p) const
...
@@ -94,13 +99,6 @@ void simplify_reshapes::apply(program& p) const
p
.
replace_instruction
(
ins
,
t
->
inputs
().
front
());
p
.
replace_instruction
(
ins
,
t
->
inputs
().
front
());
}
}
}
}
// Replace all reshapes with as_shape
for
(
auto
ins
:
iterator_for
(
p
))
{
if
(
ins
->
name
()
!=
"reshape"
)
continue
;
p
.
replace_instruction
(
ins
,
op
::
as_shape
{
ins
->
get_shape
()},
ins
->
inputs
());
}
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/target.cpp
View file @
2c60e428
...
@@ -51,7 +51,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
...
@@ -51,7 +51,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
propagate_constant
{},
propagate_constant
{},
dead_code_elimination
{},
dead_code_elimination
{},
auto_contiguous
{},
auto_contiguous
{},
//
simplify_reshapes{},
simplify_reshapes
{},
dead_code_elimination
{},
dead_code_elimination
{},
lowering
{
ctx
},
lowering
{
ctx
},
eliminate_concat
{
concat_gpu_optimization
{}},
eliminate_concat
{
concat_gpu_optimization
{}},
...
...
src/tf/tf.cpp
View file @
2c60e428
...
@@ -393,7 +393,9 @@ struct tf_parser
...
@@ -393,7 +393,9 @@ struct tf_parser
int64_t
out_channels
=
num_channels
*
multiplier
;
int64_t
out_channels
=
num_channels
*
multiplier
;
new_weights_shape
[
0
]
=
out_channels
;
new_weights_shape
[
0
]
=
out_channels
;
new_weights_shape
[
1
]
=
1
;
new_weights_shape
[
1
]
=
1
;
auto
new_weights
=
prog
.
add_instruction
(
op
::
reshape
{
new_weights_shape
},
weights
);
// Make sure weights are contiguous before doing reshape
auto
cweights
=
prog
.
add_instruction
(
op
::
contiguous
{},
weights
);
auto
new_weights
=
prog
.
add_instruction
(
op
::
reshape
{
new_weights_shape
},
cweights
);
return
prog
.
add_instruction
(
op
,
{
args
[
0
],
new_weights
});
return
prog
.
add_instruction
(
op
,
{
args
[
0
],
new_weights
});
}
}
...
...
test/gpu/miopen.cpp
View file @
2c60e428
...
@@ -1251,22 +1251,6 @@ struct test_contiguous : verify_program<test_contiguous>
...
@@ -1251,22 +1251,6 @@ struct test_contiguous : verify_program<test_contiguous>
}
}
};
};
struct
test_eliminate_contiguous
:
verify_program
<
test_eliminate_contiguous
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
auto
seq
=
p
.
add_parameter
(
"seq"
,
s
);
std
::
vector
<
int64_t
>
perm
{
0
,
2
,
1
,
3
};
auto
tran_seq
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{
perm
},
seq
);
std
::
vector
<
int64_t
>
out_shape
{
0
,
0
,
-
1
};
p
.
add_instruction
(
migraphx
::
op
::
reshape
{
out_shape
},
tran_seq
);
return
p
;
}
};
struct
test_transpose
:
verify_program
<
test_transpose
>
struct
test_transpose
:
verify_program
<
test_transpose
>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
...
...
test/tf/tf_test.cpp
View file @
2c60e428
...
@@ -136,8 +136,9 @@ TEST_CASE(depthwiseconv_test)
...
@@ -136,8 +136,9 @@ TEST_CASE(depthwiseconv_test)
op
.
group
=
3
;
op
.
group
=
3
;
auto
l2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
3
,
1
,
2
}},
l1
);
auto
l2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
3
,
1
,
2
}},
l1
);
auto
l3
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
3
,
0
,
2
}},
l2
);
auto
l3
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
3
,
0
,
2
}},
l2
);
auto
l4
=
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
3
,
1
,
3
,
3
}},
l3
);
auto
l4
=
p
.
add_instruction
(
migraphx
::
op
::
contiguous
{},
l3
);
p
.
add_instruction
(
op
,
l0
,
l4
);
auto
l5
=
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
3
,
1
,
3
,
3
}},
l4
);
p
.
add_instruction
(
op
,
l0
,
l5
);
auto
prog
=
migraphx
::
parse_tf
(
"depthwise_conv_test.pb"
,
true
);
auto
prog
=
migraphx
::
parse_tf
(
"depthwise_conv_test.pb"
,
true
);
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment