Unverified Commit 27b836c9 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #283 from InfiniTensor/issue/282

issue/282: Maca CausalSoftamx精度bug
parents af8bdb43 31e54f93
......@@ -18,7 +18,7 @@ INFINIOP_MACA_KERNEL causalSoftmax(
// [Reduce] Find max value in each row and store in shared memory
__shared__ Tdata max_;
Tdata max_0 = op::common_maca::reduce_op::max<BLOCK_SIZE, Tdata>(x, width);
Tdata max_0 = op::common_maca::reduce_op::max<BLOCK_SIZE, Tdata>(x, width - height + 1 + blockIdx.x);
if (threadIdx.x == 0) {
max_ = max_0;
}
......
......@@ -30,6 +30,7 @@ _TEST_CASES_ = [
((32, 5, 5), None, None),
((32, 20, 512), None, None),
((32, 20, 512), (20480, 512, 1), None),
((28, 15, 15), None, None),
]
# Data types used for testing
......@@ -93,7 +94,8 @@ def test(
)
x = torch.rand(shape, dtype=dtype).to(torch_device)
mask = torch.tril(torch.ones_like(x), diagonal=-1).flip(dims=[-2, -1])
x = torch.where(mask == 1, torch.full_like(x, torch.finfo(x.dtype).max), x)
ans = causal_softmax(x)
x = rearrange_if_needed(x, x_stride)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment