"...git@developer.sourcefind.cn:tsoc/superbenchmark.git" did not exist on "2d9be807a9690fff140f0e8ba9cbd297edd6d502"
Commit 0c121609 authored by Catheriany's avatar Catheriany
Browse files

issue/312: causal softmax算子支持bf16

parent 9e0773f6
...@@ -38,6 +38,12 @@ infiniStatus_t launchKernel(void *y, const void *x, infiniDtype_t dtype, ...@@ -38,6 +38,12 @@ infiniStatus_t launchKernel(void *y, const void *x, infiniDtype_t dtype,
batch_size, seq_len, total_seq_len, batch_size, seq_len, total_seq_len,
y_stride_b, y_stride_i, y_stride_b, y_stride_i,
x_stride_b, x_stride_i); x_stride_b, x_stride_i);
} else if (dtype == INFINI_DTYPE_BF16) {
causalSoftmax<BLOCK_SIZE, __hpcc_bfloat16, float>
<<<grid, BLOCK_SIZE, 0, stream>>>((__hpcc_bfloat16 *)y, (const __hpcc_bfloat16 *)x,
batch_size, seq_len, total_seq_len,
y_stride_b, y_stride_i,
x_stride_b, x_stride_i);
} else if (dtype == INFINI_DTYPE_F32) { } else if (dtype == INFINI_DTYPE_F32) {
causalSoftmax<BLOCK_SIZE, float, float> causalSoftmax<BLOCK_SIZE, float, float>
<<<grid, BLOCK_SIZE, 0, stream>>>((float *)y, (const float *)x, <<<grid, BLOCK_SIZE, 0, stream>>>((float *)y, (const float *)x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment