Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
97d4bb6c
Unverified
Commit
97d4bb6c
authored
Jul 25, 2023
by
Ted Themistokleous
Committed by
GitHub
Jul 25, 2023
Browse files
Merge branch 'develop' into add_parity_check_ci
parents
39b097c7
bdbc38bc
Changes
106
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1039 additions
and
567 deletions
+1039
-567
test/onnx/shape_end_oob_test.onnx
test/onnx/shape_end_oob_test.onnx
+0
-0
test/onnx/shape_start_oob_test.onnx
test/onnx/shape_start_oob_test.onnx
+0
-0
test/onnx/verify_onnx.cpp
test/onnx/verify_onnx.cpp
+68
-68
test/op_shape_test.cpp
test/op_shape_test.cpp
+122
-16
test/py/test_gpu.py
test/py/test_gpu.py
+2
-2
test/quantization.cpp
test/quantization.cpp
+10
-16
test/ref_dev_examples.cpp
test/ref_dev_examples.cpp
+1
-1
test/ref_dot_op_test.cpp
test/ref_dot_op_test.cpp
+41
-41
test/ref_ops_nonstd_shape_test.cpp
test/ref_ops_nonstd_shape_test.cpp
+3
-3
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+538
-293
test/ref_rnn_ops_test.cpp
test/ref_rnn_ops_test.cpp
+90
-90
test/rewrite_pooling_test.cpp
test/rewrite_pooling_test.cpp
+2
-2
test/run_on_target_test.cpp
test/run_on_target_test.cpp
+1
-1
test/shape_test.cpp
test/shape_test.cpp
+17
-8
test/simplify_qdq_test.cpp
test/simplify_qdq_test.cpp
+2
-2
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+100
-0
test/tf/tf_test.cpp
test/tf/tf_test.cpp
+15
-18
test/verify/run_verify.cpp
test/verify/run_verify.cpp
+23
-2
test/verify/test_convolution_backwards.cpp
test/verify/test_convolution_backwards.cpp
+2
-2
test/verify/test_convolution_backwards_1d.cpp
test/verify/test_convolution_backwards_1d.cpp
+2
-2
No files found.
test/onnx/shape_end_oob_test.onnx
0 → 100644
View file @
97d4bb6c
File added
test/onnx/shape_start_oob_test.onnx
0 → 100644
View file @
97d4bb6c
File added
test/onnx/verify_onnx.cpp
View file @
97d4bb6c
This diff is collapsed.
Click to expand it.
test/op_shape_test.cpp
View file @
97d4bb6c
...
@@ -453,37 +453,143 @@ TEST_CASE(contiguous_shape_singleton_dim)
...
@@ -453,37 +453,143 @@ TEST_CASE(contiguous_shape_singleton_dim)
expect_shape
(
output
,
migraphx
::
make_op
(
"contiguous"
),
input
);
expect_shape
(
output
,
migraphx
::
make_op
(
"contiguous"
),
input
);
}
}
TEST_CASE
(
deconvolution_shape
)
TEST_CASE
(
convolution_backwards_1d
)
{
migraphx
::
shape
input_1d
{
migraphx
::
shape
::
float_type
,
{
4
,
4
,
1
}};
migraphx
::
shape
weights_1d
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
}};
migraphx
::
shape
output_1d
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
}};
expect_shape
(
output_1d
,
migraphx
::
make_op
(
"convolution_backwards"
,
{{
"padding"
,
{
0
}},
{
"stride"
,
{
1
}},
{
"dilation"
,
{
1
}}}),
input_1d
,
weights_1d
);
}
TEST_CASE
(
convolution_backwards_2d
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
4
,
1
,
1
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
4
,
1
,
1
}};
migraphx
::
shape
weights
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"convolution_backwards"
),
input
,
weights
);
throws_shape
(
migraphx
::
make_op
(
"convolution_backwards"
),
input
);
throws_shape
(
migraphx
::
make_op
(
"convolution_backwards"
,
{{
"padding"
,
{
0
}},
{
"stride"
,
{
1
}},
{
"dilation"
,
{
1
}}}),
input
);
}
TEST_CASE
(
convolution_backwards_1padding
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
4
,
1
,
1
}};
migraphx
::
shape
weights
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}};
migraphx
::
shape
weights
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"deconvolution"
),
input
,
weights
);
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
1
,
1
}};
throws_shape
(
migraphx
::
make_op
(
"deconvolution"
),
input
);
expect_shape
(
output
,
throws_shape
(
migraphx
::
make_op
(
"convolution_backwards"
,
migraphx
::
make_op
(
"deconvolution"
,
{{
"padding"
,
{
0
}},
{
"stride"
,
{
1
}},
{
"dilation"
,
{
1
}}}),
{{
"padding"
,
{
1
,
1
}},
{
"stride"
,
{
1
,
1
}},
{
"dilation"
,
{
1
,
1
}}}),
input
);
input
,
weights
);
}
migraphx
::
shape
input_1d
{
migraphx
::
shape
::
float_type
,
{
4
,
4
,
1
}};
TEST_CASE
(
convolution_backwards_2stride
)
migraphx
::
shape
output_1d
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
}};
{
migraphx
::
shape
weights_1d
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
4
,
4
,
4
}};
expect_shape
(
migraphx
::
shape
weights
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}};
output_1d
,
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
9
,
9
}};
migraphx
::
make_op
(
"deconvolution"
,
{{
"padding"
,
{
0
}},
{
"stride"
,
{
1
}},
{
"dilation"
,
{
1
}}}),
expect_shape
(
output
,
input_1d
,
migraphx
::
make_op
(
"convolution_backwards"
,
weights_1d
);
{{
"padding"
,
{
0
,
0
}},
{
"stride"
,
{
2
,
2
}},
{
"dilation"
,
{
1
,
1
}}}),
input
,
weights
);
}
TEST_CASE
(
convolution_backwards_2dilation
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
4
,
4
,
4
}};
migraphx
::
shape
weights
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
8
,
8
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"convolution_backwards"
,
{{
"padding"
,
{
0
,
0
}},
{
"stride"
,
{
1
,
1
}},
{
"dilation"
,
{
2
,
2
}}}),
input
,
weights
);
}
TEST_CASE
(
convolution_backwards_3d
)
{
migraphx
::
shape
input_3d
{
migraphx
::
shape
::
float_type
,
{
4
,
4
,
1
,
1
,
1
}};
migraphx
::
shape
input_3d
{
migraphx
::
shape
::
float_type
,
{
4
,
4
,
1
,
1
,
1
}};
migraphx
::
shape
output_3d
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
,
3
}};
migraphx
::
shape
output_3d
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
,
3
}};
migraphx
::
shape
weights_3d
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
,
3
}};
migraphx
::
shape
weights_3d
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
,
3
}};
expect_shape
(
expect_shape
(
output_3d
,
output_3d
,
migraphx
::
make_op
(
"
de
convolution"
,
migraphx
::
make_op
(
"convolution
_backwards
"
,
{{
"padding"
,
{
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
,
1
}},
{
"dilation"
,
{
1
,
1
,
1
}}}),
{{
"padding"
,
{
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
,
1
}},
{
"dilation"
,
{
1
,
1
,
1
}}}),
input_3d
,
input_3d
,
weights_3d
);
weights_3d
);
}
}
TEST_CASE
(
convolution_backwards_channel_mismatch
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
4
,
1
,
1
}};
migraphx
::
shape
weights
{
migraphx
::
shape
::
float_type
,
{
3
,
3
,
3
,
3
}};
throws_shape
(
migraphx
::
make_op
(
"convolution_backwards"
),
input
,
weights
);
}
TEST_CASE
(
convolution_backwards_dyn_batch_2d
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
4
,
4
},
{
1
,
1
},
{
1
,
1
}}};
migraphx
::
shape
weights
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
3
,
3
},
{
3
,
3
},
{
3
,
3
}}};
expect_shape
(
output
,
migraphx
::
make_op
(
"convolution_backwards"
),
input
,
weights
);
}
TEST_CASE
(
convolution_backwards_dyn_img_2d
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
1
,
1
},
{
4
,
4
},
{
1
,
5
},
{
1
,
5
}}};
migraphx
::
shape
weights
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{{
1
,
1
},
{
3
,
3
},
{
3
,
7
},
{
3
,
7
}}};
expect_shape
(
output
,
migraphx
::
make_op
(
"convolution_backwards"
),
input
,
weights
);
}
TEST_CASE
(
convolution_backwards_dyn_kernel_2d
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
1
,
1
}};
migraphx
::
shape
weights
{
migraphx
::
shape
::
float_type
,
{{
4
,
4
},
{
3
,
3
},
{
2
,
6
},
{
2
,
6
}}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{{
1
,
1
},
{
3
,
3
},
{
2
,
6
},
{
2
,
6
}}};
expect_shape
(
output
,
migraphx
::
make_op
(
"convolution_backwards"
),
input
,
weights
);
}
TEST_CASE
(
dimensions_of0
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
2
,
1
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
int64_type
,
{
4
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"dimensions_of"
,
{{
"end"
,
4
}}),
input
);
}
TEST_CASE
(
dimensions_of1
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
2
,
1
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
int64_type
,
{
2
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"dimensions_of"
,
{{
"start"
,
1
},
{
"end"
,
3
}}),
input
);
}
TEST_CASE
(
dimensions_of2
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
{
2
}},
{
2
,
4
},
{
2
,
4
},
{
1
,
6
,
{
2
}}}};
migraphx
::
shape
output
{
migraphx
::
shape
::
int64_type
,
{
2
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"dimensions_of"
,
{{
"start"
,
1
},
{
"end"
,
3
}}),
input
);
}
TEST_CASE
(
dimensions_of_error0
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
{
2
}},
{
2
,
4
}}};
throws_shape
(
migraphx
::
make_op
(
"dimensions_of"
,
{{
"start"
,
3
},
{
"end"
,
3
}}),
input
);
}
TEST_CASE
(
dimensions_of_error1
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
{
2
}},
{
2
,
4
}}};
throws_shape
(
migraphx
::
make_op
(
"dimensions_of"
,
{{
"start"
,
3
},
{
"end"
,
0
}}),
input
);
}
TEST_CASE
(
dot_ndim_error0
)
TEST_CASE
(
dot_ndim_error0
)
{
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
5
}};
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
5
}};
...
@@ -1134,7 +1240,7 @@ TEST_CASE(inconsistent_attr_shape)
...
@@ -1134,7 +1240,7 @@ TEST_CASE(inconsistent_attr_shape)
{{
"padding"
,
{
1
,
1
}},
{
"stride"
,
{
2
}},
{
"dilation"
,
{
3
,
3
,
3
}}}),
{{
"padding"
,
{
1
,
1
}},
{
"stride"
,
{
2
}},
{
"dilation"
,
{
3
,
3
,
3
}}}),
input
,
input
,
weights
);
weights
);
throws_shape
(
migraphx
::
make_op
(
"
de
convolution"
,
throws_shape
(
migraphx
::
make_op
(
"convolution
_backwards
"
,
{{
"padding"
,
{
1
,
1
}},
{
"stride"
,
{
2
}},
{
"dilation"
,
{
3
,
3
,
3
}}}),
{{
"padding"
,
{
1
,
1
}},
{
"stride"
,
{
2
}},
{
"dilation"
,
{
3
,
3
,
3
}}}),
input
,
input
,
weights
);
weights
);
...
...
test/py/test_gpu.py
View file @
97d4bb6c
...
@@ -33,8 +33,8 @@ def test_conv_relu():
...
@@ -33,8 +33,8 @@ def test_conv_relu():
p
=
migraphx
.
parse_onnx
(
"conv_relu_maxpool_test.onnx"
)
p
=
migraphx
.
parse_onnx
(
"conv_relu_maxpool_test.onnx"
)
print
(
p
)
print
(
p
)
print
(
"Compiling ..."
)
print
(
"Compiling ..."
)
# set offload_copy, fast_match
and exhaustive_tune
to true
# set offload_copy, fast_match to true
p
.
compile
(
migraphx
.
get_target
(
"gpu"
),
True
,
True
,
True
)
p
.
compile
(
migraphx
.
get_target
(
"gpu"
),
True
,
True
)
print
(
p
)
print
(
p
)
params
=
{}
params
=
{}
...
...
test/quantization.cpp
View file @
97d4bb6c
...
@@ -379,10 +379,7 @@ TEST_CASE(fp16_subgraph)
...
@@ -379,10 +379,7 @@ TEST_CASE(fp16_subgraph)
auto
create_fp16_program
=
[]
{
auto
create_fp16_program
=
[]
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
sd
{
migraphx
::
shape
::
float_type
,
{
1
}};
migraphx
::
shape
sd
{
migraphx
::
shape
::
half_type
,
{
1
}};
auto
l1
=
mm
->
add_literal
(
migraphx
::
literal
(
sd
,
{
1
}));
auto
l2
=
mm
->
add_literal
(
migraphx
::
literal
(
sd
,
{
2
}));
auto
l3
=
mm
->
add_literal
(
migraphx
::
literal
(
sd
,
{
3
}));
migraphx
::
shape
sx
{
migraphx
::
shape
::
float_type
,
{
1
,
4
}};
migraphx
::
shape
sx
{
migraphx
::
shape
::
float_type
,
{
1
,
4
}};
migraphx
::
shape
sy
{
migraphx
::
shape
::
float_type
,
{
3
,
4
}};
migraphx
::
shape
sy
{
migraphx
::
shape
::
float_type
,
{
3
,
4
}};
migraphx
::
shape
sc
{
migraphx
::
shape
::
bool_type
};
migraphx
::
shape
sc
{
migraphx
::
shape
::
bool_type
};
...
@@ -390,17 +387,15 @@ TEST_CASE(fp16_subgraph)
...
@@ -390,17 +387,15 @@ TEST_CASE(fp16_subgraph)
auto
x
=
mm
->
add_parameter
(
"x"
,
sx
);
auto
x
=
mm
->
add_parameter
(
"x"
,
sx
);
auto
y
=
mm
->
add_parameter
(
"y"
,
sy
);
auto
y
=
mm
->
add_parameter
(
"y"
,
sy
);
auto
*
then_mod
=
p
.
create_module
(
"If_6_if"
);
auto
*
then_mod
=
p
.
create_module
(
"If_6_if"
);
auto
hl
1
=
then_mod
->
add_
instruction
(
auto
hl
2
=
then_mod
->
add_
literal
(
migraphx
::
literal
(
sd
,
{
2
}));
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
shape
::
half_type
}}),
l1
);
auto
hl1
=
then_mod
->
add_literal
(
migraphx
::
literal
(
sd
,
{
1
})
);
auto
mhl1
=
then_mod
->
add_instruction
(
auto
mhl1
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
4
}}}),
hl1
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
4
}}}),
hl1
);
auto
hx
=
then_mod
->
add_instruction
(
auto
hx
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
shape
::
half_type
}}),
x
);
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
shape
::
half_type
}}),
x
);
auto
ad
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
hx
,
mhl1
);
auto
ad
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
hx
,
mhl1
);
auto
fad
=
then_mod
->
add_instruction
(
auto
fad
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
shape
::
float_type
}}),
ad
);
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
shape
::
float_type
}}),
ad
);
auto
hl2
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
shape
::
half_type
}}),
l2
);
auto
mhl2
=
then_mod
->
add_instruction
(
auto
mhl2
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
4
}}}),
hl2
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
4
}}}),
hl2
);
auto
hy1
=
then_mod
->
add_instruction
(
auto
hy1
=
then_mod
->
add_instruction
(
...
@@ -411,9 +406,8 @@ TEST_CASE(fp16_subgraph)
...
@@ -411,9 +406,8 @@ TEST_CASE(fp16_subgraph)
then_mod
->
add_return
({
fad
,
fmu
,
mu
});
then_mod
->
add_return
({
fad
,
fmu
,
mu
});
auto
*
else_mod
=
p
.
create_module
(
"If_6_else"
);
auto
*
else_mod
=
p
.
create_module
(
"If_6_else"
);
auto
hl3
=
else_mod
->
add_instruction
(
auto
hl3
=
else_mod
->
add_literal
(
migraphx
::
literal
(
sd
,
{
3
}));
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
shape
::
half_type
}}),
l3
);
auto
mhl3
=
else_mod
->
add_instruction
(
auto
mhl3
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
4
}}}),
hl3
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
4
}}}),
hl3
);
auto
hx2
=
else_mod
->
add_instruction
(
auto
hx2
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
shape
::
half_type
}}),
x
);
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
shape
::
half_type
}}),
x
);
...
@@ -1020,7 +1014,7 @@ TEST_CASE(target_copy)
...
@@ -1020,7 +1014,7 @@ TEST_CASE(target_copy)
std
::
vector
<
float
>
orig_result
;
std
::
vector
<
float
>
orig_result
;
run_prog
(
p
,
ref_t
,
m
,
orig_result
);
run_prog
(
p
,
ref_t
,
m
,
orig_result
);
EXPECT
(
migraphx
::
verify_range
(
ref_result
,
orig_result
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
ref_result
,
orig_result
));
}
}
}
}
...
@@ -1084,7 +1078,7 @@ TEST_CASE(int8_quantization_dot)
...
@@ -1084,7 +1078,7 @@ TEST_CASE(int8_quantization_dot)
std
::
vector
<
float
>
no_quant_result
;
std
::
vector
<
float
>
no_quant_result
;
run_prog
(
p
,
ref_t
,
m
,
no_quant_result
);
run_prog
(
p
,
ref_t
,
m
,
no_quant_result
);
EXPECT
(
migraphx
::
verify_range
(
quant_result
,
no_quant_result
,
30000
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
quant_result
,
no_quant_result
,
30000
));
}
}
}
}
...
@@ -1129,7 +1123,7 @@ TEST_CASE(int8_quantization_conv)
...
@@ -1129,7 +1123,7 @@ TEST_CASE(int8_quantization_conv)
std
::
vector
<
float
>
no_quant_result
;
std
::
vector
<
float
>
no_quant_result
;
run_prog
(
p
,
ref_t
,
no_quant_result
);
run_prog
(
p
,
ref_t
,
no_quant_result
);
EXPECT
(
migraphx
::
verify_range
(
quant_result
,
no_quant_result
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
quant_result
,
no_quant_result
));
}
}
}
}
...
@@ -1281,7 +1275,7 @@ TEST_CASE(test_op_capture)
...
@@ -1281,7 +1275,7 @@ TEST_CASE(test_op_capture)
cap_res
.
visit
([
&
](
auto
output
)
{
cap_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
cap_res
.
visit
([
&
](
auto
output
)
{
cap_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
res
.
visit
([
&
](
auto
output
)
{
vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
res
.
visit
([
&
](
auto
output
)
{
vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
vec
,
cap_vec
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
vec
,
cap_vec
));
}
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/ref_dev_examples.cpp
View file @
97d4bb6c
...
@@ -168,7 +168,7 @@ TEST_CASE(handling_tensors)
...
@@ -168,7 +168,7 @@ TEST_CASE(handling_tensors)
std
::
vector
<
float
>
results_vector
(
64
);
std
::
vector
<
float
>
results_vector
(
64
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
sol
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
results_vector
,
sol
));
}
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/ref_dot_op_test.cpp
View file @
97d4bb6c
...
@@ -80,7 +80,7 @@ void dot_2d_test()
...
@@ -80,7 +80,7 @@ void dot_2d_test()
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
T
>
results_vector
;
std
::
vector
<
T
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
c
,
results_vector
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
c
,
results_vector
));
}
}
TEST_CASE_REGISTER
(
dot_2d_test
<
float
>
)
TEST_CASE_REGISTER
(
dot_2d_test
<
float
>
)
TEST_CASE_REGISTER
(
dot_2d_test
<
double
>
)
TEST_CASE_REGISTER
(
dot_2d_test
<
double
>
)
...
@@ -131,7 +131,7 @@ void dot_4d_test()
...
@@ -131,7 +131,7 @@ void dot_4d_test()
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
T
>
results_vector
;
std
::
vector
<
T
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
c
,
results_vector
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
c
,
results_vector
));
}
}
TEST_CASE_REGISTER
(
dot_4d_test
<
float
>
)
TEST_CASE_REGISTER
(
dot_4d_test
<
float
>
)
TEST_CASE_REGISTER
(
dot_4d_test
<
double
>
)
TEST_CASE_REGISTER
(
dot_4d_test
<
double
>
)
...
@@ -186,7 +186,7 @@ TEST_CASE(dot_3D_test)
...
@@ -186,7 +186,7 @@ TEST_CASE(dot_3D_test)
0.40245487
,
0.40245487
,
1.80182751
};
1.80182751
};
EXPECT
(
migraphx
::
verify_range
(
m
,
m_res
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
m_res
));
}
}
TEST_CASE
(
dot_3D_C_test0
)
TEST_CASE
(
dot_3D_C_test0
)
...
@@ -262,7 +262,7 @@ TEST_CASE(dot_3D_C_test0)
...
@@ -262,7 +262,7 @@ TEST_CASE(dot_3D_C_test0)
0.40245487
,
0.40245487
,
1.80182751
};
1.80182751
};
EXPECT
(
migraphx
::
verify_range
(
m
,
m_res
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
m_res
));
}
}
TEST_CASE
(
dot_3D_C_test1
)
TEST_CASE
(
dot_3D_C_test1
)
...
@@ -321,7 +321,7 @@ TEST_CASE(dot_3D_C_test1)
...
@@ -321,7 +321,7 @@ TEST_CASE(dot_3D_C_test1)
-
0.95536130
,
-
0.95536130
,
2.27996211
};
2.27996211
};
EXPECT
(
migraphx
::
verify_range
(
m
,
m_res
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
m_res
));
}
}
TEST_CASE
(
dot_4D_test1
)
TEST_CASE
(
dot_4D_test1
)
...
@@ -360,7 +360,7 @@ TEST_CASE(dot_4D_test1)
...
@@ -360,7 +360,7 @@ TEST_CASE(dot_4D_test1)
-
0.95467340
,
-
1.74728628
,
-
2.42477030
,
0.76262372
,
0.15539164
,
-
0.95467340
,
-
1.74728628
,
-
2.42477030
,
0.76262372
,
0.15539164
,
3.32281958
,
0.96769613
,
0.43727545
,
2.43019906
};
3.32281958
,
0.96769613
,
0.43727545
,
2.43019906
};
EXPECT
(
migraphx
::
verify_range
(
m
,
m_res
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
m_res
));
}
}
TEST_CASE
(
dot_4D_alpha_beta_test
)
TEST_CASE
(
dot_4D_alpha_beta_test
)
...
@@ -414,7 +414,7 @@ TEST_CASE(dot_4D_alpha_beta_test)
...
@@ -414,7 +414,7 @@ TEST_CASE(dot_4D_alpha_beta_test)
-
0.17183724
,
0.10858734
,
0.39406289
,
0.04662959
,
1.07979824
,
-
0.17183724
,
0.10858734
,
0.39406289
,
0.04662959
,
1.07979824
,
0.40355016
,
0.52410648
,
-
0.31728447
,
1.09550845
};
0.40355016
,
0.52410648
,
-
0.31728447
,
1.09550845
};
EXPECT
(
migraphx
::
verify_range
(
m
,
m_res
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
m_res
));
}
}
TEST_CASE
(
dot_4D_alpha_beta_C_test
)
TEST_CASE
(
dot_4D_alpha_beta_C_test
)
...
@@ -466,7 +466,7 @@ TEST_CASE(dot_4D_alpha_beta_C_test)
...
@@ -466,7 +466,7 @@ TEST_CASE(dot_4D_alpha_beta_C_test)
-
0.17183724
,
0.10858734
,
0.39406289
,
0.04662959
,
1.07979824
,
-
0.17183724
,
0.10858734
,
0.39406289
,
0.04662959
,
1.07979824
,
0.40355016
,
0.52410648
,
-
0.31728447
,
1.09550845
};
0.40355016
,
0.52410648
,
-
0.31728447
,
1.09550845
};
EXPECT
(
migraphx
::
verify_range
(
m
,
m_res
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
m_res
));
}
}
TEST_CASE
(
dot_2D_C_test0
)
TEST_CASE
(
dot_2D_C_test0
)
...
@@ -529,7 +529,7 @@ TEST_CASE(dot_2D_C_test0)
...
@@ -529,7 +529,7 @@ TEST_CASE(dot_2D_C_test0)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
}
}
...
@@ -567,7 +567,7 @@ TEST_CASE(dot_vv_inner_product)
...
@@ -567,7 +567,7 @@ TEST_CASE(dot_vv_inner_product)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -604,7 +604,7 @@ TEST_CASE(dot_vv_inner_product)
...
@@ -604,7 +604,7 @@ TEST_CASE(dot_vv_inner_product)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
}
}
...
@@ -642,7 +642,7 @@ TEST_CASE(dot_vm)
...
@@ -642,7 +642,7 @@ TEST_CASE(dot_vm)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -679,7 +679,7 @@ TEST_CASE(dot_vm)
...
@@ -679,7 +679,7 @@ TEST_CASE(dot_vm)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -726,7 +726,7 @@ TEST_CASE(dot_vm)
...
@@ -726,7 +726,7 @@ TEST_CASE(dot_vm)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -774,7 +774,7 @@ TEST_CASE(dot_vm)
...
@@ -774,7 +774,7 @@ TEST_CASE(dot_vm)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
}
}
...
@@ -813,7 +813,7 @@ TEST_CASE(dot_mv)
...
@@ -813,7 +813,7 @@ TEST_CASE(dot_mv)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -851,7 +851,7 @@ TEST_CASE(dot_mv)
...
@@ -851,7 +851,7 @@ TEST_CASE(dot_mv)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -895,7 +895,7 @@ TEST_CASE(dot_mv)
...
@@ -895,7 +895,7 @@ TEST_CASE(dot_mv)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
}
}
...
@@ -949,7 +949,7 @@ TEST_CASE(dot_mm1)
...
@@ -949,7 +949,7 @@ TEST_CASE(dot_mm1)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1002,7 +1002,7 @@ TEST_CASE(dot_mm1)
...
@@ -1002,7 +1002,7 @@ TEST_CASE(dot_mm1)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
}
}
...
@@ -1047,7 +1047,7 @@ TEST_CASE(dot_mm2)
...
@@ -1047,7 +1047,7 @@ TEST_CASE(dot_mm2)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1089,7 +1089,7 @@ TEST_CASE(dot_mm2)
...
@@ -1089,7 +1089,7 @@ TEST_CASE(dot_mm2)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1141,7 +1141,7 @@ TEST_CASE(dot_mm2)
...
@@ -1141,7 +1141,7 @@ TEST_CASE(dot_mm2)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1189,7 +1189,7 @@ TEST_CASE(dot_mm2)
...
@@ -1189,7 +1189,7 @@ TEST_CASE(dot_mm2)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
}
}
...
@@ -1242,7 +1242,7 @@ TEST_CASE(dot_dyn_2D_test)
...
@@ -1242,7 +1242,7 @@ TEST_CASE(dot_dyn_2D_test)
-
1.29885596e+00
,
-
1.29885596e+00
,
2.16294914e+00
,
2.16294914e+00
,
-
1.48101497e-01
};
-
1.48101497e-01
};
EXPECT
(
migraphx
::
verify_range
(
c
,
results_vector
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
c
,
results_vector
));
}
}
TEST_CASE
(
dot_dyn_4D_test
)
TEST_CASE
(
dot_dyn_4D_test
)
...
@@ -1296,7 +1296,7 @@ TEST_CASE(dot_dyn_4D_test)
...
@@ -1296,7 +1296,7 @@ TEST_CASE(dot_dyn_4D_test)
-
1.29885596e+00
,
-
1.29885596e+00
,
2.16294914e+00
,
2.16294914e+00
,
-
1.48101497e-01
};
-
1.48101497e-01
};
EXPECT
(
migraphx
::
verify_range
(
c
,
results_vector
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
c
,
results_vector
));
}
}
TEST_CASE
(
quant_dot_2args_multi4
)
TEST_CASE
(
quant_dot_2args_multi4
)
...
@@ -1324,7 +1324,7 @@ TEST_CASE(quant_dot_2args_multi4)
...
@@ -1324,7 +1324,7 @@ TEST_CASE(quant_dot_2args_multi4)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1352,7 +1352,7 @@ TEST_CASE(quant_dot_2args_multi4)
...
@@ -1352,7 +1352,7 @@ TEST_CASE(quant_dot_2args_multi4)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1380,7 +1380,7 @@ TEST_CASE(quant_dot_2args_multi4)
...
@@ -1380,7 +1380,7 @@ TEST_CASE(quant_dot_2args_multi4)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1410,7 +1410,7 @@ TEST_CASE(quant_dot_2args_multi4)
...
@@ -1410,7 +1410,7 @@ TEST_CASE(quant_dot_2args_multi4)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
}
}
...
@@ -1438,7 +1438,7 @@ TEST_CASE(quant_dot_2args_general)
...
@@ -1438,7 +1438,7 @@ TEST_CASE(quant_dot_2args_general)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1465,7 +1465,7 @@ TEST_CASE(quant_dot_2args_general)
...
@@ -1465,7 +1465,7 @@ TEST_CASE(quant_dot_2args_general)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1493,7 +1493,7 @@ TEST_CASE(quant_dot_2args_general)
...
@@ -1493,7 +1493,7 @@ TEST_CASE(quant_dot_2args_general)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1522,7 +1522,7 @@ TEST_CASE(quant_dot_2args_general)
...
@@ -1522,7 +1522,7 @@ TEST_CASE(quant_dot_2args_general)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
}
}
...
@@ -1554,7 +1554,7 @@ TEST_CASE(quant_dot_3args_general)
...
@@ -1554,7 +1554,7 @@ TEST_CASE(quant_dot_3args_general)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1582,7 +1582,7 @@ TEST_CASE(quant_dot_3args_general)
...
@@ -1582,7 +1582,7 @@ TEST_CASE(quant_dot_3args_general)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1613,7 +1613,7 @@ TEST_CASE(quant_dot_3args_general)
...
@@ -1613,7 +1613,7 @@ TEST_CASE(quant_dot_3args_general)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1644,7 +1644,7 @@ TEST_CASE(quant_dot_3args_general)
...
@@ -1644,7 +1644,7 @@ TEST_CASE(quant_dot_3args_general)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1677,7 +1677,7 @@ TEST_CASE(quant_dot_3args_general)
...
@@ -1677,7 +1677,7 @@ TEST_CASE(quant_dot_3args_general)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
}
}
...
@@ -1713,7 +1713,7 @@ TEST_CASE(quant_dot_3args_batch)
...
@@ -1713,7 +1713,7 @@ TEST_CASE(quant_dot_3args_batch)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1751,7 +1751,7 @@ TEST_CASE(quant_dot_3args_batch)
...
@@ -1751,7 +1751,7 @@ TEST_CASE(quant_dot_3args_batch)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
}
}
...
...
test/ref_ops_nonstd_shape_test.cpp
View file @
97d4bb6c
...
@@ -49,7 +49,7 @@ TEST_CASE(argmax_test_nonstd_shape)
...
@@ -49,7 +49,7 @@ TEST_CASE(argmax_test_nonstd_shape)
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
int64_t
>
res_gold_vec
;
std
::
vector
<
int64_t
>
res_gold_vec
;
res_gold
.
visit
([
&
](
auto
output
)
{
res_gold_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
res_gold
.
visit
([
&
](
auto
output
)
{
res_gold_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
result_vec
,
res_gold_vec
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
result_vec
,
res_gold_vec
));
}
}
TEST_CASE
(
argmin_test_nonstd_shape
)
TEST_CASE
(
argmin_test_nonstd_shape
)
...
@@ -68,7 +68,7 @@ TEST_CASE(argmin_test_nonstd_shape)
...
@@ -68,7 +68,7 @@ TEST_CASE(argmin_test_nonstd_shape)
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
int64_t
>
res_gold_vec
;
std
::
vector
<
int64_t
>
res_gold_vec
;
res_gold
.
visit
([
&
](
auto
output
)
{
res_gold_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
res_gold
.
visit
([
&
](
auto
output
)
{
res_gold_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
result_vec
,
res_gold_vec
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
result_vec
,
res_gold_vec
));
}
}
TEST_CASE
(
isnan_broadcast_test
)
TEST_CASE
(
isnan_broadcast_test
)
...
@@ -88,7 +88,7 @@ TEST_CASE(isnan_broadcast_test)
...
@@ -88,7 +88,7 @@ TEST_CASE(isnan_broadcast_test)
std
::
vector
<
float
>
results_vector
;
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
correct
=
{
0
,
0
,
0
,
0
,
1
,
1
};
std
::
vector
<
float
>
correct
=
{
0
,
0
,
0
,
0
,
1
,
1
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
correct
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
results_vector
,
correct
));
}
}
TEST_CASE
(
squeeze_transpose_test
)
TEST_CASE
(
squeeze_transpose_test
)
...
...
test/ref_ops_test.cpp
View file @
97d4bb6c
This diff is collapsed.
Click to expand it.
test/ref_rnn_ops_test.cpp
View file @
97d4bb6c
This diff is collapsed.
Click to expand it.
test/rewrite_pooling_test.cpp
View file @
97d4bb6c
...
@@ -198,8 +198,8 @@ TEST_CASE(literal_rewrite_pooling_test)
...
@@ -198,8 +198,8 @@ TEST_CASE(literal_rewrite_pooling_test)
p2
.
compile
(
migraphx
::
make_target
(
"ref"
));
p2
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
result1
=
p1
.
eval
({}).
back
();
auto
result1
=
p1
.
eval
({}).
back
();
auto
result2
=
p2
.
eval
({}).
back
();
auto
result2
=
p2
.
eval
({}).
back
();
visit_all
(
result1
,
visit_all
(
result1
,
result2
)(
result2
)(
[
&
](
auto
r1
,
auto
r2
)
{
EXPECT
(
migraphx
::
verify_range
(
r1
,
r2
));
});
[
&
](
auto
r1
,
auto
r2
)
{
EXPECT
(
migraphx
::
verify
::
verify
_range
(
r1
,
r2
));
});
};
};
test_rewrite_pooling
(
migraphx
::
op
::
pooling_mode
::
max
,
test_rewrite_pooling
(
migraphx
::
op
::
pooling_mode
::
max
,
...
...
test/run_on_target_test.cpp
View file @
97d4bb6c
...
@@ -68,7 +68,7 @@ TEST_CASE(eval_run_on_target)
...
@@ -68,7 +68,7 @@ TEST_CASE(eval_run_on_target)
std
::
vector
<
float
>
results_vector
(
3
);
std
::
vector
<
float
>
results_vector
(
3
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
0.5
,
0.25
,
0.125
};
std
::
vector
<
float
>
gold
=
{
0.5
,
0.25
,
0.125
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
results_vector
,
gold
));
}
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/shape_test.cpp
View file @
97d4bb6c
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -228,6 +228,15 @@ TEST_CASE(test_shape_dynamic_errors)
...
@@ -228,6 +228,15 @@ TEST_CASE(test_shape_dynamic_errors)
EXPECT
(
test
::
throws
([
&
]
{
s
.
index
(
std
::
vector
<
std
::
size_t
>
{
0
,
1
});
}));
EXPECT
(
test
::
throws
([
&
]
{
s
.
index
(
std
::
vector
<
std
::
size_t
>
{
0
,
1
});
}));
EXPECT
(
test
::
throws
([
&
]
{
s
.
with_lens
({
3
,
5
});
}));
EXPECT
(
test
::
throws
([
&
]
{
s
.
with_lens
({
3
,
5
});
}));
EXPECT
(
test
::
throws
([
&
]
{
s
.
with_lens
(
shape
::
float_type
,
{
3
,
5
});
}));
EXPECT
(
test
::
throws
([
&
]
{
s
.
with_lens
(
shape
::
float_type
,
{
3
,
5
});
}));
EXPECT
(
test
::
throws
([
&
]
{
s
.
lens
();
}));
EXPECT
(
test
::
throws
([
&
]
{
s
.
strides
();
}));
}
TEST_CASE
(
test_shape_static_dyn_dim_error
)
{
using
migraphx
::
shape
;
migraphx
::
shape
s
{
shape
::
float_type
,
{
2
,
3
,
4
}};
EXPECT
(
test
::
throws
([
&
]
{
s
.
dyn_dims
();
}));
}
}
TEST_CASE
(
test_shape_dynamic_serialize
)
TEST_CASE
(
test_shape_dynamic_serialize
)
...
@@ -947,13 +956,13 @@ TEST_CASE(test_with_type)
...
@@ -947,13 +956,13 @@ TEST_CASE(test_with_type)
TEST_CASE
(
test_multi_index
)
TEST_CASE
(
test_multi_index
)
{
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
4
,
6
}};
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
4
,
6
}};
EXPECT
(
migraphx
::
verify_range
(
s
.
multi
(
0
),
std
::
vector
<
size_t
>
{
0
,
0
,
0
}));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
s
.
multi
(
0
),
std
::
vector
<
size_t
>
{
0
,
0
,
0
}));
EXPECT
(
migraphx
::
verify_range
(
s
.
multi
(
4
),
std
::
vector
<
size_t
>
{
0
,
0
,
4
}));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
s
.
multi
(
4
),
std
::
vector
<
size_t
>
{
0
,
0
,
4
}));
EXPECT
(
migraphx
::
verify_range
(
s
.
multi
(
6
),
std
::
vector
<
size_t
>
{
0
,
1
,
0
}));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
s
.
multi
(
6
),
std
::
vector
<
size_t
>
{
0
,
1
,
0
}));
EXPECT
(
migraphx
::
verify_range
(
s
.
multi
(
8
),
std
::
vector
<
size_t
>
{
0
,
1
,
2
}));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
s
.
multi
(
8
),
std
::
vector
<
size_t
>
{
0
,
1
,
2
}));
EXPECT
(
migraphx
::
verify_range
(
s
.
multi
(
24
),
std
::
vector
<
size_t
>
{
1
,
0
,
0
}));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
s
.
multi
(
24
),
std
::
vector
<
size_t
>
{
1
,
0
,
0
}));
EXPECT
(
migraphx
::
verify_range
(
s
.
multi
(
30
),
std
::
vector
<
size_t
>
{
1
,
1
,
0
}));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
s
.
multi
(
30
),
std
::
vector
<
size_t
>
{
1
,
1
,
0
}));
EXPECT
(
migraphx
::
verify_range
(
s
.
multi
(
34
),
std
::
vector
<
size_t
>
{
1
,
1
,
4
}));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
s
.
multi
(
34
),
std
::
vector
<
size_t
>
{
1
,
1
,
4
}));
}
}
TEST_CASE
(
find_permutation_2d_standard
)
TEST_CASE
(
find_permutation_2d_standard
)
...
...
test/simplify_qdq_test.cpp
View file @
97d4bb6c
...
@@ -700,7 +700,7 @@ TEST_CASE(conv_correctness)
...
@@ -700,7 +700,7 @@ TEST_CASE(conv_correctness)
auto
result2
=
p2
.
eval
({{
"input"
,
input
},
{
"weights"
,
weights
}}).
back
();
auto
result2
=
p2
.
eval
({{
"input"
,
input
},
{
"weights"
,
weights
}}).
back
();
std
::
vector
<
float
>
rv2
(
16
);
std
::
vector
<
float
>
rv2
(
16
);
result2
.
visit
([
&
](
auto
output
)
{
rv2
.
assign
(
output
.
begin
(),
output
.
end
());
});
result2
.
visit
([
&
](
auto
output
)
{
rv2
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
rv1
,
rv2
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
rv1
,
rv2
));
}
}
TEST_CASE
(
dot_correctness
)
TEST_CASE
(
dot_correctness
)
...
@@ -750,7 +750,7 @@ TEST_CASE(dot_correctness)
...
@@ -750,7 +750,7 @@ TEST_CASE(dot_correctness)
auto
result2
=
p2
.
eval
({{
"a"
,
a
},
{
"b"
,
b
}}).
back
();
auto
result2
=
p2
.
eval
({{
"a"
,
a
},
{
"b"
,
b
}}).
back
();
std
::
vector
<
float
>
rv2
(
sh3
.
elements
());
std
::
vector
<
float
>
rv2
(
sh3
.
elements
());
result2
.
visit
([
&
](
auto
output
)
{
rv2
.
assign
(
output
.
begin
(),
output
.
end
());
});
result2
.
visit
([
&
](
auto
output
)
{
rv2
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
rv1
,
rv2
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
rv1
,
rv2
));
}
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/simplify_reshapes_test.cpp
View file @
97d4bb6c
...
@@ -357,6 +357,106 @@ TEST_CASE(nop_convert)
...
@@ -357,6 +357,106 @@ TEST_CASE(nop_convert)
EXPECT
(
std
::
distance
(
m
.
begin
(),
m
.
end
())
==
n
-
1
);
EXPECT
(
std
::
distance
(
m
.
begin
(),
m
.
end
())
==
n
-
1
);
}
}
TEST_CASE
(
nested_reshape
)
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
,
4
,
5
,
6
,
7
}};
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
s
);
auto
rshp1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
1
,
2
,
3
,
4
,
5
,
42
}}}),
x
);
auto
rshp2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
1
,
2
,
12
,
5
,
42
}}}),
rshp1
);
auto
rshp3
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
2
,
12
,
5
,
42
}}}),
rshp2
);
auto
rshp4
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
2
,
60
,
42
}}}),
rshp3
);
auto
rshp5
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
120
,
42
}}}),
rshp4
);
auto
rshp6
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
5040
}}}),
rshp5
);
m1
.
add_return
({
rshp6
});
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
x
=
m2
.
add_parameter
(
"x"
,
s
);
auto
rshp
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
5040
}}}),
x
);
m2
.
add_return
({
rshp
});
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
nested_reshape_contiguous
)
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
,
4
,
5
,
6
,
7
}};
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
s
);
auto
rshp1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
1
,
2
,
3
,
4
,
5
,
42
}}}),
x
);
auto
c1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
rshp1
);
auto
rshp2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
1
,
2
,
12
,
5
,
42
}}}),
c1
);
auto
c2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
rshp2
);
auto
rshp3
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
2
,
12
,
5
,
42
}}}),
c2
);
auto
c3
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
rshp3
);
auto
rshp4
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
2
,
60
,
42
}}}),
c3
);
auto
c4
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
rshp4
);
auto
rshp5
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
120
,
42
}}}),
c4
);
auto
c5
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
rshp5
);
auto
rshp6
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
5040
}}}),
c5
);
m1
.
add_return
({
rshp6
});
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
x
=
m2
.
add_parameter
(
"x"
,
s
);
auto
rshp
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
5040
}}}),
x
);
m2
.
add_return
({
rshp
});
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
nested_reshape_squeeze
)
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
,
4
}};
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
s
);
auto
rshp
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
1
,
2
,
12
}}}),
x
);
auto
squeeze
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
rshp
);
m1
.
add_return
({
squeeze
});
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
x
=
m2
.
add_parameter
(
"x"
,
s
);
auto
rshp
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
2
,
12
}}}),
x
);
m2
.
add_return
({
rshp
});
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
nested_squeeze_reshape
)
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
,
4
}};
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
s
);
auto
squeeze
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
x
);
auto
rshp
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
2
,
12
}}}),
squeeze
);
m1
.
add_return
({
rshp
});
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
x
=
m2
.
add_parameter
(
"x"
,
s
);
auto
rshp
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
2
,
12
}}}),
x
);
m2
.
add_return
({
rshp
});
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
concat_multibroadcasts1
)
TEST_CASE
(
concat_multibroadcasts1
)
{
{
// Broadcasted batch dim, new axis < old axis
// Broadcasted batch dim, new axis < old axis
...
...
test/tf/tf_test.cpp
View file @
97d4bb6c
...
@@ -196,7 +196,6 @@ TEST_CASE(batchnorm_test)
...
@@ -196,7 +196,6 @@ TEST_CASE(batchnorm_test)
std
::
vector
<
float
>
scale_data
(
32
,
1.0
);
std
::
vector
<
float
>
scale_data
(
32
,
1.0
);
auto
scale
=
mm
->
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
32
}},
scale_data
);
auto
scale
=
mm
->
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
32
}},
scale_data
);
auto
rt
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
::
float_type
,
{
0.5
}});
auto
eps
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
::
float_type
,
{
1e-4
f
}});
auto
eps
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
::
float_type
,
{
1e-4
f
}});
auto
usq_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
,
2
}}}),
scale
);
auto
usq_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
,
2
}}}),
scale
);
...
@@ -204,11 +203,11 @@ TEST_CASE(batchnorm_test)
...
@@ -204,11 +203,11 @@ TEST_CASE(batchnorm_test)
auto
usq_mean
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
,
2
}}}),
mean
);
auto
usq_mean
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
,
2
}}}),
mean
);
auto
usq_var
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
,
2
}}}),
var
);
auto
usq_var
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
,
2
}}}),
var
);
auto
numer
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"sub"
),
{
x
,
usq_mean
});
auto
x_sub_mean
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"sub"
),
{
x
,
usq_mean
});
auto
var_eps
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"add"
),
{
usq_var
,
eps
});
auto
var_eps
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"add"
),
{
usq_var
,
eps
});
auto
denom
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"
pow
"
),
{
var_eps
,
rt
}
);
auto
rsqrt
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"
rsqrt
"
),
var_eps
);
auto
div0
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"
div
"
),
{
numer
,
denom
});
auto
mul0
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"
mul
"
),
{
usq_scale
,
rsqrt
});
auto
r0
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"mul"
),
{
div0
,
usq_scale
});
auto
r0
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"mul"
),
{
x_sub_mean
,
mul0
});
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"add"
),
{
r0
,
usq_bias
});
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"add"
),
{
r0
,
usq_bias
});
auto
prog
=
optimize_tf
(
"batchnorm_test.pb"
,
true
);
auto
prog
=
optimize_tf
(
"batchnorm_test.pb"
,
true
);
...
@@ -227,7 +226,6 @@ TEST_CASE(batchnorm_half_test)
...
@@ -227,7 +226,6 @@ TEST_CASE(batchnorm_half_test)
std
::
vector
<
float
>
scale_data
(
32
,
1.0
);
std
::
vector
<
float
>
scale_data
(
32
,
1.0
);
auto
scale
=
mm
->
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
32
}},
scale_data
);
auto
scale
=
mm
->
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
32
}},
scale_data
);
auto
rt
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
::
half_type
,
{
0.5
}});
auto
eps
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
::
half_type
,
{
1e-4
f
}});
auto
eps
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
::
half_type
,
{
1e-4
f
}});
auto
usq_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
,
2
}}}),
scale
);
auto
usq_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
,
2
}}}),
scale
);
...
@@ -235,11 +233,11 @@ TEST_CASE(batchnorm_half_test)
...
@@ -235,11 +233,11 @@ TEST_CASE(batchnorm_half_test)
auto
usq_mean
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
,
2
}}}),
mean
);
auto
usq_mean
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
,
2
}}}),
mean
);
auto
usq_var
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
,
2
}}}),
var
);
auto
usq_var
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
,
2
}}}),
var
);
auto
numer
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"sub"
),
{
x
,
usq_mean
});
auto
x_sub_mean
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"sub"
),
{
x
,
usq_mean
});
auto
var_eps
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"add"
),
{
usq_var
,
eps
});
auto
var_eps
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"add"
),
{
usq_var
,
eps
});
auto
denom
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"
pow
"
),
{
var_eps
,
rt
}
);
auto
rsqrt
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"
rsqrt
"
),
var_eps
);
auto
div0
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"
div
"
),
{
numer
,
denom
});
auto
mul0
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"
mul
"
),
{
usq_scale
,
rsqrt
});
auto
r0
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"mul"
),
{
div0
,
usq_scale
});
auto
r0
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"mul"
),
{
x_sub_mean
,
mul0
});
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"add"
),
{
r0
,
usq_bias
});
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"add"
),
{
r0
,
usq_bias
});
auto
prog
=
optimize_tf
(
"batchnorm_half_test.pb"
,
true
);
auto
prog
=
optimize_tf
(
"batchnorm_half_test.pb"
,
true
);
...
@@ -258,7 +256,6 @@ TEST_CASE(batchnormv3_test)
...
@@ -258,7 +256,6 @@ TEST_CASE(batchnormv3_test)
std
::
vector
<
float
>
scale_data
(
32
,
1.0
);
std
::
vector
<
float
>
scale_data
(
32
,
1.0
);
auto
scale
=
mm
->
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
32
}},
scale_data
);
auto
scale
=
mm
->
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
32
}},
scale_data
);
auto
rt
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
::
float_type
,
{
0.5
}});
auto
eps
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
::
float_type
,
{
1e-6
f
}});
auto
eps
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
::
float_type
,
{
1e-6
f
}});
auto
usq_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
,
2
}}}),
scale
);
auto
usq_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
,
2
}}}),
scale
);
...
@@ -266,11 +263,11 @@ TEST_CASE(batchnormv3_test)
...
@@ -266,11 +263,11 @@ TEST_CASE(batchnormv3_test)
auto
usq_mean
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
,
2
}}}),
mean
);
auto
usq_mean
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
,
2
}}}),
mean
);
auto
usq_var
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
,
2
}}}),
var
);
auto
usq_var
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
,
2
}}}),
var
);
auto
numer
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"sub"
),
{
x
,
usq_mean
});
auto
x_sub_mean
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"sub"
),
{
x
,
usq_mean
});
auto
var_eps
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"add"
),
{
usq_var
,
eps
});
auto
var_eps
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"add"
),
{
usq_var
,
eps
});
auto
denom
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"
pow
"
),
{
var_eps
,
rt
}
);
auto
rsqrt
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"
rsqrt
"
),
var_eps
);
auto
div0
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"
div
"
),
{
numer
,
denom
});
auto
mul0
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"
mul
"
),
{
usq_scale
,
rsqrt
});
auto
r0
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"mul"
),
{
div0
,
usq_scale
});
auto
r0
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"mul"
),
{
x_sub_mean
,
mul0
});
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"add"
),
{
r0
,
usq_bias
});
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"add"
),
{
r0
,
usq_bias
});
auto
prog
=
optimize_tf
(
"batchnormv3_test.pb"
,
true
);
auto
prog
=
optimize_tf
(
"batchnormv3_test.pb"
,
true
);
...
...
test/verify/run_verify.cpp
View file @
97d4bb6c
...
@@ -88,10 +88,31 @@ inline void compile_check(migraphx::program& p,
...
@@ -88,10 +88,31 @@ inline void compile_check(migraphx::program& p,
auto
num
=
shapes
.
size
();
auto
num
=
shapes
.
size
();
for
(
std
::
size_t
i
=
0
;
i
<
num
;
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
num
;
++
i
)
{
{
if
(
p
.
get_output_shapes
()[
i
].
lens
()
!=
shapes
[
i
].
lens
())
auto
output_shape
=
p
.
get_output_shapes
()[
i
];
if
(
output_shape
.
dynamic
()
and
shapes
[
i
].
dynamic
())
{
if
(
output_shape
.
dyn_dims
()
!=
shapes
[
i
].
dyn_dims
())
{
std
::
cout
<<
ss
.
str
()
<<
std
::
endl
;
throw
std
::
runtime_error
(
"Compiling program with "
+
name
+
" alters its dynamic output dimensions"
);
}
}
else
if
(
not
(
output_shape
.
dynamic
()
or
shapes
[
i
].
dynamic
()))
{
if
(
output_shape
.
lens
()
!=
shapes
[
i
].
lens
())
{
std
::
cout
<<
ss
.
str
()
<<
std
::
endl
;
throw
std
::
runtime_error
(
"Compiling program with "
+
name
+
" alters its static output dimensions"
);
}
}
else
{
{
std
::
cout
<<
ss
.
str
()
<<
std
::
endl
;
std
::
cout
<<
ss
.
str
()
<<
std
::
endl
;
throw
std
::
runtime_error
(
"Compiling program with "
+
name
+
" alters its shape"
);
throw
std
::
runtime_error
(
"Compiling program with "
+
name
+
" alters its output dimensions (static shape vs dynamic shape)"
);
}
}
}
}
if
(
t
.
name
()
!=
"ref"
)
if
(
t
.
name
()
!=
"ref"
)
...
...
test/verify/test_
de
conv.cpp
→
test/verify/test_conv
olution_backwards
.cpp
View file @
97d4bb6c
...
@@ -27,7 +27,7 @@
...
@@ -27,7 +27,7 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
test_
de
conv
:
verify_program
<
test_
de
conv
>
struct
test_conv
olution_backwards
:
verify_program
<
test_conv
olution_backwards
>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
...
@@ -37,7 +37,7 @@ struct test_deconv : verify_program<test_deconv>
...
@@ -37,7 +37,7 @@ struct test_deconv : verify_program<test_deconv>
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
3
,
3
}});
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
3
,
3
}});
auto
weights
=
auto
weights
=
mm
->
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
3
,
3
}});
mm
->
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
3
,
3
}});
mm
->
add_instruction
(
migraphx
::
make_op
(
"
de
convolution"
),
input
,
weights
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"convolution
_backwards
"
),
input
,
weights
);
return
p
;
return
p
;
}
}
};
};
test/verify/test_
de
conv_1d.cpp
→
test/verify/test_conv
olution_backwards
_1d.cpp
View file @
97d4bb6c
...
@@ -27,7 +27,7 @@
...
@@ -27,7 +27,7 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
test_
de
conv_1d
:
verify_program
<
test_
de
conv_1d
>
struct
test_conv
olution_backwards
_1d
:
verify_program
<
test_conv
olution_backwards
_1d
>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
...
@@ -38,7 +38,7 @@ struct test_deconv_1d : verify_program<test_deconv_1d>
...
@@ -38,7 +38,7 @@ struct test_deconv_1d : verify_program<test_deconv_1d>
auto
weights
=
auto
weights
=
mm
->
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
3
}});
mm
->
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
3
}});
mm
->
add_instruction
(
mm
->
add_instruction
(
migraphx
::
make_op
(
"
de
convolution"
,
migraphx
::
make_op
(
"convolution
_backwards
"
,
{{
"padding"
,
{
0
}},
{
"stride"
,
{
1
}},
{
"dilation"
,
{
1
}}}),
{{
"padding"
,
{
0
}},
{
"stride"
,
{
1
}},
{
"dilation"
,
{
1
}}}),
input
,
input
,
weights
);
weights
);
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment