Unverified Commit 3f60d11a authored by Alessandro Palla's avatar Alessandro Palla Committed by GitHub
Browse files

Improve _update_causal_mask performance (#29210)

* Fix issue 29206

* Fix style
parent 75ed76ec
...@@ -959,15 +959,14 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -959,15 +959,14 @@ class GemmaModel(GemmaPreTrainedModel):
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
# We use the current dtype to avoid any overflows # We use the current dtype to avoid any overflows
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype
causal_mask = causal_mask.to(dtype=dtype, device=device) causal_mask = causal_mask.to(dtype=dtype, device=device)
if attention_mask is not None and attention_mask.dim() == 2: if attention_mask is not None and attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1] mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill( causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
padding_mask, torch.finfo(dtype).min
)
if self.config._attn_implementation == "sdpa" and attention_mask is not None: if self.config._attn_implementation == "sdpa" and attention_mask is not None:
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
...@@ -980,9 +979,7 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -980,9 +979,7 @@ class GemmaModel(GemmaPreTrainedModel):
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213 # Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1, keepdim=True)).to( causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)).to(dtype)
dtype
)
return causal_mask return causal_mask
......
...@@ -1066,15 +1066,14 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -1066,15 +1066,14 @@ class LlamaModel(LlamaPreTrainedModel):
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
# We use the current dtype to avoid any overflows # We use the current dtype to avoid any overflows
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype
causal_mask = causal_mask.to(dtype=dtype, device=device) causal_mask = causal_mask.to(dtype=dtype, device=device)
if attention_mask is not None and attention_mask.dim() == 2: if attention_mask is not None and attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1] mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill( causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
padding_mask, torch.finfo(dtype).min
)
if self.config._attn_implementation == "sdpa" and attention_mask is not None: if self.config._attn_implementation == "sdpa" and attention_mask is not None:
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
...@@ -1087,9 +1086,7 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -1087,9 +1086,7 @@ class LlamaModel(LlamaPreTrainedModel):
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213 # Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1, keepdim=True)).to( causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)).to(dtype)
dtype
)
return causal_mask return causal_mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment