Unverified Commit bd262158 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

fix data type (#7513)

parent 62f5ae68
...@@ -238,13 +238,20 @@ class ModuleUtilsMixin: ...@@ -238,13 +238,20 @@ class ModuleUtilsMixin:
seq_ids = torch.arange(seq_length, device=device) seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
# in case past_key_values are used we need to add a prefix ones mask to the causal mask # in case past_key_values are used we need to add a prefix ones mask to the causal mask
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(attention_mask.dtype)
if causal_mask.shape[1] < attention_mask.shape[1]: if causal_mask.shape[1] < attention_mask.shape[1]:
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
causal_mask = torch.cat( causal_mask = torch.cat(
[torch.ones((batch_size, seq_length, prefix_seq_len), device=device), causal_mask], axis=-1 [
torch.ones(
(batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype
),
causal_mask,
],
axis=-1,
) )
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(attention_mask.dtype)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment