Unverified Commit f3a075b7 authored by Derui Yang's avatar Derui Yang Committed by GitHub
Browse files

Merge pull request #313 from InfiniTensor/issue/312

issue/312: 沐曦BF16编译问题
parents fa2b3207 ac807c71
...@@ -60,4 +60,9 @@ __forceinline__ __device__ __half ...@@ -60,4 +60,9 @@ __forceinline__ __device__ __half
exp_(const __half x) { exp_(const __half x) {
return hexp(x); return hexp(x);
} }
__forceinline__ __device__ __hpcc_bfloat16;
exp_(const __hpcc_bfloat16; x) {
return hexp(x);
}
#endif #endif
...@@ -38,6 +38,12 @@ infiniStatus_t launchKernel(void *y, const void *x, infiniDtype_t dtype, ...@@ -38,6 +38,12 @@ infiniStatus_t launchKernel(void *y, const void *x, infiniDtype_t dtype,
batch_size, seq_len, total_seq_len, batch_size, seq_len, total_seq_len,
y_stride_b, y_stride_i, y_stride_b, y_stride_i,
x_stride_b, x_stride_i); x_stride_b, x_stride_i);
} else if (dtype == INFINI_DTYPE_BF16) {
causalSoftmax<BLOCK_SIZE, __hpcc_bfloat16, float>
<<<grid, BLOCK_SIZE, 0, stream>>>((__hpcc_bfloat16 *)y, (const __hpcc_bfloat16 *)x,
batch_size, seq_len, total_seq_len,
y_stride_b, y_stride_i,
x_stride_b, x_stride_i);
} else if (dtype == INFINI_DTYPE_F32) { } else if (dtype == INFINI_DTYPE_F32) {
causalSoftmax<BLOCK_SIZE, float, float> causalSoftmax<BLOCK_SIZE, float, float>
<<<grid, BLOCK_SIZE, 0, stream>>>((float *)y, (const float *)x, <<<grid, BLOCK_SIZE, 0, stream>>>((float *)y, (const float *)x,
......
...@@ -107,6 +107,11 @@ struct CudaTval<fp16_t> { ...@@ -107,6 +107,11 @@ struct CudaTval<fp16_t> {
using Type = half; using Type = half;
}; };
template <>
struct CudaTval<bf16_t> {
using Type = __hpcc_bfloat16;
};
// ↑↑↑ 通过特化将 fp16_t 转换为 half // ↑↑↑ 通过特化将 fp16_t 转换为 half
// ↓↓↓ 用于采样过程的小型 kernel // ↓↓↓ 用于采样过程的小型 kernel
......
...@@ -34,15 +34,16 @@ infiniStatus_t Descriptor::create( ...@@ -34,15 +34,16 @@ infiniStatus_t Descriptor::create(
workspace_size = workspace_result.take(); \ workspace_size = workspace_result.take(); \
} break } break
#define CASE_I(CASE, Tidx) \ #define CASE_I(CASE, Tidx) \
case CASE: \ case CASE: \
switch (info.dt_p) { \ switch (info.dt_p) { \
CASE_P(INFINI_DTYPE_F16, Tidx, half); \ CASE_P(INFINI_DTYPE_F16, Tidx, half); \
CASE_P(INFINI_DTYPE_F32, Tidx, float); \ CASE_P(INFINI_DTYPE_BF16, Tidx, __hpcc_bfloat16); \
CASE_P(INFINI_DTYPE_F64, Tidx, double); \ CASE_P(INFINI_DTYPE_F32, Tidx, float); \
default: \ CASE_P(INFINI_DTYPE_F64, Tidx, double); \
abort(); \ default: \
} \ abort(); \
} \
break break
switch (info.dt_i) { switch (info.dt_i) {
......
...@@ -58,6 +58,10 @@ infiniStatus_t launchKernel( ...@@ -58,6 +58,10 @@ infiniStatus_t launchKernel(
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) { if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float); LAUNCH_KERNEL(half, half, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(__hpcc_bfloat16, __hpcc_bfloat16, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(__hpcc_bfloat16, float, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) { } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float); LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) { } else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment