Unverified Commit ab2006e3 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

BART - Fix attention mask device issue on copied models (#18540)

* attempt to fix attn mask device

* fix bart `_prepare_decoder_attention_mask`

- add correct device
- run `make fix-copies` to propagate the fix
parent 6bea7b81
...@@ -915,7 +915,9 @@ class BartDecoder(BartPretrainedModel): ...@@ -915,7 +915,9 @@ class BartDecoder(BartPretrainedModel):
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = ( combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
) )
......
...@@ -2116,7 +2116,9 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): ...@@ -2116,7 +2116,9 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = ( combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
) )
......
...@@ -854,7 +854,9 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): ...@@ -854,7 +854,9 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = ( combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
) )
......
...@@ -850,7 +850,9 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): ...@@ -850,7 +850,9 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = ( combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
) )
......
...@@ -860,7 +860,9 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -860,7 +860,9 @@ class MarianDecoder(MarianPreTrainedModel):
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = ( combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
) )
......
...@@ -913,7 +913,9 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -913,7 +913,9 @@ class MBartDecoder(MBartPreTrainedModel):
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = ( combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
) )
......
...@@ -534,7 +534,9 @@ class OPTDecoder(OPTPreTrainedModel): ...@@ -534,7 +534,9 @@ class OPTDecoder(OPTPreTrainedModel):
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = ( combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
) )
......
...@@ -880,7 +880,9 @@ class PegasusDecoder(PegasusPreTrainedModel): ...@@ -880,7 +880,9 @@ class PegasusDecoder(PegasusPreTrainedModel):
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = ( combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
) )
......
...@@ -887,7 +887,9 @@ class PLBartDecoder(PLBartPreTrainedModel): ...@@ -887,7 +887,9 @@ class PLBartDecoder(PLBartPreTrainedModel):
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = ( combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment