Commit 8e485cc8 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

some additional code cleanup

parent a6477298
...@@ -358,12 +358,16 @@ void triadd_layernorm(hipStream_t stream, ...@@ -358,12 +358,16 @@ void triadd_layernorm(hipStream_t stream,
auto batch_item_num = in_s.lens().back(); auto batch_item_num = in_s.lens().back();
if(type == shape::half_type and (batch_item_num % 2) == 0) if(type == shape::half_type and (batch_item_num % 2) == 0)
{ {
auto half2_block_size = compute_block_size(batch_item_num, 1024); auto block_size = compute_block_size(batch_item_num, 1024);
int block_num = in_s.elements() / batch_item_num; int block_num = in_s.elements() / batch_item_num;
int shared_size = batch_item_num * 2 * in_s.type_size(); int shared_size = batch_item_num * 2 * in_s.type_size();
half2_block_size = half2_block_size / 4; auto half2_block_size = block_size / 4;
triadd_layernorm_half2<<<block_num, half2_block_size, shared_size, stream>>>( triadd_layernorm_half2<<<block_num, half2_block_size, shared_size, stream>>>(
arg1.data(), arg2.data(), arg3.data(), result.data(), batch_item_num, half2_block_size); arg1.data(), arg2.data(), arg3.data(), result.data(), batch_item_num, half2_block_size);
// auto half_block_size = block_size / 2;
// triadd_layernorm_half2<<<block_num, half_block_size, shared_size, stream>>>(
// arg1.data(), arg2.data(), arg3.data(), result.data(), batch_item_num, half_block_size);
} }
else else
{ {
...@@ -423,12 +427,16 @@ void layernorm(hipStream_t stream, const argument& result, const argument& arg1) ...@@ -423,12 +427,16 @@ void layernorm(hipStream_t stream, const argument& result, const argument& arg1)
auto batch_item_num = in_s.lens().back(); auto batch_item_num = in_s.lens().back();
if(type == shape::half_type and (batch_item_num % 2) == 0) if(type == shape::half_type and (batch_item_num % 2) == 0)
{ {
auto half2_block_size = compute_block_size(batch_item_num, 1024); auto block_size = compute_block_size(batch_item_num, 1024);
int block_num = in_s.elements() / batch_item_num; int block_num = in_s.elements() / batch_item_num;
int shared_size = batch_item_num * 2 * in_s.type_size(); int shared_size = batch_item_num * 2 * in_s.type_size();
half2_block_size = half2_block_size / 4; auto half2_block_size = block_size / 4;
layernorm_half2<<<block_num, half2_block_size, shared_size, stream>>>( layernorm_half2<<<block_num, half2_block_size, shared_size, stream>>>(
arg1.data(), result.data(), batch_item_num, half2_block_size); arg1.data(), result.data(), batch_item_num, half2_block_size);
// auto half_block_size = block_size / 2;
// layernorm_half2<<<block_num, half_block_size, shared_size, stream>>>(
// arg1.data(), result.data(), batch_item_num, half_block_size);
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment