Unverified Commit f88d4ad8 authored by 蒋帅宏(Shuaihong_Jiang)'s avatar 蒋帅宏(Shuaihong_Jiang) Committed by GitHub
Browse files

issue/254: 添加算子在CPU和CUDA上对BF16的支持,并增加相应的测试代码 (#255)



* issue/254: 添加算子在CPU和CUDA上对BF16的支持,并增加相应的测试代码

* issue/254: 将修改后的算子格式化后重新提交

* 修改与最新main的冲突

* 解决冲突后rms_norm原本的精度过不了了,现在由
{"atol": 5e-3, "rtol": 5e-3}更改为
{"atol": 8e-3, "rtol": 8e-3}

* rms_norm在debug模式下FP16的测试用例失败了(本地测试能通过,github上过不了),
所以将容差增大了两倍进行测试

* 将rms_normd的测试输入缩放0.5,将容差改回原始值来进行ci测试

* issue/254: 1.使用CHECK_DTYPE宏来进行数据类型检验
2.在test的utils.py中添加了设备对BF16支持的检验

* issue/254: rms_norm测试fp16容差由
torch.float16: {"atol": 1e-3, "rtol": 1e-3},
改为torch.float16: {"atol": 2e-3, "rtol": 2e-3},
并删除对输入0.5的放缩

* issue/254: 在utils.py中debug方法和debug_all方法中
添加了对BF16的特判

* 修改支持BF16测试的设备类型检查方法

* 修改支持BF16测试的设备检查

* issue/254: reduce redundancy in rms_norm.py

* issue/254: add back the missing comment in rms_norm.py

* issue/254: add fp32 tolerance condition in causal_softmax.py

---------
Co-authored-by: default avatarZimin Li <coollizimin@gmail.com>
parent 105065e2
......@@ -65,4 +65,9 @@ __forceinline__ __device__ __half
exp_(const __half x) {
return hexp(x);
}
__forceinline__ __device__ __nv_bfloat16
exp_(const __nv_bfloat16 x) {
return hexp(x);
}
#endif
......@@ -176,8 +176,8 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info,
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id)));
};
if constexpr (std::is_same_v<Tdata, fp16_t>) {
out[out_idx] = utils::cast<fp16_t>(Op{}(utils::cast<float>(ins[Is][get_input_idx(Is)])..., std::forward<Args>(args)...));
if constexpr (std::is_same_v<Tdata, fp16_t> || std::is_same_v<Tdata, bf16_t>) {
out[out_idx] = utils::cast<Tdata>(Op{}(utils::cast<float>(ins[Is][get_input_idx(Is)])..., std::forward<Args>(args)...));
} else {
out[out_idx] = Op{}(ins[Is][get_input_idx(Is)]..., std::forward<Args>(args)...);
}
......
......@@ -29,24 +29,24 @@ infiniStatus_t causal_softmax(const CausalSoftmaxInfo *info, T *y, const T *x) {
const T *x_ = x + x_offset;
for (size_t j = info->total_seq_len - info->seq_len + i + 1; j < info->total_seq_len; j++) {
if constexpr (std::is_same<T, fp16_t>::value) {
y_[j * info->y_stride_j] = utils::cast<fp16_t>(0.0f);
if constexpr (std::is_same<T, fp16_t>::value || std::is_same<T, bf16_t>::value) {
y_[j * info->y_stride_j] = utils::cast<T>(0.0f);
} else {
y_[j * info->y_stride_j] = 0.0f;
}
}
float val = op::common_cpu::reduce_op::max(x_, info->total_seq_len - info->seq_len + i + 1, info->x_stride_j);
for (size_t j = 0; j <= info->total_seq_len - info->seq_len + i; j++) {
if constexpr (std::is_same<T, fp16_t>::value) {
y_[j * info->y_stride_j] = utils::cast<fp16_t>(std::exp(utils::cast<float>(x_[j * info->x_stride_j]) - val));
if constexpr (std::is_same<T, fp16_t>::value || std::is_same<T, bf16_t>::value) {
y_[j * info->y_stride_j] = utils::cast<T>(std::exp(utils::cast<float>(x_[j * info->x_stride_j]) - val));
} else {
y_[j * info->y_stride_j] = std::exp(x_[j * info->x_stride_j] - val);
}
}
float sum = op::common_cpu::reduce_op::sum(y_, info->total_seq_len - info->seq_len + i + 1, info->y_stride_j);
for (size_t j = 0; j <= info->total_seq_len - info->seq_len + i; j++) {
if constexpr (std::is_same<T, fp16_t>::value) {
y_[j * info->y_stride_j] = utils::cast<fp16_t>(utils::cast<float>(y_[j * info->y_stride_j]) / sum);
if constexpr (std::is_same<T, fp16_t>::value || std::is_same<T, bf16_t>::value) {
y_[j * info->y_stride_j] = utils::cast<T>(utils::cast<float>(y_[j * info->y_stride_j]) / sum);
} else {
y_[j * info->y_stride_j] = y_[j * info->y_stride_j] / sum;
}
......@@ -64,6 +64,8 @@ infiniStatus_t Descriptor::calculate(
if (_info.dtype == INFINI_DTYPE_F16) {
CHECK_STATUS(causal_softmax<fp16_t>(&_info, (fp16_t *)y, (const fp16_t *)x));
} else if (_info.dtype == INFINI_DTYPE_BF16) {
CHECK_STATUS(causal_softmax<bf16_t>(&_info, (bf16_t *)y, (const bf16_t *)x));
} else if (_info.dtype == INFINI_DTYPE_F32) {
CHECK_STATUS(causal_softmax<float>(&_info, (float *)y, (const float *)x));
} else {
......
......@@ -38,6 +38,12 @@ infiniStatus_t launchKernel(void *y, const void *x, infiniDtype_t dtype,
batch_size, seq_len, total_seq_len,
y_stride_b, y_stride_i,
x_stride_b, x_stride_i);
} else if (dtype == INFINI_DTYPE_BF16) {
causalSoftmax<BLOCK_SIZE, __nv_bfloat16, float>
<<<grid, BLOCK_SIZE, 0, stream>>>((__nv_bfloat16 *)y, (const __nv_bfloat16 *)x,
batch_size, seq_len, total_seq_len,
y_stride_b, y_stride_i,
x_stride_b, x_stride_i);
} else if (dtype == INFINI_DTYPE_F32) {
causalSoftmax<BLOCK_SIZE, float, float>
<<<grid, BLOCK_SIZE, 0, stream>>>((float *)y, (const float *)x,
......
......@@ -29,7 +29,7 @@ public:
if (dtype != x_desc->dtype()) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32);
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
auto shape = y_desc->shape();
CHECK_SAME_SHAPE(shape, x_desc->shape());
......
......@@ -14,9 +14,7 @@ infiniStatus_t Descriptor::create(
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
auto dtype = c_desc->dtype();
if (dtype != INFINI_DTYPE_F16 && dtype != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
CHECK_RESULT(result);
......@@ -53,17 +51,17 @@ void calculate(
for (int k_ = 0; k_ < static_cast<int>(info.k); ++k_) {
auto a_ = reinterpret_cast<const Tdata *>(a) + i * info.a_matrix.stride + m_ * info.a_matrix.row_stride + k_ * info.a_matrix.col_stride;
auto b_ = reinterpret_cast<const Tdata *>(b) + i * info.b_matrix.stride + n_ * info.b_matrix.col_stride + k_ * info.b_matrix.row_stride;
if constexpr (std::is_same<Tdata, fp16_t>::value) {
if constexpr (std::is_same<Tdata, fp16_t>::value || std::is_same<Tdata, bf16_t>::value) {
sum += utils::cast<float>(*a_) * utils::cast<float>(*b_);
} else {
sum += *a_ * (*b_);
}
}
if constexpr (std::is_same<Tdata, fp16_t>::value) {
if constexpr (std::is_same<Tdata, fp16_t>::value || std::is_same<Tdata, bf16_t>::value) {
if (beta == 0) {
*c_ = utils::cast<fp16_t>(alpha * sum);
*c_ = utils::cast<Tdata>(alpha * sum);
} else {
*c_ = utils::cast<fp16_t>(beta * utils::cast<float>(*c_) + alpha * sum);
*c_ = utils::cast<Tdata>(beta * utils::cast<float>(*c_) + alpha * sum);
}
} else {
*c_ = beta * (*c_) + alpha * sum;
......@@ -86,6 +84,10 @@ infiniStatus_t Descriptor::calculate(
cpu::calculate<fp16_t>(_info, c, beta, a, b, alpha);
return INFINI_STATUS_SUCCESS;
case INFINI_DTYPE_BF16:
cpu::calculate<bf16_t>(_info, c, beta, a, b, alpha);
return INFINI_STATUS_SUCCESS;
case INFINI_DTYPE_F32:
cpu::calculate<float>(_info, c, beta, a, b, alpha);
return INFINI_STATUS_SUCCESS;
......
......@@ -20,9 +20,7 @@ infiniStatus_t Descriptor::create(
auto handle = reinterpret_cast<device::cuda::nvidia::Handle *>(handle_);
auto dtype = c_desc->dtype();
if (dtype != INFINI_DTYPE_F16 && dtype != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
CHECK_RESULT(result);
......@@ -52,7 +50,10 @@ infiniStatus_t Descriptor::calculate(
a_type = b_type = c_type = CUDA_R_16F;
compute_type = CUBLAS_COMPUTE_32F;
break;
case INFINI_DTYPE_BF16:
a_type = b_type = c_type = CUDA_R_16BF;
compute_type = CUBLAS_COMPUTE_32F;
break;
case INFINI_DTYPE_F32:
a_type = b_type = c_type = CUDA_R_32F;
#ifdef ENABLE_SUGON_CUDA_API
......
......@@ -40,6 +40,11 @@ struct ComputeType<fp16_t> {
using type = float;
};
template <>
struct ComputeType<bf16_t> {
using type = float;
};
struct Algo {
template <class Tidx, class Tval>
......
......@@ -37,6 +37,7 @@ infiniStatus_t Descriptor::create(
case CASE: \
switch (info.dt_p) { \
CASE_P(INFINI_DTYPE_F16, Tidx, half); \
CASE_P(INFINI_DTYPE_BF16, Tidx, __nv_bfloat16); \
CASE_P(INFINI_DTYPE_F32, Tidx, float); \
CASE_P(INFINI_DTYPE_F64, Tidx, double); \
default: \
......
......@@ -107,6 +107,11 @@ struct CudaTval<fp16_t> {
using Type = half;
};
template <>
struct CudaTval<bf16_t> {
using Type = __nv_bfloat16;
};
// ↑↑↑ 通过特化将 fp16_t 转换为 half
// ↓↓↓ 用于采样过程的小型 kernel
......
......@@ -18,8 +18,7 @@ struct RandomSampleInfo {
auto dt_p = probs_desc->dtype();
CHECK_DTYPE_ANY_INT(dt_i);
CHECK_DTYPE(dt_p, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_DTYPE(dt_p, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_OR_RETURN(result_desc->ndim() == 0, INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_OR_RETURN(probs_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_OR_RETURN(probs_desc->stride(0) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
......
......@@ -87,6 +87,9 @@ class Calculate {
case INFINI_DTYPE_F16:
switch_f<Tidx, fp16_t>(algo, n, args);
break;
case INFINI_DTYPE_BF16:
switch_f<Tidx, bf16_t>(algo, n, args);
break;
case INFINI_DTYPE_F32:
switch_f<Tidx, float>(algo, n, args);
break;
......
......@@ -40,12 +40,15 @@ infiniStatus_t rmsnorm(const RMSNormInfo *info, T *y, const T *x, const T *w) {
return INFINI_STATUS_SUCCESS;
}
template <typename Tw>
infiniStatus_t rmsnormF16(const RMSNormInfo *info, fp16_t *y, const fp16_t *x, const Tw *w) {
template <typename T, typename Tw>
infiniStatus_t rmsnormHalfPrecision(const RMSNormInfo *info, T *y, const T *x, const Tw *w) {
static_assert(std::is_same<T, fp16_t>::value || std::is_same<T, bf16_t>::value,
"T must be fp16_t or bf16_t");
#pragma omp parallel for
for (ptrdiff_t i = 0; i < ptrdiff_t(info->shape[0]); i++) {
fp16_t *x_ = (fp16_t *)(x + i * info->x_strides[0]);
fp16_t *y_ = (fp16_t *)(y + i * info->y_strides[0]);
T *x_ = (T *)(x + i * info->x_strides[0]);
T *y_ = (T *)(y + i * info->y_strides[0]);
// [Reduce] sum of x^2 on last dimension
float ss = op::common_cpu::reduce_op::sumSquared(x_, info->shape[1], info->x_strides[1]);
......@@ -56,10 +59,10 @@ infiniStatus_t rmsnormF16(const RMSNormInfo *info, fp16_t *y, const fp16_t *x, c
for (size_t j = 0; j < info->shape[1]; j++) {
if constexpr (std::is_same<Tw, float>::value) {
float val = utils::cast<float>(x_[j * info->x_strides[1]]) * w[j] * rms;
y_[j * info->y_strides[1]] = utils::cast<fp16_t>(val);
} else if constexpr (std::is_same<Tw, fp16_t>::value) {
y_[j * info->y_strides[1]] = utils::cast<T>(val);
} else if constexpr (std::is_same<Tw, T>::value) {
float val = utils::cast<float>(x_[j * info->x_strides[1]]) * utils::cast<float>(w[j]) * rms;
y_[j * info->y_strides[1]] = utils::cast<fp16_t>(val);
y_[j * info->y_strides[1]] = utils::cast<T>(val);
} else {
std::abort();
}
......@@ -75,9 +78,17 @@ infiniStatus_t Descriptor::calculate(
void *stream) const {
if (_info.atype == INFINI_DTYPE_F16) {
if (_info.wtype == INFINI_DTYPE_F16) {
CHECK_STATUS(rmsnormF16(&_info, (fp16_t *)y, (const fp16_t *)x, (const fp16_t *)w));
CHECK_STATUS(rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)x, (const fp16_t *)w));
} else if (_info.wtype == INFINI_DTYPE_F32) {
CHECK_STATUS(rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)x, (const float *)w));
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (_info.atype == INFINI_DTYPE_BF16) {
if (_info.wtype == INFINI_DTYPE_BF16) {
CHECK_STATUS(rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)x, (const bf16_t *)w));
} else if (_info.wtype == INFINI_DTYPE_F32) {
CHECK_STATUS(rmsnormF16(&_info, (fp16_t *)y, (const fp16_t *)x, (const float *)w));
CHECK_STATUS(rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)x, (const float *)w));
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......
......@@ -60,6 +60,10 @@ infiniStatus_t launchKernel(
LAUNCH_KERNEL(half, half, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(__nv_bfloat16, __nv_bfloat16, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(__nv_bfloat16, float, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(float, float, float);
} else {
......
......@@ -32,11 +32,13 @@ public:
if (x_desc->dtype() != atype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (atype == INFINI_DTYPE_F16) {
if (wtype != INFINI_DTYPE_F16 && wtype != INFINI_DTYPE_F32) {
if (atype == INFINI_DTYPE_F16 || atype == INFINI_DTYPE_BF16) {
// For half-precision types (FP16/BF16), weights can be the same half-precision type or FP32
if (wtype != atype && wtype != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (atype == INFINI_DTYPE_F32 || atype == INFINI_DTYPE_F64) {
// For FP32/FP64, activations and weights must be of the same type
if (atype != wtype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......
......@@ -49,14 +49,14 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
size_t pos0 = 2 * i;
size_t pos1 = 2 * i + 1;
if constexpr (std::is_same<Tdata, fp16_t>::value) {
if constexpr (std::is_same<Tdata, fp16_t>::value || std::is_same<Tdata, bf16_t>::value) {
float x0 = utils::cast<float>(x[x_offset + pos0]),
x1 = utils::cast<float>(x[x_offset + pos1]),
sin__ = utils::cast<float>(sin_table[table_offset + i]),
cos__ = utils::cast<float>(cos_table[table_offset + i]);
y[y_offset + pos0] = utils::cast<fp16_t>(x0 * cos__ - x1 * sin__);
y[y_offset + pos1] = utils::cast<fp16_t>(x0 * sin__ + x1 * cos__);
y[y_offset + pos0] = utils::cast<Tdata>(x0 * cos__ - x1 * sin__);
y[y_offset + pos1] = utils::cast<Tdata>(x0 * sin__ + x1 * cos__);
} else {
Tdata x0 = x[x_offset + pos0],
x1 = x[x_offset + pos1],
......@@ -111,6 +111,8 @@ infiniStatus_t Descriptor::calculate(
switch (_info.data_type) {
case INFINI_DTYPE_F16:
ROPE_TYPE(fp16_t);
case INFINI_DTYPE_BF16:
ROPE_TYPE(bf16_t);
case INFINI_DTYPE_F32:
ROPE_TYPE(float);
case INFINI_DTYPE_F64:
......
......@@ -102,6 +102,8 @@ infiniStatus_t Descriptor::calculate(
switch (_info.data_type) {
case INFINI_DTYPE_F16:
ROPE_TYPE(half);
case INFINI_DTYPE_BF16:
ROPE_TYPE(__nv_bfloat16);
case INFINI_DTYPE_F32:
ROPE_TYPE(float);
case INFINI_DTYPE_F64:
......
......@@ -30,6 +30,17 @@ INFINIOP_CUDA_KERNEL ropeThreadPerItem(
Tangle y0 = x.x * cos__ - x.y * sin__,
y1 = x.x * sin__ + x.y * cos__;
y = half2(y0, y1);
} else if constexpr (std::is_same<Tdata, __nv_bfloat16>::value) {
auto &y = reinterpret_cast<__nv_bfloat162 &>(y_[y_offset + 2 * i]);
auto &x = reinterpret_cast<const __nv_bfloat162 &>(x_[x_offset + 2 * i]);
Tangle x0 = __low2bfloat16(x);
Tangle x1 = __high2bfloat16(x);
Tangle y0 = x0 * cos__ - x1 * sin__;
Tangle y1 = x0 * sin__ + x1 * cos__;
y = __floats2bfloat162_rn(y0, y1);
} else {
Tangle x0 = x_[x_offset + 2 * i],
x1 = x_[x_offset + 2 * i + 1];
......
......@@ -78,7 +78,7 @@ public:
const infiniDtype_t pos_type = pos_desc->dtype();
CHECK_OR_RETURN(data_type == x_desc->dtype() && data_type == sin_desc->dtype() && data_type == cos_desc->dtype(),
INFINI_STATUS_BAD_TENSOR_DTYPE);
CHECK_DTYPE(data_type, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_DTYPE(data_type, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_DTYPE_ANY_INT(pos_type);
CHECK_OR_RETURN(y_desc->ndim() == 3
......
......@@ -19,7 +19,7 @@ infiniStatus_t Descriptor::create(
const auto &up_shape = up_desc->shape();
const auto &gate_shape = gate_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape);
......@@ -39,6 +39,8 @@ infiniStatus_t Descriptor::calculate(
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<SwiGLUOp, fp16_t>(_info, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<SwiGLUOp, bf16_t>(_info, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<SwiGLUOp, float>(_info, output, inputs, stream);
case INFINI_DTYPE_F64:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment