Unverified Commit 1337e81e authored by Sandeep Subramanian's avatar Sandeep Subramanian Committed by GitHub
Browse files

Time dimension shape check for fused scale mask softmax kernel (#1421)



* Time dimension shape check for fused scale mask softmax kernel
Signed-off-by: default avatarMaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Add shape test
Signed-off-by: default avatarMaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Fix mask shape
Signed-off-by: default avatarMaximumEntropy <sandeep.subramanian.1@umontreal.ca>
parent 5ff5a884
...@@ -165,6 +165,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -165,6 +165,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
and mask is not None # mask tensor must not be None and mask is not None # mask tensor must not be None
and 16 < sk <= 2048 # sk must be 16 ~ 2048 and 16 < sk <= 2048 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4 and sq % 4 == 0 # sq must be divisor of 4
and sk % 4 == 0 # sk must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4 and attn_batches % 4 == 0 # np * b must be divisor of 4
): ):
if 0 <= sk <= 2048: if 0 <= sk <= 2048:
......
...@@ -54,8 +54,8 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase): ...@@ -54,8 +54,8 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
attention_scores.shape = [4, 12, 24, 24] attention_scores.shape = [4, 12, 24, 24]
mask.shape = [4, 1, 24, 24] mask.shape = [4, 1, 24, 24]
""" """
for (dtype, scale, softmax_in_fp32) in itertools.product( for (dtype, scale, softmax_in_fp32, shape) in itertools.product(
(torch.half, torch.bfloat16), (None, 2.0), (False, True), (torch.half, torch.bfloat16), (None, 2.0), (False, True), ((4, 12, 24, 24), (32, 12, 4, 214))
): ):
with self.subTest(f"{dtype}-{scale}-{softmax_in_fp32}"): with self.subTest(f"{dtype}-{scale}-{softmax_in_fp32}"):
input_in_fp16 = dtype == torch.half input_in_fp16 = dtype == torch.half
...@@ -79,13 +79,14 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase): ...@@ -79,13 +79,14 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
) )
attention_scores_0 = ( attention_scores_0 = (
torch.randn((4, 12, 24, 24)) torch.randn(shape)
.to(device="cuda", dtype=dtype) .to(device="cuda", dtype=dtype)
.requires_grad_(True) .requires_grad_(True)
) )
with torch.no_grad(): with torch.no_grad():
attention_scores_1 = attention_scores_0.clone().requires_grad_(True) attention_scores_1 = attention_scores_0.clone().requires_grad_(True)
mask = torch.randint(0, 2, (4, 1, 24, 24), device="cuda").bool() mask_shape = (shape[0],) + (1,) + shape[2:]
mask = torch.randint(0, 2, mask_shape, device="cuda").bool()
expected = fused_fn(attention_scores_0, mask) expected = fused_fn(attention_scores_0, mask)
actual = torch_fn(attention_scores_1, mask) actual = torch_fn(attention_scores_1, mask)
torch.testing.assert_close(actual, expected) torch.testing.assert_close(actual, expected)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment