Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
60d5fa1a
"docs/source/hpo_advanced.rst" did not exist on "31afa426fbb097d9dac6546df76ddcebfe45eb8d"
Unverified
Commit
60d5fa1a
authored
Nov 03, 2022
by
Charlie Lin
Committed by
GitHub
Nov 03, 2022
Browse files
Merge branch 'develop' into dyn_ref_multibroadcast
parents
41268947
1820198e
Changes
37
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
625 additions
and
64 deletions
+625
-64
src/targets/gpu/include/migraphx/gpu/convolution.hpp
src/targets/gpu/include/migraphx/gpu/convolution.hpp
+5
-8
src/targets/gpu/jit/concat.cpp
src/targets/gpu/jit/concat.cpp
+28
-11
src/targets/gpu/jit/pointwise.cpp
src/targets/gpu/jit/pointwise.cpp
+3
-3
src/targets/gpu/kernels/include/migraphx/kernels/concat.hpp
src/targets/gpu/kernels/include/migraphx/kernels/concat.hpp
+21
-9
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+12
-32
src/targets/gpu/mlir.cpp
src/targets/gpu/mlir.cpp
+4
-1
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+10
-0
test/layout_nhwc.cpp
test/layout_nhwc.cpp
+127
-0
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+86
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+42
-0
test/onnx/split_test_invalid_split.onnx
test/onnx/split_test_invalid_split.onnx
+25
-0
test/onnx/split_test_no_attribute.onnx
test/onnx/split_test_no_attribute.onnx
+26
-0
test/onnx/split_test_no_attribute_invalid_input_split.onnx
test/onnx/split_test_no_attribute_invalid_input_split.onnx
+26
-0
test/onnx/split_test_no_attribute_invalid_split.onnx
test/onnx/split_test_no_attribute_invalid_split.onnx
+26
-0
test/verify/test_concat_broadcast_add.cpp
test/verify/test_concat_broadcast_add.cpp
+49
-0
test/verify/test_slice_concat_add.cpp
test/verify/test_slice_concat_add.cpp
+47
-0
tools/convert_onnx_version.py
tools/convert_onnx_version.py
+88
-0
No files found.
src/targets/gpu/include/migraphx/gpu/convolution.hpp
View file @
60d5fa1a
...
@@ -83,9 +83,10 @@ struct miopen_convolution
...
@@ -83,9 +83,10 @@ struct miopen_convolution
inline
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
inline
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
{
check_shapes
{
inputs
,
op
}.
has
(
4
)
.
standard
()
;
check_shapes
{
inputs
,
op
}.
has
(
4
);
std
::
vector
<
shape
>
conv_inputs
(
inputs
.
begin
(),
inputs
.
begin
()
+
2
);
std
::
vector
<
shape
>
conv_inputs
(
inputs
.
begin
(),
inputs
.
begin
()
+
2
);
check_shapes
{
conv_inputs
,
op
}.
max_ndims
(
5
);
check_shapes
{
conv_inputs
,
*
this
}.
max_ndims
(
5
).
packed_layouts
(
{{
0
,
1
,
2
},
{
0
,
1
,
2
,
3
},
{
0
,
2
,
3
,
1
},
{
0
,
1
,
2
,
3
,
4
}});
return
migraphx
::
compute_shape
<
Op
>
(
op
,
conv_inputs
);
return
migraphx
::
compute_shape
<
Op
>
(
op
,
conv_inputs
);
}
}
...
@@ -144,12 +145,9 @@ struct miopen_convolution
...
@@ -144,12 +145,9 @@ struct miopen_convolution
#endif
#endif
}
}
inline
void
set_conv_descriptor
()
void
set_conv_descriptor
()
{
{
if
(
cd
==
nullptr
)
cd
=
(
op
.
name
()
==
"deconvolution"
)
?
make_deconv
(
op
)
:
make_conv
(
op
);
{
cd
=
(
op
.
name
()
==
"deconvolution"
)
?
make_deconv
(
op
)
:
make_conv
(
op
);
}
}
}
value
compile
(
migraphx
::
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
shape
>&
input
)
value
compile
(
migraphx
::
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
shape
>&
input
)
...
@@ -239,7 +237,6 @@ struct miopen_convolution
...
@@ -239,7 +237,6 @@ struct miopen_convolution
if
(
status
!=
miopenStatusSuccess
)
if
(
status
!=
miopenStatusSuccess
)
MIGRAPHX_THROW
(
"MIOpen "
+
op
.
name
()
+
" : find convolution failed"
);
MIGRAPHX_THROW
(
"MIOpen "
+
op
.
name
()
+
" : find convolution failed"
);
algo
=
perf
.
fwd_algo
;
algo
=
perf
.
fwd_algo
;
size_t
solution_count
;
size_t
solution_count
;
status
=
miopenConvolutionForwardGetSolutionCount
(
ctx
.
get_stream
().
get_miopen
(),
status
=
miopenConvolutionForwardGetSolutionCount
(
ctx
.
get_stream
().
get_miopen
(),
...
...
src/targets/gpu/jit/concat.cpp
View file @
60d5fa1a
...
@@ -38,16 +38,19 @@ using namespace migraphx::gpu::gen; // NOLINT
...
@@ -38,16 +38,19 @@ using namespace migraphx::gpu::gen; // NOLINT
static
const
char
*
const
concat_kernel
=
R"__migraphx__(
static
const
char
*
const
concat_kernel
=
R"__migraphx__(
#include <migraphx/kernels/concat.hpp>
#include <migraphx/kernels/concat.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/ops.hpp>
#include <args.hpp>
#include <args.hpp>
namespace migraphx {
namespace migraphx {
${preamble}
extern "C" {
extern "C" {
__global__ void ${kernel}(${params})
__global__ void ${kernel}(${params})
{
{
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, auto... xs) {
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y,
${concat_params},
auto... xs) {
concat<${axis}>(y, xs...);
concat<${axis}>(
${concat_args})(${post},
y, xs...);
});
});
}
}
...
@@ -68,28 +71,42 @@ struct concat_compiler : compiler<concat_compiler>
...
@@ -68,28 +71,42 @@ struct concat_compiler : compiler<concat_compiler>
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
{
// TODO: Use reduce_dims
auto
num_of_concat_inputs
=
v
.
get
(
"concat_inputs"
,
inputs
.
size
()
-
1
);
hip_compile_options
options
;
hip_compile_options
options
;
options
.
inputs
=
inputs
;
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
options
.
output
=
inputs
.
back
();
options
.
params
=
"-Wno-float-equal"
;
options
.
params
=
"-Wno-float-equal"
;
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"concat_kernel"
);
auto
axis
=
find_fast_axis
(
options
.
inputs
);
auto
axis
=
find_fast_axis
(
options
.
inputs
);
auto
vec
=
vectorize
::
elements
(
ctx
,
axis
,
options
.
inputs
);
auto
vec
=
vectorize
::
elements
(
ctx
,
axis
,
options
.
inputs
);
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"concat_kernel"
);
options
.
set_launch_params
(
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
get_concat_elements
(
options
.
inputs
)
/
vec
.
size
,
256
));
v
,
compute_global_for
(
ctx
,
get_concat_elements
(
options
.
inputs
)
/
vec
.
size
,
256
));
auto
src
=
interpolate_string
(
concat_kernel
,
auto
src
=
interpolate_string
(
{{
"kernel"
,
options
.
kernel_name
},
concat_kernel
,
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{{
"kernel"
,
options
.
kernel_name
},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"transformers"
,
make_transformer_args
(
vec
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"axis"
,
v
.
at
(
"axis"
).
to
<
std
::
string
>
()}});
{
"concat_params"
,
enum_params
(
num_of_concat_inputs
,
"auto concat_x"
)},
{
"concat_args"
,
enum_params
(
num_of_concat_inputs
,
"concat_x"
)},
{
"post"
,
v
.
get
(
"post"
,
std
::
string
{
"op::id{}"
})},
{
"transformers"
,
make_transformer_args
(
vec
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})},
{
"axis"
,
v
.
at
(
"axis"
).
to
<
std
::
string
>
()}});
return
compile_hip_code_object
(
src
,
options
);
return
compile_hip_code_object
(
src
,
options
);
}
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
{
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
op
.
to_value
()));
auto
v
=
op
.
to_value
();
if
(
not
ins
->
module_inputs
().
empty
())
{
auto
*
pm
=
ins
->
module_inputs
().
front
();
v
[
"concat_inputs"
]
=
ins
->
inputs
().
size
()
-
pm
->
get_parameter_names
().
size
();
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_concat"
);
v
[
"post"
]
=
"MIGRAPHX_LIFT(post_concat)"
;
v
[
"kernel"
]
=
"concat_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
}
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
v
));
}
}
};
};
...
...
src/targets/gpu/jit/pointwise.cpp
View file @
60d5fa1a
...
@@ -58,7 +58,7 @@ __global__ void ${kernel}(${params})
...
@@ -58,7 +58,7 @@ __global__ void ${kernel}(${params})
struct
pointwise_compiler
:
compiler
<
pointwise_compiler
>
struct
pointwise_compiler
:
compiler
<
pointwise_compiler
>
{
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"pointwise"
,
"contiguous"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"pointwise"
,
"contiguous"
,
"layout"
};
}
static
std
::
size_t
oversubscribe_if
(
bool
b
)
static
std
::
size_t
oversubscribe_if
(
bool
b
)
{
{
...
@@ -91,12 +91,12 @@ struct pointwise_compiler : compiler<pointwise_compiler>
...
@@ -91,12 +91,12 @@ struct pointwise_compiler : compiler<pointwise_compiler>
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
{
if
(
op
.
name
()
==
"contiguous"
)
if
(
contains
({
"layout"
,
"contiguous"
},
op
.
name
())
)
{
{
return
replace
(
compile_op
(
return
replace
(
compile_op
(
ctx
,
ctx
,
to_shapes
(
ins
->
inputs
()),
to_shapes
(
ins
->
inputs
()),
{{
"lambda"
,
"[](auto x) { return x; }"
},
{
"kernel"
,
"contiguous
_kernel"
}}));
{{
"lambda"
,
"[](auto x) { return x; }"
},
{
"kernel"
,
op
.
name
()
+
"
_kernel"
}}));
}
}
else
else
{
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/concat.hpp
View file @
60d5fa1a
...
@@ -41,7 +41,15 @@ constexpr auto concat_slice(Output out, Input, Start)
...
@@ -41,7 +41,15 @@ constexpr auto concat_slice(Output out, Input, Start)
return
Start
{}
*
output_shape
.
strides
[
Axis
];
return
Start
{}
*
output_shape
.
strides
[
Axis
];
});
});
constexpr
auto
s
=
make_shape
(
lens
,
strides
);
constexpr
auto
s
=
make_shape
(
lens
,
strides
);
return
make_tensor_view
(
&
out
[
offset
],
s
);
MIGRAPHX_ASSERT
(
offset
<
out
.
get_shape
().
element_space
());
MIGRAPHX_ASSERT
((
s
.
element_space
()
+
offset
)
<=
out
.
get_shape
().
element_space
());
return
make_tensor_view
(
out
.
data
()
+
offset
,
s
);
}
template
<
index_int
Axis
,
class
Input
,
class
Start
,
class
...
Ts
>
constexpr
auto
concat_slices
(
Input
input
,
Start
start
,
Ts
...
xs
)
{
return
[
=
](
auto
f
)
{
f
(
concat_slice
<
Axis
>
(
xs
,
input
,
start
)...);
};
}
}
template
<
index_int
Axis
,
class
Input
>
template
<
index_int
Axis
,
class
Input
>
...
@@ -51,15 +59,19 @@ constexpr auto concat_ends(Input)
...
@@ -51,15 +59,19 @@ constexpr auto concat_ends(Input)
return
_c
<
lens
[
Axis
]
>
;
return
_c
<
lens
[
Axis
]
>
;
}
}
template
<
index_int
Axis
,
class
Output
,
class
...
Inputs
>
template
<
index_int
Axis
,
class
...
Inputs
>
__device__
void
concat
(
Output
output
,
Inputs
...
inputs
)
__device__
auto
concat
(
Inputs
...
inputs
)
{
{
auto
idx
=
make_index
();
return
[
=
](
auto
f
,
auto
...
ts
)
{
fold
([
&
](
auto
start
,
auto
input
)
{
auto
idx
=
make_index
();
auto
y
=
concat_slice
<
Axis
>
(
output
,
input
,
start
);
fold
([
&
](
auto
start
,
auto
input
)
{
idx
.
global_stride
(
input
.
get_shape
().
elements
(),
[
&
](
auto
i
)
{
y
[
i
]
=
input
[
i
];
});
concat_slices
<
Axis
>
(
input
,
start
,
ts
...)([
&
](
auto
y
,
auto
...
xs
)
{
return
start
+
concat_ends
<
Axis
>
(
input
);
idx
.
global_stride
(
input
.
get_shape
().
elements
(),
})(
_c
<
0
>
,
inputs
...);
[
&
](
auto
i
)
{
y
[
i
]
=
f
(
input
[
i
],
xs
[
i
]...);
});
});
return
start
+
concat_ends
<
Axis
>
(
input
);
})(
_c
<
0
>
,
inputs
...);
};
}
}
}
// namespace migraphx
}
// namespace migraphx
...
...
src/targets/gpu/lowering.cpp
View file @
60d5fa1a
...
@@ -29,19 +29,14 @@
...
@@ -29,19 +29,14 @@
#include <migraphx/instruction_ref.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/deconvolution.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/if_op.hpp>
#include <migraphx/op/if_op.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/compiler.hpp>
...
@@ -109,9 +104,9 @@ struct miopen_apply
...
@@ -109,9 +104,9 @@ struct miopen_apply
add_extend_op
(
"scatter_none"
);
add_extend_op
(
"scatter_none"
);
add_extend_op
(
"topk"
);
add_extend_op
(
"topk"
);
add_convolution_op
<
op
::
convolution
>
(
"convolution"
);
add_convolution_op
(
"convolution"
);
add_convolution_op
<
op
::
deconvolution
>
(
"deconvolution"
);
add_convolution_op
(
"deconvolution"
);
add_convolution_op
<
op
::
quant_convolution
>
(
"quant_convolution"
);
add_convolution_op
(
"quant_convolution"
);
add_gemm_op
<
op
::
dot
>
(
"dot"
);
add_gemm_op
<
op
::
dot
>
(
"dot"
);
add_gemm_op
<
op
::
quant_dot
>
(
"quant_dot"
);
add_gemm_op
<
op
::
quant_dot
>
(
"quant_dot"
);
add_if_op
();
add_if_op
();
...
@@ -238,34 +233,19 @@ struct miopen_apply
...
@@ -238,34 +233,19 @@ struct miopen_apply
});
});
}
}
template
<
typename
Op
>
void
add_convolution_op
(
const
std
::
string
&
name
)
void
add_convolution_op
(
const
std
::
string
&
name
)
{
{
apply_map
.
emplace
(
name
,
[
=
](
instruction_ref
ins
)
{
apply_map
.
emplace
(
name
,
[
=
](
instruction_ref
ins
)
{
operation
conv
=
operation
conv
=
make_op
(
miopen_convolution
<
Op
>
{
any_cast
<
Op
>
(
ins
->
get_operator
()),
int8_x4_format
};
"gpu::"
+
name
,
migraphx
::
context
ctx
=
get_context
();
{{
"op"
,
ins
->
get_operator
().
to_value
()},
{
"int8_x4_format"
,
int8_x4_format
}});
size_t
ws_bytes
=
0
;
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
auto
compile_conv_with_format
=
[
&
](
bool
format
)
{
conv
=
miopen_convolution
<
Op
>
{
any_cast
<
Op
>
(
ins
->
get_operator
()),
format
};
auto
ws
=
conv
.
compile
(
ctx
,
ins
->
get_shape
(),
to_shapes
(
ins
->
inputs
()));
ws_bytes
=
ws
.
get
(
"workspace"
,
0
);
};
try
{
// for the regular convolution and deconvolution, this try would always succeed
compile_conv_with_format
(
int8_x4_format
);
}
catch
(
migraphx
::
exception
&
)
{
// In case no solver supports the default format, retry using the other format.
compile_conv_with_format
(
not
int8_x4_format
);
}
auto
args
=
ins
->
inputs
();
return
mod
->
replace_instruction
(
ins
,
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
make_op
(
"gpu::miopen_op"
,
{{
"op"
,
to_value
(
conv
)}}),
auto
workspace
=
insert_allocation
(
ins
,
shape
{
shape
::
int8_type
,
{
ws_bytes
}});
ins
->
inputs
().
at
(
0
),
return
mod
->
replace_instruction
(
ins
,
conv
,
args
[
0
],
args
[
1
],
workspace
,
output
);
ins
->
inputs
().
at
(
1
),
output
);
});
});
}
}
...
...
src/targets/gpu/mlir.cpp
View file @
60d5fa1a
...
@@ -101,7 +101,10 @@ struct mlir_handle
...
@@ -101,7 +101,10 @@ struct mlir_handle
mlir_handle
(
T
p
)
:
handle
(
ptr
{
p
})
{}
mlir_handle
(
T
p
)
:
handle
(
ptr
{
p
})
{}
T
get
()
const
{
return
handle
.
get
().
get
();
}
T
get
()
const
{
return
handle
.
get
().
get
();
// NOLINT(readability-redundant-smartptr-get)
}
T
release
()
{
return
handle
.
release
().
get
();
}
T
release
()
{
return
handle
.
release
().
get
();
}
...
...
src/targets/gpu/target.cpp
View file @
60d5fa1a
...
@@ -35,6 +35,7 @@
...
@@ -35,6 +35,7 @@
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/inline_module.hpp>
#include <migraphx/inline_module.hpp>
#include <migraphx/insert_pad.hpp>
#include <migraphx/insert_pad.hpp>
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/memory_coloring.hpp>
#include <migraphx/memory_coloring.hpp>
#include <migraphx/normalize_ops.hpp>
#include <migraphx/normalize_ops.hpp>
#include <migraphx/preallocate_param.hpp>
#include <migraphx/preallocate_param.hpp>
...
@@ -50,6 +51,7 @@
...
@@ -50,6 +51,7 @@
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/gpu/allocation_model.hpp>
#include <migraphx/gpu/allocation_model.hpp>
#include <migraphx/gpu/compile_miopen.hpp>
#include <migraphx/gpu/compile_ops.hpp>
#include <migraphx/gpu/compile_ops.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/context.hpp>
...
@@ -70,6 +72,7 @@ namespace gpu {
...
@@ -70,6 +72,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_SCHEDULE_PASS
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_SCHEDULE_PASS
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_NHWC
)
struct
id_pass
struct
id_pass
{
{
...
@@ -120,6 +123,9 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -120,6 +123,9 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination
{},
dead_code_elimination
{},
simplify_algebra
{},
simplify_algebra
{},
simplify_reshapes
{},
simplify_reshapes
{},
enable_pass
(
enabled
(
MIGRAPHX_ENABLE_NHWC
{}),
layout_nhwc
{}),
dead_code_elimination
{},
simplify_reshapes
{},
simplify_algebra
{},
simplify_algebra
{},
prefuse_ops
{},
prefuse_ops
{},
dead_code_elimination
{},
dead_code_elimination
{},
...
@@ -136,8 +142,12 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -136,8 +142,12 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination
{},
dead_code_elimination
{},
eliminate_concat
{
concat_gpu_optimization
{}},
eliminate_concat
{
concat_gpu_optimization
{}},
dead_code_elimination
{},
dead_code_elimination
{},
compile_miopen
{
&
gctx
},
dead_code_elimination
{},
pack_int8_args
{},
pack_int8_args
{},
dead_code_elimination
{},
dead_code_elimination
{},
adjust_allocation
{
gpu_allocation_model
{}},
dead_code_elimination
{},
fuse_ops
{
&
ctx
,
options
.
fast_math
},
fuse_ops
{
&
ctx
,
options
.
fast_math
},
dead_code_elimination
{},
dead_code_elimination
{},
replace_allocate
{
gpu_allocation_model
{},
options
.
offload_copy
},
replace_allocate
{
gpu_allocation_model
{},
options
.
offload_copy
},
...
...
test/layout_nhwc.cpp
0 → 100644
View file @
60d5fa1a
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
void
run_pass
(
migraphx
::
module
&
m
)
{
migraphx
::
run_passes
(
m
,
{
migraphx
::
layout_nhwc
{},
migraphx
::
dead_code_elimination
{}});
}
migraphx
::
operation
layout
(
std
::
vector
<
int64_t
>
permutation
=
{
0
,
1
,
2
,
3
})
{
return
migraphx
::
make_op
(
"layout"
,
{{
"permutation"
,
permutation
}});
}
migraphx
::
instruction_ref
add_layout_nhwc
(
migraphx
::
module
&
m
,
migraphx
::
instruction_ref
ins
)
{
return
m
.
add_instruction
(
layout
({
0
,
2
,
3
,
1
}),
ins
);
}
TEST_CASE
(
conv_relu
)
{
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
,
8
,
16
,
16
}});
auto
w
=
m1
.
add_literal
(
migraphx
::
generate_literal
({
migraphx
::
shape
::
float_type
,
{
16
,
8
,
3
,
3
}}));
auto
conv
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"convolution"
,
{{
"padding"
,
{
1
,
1
}},
{
"stride"
,
{
2
,
2
}},
{
"dilation"
,
{
1
,
1
}}}),
x
,
w
);
m1
.
add_instruction
(
migraphx
::
make_op
(
"relu"
),
conv
);
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
x
=
add_layout_nhwc
(
m2
,
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
,
8
,
16
,
16
}}));
auto
w
=
add_layout_nhwc
(
m2
,
m2
.
add_literal
(
migraphx
::
generate_literal
(
{
migraphx
::
shape
::
float_type
,
{
16
,
8
,
3
,
3
}})));
auto
conv
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"convolution"
,
{{
"padding"
,
{
1
,
1
}},
{
"stride"
,
{
2
,
2
}},
{
"dilation"
,
{
1
,
1
}}}),
x
,
w
);
auto
conv_layout
=
m2
.
add_instruction
(
layout
(),
conv
);
m2
.
add_instruction
(
migraphx
::
make_op
(
"relu"
),
conv_layout
);
}
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE
(
conv_add
)
{
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
,
8
,
16
,
16
}});
auto
w
=
m1
.
add_literal
(
migraphx
::
generate_literal
({
migraphx
::
shape
::
float_type
,
{
16
,
8
,
3
,
3
}}));
auto
y
=
m1
.
add_literal
(
migraphx
::
generate_literal
({
migraphx
::
shape
::
float_type
,
{
16
}}));
auto
conv
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"convolution"
,
{{
"padding"
,
{
1
,
1
}},
{
"stride"
,
{
2
,
2
}},
{
"dilation"
,
{
1
,
1
}}}),
x
,
w
);
auto
b
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
conv
->
get_shape
().
lens
()}}),
y
);
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
conv
,
b
);
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
x
=
add_layout_nhwc
(
m2
,
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
,
8
,
16
,
16
}}));
auto
w
=
add_layout_nhwc
(
m2
,
m2
.
add_literal
(
migraphx
::
generate_literal
(
{
migraphx
::
shape
::
float_type
,
{
16
,
8
,
3
,
3
}})));
auto
y
=
m2
.
add_literal
(
migraphx
::
generate_literal
({
migraphx
::
shape
::
float_type
,
{
16
}}));
auto
conv
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"convolution"
,
{{
"padding"
,
{
1
,
1
}},
{
"stride"
,
{
2
,
2
}},
{
"dilation"
,
{
1
,
1
}}}),
x
,
w
);
auto
conv_layout
=
m2
.
add_instruction
(
layout
(),
conv
);
auto
b
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
conv
->
get_shape
().
lens
()}}),
y
);
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
conv_layout
,
b
);
}
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/onnx/gen_onnx.py
View file @
60d5fa1a
...
@@ -5755,6 +5755,92 @@ def split_test_default():
...
@@ -5755,6 +5755,92 @@ def split_test_default():
return
([
node
],
[
x
],
[
y1
,
y2
])
return
([
node
],
[
x
],
[
y1
,
y2
])
@
onnx_test
def
split_test_no_attribute
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
300
,
15
])
y1
=
helper
.
make_tensor_value_info
(
'y1'
,
TensorProto
.
FLOAT
,
[
75
,
15
])
y2
=
helper
.
make_tensor_value_info
(
'y2'
,
TensorProto
.
FLOAT
,
[
75
,
15
])
y3
=
helper
.
make_tensor_value_info
(
'y3'
,
TensorProto
.
FLOAT
,
[
75
,
15
])
y4
=
helper
.
make_tensor_value_info
(
'y4'
,
TensorProto
.
FLOAT
,
[
75
,
15
])
split
=
np
.
ones
(
4
)
*
75
split_tensor
=
helper
.
make_tensor
(
name
=
"split"
,
data_type
=
TensorProto
.
INT64
,
dims
=
split
.
shape
,
vals
=
split
.
astype
(
np
.
int64
))
const_node
=
helper
.
make_node
(
"Constant"
,
inputs
=
[],
outputs
=
[
'split'
],
value
=
split_tensor
)
node
=
onnx
.
helper
.
make_node
(
'Split'
,
inputs
=
[
'x'
,
'split'
],
outputs
=
[
'y1'
,
'y2'
,
'y3'
,
'y4'
],
)
return
([
const_node
,
node
],
[
x
],
[
y1
,
y2
,
y3
,
y4
])
@
onnx_test
def
split_test_no_attribute_invalid_split
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
300
,
15
])
y1
=
helper
.
make_tensor_value_info
(
'y1'
,
TensorProto
.
FLOAT
,
[
75
,
15
])
y2
=
helper
.
make_tensor_value_info
(
'y2'
,
TensorProto
.
FLOAT
,
[
75
,
15
])
y3
=
helper
.
make_tensor_value_info
(
'y3'
,
TensorProto
.
FLOAT
,
[
75
,
15
])
y4
=
helper
.
make_tensor_value_info
(
'y4'
,
TensorProto
.
FLOAT
,
[
75
,
15
])
split
=
np
.
ones
(
4
)
split_tensor
=
helper
.
make_tensor
(
name
=
"split"
,
data_type
=
TensorProto
.
INT64
,
dims
=
split
.
shape
,
vals
=
split
.
astype
(
np
.
int64
))
const_node
=
helper
.
make_node
(
"Constant"
,
inputs
=
[],
outputs
=
[
'split'
],
value
=
split_tensor
)
node
=
onnx
.
helper
.
make_node
(
'Split'
,
inputs
=
[
'x'
,
'split'
],
outputs
=
[
'y1'
,
'y2'
,
'y3'
,
'y4'
],
)
return
([
const_node
,
node
],
[
x
],
[
y1
,
y2
,
y3
,
y4
])
@
onnx_test
def
split_test_invalid_split
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
10
,
15
])
y1
=
helper
.
make_tensor_value_info
(
'y1'
,
TensorProto
.
FLOAT
,
[
10
,
7
])
y2
=
helper
.
make_tensor_value_info
(
'y2'
,
TensorProto
.
FLOAT
,
[
10
,
4
])
y3
=
helper
.
make_tensor_value_info
(
'y3'
,
TensorProto
.
FLOAT
,
[
10
,
4
])
node
=
onnx
.
helper
.
make_node
(
'Split'
,
inputs
=
[
'x'
],
outputs
=
[
'y1'
,
'y2'
,
'y3'
],
axis
=
1
,
split
=
[
1
,
1
,
1
])
return
([
node
],
[
x
],
[
y1
,
y2
,
y3
])
@
onnx_test
def
split_test_no_attribute_invalid_input_split
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
10
,
15
])
y1
=
helper
.
make_tensor_value_info
(
'y1'
,
TensorProto
.
FLOAT
,
[
10
,
7
])
y2
=
helper
.
make_tensor_value_info
(
'y2'
,
TensorProto
.
FLOAT
,
[
10
,
4
])
y3
=
helper
.
make_tensor_value_info
(
'y3'
,
TensorProto
.
FLOAT
,
[
10
,
4
])
node
=
onnx
.
helper
.
make_node
(
'Split'
,
inputs
=
[
'x'
],
outputs
=
[
'y1'
,
'y2'
,
'y3'
],
axis
=
1
,
split
=
[])
return
([
node
],
[
x
],
[
y1
,
y2
,
y3
])
@
onnx_test
@
onnx_test
def
sqrt_test
():
def
sqrt_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
10
,
15
])
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
10
,
15
])
...
...
test/onnx/onnx_test.cpp
View file @
60d5fa1a
...
@@ -5607,6 +5607,31 @@ TEST_CASE(split_test)
...
@@ -5607,6 +5607,31 @@ TEST_CASE(split_test)
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
}
}
TEST_CASE
(
split_test_no_attribute
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
si
{
migraphx
::
shape
::
int64_type
,
{
4
},
{
1
}};
std
::
vector
<
int
>
ind
=
{
75
,
75
,
75
,
75
};
auto
input
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
300
,
15
}});
mm
->
add_literal
(
migraphx
::
literal
(
si
,
ind
));
auto
r1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
75
}}}),
input
);
auto
r2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
75
}},
{
"ends"
,
{
150
}}}),
input
);
auto
r3
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
150
}},
{
"ends"
,
{
225
}}}),
input
);
auto
r4
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
225
}},
{
"ends"
,
{
300
}}}),
input
);
mm
->
add_return
({
r1
,
r2
,
r3
,
r4
});
auto
prog
=
migraphx
::
parse_onnx
(
"split_test_no_attribute.onnx"
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
split_test_default
)
TEST_CASE
(
split_test_default
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
@@ -5622,6 +5647,23 @@ TEST_CASE(split_test_default)
...
@@ -5622,6 +5647,23 @@ TEST_CASE(split_test_default)
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
}
}
TEST_CASE
(
split_test_no_attribute_invalid_split
)
{
EXPECT
(
test
::
throws
([
&
]
{
migraphx
::
parse_onnx
(
"split_test_no_attribute_invalid_split.onnx"
);
}));
}
TEST_CASE
(
split_test_invalid_split
)
{
EXPECT
(
test
::
throws
([
&
]
{
migraphx
::
parse_onnx
(
"split_test_invalid_split.onnx"
);
}));
}
TEST_CASE
(
split_test_no_attribute_invalid_input_split
)
{
EXPECT
(
test
::
throws
(
[
&
]
{
migraphx
::
parse_onnx
(
"split_test_no_attribute_invalid_input_split.onnx"
);
}));
}
TEST_CASE
(
sqrt_test
)
TEST_CASE
(
sqrt_test
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
...
test/onnx/split_test_invalid_split.onnx
0 → 100644
View file @
60d5fa1a
split_test_invalid_split:
5
xy1y2y3"Split*
axis*
split@@@split_test_invalid_splitZ
x
b
y1
b
y2
b
y3
B
\ No newline at end of file
test/onnx/split_test_no_attribute.onnx
0 → 100644
View file @
60d5fa1a
split_test_no_attribute:
0split"Constant*
value*:KKKKBsplit
!
x
splity1y2y3y4"Splitsplit_test_no_attributeZ
x
b
y1
K
b
y2
K
b
y3
K
b
y4
K
B
\ No newline at end of file
test/onnx/split_test_no_attribute_invalid_input_split.onnx
0 → 100644
View file @
60d5fa1a
+split_test_no_attribute_invalid_input_split:
/
xy1y2y3"Split*
axis*
split+split_test_no_attribute_invalid_input_splitZ
x
b
y1
b
y2
b
y3
B
\ No newline at end of file
test/onnx/split_test_no_attribute_invalid_split.onnx
0 → 100644
View file @
60d5fa1a
%split_test_no_attribute_invalid_split:
0split"Constant*
value*:Bsplit
!
x
splity1y2y3y4"Split%split_test_no_attribute_invalid_splitZ
x
b
y1
K
b
y2
K
b
y3
K
b
y4
K
B
\ No newline at end of file
test/verify/test_concat_broadcast_add.cpp
0 → 100644
View file @
60d5fa1a
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_concat_broadcast_add
:
verify_program
<
test_concat_broadcast_add
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s0
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
4
}};
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
1
,
6
,
4
}};
migraphx
::
shape
s2
{
migraphx
::
shape
::
float_type
,
{
6
,
1
}};
auto
x
=
mm
->
add_parameter
(
"x"
,
s0
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s0
);
auto
z
=
mm
->
add_parameter
(
"z"
,
s0
);
auto
concat
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
1
}}),
x
,
y
,
z
);
auto
b
=
mm
->
add_literal
(
migraphx
::
generate_literal
(
s2
,
15
));
auto
bb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s1
.
lens
()}}),
b
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
concat
,
bb
);
return
p
;
}
};
test/verify/test_slice_concat_add.cpp
0 → 100644
View file @
60d5fa1a
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_slice_concat_add
:
verify_program
<
test_slice_concat_add
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s0
{
migraphx
::
shape
::
float_type
,
{
1
,
24
,
2
,
2
}};
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
1
,
8
,
2
,
2
}};
auto
x
=
mm
->
add_parameter
(
"x"
,
s0
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s1
);
auto
z
=
mm
->
add_parameter
(
"z"
,
s0
);
auto
slice
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
8
}}}),
x
);
auto
concat
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
1
}}),
slice
,
y
,
y
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
concat
,
z
);
return
p
;
}
};
tools/convert_onnx_version.py
0 → 100644
View file @
60d5fa1a
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
import
argparse
import
onnx
from
onnx
import
version_converter
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'MIGraphX Onnx Model Convertion. Use to convert the opset of the input model to MIGraphX
\'
s'
)
req_args
=
parser
.
add_argument_group
(
title
=
'required arguments'
)
req_args
.
add_argument
(
'--model'
,
type
=
str
,
required
=
True
,
help
=
'path to onnx file'
)
req_args
.
add_argument
(
'--output'
,
type
=
str
,
required
=
True
,
help
=
'path to output onnx file'
)
req_args
.
add_argument
(
'--opset'
,
type
=
int
,
required
=
True
,
help
=
'The output opset'
)
req_args
.
add_argument
(
'--infer_shapes'
,
action
=
'store_true'
,
help
=
'Infer shapes for output model'
)
parser
.
add_argument
(
'--verbose'
,
action
=
'store_true'
,
help
=
'show verbose information (for debugging)'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
model_path
=
args
.
model
out_model_path
=
args
.
output
target_opset
=
args
.
opset
verbose
=
args
.
verbose
infer_shapes
=
args
.
infer_shapes
original_model
=
onnx
.
load
(
model_path
)
if
verbose
:
print
(
f
"The model before conversion:
\n
{
original_model
}
"
)
# A full list of supported adapters can be found here:
# https://github.com/onnx/onnx/blob/main/onnx/version_converter.py#L21
# Apply the version conversion on the original model
converted_model
=
version_converter
.
convert_version
(
original_model
,
target_opset
)
if
infer_shapes
:
converted_model
=
onnx
.
shape_inference
.
infer_shapes
(
converted_model
)
if
verbose
:
print
(
f
"The model after conversion:
\n
{
converted_model
}
"
)
# Save the ONNX model
onnx
.
save
(
converted_model
,
out_model_path
)
if
__name__
==
'__main__'
:
main
()
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment