Unverified Commit 27b836c9 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #283 from InfiniTensor/issue/282

issue/282: Maca CausalSoftamx精度bug
parents af8bdb43 31e54f93
...@@ -18,7 +18,7 @@ INFINIOP_MACA_KERNEL causalSoftmax( ...@@ -18,7 +18,7 @@ INFINIOP_MACA_KERNEL causalSoftmax(
// [Reduce] Find max value in each row and store in shared memory // [Reduce] Find max value in each row and store in shared memory
__shared__ Tdata max_; __shared__ Tdata max_;
Tdata max_0 = op::common_maca::reduce_op::max<BLOCK_SIZE, Tdata>(x, width); Tdata max_0 = op::common_maca::reduce_op::max<BLOCK_SIZE, Tdata>(x, width - height + 1 + blockIdx.x);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
max_ = max_0; max_ = max_0;
} }
......
...@@ -30,6 +30,7 @@ _TEST_CASES_ = [ ...@@ -30,6 +30,7 @@ _TEST_CASES_ = [
((32, 5, 5), None, None), ((32, 5, 5), None, None),
((32, 20, 512), None, None), ((32, 20, 512), None, None),
((32, 20, 512), (20480, 512, 1), None), ((32, 20, 512), (20480, 512, 1), None),
((28, 15, 15), None, None),
] ]
# Data types used for testing # Data types used for testing
...@@ -93,7 +94,8 @@ def test( ...@@ -93,7 +94,8 @@ def test(
) )
x = torch.rand(shape, dtype=dtype).to(torch_device) x = torch.rand(shape, dtype=dtype).to(torch_device)
mask = torch.tril(torch.ones_like(x), diagonal=-1).flip(dims=[-2, -1])
x = torch.where(mask == 1, torch.full_like(x, torch.finfo(x.dtype).max), x)
ans = causal_softmax(x) ans = causal_softmax(x)
x = rearrange_if_needed(x, x_stride) x = rearrange_if_needed(x, x_stride)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment