"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "5c6f57ee75665499c8045a8bf7c73bf2415fba20"
Unverified Commit 225c36fb authored by Chiao's avatar Chiao Committed by GitHub
Browse files

gradient checkpointing for GPT-NeoX (#19946)

* gradient checkpointing for GPT-NeoX

* initialize gradient checkpointing flag

* must set flag before init
parent 6176e136
......@@ -422,6 +422,8 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
......@@ -518,14 +520,37 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = layer(
hidden_states,
attention_mask=attention_mask,
head_mask=head_mask[i],
layer_past=layer_past,
use_cache=use_cache,
output_attentions=output_attentions,
)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
# None for layer_past
return module(*inputs, use_cache, None, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
hidden_states,
attention_mask,
head_mask[i],
)
else:
outputs = layer(
hidden_states,
attention_mask=attention_mask,
head_mask=head_mask[i],
layer_past=layer_past,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment