Unverified Commit f3a075b7 authored by Derui Yang's avatar Derui Yang Committed by GitHub
Browse files

Merge pull request #313 from InfiniTensor/issue/312

issue/312: 沐曦BF16编译问题
parents fa2b3207 ac807c71
......@@ -60,4 +60,9 @@ __forceinline__ __device__ __half
exp_(const __half x) {
return hexp(x);
}
__forceinline__ __device__ __hpcc_bfloat16;
exp_(const __hpcc_bfloat16; x) {
return hexp(x);
}
#endif
......@@ -38,6 +38,12 @@ infiniStatus_t launchKernel(void *y, const void *x, infiniDtype_t dtype,
batch_size, seq_len, total_seq_len,
y_stride_b, y_stride_i,
x_stride_b, x_stride_i);
} else if (dtype == INFINI_DTYPE_BF16) {
causalSoftmax<BLOCK_SIZE, __hpcc_bfloat16, float>
<<<grid, BLOCK_SIZE, 0, stream>>>((__hpcc_bfloat16 *)y, (const __hpcc_bfloat16 *)x,
batch_size, seq_len, total_seq_len,
y_stride_b, y_stride_i,
x_stride_b, x_stride_i);
} else if (dtype == INFINI_DTYPE_F32) {
causalSoftmax<BLOCK_SIZE, float, float>
<<<grid, BLOCK_SIZE, 0, stream>>>((float *)y, (const float *)x,
......
......@@ -107,6 +107,11 @@ struct CudaTval<fp16_t> {
using Type = half;
};
template <>
struct CudaTval<bf16_t> {
using Type = __hpcc_bfloat16;
};
// ↑↑↑ 通过特化将 fp16_t 转换为 half
// ↓↓↓ 用于采样过程的小型 kernel
......
......@@ -38,6 +38,7 @@ infiniStatus_t Descriptor::create(
case CASE: \
switch (info.dt_p) { \
CASE_P(INFINI_DTYPE_F16, Tidx, half); \
CASE_P(INFINI_DTYPE_BF16, Tidx, __hpcc_bfloat16); \
CASE_P(INFINI_DTYPE_F32, Tidx, float); \
CASE_P(INFINI_DTYPE_F64, Tidx, double); \
default: \
......
......@@ -58,6 +58,10 @@ infiniStatus_t launchKernel(
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(__hpcc_bfloat16, __hpcc_bfloat16, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(__hpcc_bfloat16, float, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment