Unverified Commit c1780ce7 authored by baeseongsu's avatar baeseongsu Committed by GitHub
Browse files

fix head_mask for albert encoder part(`AlbertTransformer`) (#11596)

* fix head mask for albert encoder part

* fix head_mask for albert encoder part
parent 864c1dfe
...@@ -450,6 +450,8 @@ class AlbertTransformer(nn.Module): ...@@ -450,6 +450,8 @@ class AlbertTransformer(nn.Module):
all_hidden_states = (hidden_states,) if output_hidden_states else None all_hidden_states = (hidden_states,) if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
head_mask = [None] * self.config.num_hidden_layers if head_mask is None else head_mask
for i in range(self.config.num_hidden_layers): for i in range(self.config.num_hidden_layers):
# Number of layers in a hidden group # Number of layers in a hidden group
layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups) layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment