Unverified Commit 19d60bf8 authored by zhangyue's avatar zhangyue Committed by GitHub
Browse files

Merge pull request #422 from InfiniTensor/issue/421

issue/421: 适配 rmsnorm 测例修改,支持 bf16 和 f16数据类型 weights
parents c0d1b0d0 1048c1bc
......@@ -54,7 +54,7 @@ __device__ void causalSoftmaxBlock(
// Apply softmax
for (size_t col = core_id(); col < width; col += BLOCK_SIZE) {
if (sum_ != 0) {
y[col] = to<Tdata>(to<Tcompute>(y[col]) / sum_);
y[col] = Tdata(Tcompute(y[col]) / sum_);
} else {
y[col] = Tdata(0);
}
......
......@@ -27,7 +27,7 @@ __device__ void rmsnormBlock(
for (size_t i = core_id(); i < dim; i += BLOCK_SIZE) {
Tdata xi = x[i];
Tweight wi = w[i];
y[i] = static_cast<Tdata>(to<Tcompute>(xi) * to<Tcompute>(wi) * rms);
y[i] = Tdata(Tcompute(xi) * Tcompute(wi) * rms);
}
sync_cluster();
}
......
......@@ -95,10 +95,14 @@ infiniStatus_t launchKernel(
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(half, bfloat16_t, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(bfloat16_t, bfloat16_t, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(bfloat16_t, half, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(bfloat16_t, float, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
......
......@@ -14,12 +14,12 @@ __device__ inline Tcompute sumSquared(__shared_ptr__ const Tdata *data_ptr, size
for (size_t i = core_id(); i < count; i += BLOCK_SIZE) {
Tdata xi = data_ptr[i];
ss += to<Tcompute>(xi) * to<Tcompute>(xi);
ss += Tcompute(xi) * Tcompute(xi);
}
__shared__ Tcompute temp_storage;
if (core_id() == 0) {
temp_storage = to<Tcompute>(0.f);
temp_storage = Tcompute(0.f);
}
sync_cluster();
......@@ -36,12 +36,12 @@ __device__ inline Tcompute sum(__shared_ptr__ const Tdata *data_ptr, size_t coun
for (size_t i = core_id(); i < count; i += BLOCK_SIZE) {
Tdata xi = data_ptr[i];
ss += to<Tcompute>(xi);
ss += Tcompute(xi);
}
__shared__ Tcompute temp_storage;
if (core_id() == 0) {
temp_storage = to<Tcompute>(0.f);
temp_storage = Tcompute(0.f);
}
sync_cluster();
......@@ -58,7 +58,7 @@ __device__ inline Tdata max(__shared_ptr__ const Tdata *data_ptr, size_t count)
for (size_t i = core_id(); i < count; i += BLOCK_SIZE) {
Tdata xi = data_ptr[i];
max_val = fmax(max_val, to<Tdata>(xi));
max_val = fmax(max_val, Tdata(xi));
}
__shared__ Tdata temp_storage;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment