Unverified Commit 225c36fb authored by Chiao's avatar Chiao Committed by GitHub
Browse files

gradient checkpointing for GPT-NeoX (#19946)

* gradient checkpointing for GPT-NeoX

* initialize gradient checkpointing flag

* must set flag before init
parent 6176e136
...@@ -422,6 +422,8 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): ...@@ -422,6 +422,8 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
...@@ -518,14 +520,37 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): ...@@ -518,14 +520,37 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)): for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
outputs = layer(
hidden_states, if self.gradient_checkpointing and self.training:
attention_mask=attention_mask,
head_mask=head_mask[i], if use_cache:
layer_past=layer_past, logger.warning(
use_cache=use_cache, "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
output_attentions=output_attentions, )
) use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
# None for layer_past
return module(*inputs, use_cache, None, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
hidden_states,
attention_mask,
head_mask[i],
)
else:
outputs = layer(
hidden_states,
attention_mask=attention_mask,
head_mask=head_mask[i],
layer_past=layer_past,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0] hidden_states = outputs[0]
if use_cache is True: if use_cache is True:
presents = presents + (outputs[1],) presents = presents + (outputs[1],)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment