"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "9435cc6670b7b8656b33e8ff28d3bbe9bafbca9d"
Unverified Commit 89f0fda5 authored by yqy2001's avatar yqy2001 Committed by GitHub
Browse files

Fix the gradient checkpointing bug of the llama model (#22270)

fix grad ckpt bug of llama
parent cf0af9a3
...@@ -372,7 +372,7 @@ class LlamaPreTrainedModel(PreTrainedModel): ...@@ -372,7 +372,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (LlamaDecoderLayer)): if isinstance(module, LlamaModel):
module.gradient_checkpointing = value module.gradient_checkpointing = value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment