Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7cc3243c
Commit
7cc3243c
authored
Sep 26, 2019
by
Paul Fultz II
Committed by
mvermeulen
Sep 26, 2019
Browse files
Fix exception thrown when compiling inceptionv4 (#367)
* Fix compiler crash in TF inceptionv4 * Formatting * Remove else
parent
3962c2ad
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
69 additions
and
20 deletions
+69
-20
src/rewrite_pooling.cpp
src/rewrite_pooling.cpp
+3
-1
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+17
-17
test/gpu/ops_test.cpp
test/gpu/ops_test.cpp
+27
-2
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+22
-0
No files found.
src/rewrite_pooling.cpp
View file @
7cc3243c
...
@@ -19,7 +19,9 @@ void rewrite_pooling::apply(program& prog) const
...
@@ -19,7 +19,9 @@ void rewrite_pooling::apply(program& prog) const
continue
;
continue
;
if
(
ins
->
inputs
().
empty
())
if
(
ins
->
inputs
().
empty
())
continue
;
continue
;
auto
&&
s
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
&&
s
=
ins
->
inputs
().
front
()
->
get_shape
();
if
(
not
s
.
standard
())
continue
;
auto
&&
op
=
any_cast
<
op
::
pooling
>
(
ins
->
get_operator
());
auto
&&
op
=
any_cast
<
op
::
pooling
>
(
ins
->
get_operator
());
if
(
op
.
mode
!=
"average"
)
if
(
op
.
mode
!=
"average"
)
continue
;
continue
;
...
...
src/simplify_reshapes.cpp
View file @
7cc3243c
...
@@ -177,8 +177,7 @@ struct find_concat_transpose
...
@@ -177,8 +177,7 @@ struct find_concat_transpose
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"concat"
)(
match
::
same_input_shapes
(),
return
match
::
name
(
"concat"
)(
match
::
all_of
[
match
::
inputs
()](
match
::
transpose_shape
()));
match
::
all_of
[
match
::
inputs
()](
match
::
transpose_shape
()));
}
}
void
apply
(
program
&
p
,
const
match
::
matcher_result
&
mr
)
const
void
apply
(
program
&
p
,
const
match
::
matcher_result
&
mr
)
const
...
@@ -194,8 +193,6 @@ struct find_concat_transpose
...
@@ -194,8 +193,6 @@ struct find_concat_transpose
std
::
vector
<
instruction_ref
>
inputs
;
std
::
vector
<
instruction_ref
>
inputs
;
std
::
transform
(
std
::
transform
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
[
&
](
auto
i
)
{
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
[
&
](
auto
i
)
{
if
(
i
->
name
()
==
"transpose"
and
i
->
inputs
().
front
()
->
get_shape
().
standard
())
return
i
->
inputs
().
front
();
return
p
.
insert_instruction
(
ins
,
op
::
transpose
{
permutation
},
i
);
return
p
.
insert_instruction
(
ins
,
op
::
transpose
{
permutation
},
i
);
});
});
auto
concat
=
p
.
insert_instruction
(
ins
,
op
,
inputs
);
auto
concat
=
p
.
insert_instruction
(
ins
,
op
,
inputs
);
...
@@ -207,20 +204,23 @@ struct find_concat_transpose
...
@@ -207,20 +204,23 @@ struct find_concat_transpose
void
simplify_reshapes
::
apply
(
program
&
p
)
const
void
simplify_reshapes
::
apply
(
program
&
p
)
const
{
{
auto
end
=
std
::
prev
(
p
.
end
());
for
(
int
i
=
0
;
i
<
2
;
i
++
)
for
(
auto
ins
:
iterator_for
(
p
))
{
{
if
(
ins
==
end
and
ins
->
name
()
==
"contiguous"
)
auto
end
=
std
::
prev
(
p
.
end
());
continue
;
for
(
auto
ins
:
iterator_for
(
p
))
// Skip possible dead instructions
{
if
(
ins
->
outputs
().
empty
()
and
ins
!=
end
)
if
(
ins
==
end
and
ins
->
name
()
==
"contiguous"
)
continue
;
continue
;
match
::
find_matches
(
p
,
// Skip possible dead instructions
ins
,
if
(
ins
->
outputs
().
empty
()
and
ins
!=
end
)
find_nop_reshapes
{},
continue
;
find_reshaper
{},
match
::
find_matches
(
p
,
find_transpose
{},
ins
,
find_concat_transpose
{});
find_nop_reshapes
{},
find_reshaper
{},
find_transpose
{},
find_concat_transpose
{});
}
}
}
}
}
...
...
test/gpu/ops_test.cpp
View file @
7cc3243c
#include <migraphx/env.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
...
@@ -24,6 +25,8 @@
...
@@ -24,6 +25,8 @@
#pragma clang diagnostic ignored "-Wglobal-constructors"
#pragma clang diagnostic ignored "-Wglobal-constructors"
#endif
#endif
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_GPU_COMPILE)
// An improved async, that doesn't block
// An improved async, that doesn't block
template <class Function>
template <class Function>
std::future<typename std::result_of<Function()>::type> detach_async(Function&& f,
std::future<typename std::result_of<Function()>::type> detach_async(Function&& f,
...
@@ -82,7 +85,7 @@ auto get_hash(const T& x)
...
@@ -82,7 +85,7 @@ auto get_hash(const T& x)
return std::hash<T>{}(x);
return std::hash<T>{}(x);
}
}
void
compile_check
(
migraphx
::
program
&
p
,
const
migraphx
::
target
&
t
)
void compile_check(migraphx::program& p, const migraphx::target& t
, bool show_trace = false
)
{
{
auto name = t.name();
auto name = t.name();
auto s = p.get_shape();
auto s = p.get_shape();
...
@@ -93,6 +96,10 @@ void compile_check(migraphx::program& p, const migraphx::target& t)
...
@@ -93,6 +96,10 @@ void compile_check(migraphx::program& p, const migraphx::target& t)
std::cout << ss.str() << std::endl;
std::cout << ss.str() << std::endl;
throw std::runtime_error("Compiling program with " + name + " alters its shape");
throw std::runtime_error("Compiling program with " + name + " alters its shape");
}
}
if(show_trace)
{
std::cout << ss.str() << std::endl;
}
}
}
template <class V>
template <class V>
...
@@ -116,7 +123,7 @@ migraphx::argument run_gpu(migraphx::program& p)
...
@@ -116,7 +123,7 @@ migraphx::argument run_gpu(migraphx::program& p)
V v;
V v;
p = v.create_program();
p = v.create_program();
auto_print pp{p, 1};
auto_print pp{p, 1};
compile_check
(
p
,
migraphx
::
gpu
::
target
{});
compile_check(p, migraphx::gpu::target{}
, migraphx::enabled(MIGRAPHX_TRACE_GPU_COMPILE{})
);
migraphx::program::parameter_map m;
migraphx::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
for(auto&& x : p.get_parameter_shapes())
{
{
...
@@ -985,6 +992,24 @@ struct test_conv_pooling : verify_program<test_conv_pooling>
...
@@ -985,6 +992,24 @@ struct test_conv_pooling : verify_program<test_conv_pooling>
}
}
};
};
struct test_concat_pooling : verify_program<test_concat_pooling>
{
migraphx::program create_program() const
{
migraphx::program p;
auto input =
p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 256, 8, 8}});
auto transpose = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, input);
auto concat = p.add_instruction(migraphx::op::concat{3}, transpose);
auto concat_t = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, concat);
auto pooling =
p.add_instruction(migraphx::op::pooling{"average", {0, 0}, {1, 1}, {8, 8}}, concat_t);
p.add_instruction(migraphx::op::relu{}, pooling);
return p;
}
};
struct test_global_avg_pooling : verify_program<test_global_avg_pooling>
struct test_global_avg_pooling : verify_program<test_global_avg_pooling>
{
{
migraphx::program create_program() const
migraphx::program create_program() const
...
...
test/simplify_reshapes_test.cpp
View file @
7cc3243c
...
@@ -306,4 +306,26 @@ TEST_CASE(concat_transpose2)
...
@@ -306,4 +306,26 @@ TEST_CASE(concat_transpose2)
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
1
);
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
1
);
}
}
TEST_CASE
(
concat_transpose3
)
{
migraphx
::
program
p
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
,
4
}};
auto
x
=
p
.
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
,
4
}});
auto
y
=
p
.
add_parameter
(
"y"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
5
,
3
,
4
}});
auto
xt
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
2
,
3
,
1
}},
x
);
auto
yt
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
2
,
3
,
1
}},
y
);
auto
concat
=
p
.
add_instruction
(
migraphx
::
op
::
concat
{
3
},
xt
,
yt
);
auto
t
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
2
,
3
,
1
}},
concat
);
p
.
add_instruction
(
pass_op
{},
t
);
auto
out_shape
=
p
.
get_shape
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
().
lens
()
==
out_shape
.
lens
());
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
-
2
);
auto
new_concat
=
std
::
find_if
(
p
.
begin
(),
p
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"concat"
;
});
EXPECT
(
bool
{
new_concat
!=
p
.
end
()});
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
1
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment