Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
41496b95
Unverified
Commit
41496b95
authored
Oct 24, 2023
by
Marc Sun
Committed by
GitHub
Oct 24, 2023
Browse files
Add fuyu device map (#26949)
* add _no_split_modules * style * fix _no_split_modules * add doc
parent
b18e3140
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
6 deletions
+31
-6
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+30
-6
src/transformers/models/fuyu/modeling_fuyu.py
src/transformers/models/fuyu/modeling_fuyu.py
+1
-0
No files found.
src/transformers/modeling_utils.py
View file @
41496b95
...
...
@@ -1507,6 +1507,35 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if
hasattr
(
output_embeddings
,
"out_features"
)
and
hasattr
(
input_embeddings
,
"num_embeddings"
):
output_embeddings
.
out_features
=
input_embeddings
.
num_embeddings
def
_get_no_split_modules
(
self
,
device_map
:
str
):
"""
Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
get the underlying `_no_split_modules`.
Args:
device_map (`str`):
The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"]
Returns:
`List[str]`: List of modules that should not be split
"""
if
self
.
_no_split_modules
is
None
:
raise
ValueError
(
f
"
{
self
.
__class__
.
__name__
}
does not support `device_map='
{
device_map
}
'`. To implement support, the model "
"class needs to implement the `_no_split_modules` attribute."
)
_no_split_modules
=
set
(
self
.
_no_split_modules
)
for
module
in
self
.
modules
():
if
isinstance
(
module
,
PreTrainedModel
):
if
module
.
_no_split_modules
is
None
:
raise
ValueError
(
f
"
{
module
.
__class__
.
__name__
}
does not support `device_map='
{
device_map
}
'`. To implement support, the model "
"class needs to implement the `_no_split_modules` attribute."
)
else
:
_no_split_modules
=
_no_split_modules
|
set
(
module
.
_no_split_modules
)
return
list
(
_no_split_modules
)
def
resize_token_embeddings
(
self
,
new_num_tokens
:
Optional
[
int
]
=
None
,
pad_to_multiple_of
:
Optional
[
int
]
=
None
)
->
nn
.
Embedding
:
...
...
@@ -3226,12 +3255,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
elif
load_in_8bit
:
target_dtype
=
torch
.
int8
if
model
.
_no_split_modules
is
None
:
raise
ValueError
(
f
"
{
model
.
__class__
.
__name__
}
does not support `device_map='
{
device_map
}
'`. To implement support, the model "
"class needs to implement the `_no_split_modules` attribute."
)
no_split_modules
=
model
.
_no_split_modules
no_split_modules
=
model
.
_get_no_split_modules
(
device_map
)
if
device_map
not
in
[
"auto"
,
"balanced"
,
"balanced_low_0"
,
"sequential"
]:
raise
ValueError
(
"If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or "
...
...
src/transformers/models/fuyu/modeling_fuyu.py
View file @
41496b95
...
...
@@ -262,6 +262,7 @@ class FuyuForCausalLM(FuyuPreTrainedModel):
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
()(
input_ids
)
if
image_patches
is
not
None
and
past_key_values
is
None
:
patch_embeddings
=
self
.
vision_embed_tokens
(
image_patches
.
to
(
self
.
vision_embed_tokens
.
weight
.
dtype
))
patch_embeddings
=
patch_embeddings
.
to
(
inputs_embeds
.
device
)
inputs_embeds
=
self
.
gather_continuous_embeddings
(
word_embeddings
=
inputs_embeds
,
continuous_embeddings
=
patch_embeddings
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment