Unverified Commit 3f60d11a authored by Alessandro Palla's avatar Alessandro Palla Committed by GitHub
Browse files

Improve _update_causal_mask performance (#29210)

* Fix issue 29206

* Fix style
parent 75ed76ec
......@@ -959,15 +959,14 @@ class GemmaModel(GemmaPreTrainedModel):
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
# We use the current dtype to avoid any overflows
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min
min_dtype = torch.finfo(dtype).min
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype
causal_mask = causal_mask.to(dtype=dtype, device=device)
if attention_mask is not None and attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(
padding_mask, torch.finfo(dtype).min
)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
if self.config._attn_implementation == "sdpa" and attention_mask is not None:
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
......@@ -980,9 +979,7 @@ class GemmaModel(GemmaPreTrainedModel):
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1, keepdim=True)).to(
dtype
)
causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)).to(dtype)
return causal_mask
......
......@@ -1066,15 +1066,14 @@ class LlamaModel(LlamaPreTrainedModel):
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
# We use the current dtype to avoid any overflows
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min
min_dtype = torch.finfo(dtype).min
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype
causal_mask = causal_mask.to(dtype=dtype, device=device)
if attention_mask is not None and attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(
padding_mask, torch.finfo(dtype).min
)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
if self.config._attn_implementation == "sdpa" and attention_mask is not None:
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
......@@ -1087,9 +1086,7 @@ class LlamaModel(LlamaPreTrainedModel):
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1, keepdim=True)).to(
dtype
)
causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)).to(dtype)
return causal_mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment