"test/git@developer.sourcefind.cn:change/sglang.git" did not exist on "de8b8b6e5cec42d91e9b7cc3cad849f36f424545"
Unverified Commit fabe17a7 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Skip device placement for past key values in decoder models (#23919)

parent 6affd9cd
...@@ -1052,6 +1052,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1052,6 +1052,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
main_input_name = "input_ids" main_input_name = "input_ids"
_auto_class = None _auto_class = None
_no_split_modules = None _no_split_modules = None
_skip_keys_device_placement = None
_keep_in_fp32_modules = None _keep_in_fp32_modules = None
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
...@@ -2887,7 +2888,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2887,7 +2888,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Dispatch model with hooks on all devices if necessary # Dispatch model with hooks on all devices if necessary
if device_map is not None: if device_map is not None:
dispatch_model(model, device_map=device_map, offload_dir=offload_folder, offload_index=offload_index) kwargs = {"device_map": device_map, "offload_dir": offload_folder, "offload_index": offload_index}
if "skip_keys" in inspect.signature(dispatch_model).parameters:
kwargs["skip_keys"] = model._skip_keys_device_placement
dispatch_model(model, **kwargs)
if output_loading_info: if output_loading_info:
if loading_info is None: if loading_info is None:
......
...@@ -509,6 +509,7 @@ class BartPretrainedModel(PreTrainedModel): ...@@ -509,6 +509,7 @@ class BartPretrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keys_to_ignore_on_load_unexpected = [r"encoder.version", r"decoder.version"] _keys_to_ignore_on_load_unexpected = [r"encoder.version", r"decoder.version"]
_no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"] _no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
......
...@@ -1597,6 +1597,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel): ...@@ -1597,6 +1597,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["BigBirdPegasusEncoderLayer", "BigBirdPegasusDecoderLayer"] _no_split_modules = ["BigBirdPegasusEncoderLayer", "BigBirdPegasusDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
......
...@@ -286,6 +286,7 @@ class Blip2PreTrainedModel(PreTrainedModel): ...@@ -286,6 +286,7 @@ class Blip2PreTrainedModel(PreTrainedModel):
r"language_model.lm_head.weight", r"language_model.lm_head.weight",
] ]
_no_split_modules = ["Blip2Attention", "T5Block", "OPTDecoderLayer"] _no_split_modules = ["Blip2Attention", "T5Block", "OPTDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_keep_in_fp32_modules = ["wo"] _keep_in_fp32_modules = ["wo"]
def _init_weights(self, module): def _init_weights(self, module):
......
...@@ -481,6 +481,7 @@ class BloomPreTrainedModel(PreTrainedModel): ...@@ -481,6 +481,7 @@ class BloomPreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer" base_model_prefix = "transformer"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["BloomBlock"] _no_split_modules = ["BloomBlock"]
_skip_keys_device_placement = "past_key_values"
def __init__(self, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
......
...@@ -982,6 +982,7 @@ class BridgeTowerPreTrainedModel(PreTrainedModel): ...@@ -982,6 +982,7 @@ class BridgeTowerPreTrainedModel(PreTrainedModel):
base_model_prefix = "bridgetower" base_model_prefix = "bridgetower"
supports_gradient_checkpointing = False supports_gradient_checkpointing = False
_no_split_modules = ["BridgeTowerSelfAttention", "BridgeTowerResidualAttention"] _no_split_modules = ["BridgeTowerSelfAttention", "BridgeTowerResidualAttention"]
_skip_keys_device_placement = "past_key_values"
def _init_weights(self, module): def _init_weights(self, module):
if isinstance(module, BridgeTowerVisionModel): if isinstance(module, BridgeTowerVisionModel):
......
...@@ -315,6 +315,7 @@ class CodeGenPreTrainedModel(PreTrainedModel): ...@@ -315,6 +315,7 @@ class CodeGenPreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer" base_model_prefix = "transformer"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["CodeGenBlock"] _no_split_modules = ["CodeGenBlock"]
_skip_keys_device_placement = "past_key_values"
def __init__(self, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
......
...@@ -449,6 +449,7 @@ class GPT2PreTrainedModel(PreTrainedModel): ...@@ -449,6 +449,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
is_parallelizable = True is_parallelizable = True
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["GPT2Block"] _no_split_modules = ["GPT2Block"]
_skip_keys_device_placement = "past_key_values"
def __init__(self, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
......
...@@ -372,6 +372,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel): ...@@ -372,6 +372,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer" base_model_prefix = "transformer"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["GPTBigCodeBlock"] _no_split_modules = ["GPTBigCodeBlock"]
_skip_keys_device_placement = "past_key_values"
def __init__(self, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
......
...@@ -363,6 +363,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel): ...@@ -363,6 +363,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer" base_model_prefix = "transformer"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["GPTNeoBlock"] _no_split_modules = ["GPTNeoBlock"]
_skip_keys_device_placement = "past_key_values"
def __init__(self, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
......
...@@ -62,6 +62,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel): ...@@ -62,6 +62,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
base_model_prefix = "gpt_neox" base_model_prefix = "gpt_neox"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["GPTNeoXLayer"] _no_split_modules = ["GPTNeoXLayer"]
_skip_keys_device_placement = "past_key_values"
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -50,6 +50,7 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel): ...@@ -50,6 +50,7 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
base_model_prefix = "gpt_neox_japanese" base_model_prefix = "gpt_neox_japanese"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["GPTNeoXJapaneseLayer"] _no_split_modules = ["GPTNeoXJapaneseLayer"]
_skip_keys_device_placement = "past_key_values"
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -340,6 +340,7 @@ class GPTJPreTrainedModel(PreTrainedModel): ...@@ -340,6 +340,7 @@ class GPTJPreTrainedModel(PreTrainedModel):
is_parallelizable = True is_parallelizable = True
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["GPTJBlock"] _no_split_modules = ["GPTJBlock"]
_skip_keys_device_placement = "past_key_values"
def __init__(self, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
......
...@@ -692,6 +692,7 @@ class GPTSanJapanesePreTrainedModel(PreTrainedModel): ...@@ -692,6 +692,7 @@ class GPTSanJapanesePreTrainedModel(PreTrainedModel):
base_model_prefix = "gptsan_japanese" base_model_prefix = "gptsan_japanese"
supports_gradient_checkpointing = False supports_gradient_checkpointing = False
_no_split_modules = ["GPTSanJapaneseBlock"] _no_split_modules = ["GPTSanJapaneseBlock"]
_skip_keys_device_placement = "past_key_values"
@property @property
def dummy_inputs(self): def dummy_inputs(self):
......
...@@ -342,6 +342,7 @@ class LlamaPreTrainedModel(PreTrainedModel): ...@@ -342,6 +342,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"] _no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"] _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
def _init_weights(self, module): def _init_weights(self, module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment