Unverified Commit 39274045 authored by Zhao Tianyu's avatar Zhao Tianyu Committed by GitHub
Browse files

Add dropouts to GPT-NeoX (#24680)

* add attention dropout, post attention dropout, post mlp dropout to gpt-neox

* fix typo

* add documentation

* fix too long line

* ran Checking/fixing src/transformers/models/gpt_neox/configuration_gpt_neox.py src/transformers/models/gpt_neox/modeling_gpt_neox.py
python utils/custom_init_isort.py
python utils/sort_auto_mappings.py
doc-builder style src/transformers docs/source --max_len 119 --path_to_docs docs/source
python utils/check_doc_toc.py --fix_and_overwrite
running deps_table_update
updating src/transformers/dependency_versions_table.py
python utils/check_copies.py
python utils/check_table.py
python utils/check_dummies.py
python utils/check_repo.py
Checking all models are included.
Checking all models are public.
Checking all models are properly tested.
Checking all objects are properly documented.
Checking all models are in at least one auto class.
Checking all names in auto name mappings are defined.
Checking all keys in auto name mappings are defined in `CONFIG_MAPPING_NAMES`.
Checking all auto mappings could be imported.
Checking all objects are equally (across frameworks) in the main __init__.
python utils/check_inits.py
python utils/check_config_docstrings.py
python utils/check_config_attributes.py
python utils/check_doctest_list.py
python utils/update_metadata.py --check-only
python utils/check_task_guides.py
parent fb3b22c3
...@@ -56,6 +56,11 @@ class GPTNeoXConfig(PretrainedConfig): ...@@ -56,6 +56,11 @@ class GPTNeoXConfig(PretrainedConfig):
percentage of hidden dimensions to allocate to rotary embeddings percentage of hidden dimensions to allocate to rotary embeddings
rotary_emb_base (`int`, *optional*, defaults to 10000) rotary_emb_base (`int`, *optional*, defaults to 10000)
base for computing rotary embeddings frequency base for computing rotary embeddings frequency
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio probability of the attention score.
hidden_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio of (1) the word embeddings, (2) the post-attention hidden states, and (3) the post-mlp
hidden states.
classifier_dropout (`float`, *optional*, defaults to 0.1): classifier_dropout (`float`, *optional*, defaults to 0.1):
Argument used when doing token classification, used in the model [`GPTNeoXForTokenClassification`]. Argument used when doing token classification, used in the model [`GPTNeoXForTokenClassification`].
...@@ -99,6 +104,8 @@ class GPTNeoXConfig(PretrainedConfig): ...@@ -99,6 +104,8 @@ class GPTNeoXConfig(PretrainedConfig):
hidden_act="gelu", hidden_act="gelu",
rotary_pct=0.25, rotary_pct=0.25,
rotary_emb_base=10000, rotary_emb_base=10000,
attention_dropout=0.0,
hidden_dropout=0.0,
classifier_dropout=0.1, classifier_dropout=0.1,
max_position_embeddings=2048, max_position_embeddings=2048,
initializer_range=0.02, initializer_range=0.02,
...@@ -120,6 +127,8 @@ class GPTNeoXConfig(PretrainedConfig): ...@@ -120,6 +127,8 @@ class GPTNeoXConfig(PretrainedConfig):
self.hidden_act = hidden_act self.hidden_act = hidden_act
self.rotary_pct = rotary_pct self.rotary_pct = rotary_pct
self.rotary_emb_base = rotary_emb_base self.rotary_emb_base = rotary_emb_base
self.attention_dropout = attention_dropout
self.hidden_dropout = hidden_dropout
self.classifier_dropout = classifier_dropout self.classifier_dropout = classifier_dropout
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
......
...@@ -114,6 +114,8 @@ class GPTNeoXAttention(nn.Module): ...@@ -114,6 +114,8 @@ class GPTNeoXAttention(nn.Module):
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size) self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.attention_dropout = nn.Dropout(config.attention_dropout)
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.FloatTensor,
...@@ -245,6 +247,8 @@ class GPTNeoXAttention(nn.Module): ...@@ -245,6 +247,8 @@ class GPTNeoXAttention(nn.Module):
if head_mask is not None: if head_mask is not None:
attn_weights = attn_weights * head_mask attn_weights = attn_weights * head_mask
attn_weights = self.attention_dropout(attn_weights)
attn_output = torch.matmul(attn_weights, value) attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights return attn_output, attn_weights
...@@ -320,6 +324,8 @@ class GPTNeoXLayer(nn.Module): ...@@ -320,6 +324,8 @@ class GPTNeoXLayer(nn.Module):
self.use_parallel_residual = config.use_parallel_residual self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_dropout = nn.Dropout(config.hidden_dropout)
self.post_mlp_dropout = nn.Dropout(config.hidden_dropout)
self.attention = GPTNeoXAttention(config) self.attention = GPTNeoXAttention(config)
self.mlp = GPTNeoXMLP(config) self.mlp = GPTNeoXMLP(config)
...@@ -343,12 +349,14 @@ class GPTNeoXLayer(nn.Module): ...@@ -343,12 +349,14 @@ class GPTNeoXLayer(nn.Module):
output_attentions=output_attentions, output_attentions=output_attentions,
) )
attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights) attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights)
attn_output = self.post_attention_dropout(attn_output)
outputs = attention_layer_outputs[1:] outputs = attention_layer_outputs[1:]
if self.use_parallel_residual: if self.use_parallel_residual:
# pseudocode: # pseudocode:
# x = x + attn(ln1(x)) + mlp(ln2(x)) # x = x + attn(ln1(x)) + mlp(ln2(x))
mlp_output = self.mlp(self.post_attention_layernorm(hidden_states)) mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
mlp_output = self.post_mlp_dropout(mlp_output)
hidden_states = mlp_output + attn_output + hidden_states hidden_states = mlp_output + attn_output + hidden_states
else: else:
# pseudocode: # pseudocode:
...@@ -356,6 +364,7 @@ class GPTNeoXLayer(nn.Module): ...@@ -356,6 +364,7 @@ class GPTNeoXLayer(nn.Module):
# x = x + mlp(ln2(x)) # x = x + mlp(ln2(x))
attn_output = attn_output + hidden_states attn_output = attn_output + hidden_states
mlp_output = self.mlp(self.post_attention_layernorm(attn_output)) mlp_output = self.mlp(self.post_attention_layernorm(attn_output))
mlp_output = self.post_mlp_dropout(mlp_output)
hidden_states = mlp_output + attn_output hidden_states = mlp_output + attn_output
if use_cache: if use_cache:
...@@ -429,6 +438,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): ...@@ -429,6 +438,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
self.config = config self.config = config
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size) self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
self.emb_dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
...@@ -533,7 +543,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): ...@@ -533,7 +543,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_in(input_ids) inputs_embeds = self.embed_in(input_ids)
hidden_states = inputs_embeds hidden_states = self.emb_dropout(inputs_embeds)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment