Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
1da02b0f
"docs/_removed/examples.rst" did not exist on "abd164c2598d4cf19a081b4e5c1070de7bea8386"
Commit
1da02b0f
authored
Mar 09, 2022
by
Shucai Xiao
Browse files
backup softmax changes
parent
ae59a3b1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
14 deletions
+11
-14
src/targets/gpu/device/softmax.cpp
src/targets/gpu/device/softmax.cpp
+11
-14
No files found.
src/targets/gpu/device/softmax.cpp
View file @
1da02b0f
...
...
@@ -39,12 +39,11 @@ __device__ __half2
block_reduce
(
__half2
*
buffer
,
index_int
batch_item_num
,
index_int
tid
,
index_int
block_size
,
Op
op
)
{
__syncthreads
();
for
(
index_int
s
=
1
;
s
<
block_size
;
s
*
=
2
)
for
(
index_int
s
=
block_size
;
s
>
0
;
s
>>
=
1
)
{
const
index_int
index
=
2
*
s
*
tid
;
if
(
index
+
s
<
batch_item_num
)
if
(
tid
<
s
and
tid
+
s
<
batch_item_num
)
{
buffer
[
index
]
=
op
(
buffer
[
index
],
buffer
[
index
+
s
]);
buffer
[
tid
]
=
op
(
buffer
[
tid
],
buffer
[
tid
+
s
]);
}
__syncthreads
();
}
...
...
@@ -61,12 +60,11 @@ softmax_kernel(void* data_in, index_int batch_item_num, index_int block_size, vo
__half2
*
input
=
reinterpret_cast
<
__half2
*>
(
data_in
);
__half2
*
output
=
reinterpret_cast
<
__half2
*>
(
data_out
);
batch_item_num
/=
2
;
int
tid
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
extern
MIGRAPHX_DEVICE_SHARED
__half2
buffer2
[];
__half2
*
in_data_reduce
=
buffer2
;
__half2
*
in_data
=
buffer2
+
batch_item_num
;
int
start
=
tid
/
block_size
*
batch_item_num
;
int
start
=
blockIdx
.
x
*
batch_item_num
;
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
auto
d
=
input
[
i
+
start
];
...
...
@@ -98,12 +96,11 @@ __device__ __half
block_reduce2
(
__half
*
data
,
index_int
batch_item_num
,
index_int
tid
,
index_int
block_size
,
Op
op
)
{
__syncthreads
();
for
(
index_int
s
=
1
;
s
<
block_size
;
s
*
=
2
)
for
(
index_int
s
=
block_size
/
2
;
s
>
0
;
s
>>
=
1
)
{
const
index_int
index
=
2
*
s
*
tid
;
if
(
index
+
s
<
batch_item_num
)
if
(
tid
<
s
and
tid
+
s
<
batch_item_num
)
{
data
[
index
]
=
op
(
data
[
index
],
data
[
index
+
s
]);
data
[
tid
]
=
op
(
data
[
tid
],
data
[
tid
+
s
]);
}
__syncthreads
();
}
...
...
@@ -158,13 +155,13 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
if
(
axis
==
batch_lens
.
size
()
-
1
)
{
auto
in_type
=
result
.
get_shape
().
type
();
if
(
in_type
==
shape
::
half_type
and
batch_item_num
<=
2048
)
if
(
in_type
==
shape
::
half_type
and
batch_item_num
<=
1024
)
{
int
block_num
=
batch_shape
.
elements
();
int
shared_size
=
batch_item_num
*
2
*
result
.
get_shape
().
type_size
();
softmax_kernel
<<<
block_num
,
block_size
,
shared_size
,
stream
>>>
(
arg
.
data
(),
batch_item_num
,
block_size
,
result
.
data
());
auto
half2_block_size
=
block_size
/
4
;
softmax_kernel
<<<
block_num
,
half2_
block_size
,
shared_size
,
stream
>>>
(
arg
.
data
(),
batch_item_num
,
half2_
block_size
,
result
.
data
());
}
else
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment