Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
69c94135
Commit
69c94135
authored
Mar 28, 2022
by
Shucai Xiao
Browse files
half and half2 have the same results
parent
580673a0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
24 deletions
+23
-24
src/targets/gpu/device/layernorm.cpp
src/targets/gpu/device/layernorm.cpp
+23
-24
No files found.
src/targets/gpu/device/layernorm.cpp
View file @
69c94135
...
...
@@ -237,13 +237,13 @@ __global__ void triadd_layernorm_kernel_half2(
__half2
*
input2
=
reinterpret_cast
<
__half2
*>
(
in2
);
__half2
*
input3
=
reinterpret_cast
<
__half2
*>
(
in3
);
__half2
*
output
=
reinterpret_cast
<
__half2
*>
(
data_out
);
auto
rnum
=
__float2half2_rn
(
1.0
f
/
batch_item_num
);
batch_item_num
/=
2
;
extern
MIGRAPHX_DEVICE_SHARED
__half2
buffer2
[];
__half2
*
in_data_reduce
=
buffer2
;
__half2
*
in_data
=
buffer2
+
batch_item_num
;
int
start
=
blockIdx
.
x
*
batch_item_num
;
auto
rnum
=
__float2half2_rn
(
1.0
f
/
batch_item_num
);
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
int
idx
=
i
+
start
;
...
...
@@ -403,31 +403,30 @@ void triadd_layernorm(hipStream_t stream,
auto
in_s
=
arg1
.
get_shape
();
auto
type
=
in_s
.
type
();
auto
batch_item_num
=
in_s
.
lens
().
back
();
// if(type == shape::half_type and (batch_item_num % 2) == 0)
// {
// auto half2_block_size = compute_block_size(batch_item_num, 1024);
// int block_num = in_s.elements() / batch_item_num;
// int shared_size = batch_item_num * 2 * in_s.type_size();
// half2_block_size = half2_block_size / 4;
// triadd_layernorm_kernel_half2<<<block_num, half2_block_size, shared_size, stream>>>(
// arg1.data(), arg2.data(), arg3.data(), result.data(), batch_item_num,
// half2_block_size);
// }
// if(type == shape::half_type and (batch_item_num % 2) == 0)
if
(
type
==
shape
::
half_type
)
if
(
type
==
shape
::
half_type
and
(
batch_item_num
%
2
)
==
0
)
{
auto
reduce_block_size
=
compute_block_size
(
batch_item_num
,
1024
);
int
block_num
=
in_s
.
elements
()
/
batch_item_num
;
int
shared_size
=
batch_item_num
*
2
*
in_s
.
type_size
();
reduce_block_size
=
reduce_block_size
/
2
;
triadd_layernorm_kernel
<
__half
>
<<<
block_num
,
reduce_block_size
,
shared_size
,
stream
>>>
(
arg1
.
data
(),
arg2
.
data
(),
arg3
.
data
(),
result
.
data
(),
batch_item_num
,
reduce_block_size
);
auto
half2_block_size
=
compute_block_size
(
batch_item_num
,
1024
);
int
block_num
=
in_s
.
elements
()
/
batch_item_num
;
int
shared_size
=
batch_item_num
*
2
*
in_s
.
type_size
();
half2_block_size
=
half2_block_size
/
4
;
triadd_layernorm_kernel_half2
<<<
block_num
,
half2_block_size
,
shared_size
,
stream
>>>
(
arg1
.
data
(),
arg2
.
data
(),
arg3
.
data
(),
result
.
data
(),
batch_item_num
,
half2_block_size
);
}
// if(type == shape::half_type)
// {
// auto reduce_block_size = compute_block_size(batch_item_num, 1024);
// int block_num = in_s.elements() / batch_item_num;
// int shared_size = batch_item_num * 2 * in_s.type_size();
// reduce_block_size = reduce_block_size / 2;
// triadd_layernorm_kernel<__half>
// <<<block_num, reduce_block_size, shared_size, stream>>>(arg1.data(),
// arg2.data(),
// arg3.data(),
// result.data(),
// batch_item_num,
// reduce_block_size);
// }
else
{
layernorm_fusion
(
stream
,
result
,
arg1
,
arg2
,
arg3
)(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment