Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
63773ec0
"vscode:/vscode.git/clone" did not exist on "14344caa38e056ae569dce3b1aaefec712b16f2a"
Commit
63773ec0
authored
Jun 24, 2019
by
Shucai Xiao
Browse files
code cleanup for softmax.
parent
6ae2f087
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
27 deletions
+26
-27
src/targets/gpu/device/softmax.cpp
src/targets/gpu/device/softmax.cpp
+26
-27
No files found.
src/targets/gpu/device/softmax.cpp
View file @
63773ec0
...
...
@@ -14,7 +14,7 @@ namespace device {
template
<
class
T
>
__device__
void
reduce_max
(
MIGRAPHX_DEVICE_SHARED
T
*
data_ptr
,
size_t
block_size
,
size_t
thr_idx
,
size_t
item_num
)
reduce_max
(
T
*
data_ptr
,
size_t
block_size
,
size_t
thr_idx
,
size_t
item_num
)
{
auto
stride
=
(
item_num
+
1
)
/
2
;
while
(
true
)
...
...
@@ -43,7 +43,7 @@ reduce_max(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size, size_t thr_idx
template
<
class
T
>
__device__
void
reduce_sum
(
MIGRAPHX_DEVICE_SHARED
T
*
data_ptr
,
size_t
block_size
,
size_t
thr_idx
,
size_t
item_num
)
reduce_sum
(
T
*
data_ptr
,
size_t
block_size
,
size_t
thr_idx
,
size_t
item_num
)
{
auto
stride
=
(
item_num
+
1
)
/
2
;
while
(
true
)
...
...
@@ -62,7 +62,7 @@ reduce_sum(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size, size_t thr_idx
if
(
thr_idx
==
0
)
{
data_ptr
[
block_size
+
1
]
+=
data_ptr
[
0
];
data_ptr
[
block_size
]
+=
data_ptr
[
0
];
}
__syncthreads
();
...
...
@@ -72,7 +72,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
{
auto
lens
=
result
.
get_shape
().
lens
();
auto
batch_lens
=
lens
;
size_t
n_dims
=
lens
[
axis
];
size_t
batch_item_num
=
lens
[
axis
];
batch_lens
[
axis
]
=
1
;
migraphx
::
shape
batch_shape
{
result
.
get_shape
().
type
(),
batch_lens
};
...
...
@@ -86,7 +86,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
// use one block for items in one batch.
const
size_t
max_block_size
=
1024
;
size_t
block_size
=
1
;
while
(
block_size
<
max_block_size
and
block_size
<
n_dims
)
while
(
block_size
<
max_block_size
and
block_size
<
batch_item_num
)
{
block_size
*=
2
;
}
...
...
@@ -97,19 +97,16 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
size_t
blk_idx
=
idx
.
group
;
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>>
;
// all data can be loaded to the lds once, so all operations are
// done in lds
MIGRAPHX_DEVICE_SHARED
type
lds_data
[
max_block_size
+
2
];
MIGRAPHX_DEVICE_SHARED
type
lds_data
[
max_block_size
+
1
];
auto
batch_idx
=
desc_batch
.
multi
(
blk_idx
);
auto
data_idx
=
batch_idx
;
// load data to lds and compute the batch max
size_t
item_num
=
n_dims
;
size_t
thread
_num
=
(
n_dims
+
block_size
-
1
)
/
block_size
*
block_size
;
size_t
remaining_
item_num
=
batch_item_num
;
size_t
round_item
_num
=
(
batch_item_num
+
block_size
-
1
)
/
block_size
*
block_size
;
lds_data
[
block_size
]
=
input_ptr
[
0
];
lds_data
[
block_size
+
1
]
=
0
;
for
(
size_t
i
=
thr_idx
;
i
<
thread_num
;
i
+=
block_size
)
for
(
size_t
i
=
thr_idx
;
i
<
round_item_num
;
i
+=
block_size
)
{
if
(
i
<
n_dims
)
if
(
i
<
batch_item_num
)
{
data_idx
[
axis
]
=
i
;
lds_data
[
thr_idx
]
=
input_ptr
[
desc_data
.
linear
(
data_idx
)];
...
...
@@ -117,40 +114,42 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
__syncthreads
();
auto
size
=
(
item_num
>
block_size
)
?
block_size
:
item_num
;
auto
size
=
(
remaining_
item_num
>
block_size
)
?
block_size
:
remaining_
item_num
;
reduce_max
<
type
>
(
lds_data
,
block_size
,
thr_idx
,
size
);
__syncthreads
();
item_num
-=
block_size
;
remaining_item_num
-=
block_size
;
}
item_num
=
n_dims
;
for
(
size_t
i
=
thr_idx
;
i
<
thread_num
;
i
+=
block_size
)
auto
batch_max
=
lds_data
[
block_size
];
__syncthreads
();
lds_data
[
block_size
]
=
0
;
remaining_item_num
=
batch_item_num
;
for
(
size_t
i
=
thr_idx
;
i
<
round_item_num
;
i
+=
block_size
)
{
if
(
i
<
n_dims
)
if
(
i
<
batch_item_num
)
{
data_idx
[
axis
]
=
i
;
lds_data
[
thr_idx
]
=
input_ptr
[
desc_data
.
linear
(
data_idx
)]
-
lds_data
[
block_size
]
;
input_ptr
[
desc_data
.
linear
(
data_idx
)]
-
batch_max
;
lds_data
[
thr_idx
]
=
::
exp
(
to_hip_type
(
lds_data
[
thr_idx
]));
}
__syncthreads
();
auto
size
=
(
item_num
>
block_size
)
?
block_size
:
item_num
;
auto
size
=
(
remaining_
item_num
>
block_size
)
?
block_size
:
remaining_
item_num
;
reduce_sum
<
type
>
(
lds_data
,
block_size
,
thr_idx
,
size
);
__syncthreads
();
item_num
-=
block_size
;
remaining_
item_num
-=
block_size
;
}
auto
batch_sum
=
lds_data
[
block_size
];
for
(
size_t
i
=
thr_idx
;
i
<
n_dims
;
i
+=
block_size
)
for
(
size_t
i
=
thr_idx
;
i
<
batch_item_num
;
i
+=
block_size
)
{
data_idx
[
axis
]
=
i
;
size_t
index
=
desc_data
.
linear
(
data_idx
);
auto
val
=
input_ptr
[
index
]
-
lds_data
[
block_size
]
;
output_ptr
[
index
]
=
::
exp
(
to_hip_type
(
val
))
/
lds_data
[
block_size
+
1
]
;
auto
val
=
input_ptr
[
index
]
-
batch_max
;
output_ptr
[
index
]
=
::
exp
(
to_hip_type
(
val
))
/
batch_sum
;
}
});
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment