Commit 31e54f93 authored by Catheriany's avatar Catheriany
Browse files

issue/282: 添加max_reduction测试

parent 53468445
......@@ -30,6 +30,7 @@ _TEST_CASES_ = [
((32, 5, 5), None, None),
((32, 20, 512), None, None),
((32, 20, 512), (20480, 512, 1), None),
((28, 15, 15), None, None),
]
# Data types used for testing
......@@ -93,7 +94,8 @@ def test(
)
x = torch.rand(shape, dtype=dtype).to(torch_device)
mask = torch.tril(torch.ones_like(x), diagonal=-1).flip(dims=[-2, -1])
x = torch.where(mask == 1, torch.full_like(x, torch.finfo(x.dtype).max), x)
ans = causal_softmax(x)
x = rearrange_if_needed(x, x_stride)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment