Unverified Commit 248fa1ae authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

fix T5 head mask in model_parallel (#9726)

* fix head mask in model_parallel

* pass correct head mask
parent ca422e3d
...@@ -920,6 +920,8 @@ class T5Stack(T5PreTrainedModel): ...@@ -920,6 +920,8 @@ class T5Stack(T5PreTrainedModel):
hidden_states = self.dropout(inputs_embeds) hidden_states = self.dropout(inputs_embeds)
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
layer_head_mask = head_mask[i]
encoder_layer_head_mask = encoder_head_mask[i]
# Model parallel # Model parallel
if self.model_parallel: if self.model_parallel:
torch.cuda.set_device(hidden_states.device) torch.cuda.set_device(hidden_states.device)
...@@ -934,10 +936,10 @@ class T5Stack(T5PreTrainedModel): ...@@ -934,10 +936,10 @@ class T5Stack(T5PreTrainedModel):
encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
if encoder_decoder_position_bias is not None: if encoder_decoder_position_bias is not None:
encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
if not (isinstance(head_mask, list) and head_mask[0] is None): if layer_head_mask is not None:
head_mask = head_mask.to(hidden_states.device) layer_head_mask = layer_head_mask.to(hidden_states.device)
if not (isinstance(encoder_head_mask, list) and encoder_head_mask[0] is None): if encoder_layer_head_mask is not None:
encoder_head_mask = encoder_head_mask.to(hidden_states.device) encoder_layer_head_mask = encoder_layer_head_mask.to(hidden_states.device)
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
...@@ -948,8 +950,8 @@ class T5Stack(T5PreTrainedModel): ...@@ -948,8 +950,8 @@ class T5Stack(T5PreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask, encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=head_mask[i], layer_head_mask=layer_head_mask,
encoder_layer_head_mask=encoder_head_mask[i] if encoder_head_mask is not None else None, encoder_layer_head_mask=encoder_layer_head_mask,
past_key_value=past_key_value, past_key_value=past_key_value,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment