Unverified Commit 061a73d1 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[CodeGen] support device_map="auto" for sharded checkpoints (#17871)

parent d6b6fb99
...@@ -332,6 +332,7 @@ class CodeGenPreTrainedModel(PreTrainedModel): ...@@ -332,6 +332,7 @@ class CodeGenPreTrainedModel(PreTrainedModel):
config_class = CodeGenConfig config_class = CodeGenConfig
base_model_prefix = "transformer" base_model_prefix = "transformer"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["CodeGenBlock"]
def __init__(self, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment