"vscode:/vscode.git/clone" did not exist on "32d5f028ae6dc7d5eda078e3477b519359673f74"
Commit 714c1b4f authored by Tri Dao's avatar Tri Dao
Browse files

[Bert] Fix embedding layer norm before embedding dropout

parent ef1ba918
...@@ -295,7 +295,7 @@ class BertModel(BertPreTrainedModel): ...@@ -295,7 +295,7 @@ class BertModel(BertPreTrainedModel):
config.vocab_size += (self.pad_vocab_size_multiple config.vocab_size += (self.pad_vocab_size_multiple
- (config.vocab_size % self.pad_vocab_size_multiple)) - (config.vocab_size % self.pad_vocab_size_multiple))
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False) self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
if self.fused_dropout_add_ln and dropout_add_layer_norm is None: if self.fused_dropout_add_ln and layer_norm is None:
raise ImportError('dropout_add_layer_norm is not installed') raise ImportError('dropout_add_layer_norm is not installed')
assert config.position_embedding_type == 'absolute' assert config.position_embedding_type == 'absolute'
assert config.hidden_act in ['gelu', 'gelu_new', 'gelu_fast'] assert config.hidden_act in ['gelu', 'gelu_new', 'gelu_fast']
...@@ -320,14 +320,13 @@ class BertModel(BertPreTrainedModel): ...@@ -320,14 +320,13 @@ class BertModel(BertPreTrainedModel):
hidden_states = self.embeddings(input_ids, position_ids=position_ids, hidden_states = self.embeddings(input_ids, position_ids=position_ids,
token_type_ids=token_type_ids) token_type_ids=token_type_ids)
# TD [2022-12:18]: Don't need to force residual in fp32 # TD [2022-12:18]: Don't need to force residual in fp32
# BERT puts embedding LayerNorm before embedding dropout.
if not self.fused_dropout_add_ln: if not self.fused_dropout_add_ln:
hidden_states = self.emb_drop(hidden_states)
hidden_states = self.emb_ln(hidden_states) hidden_states = self.emb_ln(hidden_states)
else: else:
hidden_states = dropout_add_layer_norm( hidden_states = layer_norm(hidden_states, self.emb_ln.weight, self.emb_ln.bias,
hidden_states, None, self.emb_ln.weight, self.emb_ln.bias, self.emb_ln.eps)
self.emb_drop.p if self.training else 0.0, self.emb_ln.eps, prenorm=False, hidden_states = self.emb_drop(hidden_states)
)
if masked_tokens_mask is not None: if masked_tokens_mask is not None:
batch_size, seqlen = input_ids.shape[:2] batch_size, seqlen = input_ids.shape[:2]
......
...@@ -220,6 +220,9 @@ class GPTModel(GPTPreTrainedModel): ...@@ -220,6 +220,9 @@ class GPTModel(GPTPreTrainedModel):
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers, self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
initializer_range=config.initializer_range)) initializer_range=config.initializer_range))
self.tie_weights()
def tie_weights(self):
if self.process_group is not None: if self.process_group is not None:
sync_sequence_parallel_params(self, self.process_group) sync_sequence_parallel_params(self, self.process_group)
...@@ -266,11 +269,11 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): ...@@ -266,11 +269,11 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers, self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
initializer_range=config.initializer_range)) initializer_range=config.initializer_range))
self.tie_weights() self.tie_weights()
if self.process_group is not None:
sync_sequence_parallel_params(self, self.process_group)
def tie_weights(self): def tie_weights(self):
self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
if self.process_group is not None:
sync_sequence_parallel_params(self, self.process_group)
def forward(self, input_ids, position_ids=None, inference_params=None): def forward(self, input_ids, position_ids=None, inference_params=None):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment