Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2013da67
Commit
2013da67
authored
Oct 20, 2023
by
Umang Yadav
Browse files
Merge branch 'develop' into resnet50_partition
parents
d3496ac6
6ae4227a
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
326 additions
and
87 deletions
+326
-87
src/include/migraphx/op/argmax.hpp
src/include/migraphx/op/argmax.hpp
+9
-3
src/include/migraphx/op/argmin.hpp
src/include/migraphx/op/argmin.hpp
+8
-2
src/onnx/parse_arg_op.cpp
src/onnx/parse_arg_op.cpp
+14
-3
src/targets/gpu/argmax.cpp
src/targets/gpu/argmax.cpp
+3
-2
src/targets/gpu/argmin.cpp
src/targets/gpu/argmin.cpp
+3
-2
src/targets/gpu/device/argmax.cpp
src/targets/gpu/device/argmax.cpp
+10
-3
src/targets/gpu/device/argmin.cpp
src/targets/gpu/device/argmin.cpp
+10
-3
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
+39
-3
src/targets/gpu/include/migraphx/gpu/device/argmax.hpp
src/targets/gpu/include/migraphx/gpu/device/argmax.hpp
+3
-2
src/targets/gpu/include/migraphx/gpu/device/argmin.hpp
src/targets/gpu/include/migraphx/gpu/device/argmin.hpp
+3
-2
test/onnx/argmax_select_last_index_test.onnx
test/onnx/argmax_select_last_index_test.onnx
+0
-0
test/onnx/argmin_select_last_index_test.onnx
test/onnx/argmin_select_last_index_test.onnx
+0
-0
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+30
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+26
-0
test/py/onnx_backend_test.py
test/py/onnx_backend_test.py
+0
-10
test/ref/argmax.cpp
test/ref/argmax.cpp
+34
-0
test/ref/argmin.cpp
test/ref/argmin.cpp
+34
-0
test/verify/test_arg_ops.cpp
test/verify/test_arg_ops.cpp
+100
-52
No files found.
src/include/migraphx/op/argmax.hpp
View file @
2013da67
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -31,6 +31,7 @@
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/float_equal.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -38,12 +39,13 @@ namespace op {
struct
argmax
{
int64_t
axis
=
0
;
int64_t
axis
=
0
;
bool
select_last_index
=
false
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axis
,
"axis"
));
return
pack
(
f
(
self
.
axis
,
"axis"
)
,
f
(
self
.
select_last_index
,
"select_last_index"
)
);
}
value
attributes
()
const
...
...
@@ -87,6 +89,10 @@ struct argmax
max_val
=
cur_val
;
max_index
=
i
;
}
else
if
(
select_last_index
and
float_equal
(
max_val
,
cur_val
))
{
max_index
=
i
;
}
}
return
max_index
;
}
...
...
src/include/migraphx/op/argmin.hpp
View file @
2013da67
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -30,6 +30,7 @@
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/float_equal.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -38,11 +39,12 @@ namespace op {
struct
argmin
{
int64_t
axis
=
0
;
bool
select_last_index
=
false
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axis
,
"axis"
));
return
pack
(
f
(
self
.
axis
,
"axis"
)
,
f
(
self
.
select_last_index
,
"select_last_index"
)
);
}
value
attributes
()
const
...
...
@@ -78,6 +80,10 @@ struct argmin
min_val
=
cur_val
;
min_index
=
i
;
}
else
if
(
select_last_index
and
float_equal
(
min_val
,
cur_val
))
{
min_index
=
i
;
}
}
return
min_index
;
...
...
src/onnx/parse_arg_op.cpp
View file @
2013da67
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -50,14 +50,25 @@ struct parse_arg_op : op_parser<parse_arg_op>
keep_dims
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"keepdims"
)).
at
<
int
>
();
}
bool
select_last_index
=
false
;
if
(
contains
(
info
.
attributes
,
"select_last_index"
))
{
select_last_index
=
static_cast
<
bool
>
(
parser
.
parse_value
(
info
.
attributes
.
at
(
"select_last_index"
)).
at
<
int
>
());
}
if
(
keep_dims
==
0
)
{
auto
ins
=
info
.
add_instruction
(
make_op
(
opd
.
op_name
,
{{
"axis"
,
axis
}}),
args
);
auto
ins
=
info
.
add_instruction
(
make_op
(
opd
.
op_name
,
{{
"axis"
,
axis
},
{
"select_last_index"
,
select_last_index
}}),
args
);
return
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
axis
}}}),
ins
);
}
else
{
return
info
.
add_instruction
(
make_op
(
opd
.
op_name
,
{{
"axis"
,
axis
}}),
args
);
return
info
.
add_instruction
(
make_op
(
opd
.
op_name
,
{{
"axis"
,
axis
},
{
"select_last_index"
,
select_last_index
}}),
args
);
}
}
};
...
...
src/targets/gpu/argmax.cpp
View file @
2013da67
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -40,7 +40,8 @@ argument hip_argmax::compute(context& ctx, const shape&, const std::vector<argum
{
auto
n_dim
=
args
.
front
().
get_shape
().
lens
().
size
();
int64_t
tuned_axis
=
tune_axis
(
n_dim
,
op
.
axis
,
op
.
name
());
device
::
argmax
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
.
front
(),
tuned_axis
);
device
::
argmax
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
.
front
(),
tuned_axis
,
op
.
select_last_index
);
return
args
.
back
();
}
...
...
src/targets/gpu/argmin.cpp
View file @
2013da67
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -40,7 +40,8 @@ argument hip_argmin::compute(context& ctx, const shape&, const std::vector<argum
{
auto
n_dim
=
args
.
front
().
get_shape
().
lens
().
size
();
int64_t
tuned_axis
=
tune_axis
(
n_dim
,
op
.
axis
,
op
.
name
());
device
::
argmin
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
.
front
(),
tuned_axis
);
device
::
argmin
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
.
front
(),
tuned_axis
,
op
.
select_last_index
);
return
args
.
back
();
}
...
...
src/targets/gpu/device/argmax.cpp
View file @
2013da67
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -34,9 +34,16 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
device
{
void
argmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int64_t
axis
)
void
argmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int64_t
axis
,
bool
select_last_index
)
{
arg_op
(
argmax_op
{},
stream
,
result
,
arg
,
axis
);
if
(
select_last_index
)
arg_op
(
argmax_op_last_index
{},
stream
,
result
,
arg
,
axis
);
else
arg_op
(
argmax_op_first_index
{},
stream
,
result
,
arg
,
axis
);
}
}
// namespace device
...
...
src/targets/gpu/device/argmin.cpp
View file @
2013da67
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -34,9 +34,16 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
device
{
void
argmin
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int64_t
axis
)
void
argmin
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int64_t
axis
,
bool
select_last_index
)
{
arg_op
(
argmin_op
{},
stream
,
result
,
arg
,
axis
);
if
(
select_last_index
)
arg_op
(
argmin_op_last_index
{},
stream
,
result
,
arg
,
axis
);
else
arg_op
(
argmin_op_first_index
{},
stream
,
result
,
arg
,
axis
);
}
}
// namespace device
...
...
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
View file @
2013da67
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -55,7 +55,7 @@ MIGRAPHX_DEVICE_CONSTEXPR val_index<T> make_val_index(T v, int64_t i)
return
{
v
,
i
};
}
struct
argmax_op
struct
argmax_op
_first_index
{
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
operator
()(
val_index
<
T
>
x
,
val_index
<
T
>
y
)
const
...
...
@@ -73,7 +73,25 @@ struct argmax_op
MIGRAPHX_DEVICE_CONSTEXPR
auto
init
()
const
{
return
lowest
();
}
};
struct
argmin_op
struct
argmax_op_last_index
{
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
operator
()(
val_index
<
T
>
x
,
val_index
<
T
>
y
)
const
{
if
(
x
.
val
>
y
.
val
)
return
x
;
else
if
(
x
.
val
<
y
.
val
)
return
y
;
else
{
return
(
x
.
index
>
y
.
index
)
?
x
:
y
;
}
}
MIGRAPHX_DEVICE_CONSTEXPR
auto
init
()
const
{
return
lowest
();
}
};
struct
argmin_op_first_index
{
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
operator
()(
val_index
<
T
>
x
,
val_index
<
T
>
y
)
const
...
...
@@ -91,6 +109,24 @@ struct argmin_op
MIGRAPHX_DEVICE_CONSTEXPR
auto
init
()
const
{
return
highest
();
}
};
struct
argmin_op_last_index
{
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
operator
()(
val_index
<
T
>
x
,
val_index
<
T
>
y
)
const
{
if
(
x
.
val
<
y
.
val
)
return
x
;
else
if
(
x
.
val
>
y
.
val
)
return
y
;
else
{
return
(
x
.
index
>
y
.
index
)
?
x
:
y
;
}
}
MIGRAPHX_DEVICE_CONSTEXPR
auto
init
()
const
{
return
highest
();
}
};
template
<
class
Op
>
void
arg_op
(
Op
op
,
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int64_t
axis
)
{
...
...
src/targets/gpu/include/migraphx/gpu/device/argmax.hpp
View file @
2013da67
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -36,7 +36,8 @@ namespace device {
void
MIGRAPHX_DEVICE_EXPORT
argmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int64_t
axis
);
int64_t
axis
,
bool
select_last_index
);
}
// namespace device
}
// namespace gpu
...
...
src/targets/gpu/include/migraphx/gpu/device/argmin.hpp
View file @
2013da67
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -36,7 +36,8 @@ namespace device {
void
MIGRAPHX_DEVICE_EXPORT
argmin
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int64_t
axis
);
int64_t
axis
,
bool
select_last_index
);
}
// namespace device
}
// namespace gpu
...
...
test/onnx/argmax_select_last_index_test.onnx
0 → 100644
View file @
2013da67
File added
test/onnx/argmin_select_last_index_test.onnx
0 → 100644
View file @
2013da67
File added
test/onnx/gen_onnx.py
View file @
2013da67
...
...
@@ -149,6 +149,21 @@ def argmax_test():
return
([
node
],
[
x
],
[
y
])
@
onnx_test
()
def
argmax_select_last_index_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
3
,
4
,
5
,
6
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT
,
[
3
,
4
,
6
])
node
=
onnx
.
helper
.
make_node
(
'ArgMax'
,
inputs
=
[
'x'
],
outputs
=
[
'y'
],
axis
=
2
,
keepdims
=
0
,
select_last_index
=
1
)
return
([
node
],
[
x
],
[
y
])
@
onnx_test
()
def
argmax_dyn_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
None
,
4
,
5
,
6
])
...
...
@@ -177,6 +192,21 @@ def argmin_test():
return
([
node
],
[
x
],
[
y
])
@
onnx_test
()
def
argmin_select_last_index_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
3
,
4
,
5
,
6
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT
,
[
3
,
4
,
5
])
node
=
onnx
.
helper
.
make_node
(
'ArgMin'
,
inputs
=
[
'x'
],
outputs
=
[
'y'
],
axis
=
3
,
keepdims
=
0
,
select_last_index
=
1
)
return
([
node
],
[
x
],
[
y
])
@
onnx_test
()
def
asin_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
10
])
...
...
test/onnx/onnx_test.cpp
View file @
2013da67
...
...
@@ -184,6 +184,19 @@ TEST_CASE(argmax_test)
EXPECT(p == prog);
}
TEST_CASE(argmax_select_last_index_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto ins = mm->add_instruction(
migraphx::make_op("argmax", {{"axis", 2}, {"select_last_index", true}}), l0);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins);
auto prog = optimize_onnx("argmax_select_last_index_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(argmax_dyn_test)
{
migraphx::program p;
...
...
@@ -213,6 +226,19 @@ TEST_CASE(argmin_test)
EXPECT(p == prog);
}
TEST_CASE(argmin_select_last_index_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto ins = mm->add_instruction(
migraphx::make_op("argmin", {{"axis", 3}, {"select_last_index", true}}), l0);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {3}}}), ins);
auto prog = optimize_onnx("argmin_select_last_index_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(asin_test)
{
migraphx::program p;
...
...
test/py/onnx_backend_test.py
View file @
2013da67
...
...
@@ -66,16 +66,6 @@ class MIGraphXBackendTest(onnx.backend.test.BackendTest):
def
disabled_tests_onnx_1_7_0
(
backend_test
):
# fails
# from OnnxBackendNodeModelTest
backend_test
.
exclude
(
r
'test_argmax_keepdims_example_select_last_index_cpu'
)
backend_test
.
exclude
(
r
'test_argmax_negative_axis_keepdims_example_select_last_index_cpu'
)
backend_test
.
exclude
(
r
'test_argmax_no_keepdims_example_select_last_index_cpu'
)
backend_test
.
exclude
(
r
'test_argmin_keepdims_example_select_last_index_cpu'
)
backend_test
.
exclude
(
r
'test_argmin_negative_axis_keepdims_example_select_last_index_cpu'
)
backend_test
.
exclude
(
r
'test_argmin_no_keepdims_example_select_last_index_cpu'
)
backend_test
.
exclude
(
r
'test_logsoftmax_axis_0_cpu'
)
backend_test
.
exclude
(
r
'test_logsoftmax_axis_1_cpu'
)
backend_test
.
exclude
(
r
'test_logsoftmax_default_axis_cpu'
)
...
...
test/ref/argmax.cpp
View file @
2013da67
...
...
@@ -147,3 +147,37 @@ TEST_CASE(argmax_test_nonstd_shape)
res_gold
.
visit
([
&
](
auto
output
)
{
res_gold_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vec
,
res_gold_vec
));
}
TEST_CASE
(
argmax_test_select_last_index_0
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
float
>
data
=
{
2.0305
,
-
1.853
,
2.0305
,
-
1.5706
,
0.7545
,
0.7545
};
std
::
vector
<
int64_t
>
res_gold
=
{
2
,
2
};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
auto
dl
=
mm
->
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"argmax"
,
{{
"axis"
,
1
},
{
"select_last_index"
,
true
}}),
dl
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vec
,
res_gold
));
}
TEST_CASE
(
argmax_test_select_last_index_1
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
float
>
data
=
{
2.0305
,
-
1.853
,
2.0305
,
-
1.5706
,
0.7545
,
0.7545
};
std
::
vector
<
int64_t
>
res_gold
=
{
0
,
1
};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
auto
dl
=
mm
->
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"argmax"
,
{{
"axis"
,
1
},
{
"select_last_index"
,
false
}}),
dl
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vec
,
res_gold
));
}
test/ref/argmin.cpp
View file @
2013da67
...
...
@@ -125,3 +125,37 @@ TEST_CASE(argmin_test_nonstd_shape)
res_gold
.
visit
([
&
](
auto
output
)
{
res_gold_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vec
,
res_gold_vec
));
}
TEST_CASE
(
argmin_test_select_last_index_0
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
float
>
data
=
{
-
2.0305
,
0.853
,
-
2.0305
,
1.5706
,
0.7545
,
0.7545
};
std
::
vector
<
int64_t
>
res_gold
=
{
2
,
2
};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
auto
dl
=
mm
->
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"argmin"
,
{{
"axis"
,
1
},
{
"select_last_index"
,
true
}}),
dl
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vec
,
res_gold
));
}
TEST_CASE
(
argmin_test_select_last_index_1
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
float
>
data
=
{
-
2.0305
,
0.853
,
-
2.0305
,
1.5706
,
0.7545
,
0.7545
};
std
::
vector
<
int64_t
>
res_gold
=
{
0
,
1
};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
auto
dl
=
mm
->
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"argmin"
,
{{
"axis"
,
1
},
{
"select_last_index"
,
false
}}),
dl
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vec
,
res_gold
));
}
test/verify/test_arg_ops.cpp
View file @
2013da67
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -29,8 +29,8 @@
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
template
<
class
T
,
int
Axis
,
int
NonStdShape
>
struct
test_arg_ops
:
verify_program
<
test_arg_ops
<
T
,
Axis
,
NonStdShape
>>
template
<
class
T
,
int
Axis
,
bool
LastIndex
,
int
NonStdShape
>
struct
test_arg_ops
:
verify_program
<
test_arg_ops
<
T
,
Axis
,
LastIndex
,
NonStdShape
>>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -54,63 +54,111 @@ struct test_arg_ops : verify_program<test_arg_ops<T, Axis, NonStdShape>>
break
;
default:
break
;
}
mm
->
add_instruction
(
T
{
Axis
},
param
);
mm
->
add_instruction
(
T
{
Axis
,
LastIndex
},
param
);
return
p
;
}
};
// transpose argmax tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
false
,
0
>;
// transpose argmin tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
false
,
0
>;
// broadcast argmax tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
false
,
1
>;
// broadcast argmin tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
false
,
1
>;
// slice argmax tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
false
,
2
>;
// slice argmin tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
false
,
2
>;
// default case, standard shape argmax tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
false
,
3
>;
// default case, standard shape argmin tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
false
,
3
>;
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment