Unverified Commit 1337e81e authored by Sandeep Subramanian's avatar Sandeep Subramanian Committed by GitHub
Browse files

Time dimension shape check for fused scale mask softmax kernel (#1421)



* Time dimension shape check for fused scale mask softmax kernel
Signed-off-by: default avatarMaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Add shape test
Signed-off-by: default avatarMaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Fix mask shape
Signed-off-by: default avatarMaximumEntropy <sandeep.subramanian.1@umontreal.ca>
parent 5ff5a884
......@@ -165,6 +165,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
and mask is not None # mask tensor must not be None
and 16 < sk <= 2048 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4
and sk % 4 == 0 # sk must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 2048:
......
......@@ -54,8 +54,8 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
attention_scores.shape = [4, 12, 24, 24]
mask.shape = [4, 1, 24, 24]
"""
for (dtype, scale, softmax_in_fp32) in itertools.product(
(torch.half, torch.bfloat16), (None, 2.0), (False, True),
for (dtype, scale, softmax_in_fp32, shape) in itertools.product(
(torch.half, torch.bfloat16), (None, 2.0), (False, True), ((4, 12, 24, 24), (32, 12, 4, 214))
):
with self.subTest(f"{dtype}-{scale}-{softmax_in_fp32}"):
input_in_fp16 = dtype == torch.half
......@@ -79,13 +79,14 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
)
attention_scores_0 = (
torch.randn((4, 12, 24, 24))
torch.randn(shape)
.to(device="cuda", dtype=dtype)
.requires_grad_(True)
)
with torch.no_grad():
attention_scores_1 = attention_scores_0.clone().requires_grad_(True)
mask = torch.randint(0, 2, (4, 1, 24, 24), device="cuda").bool()
mask_shape = (shape[0],) + (1,) + shape[2:]
mask = torch.randint(0, 2, mask_shape, device="cuda").bool()
expected = fused_fn(attention_scores_0, mask)
actual = torch_fn(attention_scores_1, mask)
torch.testing.assert_close(actual, expected)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment