Unverified Commit 3b7b7c68 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Catch misaligned address errors in softmax (#390)



Catch misaligned address errors
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 8aa2da17
...@@ -20,8 +20,9 @@ at::Tensor scaled_softmax_forward(at::Tensor input, ...@@ -20,8 +20,9 @@ at::Tensor scaled_softmax_forward(at::Tensor input,
const int query_seq_len = input.size(2); const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3); const int key_seq_len = input.size(3);
TORCH_CHECK(key_seq_len <= 4096); AT_ASSERTM(key_seq_len <= 4096, "Key sequence length must be 4096 or less");
TORCH_CHECK(query_seq_len > 1); AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8");
AT_ASSERTM(query_seq_len > 1, "Query sequence length must be greater than 1");
// Output // Output
auto act_options = input.options().requires_grad(false); auto act_options = input.options().requires_grad(false);
...@@ -90,8 +91,10 @@ at::Tensor scaled_masked_softmax_forward(at::Tensor input, ...@@ -90,8 +91,10 @@ at::Tensor scaled_masked_softmax_forward(at::Tensor input,
const int attn_heads = input.size(1); const int attn_heads = input.size(1);
const int query_seq_len = input.size(2); const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3); const int key_seq_len = input.size(3);
TORCH_CHECK(key_seq_len <= 4096);
TORCH_CHECK(query_seq_len > 1); AT_ASSERTM(key_seq_len <= 4096, "Key sequence length must be 4096 or less");
AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8");
AT_ASSERTM(query_seq_len > 1, "Query sequence length must be greater than 1");
TORCH_CHECK(pad_batches == 1 || pad_batches == batches); TORCH_CHECK(pad_batches == 1 || pad_batches == batches);
TORCH_CHECK(mask.size(1) == 1); TORCH_CHECK(mask.size(1) == 1);
TORCH_CHECK(mask.size(2) == query_seq_len); TORCH_CHECK(mask.size(2) == query_seq_len);
...@@ -157,7 +160,7 @@ at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, ...@@ -157,7 +160,7 @@ at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input,
const int attn_batches = input.size(0); const int attn_batches = input.size(0);
const int seq_len = input.size(1); const int seq_len = input.size(1);
TORCH_CHECK(seq_len <= 2048); AT_ASSERTM(seq_len <= 2048, "Sequence length must be 2048 or less");
// Output // Output
auto act_options = input.options().requires_grad(false); auto act_options = input.options().requires_grad(false);
......
...@@ -270,10 +270,11 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -270,10 +270,11 @@ class FusedScaleMaskSoftmax(nn.Module):
"""Check FusedScaleMaskSoftmax kernel availability based on size""" """Check FusedScaleMaskSoftmax kernel availability based on size"""
attn_batches = b * np attn_batches = b * np
if ( if ( # pylint: disable=too-many-boolean-expressions
self.scaled_masked_softmax_fusion # user want to fuse self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16 and self.input_in_float16 # input must be fp16
and 16 < sk <= 4096 # sk must be 16 ~ 2048 and 16 < sk <= 4096 # sk must be 16 ~ 2048
and sk % 8 == 0 # sk must be divisor of 8
and sq % 4 == 0 # sq must be divisor of 4 and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4 and attn_batches % 4 == 0 # np * b must be divisor of 4
): ):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment