Commit 0c121609 authored by Catheriany's avatar Catheriany
Browse files

issue/312: causal softmax算子支持bf16

parent 9e0773f6
......@@ -38,6 +38,12 @@ infiniStatus_t launchKernel(void *y, const void *x, infiniDtype_t dtype,
batch_size, seq_len, total_seq_len,
y_stride_b, y_stride_i,
x_stride_b, x_stride_i);
} else if (dtype == INFINI_DTYPE_BF16) {
causalSoftmax<BLOCK_SIZE, __hpcc_bfloat16, float>
<<<grid, BLOCK_SIZE, 0, stream>>>((__hpcc_bfloat16 *)y, (const __hpcc_bfloat16 *)x,
batch_size, seq_len, total_seq_len,
y_stride_b, y_stride_i,
x_stride_b, x_stride_i);
} else if (dtype == INFINI_DTYPE_F32) {
causalSoftmax<BLOCK_SIZE, float, float>
<<<grid, BLOCK_SIZE, 0, stream>>>((float *)y, (const float *)x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment