Unverified Commit 9ef9c397 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Cannot index `None` (#6984)

parent 08de989a
......@@ -464,6 +464,8 @@ class BertEncoder(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
if getattr(self.config, "gradient_checkpointing", False):
def create_custom_forward(module):
......@@ -476,7 +478,7 @@ class BertEncoder(nn.Module):
create_custom_forward(layer_module),
hidden_states,
attention_mask,
head_mask[i],
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
)
......@@ -484,7 +486,7 @@ class BertEncoder(nn.Module):
layer_outputs = layer_module(
hidden_states,
attention_mask,
head_mask[i],
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment