Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
9c5f6324
Commit
9c5f6324
authored
Jan 27, 2022
by
Shucai Xiao
Browse files
Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into keep_std_shape
parents
90f10299
332cb710
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
217 additions
and
32 deletions
+217
-32
src/include/migraphx/op/argmax.hpp
src/include/migraphx/op/argmax.hpp
+1
-1
src/include/migraphx/op/argmin.hpp
src/include/migraphx/op/argmin.hpp
+1
-1
src/onnx/parse_hardsigmoid.cpp
src/onnx/parse_hardsigmoid.cpp
+20
-9
src/targets/gpu/argmax.cpp
src/targets/gpu/argmax.cpp
+1
-1
src/targets/gpu/argmin.cpp
src/targets/gpu/argmin.cpp
+1
-1
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
+2
-1
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+10
-0
test/onnx/hardswish_test.onnx
test/onnx/hardswish_test.onnx
+11
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+35
-0
test/py/onnx_backend_test.py
test/py/onnx_backend_test.py
+1
-0
test/ref_ops_nonstd_shape_test.cpp
test/ref_ops_nonstd_shape_test.cpp
+58
-0
test/verify/test_arg_ops.cpp
test/verify/test_arg_ops.cpp
+76
-18
No files found.
src/include/migraphx/op/argmax.hpp
View file @
9c5f6324
...
@@ -35,7 +35,7 @@ struct argmax
...
@@ -35,7 +35,7 @@ struct argmax
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
)
.
standard
()
;
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
lens
=
inputs
[
0
].
lens
();
auto
lens
=
inputs
[
0
].
lens
();
lens
[
axis
]
=
1
;
lens
[
axis
]
=
1
;
...
...
src/include/migraphx/op/argmin.hpp
View file @
9c5f6324
...
@@ -35,7 +35,7 @@ struct argmin
...
@@ -35,7 +35,7 @@ struct argmin
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
)
.
standard
()
;
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
lens
=
inputs
[
0
].
lens
();
auto
lens
=
inputs
[
0
].
lens
();
lens
[
axis
]
=
1
;
lens
[
axis
]
=
1
;
...
...
src/onnx/parse_hardsigmoid.cpp
View file @
9c5f6324
...
@@ -10,20 +10,27 @@ namespace onnx {
...
@@ -10,20 +10,27 @@ namespace onnx {
struct
parse_hardsigmoid
:
op_parser
<
parse_hardsigmoid
>
struct
parse_hardsigmoid
:
op_parser
<
parse_hardsigmoid
>
{
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"HardSigmoid"
}};
}
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"HardSigmoid"
}
,
{
"HardSwish"
}
};
}
instruction_ref
parse
(
const
op_desc
&
/*
opd
*/
,
instruction_ref
parse
(
const
op_desc
&
opd
,
const
onnx_parser
&
/*parser*/
,
const
onnx_parser
&
/*parser*/
,
const
onnx_parser
::
node_info
&
info
,
const
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>
args
)
const
std
::
vector
<
instruction_ref
>
args
)
const
{
{
float
alpha
=
0.2
;
float
alpha
=
0.2
;
float
beta
=
0.5
;
float
beta
=
0.5
;
if
(
contains
(
info
.
attributes
,
"alpha"
))
if
(
opd
.
onnx_name
==
"HardSwish"
)
alpha
=
info
.
attributes
.
at
(
"alpha"
).
f
();
{
alpha
=
1.0
/
6.0
;
}
else
{
if
(
contains
(
info
.
attributes
,
"alpha"
))
alpha
=
info
.
attributes
.
at
(
"alpha"
).
f
();
if
(
contains
(
info
.
attributes
,
"beta"
))
if
(
contains
(
info
.
attributes
,
"beta"
))
beta
=
info
.
attributes
.
at
(
"beta"
).
f
();
beta
=
info
.
attributes
.
at
(
"beta"
).
f
();
}
auto
input_lens
=
args
[
0
]
->
get_shape
().
lens
();
auto
input_lens
=
args
[
0
]
->
get_shape
().
lens
();
auto
input_type
=
args
[
0
]
->
get_shape
().
type
();
auto
input_type
=
args
[
0
]
->
get_shape
().
type
();
...
@@ -40,9 +47,13 @@ struct parse_hardsigmoid : op_parser<parse_hardsigmoid>
...
@@ -40,9 +47,13 @@ struct parse_hardsigmoid : op_parser<parse_hardsigmoid>
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input_lens
}}),
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input_lens
}}),
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
input_type
},
{
1
}}));
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
input_type
},
{
1
}}));
auto
mul
=
info
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
mb_alpha
,
args
[
0
]);
auto
mul
=
info
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
mb_alpha
,
args
[
0
]);
auto
add
=
info
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
mb_beta
,
mul
);
auto
add
=
info
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
mb_beta
,
mul
);
return
info
.
add_instruction
(
migraphx
::
make_op
(
"clip"
),
add
,
mb_zero
,
mb_one
);
auto
hardsigmoid
=
info
.
add_instruction
(
migraphx
::
make_op
(
"clip"
),
add
,
mb_zero
,
mb_one
);
if
(
opd
.
onnx_name
==
"HardSwish"
)
return
info
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
args
[
0
],
hardsigmoid
);
return
hardsigmoid
;
}
}
};
};
...
...
src/targets/gpu/argmax.cpp
View file @
9c5f6324
...
@@ -9,7 +9,7 @@ namespace gpu {
...
@@ -9,7 +9,7 @@ namespace gpu {
shape
hip_argmax
::
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
shape
hip_argmax
::
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
)
.
standard
()
;
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
return
op
.
normalize_compute_shape
({
inputs
.
at
(
0
)});
return
op
.
normalize_compute_shape
({
inputs
.
at
(
0
)});
}
}
...
...
src/targets/gpu/argmin.cpp
View file @
9c5f6324
...
@@ -9,7 +9,7 @@ namespace gpu {
...
@@ -9,7 +9,7 @@ namespace gpu {
shape
hip_argmin
::
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
shape
hip_argmin
::
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
)
.
standard
()
;
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
return
op
.
normalize_compute_shape
({
inputs
.
at
(
0
)});
return
op
.
normalize_compute_shape
({
inputs
.
at
(
0
)});
}
}
...
...
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
View file @
9c5f6324
...
@@ -76,8 +76,9 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a
...
@@ -76,8 +76,9 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a
size_t
batch_item_num
=
batch_lens
[
axis
];
size_t
batch_item_num
=
batch_lens
[
axis
];
batch_lens
[
axis
]
=
1
;
batch_lens
[
axis
]
=
1
;
migraphx
::
shape
batch_shape
{
arg_shape
.
type
(),
batch_lens
};
migraphx
::
shape
batch_shape
{
arg_shape
.
type
(),
batch_lens
};
migraphx
::
shape
std_arg_shape
{
arg_shape
.
type
(),
arg_shape
.
lens
()};
hip_visit_all
(
arg
,
arg_shape
,
batch_shape
)([
&
](
auto
input
,
auto
arg_s
,
auto
batch_s
)
{
hip_visit_all
(
arg
,
std_
arg_shape
,
batch_shape
)([
&
](
auto
input
,
auto
arg_s
,
auto
batch_s
)
{
auto
*
output
=
device_cast
(
result
.
get
<
int64_t
>
().
data
());
auto
*
output
=
device_cast
(
result
.
get
<
int64_t
>
().
data
());
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
input
)
::
value_type
>>
;
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
input
)
::
value_type
>>
;
// use one block for items in one batch.
// use one block for items in one batch.
...
...
test/onnx/gen_onnx.py
View file @
9c5f6324
...
@@ -1694,6 +1694,16 @@ def hardsigmoid_verify_test():
...
@@ -1694,6 +1694,16 @@ def hardsigmoid_verify_test():
return
([
node
],
[
x
],
[
y
])
return
([
node
],
[
x
],
[
y
])
@
onnx_test
def
hardswish_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
2
,
5
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT
,
[
2
,
5
])
node
=
onnx
.
helper
.
make_node
(
'HardSwish'
,
inputs
=
[
'x'
],
outputs
=
[
'y'
])
return
([
node
],
[
x
],
[
y
])
@
onnx_test
@
onnx_test
def
if_else_test
():
def
if_else_test
():
x
=
onnx
.
helper
.
make_tensor_value_info
(
'x'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
x
=
onnx
.
helper
.
make_tensor_value_info
(
'x'
,
onnx
.
TensorProto
.
FLOAT
,
[
2
,
3
])
...
...
test/onnx/hardswish_test.onnx
0 → 100644
View file @
9c5f6324
hardswish_test:M
xy" HardSwishhardswish_testZ
x
b
y
B
\ No newline at end of file
test/onnx/onnx_test.cpp
View file @
9c5f6324
...
@@ -1687,6 +1687,41 @@ TEST_CASE(hardsigmoid_half_test)
...
@@ -1687,6 +1687,41 @@ TEST_CASE(hardsigmoid_half_test)
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
}
}
TEST_CASE
(
hardswish_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
std
::
size_t
>
input_lens
{
2
,
5
};
auto
input_type
=
migraphx
::
shape
::
float_type
;
migraphx
::
shape
s
{
input_type
,
input_lens
};
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
float
alpha
=
1.0
/
6.0
;
float
beta
=
0.5
;
auto
mb_alpha
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input_lens
}}),
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
input_type
},
{
alpha
}}));
auto
mb_beta
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input_lens
}}),
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
input_type
},
{
beta
}}));
auto
mb_zero
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input_lens
}}),
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
input_type
},
{
0
}}));
auto
mb_one
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input_lens
}}),
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
input_type
},
{
1
}}));
auto
mul
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
mb_alpha
,
x
);
auto
add
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
mb_beta
,
mul
);
auto
hardsigmoid
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"clip"
),
add
,
mb_zero
,
mb_one
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
x
,
hardsigmoid
);
auto
prog
=
optimize_onnx
(
"hardswish_test.onnx"
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
if_else_test
)
TEST_CASE
(
if_else_test
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
...
test/py/onnx_backend_test.py
View file @
9c5f6324
...
@@ -119,6 +119,7 @@ def create_backend_test(testname=None, target_device=None):
...
@@ -119,6 +119,7 @@ def create_backend_test(testname=None, target_device=None):
backend_test
.
include
(
r
'.*test_globalmaxpool.*'
)
backend_test
.
include
(
r
'.*test_globalmaxpool.*'
)
backend_test
.
include
(
r
'.*test_greater.*'
)
backend_test
.
include
(
r
'.*test_greater.*'
)
backend_test
.
include
(
r
'.*test_hardsigmoid.*'
)
backend_test
.
include
(
r
'.*test_hardsigmoid.*'
)
backend_test
.
include
(
r
'.*test_hardswish.*'
)
backend_test
.
include
(
r
'.*test_identity.*'
)
backend_test
.
include
(
r
'.*test_identity.*'
)
backend_test
.
include
(
r
'.*test_if.*'
)
backend_test
.
include
(
r
'.*test_if.*'
)
backend_test
.
include
(
r
'.*test_LeakyReLU*'
)
backend_test
.
include
(
r
'.*test_LeakyReLU*'
)
...
...
test/ref_ops_nonstd_shape_test.cpp
0 → 100644
View file @
9c5f6324
#include <iostream>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/pass_manager.hpp>
#include "test.hpp"
TEST_CASE
(
argmax_test_nonstd_shape
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
float
>
data
=
{
1.2255
,
1.6834
,
-
2.0305
,
-
0.3221
,
0.4701
,
0.2583
,
0.7545
,
2.5758
,
-
1.6849
,
0.0928
,
0.9022
,
-
0.8765
,
-
0.4090
,
0.9301
,
2.0724
,
-
1.5706
,
0.4867
,
-
0.1493
,
0.6957
,
-
0.2179
,
0.7142
,
0.7177
,
0.0183
,
1.3497
};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}};
auto
dl
=
mm
->
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
auto
dl_trans
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
2
,
0
}}}),
dl
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"argmax"
,
{{
"axis"
,
-
3
}}),
dl_trans
);
auto
p_uncompiled
=
p
;
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
result
=
p
.
eval
({}).
back
();
auto
res_gold
=
p_uncompiled
.
eval
({}).
back
();
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
int64_t
>
res_gold_vec
;
res_gold
.
visit
([
&
](
auto
output
)
{
res_gold_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
result_vec
,
res_gold_vec
));
}
TEST_CASE
(
argmin_test_nonstd_shape
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
float
>
data
=
{
1.2255
,
1.6834
,
-
2.0305
,
-
0.3221
,
0.4701
,
0.2583
,
0.7545
,
2.5758
,
-
1.6849
,
0.0928
,
0.9022
,
-
0.8765
,
-
0.4090
,
0.9301
,
2.0724
,
-
1.5706
,
0.4867
,
-
0.1493
,
0.6957
,
-
0.2179
,
0.7142
,
0.7177
,
0.0183
,
1.3497
};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}};
auto
dl
=
mm
->
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
auto
dl_trans
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
2
,
0
}}}),
dl
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"argmin"
,
{{
"axis"
,
-
1
}}),
dl_trans
);
auto
p_uncompiled
=
p
;
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
result
=
p
.
eval
({}).
back
();
auto
res_gold
=
p_uncompiled
.
eval
({}).
back
();
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
int64_t
>
res_gold_vec
;
res_gold
.
visit
([
&
](
auto
output
)
{
res_gold_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
result_vec
,
res_gold_vec
));
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/verify/test_arg_ops.cpp
100755 → 100644
View file @
9c5f6324
...
@@ -2,34 +2,92 @@
...
@@ -2,34 +2,92 @@
#include "verify_program.hpp"
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/op/argmin.hpp>
template
<
class
T
,
int
Axis
>
template
<
class
T
,
int
Axis
,
int
NonStdShape
>
struct
test_arg_ops
:
verify_program
<
test_arg_ops
<
T
,
Axis
>>
struct
test_arg_ops
:
verify_program
<
test_arg_ops
<
T
,
Axis
,
NonStdShape
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
1025
}};
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
1
,
4
,
1025
}};
auto
param
=
mm
->
add_parameter
(
"data"
,
s
);
auto
param
=
mm
->
add_parameter
(
"data"
,
s
);
switch
(
NonStdShape
)
{
case
0
:
param
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
param
);
break
;
case
1
:
param
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
4
,
1025
}}}),
param
);
break
;
case
2
:
param
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
3
}}}),
param
);
break
;
default:
break
;
}
mm
->
add_instruction
(
T
{
Axis
},
param
);
mm
->
add_instruction
(
T
{
Axis
},
param
);
return
p
;
return
p
;
}
}
};
};
// transpose argmax tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
0
>;
// transpose argmin tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
0
>;
// broadcast argmax tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
1
>;
// broadcast argmin tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
1
>;
// slice argmax tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
2
>;
// slice argmin tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
2
>;
// default case, standard shape argmax tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
3
>;
// default case, standard shape argmin tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
3
>;
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment