Commit 80a6ca93 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

backup code changes

parent a5181cd0
...@@ -160,7 +160,7 @@ void layernorm_impl(hipStream_t stream, ...@@ -160,7 +160,7 @@ void layernorm_impl(hipStream_t stream,
const Arguments&... args) const Arguments&... args)
{ {
hip_visit_all(result, args...)([&](auto output, auto... inputs) { hip_visit_all(result, args...)([&](auto output, auto... inputs) {
const std::size_t max_block_size = 128; const std::size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(relements, max_block_size); const std::size_t block_size = compute_block_size(relements, max_block_size);
const std::size_t block_size_div = encode_divisor(block_size); const std::size_t block_size_div = encode_divisor(block_size);
assert(relements <= block_size); assert(relements <= block_size);
...@@ -248,18 +248,23 @@ __global__ void triadd_layernorm_kernel_half2( ...@@ -248,18 +248,23 @@ __global__ void triadd_layernorm_kernel_half2(
{ {
int idx = i + start; int idx = i + start;
in_data[i] = __hadd2(__hadd2(input1[idx], input2[idx]), input3[idx]); in_data[i] = __hadd2(__hadd2(input1[idx], input2[idx]), input3[idx]);
in_data_reduce[i] = __hmul2(in_data[i], rnum); in_data_reduce[i] = in_data[i];
// in_data_reduce[i] = __hmul2(in_data[i], rnum);
} }
auto m = auto m =
block_reduce_half2(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_sum{}); block_reduce_half2(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_sum{});
m = __hmul2(m, rnum);
for(int i = threadIdx.x; i < batch_item_num; i += block_size) for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{ {
in_data[i] = __hsub2(in_data[i], m); in_data[i] = __hsub2(in_data[i], m);
in_data_reduce[i] = __hmul2(__hmul2(in_data[i], in_data[i]), rnum); // in_data_reduce[i] = __hmul2(__hmul2(in_data[i], in_data[i]), rnum);
in_data_reduce[i] = __hmul2(in_data[i], in_data[i]);
} }
m = block_reduce_half2(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_sum{}); m = block_reduce_half2(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_sum{});
m = __hmul2(m, rnum);
auto eps = __float2half2_rn(1.0e-12f); auto eps = __float2half2_rn(1.0e-12f);
auto r = __hadd2(m, eps); auto r = __hadd2(m, eps);
...@@ -368,23 +373,23 @@ __global__ void triadd_layernorm_kernel( ...@@ -368,23 +373,23 @@ __global__ void triadd_layernorm_kernel(
{ {
int idx = i + start; int idx = i + start;
in_data[i] = input1[idx] + input2[idx] + input3[idx]; in_data[i] = input1[idx] + input2[idx] + input3[idx];
in_data_reduce[i] = in_data[i] * rnum; in_data_reduce[i] = __half2float(in_data[i]) * rnum;
} }
auto m = block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size); auto m = block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size);
for(int i = threadIdx.x; i < batch_item_num; i += block_size) for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{ {
in_data[i] = in_data[i] - m; in_data[i] = in_data[i] - m;
in_data_reduce[i] = in_data[i] * in_data[i] * rnum; in_data_reduce[i] = __half2float(in_data[i] * in_data[i]) * rnum;
} }
m = block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size) + 1.0e-12f; m = __half2float(block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size)) + 1.0e-12f;
auto r = rsqrt(m); auto r = rsqrt(m);
for(int i = threadIdx.x; i < batch_item_num; i += block_size) for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{ {
int idx = i + start; int idx = i + start;
output[idx] = in_data[i] * r; output[idx] = __half2float(in_data[i]) * r;
} }
} }
...@@ -397,29 +402,30 @@ void triadd_layernorm(hipStream_t stream, ...@@ -397,29 +402,30 @@ void triadd_layernorm(hipStream_t stream,
auto in_s = arg1.get_shape(); auto in_s = arg1.get_shape();
auto type = in_s.type(); auto type = in_s.type();
auto batch_item_num = in_s.lens().back(); auto batch_item_num = in_s.lens().back();
if(type == shape::half_type and (batch_item_num % 2) == 0)
{
auto half2_block_size = compute_block_size(batch_item_num, 1024);
int block_num = in_s.elements() / batch_item_num;
int shared_size = batch_item_num * 2 * in_s.type_size();
half2_block_size = half2_block_size / 4;
triadd_layernorm_kernel_half2<<<block_num, half2_block_size, shared_size, stream>>>(
arg1.data(), arg2.data(), arg3.data(), result.data(), batch_item_num, half2_block_size);
}
// if(type == shape::half_type and (batch_item_num % 2) == 0) // if(type == shape::half_type and (batch_item_num % 2) == 0)
// { // {
// auto reduce_block_size = compute_block_size(batch_item_num, 1024); // auto half2_block_size = compute_block_size(batch_item_num, 1024);
// int block_num = in_s.elements() / batch_item_num; // int block_num = in_s.elements() / batch_item_num;
// int shared_size = batch_item_num * 2 * in_s.type_size(); // int shared_size = batch_item_num * 2 * in_s.type_size();
// reduce_block_size = reduce_block_size / 2; // half2_block_size = half2_block_size / 4;
// triadd_layernorm_kernel_half<<<block_num, reduce_block_size, shared_size, stream>>>( // triadd_layernorm_kernel_half2<<<block_num, half2_block_size, shared_size, stream>>>(
// arg1.data(), // arg1.data(), arg2.data(), arg3.data(), result.data(), batch_item_num, half2_block_size);
// arg2.data(),
// arg3.data(),
// result.data(),
// batch_item_num,
// reduce_block_size);
// } // }
// if(type == shape::half_type and (batch_item_num % 2) == 0)
if(type == shape::half_type)
{
auto reduce_block_size = compute_block_size(batch_item_num, 1024);
int block_num = in_s.elements() / batch_item_num;
int shared_size = batch_item_num * 2 * in_s.type_size();
reduce_block_size = reduce_block_size / 2;
triadd_layernorm_kernel<__half><<<block_num, reduce_block_size, shared_size, stream>>>(
arg1.data(),
arg2.data(),
arg3.data(),
result.data(),
batch_item_num,
reduce_block_size);
}
else else
{ {
layernorm_fusion(stream, result, arg1, arg2, arg3)( layernorm_fusion(stream, result, arg1, arg2, arg3)(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment