Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
6ae4227a
Unverified
Commit
6ae4227a
authored
Oct 20, 2023
by
Zakor Gyula
Committed by
GitHub
Oct 20, 2023
Browse files
Add support for select_last_index attribute for ArgMax & ArgMin (#2235)
parent
f47e0b5b
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
326 additions
and
87 deletions
+326
-87
src/include/migraphx/op/argmax.hpp
src/include/migraphx/op/argmax.hpp
+9
-3
src/include/migraphx/op/argmin.hpp
src/include/migraphx/op/argmin.hpp
+8
-2
src/onnx/parse_arg_op.cpp
src/onnx/parse_arg_op.cpp
+14
-3
src/targets/gpu/argmax.cpp
src/targets/gpu/argmax.cpp
+3
-2
src/targets/gpu/argmin.cpp
src/targets/gpu/argmin.cpp
+3
-2
src/targets/gpu/device/argmax.cpp
src/targets/gpu/device/argmax.cpp
+10
-3
src/targets/gpu/device/argmin.cpp
src/targets/gpu/device/argmin.cpp
+10
-3
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
+39
-3
src/targets/gpu/include/migraphx/gpu/device/argmax.hpp
src/targets/gpu/include/migraphx/gpu/device/argmax.hpp
+3
-2
src/targets/gpu/include/migraphx/gpu/device/argmin.hpp
src/targets/gpu/include/migraphx/gpu/device/argmin.hpp
+3
-2
test/onnx/argmax_select_last_index_test.onnx
test/onnx/argmax_select_last_index_test.onnx
+0
-0
test/onnx/argmin_select_last_index_test.onnx
test/onnx/argmin_select_last_index_test.onnx
+0
-0
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+30
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+26
-0
test/py/onnx_backend_test.py
test/py/onnx_backend_test.py
+0
-10
test/ref/argmax.cpp
test/ref/argmax.cpp
+34
-0
test/ref/argmin.cpp
test/ref/argmin.cpp
+34
-0
test/verify/test_arg_ops.cpp
test/verify/test_arg_ops.cpp
+100
-52
No files found.
src/include/migraphx/op/argmax.hpp
View file @
6ae4227a
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -31,6 +31,7 @@
...
@@ -31,6 +31,7 @@
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/float_equal.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -38,12 +39,13 @@ namespace op {
...
@@ -38,12 +39,13 @@ namespace op {
struct
argmax
struct
argmax
{
{
int64_t
axis
=
0
;
int64_t
axis
=
0
;
bool
select_last_index
=
false
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
{
{
return
pack
(
f
(
self
.
axis
,
"axis"
));
return
pack
(
f
(
self
.
axis
,
"axis"
)
,
f
(
self
.
select_last_index
,
"select_last_index"
)
);
}
}
value
attributes
()
const
value
attributes
()
const
...
@@ -87,6 +89,10 @@ struct argmax
...
@@ -87,6 +89,10 @@ struct argmax
max_val
=
cur_val
;
max_val
=
cur_val
;
max_index
=
i
;
max_index
=
i
;
}
}
else
if
(
select_last_index
and
float_equal
(
max_val
,
cur_val
))
{
max_index
=
i
;
}
}
}
return
max_index
;
return
max_index
;
}
}
...
...
src/include/migraphx/op/argmin.hpp
View file @
6ae4227a
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -30,6 +30,7 @@
...
@@ -30,6 +30,7 @@
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/float_equal.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -38,11 +39,12 @@ namespace op {
...
@@ -38,11 +39,12 @@ namespace op {
struct
argmin
struct
argmin
{
{
int64_t
axis
=
0
;
int64_t
axis
=
0
;
bool
select_last_index
=
false
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
{
{
return
pack
(
f
(
self
.
axis
,
"axis"
));
return
pack
(
f
(
self
.
axis
,
"axis"
)
,
f
(
self
.
select_last_index
,
"select_last_index"
)
);
}
}
value
attributes
()
const
value
attributes
()
const
...
@@ -78,6 +80,10 @@ struct argmin
...
@@ -78,6 +80,10 @@ struct argmin
min_val
=
cur_val
;
min_val
=
cur_val
;
min_index
=
i
;
min_index
=
i
;
}
}
else
if
(
select_last_index
and
float_equal
(
min_val
,
cur_val
))
{
min_index
=
i
;
}
}
}
return
min_index
;
return
min_index
;
...
...
src/onnx/parse_arg_op.cpp
View file @
6ae4227a
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -50,14 +50,25 @@ struct parse_arg_op : op_parser<parse_arg_op>
...
@@ -50,14 +50,25 @@ struct parse_arg_op : op_parser<parse_arg_op>
keep_dims
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"keepdims"
)).
at
<
int
>
();
keep_dims
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"keepdims"
)).
at
<
int
>
();
}
}
bool
select_last_index
=
false
;
if
(
contains
(
info
.
attributes
,
"select_last_index"
))
{
select_last_index
=
static_cast
<
bool
>
(
parser
.
parse_value
(
info
.
attributes
.
at
(
"select_last_index"
)).
at
<
int
>
());
}
if
(
keep_dims
==
0
)
if
(
keep_dims
==
0
)
{
{
auto
ins
=
info
.
add_instruction
(
make_op
(
opd
.
op_name
,
{{
"axis"
,
axis
}}),
args
);
auto
ins
=
info
.
add_instruction
(
make_op
(
opd
.
op_name
,
{{
"axis"
,
axis
},
{
"select_last_index"
,
select_last_index
}}),
args
);
return
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
axis
}}}),
ins
);
return
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
axis
}}}),
ins
);
}
}
else
else
{
{
return
info
.
add_instruction
(
make_op
(
opd
.
op_name
,
{{
"axis"
,
axis
}}),
args
);
return
info
.
add_instruction
(
make_op
(
opd
.
op_name
,
{{
"axis"
,
axis
},
{
"select_last_index"
,
select_last_index
}}),
args
);
}
}
}
}
};
};
...
...
src/targets/gpu/argmax.cpp
View file @
6ae4227a
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -40,7 +40,8 @@ argument hip_argmax::compute(context& ctx, const shape&, const std::vector<argum
...
@@ -40,7 +40,8 @@ argument hip_argmax::compute(context& ctx, const shape&, const std::vector<argum
{
{
auto
n_dim
=
args
.
front
().
get_shape
().
lens
().
size
();
auto
n_dim
=
args
.
front
().
get_shape
().
lens
().
size
();
int64_t
tuned_axis
=
tune_axis
(
n_dim
,
op
.
axis
,
op
.
name
());
int64_t
tuned_axis
=
tune_axis
(
n_dim
,
op
.
axis
,
op
.
name
());
device
::
argmax
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
.
front
(),
tuned_axis
);
device
::
argmax
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
.
front
(),
tuned_axis
,
op
.
select_last_index
);
return
args
.
back
();
return
args
.
back
();
}
}
...
...
src/targets/gpu/argmin.cpp
View file @
6ae4227a
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -40,7 +40,8 @@ argument hip_argmin::compute(context& ctx, const shape&, const std::vector<argum
...
@@ -40,7 +40,8 @@ argument hip_argmin::compute(context& ctx, const shape&, const std::vector<argum
{
{
auto
n_dim
=
args
.
front
().
get_shape
().
lens
().
size
();
auto
n_dim
=
args
.
front
().
get_shape
().
lens
().
size
();
int64_t
tuned_axis
=
tune_axis
(
n_dim
,
op
.
axis
,
op
.
name
());
int64_t
tuned_axis
=
tune_axis
(
n_dim
,
op
.
axis
,
op
.
name
());
device
::
argmin
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
.
front
(),
tuned_axis
);
device
::
argmin
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
.
front
(),
tuned_axis
,
op
.
select_last_index
);
return
args
.
back
();
return
args
.
back
();
}
}
...
...
src/targets/gpu/device/argmax.cpp
View file @
6ae4227a
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -34,9 +34,16 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -34,9 +34,16 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
void
argmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int64_t
axis
)
void
argmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int64_t
axis
,
bool
select_last_index
)
{
{
arg_op
(
argmax_op
{},
stream
,
result
,
arg
,
axis
);
if
(
select_last_index
)
arg_op
(
argmax_op_last_index
{},
stream
,
result
,
arg
,
axis
);
else
arg_op
(
argmax_op_first_index
{},
stream
,
result
,
arg
,
axis
);
}
}
}
// namespace device
}
// namespace device
...
...
src/targets/gpu/device/argmin.cpp
View file @
6ae4227a
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -34,9 +34,16 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -34,9 +34,16 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
void
argmin
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int64_t
axis
)
void
argmin
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int64_t
axis
,
bool
select_last_index
)
{
{
arg_op
(
argmin_op
{},
stream
,
result
,
arg
,
axis
);
if
(
select_last_index
)
arg_op
(
argmin_op_last_index
{},
stream
,
result
,
arg
,
axis
);
else
arg_op
(
argmin_op_first_index
{},
stream
,
result
,
arg
,
axis
);
}
}
}
// namespace device
}
// namespace device
...
...
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
View file @
6ae4227a
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -55,7 +55,7 @@ MIGRAPHX_DEVICE_CONSTEXPR val_index<T> make_val_index(T v, int64_t i)
...
@@ -55,7 +55,7 @@ MIGRAPHX_DEVICE_CONSTEXPR val_index<T> make_val_index(T v, int64_t i)
return
{
v
,
i
};
return
{
v
,
i
};
}
}
struct
argmax_op
struct
argmax_op
_first_index
{
{
template
<
class
T
>
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
operator
()(
val_index
<
T
>
x
,
val_index
<
T
>
y
)
const
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
operator
()(
val_index
<
T
>
x
,
val_index
<
T
>
y
)
const
...
@@ -73,7 +73,25 @@ struct argmax_op
...
@@ -73,7 +73,25 @@ struct argmax_op
MIGRAPHX_DEVICE_CONSTEXPR
auto
init
()
const
{
return
lowest
();
}
MIGRAPHX_DEVICE_CONSTEXPR
auto
init
()
const
{
return
lowest
();
}
};
};
struct
argmin_op
struct
argmax_op_last_index
{
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
operator
()(
val_index
<
T
>
x
,
val_index
<
T
>
y
)
const
{
if
(
x
.
val
>
y
.
val
)
return
x
;
else
if
(
x
.
val
<
y
.
val
)
return
y
;
else
{
return
(
x
.
index
>
y
.
index
)
?
x
:
y
;
}
}
MIGRAPHX_DEVICE_CONSTEXPR
auto
init
()
const
{
return
lowest
();
}
};
struct
argmin_op_first_index
{
{
template
<
class
T
>
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
operator
()(
val_index
<
T
>
x
,
val_index
<
T
>
y
)
const
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
operator
()(
val_index
<
T
>
x
,
val_index
<
T
>
y
)
const
...
@@ -91,6 +109,24 @@ struct argmin_op
...
@@ -91,6 +109,24 @@ struct argmin_op
MIGRAPHX_DEVICE_CONSTEXPR
auto
init
()
const
{
return
highest
();
}
MIGRAPHX_DEVICE_CONSTEXPR
auto
init
()
const
{
return
highest
();
}
};
};
struct
argmin_op_last_index
{
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
operator
()(
val_index
<
T
>
x
,
val_index
<
T
>
y
)
const
{
if
(
x
.
val
<
y
.
val
)
return
x
;
else
if
(
x
.
val
>
y
.
val
)
return
y
;
else
{
return
(
x
.
index
>
y
.
index
)
?
x
:
y
;
}
}
MIGRAPHX_DEVICE_CONSTEXPR
auto
init
()
const
{
return
highest
();
}
};
template
<
class
Op
>
template
<
class
Op
>
void
arg_op
(
Op
op
,
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int64_t
axis
)
void
arg_op
(
Op
op
,
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int64_t
axis
)
{
{
...
...
src/targets/gpu/include/migraphx/gpu/device/argmax.hpp
View file @
6ae4227a
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -36,7 +36,8 @@ namespace device {
...
@@ -36,7 +36,8 @@ namespace device {
void
MIGRAPHX_DEVICE_EXPORT
argmax
(
hipStream_t
stream
,
void
MIGRAPHX_DEVICE_EXPORT
argmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
result
,
const
argument
&
arg
,
const
argument
&
arg
,
int64_t
axis
);
int64_t
axis
,
bool
select_last_index
);
}
// namespace device
}
// namespace device
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/include/migraphx/gpu/device/argmin.hpp
View file @
6ae4227a
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -36,7 +36,8 @@ namespace device {
...
@@ -36,7 +36,8 @@ namespace device {
void
MIGRAPHX_DEVICE_EXPORT
argmin
(
hipStream_t
stream
,
void
MIGRAPHX_DEVICE_EXPORT
argmin
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
result
,
const
argument
&
arg
,
const
argument
&
arg
,
int64_t
axis
);
int64_t
axis
,
bool
select_last_index
);
}
// namespace device
}
// namespace device
}
// namespace gpu
}
// namespace gpu
...
...
test/onnx/argmax_select_last_index_test.onnx
0 → 100644
View file @
6ae4227a
File added
test/onnx/argmin_select_last_index_test.onnx
0 → 100644
View file @
6ae4227a
File added
test/onnx/gen_onnx.py
View file @
6ae4227a
...
@@ -149,6 +149,21 @@ def argmax_test():
...
@@ -149,6 +149,21 @@ def argmax_test():
return
([
node
],
[
x
],
[
y
])
return
([
node
],
[
x
],
[
y
])
@
onnx_test
()
def
argmax_select_last_index_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
3
,
4
,
5
,
6
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT
,
[
3
,
4
,
6
])
node
=
onnx
.
helper
.
make_node
(
'ArgMax'
,
inputs
=
[
'x'
],
outputs
=
[
'y'
],
axis
=
2
,
keepdims
=
0
,
select_last_index
=
1
)
return
([
node
],
[
x
],
[
y
])
@
onnx_test
()
@
onnx_test
()
def
argmax_dyn_test
():
def
argmax_dyn_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
None
,
4
,
5
,
6
])
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
None
,
4
,
5
,
6
])
...
@@ -177,6 +192,21 @@ def argmin_test():
...
@@ -177,6 +192,21 @@ def argmin_test():
return
([
node
],
[
x
],
[
y
])
return
([
node
],
[
x
],
[
y
])
@
onnx_test
()
def
argmin_select_last_index_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
3
,
4
,
5
,
6
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT
,
[
3
,
4
,
5
])
node
=
onnx
.
helper
.
make_node
(
'ArgMin'
,
inputs
=
[
'x'
],
outputs
=
[
'y'
],
axis
=
3
,
keepdims
=
0
,
select_last_index
=
1
)
return
([
node
],
[
x
],
[
y
])
@
onnx_test
()
@
onnx_test
()
def
asin_test
():
def
asin_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
10
])
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
10
])
...
...
test/onnx/onnx_test.cpp
View file @
6ae4227a
...
@@ -184,6 +184,19 @@ TEST_CASE(argmax_test)
...
@@ -184,6 +184,19 @@ TEST_CASE(argmax_test)
EXPECT(p == prog);
EXPECT(p == prog);
}
}
TEST_CASE(argmax_select_last_index_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto ins = mm->add_instruction(
migraphx::make_op("argmax", {{"axis", 2}, {"select_last_index", true}}), l0);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins);
auto prog = optimize_onnx("argmax_select_last_index_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(argmax_dyn_test)
TEST_CASE(argmax_dyn_test)
{
{
migraphx::program p;
migraphx::program p;
...
@@ -213,6 +226,19 @@ TEST_CASE(argmin_test)
...
@@ -213,6 +226,19 @@ TEST_CASE(argmin_test)
EXPECT(p == prog);
EXPECT(p == prog);
}
}
TEST_CASE(argmin_select_last_index_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto ins = mm->add_instruction(
migraphx::make_op("argmin", {{"axis", 3}, {"select_last_index", true}}), l0);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {3}}}), ins);
auto prog = optimize_onnx("argmin_select_last_index_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(asin_test)
TEST_CASE(asin_test)
{
{
migraphx::program p;
migraphx::program p;
...
...
test/py/onnx_backend_test.py
View file @
6ae4227a
...
@@ -66,16 +66,6 @@ class MIGraphXBackendTest(onnx.backend.test.BackendTest):
...
@@ -66,16 +66,6 @@ class MIGraphXBackendTest(onnx.backend.test.BackendTest):
def
disabled_tests_onnx_1_7_0
(
backend_test
):
def
disabled_tests_onnx_1_7_0
(
backend_test
):
# fails
# fails
# from OnnxBackendNodeModelTest
# from OnnxBackendNodeModelTest
backend_test
.
exclude
(
r
'test_argmax_keepdims_example_select_last_index_cpu'
)
backend_test
.
exclude
(
r
'test_argmax_negative_axis_keepdims_example_select_last_index_cpu'
)
backend_test
.
exclude
(
r
'test_argmax_no_keepdims_example_select_last_index_cpu'
)
backend_test
.
exclude
(
r
'test_argmin_keepdims_example_select_last_index_cpu'
)
backend_test
.
exclude
(
r
'test_argmin_negative_axis_keepdims_example_select_last_index_cpu'
)
backend_test
.
exclude
(
r
'test_argmin_no_keepdims_example_select_last_index_cpu'
)
backend_test
.
exclude
(
r
'test_logsoftmax_axis_0_cpu'
)
backend_test
.
exclude
(
r
'test_logsoftmax_axis_0_cpu'
)
backend_test
.
exclude
(
r
'test_logsoftmax_axis_1_cpu'
)
backend_test
.
exclude
(
r
'test_logsoftmax_axis_1_cpu'
)
backend_test
.
exclude
(
r
'test_logsoftmax_default_axis_cpu'
)
backend_test
.
exclude
(
r
'test_logsoftmax_default_axis_cpu'
)
...
...
test/ref/argmax.cpp
View file @
6ae4227a
...
@@ -147,3 +147,37 @@ TEST_CASE(argmax_test_nonstd_shape)
...
@@ -147,3 +147,37 @@ TEST_CASE(argmax_test_nonstd_shape)
res_gold
.
visit
([
&
](
auto
output
)
{
res_gold_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
res_gold
.
visit
([
&
](
auto
output
)
{
res_gold_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vec
,
res_gold_vec
));
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vec
,
res_gold_vec
));
}
}
TEST_CASE
(
argmax_test_select_last_index_0
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
float
>
data
=
{
2.0305
,
-
1.853
,
2.0305
,
-
1.5706
,
0.7545
,
0.7545
};
std
::
vector
<
int64_t
>
res_gold
=
{
2
,
2
};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
auto
dl
=
mm
->
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"argmax"
,
{{
"axis"
,
1
},
{
"select_last_index"
,
true
}}),
dl
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vec
,
res_gold
));
}
TEST_CASE
(
argmax_test_select_last_index_1
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
float
>
data
=
{
2.0305
,
-
1.853
,
2.0305
,
-
1.5706
,
0.7545
,
0.7545
};
std
::
vector
<
int64_t
>
res_gold
=
{
0
,
1
};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
auto
dl
=
mm
->
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"argmax"
,
{{
"axis"
,
1
},
{
"select_last_index"
,
false
}}),
dl
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vec
,
res_gold
));
}
test/ref/argmin.cpp
View file @
6ae4227a
...
@@ -125,3 +125,37 @@ TEST_CASE(argmin_test_nonstd_shape)
...
@@ -125,3 +125,37 @@ TEST_CASE(argmin_test_nonstd_shape)
res_gold
.
visit
([
&
](
auto
output
)
{
res_gold_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
res_gold
.
visit
([
&
](
auto
output
)
{
res_gold_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vec
,
res_gold_vec
));
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vec
,
res_gold_vec
));
}
}
TEST_CASE
(
argmin_test_select_last_index_0
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
float
>
data
=
{
-
2.0305
,
0.853
,
-
2.0305
,
1.5706
,
0.7545
,
0.7545
};
std
::
vector
<
int64_t
>
res_gold
=
{
2
,
2
};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
auto
dl
=
mm
->
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"argmin"
,
{{
"axis"
,
1
},
{
"select_last_index"
,
true
}}),
dl
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vec
,
res_gold
));
}
TEST_CASE
(
argmin_test_select_last_index_1
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
float
>
data
=
{
-
2.0305
,
0.853
,
-
2.0305
,
1.5706
,
0.7545
,
0.7545
};
std
::
vector
<
int64_t
>
res_gold
=
{
0
,
1
};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
auto
dl
=
mm
->
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"argmin"
,
{{
"axis"
,
1
},
{
"select_last_index"
,
false
}}),
dl
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vec
,
res_gold
));
}
test/verify/test_arg_ops.cpp
View file @
6ae4227a
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -29,8 +29,8 @@
...
@@ -29,8 +29,8 @@
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/op/argmin.hpp>
template
<
class
T
,
int
Axis
,
int
NonStdShape
>
template
<
class
T
,
int
Axis
,
bool
LastIndex
,
int
NonStdShape
>
struct
test_arg_ops
:
verify_program
<
test_arg_ops
<
T
,
Axis
,
NonStdShape
>>
struct
test_arg_ops
:
verify_program
<
test_arg_ops
<
T
,
Axis
,
LastIndex
,
NonStdShape
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
...
@@ -54,63 +54,111 @@ struct test_arg_ops : verify_program<test_arg_ops<T, Axis, NonStdShape>>
...
@@ -54,63 +54,111 @@ struct test_arg_ops : verify_program<test_arg_ops<T, Axis, NonStdShape>>
break
;
break
;
default:
break
;
default:
break
;
}
}
mm
->
add_instruction
(
T
{
Axis
},
param
);
mm
->
add_instruction
(
T
{
Axis
,
LastIndex
},
param
);
return
p
;
return
p
;
}
}
};
};
// transpose argmax tests
// transpose argmax tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
false
,
0
>;
// transpose argmin tests
// transpose argmin tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
false
,
0
>;
// broadcast argmax tests
// broadcast argmax tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
false
,
1
>;
// broadcast argmin tests
// broadcast argmin tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
false
,
1
>;
// slice argmax tests
// slice argmax tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
false
,
2
>;
// slice argmin tests
// slice argmin tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
false
,
2
>;
// default case, standard shape argmax tests
// default case, standard shape argmax tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
false
,
3
>;
// default case, standard shape argmin tests
// default case, standard shape argmin tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
false
,
3
>;
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment