Unverified Commit 19d60bf8 authored by zhangyue's avatar zhangyue Committed by GitHub
Browse files

Merge pull request #422 from InfiniTensor/issue/421

issue/421: 适配 rmsnorm 测例修改,支持 bf16 和 f16数据类型 weights
parents c0d1b0d0 1048c1bc
...@@ -54,7 +54,7 @@ __device__ void causalSoftmaxBlock( ...@@ -54,7 +54,7 @@ __device__ void causalSoftmaxBlock(
// Apply softmax // Apply softmax
for (size_t col = core_id(); col < width; col += BLOCK_SIZE) { for (size_t col = core_id(); col < width; col += BLOCK_SIZE) {
if (sum_ != 0) { if (sum_ != 0) {
y[col] = to<Tdata>(to<Tcompute>(y[col]) / sum_); y[col] = Tdata(Tcompute(y[col]) / sum_);
} else { } else {
y[col] = Tdata(0); y[col] = Tdata(0);
} }
......
...@@ -27,7 +27,7 @@ __device__ void rmsnormBlock( ...@@ -27,7 +27,7 @@ __device__ void rmsnormBlock(
for (size_t i = core_id(); i < dim; i += BLOCK_SIZE) { for (size_t i = core_id(); i < dim; i += BLOCK_SIZE) {
Tdata xi = x[i]; Tdata xi = x[i];
Tweight wi = w[i]; Tweight wi = w[i];
y[i] = static_cast<Tdata>(to<Tcompute>(xi) * to<Tcompute>(wi) * rms); y[i] = Tdata(Tcompute(xi) * Tcompute(wi) * rms);
} }
sync_cluster(); sync_cluster();
} }
......
...@@ -95,10 +95,14 @@ infiniStatus_t launchKernel( ...@@ -95,10 +95,14 @@ infiniStatus_t launchKernel(
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) { if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float); LAUNCH_KERNEL(half, half, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(half, bfloat16_t, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) { } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float); LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) { } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(bfloat16_t, bfloat16_t, float); LAUNCH_KERNEL(bfloat16_t, bfloat16_t, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(bfloat16_t, half, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) { } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(bfloat16_t, float, float); LAUNCH_KERNEL(bfloat16_t, float, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) { } else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
......
...@@ -14,12 +14,12 @@ __device__ inline Tcompute sumSquared(__shared_ptr__ const Tdata *data_ptr, size ...@@ -14,12 +14,12 @@ __device__ inline Tcompute sumSquared(__shared_ptr__ const Tdata *data_ptr, size
for (size_t i = core_id(); i < count; i += BLOCK_SIZE) { for (size_t i = core_id(); i < count; i += BLOCK_SIZE) {
Tdata xi = data_ptr[i]; Tdata xi = data_ptr[i];
ss += to<Tcompute>(xi) * to<Tcompute>(xi); ss += Tcompute(xi) * Tcompute(xi);
} }
__shared__ Tcompute temp_storage; __shared__ Tcompute temp_storage;
if (core_id() == 0) { if (core_id() == 0) {
temp_storage = to<Tcompute>(0.f); temp_storage = Tcompute(0.f);
} }
sync_cluster(); sync_cluster();
...@@ -36,12 +36,12 @@ __device__ inline Tcompute sum(__shared_ptr__ const Tdata *data_ptr, size_t coun ...@@ -36,12 +36,12 @@ __device__ inline Tcompute sum(__shared_ptr__ const Tdata *data_ptr, size_t coun
for (size_t i = core_id(); i < count; i += BLOCK_SIZE) { for (size_t i = core_id(); i < count; i += BLOCK_SIZE) {
Tdata xi = data_ptr[i]; Tdata xi = data_ptr[i];
ss += to<Tcompute>(xi); ss += Tcompute(xi);
} }
__shared__ Tcompute temp_storage; __shared__ Tcompute temp_storage;
if (core_id() == 0) { if (core_id() == 0) {
temp_storage = to<Tcompute>(0.f); temp_storage = Tcompute(0.f);
} }
sync_cluster(); sync_cluster();
...@@ -58,7 +58,7 @@ __device__ inline Tdata max(__shared_ptr__ const Tdata *data_ptr, size_t count) ...@@ -58,7 +58,7 @@ __device__ inline Tdata max(__shared_ptr__ const Tdata *data_ptr, size_t count)
for (size_t i = core_id(); i < count; i += BLOCK_SIZE) { for (size_t i = core_id(); i < count; i += BLOCK_SIZE) {
Tdata xi = data_ptr[i]; Tdata xi = data_ptr[i];
max_val = fmax(max_val, to<Tdata>(xi)); max_val = fmax(max_val, Tdata(xi));
} }
__shared__ Tdata temp_storage; __shared__ Tdata temp_storage;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment