Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7cc3243c
Commit
7cc3243c
authored
Sep 26, 2019
by
Paul Fultz II
Committed by
mvermeulen
Sep 26, 2019
Browse files
Fix exception thrown when compiling inceptionv4 (#367)
* Fix compiler crash in TF inceptionv4 * Formatting * Remove else
parent
3962c2ad
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
69 additions
and
20 deletions
+69
-20
src/rewrite_pooling.cpp
src/rewrite_pooling.cpp
+3
-1
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+17
-17
test/gpu/ops_test.cpp
test/gpu/ops_test.cpp
+27
-2
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+22
-0
No files found.
src/rewrite_pooling.cpp
View file @
7cc3243c
...
@@ -20,6 +20,8 @@ void rewrite_pooling::apply(program& prog) const
...
@@ -20,6 +20,8 @@ void rewrite_pooling::apply(program& prog) const
if
(
ins
->
inputs
().
empty
())
if
(
ins
->
inputs
().
empty
())
continue
;
continue
;
auto
&&
s
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
&&
s
=
ins
->
inputs
().
front
()
->
get_shape
();
if
(
not
s
.
standard
())
continue
;
auto
&&
op
=
any_cast
<
op
::
pooling
>
(
ins
->
get_operator
());
auto
&&
op
=
any_cast
<
op
::
pooling
>
(
ins
->
get_operator
());
if
(
op
.
mode
!=
"average"
)
if
(
op
.
mode
!=
"average"
)
continue
;
continue
;
...
...
src/simplify_reshapes.cpp
View file @
7cc3243c
...
@@ -177,8 +177,7 @@ struct find_concat_transpose
...
@@ -177,8 +177,7 @@ struct find_concat_transpose
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"concat"
)(
match
::
same_input_shapes
(),
return
match
::
name
(
"concat"
)(
match
::
all_of
[
match
::
inputs
()](
match
::
transpose_shape
()));
match
::
all_of
[
match
::
inputs
()](
match
::
transpose_shape
()));
}
}
void
apply
(
program
&
p
,
const
match
::
matcher_result
&
mr
)
const
void
apply
(
program
&
p
,
const
match
::
matcher_result
&
mr
)
const
...
@@ -194,8 +193,6 @@ struct find_concat_transpose
...
@@ -194,8 +193,6 @@ struct find_concat_transpose
std
::
vector
<
instruction_ref
>
inputs
;
std
::
vector
<
instruction_ref
>
inputs
;
std
::
transform
(
std
::
transform
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
[
&
](
auto
i
)
{
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
[
&
](
auto
i
)
{
if
(
i
->
name
()
==
"transpose"
and
i
->
inputs
().
front
()
->
get_shape
().
standard
())
return
i
->
inputs
().
front
();
return
p
.
insert_instruction
(
ins
,
op
::
transpose
{
permutation
},
i
);
return
p
.
insert_instruction
(
ins
,
op
::
transpose
{
permutation
},
i
);
});
});
auto
concat
=
p
.
insert_instruction
(
ins
,
op
,
inputs
);
auto
concat
=
p
.
insert_instruction
(
ins
,
op
,
inputs
);
...
@@ -207,6 +204,8 @@ struct find_concat_transpose
...
@@ -207,6 +204,8 @@ struct find_concat_transpose
void
simplify_reshapes
::
apply
(
program
&
p
)
const
void
simplify_reshapes
::
apply
(
program
&
p
)
const
{
{
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
auto
end
=
std
::
prev
(
p
.
end
());
auto
end
=
std
::
prev
(
p
.
end
());
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
p
))
{
{
...
@@ -222,6 +221,7 @@ void simplify_reshapes::apply(program& p) const
...
@@ -222,6 +221,7 @@ void simplify_reshapes::apply(program& p) const
find_transpose
{},
find_transpose
{},
find_concat_transpose
{});
find_concat_transpose
{});
}
}
}
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
test/gpu/ops_test.cpp
View file @
7cc3243c
#include <migraphx/env.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
...
@@ -24,6 +25,8 @@
...
@@ -24,6 +25,8 @@
#pragma clang diagnostic ignored "-Wglobal-constructors"
#pragma clang diagnostic ignored "-Wglobal-constructors"
#endif
#endif
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_GPU_COMPILE
)
// An improved async, that doesn't block
// An improved async, that doesn't block
template
<
class
Function
>
template
<
class
Function
>
std
::
future
<
typename
std
::
result_of
<
Function
()
>::
type
>
detach_async
(
Function
&&
f
,
std
::
future
<
typename
std
::
result_of
<
Function
()
>::
type
>
detach_async
(
Function
&&
f
,
...
@@ -82,7 +85,7 @@ auto get_hash(const T& x)
...
@@ -82,7 +85,7 @@ auto get_hash(const T& x)
return
std
::
hash
<
T
>
{}(
x
);
return
std
::
hash
<
T
>
{}(
x
);
}
}
void
compile_check
(
migraphx
::
program
&
p
,
const
migraphx
::
target
&
t
)
void
compile_check
(
migraphx
::
program
&
p
,
const
migraphx
::
target
&
t
,
bool
show_trace
=
false
)
{
{
auto
name
=
t
.
name
();
auto
name
=
t
.
name
();
auto
s
=
p
.
get_shape
();
auto
s
=
p
.
get_shape
();
...
@@ -93,6 +96,10 @@ void compile_check(migraphx::program& p, const migraphx::target& t)
...
@@ -93,6 +96,10 @@ void compile_check(migraphx::program& p, const migraphx::target& t)
std
::
cout
<<
ss
.
str
()
<<
std
::
endl
;
std
::
cout
<<
ss
.
str
()
<<
std
::
endl
;
throw
std
::
runtime_error
(
"Compiling program with "
+
name
+
" alters its shape"
);
throw
std
::
runtime_error
(
"Compiling program with "
+
name
+
" alters its shape"
);
}
}
if
(
show_trace
)
{
std
::
cout
<<
ss
.
str
()
<<
std
::
endl
;
}
}
}
template
<
class
V
>
template
<
class
V
>
...
@@ -116,7 +123,7 @@ migraphx::argument run_gpu(migraphx::program& p)
...
@@ -116,7 +123,7 @@ migraphx::argument run_gpu(migraphx::program& p)
V
v
;
V
v
;
p
=
v
.
create_program
();
p
=
v
.
create_program
();
auto_print
pp
{
p
,
1
};
auto_print
pp
{
p
,
1
};
compile_check
(
p
,
migraphx
::
gpu
::
target
{});
compile_check
(
p
,
migraphx
::
gpu
::
target
{}
,
migraphx
::
enabled
(
MIGRAPHX_TRACE_GPU_COMPILE
{})
);
migraphx
::
program
::
parameter_map
m
;
migraphx
::
program
::
parameter_map
m
;
for
(
auto
&&
x
:
p
.
get_parameter_shapes
())
for
(
auto
&&
x
:
p
.
get_parameter_shapes
())
{
{
...
@@ -985,6 +992,24 @@ struct test_conv_pooling : verify_program<test_conv_pooling>
...
@@ -985,6 +992,24 @@ struct test_conv_pooling : verify_program<test_conv_pooling>
}
}
};
};
struct
test_concat_pooling
:
verify_program
<
test_concat_pooling
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
input
=
p
.
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
256
,
8
,
8
}});
auto
transpose
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
2
,
3
,
1
}},
input
);
auto
concat
=
p
.
add_instruction
(
migraphx
::
op
::
concat
{
3
},
transpose
);
auto
concat_t
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
3
,
1
,
2
}},
concat
);
auto
pooling
=
p
.
add_instruction
(
migraphx
::
op
::
pooling
{
"average"
,
{
0
,
0
},
{
1
,
1
},
{
8
,
8
}},
concat_t
);
p
.
add_instruction
(
migraphx
::
op
::
relu
{},
pooling
);
return
p
;
}
};
struct
test_global_avg_pooling
:
verify_program
<
test_global_avg_pooling
>
struct
test_global_avg_pooling
:
verify_program
<
test_global_avg_pooling
>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
...
...
test/simplify_reshapes_test.cpp
View file @
7cc3243c
...
@@ -306,4 +306,26 @@ TEST_CASE(concat_transpose2)
...
@@ -306,4 +306,26 @@ TEST_CASE(concat_transpose2)
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
1
);
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
1
);
}
}
TEST_CASE
(
concat_transpose3
)
{
migraphx
::
program
p
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
,
4
}};
auto
x
=
p
.
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
,
4
}});
auto
y
=
p
.
add_parameter
(
"y"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
5
,
3
,
4
}});
auto
xt
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
2
,
3
,
1
}},
x
);
auto
yt
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
2
,
3
,
1
}},
y
);
auto
concat
=
p
.
add_instruction
(
migraphx
::
op
::
concat
{
3
},
xt
,
yt
);
auto
t
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
2
,
3
,
1
}},
concat
);
p
.
add_instruction
(
pass_op
{},
t
);
auto
out_shape
=
p
.
get_shape
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
().
lens
()
==
out_shape
.
lens
());
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
-
2
);
auto
new_concat
=
std
::
find_if
(
p
.
begin
(),
p
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"concat"
;
});
EXPECT
(
bool
{
new_concat
!=
p
.
end
()});
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
1
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment