Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b07243b4
Unverified
Commit
b07243b4
authored
Nov 04, 2022
by
Umang Yadav
Committed by
GitHub
Nov 04, 2022
Browse files
Merge branch 'develop' into tuning_warning
parents
e3c9dcc4
1820198e
Changes
38
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
676 additions
and
64 deletions
+676
-64
src/targets/gpu/include/migraphx/gpu/compile_miopen.hpp
src/targets/gpu/include/migraphx/gpu/compile_miopen.hpp
+51
-0
src/targets/gpu/include/migraphx/gpu/convolution.hpp
src/targets/gpu/include/migraphx/gpu/convolution.hpp
+5
-8
src/targets/gpu/jit/concat.cpp
src/targets/gpu/jit/concat.cpp
+28
-11
src/targets/gpu/jit/pointwise.cpp
src/targets/gpu/jit/pointwise.cpp
+3
-3
src/targets/gpu/kernels/include/migraphx/kernels/concat.hpp
src/targets/gpu/kernels/include/migraphx/kernels/concat.hpp
+21
-9
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+12
-32
src/targets/gpu/mlir.cpp
src/targets/gpu/mlir.cpp
+4
-1
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+10
-0
test/layout_nhwc.cpp
test/layout_nhwc.cpp
+127
-0
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+86
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+42
-0
test/onnx/split_test_invalid_split.onnx
test/onnx/split_test_invalid_split.onnx
+25
-0
test/onnx/split_test_no_attribute.onnx
test/onnx/split_test_no_attribute.onnx
+26
-0
test/onnx/split_test_no_attribute_invalid_input_split.onnx
test/onnx/split_test_no_attribute_invalid_input_split.onnx
+26
-0
test/onnx/split_test_no_attribute_invalid_split.onnx
test/onnx/split_test_no_attribute_invalid_split.onnx
+26
-0
test/verify/test_concat_broadcast_add.cpp
test/verify/test_concat_broadcast_add.cpp
+49
-0
test/verify/test_slice_concat_add.cpp
test/verify/test_slice_concat_add.cpp
+47
-0
tools/convert_onnx_version.py
tools/convert_onnx_version.py
+88
-0
No files found.
src/targets/gpu/include/migraphx/gpu/compile_miopen.hpp
0 → 100644
View file @
b07243b4
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_GPU_COMPILE_MIOPEN_HPP
#define MIGRAPHX_GUARD_GPU_COMPILE_MIOPEN_HPP
#include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp>
#include <string>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module
;
struct
context
;
struct
operation
;
namespace
gpu
{
struct
compile_miopen
{
context
*
ctx
=
nullptr
;
std
::
string
name
()
const
{
return
"gpu::compile_miopen"
;
}
void
apply
(
module
&
m
)
const
;
std
::
size_t
compile
(
operation
&
op
,
instruction_ref
ins
,
bool
format
)
const
;
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_COMPILE_MIOPEN_HPP
src/targets/gpu/include/migraphx/gpu/convolution.hpp
View file @
b07243b4
...
...
@@ -85,9 +85,10 @@ struct miopen_convolution
inline
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
op
}.
has
(
4
)
.
standard
()
;
check_shapes
{
inputs
,
op
}.
has
(
4
);
std
::
vector
<
shape
>
conv_inputs
(
inputs
.
begin
(),
inputs
.
begin
()
+
2
);
check_shapes
{
conv_inputs
,
op
}.
max_ndims
(
5
);
check_shapes
{
conv_inputs
,
*
this
}.
max_ndims
(
5
).
packed_layouts
(
{{
0
,
1
,
2
},
{
0
,
1
,
2
,
3
},
{
0
,
2
,
3
,
1
},
{
0
,
1
,
2
,
3
,
4
}});
return
migraphx
::
compute_shape
<
Op
>
(
op
,
conv_inputs
);
}
...
...
@@ -146,12 +147,9 @@ struct miopen_convolution
#endif
}
inline
void
set_conv_descriptor
()
void
set_conv_descriptor
()
{
if
(
cd
==
nullptr
)
{
cd
=
(
op
.
name
()
==
"deconvolution"
)
?
make_deconv
(
op
)
:
make_conv
(
op
);
}
cd
=
(
op
.
name
()
==
"deconvolution"
)
?
make_deconv
(
op
)
:
make_conv
(
op
);
}
value
compile
(
migraphx
::
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
shape
>&
input
)
...
...
@@ -241,7 +239,6 @@ struct miopen_convolution
if
(
status
!=
miopenStatusSuccess
)
MIGRAPHX_THROW
(
"MIOpen "
+
op
.
name
()
+
" : find convolution failed"
);
algo
=
perf
.
fwd_algo
;
size_t
solution_count
;
status
=
miopenConvolutionForwardGetSolutionCount
(
ctx
.
get_stream
().
get_miopen
(),
...
...
src/targets/gpu/jit/concat.cpp
View file @
b07243b4
...
...
@@ -38,16 +38,19 @@ using namespace migraphx::gpu::gen; // NOLINT
static
const
char
*
const
concat_kernel
=
R"__migraphx__(
#include <migraphx/kernels/concat.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/ops.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
__global__ void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, auto... xs) {
concat<${axis}>(y, xs...);
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y,
${concat_params},
auto... xs) {
concat<${axis}>(
${concat_args})(${post},
y, xs...);
});
}
...
...
@@ -68,28 +71,42 @@ struct concat_compiler : compiler<concat_compiler>
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
// TODO: Use reduce_dims
auto
num_of_concat_inputs
=
v
.
get
(
"concat_inputs"
,
inputs
.
size
()
-
1
);
hip_compile_options
options
;
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
options
.
params
=
"-Wno-float-equal"
;
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"concat_kernel"
);
auto
axis
=
find_fast_axis
(
options
.
inputs
);
auto
vec
=
vectorize
::
elements
(
ctx
,
axis
,
options
.
inputs
);
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"concat_kernel"
);
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
get_concat_elements
(
options
.
inputs
)
/
vec
.
size
,
256
));
auto
src
=
interpolate_string
(
concat_kernel
,
{{
"kernel"
,
options
.
kernel_name
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"transformers"
,
make_transformer_args
(
vec
)},
{
"axis"
,
v
.
at
(
"axis"
).
to
<
std
::
string
>
()}});
auto
src
=
interpolate_string
(
concat_kernel
,
{{
"kernel"
,
options
.
kernel_name
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"concat_params"
,
enum_params
(
num_of_concat_inputs
,
"auto concat_x"
)},
{
"concat_args"
,
enum_params
(
num_of_concat_inputs
,
"concat_x"
)},
{
"post"
,
v
.
get
(
"post"
,
std
::
string
{
"op::id{}"
})},
{
"transformers"
,
make_transformer_args
(
vec
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})},
{
"axis"
,
v
.
at
(
"axis"
).
to
<
std
::
string
>
()}});
return
compile_hip_code_object
(
src
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
op
.
to_value
()));
auto
v
=
op
.
to_value
();
if
(
not
ins
->
module_inputs
().
empty
())
{
auto
*
pm
=
ins
->
module_inputs
().
front
();
v
[
"concat_inputs"
]
=
ins
->
inputs
().
size
()
-
pm
->
get_parameter_names
().
size
();
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_concat"
);
v
[
"post"
]
=
"MIGRAPHX_LIFT(post_concat)"
;
v
[
"kernel"
]
=
"concat_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
}
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
v
));
}
};
...
...
src/targets/gpu/jit/pointwise.cpp
View file @
b07243b4
...
...
@@ -58,7 +58,7 @@ __global__ void ${kernel}(${params})
struct
pointwise_compiler
:
compiler
<
pointwise_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"pointwise"
,
"contiguous"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"pointwise"
,
"contiguous"
,
"layout"
};
}
static
std
::
size_t
oversubscribe_if
(
bool
b
)
{
...
...
@@ -91,12 +91,12 @@ struct pointwise_compiler : compiler<pointwise_compiler>
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
if
(
op
.
name
()
==
"contiguous"
)
if
(
contains
({
"layout"
,
"contiguous"
},
op
.
name
())
)
{
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
{{
"lambda"
,
"[](auto x) { return x; }"
},
{
"kernel"
,
"contiguous
_kernel"
}}));
{{
"lambda"
,
"[](auto x) { return x; }"
},
{
"kernel"
,
op
.
name
()
+
"
_kernel"
}}));
}
else
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/concat.hpp
View file @
b07243b4
...
...
@@ -41,7 +41,15 @@ constexpr auto concat_slice(Output out, Input, Start)
return
Start
{}
*
output_shape
.
strides
[
Axis
];
});
constexpr
auto
s
=
make_shape
(
lens
,
strides
);
return
make_tensor_view
(
&
out
[
offset
],
s
);
MIGRAPHX_ASSERT
(
offset
<
out
.
get_shape
().
element_space
());
MIGRAPHX_ASSERT
((
s
.
element_space
()
+
offset
)
<=
out
.
get_shape
().
element_space
());
return
make_tensor_view
(
out
.
data
()
+
offset
,
s
);
}
template
<
index_int
Axis
,
class
Input
,
class
Start
,
class
...
Ts
>
constexpr
auto
concat_slices
(
Input
input
,
Start
start
,
Ts
...
xs
)
{
return
[
=
](
auto
f
)
{
f
(
concat_slice
<
Axis
>
(
xs
,
input
,
start
)...);
};
}
template
<
index_int
Axis
,
class
Input
>
...
...
@@ -51,15 +59,19 @@ constexpr auto concat_ends(Input)
return
_c
<
lens
[
Axis
]
>
;
}
template
<
index_int
Axis
,
class
Output
,
class
...
Inputs
>
__device__
void
concat
(
Output
output
,
Inputs
...
inputs
)
template
<
index_int
Axis
,
class
...
Inputs
>
__device__
auto
concat
(
Inputs
...
inputs
)
{
auto
idx
=
make_index
();
fold
([
&
](
auto
start
,
auto
input
)
{
auto
y
=
concat_slice
<
Axis
>
(
output
,
input
,
start
);
idx
.
global_stride
(
input
.
get_shape
().
elements
(),
[
&
](
auto
i
)
{
y
[
i
]
=
input
[
i
];
});
return
start
+
concat_ends
<
Axis
>
(
input
);
})(
_c
<
0
>
,
inputs
...);
return
[
=
](
auto
f
,
auto
...
ts
)
{
auto
idx
=
make_index
();
fold
([
&
](
auto
start
,
auto
input
)
{
concat_slices
<
Axis
>
(
input
,
start
,
ts
...)([
&
](
auto
y
,
auto
...
xs
)
{
idx
.
global_stride
(
input
.
get_shape
().
elements
(),
[
&
](
auto
i
)
{
y
[
i
]
=
f
(
input
[
i
],
xs
[
i
]...);
});
});
return
start
+
concat_ends
<
Axis
>
(
input
);
})(
_c
<
0
>
,
inputs
...);
};
}
}
// namespace migraphx
...
...
src/targets/gpu/lowering.cpp
View file @
b07243b4
...
...
@@ -29,19 +29,14 @@
#include <migraphx/instruction_ref.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/deconvolution.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/if_op.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/compiler.hpp>
...
...
@@ -109,9 +104,9 @@ struct miopen_apply
add_extend_op
(
"scatter_none"
);
add_extend_op
(
"topk"
);
add_convolution_op
<
op
::
convolution
>
(
"convolution"
);
add_convolution_op
<
op
::
deconvolution
>
(
"deconvolution"
);
add_convolution_op
<
op
::
quant_convolution
>
(
"quant_convolution"
);
add_convolution_op
(
"convolution"
);
add_convolution_op
(
"deconvolution"
);
add_convolution_op
(
"quant_convolution"
);
add_gemm_op
<
op
::
dot
>
(
"dot"
);
add_gemm_op
<
op
::
quant_dot
>
(
"quant_dot"
);
add_if_op
();
...
...
@@ -238,34 +233,19 @@ struct miopen_apply
});
}
template
<
typename
Op
>
void
add_convolution_op
(
const
std
::
string
&
name
)
{
apply_map
.
emplace
(
name
,
[
=
](
instruction_ref
ins
)
{
operation
conv
=
miopen_convolution
<
Op
>
{
any_cast
<
Op
>
(
ins
->
get_operator
()),
int8_x4_format
};
migraphx
::
context
ctx
=
get_context
();
size_t
ws_bytes
=
0
;
auto
compile_conv_with_format
=
[
&
](
bool
format
)
{
conv
=
miopen_convolution
<
Op
>
{
any_cast
<
Op
>
(
ins
->
get_operator
()),
format
};
auto
ws
=
conv
.
compile
(
ctx
,
ins
->
get_shape
(),
to_shapes
(
ins
->
inputs
()));
ws_bytes
=
ws
.
get
(
"workspace"
,
0
);
};
try
{
// for the regular convolution and deconvolution, this try would always succeed
compile_conv_with_format
(
int8_x4_format
);
}
catch
(
migraphx
::
exception
&
)
{
// In case no solver supports the default format, retry using the other format.
compile_conv_with_format
(
not
int8_x4_format
);
}
operation
conv
=
make_op
(
"gpu::"
+
name
,
{{
"op"
,
ins
->
get_operator
().
to_value
()},
{
"int8_x4_format"
,
int8_x4_format
}});
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
auto
args
=
ins
->
inputs
();
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
auto
workspace
=
insert_allocation
(
ins
,
shape
{
shape
::
int8_type
,
{
ws_bytes
}});
return
mod
->
replace_instruction
(
ins
,
conv
,
args
[
0
],
args
[
1
],
workspace
,
output
);
return
mod
->
replace_instruction
(
ins
,
make_op
(
"gpu::miopen_op"
,
{{
"op"
,
to_value
(
conv
)}}),
ins
->
inputs
().
at
(
0
),
ins
->
inputs
().
at
(
1
),
output
);
});
}
...
...
src/targets/gpu/mlir.cpp
View file @
b07243b4
...
...
@@ -101,7 +101,10 @@ struct mlir_handle
mlir_handle
(
T
p
)
:
handle
(
ptr
{
p
})
{}
T
get
()
const
{
return
handle
.
get
().
get
();
}
T
get
()
const
{
return
handle
.
get
().
get
();
// NOLINT(readability-redundant-smartptr-get)
}
T
release
()
{
return
handle
.
release
().
get
();
}
...
...
src/targets/gpu/target.cpp
View file @
b07243b4
...
...
@@ -35,6 +35,7 @@
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/inline_module.hpp>
#include <migraphx/insert_pad.hpp>
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/memory_coloring.hpp>
#include <migraphx/normalize_ops.hpp>
#include <migraphx/preallocate_param.hpp>
...
...
@@ -50,6 +51,7 @@
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/gpu/allocation_model.hpp>
#include <migraphx/gpu/compile_miopen.hpp>
#include <migraphx/gpu/compile_ops.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/context.hpp>
...
...
@@ -70,6 +72,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_SCHEDULE_PASS
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_NHWC
)
struct
id_pass
{
...
...
@@ -120,6 +123,9 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination
{},
simplify_algebra
{},
simplify_reshapes
{},
enable_pass
(
enabled
(
MIGRAPHX_ENABLE_NHWC
{}),
layout_nhwc
{}),
dead_code_elimination
{},
simplify_reshapes
{},
simplify_algebra
{},
prefuse_ops
{},
dead_code_elimination
{},
...
...
@@ -136,8 +142,12 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination
{},
eliminate_concat
{
concat_gpu_optimization
{}},
dead_code_elimination
{},
compile_miopen
{
&
gctx
},
dead_code_elimination
{},
pack_int8_args
{},
dead_code_elimination
{},
adjust_allocation
{
gpu_allocation_model
{}},
dead_code_elimination
{},
fuse_ops
{
&
ctx
,
options
.
fast_math
},
dead_code_elimination
{},
replace_allocate
{
gpu_allocation_model
{},
options
.
offload_copy
},
...
...
test/layout_nhwc.cpp
0 → 100644
View file @
b07243b4
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
void
run_pass
(
migraphx
::
module
&
m
)
{
migraphx
::
run_passes
(
m
,
{
migraphx
::
layout_nhwc
{},
migraphx
::
dead_code_elimination
{}});
}
migraphx
::
operation
layout
(
std
::
vector
<
int64_t
>
permutation
=
{
0
,
1
,
2
,
3
})
{
return
migraphx
::
make_op
(
"layout"
,
{{
"permutation"
,
permutation
}});
}
migraphx
::
instruction_ref
add_layout_nhwc
(
migraphx
::
module
&
m
,
migraphx
::
instruction_ref
ins
)
{
return
m
.
add_instruction
(
layout
({
0
,
2
,
3
,
1
}),
ins
);
}
TEST_CASE
(
conv_relu
)
{
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
,
8
,
16
,
16
}});
auto
w
=
m1
.
add_literal
(
migraphx
::
generate_literal
({
migraphx
::
shape
::
float_type
,
{
16
,
8
,
3
,
3
}}));
auto
conv
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"convolution"
,
{{
"padding"
,
{
1
,
1
}},
{
"stride"
,
{
2
,
2
}},
{
"dilation"
,
{
1
,
1
}}}),
x
,
w
);
m1
.
add_instruction
(
migraphx
::
make_op
(
"relu"
),
conv
);
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
x
=
add_layout_nhwc
(
m2
,
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
,
8
,
16
,
16
}}));
auto
w
=
add_layout_nhwc
(
m2
,
m2
.
add_literal
(
migraphx
::
generate_literal
(
{
migraphx
::
shape
::
float_type
,
{
16
,
8
,
3
,
3
}})));
auto
conv
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"convolution"
,
{{
"padding"
,
{
1
,
1
}},
{
"stride"
,
{
2
,
2
}},
{
"dilation"
,
{
1
,
1
}}}),
x
,
w
);
auto
conv_layout
=
m2
.
add_instruction
(
layout
(),
conv
);
m2
.
add_instruction
(
migraphx
::
make_op
(
"relu"
),
conv_layout
);
}
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE
(
conv_add
)
{
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
,
8
,
16
,
16
}});
auto
w
=
m1
.
add_literal
(
migraphx
::
generate_literal
({
migraphx
::
shape
::
float_type
,
{
16
,
8
,
3
,
3
}}));
auto
y
=
m1
.
add_literal
(
migraphx
::
generate_literal
({
migraphx
::
shape
::
float_type
,
{
16
}}));
auto
conv
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"convolution"
,
{{
"padding"
,
{
1
,
1
}},
{
"stride"
,
{
2
,
2
}},
{
"dilation"
,
{
1
,
1
}}}),
x
,
w
);
auto
b
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
conv
->
get_shape
().
lens
()}}),
y
);
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
conv
,
b
);
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
x
=
add_layout_nhwc
(
m2
,
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
,
8
,
16
,
16
}}));
auto
w
=
add_layout_nhwc
(
m2
,
m2
.
add_literal
(
migraphx
::
generate_literal
(
{
migraphx
::
shape
::
float_type
,
{
16
,
8
,
3
,
3
}})));
auto
y
=
m2
.
add_literal
(
migraphx
::
generate_literal
({
migraphx
::
shape
::
float_type
,
{
16
}}));
auto
conv
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"convolution"
,
{{
"padding"
,
{
1
,
1
}},
{
"stride"
,
{
2
,
2
}},
{
"dilation"
,
{
1
,
1
}}}),
x
,
w
);
auto
conv_layout
=
m2
.
add_instruction
(
layout
(),
conv
);
auto
b
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
conv
->
get_shape
().
lens
()}}),
y
);
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
conv_layout
,
b
);
}
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/onnx/gen_onnx.py
View file @
b07243b4
...
...
@@ -5687,6 +5687,92 @@ def split_test_default():
return
([
node
],
[
x
],
[
y1
,
y2
])
@
onnx_test
def
split_test_no_attribute
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
300
,
15
])
y1
=
helper
.
make_tensor_value_info
(
'y1'
,
TensorProto
.
FLOAT
,
[
75
,
15
])
y2
=
helper
.
make_tensor_value_info
(
'y2'
,
TensorProto
.
FLOAT
,
[
75
,
15
])
y3
=
helper
.
make_tensor_value_info
(
'y3'
,
TensorProto
.
FLOAT
,
[
75
,
15
])
y4
=
helper
.
make_tensor_value_info
(
'y4'
,
TensorProto
.
FLOAT
,
[
75
,
15
])
split
=
np
.
ones
(
4
)
*
75
split_tensor
=
helper
.
make_tensor
(
name
=
"split"
,
data_type
=
TensorProto
.
INT64
,
dims
=
split
.
shape
,
vals
=
split
.
astype
(
np
.
int64
))
const_node
=
helper
.
make_node
(
"Constant"
,
inputs
=
[],
outputs
=
[
'split'
],
value
=
split_tensor
)
node
=
onnx
.
helper
.
make_node
(
'Split'
,
inputs
=
[
'x'
,
'split'
],
outputs
=
[
'y1'
,
'y2'
,
'y3'
,
'y4'
],
)
return
([
const_node
,
node
],
[
x
],
[
y1
,
y2
,
y3
,
y4
])
@
onnx_test
def
split_test_no_attribute_invalid_split
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
300
,
15
])
y1
=
helper
.
make_tensor_value_info
(
'y1'
,
TensorProto
.
FLOAT
,
[
75
,
15
])
y2
=
helper
.
make_tensor_value_info
(
'y2'
,
TensorProto
.
FLOAT
,
[
75
,
15
])
y3
=
helper
.
make_tensor_value_info
(
'y3'
,
TensorProto
.
FLOAT
,
[
75
,
15
])
y4
=
helper
.
make_tensor_value_info
(
'y4'
,
TensorProto
.
FLOAT
,
[
75
,
15
])
split
=
np
.
ones
(
4
)
split_tensor
=
helper
.
make_tensor
(
name
=
"split"
,
data_type
=
TensorProto
.
INT64
,
dims
=
split
.
shape
,
vals
=
split
.
astype
(
np
.
int64
))
const_node
=
helper
.
make_node
(
"Constant"
,
inputs
=
[],
outputs
=
[
'split'
],
value
=
split_tensor
)
node
=
onnx
.
helper
.
make_node
(
'Split'
,
inputs
=
[
'x'
,
'split'
],
outputs
=
[
'y1'
,
'y2'
,
'y3'
,
'y4'
],
)
return
([
const_node
,
node
],
[
x
],
[
y1
,
y2
,
y3
,
y4
])
@
onnx_test
def
split_test_invalid_split
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
10
,
15
])
y1
=
helper
.
make_tensor_value_info
(
'y1'
,
TensorProto
.
FLOAT
,
[
10
,
7
])
y2
=
helper
.
make_tensor_value_info
(
'y2'
,
TensorProto
.
FLOAT
,
[
10
,
4
])
y3
=
helper
.
make_tensor_value_info
(
'y3'
,
TensorProto
.
FLOAT
,
[
10
,
4
])
node
=
onnx
.
helper
.
make_node
(
'Split'
,
inputs
=
[
'x'
],
outputs
=
[
'y1'
,
'y2'
,
'y3'
],
axis
=
1
,
split
=
[
1
,
1
,
1
])
return
([
node
],
[
x
],
[
y1
,
y2
,
y3
])
@
onnx_test
def
split_test_no_attribute_invalid_input_split
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
10
,
15
])
y1
=
helper
.
make_tensor_value_info
(
'y1'
,
TensorProto
.
FLOAT
,
[
10
,
7
])
y2
=
helper
.
make_tensor_value_info
(
'y2'
,
TensorProto
.
FLOAT
,
[
10
,
4
])
y3
=
helper
.
make_tensor_value_info
(
'y3'
,
TensorProto
.
FLOAT
,
[
10
,
4
])
node
=
onnx
.
helper
.
make_node
(
'Split'
,
inputs
=
[
'x'
],
outputs
=
[
'y1'
,
'y2'
,
'y3'
],
axis
=
1
,
split
=
[])
return
([
node
],
[
x
],
[
y1
,
y2
,
y3
])
@
onnx_test
def
sqrt_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
10
,
15
])
...
...
test/onnx/onnx_test.cpp
View file @
b07243b4
...
...
@@ -5537,6 +5537,31 @@ TEST_CASE(split_test)
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
split_test_no_attribute
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
si
{
migraphx
::
shape
::
int64_type
,
{
4
},
{
1
}};
std
::
vector
<
int
>
ind
=
{
75
,
75
,
75
,
75
};
auto
input
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
300
,
15
}});
mm
->
add_literal
(
migraphx
::
literal
(
si
,
ind
));
auto
r1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
75
}}}),
input
);
auto
r2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
75
}},
{
"ends"
,
{
150
}}}),
input
);
auto
r3
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
150
}},
{
"ends"
,
{
225
}}}),
input
);
auto
r4
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
225
}},
{
"ends"
,
{
300
}}}),
input
);
mm
->
add_return
({
r1
,
r2
,
r3
,
r4
});
auto
prog
=
migraphx
::
parse_onnx
(
"split_test_no_attribute.onnx"
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
split_test_default
)
{
migraphx
::
program
p
;
...
...
@@ -5552,6 +5577,23 @@ TEST_CASE(split_test_default)
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
split_test_no_attribute_invalid_split
)
{
EXPECT
(
test
::
throws
([
&
]
{
migraphx
::
parse_onnx
(
"split_test_no_attribute_invalid_split.onnx"
);
}));
}
TEST_CASE
(
split_test_invalid_split
)
{
EXPECT
(
test
::
throws
([
&
]
{
migraphx
::
parse_onnx
(
"split_test_invalid_split.onnx"
);
}));
}
TEST_CASE
(
split_test_no_attribute_invalid_input_split
)
{
EXPECT
(
test
::
throws
(
[
&
]
{
migraphx
::
parse_onnx
(
"split_test_no_attribute_invalid_input_split.onnx"
);
}));
}
TEST_CASE
(
sqrt_test
)
{
migraphx
::
program
p
;
...
...
test/onnx/split_test_invalid_split.onnx
0 → 100644
View file @
b07243b4
split_test_invalid_split:
5
xy1y2y3"Split*
axis*
split@@@split_test_invalid_splitZ
x
b
y1
b
y2
b
y3
B
\ No newline at end of file
test/onnx/split_test_no_attribute.onnx
0 → 100644
View file @
b07243b4
split_test_no_attribute:
0split"Constant*
value*:KKKKBsplit
!
x
splity1y2y3y4"Splitsplit_test_no_attributeZ
x
b
y1
K
b
y2
K
b
y3
K
b
y4
K
B
\ No newline at end of file
test/onnx/split_test_no_attribute_invalid_input_split.onnx
0 → 100644
View file @
b07243b4
+split_test_no_attribute_invalid_input_split:
/
xy1y2y3"Split*
axis*
split+split_test_no_attribute_invalid_input_splitZ
x
b
y1
b
y2
b
y3
B
\ No newline at end of file
test/onnx/split_test_no_attribute_invalid_split.onnx
0 → 100644
View file @
b07243b4
%split_test_no_attribute_invalid_split:
0split"Constant*
value*:Bsplit
!
x
splity1y2y3y4"Split%split_test_no_attribute_invalid_splitZ
x
b
y1
K
b
y2
K
b
y3
K
b
y4
K
B
\ No newline at end of file
test/verify/test_concat_broadcast_add.cpp
0 → 100644
View file @
b07243b4
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_concat_broadcast_add
:
verify_program
<
test_concat_broadcast_add
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s0
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
4
}};
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
1
,
6
,
4
}};
migraphx
::
shape
s2
{
migraphx
::
shape
::
float_type
,
{
6
,
1
}};
auto
x
=
mm
->
add_parameter
(
"x"
,
s0
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s0
);
auto
z
=
mm
->
add_parameter
(
"z"
,
s0
);
auto
concat
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
1
}}),
x
,
y
,
z
);
auto
b
=
mm
->
add_literal
(
migraphx
::
generate_literal
(
s2
,
15
));
auto
bb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s1
.
lens
()}}),
b
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
concat
,
bb
);
return
p
;
}
};
test/verify/test_slice_concat_add.cpp
0 → 100644
View file @
b07243b4
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_slice_concat_add
:
verify_program
<
test_slice_concat_add
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s0
{
migraphx
::
shape
::
float_type
,
{
1
,
24
,
2
,
2
}};
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
1
,
8
,
2
,
2
}};
auto
x
=
mm
->
add_parameter
(
"x"
,
s0
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s1
);
auto
z
=
mm
->
add_parameter
(
"z"
,
s0
);
auto
slice
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
8
}}}),
x
);
auto
concat
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
1
}}),
slice
,
y
,
y
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
concat
,
z
);
return
p
;
}
};
tools/convert_onnx_version.py
0 → 100644
View file @
b07243b4
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
import
argparse
import
onnx
from
onnx
import
version_converter
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'MIGraphX Onnx Model Convertion. Use to convert the opset of the input model to MIGraphX
\'
s'
)
req_args
=
parser
.
add_argument_group
(
title
=
'required arguments'
)
req_args
.
add_argument
(
'--model'
,
type
=
str
,
required
=
True
,
help
=
'path to onnx file'
)
req_args
.
add_argument
(
'--output'
,
type
=
str
,
required
=
True
,
help
=
'path to output onnx file'
)
req_args
.
add_argument
(
'--opset'
,
type
=
int
,
required
=
True
,
help
=
'The output opset'
)
req_args
.
add_argument
(
'--infer_shapes'
,
action
=
'store_true'
,
help
=
'Infer shapes for output model'
)
parser
.
add_argument
(
'--verbose'
,
action
=
'store_true'
,
help
=
'show verbose information (for debugging)'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
model_path
=
args
.
model
out_model_path
=
args
.
output
target_opset
=
args
.
opset
verbose
=
args
.
verbose
infer_shapes
=
args
.
infer_shapes
original_model
=
onnx
.
load
(
model_path
)
if
verbose
:
print
(
f
"The model before conversion:
\n
{
original_model
}
"
)
# A full list of supported adapters can be found here:
# https://github.com/onnx/onnx/blob/main/onnx/version_converter.py#L21
# Apply the version conversion on the original model
converted_model
=
version_converter
.
convert_version
(
original_model
,
target_opset
)
if
infer_shapes
:
converted_model
=
onnx
.
shape_inference
.
infer_shapes
(
converted_model
)
if
verbose
:
print
(
f
"The model after conversion:
\n
{
converted_model
}
"
)
# Save the ONNX model
onnx
.
save
(
converted_model
,
out_model_path
)
if
__name__
==
'__main__'
:
main
()
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment