Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
9f06859b
"Dockerfile" did not exist on "9c4a1c74eafd996d9c6bcfe1379bea71f48fbf58"
Commit
9f06859b
authored
Mar 08, 2022
by
Shucai Xiao
Browse files
final version of softmax that works
parent
bc9eac75
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
6 deletions
+4
-6
src/targets/gpu/device/softmax.cpp
src/targets/gpu/device/softmax.cpp
+4
-6
No files found.
src/targets/gpu/device/softmax.cpp
View file @
9f06859b
...
@@ -38,6 +38,7 @@ template <class Op>
...
@@ -38,6 +38,7 @@ template <class Op>
__device__
__half2
__device__
__half2
block_reduce
(
__half2
*
buffer
,
index_int
batch_item_num
,
index_int
tid
,
index_int
block_size
,
Op
op
)
block_reduce
(
__half2
*
buffer
,
index_int
batch_item_num
,
index_int
tid
,
index_int
block_size
,
Op
op
)
{
{
__syncthreads
();
for
(
index_int
s
=
1
;
s
<
block_size
;
s
*=
2
)
for
(
index_int
s
=
1
;
s
<
block_size
;
s
*=
2
)
{
{
const
index_int
index
=
2
*
s
*
tid
;
const
index_int
index
=
2
*
s
*
tid
;
...
@@ -96,6 +97,7 @@ template <class Op>
...
@@ -96,6 +97,7 @@ template <class Op>
__device__
__half
__device__
__half
block_reduce2
(
__half
*
data
,
index_int
batch_item_num
,
index_int
tid
,
index_int
block_size
,
Op
op
)
block_reduce2
(
__half
*
data
,
index_int
batch_item_num
,
index_int
tid
,
index_int
block_size
,
Op
op
)
{
{
__syncthreads
();
for
(
index_int
s
=
1
;
s
<
block_size
;
s
*=
2
)
for
(
index_int
s
=
1
;
s
<
block_size
;
s
*=
2
)
{
{
const
index_int
index
=
2
*
s
*
tid
;
const
index_int
index
=
2
*
s
*
tid
;
...
@@ -124,21 +126,16 @@ softmax_kernel2(void* data_in, index_int batch_item_num, index_int block_size, v
...
@@ -124,21 +126,16 @@ softmax_kernel2(void* data_in, index_int batch_item_num, index_int block_size, v
auto
d
=
input
[
i
+
start
];
auto
d
=
input
[
i
+
start
];
in_data
[
i
]
=
d
;
in_data
[
i
]
=
d
;
in_data_reduce
[
i
]
=
d
;
in_data_reduce
[
i
]
=
d
;
// printf("blockIdx = %d, ori_val = %f\n", start, __half2float(d));
}
}
auto
batch_max
=
block_reduce2
(
in_data_reduce
,
batch_item_num
,
threadIdx
.
x
,
block_size
,
max
{});
auto
batch_max
=
block_reduce2
(
in_data_reduce
,
batch_item_num
,
threadIdx
.
x
,
block_size
,
max
{});
// printf("blockIdx = %d, batch_max = %f\n", start, __half2float(batch_max));
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
{
in_data
[
i
]
=
__float2half
(
::
exp
(
__half2float
(
in_data
[
i
])
-
__half2float
(
batch_max
)));
in_data
[
i
]
=
__float2half
(
::
exp
(
__half2float
(
in_data
[
i
])
-
__half2float
(
batch_max
)));
in_data_reduce
[
i
]
=
in_data
[
i
];
in_data_reduce
[
i
]
=
in_data
[
i
];
// printf("blockIdx = %d, exp_val = %f\n", start, __half2float(in_data[i]));
}
}
auto
batch_sum
=
block_reduce2
(
in_data_reduce
,
batch_item_num
,
threadIdx
.
x
,
block_size
,
sum
{});
auto
batch_sum
=
block_reduce2
(
in_data_reduce
,
batch_item_num
,
threadIdx
.
x
,
block_size
,
sum
{});
// printf("blockIdx = %d, batch_sum = %f\n", start, __half2float(batch_sum));
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
{
output
[
i
+
start
]
=
__float2half
(
__half2float
(
in_data
[
i
])
/
__half2float
(
batch_sum
));
output
[
i
+
start
]
=
__float2half
(
__half2float
(
in_data
[
i
])
/
__half2float
(
batch_sum
));
...
@@ -153,7 +150,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
...
@@ -153,7 +150,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
migraphx
::
shape
batch_shape
{
result
.
get_shape
().
type
(),
batch_lens
};
migraphx
::
shape
batch_shape
{
result
.
get_shape
().
type
(),
batch_lens
};
hip_visit_all
(
result
,
arg
,
batch_shape
)([
&
](
auto
output
,
auto
input
,
auto
batch
)
{
hip_visit_all
(
result
,
arg
,
batch_shape
)([
&
](
auto
output
,
auto
input
,
auto
batch
)
{
const
index_int
max_block_size
=
1
28
;
const
index_int
max_block_size
=
1
024
;
const
index_int
block_size
=
compute_block_size
(
batch_item_num
,
max_block_size
);
const
index_int
block_size
=
compute_block_size
(
batch_item_num
,
max_block_size
);
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
input
)
::
value_type
>>
;
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
input
)
::
value_type
>>
;
type
init
=
lowest
();
type
init
=
lowest
();
...
@@ -165,6 +162,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
...
@@ -165,6 +162,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
{
{
int
block_num
=
batch_shape
.
elements
();
int
block_num
=
batch_shape
.
elements
();
int
shared_size
=
batch_item_num
*
2
*
result
.
get_shape
().
type_size
();
int
shared_size
=
batch_item_num
*
2
*
result
.
get_shape
().
type_size
();
softmax_kernel
<<<
block_num
,
block_size
,
shared_size
,
stream
>>>
(
softmax_kernel
<<<
block_num
,
block_size
,
shared_size
,
stream
>>>
(
arg
.
data
(),
batch_item_num
,
block_size
,
result
.
data
());
arg
.
data
(),
batch_item_num
,
block_size
,
result
.
data
());
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment