Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fabe17a7
Unverified
Commit
fabe17a7
authored
May 31, 2023
by
Sylvain Gugger
Committed by
GitHub
May 31, 2023
Browse files
Skip device placement for past key values in decoder models (#23919)
parent
6affd9cd
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
19 additions
and
1 deletion
+19
-1
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+5
-1
src/transformers/models/bart/modeling_bart.py
src/transformers/models/bart/modeling_bart.py
+1
-0
src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
...ormers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
+1
-0
src/transformers/models/blip_2/modeling_blip_2.py
src/transformers/models/blip_2/modeling_blip_2.py
+1
-0
src/transformers/models/bloom/modeling_bloom.py
src/transformers/models/bloom/modeling_bloom.py
+1
-0
src/transformers/models/bridgetower/modeling_bridgetower.py
src/transformers/models/bridgetower/modeling_bridgetower.py
+1
-0
src/transformers/models/codegen/modeling_codegen.py
src/transformers/models/codegen/modeling_codegen.py
+1
-0
src/transformers/models/gpt2/modeling_gpt2.py
src/transformers/models/gpt2/modeling_gpt2.py
+1
-0
src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
+1
-0
src/transformers/models/gpt_neo/modeling_gpt_neo.py
src/transformers/models/gpt_neo/modeling_gpt_neo.py
+1
-0
src/transformers/models/gpt_neox/modeling_gpt_neox.py
src/transformers/models/gpt_neox/modeling_gpt_neox.py
+1
-0
src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py
...rs/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py
+1
-0
src/transformers/models/gptj/modeling_gptj.py
src/transformers/models/gptj/modeling_gptj.py
+1
-0
src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py
...ormers/models/gptsan_japanese/modeling_gptsan_japanese.py
+1
-0
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+1
-0
No files found.
src/transformers/modeling_utils.py
View file @
fabe17a7
...
@@ -1052,6 +1052,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -1052,6 +1052,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
main_input_name
=
"input_ids"
main_input_name
=
"input_ids"
_auto_class
=
None
_auto_class
=
None
_no_split_modules
=
None
_no_split_modules
=
None
_skip_keys_device_placement
=
None
_keep_in_fp32_modules
=
None
_keep_in_fp32_modules
=
None
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
...
@@ -2887,7 +2888,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -2887,7 +2888,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Dispatch model with hooks on all devices if necessary
# Dispatch model with hooks on all devices if necessary
if
device_map
is
not
None
:
if
device_map
is
not
None
:
dispatch_model
(
model
,
device_map
=
device_map
,
offload_dir
=
offload_folder
,
offload_index
=
offload_index
)
kwargs
=
{
"device_map"
:
device_map
,
"offload_dir"
:
offload_folder
,
"offload_index"
:
offload_index
}
if
"skip_keys"
in
inspect
.
signature
(
dispatch_model
).
parameters
:
kwargs
[
"skip_keys"
]
=
model
.
_skip_keys_device_placement
dispatch_model
(
model
,
**
kwargs
)
if
output_loading_info
:
if
output_loading_info
:
if
loading_info
is
None
:
if
loading_info
is
None
:
...
...
src/transformers/models/bart/modeling_bart.py
View file @
fabe17a7
...
@@ -509,6 +509,7 @@ class BartPretrainedModel(PreTrainedModel):
...
@@ -509,6 +509,7 @@ class BartPretrainedModel(PreTrainedModel):
supports_gradient_checkpointing
=
True
supports_gradient_checkpointing
=
True
_keys_to_ignore_on_load_unexpected
=
[
r
"encoder.version"
,
r
"decoder.version"
]
_keys_to_ignore_on_load_unexpected
=
[
r
"encoder.version"
,
r
"decoder.version"
]
_no_split_modules
=
[
r
"BartEncoderLayer"
,
r
"BartDecoderLayer"
]
_no_split_modules
=
[
r
"BartEncoderLayer"
,
r
"BartDecoderLayer"
]
_skip_keys_device_placement
=
"past_key_values"
def
_init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
init_std
std
=
self
.
config
.
init_std
...
...
src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
View file @
fabe17a7
...
@@ -1597,6 +1597,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
...
@@ -1597,6 +1597,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
base_model_prefix
=
"model"
base_model_prefix
=
"model"
supports_gradient_checkpointing
=
True
supports_gradient_checkpointing
=
True
_no_split_modules
=
[
"BigBirdPegasusEncoderLayer"
,
"BigBirdPegasusDecoderLayer"
]
_no_split_modules
=
[
"BigBirdPegasusEncoderLayer"
,
"BigBirdPegasusDecoderLayer"
]
_skip_keys_device_placement
=
"past_key_values"
def
_init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
init_std
std
=
self
.
config
.
init_std
...
...
src/transformers/models/blip_2/modeling_blip_2.py
View file @
fabe17a7
...
@@ -286,6 +286,7 @@ class Blip2PreTrainedModel(PreTrainedModel):
...
@@ -286,6 +286,7 @@ class Blip2PreTrainedModel(PreTrainedModel):
r
"language_model.lm_head.weight"
,
r
"language_model.lm_head.weight"
,
]
]
_no_split_modules
=
[
"Blip2Attention"
,
"T5Block"
,
"OPTDecoderLayer"
]
_no_split_modules
=
[
"Blip2Attention"
,
"T5Block"
,
"OPTDecoderLayer"
]
_skip_keys_device_placement
=
"past_key_values"
_keep_in_fp32_modules
=
[
"wo"
]
_keep_in_fp32_modules
=
[
"wo"
]
def
_init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
...
...
src/transformers/models/bloom/modeling_bloom.py
View file @
fabe17a7
...
@@ -481,6 +481,7 @@ class BloomPreTrainedModel(PreTrainedModel):
...
@@ -481,6 +481,7 @@ class BloomPreTrainedModel(PreTrainedModel):
base_model_prefix
=
"transformer"
base_model_prefix
=
"transformer"
supports_gradient_checkpointing
=
True
supports_gradient_checkpointing
=
True
_no_split_modules
=
[
"BloomBlock"
]
_no_split_modules
=
[
"BloomBlock"
]
_skip_keys_device_placement
=
"past_key_values"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
().
__init__
(
*
inputs
,
**
kwargs
)
super
().
__init__
(
*
inputs
,
**
kwargs
)
...
...
src/transformers/models/bridgetower/modeling_bridgetower.py
View file @
fabe17a7
...
@@ -982,6 +982,7 @@ class BridgeTowerPreTrainedModel(PreTrainedModel):
...
@@ -982,6 +982,7 @@ class BridgeTowerPreTrainedModel(PreTrainedModel):
base_model_prefix
=
"bridgetower"
base_model_prefix
=
"bridgetower"
supports_gradient_checkpointing
=
False
supports_gradient_checkpointing
=
False
_no_split_modules
=
[
"BridgeTowerSelfAttention"
,
"BridgeTowerResidualAttention"
]
_no_split_modules
=
[
"BridgeTowerSelfAttention"
,
"BridgeTowerResidualAttention"
]
_skip_keys_device_placement
=
"past_key_values"
def
_init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
if
isinstance
(
module
,
BridgeTowerVisionModel
):
if
isinstance
(
module
,
BridgeTowerVisionModel
):
...
...
src/transformers/models/codegen/modeling_codegen.py
View file @
fabe17a7
...
@@ -315,6 +315,7 @@ class CodeGenPreTrainedModel(PreTrainedModel):
...
@@ -315,6 +315,7 @@ class CodeGenPreTrainedModel(PreTrainedModel):
base_model_prefix
=
"transformer"
base_model_prefix
=
"transformer"
supports_gradient_checkpointing
=
True
supports_gradient_checkpointing
=
True
_no_split_modules
=
[
"CodeGenBlock"
]
_no_split_modules
=
[
"CodeGenBlock"
]
_skip_keys_device_placement
=
"past_key_values"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
().
__init__
(
*
inputs
,
**
kwargs
)
super
().
__init__
(
*
inputs
,
**
kwargs
)
...
...
src/transformers/models/gpt2/modeling_gpt2.py
View file @
fabe17a7
...
@@ -449,6 +449,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
...
@@ -449,6 +449,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
is_parallelizable
=
True
is_parallelizable
=
True
supports_gradient_checkpointing
=
True
supports_gradient_checkpointing
=
True
_no_split_modules
=
[
"GPT2Block"
]
_no_split_modules
=
[
"GPT2Block"
]
_skip_keys_device_placement
=
"past_key_values"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
().
__init__
(
*
inputs
,
**
kwargs
)
super
().
__init__
(
*
inputs
,
**
kwargs
)
...
...
src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
View file @
fabe17a7
...
@@ -372,6 +372,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
...
@@ -372,6 +372,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
base_model_prefix
=
"transformer"
base_model_prefix
=
"transformer"
supports_gradient_checkpointing
=
True
supports_gradient_checkpointing
=
True
_no_split_modules
=
[
"GPTBigCodeBlock"
]
_no_split_modules
=
[
"GPTBigCodeBlock"
]
_skip_keys_device_placement
=
"past_key_values"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
().
__init__
(
*
inputs
,
**
kwargs
)
super
().
__init__
(
*
inputs
,
**
kwargs
)
...
...
src/transformers/models/gpt_neo/modeling_gpt_neo.py
View file @
fabe17a7
...
@@ -363,6 +363,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel):
...
@@ -363,6 +363,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel):
base_model_prefix
=
"transformer"
base_model_prefix
=
"transformer"
supports_gradient_checkpointing
=
True
supports_gradient_checkpointing
=
True
_no_split_modules
=
[
"GPTNeoBlock"
]
_no_split_modules
=
[
"GPTNeoBlock"
]
_skip_keys_device_placement
=
"past_key_values"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
().
__init__
(
*
inputs
,
**
kwargs
)
super
().
__init__
(
*
inputs
,
**
kwargs
)
...
...
src/transformers/models/gpt_neox/modeling_gpt_neox.py
View file @
fabe17a7
...
@@ -62,6 +62,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
...
@@ -62,6 +62,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
base_model_prefix
=
"gpt_neox"
base_model_prefix
=
"gpt_neox"
supports_gradient_checkpointing
=
True
supports_gradient_checkpointing
=
True
_no_split_modules
=
[
"GPTNeoXLayer"
]
_no_split_modules
=
[
"GPTNeoXLayer"
]
_skip_keys_device_placement
=
"past_key_values"
def
_init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
"""Initialize the weights"""
"""Initialize the weights"""
...
...
src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py
View file @
fabe17a7
...
@@ -50,6 +50,7 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
...
@@ -50,6 +50,7 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
base_model_prefix
=
"gpt_neox_japanese"
base_model_prefix
=
"gpt_neox_japanese"
supports_gradient_checkpointing
=
True
supports_gradient_checkpointing
=
True
_no_split_modules
=
[
"GPTNeoXJapaneseLayer"
]
_no_split_modules
=
[
"GPTNeoXJapaneseLayer"
]
_skip_keys_device_placement
=
"past_key_values"
def
_init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
"""Initialize the weights"""
"""Initialize the weights"""
...
...
src/transformers/models/gptj/modeling_gptj.py
View file @
fabe17a7
...
@@ -340,6 +340,7 @@ class GPTJPreTrainedModel(PreTrainedModel):
...
@@ -340,6 +340,7 @@ class GPTJPreTrainedModel(PreTrainedModel):
is_parallelizable
=
True
is_parallelizable
=
True
supports_gradient_checkpointing
=
True
supports_gradient_checkpointing
=
True
_no_split_modules
=
[
"GPTJBlock"
]
_no_split_modules
=
[
"GPTJBlock"
]
_skip_keys_device_placement
=
"past_key_values"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
().
__init__
(
*
inputs
,
**
kwargs
)
super
().
__init__
(
*
inputs
,
**
kwargs
)
...
...
src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py
View file @
fabe17a7
...
@@ -692,6 +692,7 @@ class GPTSanJapanesePreTrainedModel(PreTrainedModel):
...
@@ -692,6 +692,7 @@ class GPTSanJapanesePreTrainedModel(PreTrainedModel):
base_model_prefix
=
"gptsan_japanese"
base_model_prefix
=
"gptsan_japanese"
supports_gradient_checkpointing
=
False
supports_gradient_checkpointing
=
False
_no_split_modules
=
[
"GPTSanJapaneseBlock"
]
_no_split_modules
=
[
"GPTSanJapaneseBlock"
]
_skip_keys_device_placement
=
"past_key_values"
@
property
@
property
def
dummy_inputs
(
self
):
def
dummy_inputs
(
self
):
...
...
src/transformers/models/llama/modeling_llama.py
View file @
fabe17a7
...
@@ -342,6 +342,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
...
@@ -342,6 +342,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
base_model_prefix
=
"model"
base_model_prefix
=
"model"
supports_gradient_checkpointing
=
True
supports_gradient_checkpointing
=
True
_no_split_modules
=
[
"LlamaDecoderLayer"
]
_no_split_modules
=
[
"LlamaDecoderLayer"
]
_skip_keys_device_placement
=
"past_key_values"
_keys_to_ignore_on_load_unexpected
=
[
r
"decoder\.version"
]
_keys_to_ignore_on_load_unexpected
=
[
r
"decoder\.version"
]
def
_init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment