Unverified Commit e7e96a29 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #748 from InfiniTensor/issue/746

issue/746: 修复causal_softmax在长宽在1024边缘的计算错误
parents 51beebc6 0b6bdab0
......@@ -28,7 +28,7 @@ __device__ void causalSoftmaxKernel(
// 1 | * * * ... * * |
// 2 | * * * ... * * * |
// height: 3 col_id->
if (width + blockIdx.x >= threadIdx.x + height) {
if (width + blockIdx.x >= col + height) {
if constexpr (std::is_same_v<Tdata, half> || std::is_same_v<Tdata, cuda_bfloat16>) {
y[col] = hexp(x[col] - max_);
} else {
......
......@@ -32,6 +32,9 @@ _TEST_CASES_ = [
((32, 20, 512), None, None),
((32, 20, 512), (20480, 512, 1), None),
((28, 15, 15), None, None),
((28, 1024, 1024), None, None),
((28, 1025, 1025), None, None),
((28, 1031, 1031), None, None),
]
# Data types used for testing
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment