Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f2726514
Commit
f2726514
authored
Jun 21, 2019
by
Shucai Xiao
Browse files
further optimization of the logsoftmax operator.
parent
7959306c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
16 deletions
+36
-16
src/targets/gpu/device/logsoftmax.cpp
src/targets/gpu/device/logsoftmax.cpp
+36
-16
No files found.
src/targets/gpu/device/logsoftmax.cpp
View file @
f2726514
...
@@ -38,7 +38,6 @@ argument logsoftmax(hipStream_t stream,
...
@@ -38,7 +38,6 @@ argument logsoftmax(hipStream_t stream,
stream
,
batch_shape
.
elements
()
*
block_size
,
block_size
)([
=
](
auto
idx
)
__device__
{
stream
,
batch_shape
.
elements
()
*
block_size
,
block_size
)([
=
](
auto
idx
)
__device__
{
size_t
thr_idx
=
idx
.
local
;
size_t
thr_idx
=
idx
.
local
;
size_t
blk_idx
=
idx
.
group
;
size_t
blk_idx
=
idx
.
group
;
// using type = typename decltype(input)::value_type;
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>>
;
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>>
;
// all data can be loaded to the lds once, so all operations are
// all data can be loaded to the lds once, so all operations are
...
@@ -56,18 +55,29 @@ argument logsoftmax(hipStream_t stream,
...
@@ -56,18 +55,29 @@ argument logsoftmax(hipStream_t stream,
__syncthreads
();
__syncthreads
();
// use thread 0 for batch_max
auto
size
=
(
item_num
>
block_size
)
?
block_size
:
item_num
;
if
(
thr_idx
==
0
)
auto
stride
=
(
size
+
1
)
/
2
;
while
(
true
)
{
{
auto
size
=
(
item_num
>
block_size
)
?
block_size
:
item_num
;
if
(
thr_idx
+
stride
<
size
)
for
(
size_t
j
=
0
;
j
<
size
;
j
++
)
{
{
lds_data
[
block_size
]
=
lds_data
[
thr_idx
]
=
::
max
(
to_hip_type
(
lds_data
[
block_size
]),
to_hip_type
(
lds_data
[
j
]));
::
max
(
to_hip_type
(
lds_data
[
thr_idx
]),
to_hip_type
(
lds_data
[
thr_idx
+
stride
]));
}
}
item_num
-=
block_size
;
__syncthreads
();
size
=
stride
;
stride
=
(
stride
+
1
)
/
2
;
if
(
size
==
1
)
break
;
}
if
(
thr_idx
==
0
)
{
lds_data
[
block_size
]
=
(
lds_data
[
0
]
<
lds_data
[
block_size
])
?
lds_data
[
block_size
]
:
lds_data
[
0
];
}
}
__syncthreads
();
__syncthreads
();
item_num
-=
block_size
;
}
}
const
size_t
block_size1
=
block_size
+
1
;
const
size_t
block_size1
=
block_size
+
1
;
...
@@ -76,22 +86,32 @@ argument logsoftmax(hipStream_t stream,
...
@@ -76,22 +86,32 @@ argument logsoftmax(hipStream_t stream,
for
(
size_t
i
=
thr_idx
;
i
<
num_in_batch
;
i
+=
block_size
)
for
(
size_t
i
=
thr_idx
;
i
<
num_in_batch
;
i
+=
block_size
)
{
{
data_idx
[
axis
]
=
i
;
data_idx
[
axis
]
=
i
;
lds_data
[
i
]
=
input_ptr
[
desc_data
.
linear
(
data_idx
)];
lds_data
[
i
]
=
input_ptr
[
desc_data
.
linear
(
data_idx
)]
-
lds_data
[
block_size
];
lds_data
[
i
]
=
::
exp
(
to_hip_type
(
lds_data
[
i
]));
__syncthreads
();
__syncthreads
();
// use thread 0 for batch_max
auto
size
=
(
item_num
>
block_size
)
?
block_size
:
item_num
;
if
(
thr_idx
==
0
)
auto
stride
=
(
size
+
1
)
/
2
;
while
(
true
)
{
{
auto
size
=
(
item_num
>
block_size
)
?
block_size
:
item_num
;
if
(
thr_idx
+
stride
<
size
)
for
(
size_t
j
=
0
;
j
<
size
;
j
++
)
{
{
lds_data
[
block_size1
]
+=
lds_data
[
thr_idx
]
+=
lds_data
[
thr_idx
+
stride
];
::
exp
(
to_hip_type
(
lds_data
[
j
]
-
lds_data
[
block_size
]));
}
}
item_num
-=
block_size
;
__syncthreads
();
size
=
stride
;
stride
=
(
stride
+
1
)
/
2
;
if
(
size
==
1
)
break
;
}
if
(
thr_idx
==
0
)
{
lds_data
[
block_size1
]
+=
lds_data
[
0
];
}
}
__syncthreads
();
__syncthreads
();
item_num
-=
block_size
;
}
}
auto
log_batch_sum
=
auto
log_batch_sum
=
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment