Unverified Commit 226b0e46 authored by wangxu's avatar wangxu Committed by GitHub
Browse files

Add a use_parallel_residual argument to control the residual computing way (#18695)

* Add a gpt_j_residual argument to control the residual computing way

* Put duplicate code outside of the if block

* Rename parameter "gpt_j_residual" to "use_parallel_residual" and set the default value to True
parent 88f597ba
...@@ -66,6 +66,9 @@ class GPTNeoXConfig(PretrainedConfig): ...@@ -66,6 +66,9 @@ class GPTNeoXConfig(PretrainedConfig):
use_cache (`bool`, *optional*, defaults to `True`): use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`. relevant if `config.is_decoder=True`.
use_parallel_residual (`bool`, *optional*, defaults to `True`):
Whether to use a "parallel" formulation in each Transformer layer, which can provide a slight training
speedup at large scales (e.g. 20B).
Example: Example:
```python ```python
...@@ -99,6 +102,7 @@ class GPTNeoXConfig(PretrainedConfig): ...@@ -99,6 +102,7 @@ class GPTNeoXConfig(PretrainedConfig):
bos_token_id=0, bos_token_id=0,
eos_token_id=2, eos_token_id=2,
tie_word_embeddings=False, tie_word_embeddings=False,
use_parallel_residual=True,
**kwargs **kwargs
): ):
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
...@@ -115,3 +119,4 @@ class GPTNeoXConfig(PretrainedConfig): ...@@ -115,3 +119,4 @@ class GPTNeoXConfig(PretrainedConfig):
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.use_cache = use_cache self.use_cache = use_cache
self.tie_word_embeddings = tie_word_embeddings self.tie_word_embeddings = tie_word_embeddings
self.use_parallel_residual = use_parallel_residual
...@@ -300,6 +300,7 @@ class GPTNeoXMLP(nn.Module): ...@@ -300,6 +300,7 @@ class GPTNeoXMLP(nn.Module):
class GPTNeoXLayer(nn.Module): class GPTNeoXLayer(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention = GPTNeoXAttention(config) self.attention = GPTNeoXAttention(config)
...@@ -314,28 +315,37 @@ class GPTNeoXLayer(nn.Module): ...@@ -314,28 +315,37 @@ class GPTNeoXLayer(nn.Module):
layer_past=None, layer_past=None,
output_attentions=False, output_attentions=False,
): ):
residual = hidden_states
ln_out = self.input_layernorm(hidden_states)
attention_layer_outputs = self.attention( attention_layer_outputs = self.attention(
ln_out, self.input_layernorm(hidden_states),
attention_mask=attention_mask, attention_mask=attention_mask,
layer_past=layer_past, layer_past=layer_past,
head_mask=head_mask, head_mask=head_mask,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
attn_output = attention_layer_outputs[0] # output_attn: a, present, (attentions) attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights)
outputs = attention_layer_outputs[1:] outputs = attention_layer_outputs[1:]
mlp_output = self.mlp(self.post_attention_layernorm(hidden_states)) if self.use_parallel_residual:
hidden_states = mlp_output + attn_output + residual # pseudocode:
# x = x + attn(ln1(x)) + mlp(ln2(x))
mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
hidden_states = mlp_output + attn_output + hidden_states
else:
# pseudocode:
# x = x + attn(ln1(x))
# x = x + mlp(ln2(x))
attn_output = attn_output + hidden_states
mlp_output = self.mlp(self.post_attention_layernorm(attn_output))
hidden_states = mlp_output + attn_output
if use_cache: if use_cache:
outputs = (hidden_states,) + outputs outputs = (hidden_states,) + outputs # hidden_states, present, (attn_weights)
else: else:
outputs = (hidden_states,) + outputs[1:] outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights)
return outputs # hidden_states, present, (attentions) return outputs
GPT_NEOX_START_DOCSTRING = r""" GPT_NEOX_START_DOCSTRING = r"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment