Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
4924bb45
Commit
4924bb45
authored
Jul 02, 2019
by
Shucai Xiao
Browse files
Merge branch 'argmax_min' of
https://github.com/ROCmSoftwarePlatform/AMDMIGraphX
into argmax_min
parents
add40fd1
8ffe3180
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
54 additions
and
95 deletions
+54
-95
src/include/migraphx/op/argmax.hpp
src/include/migraphx/op/argmax.hpp
+3
-2
src/include/migraphx/op/argmin.hpp
src/include/migraphx/op/argmin.hpp
+3
-2
src/targets/gpu/device/argmax.cpp
src/targets/gpu/device/argmax.cpp
+1
-1
src/targets/gpu/device/argmin.cpp
src/targets/gpu/device/argmin.cpp
+1
-1
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
+46
-89
No files found.
src/include/migraphx/op/argmax.hpp
View file @
4924bb45
...
...
@@ -45,9 +45,10 @@ struct argmax
for
(
std
::
size_t
i
=
1
;
i
<
item_num
;
++
i
)
{
indices
[
axis
]
=
i
;
if
(
max_val
<
input
(
indices
.
begin
(),
indices
.
end
()))
auto
cur_val
=
input
(
indices
.
begin
(),
indices
.
end
());
if
(
max_val
<
cur_val
)
{
max_val
=
input
(
indices
.
begin
(),
indices
.
end
())
;
max_val
=
cur_val
;
max_index
=
i
;
}
}
...
...
src/include/migraphx/op/argmin.hpp
View file @
4924bb45
...
...
@@ -50,9 +50,10 @@ struct argmin
for
(
std
::
size_t
i
=
1
;
i
<
item_num
;
++
i
)
{
indices
[
axis
]
=
i
;
if
(
min_val
>
input
(
indices
.
begin
(),
indices
.
end
()))
auto
cur_val
=
input
(
indices
.
begin
(),
indices
.
end
());
if
(
min_val
>
cur_val
)
{
min_val
=
input
(
indices
.
begin
(),
indices
.
end
())
;
min_val
=
cur_val
;
min_index
=
i
;
}
}
...
...
src/targets/gpu/device/argmax.cpp
View file @
4924bb45
...
...
@@ -16,7 +16,7 @@ void argmax(hipStream_t stream, const argument& result, const argument& arg, int
{
arg
.
visit
([
&
](
auto
input
)
{
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
input
)
::
value_type
>>
;
arg_op
<
pair_max
<
type
,
int64_t
>>
(
pair_max
<
type
,
int64_t
>
{},
stream
,
result
,
arg
,
axis
);
arg_op
<
type
,
argmax_op
<
type
>>
(
argmax_op
<
type
>
{},
stream
,
result
,
arg
,
axis
);
});
}
...
...
src/targets/gpu/device/argmin.cpp
View file @
4924bb45
...
...
@@ -16,7 +16,7 @@ void argmin(hipStream_t stream, const argument& result, const argument& arg, int
{
arg
.
visit
([
&
](
auto
input
)
{
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
input
)
::
value_type
>>
;
arg_op
<
pair_min
<
type
,
int64_t
>>
(
pair_min
<
type
,
int64_t
>
{},
stream
,
result
,
arg
,
axis
);
arg_op
<
type
,
argmin_op
<
type
>>
(
argmin_op
<
type
>
{},
stream
,
result
,
arg
,
axis
);
});
}
...
...
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
View file @
4924bb45
...
...
@@ -6,6 +6,7 @@
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/reduce.hpp>
#include <migraphx/gpu/hip.hpp>
namespace
migraphx
{
...
...
@@ -13,71 +14,50 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
device
{
template
<
class
T
,
class
F
>
struct
pair_ma
x
template
<
class
T
>
struct
val_inde
x
{
using
type
=
std
::
pair
<
T
,
F
>
;
// This implementation is to ensure when multiple values
// are of max, the min index is returned
type
operator
()(
type
x
,
type
y
)
const
T
val
;
int64_t
index
;
};
template
<
class
T
>
struct
argmax_op
{
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
operator
()(
val_index
<
T
>
x
,
val_index
<
T
>
y
)
const
{
if
(
x
.
first
>
y
.
first
)
if
(
x
.
val
>
y
.
val
)
return
x
;
else
if
(
x
.
first
<
y
.
first
)
else
if
(
x
.
val
<
y
.
val
)
return
y
;
else
{
return
(
x
.
second
<
y
.
second
)
?
x
:
y
;
return
(
x
.
index
<
y
.
index
)
?
x
:
y
;
}
}
};
template
<
class
T
,
class
F
>
struct
pair_min
{
using
type
=
std
::
pair
<
T
,
F
>
;
type
operator
()(
type
x
,
type
y
)
const
{
return
(
x
<
y
)
?
x
:
y
;
}
MIGRAPHX_DEVICE_CONSTEXPR
T
init
()
const
{
return
lowest
();
}
};
template
<
class
T
,
class
Op
>
inline
__device__
void
block_reduce_arg
(
T
*
data_ptr
,
int64_t
*
index_ptr
,
Op
op
,
std
::
size_t
block_size
,
std
::
size_t
thr_idx
,
std
::
size_t
item_num
,
std
::
size_t
output_index
)
template
<
class
T
>
struct
argmin_op
{
while
(
true
)
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
operator
()(
val_index
<
T
>
x
,
val_index
<
T
>
y
)
const
{
auto
stride
=
(
item_num
+
1
)
/
2
;
auto
size
=
item_num
/
2
;
for
(
std
::
size_t
i
=
thr_idx
;
i
<
size
;
i
+=
block_size
)
if
(
x
.
val
<
y
.
val
)
return
x
;
else
if
(
x
.
val
>
y
.
val
)
return
y
;
else
{
auto
output
=
op
({
data_ptr
[
i
],
index_ptr
[
i
]},
{
data_ptr
[
i
+
stride
],
index_ptr
[
i
+
stride
]});
data_ptr
[
i
]
=
output
.
first
;
index_ptr
[
i
]
=
output
.
second
;
return
(
x
.
index
<
y
.
index
)
?
x
:
y
;
}
__syncthreads
();
item_num
=
stride
;
if
(
item_num
==
1
)
break
;
}
if
(
thr_idx
==
0
)
{
auto
output
=
op
({
data_ptr
[
output_index
],
index_ptr
[
output_index
]},
{
data_ptr
[
0
],
index_ptr
[
0
]});
data_ptr
[
output_index
]
=
output
.
first
;
index_ptr
[
output_index
]
=
output
.
second
;
}
__syncthreads
();
}
MIGRAPHX_DEVICE_CONSTEXPR
T
init
()
const
{
return
highest
();
}
};
template
<
class
Op
>
template
<
class
T
,
class
Op
>
void
arg_op
(
Op
op
,
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
)
{
auto
arg_shape
=
arg
.
get_shape
();
...
...
@@ -90,50 +70,27 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a
hip_visit_all
(
arg
,
arg_shape
,
batch_shape
)([
&
](
auto
input
,
auto
arg_s
,
auto
batch_s
)
{
auto
output
=
device_cast
(
result
.
get
<
int64_t
>
().
data
());
// use one block for items in one batch.
const
size_t
max_block_size
=
1024
;
size_t
block_size
=
1
;
while
(
block_size
<
max_block_size
and
block_size
<
batch_item_num
)
{
block_size
*=
2
;
}
launch
(
stream
,
batch_shape
.
elements
()
*
block_size
,
block_size
)([
=
](
auto
idx
)
__device__
{
size_t
thr_idx
=
idx
.
local
;
size_t
blk_idx
=
idx
.
group
;
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
input
)
::
value_type
>>
;
auto
batch_idx
=
batch_s
.
multi
(
blk_idx
);
auto
data_idx
=
batch_idx
;
MIGRAPHX_DEVICE_SHARED
type
lds_data
[
max_block_size
+
1
];
MIGRAPHX_DEVICE_SHARED
int64_t
lds_index
[
max_block_size
+
1
];
// load data to lds_data
size_t
round_item_num
=
(
batch_item_num
+
block_size
-
1
)
/
block_size
*
block_size
;
size_t
remaining_item_num
=
batch_item_num
;
data_idx
[
axis
]
=
0
;
lds_data
[
max_block_size
]
=
input
[
arg_s
.
index
(
data_idx
)];
lds_index
[
max_block_size
]
=
0
;
for
(
size_t
i
=
thr_idx
;
i
<
round_item_num
;
i
+=
block_size
)
{
if
(
i
<
batch_item_num
)
const
size_t
max_block_size
=
256
;
const
std
::
size_t
block_size
=
compute_block_size
(
batch_item_num
,
max_block_size
);
gs_launch
(
stream
,
batch_shape
.
elements
()
*
block_size
,
block_size
)(
[
=
](
auto
i
,
auto
idx
)
__device__
{
auto
batch_idx
=
batch_s
.
multi
(
i
/
block_size
);
auto
data_idx
=
batch_idx
;
T
init_val
=
op
.
init
();
val_index
<
T
>
init
=
{
init_val
,
-
1
};
auto
op_output
=
block_reduce
<
max_block_size
,
Op
,
val_index
<
T
>>
(
idx
,
op
,
init
,
batch_item_num
,
[
&
](
auto
j
)
__device__
{
data_idx
[
axis
]
=
j
;
T
val
=
input
[
arg_s
.
index
(
data_idx
)];
return
val_index
<
T
>
{
val
,
static_cast
<
int64_t
>
(
j
)};
});
if
(
idx
.
local
==
0
)
{
data_idx
[
axis
]
=
i
;
lds_index
[
thr_idx
]
=
i
;
lds_data
[
thr_idx
]
=
input
[
arg_s
.
index
(
data_idx
)];
output
[
batch_s
.
index
(
batch_idx
)]
=
op_output
.
index
;
}
__syncthreads
();
auto
item_num
=
(
remaining_item_num
>
block_size
)
?
block_size
:
remaining_item_num
;
block_reduce_arg
<
type
,
Op
>
(
lds_data
,
lds_index
,
op
,
block_size
,
thr_idx
,
item_num
,
max_block_size
);
remaining_item_num
-=
block_size
;
}
if
(
thr_idx
==
0
)
{
output
[
batch_s
.
index
(
batch_idx
)]
=
lds_index
[
max_block_size
];
}
});
});
});
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment