"examples/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "bc0e55a00bbbc825f27d851ccc58a749d18b4fd9"
Commit dda1adad authored by Rémi Louf's avatar Rémi Louf
Browse files

rename BertLayer to BertEncoderLayer

parent 0053c0e0
...@@ -315,9 +315,9 @@ class BertOutput(nn.Module): ...@@ -315,9 +315,9 @@ class BertOutput(nn.Module):
return hidden_states return hidden_states
class BertLayer(nn.Module): class BertEncoderLayer(nn.Module):
def __init__(self, config): def __init__(self, config):
super(BertLayer, self).__init__() super(BertEncoderLayer, self).__init__()
self.attention = BertAttention(config) self.attention = BertAttention(config)
self.intermediate = BertIntermediate(config) self.intermediate = BertIntermediate(config)
self.output = BertOutput(config) self.output = BertOutput(config)
...@@ -336,7 +336,7 @@ class BertEncoder(nn.Module): ...@@ -336,7 +336,7 @@ class BertEncoder(nn.Module):
super(BertEncoder, self).__init__() super(BertEncoder, self).__init__()
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([BertEncoderLayer(config) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask=None, head_mask=None): def forward(self, hidden_states, attention_mask=None, head_mask=None):
all_hidden_states = () all_hidden_states = ()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment