"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "c1aaa439350051acdcd585946e91525502a6b063"
Unverified Commit 226b0e46 authored by wangxu's avatar wangxu Committed by GitHub
Browse files

Add a use_parallel_residual argument to control the residual computing way (#18695)

* Add a gpt_j_residual argument to control the residual computing way

* Put duplicate code outside of the if block

* Rename parameter "gpt_j_residual" to "use_parallel_residual" and set the default value to True
parent 88f597ba
......@@ -66,6 +66,9 @@ class GPTNeoXConfig(PretrainedConfig):
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
use_parallel_residual (`bool`, *optional*, defaults to `True`):
Whether to use a "parallel" formulation in each Transformer layer, which can provide a slight training
speedup at large scales (e.g. 20B).
Example:
```python
......@@ -99,6 +102,7 @@ class GPTNeoXConfig(PretrainedConfig):
bos_token_id=0,
eos_token_id=2,
tie_word_embeddings=False,
use_parallel_residual=True,
**kwargs
):
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
......@@ -115,3 +119,4 @@ class GPTNeoXConfig(PretrainedConfig):
self.layer_norm_eps = layer_norm_eps
self.use_cache = use_cache
self.tie_word_embeddings = tie_word_embeddings
self.use_parallel_residual = use_parallel_residual
......@@ -300,6 +300,7 @@ class GPTNeoXMLP(nn.Module):
class GPTNeoXLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention = GPTNeoXAttention(config)
......@@ -314,28 +315,37 @@ class GPTNeoXLayer(nn.Module):
layer_past=None,
output_attentions=False,
):
residual = hidden_states
ln_out = self.input_layernorm(hidden_states)
attention_layer_outputs = self.attention(
ln_out,
self.input_layernorm(hidden_states),
attention_mask=attention_mask,
layer_past=layer_past,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attn_output = attention_layer_outputs[0] # output_attn: a, present, (attentions)
attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights)
outputs = attention_layer_outputs[1:]
mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
hidden_states = mlp_output + attn_output + residual
if self.use_parallel_residual:
# pseudocode:
# x = x + attn(ln1(x)) + mlp(ln2(x))
mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
hidden_states = mlp_output + attn_output + hidden_states
else:
# pseudocode:
# x = x + attn(ln1(x))
# x = x + mlp(ln2(x))
attn_output = attn_output + hidden_states
mlp_output = self.mlp(self.post_attention_layernorm(attn_output))
hidden_states = mlp_output + attn_output
if use_cache:
outputs = (hidden_states,) + outputs
outputs = (hidden_states,) + outputs # hidden_states, present, (attn_weights)
else:
outputs = (hidden_states,) + outputs[1:]
outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights)
return outputs # hidden_states, present, (attentions)
return outputs
GPT_NEOX_START_DOCSTRING = r"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment