Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3e41cf13
"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "5019aabfacf7599b9a6b4e7a1adc1fb5c9017727"
Unverified
Commit
3e41cf13
authored
Aug 10, 2023
by
Joao Gante
Committed by
GitHub
Aug 10, 2023
Browse files
Generate: Load generation config when `device_map` is passed (#25413)
parent
d0839f1a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
10 deletions
+28
-10
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+14
-10
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+14
-0
No files found.
src/transformers/modeling_utils.py
View file @
3e41cf13
...
@@ -2849,9 +2849,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -2849,9 +2849,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"'sequential'."
"'sequential'."
)
)
kwargs
=
{
"no_split_module_classes"
:
no_split_modules
}
device_map_
kwargs
=
{
"no_split_module_classes"
:
no_split_modules
}
if
"special_dtypes"
in
inspect
.
signature
(
infer_auto_device_map
).
parameters
:
if
"special_dtypes"
in
inspect
.
signature
(
infer_auto_device_map
).
parameters
:
kwargs
[
"special_dtypes"
]
=
special_dtypes
device_map_
kwargs
[
"special_dtypes"
]
=
special_dtypes
elif
len
(
special_dtypes
)
>
0
:
elif
len
(
special_dtypes
)
>
0
:
logger
.
warning
(
logger
.
warning
(
"This model has some weights that should be kept in higher precision, you need to upgrade "
"This model has some weights that should be kept in higher precision, you need to upgrade "
...
@@ -2863,12 +2863,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -2863,12 +2863,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
dtype
=
target_dtype
,
dtype
=
target_dtype
,
low_zero
=
(
device_map
==
"balanced_low_0"
),
low_zero
=
(
device_map
==
"balanced_low_0"
),
max_memory
=
max_memory
,
max_memory
=
max_memory
,
**
kwargs
,
**
device_map_
kwargs
,
)
)
kwargs
[
"max_memory"
]
=
max_memory
device_map_
kwargs
[
"max_memory"
]
=
max_memory
# Make sure tied weights are tied before creating the device map.
# Make sure tied weights are tied before creating the device map.
model
.
tie_weights
()
model
.
tie_weights
()
device_map
=
infer_auto_device_map
(
model
,
dtype
=
target_dtype
,
**
kwargs
)
device_map
=
infer_auto_device_map
(
model
,
dtype
=
target_dtype
,
**
device_map_
kwargs
)
if
load_in_8bit
or
load_in_4bit
:
if
load_in_8bit
or
load_in_4bit
:
# The LM head / tied weights or any last module can stay on disk / CPU
# The LM head / tied weights or any last module can stay on disk / CPU
...
@@ -2966,7 +2966,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -2966,7 +2966,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model
.
eval
()
model
.
eval
()
# If it is a model with generation capabilities, attempt to load the generation config
# If it is a model with generation capabilities, attempt to load the generation config
if
model
.
can_generate
():
if
model
.
can_generate
()
and
pretrained_model_name_or_path
is
not
None
:
try
:
try
:
model
.
generation_config
=
GenerationConfig
.
from_pretrained
(
model
.
generation_config
=
GenerationConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
pretrained_model_name_or_path
,
...
@@ -2982,7 +2982,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -2982,7 +2982,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
_from_pipeline
=
from_pipeline
,
_from_pipeline
=
from_pipeline
,
**
kwargs
,
**
kwargs
,
)
)
except
(
OSError
,
TypeError
)
:
except
OSError
:
logger
.
info
(
logger
.
info
(
"Generation config file not found, using a generation config created from the model config."
"Generation config file not found, using a generation config created from the model config."
)
)
...
@@ -2990,10 +2990,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -2990,10 +2990,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Dispatch model with hooks on all devices if necessary
# Dispatch model with hooks on all devices if necessary
if
device_map
is
not
None
:
if
device_map
is
not
None
:
kwargs
=
{
"device_map"
:
device_map
,
"offload_dir"
:
offload_folder
,
"offload_index"
:
offload_index
}
device_map_kwargs
=
{
"device_map"
:
device_map
,
"offload_dir"
:
offload_folder
,
"offload_index"
:
offload_index
,
}
if
"skip_keys"
in
inspect
.
signature
(
dispatch_model
).
parameters
:
if
"skip_keys"
in
inspect
.
signature
(
dispatch_model
).
parameters
:
kwargs
[
"skip_keys"
]
=
model
.
_skip_keys_device_placement
device_map_
kwargs
[
"skip_keys"
]
=
model
.
_skip_keys_device_placement
dispatch_model
(
model
,
**
kwargs
)
dispatch_model
(
model
,
**
device_map_
kwargs
)
if
output_loading_info
:
if
output_loading_info
:
if
loading_info
is
None
:
if
loading_info
is
None
:
...
...
tests/test_modeling_utils.py
View file @
3e41cf13
...
@@ -1036,6 +1036,20 @@ class ModelUtilsTest(TestCasePlus):
...
@@ -1036,6 +1036,20 @@ class ModelUtilsTest(TestCasePlus):
self
.
assertEqual
(
model
.
__class__
.
__name__
,
model_ref
.
__class__
.
__name__
)
self
.
assertEqual
(
model
.
__class__
.
__name__
,
model_ref
.
__class__
.
__name__
)
def
test_generation_config_is_loaded_with_model
(
self
):
# Note: `joaogante/tiny-random-gpt2-with-generation-config` has a `generation_config.json` containing a dummy
# `transformers_version` field set to `foo`. If loading the file fails, this test also fails.
# 1. Load without further parameters
model
=
AutoModelForCausalLM
.
from_pretrained
(
"joaogante/tiny-random-gpt2-with-generation-config"
)
self
.
assertEqual
(
model
.
generation_config
.
transformers_version
,
"foo"
)
# 2. Load with `device_map`
model
=
AutoModelForCausalLM
.
from_pretrained
(
"joaogante/tiny-random-gpt2-with-generation-config"
,
device_map
=
"auto"
)
self
.
assertEqual
(
model
.
generation_config
.
transformers_version
,
"foo"
)
@
require_torch
@
require_torch
@
is_staging_test
@
is_staging_test
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment