"git@developer.sourcefind.cn:gaoqiong/pybind11.git" did not exist on "d46f3322a45cb907777f568bb1377ab1b9457ad8"
Unverified Commit 3b7b7c68 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Catch misaligned address errors in softmax (#390)



Catch misaligned address errors
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 8aa2da17
......@@ -20,8 +20,9 @@ at::Tensor scaled_softmax_forward(at::Tensor input,
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_CHECK(key_seq_len <= 4096);
TORCH_CHECK(query_seq_len > 1);
AT_ASSERTM(key_seq_len <= 4096, "Key sequence length must be 4096 or less");
AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8");
AT_ASSERTM(query_seq_len > 1, "Query sequence length must be greater than 1");
// Output
auto act_options = input.options().requires_grad(false);
......@@ -90,8 +91,10 @@ at::Tensor scaled_masked_softmax_forward(at::Tensor input,
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_CHECK(key_seq_len <= 4096);
TORCH_CHECK(query_seq_len > 1);
AT_ASSERTM(key_seq_len <= 4096, "Key sequence length must be 4096 or less");
AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8");
AT_ASSERTM(query_seq_len > 1, "Query sequence length must be greater than 1");
TORCH_CHECK(pad_batches == 1 || pad_batches == batches);
TORCH_CHECK(mask.size(1) == 1);
TORCH_CHECK(mask.size(2) == query_seq_len);
......@@ -157,7 +160,7 @@ at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input,
const int attn_batches = input.size(0);
const int seq_len = input.size(1);
TORCH_CHECK(seq_len <= 2048);
AT_ASSERTM(seq_len <= 2048, "Sequence length must be 2048 or less");
// Output
auto act_options = input.options().requires_grad(false);
......
......@@ -270,10 +270,11 @@ class FusedScaleMaskSoftmax(nn.Module):
"""Check FusedScaleMaskSoftmax kernel availability based on size"""
attn_batches = b * np
if (
if ( # pylint: disable=too-many-boolean-expressions
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and 16 < sk <= 4096 # sk must be 16 ~ 2048
and sk % 8 == 0 # sk must be divisor of 8
and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment