Commit 4916bae6 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

conditioning fused kernels

parent 872e38ea
...@@ -119,11 +119,13 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -119,11 +119,13 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
data_size = input.size() data_size = input.size()
query_seq_len = data_size[-2] query_seq_len = data_size[-2]
key_seq_len = data_size[-1] key_seq_len = data_size[-1]
attn_batch_size = data_size[0] * data_size[1]
assert input.dim() == 4 assert input.dim() == 4
# invoke custom kernel # invoke custom kernel
if self.input_in_fp16 and key_seq_len <= 2048 and mask is not None and \ if self.input_in_fp16 and key_seq_len <= 2048 and mask is not None and \
query_seq_len % 4 == 0 and self.scaled_masked_softmax_fusion: query_seq_len % 4 == 0 and key_seq_len > 16 and \
attn_batch_size % 4 == 0 and self.scaled_masked_softmax_fusion:
scale = self.scale if self.scale is not None else 1.0 scale = self.scale if self.scale is not None else 1.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment