Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
13d14c66
Commit
13d14c66
authored
Oct 24, 2023
by
Brian Pickrell
Browse files
Merge branch 'develop' into dyn_resize_gather
parents
f4e7d9d9
d1abf06f
Changes
420
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
957 additions
and
198 deletions
+957
-198
test/ref_dev_examples.cpp
test/ref_dev_examples.cpp
+19
-20
test/replace_allocate.cpp
test/replace_allocate.cpp
+2
-2
test/rewrite_pooling_test.cpp
test/rewrite_pooling_test.cpp
+18
-20
test/rewrite_quantization_test.cpp
test/rewrite_quantization_test.cpp
+8
-1
test/run_on_target_test.cpp
test/run_on_target_test.cpp
+1
-1
test/shape_test.cpp
test/shape_test.cpp
+7
-7
test/simplify_algebra_test.cpp
test/simplify_algebra_test.cpp
+213
-4
test/simplify_dyn_ops_test.cpp
test/simplify_dyn_ops_test.cpp
+240
-0
test/simplify_qdq_test.cpp
test/simplify_qdq_test.cpp
+17
-17
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+101
-1
test/split_single_dyn_dim_test.cpp
test/split_single_dyn_dim_test.cpp
+4
-64
test/targets.cpp
test/targets.cpp
+1
-1
test/verify/CMakeLists.txt
test/verify/CMakeLists.txt
+3
-3
test/verify/ck_gemm_softmax_gemm.cpp
test/verify/ck_gemm_softmax_gemm.cpp
+56
-0
test/verify/run_verify.cpp
test/verify/run_verify.cpp
+3
-3
test/verify/test_arg_ops.cpp
test/verify/test_arg_ops.cpp
+100
-52
test/verify/test_flatten_dot_relu.cpp
test/verify/test_flatten_dot_relu.cpp
+46
-0
test/verify/test_layernorm.cpp
test/verify/test_layernorm.cpp
+22
-2
test/verify/test_reduce_add.cpp
test/verify/test_reduce_add.cpp
+48
-0
test/verify/test_reduce_noop_add.cpp
test/verify/test_reduce_noop_add.cpp
+48
-0
No files found.
test/ref_dev_examples.cpp
View file @
13d14c66
...
...
@@ -140,24 +140,6 @@ TEST_CASE(handling_tensors)
-
0.06269585
,
0.18658121
,
-
0.03944227
,
0.0111798
,
-
0.17731084
,
0.11789055
,
-
0.09982193
,
0.08142821
,
0.0729029
,
0.11303909
,
0.12735154
,
0.03885292
};
// Solution vector
std
::
vector
<
float
>
sol
=
{
-
0.20817225
,
0.87965256
,
0.14958936
,
-
1.24887264
,
-
0.06540672
,
0.20778663
,
0.40456355
,
-
0.99900877
,
0.4917807
,
0.1994698
,
0.64205718
,
0.37798831
,
-
0.25315839
,
0.44276932
,
-
0.16138598
,
0.79344082
};
// Create the arguments in a parameter_map
migraphx
::
parameter_map
params
;
params
[
"X"
]
=
migraphx
::
argument
(
input_shape
,
a
.
data
());
...
...
@@ -167,8 +149,25 @@ TEST_CASE(handling_tensors)
auto
result
=
p
.
eval
(
params
).
back
();
std
::
vector
<
float
>
results_vector
(
64
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
sol
));
// Solution vector
std
::
vector
<
float
>
gold
=
{
-
0.20817225
,
0.87965256
,
0.14958936
,
-
1.24887264
,
-
0.06540672
,
0.20778663
,
0.40456355
,
-
0.99900877
,
0.4917807
,
0.1994698
,
0.64205718
,
0.37798831
,
-
0.25315839
,
0.44276932
,
-
0.16138598
,
0.79344082
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
results_vector
,
gold
));
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/replace_allocate.cpp
View file @
13d14c66
...
...
@@ -54,7 +54,7 @@ struct allocate_no_out : migraphx::auto_register_op<allocate_no_out>
const
migraphx
::
shape
&
output_shape
,
const
std
::
vector
<
migraphx
::
argument
>&
)
const
{
return
{
output_shape
};
return
migraphx
::
argument
{
output_shape
};
}
};
...
...
@@ -78,7 +78,7 @@ struct allocate_with_out : migraphx::auto_register_op<allocate_with_out>
const
migraphx
::
shape
&
output_shape
,
const
std
::
vector
<
migraphx
::
argument
>&
)
const
{
return
{
output_shape
};
return
migraphx
::
argument
{
output_shape
};
}
};
...
...
test/rewrite_pooling_test.cpp
View file @
13d14c66
...
...
@@ -50,10 +50,10 @@ TEST_CASE(rewrite_pooling_test)
migraphx
::
module
m
;
auto
input
=
m
.
add_parameter
(
"x"
,
s
);
auto
ret
=
m
.
add_instruction
(
migraphx
::
make_op
(
"pooling"
,
{{
"mode"
,
mode
},
{
"padding"
,
{
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
,
1
}},
{
"lengths"
,
{
3
,
4
,
5
}}}),
{{
"mode"
,
mode
},
{
"padding"
,
{
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
,
1
}},
{
"lengths"
,
{
3
,
4
,
5
}}}),
input
);
m
.
add_return
({
ret
});
return
m
;
...
...
@@ -62,11 +62,8 @@ TEST_CASE(rewrite_pooling_test)
auto
opt_program
=
[
&
](
const
migraphx
::
operation
&
reduce_op
)
{
migraphx
::
module
m
;
auto
input
=
m
.
add_parameter
(
"x"
,
s
);
auto
rsp
=
m
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
4
,
-
1
}}}),
input
);
auto
rdm
=
m
.
add_instruction
(
reduce_op
,
rsp
);
auto
ret
=
m
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
2
,
2
,
1
,
1
,
1
}}}),
rdm
);
m
.
add_return
({
ret
});
auto
rdm
=
m
.
add_instruction
(
reduce_op
,
input
);
m
.
add_return
({
rdm
});
return
m
;
};
...
...
@@ -78,8 +75,9 @@ TEST_CASE(rewrite_pooling_test)
};
test_rewrite
(
migraphx
::
op
::
pooling_mode
::
average
,
migraphx
::
make_op
(
"reduce_mean"
,
{{
"axes"
,
{
1
}}}));
test_rewrite
(
migraphx
::
op
::
pooling_mode
::
max
,
migraphx
::
make_op
(
"reduce_max"
,
{{
"axes"
,
{
1
}}}));
migraphx
::
make_op
(
"reduce_mean"
,
{{
"axes"
,
{
2
,
3
,
4
}}}));
test_rewrite
(
migraphx
::
op
::
pooling_mode
::
max
,
migraphx
::
make_op
(
"reduce_max"
,
{{
"axes"
,
{
2
,
3
,
4
}}}));
}
TEST_CASE
(
rewrite_avepooling_na1_test
)
...
...
@@ -140,10 +138,10 @@ TEST_CASE(rewrite_avepooling_na3_test)
auto
input
=
m
.
add_parameter
(
"x"
,
s
);
auto
ret
=
m
.
add_instruction
(
migraphx
::
make_op
(
"pooling"
,
{{
"mode"
,
migraphx
::
op
::
pooling_mode
::
max
},
{
"padding"
,
{
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
,
1
}},
{
"lengths"
,
{
3
,
3
,
5
}}}),
{{
"mode"
,
migraphx
::
op
::
pooling_mode
::
max
},
{
"padding"
,
{
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
,
1
}},
{
"lengths"
,
{
3
,
3
,
5
}}}),
input
);
m
.
add_return
({
ret
});
return
m
;
...
...
@@ -168,10 +166,10 @@ TEST_CASE(literal_rewrite_pooling_test)
auto
*
mm
=
p
.
get_main_module
();
auto
input
=
mm
->
add_literal
(
migraphx
::
literal
(
s
,
data
));
auto
ret
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"pooling"
,
{{
"mode"
,
mode
},
{
"padding"
,
{
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
,
1
}},
{
"lengths"
,
{
3
,
4
,
5
}}}),
{{
"mode"
,
mode
},
{
"padding"
,
{
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
,
1
}},
{
"lengths"
,
{
3
,
4
,
5
}}}),
input
);
mm
->
add_return
({
ret
});
return
p
;
...
...
@@ -199,7 +197,7 @@ TEST_CASE(literal_rewrite_pooling_test)
auto
result1
=
p1
.
eval
({}).
back
();
auto
result2
=
p2
.
eval
({}).
back
();
visit_all
(
result1
,
result2
)(
[
&
](
auto
r1
,
auto
r2
)
{
EXPECT
(
migraphx
::
verify
::
verify_range
(
r1
,
r2
));
});
[
&
](
auto
r1
,
auto
r2
)
{
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
r1
,
r2
));
});
};
test_rewrite_pooling
(
migraphx
::
op
::
pooling_mode
::
max
,
...
...
test/rewrite_quantization_test.cpp
View file @
13d14c66
...
...
@@ -31,10 +31,13 @@
#include <migraphx/ranges.hpp>
#include <test.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/env.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/pass_manager.hpp>
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_CK_WORKAROUNDS
);
bool
is_quantizelinear
(
migraphx
::
instruction
&
ins
)
{
return
ins
.
name
()
==
"quantizelinear"
;
}
bool
is_dequantizelinear
(
migraphx
::
instruction
&
ins
)
{
return
ins
.
name
()
==
"dequantizelinear"
;
}
bool
is_clip_scalar
(
migraphx
::
instruction
&
ins
)
...
...
@@ -82,7 +85,11 @@ TEST_CASE(quantizelinear)
EXPECT
(
any_of
(
*
p1
.
get_main_module
(),
&
is_quantizelinear
));
EXPECT
(
none_of
(
*
p2
.
get_main_module
(),
&
is_quantizelinear
));
// ensure clip literals created in quantized program are scalar
EXPECT
(
any_of
(
*
p2
.
get_main_module
(),
&
is_clip_scalar
));
// unless CK workarounds are enabled
if
(
migraphx
::
enabled
(
MIGRAPHX_ENABLE_CK_WORKAROUNDS
{}))
EXPECT
(
none_of
(
*
p2
.
get_main_module
(),
&
is_clip_scalar
));
else
EXPECT
(
any_of
(
*
p2
.
get_main_module
(),
&
is_clip_scalar
));
}
TEST_CASE
(
dequantizelinear
)
...
...
test/run_on_target_test.cpp
View file @
13d14c66
...
...
@@ -68,7 +68,7 @@ TEST_CASE(eval_run_on_target)
std
::
vector
<
float
>
results_vector
(
3
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
0.5
,
0.25
,
0.125
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/shape_test.cpp
View file @
13d14c66
...
...
@@ -956,13 +956,13 @@ TEST_CASE(test_with_type)
TEST_CASE
(
test_multi_index
)
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
4
,
6
}};
EXPECT
(
migraphx
::
verify
::
verify_range
(
s
.
multi
(
0
),
std
::
vector
<
size_t
>
{
0
,
0
,
0
}));
EXPECT
(
migraphx
::
verify
::
verify_range
(
s
.
multi
(
4
),
std
::
vector
<
size_t
>
{
0
,
0
,
4
}));
EXPECT
(
migraphx
::
verify
::
verify_range
(
s
.
multi
(
6
),
std
::
vector
<
size_t
>
{
0
,
1
,
0
}));
EXPECT
(
migraphx
::
verify
::
verify_range
(
s
.
multi
(
8
),
std
::
vector
<
size_t
>
{
0
,
1
,
2
}));
EXPECT
(
migraphx
::
verify
::
verify_range
(
s
.
multi
(
24
),
std
::
vector
<
size_t
>
{
1
,
0
,
0
}));
EXPECT
(
migraphx
::
verify
::
verify_range
(
s
.
multi
(
30
),
std
::
vector
<
size_t
>
{
1
,
1
,
0
}));
EXPECT
(
migraphx
::
verify
::
verify_range
(
s
.
multi
(
34
),
std
::
vector
<
size_t
>
{
1
,
1
,
4
}));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
s
.
multi
(
0
),
std
::
vector
<
size_t
>
{
0
,
0
,
0
}));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
s
.
multi
(
4
),
std
::
vector
<
size_t
>
{
0
,
0
,
4
}));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
s
.
multi
(
6
),
std
::
vector
<
size_t
>
{
0
,
1
,
0
}));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
s
.
multi
(
8
),
std
::
vector
<
size_t
>
{
0
,
1
,
2
}));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
s
.
multi
(
24
),
std
::
vector
<
size_t
>
{
1
,
0
,
0
}));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
s
.
multi
(
30
),
std
::
vector
<
size_t
>
{
1
,
1
,
0
}));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
s
.
multi
(
34
),
std
::
vector
<
size_t
>
{
1
,
1
,
4
}));
}
TEST_CASE
(
find_permutation_2d_standard
)
...
...
test/simplify_algebra_test.cpp
View file @
13d14c66
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -669,6 +669,23 @@ TEST_CASE(simplify_inner_broadcast_different_broadcasts)
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
simplify_inner_broadcast_no_common_axis
)
{
auto
b
=
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
5
,
10
}}});
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
5
,
10
}});
auto
y
=
m1
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
5
,
1
}});
auto
xb
=
m1
.
add_instruction
(
b
,
x
);
auto
yb
=
m1
.
add_instruction
(
b
,
y
);
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
xb
,
yb
);
m1
.
add_instruction
(
pass_op
{},
sum
);
}
migraphx
::
module
m2
=
m1
;
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
simplify_add_conv1
)
{
migraphx
::
module
m
;
...
...
@@ -2910,6 +2927,179 @@ TEST_CASE(reorder_reshape_slice_not_apply)
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE
(
reorder_reshape_slice_multi_rsp
)
{
migraphx
::
module
m1
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
4
,
128
,
3
,
32
,
80
}};
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
t1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
2
,
0
,
3
,
1
,
4
}}}),
input
);
auto
slc0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
t1
);
auto
slc1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
2
}}}),
t1
);
auto
slc2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
2
}},
{
"ends"
,
{
3
}}}),
t1
);
auto
c1_1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc1
);
auto
c2_1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc2
);
auto
c1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc1
);
auto
r1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
4
,
32
,
128
,
80
}}}),
c1
);
auto
c2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc2
);
auto
r2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
4
,
32
,
128
,
80
}}}),
c2
);
auto
r1_1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
128
,
128
,
80
}}}),
c1_1
);
auto
r2_1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
128
,
128
,
80
}}}),
c2_1
);
auto
c0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc0
);
auto
r0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
128
,
128
,
80
}}}),
c0
);
auto
t2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
}}}),
r1_1
);
auto
c_t2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
t2
);
auto
dot
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
r0
,
c_t2
);
m1
.
add_return
({
r1
,
r2
,
dot
,
r2_1
});
};
migraphx
::
module
m2
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
4
,
128
,
3
,
32
,
80
}};
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
auto
t1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
2
,
0
,
3
,
1
,
4
}}}),
input
);
auto
c_t1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
t1
);
auto
rsp1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
384
,
128
,
80
}}}),
c_t1
);
auto
slc0
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
256
}},
{
"ends"
,
{
384
}}}),
rsp1
);
auto
slc1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
128
}},
{
"ends"
,
{
256
}}}),
rsp1
);
auto
t_slc1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
}}}),
slc1
);
auto
c_t_slc1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
t_slc1
);
auto
slc2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
128
}}}),
rsp1
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
slc2
,
c_t_slc1
);
auto
c_t1_1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
t1
);
auto
rsp2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
12
,
32
,
128
,
80
}}}),
c_t1_1
);
auto
slc2_1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
4
}},
{
"ends"
,
{
8
}}}),
rsp2
);
auto
slc2_2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
8
}},
{
"ends"
,
{
12
}}}),
rsp2
);
m2
.
add_return
({
slc2_1
,
slc2_2
,
dot
,
slc0
});
};
run_pass
(
m1
);
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE
(
reorder_reshape_slice_partial
)
{
migraphx
::
module
m1
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
128
,
96
}};
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
slc0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
8
}}}),
input
);
auto
slc1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
8
}},
{
"ends"
,
{
16
}}}),
input
);
auto
slc2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
16
}},
{
"ends"
,
{
24
}}}),
input
);
auto
slc3
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
24
}},
{
"ends"
,
{
128
}}}),
input
);
auto
c0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc0
);
auto
c1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc1
);
auto
c2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc2
);
std
::
vector
<
int64_t
>
lens
=
{
2
,
4
,
96
};
auto
r0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c0
);
auto
r1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c1
);
auto
r2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c2
);
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
r0
,
r1
);
auto
ret
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
sum
,
r2
);
m1
.
add_return
({
ret
,
slc3
});
};
migraphx
::
module
m2
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
128
,
96
}};
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
auto
rsp
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
32
,
4
,
96
}}}),
input
);
auto
slc3
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
24
}},
{
"ends"
,
{
128
}}}),
input
);
auto
slc0
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
2
}}}),
rsp
);
auto
slc1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
2
}},
{
"ends"
,
{
4
}}}),
rsp
);
auto
slc2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
4
}},
{
"ends"
,
{
6
}}}),
rsp
);
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
slc0
,
slc1
);
auto
ret
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
sum
,
slc2
);
m2
.
add_return
({
ret
,
slc3
});
};
run_pass
(
m1
);
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE
(
reorder_reshape_slice_uneven_slice
)
{
auto
create_p
=
[]
{
migraphx
::
module
m
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
128
,
96
}};
auto
input
=
m
.
add_parameter
(
"input"
,
s
);
auto
slc0
=
m
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
31
}}}),
input
);
auto
slc1
=
m
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
31
}},
{
"ends"
,
{
62
}}}),
input
);
auto
slc2
=
m
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
62
}},
{
"ends"
,
{
93
}}}),
input
);
auto
slc3
=
m
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
93
}},
{
"ends"
,
{
128
}}}),
input
);
auto
c0
=
m
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc0
);
auto
c1
=
m
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc1
);
auto
c2
=
m
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc2
);
std
::
vector
<
int64_t
>
lens
=
{
1
,
31
,
96
};
auto
r0
=
m
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c0
);
auto
r1
=
m
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c1
);
auto
r2
=
m
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c2
);
auto
sum
=
m
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
r0
,
r1
);
auto
ret
=
m
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
sum
,
r2
);
m
.
add_return
({
ret
,
slc3
});
return
m
;
};
auto
m1
=
create_p
();
auto
m2
=
m1
;
run_pass
(
m1
);
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
template
<
std
::
size_t
BS
>
void
reorder_reshape_slice_diff_dims
()
{
...
...
@@ -2931,13 +3121,32 @@ void reorder_reshape_slice_diff_dims()
std
::
vector
<
int64_t
>
lens
=
{
static_cast
<
int64_t
>
(
BS
),
32
,
3
,
32
};
std
::
vector
<
int64_t
>
lens1
=
{
static_cast
<
int64_t
>
(
BS
),
48
,
2
,
32
};
auto
r0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c0
);
auto
r1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c1
);
auto
r2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
1
}}),
c2
);
auto
r1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
1
}}),
c1
);
auto
r2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c2
);
m1
.
add_return
({
r0
,
r1
,
r2
});
};
auto
m2
=
m1
;
migraphx
::
module
m2
;
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
BS
,
96
,
96
}};
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
auto
slc1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
32
}},
{
"ends"
,
{
64
}}}),
input
);
auto
c1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc1
);
std
::
vector
<
int64_t
>
lens1
=
{
static_cast
<
int64_t
>
(
BS
),
48
,
2
,
32
};
auto
r1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens1
}}),
c1
);
std
::
vector
<
int64_t
>
lens
=
{
static_cast
<
int64_t
>
(
BS
),
32
,
3
,
96
};
auto
r_new
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
input
);
auto
slc0
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
3
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
32
}}}),
r_new
);
auto
slc2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
3
}},
{
"starts"
,
{
64
}},
{
"ends"
,
{
96
}}}),
r_new
);
m2
.
add_return
({
slc0
,
r1
,
slc2
});
};
run_pass
(
m1
);
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
...
...
test/simplify_dyn_ops_test.cpp
0 → 100644
View file @
13d14c66
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/split_single_dyn_dim.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
void
run_pass
(
migraphx
::
module
&
m
)
{
migraphx
::
run_passes
(
m
,
{
migraphx
::
simplify_dyn_ops
{},
migraphx
::
dead_code_elimination
{}});
}
TEST_CASE
(
static_broadcast
)
{
migraphx
::
module
m0
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
4
}};
auto
input
=
m0
.
add_parameter
(
"data"
,
s
);
migraphx
::
shape
lit_s
{
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
}}};
auto
literal_ins
=
m0
.
add_literal
(
migraphx
::
literal
{
lit_s
,
{
6
,
5
,
4
,
3
}});
auto
broadcast_lit
=
m0
.
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
s
.
lens
()}}),
literal_ins
);
auto
add_ins
=
m0
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
input
,
broadcast_lit
);
m0
.
add_return
({
add_ins
});
}
migraphx
::
module
m1
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
4
}};
auto
input
=
m1
.
add_parameter
(
"data"
,
s
);
migraphx
::
shape
lit_s
{
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
}}};
auto
literal_ins
=
m1
.
add_literal
(
migraphx
::
literal
{
lit_s
,
{
6
,
5
,
4
,
3
}});
auto
broadcast_lit
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
}}),
literal_ins
,
input
);
auto
add_ins
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
input
,
broadcast_lit
);
m1
.
add_return
({
add_ins
});
}
run_pass
(
m1
);
EXPECT
(
m0
==
m1
);
}
TEST_CASE
(
static_multibroadcast
)
{
migraphx
::
module
m0
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
4
}};
auto
input
=
m0
.
add_parameter
(
"data"
,
s
);
migraphx
::
shape
lit_s
{
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
},
{
0
}}};
auto
literal_ins
=
m0
.
add_literal
(
migraphx
::
literal
{
lit_s
,
{
6
}});
auto
broadcast_lit
=
m0
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s
.
lens
()}}),
literal_ins
);
auto
add_ins
=
m0
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
input
,
broadcast_lit
);
m0
.
add_return
({
add_ins
});
}
migraphx
::
module
m1
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
4
}};
auto
input
=
m1
.
add_parameter
(
"data"
,
s
);
migraphx
::
shape
lit_s
{
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
},
{
0
}}};
auto
literal_ins
=
m1
.
add_literal
(
migraphx
::
literal
{
lit_s
,
{
6
}});
auto
broadcast_lit
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
),
literal_ins
,
input
);
auto
add_ins
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
input
,
broadcast_lit
);
m1
.
add_return
({
add_ins
});
}
run_pass
(
m1
);
EXPECT
(
m0
==
m1
);
}
TEST_CASE
(
after_split_dyn_broadcast_match
)
{
migraphx
::
program
p0
;
{
auto
*
mm0
=
p0
.
get_main_module
();
// create batch submodules
auto
create_submodule
=
[
&
](
std
::
size_t
batch_size
,
const
std
::
string
&
module_name
)
{
auto
*
submod
=
p0
.
create_module
(
module_name
);
migraphx
::
shape
sm_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
4
}};
auto
sm_input
=
submod
->
add_parameter
(
"data"
,
sm_shape
);
migraphx
::
shape
lit_s
{
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
}}};
auto
literal_ins
=
submod
->
add_literal
(
migraphx
::
literal
{
lit_s
,
{
6
,
5
,
4
,
3
}});
auto
broadcast_lit
=
submod
->
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
sm_shape
.
lens
()}}),
literal_ins
);
auto
add_ins
=
submod
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
sm_input
,
broadcast_lit
);
submod
->
add_return
({
add_ins
});
return
submod
;
};
auto
*
dim1
=
create_submodule
(
1
,
"dim_1"
);
auto
*
dim2
=
create_submodule
(
2
,
"dim_2"
);
auto
*
dim3
=
create_submodule
(
3
,
"dim_3"
);
auto
*
dim4
=
create_submodule
(
4
,
"dim_4"
);
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
4
,
4
}}};
auto
input0
=
mm0
->
add_parameter
(
"data"
,
s
);
std
::
vector
<
migraphx
::
shape
>
sub_shapes
=
{};
sub_shapes
.
push_back
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
4
,
4
}}});
migraphx
::
shape
out_attr
=
migraphx
::
shape
{
sub_shapes
};
auto
sm_ins
=
mm0
->
add_instruction
(
migraphx
::
make_op
(
"select_module"
,
{{
"output_dyn_shapes"
,
migraphx
::
to_value
(
out_attr
)}}),
{
input0
},
{
dim1
,
dim2
,
dim3
,
dim4
});
auto
ret
=
mm0
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
0
}}),
sm_ins
);
mm0
->
add_return
({
ret
});
}
migraphx
::
program
p1
;
{
auto
*
mm1
=
p1
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
4
,
4
}}};
auto
input1
=
mm1
->
add_parameter
(
"data"
,
s
);
migraphx
::
shape
lit_s
{
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
}}};
auto
literal_ins
=
mm1
->
add_literal
(
migraphx
::
literal
{
lit_s
,
{
6
,
5
,
4
,
3
}});
auto
broadcast_lit
=
mm1
->
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
}}),
literal_ins
,
input1
);
auto
add_ins
=
mm1
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
input1
,
broadcast_lit
);
mm1
->
add_return
({
add_ins
});
}
migraphx
::
run_passes
(
p1
,
{
migraphx
::
split_single_dyn_dim
{},
migraphx
::
dead_code_elimination
{},
migraphx
::
simplify_dyn_ops
{}});
EXPECT
(
p0
==
p1
);
}
TEST_CASE
(
const_slice_3input
)
{
migraphx
::
module
m0
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
6
,
4
,
4
}};
auto
input
=
m0
.
add_parameter
(
"data"
,
s
);
auto
slice_ins
=
m0
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
0
}},
{
"ends"
,
{
3
}},
{
"axes"
,
{
0
}}}),
input
);
m0
.
add_return
({
slice_ins
});
}
migraphx
::
module
m1
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
6
,
4
,
4
}};
auto
input
=
m1
.
add_parameter
(
"data"
,
s
);
migraphx
::
shape
s1
{
migraphx
::
shape
::
int32_type
,
{
1
}};
auto
input_starts
=
m1
.
add_literal
(
migraphx
::
literal
{
s1
,
{
0
}});
auto
input_ends
=
m1
.
add_literal
(
migraphx
::
literal
{
s1
,
{
3
}});
auto
slice_ins
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}}}),
input
,
input_starts
,
input_ends
);
m1
.
add_return
({
slice_ins
});
}
run_pass
(
m1
);
EXPECT
(
m0
==
m1
);
}
TEST_CASE
(
const_slice_3input_dyn
)
{
migraphx
::
module
m0
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{{
6
,
6
},
{
2
,
4
,
{
2
,
4
}},
{
2
,
4
,
{
2
,
4
}}}};
auto
input
=
m0
.
add_parameter
(
"data"
,
s
);
auto
slice_ins
=
m0
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
0
}},
{
"ends"
,
{
3
}},
{
"axes"
,
{
0
}}}),
input
);
m0
.
add_return
({
slice_ins
});
}
migraphx
::
module
m1
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{{
6
,
6
},
{
2
,
4
,
{
2
,
4
}},
{
2
,
4
,
{
2
,
4
}}}};
auto
input
=
m1
.
add_parameter
(
"data"
,
s
);
migraphx
::
shape
s1
{
migraphx
::
shape
::
int32_type
,
{
1
}};
auto
input_starts
=
m1
.
add_literal
(
migraphx
::
literal
{
s1
,
{
0
}});
auto
input_ends
=
m1
.
add_literal
(
migraphx
::
literal
{
s1
,
{
3
}});
auto
slice_ins
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}}}),
input
,
input_starts
,
input_ends
);
m1
.
add_return
({
slice_ins
});
}
run_pass
(
m1
);
EXPECT
(
m0
==
m1
);
}
TEST_CASE
(
const_slice_4input
)
{
migraphx
::
module
m0
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
6
,
4
,
4
}};
auto
input
=
m0
.
add_parameter
(
"data"
,
s
);
auto
slice_ins
=
m0
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
0
}},
{
"ends"
,
{
3
}},
{
"axes"
,
{
0
}}}),
input
);
m0
.
add_return
({
slice_ins
});
}
migraphx
::
module
m1
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
6
,
4
,
4
}};
auto
input
=
m1
.
add_parameter
(
"data"
,
s
);
migraphx
::
shape
s1
{
migraphx
::
shape
::
int32_type
,
{
1
}};
auto
input_starts
=
m1
.
add_literal
(
migraphx
::
literal
{
s1
,
{
0
}});
auto
input_ends
=
m1
.
add_literal
(
migraphx
::
literal
{
s1
,
{
3
}});
auto
input_axes
=
m1
.
add_literal
(
migraphx
::
literal
{
s1
,
{
0
}});
auto
slice_ins
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
),
input
,
input_starts
,
input_ends
,
input_axes
);
m1
.
add_return
({
slice_ins
});
}
run_pass
(
m1
);
EXPECT
(
m0
==
m1
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/simplify_qdq_test.cpp
View file @
13d14c66
...
...
@@ -479,11 +479,11 @@ TEST_CASE(conv_pooling_dot)
auto
q1
=
add_quantize_op
(
m1
,
"quantizelinear"
,
input
,
scale
,
zero
);
auto
d5
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q1
,
scale
,
zero
);
auto
c1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"convolution"
,
{{
"padding"
,
{
0
,
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
}},
{
"dilation"
,
{
1
,
1
}},
{
"group"
,
1
},
{
"padding_mode"
,
0
}}),
{{
"padding"
,
{
0
,
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
}},
{
"dilation"
,
{
1
,
1
}},
{
"group"
,
1
},
{
"padding_mode"
,
0
}}),
d5
,
d1
);
auto
bc1
=
m1
.
add_instruction
(
...
...
@@ -526,11 +526,11 @@ TEST_CASE(conv_pooling_dot)
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
ab
,
scale
,
zero
);
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
input
,
scale
,
zero
);
auto
c1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
,
{{
"padding"
,
{
0
,
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
}},
{
"dilation"
,
{
1
,
1
}},
{
"group"
,
1
},
{
"padding_mode"
,
0
}}),
{{
"padding"
,
{
0
,
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
}},
{
"dilation"
,
{
1
,
1
}},
{
"group"
,
1
},
{
"padding_mode"
,
0
}}),
q1
,
weights
);
auto
d5
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
c1
,
scale1
);
...
...
@@ -585,11 +585,11 @@ TEST_CASE(mobilenet_snippet)
auto
q1
=
add_quantize_op
(
mm
,
"quantizelinear"
,
input
,
scale
,
zero
);
auto
d5
=
add_quantize_op
(
mm
,
"dequantizelinear"
,
q1
,
scale
,
zero
);
auto
c1
=
mm
.
add_instruction
(
migraphx
::
make_op
(
"convolution"
,
{{
"padding"
,
{
0
,
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
}},
{
"dilation"
,
{
1
,
1
}},
{
"group"
,
1
},
{
"padding_mode"
,
0
}}),
{{
"padding"
,
{
0
,
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
}},
{
"dilation"
,
{
1
,
1
}},
{
"group"
,
1
},
{
"padding_mode"
,
0
}}),
d5
,
d1
);
auto
bc1
=
mm
.
add_instruction
(
...
...
@@ -700,7 +700,7 @@ TEST_CASE(conv_correctness)
auto
result2
=
p2
.
eval
({{
"input"
,
input
},
{
"weights"
,
weights
}}).
back
();
std
::
vector
<
float
>
rv2
(
16
);
result2
.
visit
([
&
](
auto
output
)
{
rv2
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_range
(
rv1
,
rv2
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
rv1
,
rv2
));
}
TEST_CASE
(
dot_correctness
)
...
...
@@ -750,7 +750,7 @@ TEST_CASE(dot_correctness)
auto
result2
=
p2
.
eval
({{
"a"
,
a
},
{
"b"
,
b
}}).
back
();
std
::
vector
<
float
>
rv2
(
sh3
.
elements
());
result2
.
visit
([
&
](
auto
output
)
{
rv2
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_range
(
rv1
,
rv2
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
rv1
,
rv2
));
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/simplify_reshapes_test.cpp
View file @
13d14c66
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -67,6 +67,106 @@ migraphx::module make_concat_multibroadcast(const std::vector<size_t>& in_lens,
return
m
;
}
TEST_CASE
(
broadcast_transpose
)
{
migraphx
::
module
m1
;
{
auto
l
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
5
}});
auto
mb
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
5
}}}),
l
);
auto
t1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
2
,
0
,
1
}}}),
mb
);
m1
.
add_return
({
t1
});
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
l
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
5
}});
auto
u1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
,
1
}}}),
l
);
auto
t1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
2
,
0
,
1
}}}),
u1
);
auto
mb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
5
,
2
,
3
}}}),
t1
);
m2
.
add_return
({
mb
});
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
broadcast_transpose_opt
)
{
// extra transpose from transformation will be optimized out
migraphx
::
module
m1
;
{
auto
l
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
5
}});
auto
mb
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
5
}}}),
l
);
auto
t1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
,
2
}}}),
mb
);
m1
.
add_return
({
t1
});
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
l
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
5
}});
auto
u1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
,
1
}}}),
l
);
auto
mb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
2
,
5
}}}),
u1
);
m2
.
add_return
({
mb
});
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
broadcast_transpose_scalar
)
{
migraphx
::
module
m1
;
{
auto
l
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
},
{
0
}});
auto
mb
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
}}}),
l
);
auto
t1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
mb
);
m1
.
add_return
({
t1
});
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
l
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
},
{
0
}});
auto
mb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
2
}}}),
l
);
m2
.
add_return
({
mb
});
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
broadcast_transpose_scalar_multi_use
)
{
// multibroadcast used more than once
migraphx
::
module
m1
;
{
auto
l
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
},
{
0
}});
auto
mb
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
}}}),
l
);
auto
t1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
mb
);
auto
id
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"identity"
),
mb
);
m1
.
add_return
({
t1
,
id
});
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
l
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
},
{
0
}});
auto
mb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
2
}}}),
l
);
auto
mb2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
}}}),
l
);
auto
id
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"identity"
),
mb2
);
m2
.
add_return
({
mb
,
id
});
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
double_contig
)
{
migraphx
::
program
p
;
...
...
test/split_single_dyn_dim_test.cpp
View file @
13d14c66
...
...
@@ -50,8 +50,8 @@ TEST_CASE(dynamic_batch)
auto
sm_input
=
submod
->
add_parameter
(
"data"
,
sm_shape
);
migraphx
::
shape
lit_s
{
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
}}};
auto
literal_ins
=
submod
->
add_literal
(
migraphx
::
literal
{
lit_s
,
{
6
}});
auto
broadcast_lit
=
submod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sm_shape
.
lens
()}}
),
literal_ins
);
auto
broadcast_lit
=
submod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
),
literal_ins
,
sm_input
);
auto
add_ins
=
submod
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
sm_input
,
broadcast_lit
);
submod
->
add_return
({
add_ins
});
...
...
@@ -107,8 +107,8 @@ TEST_CASE(multiple_outputs)
auto
sm_input
=
submod
->
add_parameter
(
"data"
,
sm_shape
);
migraphx
::
shape
lit_s
{
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
}}};
auto
literal_ins
=
submod
->
add_literal
(
migraphx
::
literal
{
lit_s
,
{
6
}});
auto
broadcast_lit
=
submod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sm_shape
.
lens
()}}
),
literal_ins
);
auto
broadcast_lit
=
submod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
),
literal_ins
,
sm_input
);
auto
add0_ins
=
submod
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
sm_input
,
broadcast_lit
);
auto
add1_ins
=
submod
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
sm_input
,
sm_input
);
...
...
@@ -157,64 +157,4 @@ TEST_CASE(multiple_outputs)
EXPECT
(
p0
==
p1
);
}
TEST_CASE
(
broadcast_match
)
{
// Slightly different from ref_ops_test in that the literal is copied over the submodules.
// A different compiler pass will pull the literals from the submodules to the main module.
migraphx
::
program
p0
;
{
auto
*
mm0
=
p0
.
get_main_module
();
// create batch submodules
auto
create_submodule
=
[
&
](
std
::
size_t
batch_size
,
const
std
::
string
&
module_name
)
{
auto
*
submod
=
p0
.
create_module
(
module_name
);
migraphx
::
shape
sm_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
4
}};
auto
sm_input
=
submod
->
add_parameter
(
"data"
,
sm_shape
);
migraphx
::
shape
lit_s
{
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
}}};
auto
literal_ins
=
submod
->
add_literal
(
migraphx
::
literal
{
lit_s
,
{
6
,
5
,
4
,
3
}});
auto
broadcast_lit
=
submod
->
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
sm_shape
.
lens
()}}),
literal_ins
);
auto
add_ins
=
submod
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
sm_input
,
broadcast_lit
);
submod
->
add_return
({
add_ins
});
return
submod
;
};
auto
*
dim1
=
create_submodule
(
1
,
"dim_1"
);
auto
*
dim2
=
create_submodule
(
2
,
"dim_2"
);
auto
*
dim3
=
create_submodule
(
3
,
"dim_3"
);
auto
*
dim4
=
create_submodule
(
4
,
"dim_4"
);
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
4
,
4
}}};
auto
input0
=
mm0
->
add_parameter
(
"data"
,
s
);
std
::
vector
<
migraphx
::
shape
>
sub_shapes
=
{};
sub_shapes
.
push_back
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
4
,
4
}}});
migraphx
::
shape
out_attr
=
migraphx
::
shape
{
sub_shapes
};
auto
sm_ins
=
mm0
->
add_instruction
(
migraphx
::
make_op
(
"select_module"
,
{{
"output_dyn_shapes"
,
migraphx
::
to_value
(
out_attr
)}}),
{
input0
},
{
dim1
,
dim2
,
dim3
,
dim4
});
auto
ret
=
mm0
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
0
}}),
sm_ins
);
mm0
->
add_return
({
ret
});
}
migraphx
::
program
p1
;
{
auto
*
mm1
=
p1
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
4
,
4
}}};
auto
input1
=
mm1
->
add_parameter
(
"data"
,
s
);
migraphx
::
shape
lit_s
{
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
}}};
auto
literal_ins
=
mm1
->
add_literal
(
migraphx
::
literal
{
lit_s
,
{
6
,
5
,
4
,
3
}});
auto
broadcast_lit
=
mm1
->
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
}}),
literal_ins
,
input1
);
auto
add_ins
=
mm1
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
input1
,
broadcast_lit
);
mm1
->
add_return
({
add_ins
});
}
run_pass
(
p1
);
EXPECT
(
p0
==
p1
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/targets.cpp
View file @
13d14c66
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
test/verify/CMakeLists.txt
View file @
13d14c66
...
...
@@ -25,14 +25,14 @@
file
(
GLOB VERIFY_TESTS CONFIGURE_DEPENDS *.cpp
)
add_executable
(
test_verify
${
VERIFY_TESTS
}
)
add_dependencies
(
test
s
test_verify
)
add_dependencies
(
check
test_verify
)
rocm_mark_as_
test
(
test_verify
)
rocm_install_test
(
TARGETS
test_verify
)
target_link_libraries
(
test_verify migraphx migraphx_all_targets
)
target_include_directories
(
test_verify PUBLIC ../include
)
rocm_clang_tidy_check
(
test_verify
)
foreach
(
SECTION general rnn
)
add_test
_command
(
test_verify_
${
SECTION
}
test_verify
${
SECTION
}
)
rocm_
add_test
(
NAME
test_verify_
${
SECTION
}
COMMAND
test_verify
${
SECTION
}
)
set_tests_properties
(
test_verify_
${
SECTION
}
PROPERTIES
COST 100
)
...
...
test/verify/ck_gemm_softmax_gemm.cpp
0 → 100644
View file @
13d14c66
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
ck_gemm_softmax_gemm
:
verify_program
<
ck_gemm_softmax_gemm
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
half_type
,
{
1
,
12
,
256
,
256
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
half_type
,
{
1
,
12
,
256
,
256
}};
auto
m2_elements
=
m2_shape
.
elements
();
auto
a
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
b
=
mm
->
add_parameter
(
"2"
,
m1_shape
);
auto
b1
=
mm
->
add_parameter
(
"3"
,
m1_shape
);
std
::
vector
<
float
>
eights
(
m2_elements
,
0.125
);
auto
eight
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
eights
});
std
::
vector
<
float
>
zeros
(
m2_elements
,
0
);
auto
zero
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
zeros
});
b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
b
);
auto
gemm1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
a
,
b
);
auto
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
gemm1
,
eight
);
auto
bias
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
scale
,
zero
);
auto
softmax
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"softmax"
,
{{
"axis"
,
-
1
}}),
bias
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
softmax
,
b1
);
return
p
;
}
};
test/verify/run_verify.cpp
View file @
13d14c66
...
...
@@ -44,8 +44,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DUMP_TEST)
// An improved async, that doesn't block
template
<
class
Function
>
std
::
future
<
typename
std
::
result_of
<
Function
()
>::
type
>
detach_async
(
Function
&&
f
,
bool
parallel
=
true
)
std
::
future
<
std
::
invoke_result_t
<
Function
>>
detach_async
(
Function
&&
f
,
bool
parallel
=
true
)
{
if
(
parallel
)
{
...
...
@@ -251,7 +250,8 @@ void run_verify::verify(const std::string& name,
std
::
size_t
num
=
gold
.
size
();
for
(
std
::
size_t
i
=
0
;
((
i
<
num
)
and
passed
);
++
i
)
{
passed
&=
migraphx
::
verify_args
(
tname
,
gold
[
i
],
result
[
i
]);
passed
&=
migraphx
::
verify_args_with_tolerance
(
tname
,
result
[
i
],
migraphx
::
verify
::
expected
{
gold
[
i
]});
}
if
(
not
passed
or
migraphx
::
enabled
(
MIGRAPHX_TRACE_TEST
{}))
...
...
test/verify/test_arg_ops.cpp
View file @
13d14c66
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -29,8 +29,8 @@
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
template
<
class
T
,
int
Axis
,
int
NonStdShape
>
struct
test_arg_ops
:
verify_program
<
test_arg_ops
<
T
,
Axis
,
NonStdShape
>>
template
<
class
T
,
int
Axis
,
bool
LastIndex
,
int
NonStdShape
>
struct
test_arg_ops
:
verify_program
<
test_arg_ops
<
T
,
Axis
,
LastIndex
,
NonStdShape
>>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -54,63 +54,111 @@ struct test_arg_ops : verify_program<test_arg_ops<T, Axis, NonStdShape>>
break
;
default:
break
;
}
mm
->
add_instruction
(
T
{
Axis
},
param
);
mm
->
add_instruction
(
T
{
Axis
,
LastIndex
},
param
);
return
p
;
}
};
// transpose argmax tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
false
,
0
>;
// transpose argmin tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
false
,
0
>;
// broadcast argmax tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
false
,
1
>;
// broadcast argmin tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
false
,
1
>;
// slice argmax tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
false
,
2
>;
// slice argmin tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
false
,
2
>;
// default case, standard shape argmax tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
false
,
3
>;
// default case, standard shape argmin tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
false
,
3
>;
test/verify/test_flatten_dot_relu.cpp
0 → 100644
View file @
13d14c66
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_flatten_dot_relu
:
verify_program
<
test_flatten_dot_relu
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
a
=
mm
->
add_parameter
(
"a"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
,
3
,
5
}});
a
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"flatten"
,
{{
"axis"
,
3
}}),
a
);
auto
b
=
mm
->
add_parameter
(
"b"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
5
,
3
,
3
,
1
}});
b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"flatten"
,
{{
"axis"
,
3
}}),
b
);
auto
dot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
a
,
b
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
dot
);
return
p
;
}
};
test/verify/test_layernorm.cpp
View file @
13d14c66
...
...
@@ -49,7 +49,8 @@ migraphx::instruction_ref add_layernorm(migraphx::module& m,
auto
pow
=
m
.
add_instruction
(
migraphx
::
make_op
(
"pow"
),
sub
,
exponent_mbcast
);
auto
var
=
m
.
add_instruction
(
migraphx
::
make_op
(
"reduce_mean"
,
{{
"axes"
,
{
2
}}}),
pow
);
auto
epsilon_mbcast
=
m
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
dims
.
at
(
1
),
1
}}}),
epsilon
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
dims
.
at
(
0
),
dims
.
at
(
1
),
1
}}}),
epsilon
);
auto
add_epsilon
=
m
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
var
,
epsilon_mbcast
);
auto
sqrt
=
m
.
add_instruction
(
migraphx
::
make_op
(
"sqrt"
),
add_epsilon
);
auto
sqrt_mbcast
=
...
...
@@ -57,7 +58,8 @@ migraphx::instruction_ref add_layernorm(migraphx::module& m,
auto
div
=
m
.
add_instruction
(
migraphx
::
make_op
(
"div"
),
sub
,
sqrt_mbcast
);
auto
scale_mbcast
=
m
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
dims
}}),
scale
);
auto
mul
=
m
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
scale_mbcast
,
div
);
auto
mul
=
m
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
div
,
scale_mbcast
);
auto
bias_mbcast
=
m
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
dims
}}),
bias
);
return
m
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
mul
,
bias_mbcast
);
...
...
@@ -161,3 +163,21 @@ struct test_layernorm_triadd_large : verify_program<test_layernorm_triadd_large>
return
p
;
}
};
struct
test_add_layernorm_add_gemm_nonstd
:
verify_program
<
test_add_layernorm_add_gemm_nonstd
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
s
=
migraphx
::
shape
::
from_permutation
(
migraphx
::
shape
::
float_type
,
{
8
,
1
,
16
},
{
1
,
2
,
0
});
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s
);
auto
z
=
mm
->
add_parameter
(
"z"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
8
,
16
,
64
}});
auto
add
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x
,
y
);
auto
layernorm_ins
=
add_layernorm
(
*
mm
,
add
,
s
.
lens
());
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
layernorm_ins
,
z
);
return
p
;
}
};
test/verify/test_reduce_add.cpp
0 → 100644
View file @
13d14c66
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
struct
test_reduce_add
:
verify_program
<
test_reduce_add
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
4
,
1000
,
2
,
2
}};
migraphx
::
shape
bs
{
migraphx
::
shape
::
half_type
,
{
1
,
32
,
128
}};
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
auto
reduce_mean
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reduce_mean"
,
{{
"axes"
,
{
2
,
3
}}}),
x
);
auto
reduce_max
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reduce_max"
,
{{
"axes"
,
{
2
,
3
}}}),
x
);
auto
add
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
reduce_mean
,
reduce_max
);
mm
->
add_return
({
add
});
return
p
;
};
};
test/verify/test_reduce_noop_add.cpp
0 → 100644
View file @
13d14c66
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
struct
test_reduce_noop_add
:
verify_program
<
test_reduce_noop_add
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
4
,
1000
,
1
,
1
}};
migraphx
::
shape
bs
{
migraphx
::
shape
::
half_type
,
{
1
,
32
,
128
}};
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
auto
reduce_mean
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reduce_mean"
,
{{
"axes"
,
{
2
,
3
}}}),
x
);
auto
reduce_max
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reduce_max"
,
{{
"axes"
,
{
2
,
3
}}}),
x
);
auto
add
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
reduce_mean
,
reduce_max
);
mm
->
add_return
({
add
});
return
p
;
};
};
Prev
1
…
16
17
18
19
20
21
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment