Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
95020f20
Unverified
Commit
95020f20
authored
Nov 01, 2023
by
Lysandre Debut
Committed by
GitHub
Nov 01, 2023
Browse files
Fix CPU offload + disk offload tests (#27204)
Fix disk offload tests + weight sharing issues
parent
c9e72f55
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
147 additions
and
5 deletions
+147
-5
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+3
-1
src/transformers/models/bart/modeling_bart.py
src/transformers/models/bart/modeling_bart.py
+5
-0
src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
...ormers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
+5
-0
src/transformers/models/longt5/modeling_longt5.py
src/transformers/models/longt5/modeling_longt5.py
+14
-0
src/transformers/models/m2m_100/modeling_m2m_100.py
src/transformers/models/m2m_100/modeling_m2m_100.py
+5
-0
src/transformers/models/nllb_moe/modeling_nllb_moe.py
src/transformers/models/nllb_moe/modeling_nllb_moe.py
+5
-0
src/transformers/models/plbart/modeling_plbart.py
src/transformers/models/plbart/modeling_plbart.py
+5
-0
src/transformers/models/seamless_m4t/modeling_seamless_m4t.py
...transformers/models/seamless_m4t/modeling_seamless_m4t.py
+5
-0
src/transformers/models/switch_transformers/modeling_switch_transformers.py
...odels/switch_transformers/modeling_switch_transformers.py
+14
-0
src/transformers/models/t5/modeling_t5.py
src/transformers/models/t5/modeling_t5.py
+19
-0
src/transformers/models/umt5/modeling_umt5.py
src/transformers/models/umt5/modeling_umt5.py
+23
-0
tests/models/vitdet/test_modeling_vitdet.py
tests/models/vitdet/test_modeling_vitdet.py
+5
-1
tests/models/whisper/test_modeling_whisper.py
tests/models/whisper/test_modeling_whisper.py
+5
-1
tests/test_modeling_common.py
tests/test_modeling_common.py
+34
-2
No files found.
src/transformers/modeling_utils.py
View file @
95020f20
...
@@ -4576,7 +4576,9 @@ def expand_device_map(device_map, param_names):
...
@@ -4576,7 +4576,9 @@ def expand_device_map(device_map, param_names):
"""
"""
new_device_map
=
{}
new_device_map
=
{}
for
module
,
device
in
device_map
.
items
():
for
module
,
device
in
device_map
.
items
():
new_device_map
.
update
({
p
:
device
for
p
in
param_names
if
p
==
module
or
p
.
startswith
(
f
"
{
module
}
."
)})
new_device_map
.
update
(
{
p
:
device
for
p
in
param_names
if
p
==
module
or
p
.
startswith
(
f
"
{
module
}
."
)
or
module
==
""
}
)
return
new_device_map
return
new_device_map
...
...
src/transformers/models/bart/modeling_bart.py
View file @
95020f20
...
@@ -1125,6 +1125,11 @@ class BartModel(BartPreTrainedModel):
...
@@ -1125,6 +1125,11 @@ class BartModel(BartPreTrainedModel):
# Initialize weights and apply final processing
# Initialize weights and apply final processing
self
.
post_init
()
self
.
post_init
()
def
_tie_weights
(
self
):
if
self
.
config
.
tie_word_embeddings
:
self
.
_tie_or_clone_weights
(
self
.
encoder
.
embed_tokens
,
self
.
shared
)
self
.
_tie_or_clone_weights
(
self
.
decoder
.
embed_tokens
,
self
.
shared
)
def
get_input_embeddings
(
self
):
def
get_input_embeddings
(
self
):
return
self
.
shared
return
self
.
shared
...
...
src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
View file @
95020f20
...
@@ -2312,6 +2312,11 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
...
@@ -2312,6 +2312,11 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
self
.
encoder
.
embed_tokens
=
self
.
shared
self
.
encoder
.
embed_tokens
=
self
.
shared
self
.
decoder
.
embed_tokens
=
self
.
shared
self
.
decoder
.
embed_tokens
=
self
.
shared
def
_tie_weights
(
self
):
if
self
.
config
.
tie_word_embeddings
:
self
.
_tie_or_clone_weights
(
self
.
encoder
.
embed_tokens
,
self
.
shared
)
self
.
_tie_or_clone_weights
(
self
.
decoder
.
embed_tokens
,
self
.
shared
)
def
get_encoder
(
self
):
def
get_encoder
(
self
):
return
self
.
encoder
return
self
.
encoder
...
...
src/transformers/models/longt5/modeling_longt5.py
View file @
95020f20
...
@@ -1783,6 +1783,11 @@ class LongT5Model(LongT5PreTrainedModel):
...
@@ -1783,6 +1783,11 @@ class LongT5Model(LongT5PreTrainedModel):
self
.
encoder
.
set_input_embeddings
(
new_embeddings
)
self
.
encoder
.
set_input_embeddings
(
new_embeddings
)
self
.
decoder
.
set_input_embeddings
(
new_embeddings
)
self
.
decoder
.
set_input_embeddings
(
new_embeddings
)
def
_tie_weights
(
self
):
if
self
.
config
.
tie_word_embeddings
:
self
.
_tie_or_clone_weights
(
self
.
encoder
.
embed_tokens
,
self
.
shared
)
self
.
_tie_or_clone_weights
(
self
.
decoder
.
embed_tokens
,
self
.
shared
)
def
get_encoder
(
self
):
def
get_encoder
(
self
):
return
self
.
encoder
return
self
.
encoder
...
@@ -1937,6 +1942,11 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel):
...
@@ -1937,6 +1942,11 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel):
self
.
encoder
.
set_input_embeddings
(
new_embeddings
)
self
.
encoder
.
set_input_embeddings
(
new_embeddings
)
self
.
decoder
.
set_input_embeddings
(
new_embeddings
)
self
.
decoder
.
set_input_embeddings
(
new_embeddings
)
def
_tie_weights
(
self
):
if
self
.
config
.
tie_word_embeddings
:
self
.
_tie_or_clone_weights
(
self
.
encoder
.
embed_tokens
,
self
.
shared
)
self
.
_tie_or_clone_weights
(
self
.
decoder
.
embed_tokens
,
self
.
shared
)
def
set_output_embeddings
(
self
,
new_embeddings
):
def
set_output_embeddings
(
self
,
new_embeddings
):
self
.
lm_head
=
new_embeddings
self
.
lm_head
=
new_embeddings
...
@@ -2170,6 +2180,10 @@ class LongT5EncoderModel(LongT5PreTrainedModel):
...
@@ -2170,6 +2180,10 @@ class LongT5EncoderModel(LongT5PreTrainedModel):
self
.
shared
=
new_embeddings
self
.
shared
=
new_embeddings
self
.
encoder
.
set_input_embeddings
(
new_embeddings
)
self
.
encoder
.
set_input_embeddings
(
new_embeddings
)
def
_tie_weights
(
self
):
if
self
.
config
.
tie_word_embeddings
:
self
.
_tie_or_clone_weights
(
self
.
encoder
.
embed_tokens
,
self
.
shared
)
def
get_encoder
(
self
):
def
get_encoder
(
self
):
return
self
.
encoder
return
self
.
encoder
...
...
src/transformers/models/m2m_100/modeling_m2m_100.py
View file @
95020f20
...
@@ -1103,6 +1103,11 @@ class M2M100Model(M2M100PreTrainedModel):
...
@@ -1103,6 +1103,11 @@ class M2M100Model(M2M100PreTrainedModel):
self
.
encoder
.
embed_tokens
=
self
.
shared
self
.
encoder
.
embed_tokens
=
self
.
shared
self
.
decoder
.
embed_tokens
=
self
.
shared
self
.
decoder
.
embed_tokens
=
self
.
shared
def
_tie_weights
(
self
):
if
self
.
config
.
tie_word_embeddings
:
self
.
_tie_or_clone_weights
(
self
.
encoder
.
embed_tokens
,
self
.
shared
)
self
.
_tie_or_clone_weights
(
self
.
decoder
.
embed_tokens
,
self
.
shared
)
def
get_encoder
(
self
):
def
get_encoder
(
self
):
return
self
.
encoder
return
self
.
encoder
...
...
src/transformers/models/nllb_moe/modeling_nllb_moe.py
View file @
95020f20
...
@@ -1471,6 +1471,11 @@ class NllbMoeModel(NllbMoePreTrainedModel):
...
@@ -1471,6 +1471,11 @@ class NllbMoeModel(NllbMoePreTrainedModel):
self
.
encoder
.
embed_tokens
=
self
.
shared
self
.
encoder
.
embed_tokens
=
self
.
shared
self
.
decoder
.
embed_tokens
=
self
.
shared
self
.
decoder
.
embed_tokens
=
self
.
shared
def
_tie_weights
(
self
):
if
self
.
config
.
tie_word_embeddings
:
self
.
_tie_or_clone_weights
(
self
.
encoder
.
embed_tokens
,
self
.
shared
)
self
.
_tie_or_clone_weights
(
self
.
decoder
.
embed_tokens
,
self
.
shared
)
def
get_encoder
(
self
):
def
get_encoder
(
self
):
return
self
.
encoder
return
self
.
encoder
...
...
src/transformers/models/plbart/modeling_plbart.py
View file @
95020f20
...
@@ -1084,6 +1084,11 @@ class PLBartModel(PLBartPreTrainedModel):
...
@@ -1084,6 +1084,11 @@ class PLBartModel(PLBartPreTrainedModel):
self
.
encoder
.
embed_tokens
=
self
.
shared
self
.
encoder
.
embed_tokens
=
self
.
shared
self
.
decoder
.
embed_tokens
=
self
.
shared
self
.
decoder
.
embed_tokens
=
self
.
shared
def
_tie_weights
(
self
):
if
self
.
config
.
tie_word_embeddings
:
self
.
_tie_or_clone_weights
(
self
.
encoder
.
embed_tokens
,
self
.
shared
)
self
.
_tie_or_clone_weights
(
self
.
decoder
.
embed_tokens
,
self
.
shared
)
def
get_encoder
(
self
):
def
get_encoder
(
self
):
return
self
.
encoder
return
self
.
encoder
...
...
src/transformers/models/seamless_m4t/modeling_seamless_m4t.py
View file @
95020f20
...
@@ -4125,6 +4125,11 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel):
...
@@ -4125,6 +4125,11 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel):
self
.
text_decoder
.
embed_tokens
=
value
self
.
text_decoder
.
embed_tokens
=
value
self
.
shared
=
value
self
.
shared
=
value
def
_tie_weights
(
self
):
if
self
.
config
.
tie_word_embeddings
:
self
.
_tie_or_clone_weights
(
self
.
text_encoder
.
embed_tokens
,
self
.
shared
)
self
.
_tie_or_clone_weights
(
self
.
text_decoder
.
embed_tokens
,
self
.
shared
)
@
add_start_docstrings_to_model_forward
(
M4T_MODEL_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
M4T_MODEL_INPUTS_DOCSTRING
)
def
forward
(
def
forward
(
self
,
self
,
...
...
src/transformers/models/switch_transformers/modeling_switch_transformers.py
View file @
95020f20
...
@@ -1329,6 +1329,11 @@ class SwitchTransformersModel(SwitchTransformersPreTrainedModel):
...
@@ -1329,6 +1329,11 @@ class SwitchTransformersModel(SwitchTransformersPreTrainedModel):
self
.
encoder
.
set_input_embeddings
(
new_embeddings
)
self
.
encoder
.
set_input_embeddings
(
new_embeddings
)
self
.
decoder
.
set_input_embeddings
(
new_embeddings
)
self
.
decoder
.
set_input_embeddings
(
new_embeddings
)
def
_tie_weights
(
self
):
if
self
.
config
.
tie_word_embeddings
:
self
.
_tie_or_clone_weights
(
self
.
encoder
.
embed_tokens
,
self
.
shared
)
self
.
_tie_or_clone_weights
(
self
.
decoder
.
embed_tokens
,
self
.
shared
)
def
get_encoder
(
self
):
def
get_encoder
(
self
):
return
self
.
encoder
return
self
.
encoder
...
@@ -1505,6 +1510,11 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
...
@@ -1505,6 +1510,11 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
self
.
encoder
.
set_input_embeddings
(
new_embeddings
)
self
.
encoder
.
set_input_embeddings
(
new_embeddings
)
self
.
decoder
.
set_input_embeddings
(
new_embeddings
)
self
.
decoder
.
set_input_embeddings
(
new_embeddings
)
def
_tie_weights
(
self
):
if
self
.
config
.
tie_word_embeddings
:
self
.
_tie_or_clone_weights
(
self
.
encoder
.
embed_tokens
,
self
.
shared
)
self
.
_tie_or_clone_weights
(
self
.
decoder
.
embed_tokens
,
self
.
shared
)
def
set_output_embeddings
(
self
,
new_embeddings
):
def
set_output_embeddings
(
self
,
new_embeddings
):
self
.
lm_head
=
new_embeddings
self
.
lm_head
=
new_embeddings
...
@@ -1807,6 +1817,10 @@ class SwitchTransformersEncoderModel(SwitchTransformersPreTrainedModel):
...
@@ -1807,6 +1817,10 @@ class SwitchTransformersEncoderModel(SwitchTransformersPreTrainedModel):
self
.
shared
=
new_embeddings
self
.
shared
=
new_embeddings
self
.
encoder
.
set_input_embeddings
(
new_embeddings
)
self
.
encoder
.
set_input_embeddings
(
new_embeddings
)
def
_tie_weights
(
self
):
if
self
.
config
.
tie_word_embeddings
:
self
.
_tie_or_clone_weights
(
self
.
encoder
.
embed_tokens
,
self
.
shared
)
def
get_encoder
(
self
):
def
get_encoder
(
self
):
return
self
.
encoder
return
self
.
encoder
...
...
src/transformers/models/t5/modeling_t5.py
View file @
95020f20
...
@@ -1414,6 +1414,11 @@ class T5Model(T5PreTrainedModel):
...
@@ -1414,6 +1414,11 @@ class T5Model(T5PreTrainedModel):
self
.
encoder
.
set_input_embeddings
(
new_embeddings
)
self
.
encoder
.
set_input_embeddings
(
new_embeddings
)
self
.
decoder
.
set_input_embeddings
(
new_embeddings
)
self
.
decoder
.
set_input_embeddings
(
new_embeddings
)
def
_tie_weights
(
self
):
if
self
.
config
.
tie_word_embeddings
:
self
.
_tie_or_clone_weights
(
self
.
encoder
.
embed_tokens
,
self
.
shared
)
self
.
_tie_or_clone_weights
(
self
.
decoder
.
embed_tokens
,
self
.
shared
)
def
get_encoder
(
self
):
def
get_encoder
(
self
):
return
self
.
encoder
return
self
.
encoder
...
@@ -1620,6 +1625,11 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
...
@@ -1620,6 +1625,11 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
self
.
encoder
.
set_input_embeddings
(
new_embeddings
)
self
.
encoder
.
set_input_embeddings
(
new_embeddings
)
self
.
decoder
.
set_input_embeddings
(
new_embeddings
)
self
.
decoder
.
set_input_embeddings
(
new_embeddings
)
def
_tie_weights
(
self
):
if
self
.
config
.
tie_word_embeddings
:
self
.
_tie_or_clone_weights
(
self
.
encoder
.
embed_tokens
,
self
.
shared
)
self
.
_tie_or_clone_weights
(
self
.
decoder
.
embed_tokens
,
self
.
shared
)
def
set_output_embeddings
(
self
,
new_embeddings
):
def
set_output_embeddings
(
self
,
new_embeddings
):
self
.
lm_head
=
new_embeddings
self
.
lm_head
=
new_embeddings
...
@@ -1920,6 +1930,10 @@ class T5EncoderModel(T5PreTrainedModel):
...
@@ -1920,6 +1930,10 @@ class T5EncoderModel(T5PreTrainedModel):
self
.
shared
=
new_embeddings
self
.
shared
=
new_embeddings
self
.
encoder
.
set_input_embeddings
(
new_embeddings
)
self
.
encoder
.
set_input_embeddings
(
new_embeddings
)
def
_tie_weights
(
self
):
if
self
.
config
.
tie_word_embeddings
:
self
.
_tie_or_clone_weights
(
self
.
encoder
.
embed_tokens
,
self
.
shared
)
def
get_encoder
(
self
):
def
get_encoder
(
self
):
return
self
.
encoder
return
self
.
encoder
...
@@ -2152,6 +2166,11 @@ class T5ForQuestionAnswering(T5PreTrainedModel):
...
@@ -2152,6 +2166,11 @@ class T5ForQuestionAnswering(T5PreTrainedModel):
self
.
encoder
.
set_input_embeddings
(
new_embeddings
)
self
.
encoder
.
set_input_embeddings
(
new_embeddings
)
self
.
decoder
.
set_input_embeddings
(
new_embeddings
)
self
.
decoder
.
set_input_embeddings
(
new_embeddings
)
def
_tie_weights
(
self
):
if
self
.
config
.
tie_word_embeddings
:
self
.
_tie_or_clone_weights
(
self
.
encoder
.
embed_tokens
,
self
.
shared
)
self
.
_tie_or_clone_weights
(
self
.
decoder
.
embed_tokens
,
self
.
shared
)
def
get_encoder
(
self
):
def
get_encoder
(
self
):
return
self
.
encoder
return
self
.
encoder
...
...
src/transformers/models/umt5/modeling_umt5.py
View file @
95020f20
...
@@ -973,6 +973,12 @@ class UMT5Model(UMT5PreTrainedModel):
...
@@ -973,6 +973,12 @@ class UMT5Model(UMT5PreTrainedModel):
self
.
encoder
.
set_input_embeddings
(
new_embeddings
)
self
.
encoder
.
set_input_embeddings
(
new_embeddings
)
self
.
decoder
.
set_input_embeddings
(
new_embeddings
)
self
.
decoder
.
set_input_embeddings
(
new_embeddings
)
# Copied from transformers.models.t5.modeling_t5.T5Model._tie_weights
def
_tie_weights
(
self
):
if
self
.
config
.
tie_word_embeddings
:
self
.
_tie_or_clone_weights
(
self
.
encoder
.
embed_tokens
,
self
.
shared
)
self
.
_tie_or_clone_weights
(
self
.
decoder
.
embed_tokens
,
self
.
shared
)
# Copied from transformers.models.t5.modeling_t5.T5Model.get_encoder
# Copied from transformers.models.t5.modeling_t5.T5Model.get_encoder
def
get_encoder
(
self
):
def
get_encoder
(
self
):
return
self
.
encoder
return
self
.
encoder
...
@@ -1142,6 +1148,12 @@ class UMT5ForConditionalGeneration(UMT5PreTrainedModel):
...
@@ -1142,6 +1148,12 @@ class UMT5ForConditionalGeneration(UMT5PreTrainedModel):
self
.
encoder
.
set_input_embeddings
(
new_embeddings
)
self
.
encoder
.
set_input_embeddings
(
new_embeddings
)
self
.
decoder
.
set_input_embeddings
(
new_embeddings
)
self
.
decoder
.
set_input_embeddings
(
new_embeddings
)
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration._tie_weights
def
_tie_weights
(
self
):
if
self
.
config
.
tie_word_embeddings
:
self
.
_tie_or_clone_weights
(
self
.
encoder
.
embed_tokens
,
self
.
shared
)
self
.
_tie_or_clone_weights
(
self
.
decoder
.
embed_tokens
,
self
.
shared
)
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.set_output_embeddings
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.set_output_embeddings
def
set_output_embeddings
(
self
,
new_embeddings
):
def
set_output_embeddings
(
self
,
new_embeddings
):
self
.
lm_head
=
new_embeddings
self
.
lm_head
=
new_embeddings
...
@@ -1380,6 +1392,11 @@ class UMT5EncoderModel(UMT5PreTrainedModel):
...
@@ -1380,6 +1392,11 @@ class UMT5EncoderModel(UMT5PreTrainedModel):
self
.
shared
=
new_embeddings
self
.
shared
=
new_embeddings
self
.
encoder
.
set_input_embeddings
(
new_embeddings
)
self
.
encoder
.
set_input_embeddings
(
new_embeddings
)
# Copied from transformers.models.t5.modeling_t5.T5EncoderModel._tie_weights
def
_tie_weights
(
self
):
if
self
.
config
.
tie_word_embeddings
:
self
.
_tie_or_clone_weights
(
self
.
encoder
.
embed_tokens
,
self
.
shared
)
# Copied from transformers.models.t5.modeling_t5.T5EncoderModel.get_encoder
# Copied from transformers.models.t5.modeling_t5.T5EncoderModel.get_encoder
def
get_encoder
(
self
):
def
get_encoder
(
self
):
return
self
.
encoder
return
self
.
encoder
...
@@ -1615,6 +1632,12 @@ class UMT5ForQuestionAnswering(UMT5PreTrainedModel):
...
@@ -1615,6 +1632,12 @@ class UMT5ForQuestionAnswering(UMT5PreTrainedModel):
self
.
encoder
.
set_input_embeddings
(
new_embeddings
)
self
.
encoder
.
set_input_embeddings
(
new_embeddings
)
self
.
decoder
.
set_input_embeddings
(
new_embeddings
)
self
.
decoder
.
set_input_embeddings
(
new_embeddings
)
# Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering._tie_weights
def
_tie_weights
(
self
):
if
self
.
config
.
tie_word_embeddings
:
self
.
_tie_or_clone_weights
(
self
.
encoder
.
embed_tokens
,
self
.
shared
)
self
.
_tie_or_clone_weights
(
self
.
decoder
.
embed_tokens
,
self
.
shared
)
# Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.get_encoder
# Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.get_encoder
def
get_encoder
(
self
):
def
get_encoder
(
self
):
return
self
.
encoder
return
self
.
encoder
...
...
tests/models/vitdet/test_modeling_vitdet.py
View file @
95020f20
...
@@ -182,7 +182,11 @@ class VitDetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
...
@@ -182,7 +182,11 @@ class VitDetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
# TODO: Fix me (once this model gets more usage)
# TODO: Fix me (once this model gets more usage)
@
unittest
.
skip
(
"Does not work on the tiny model as we keep hitting edge cases."
)
@
unittest
.
skip
(
"Does not work on the tiny model as we keep hitting edge cases."
)
def
test_disk_offload
(
self
):
def
test_disk_offload_bin
(
self
):
super
().
test_disk_offload
()
@
unittest
.
skip
(
"Does not work on the tiny model as we keep hitting edge cases."
)
def
test_disk_offload_safetensors
(
self
):
super
().
test_disk_offload
()
super
().
test_disk_offload
()
# TODO: Fix me (once this model gets more usage)
# TODO: Fix me (once this model gets more usage)
...
...
tests/models/whisper/test_modeling_whisper.py
View file @
95020f20
...
@@ -1788,7 +1788,11 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
...
@@ -1788,7 +1788,11 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
pass
pass
@
unittest
.
skip
(
reason
=
"Some undefined behavior encountered with tiny versions of this model. Skip for now."
)
@
unittest
.
skip
(
reason
=
"Some undefined behavior encountered with tiny versions of this model. Skip for now."
)
def
test_disk_offload
(
self
):
def
test_disk_offload_bin
(
self
):
pass
@
unittest
.
skip
(
reason
=
"Some undefined behavior encountered with tiny versions of this model. Skip for now."
)
def
test_disk_offload_safetensors
(
self
):
pass
pass
@
unittest
.
skip
(
reason
=
"Some undefined behavior encountered with tiny versions of this model. Skip for now."
)
@
unittest
.
skip
(
reason
=
"Some undefined behavior encountered with tiny versions of this model. Skip for now."
)
...
...
tests/test_modeling_common.py
View file @
95020f20
...
@@ -2578,7 +2578,7 @@ class ModelTesterMixin:
...
@@ -2578,7 +2578,7 @@ class ModelTesterMixin:
@
require_accelerate
@
require_accelerate
@
mark
.
accelerate_tests
@
mark
.
accelerate_tests
@
require_torch_gpu
@
require_torch_gpu
def
test_disk_offload
(
self
):
def
test_disk_offload
_bin
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
for
model_class
in
self
.
all_model_classes
:
...
@@ -2593,7 +2593,7 @@ class ModelTesterMixin:
...
@@ -2593,7 +2593,7 @@ class ModelTesterMixin:
model_size
=
compute_module_sizes
(
model
)[
""
]
model_size
=
compute_module_sizes
(
model
)[
""
]
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
model
.
cpu
().
save_pretrained
(
tmp_dir
)
model
.
cpu
().
save_pretrained
(
tmp_dir
,
safe_serialization
=
False
)
with
self
.
assertRaises
(
ValueError
):
with
self
.
assertRaises
(
ValueError
):
max_size
=
int
(
self
.
model_split_percents
[
0
]
*
model_size
)
max_size
=
int
(
self
.
model_split_percents
[
0
]
*
model_size
)
...
@@ -2613,6 +2613,38 @@ class ModelTesterMixin:
...
@@ -2613,6 +2613,38 @@ class ModelTesterMixin:
self
.
assertTrue
(
torch
.
allclose
(
base_output
[
0
],
new_output
[
0
],
atol
=
1e-5
))
self
.
assertTrue
(
torch
.
allclose
(
base_output
[
0
],
new_output
[
0
],
atol
=
1e-5
))
@
require_accelerate
@
mark
.
accelerate_tests
@
require_torch_gpu
def
test_disk_offload_safetensors
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
if
model_class
.
_no_split_modules
is
None
:
continue
inputs_dict_class
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
model
=
model_class
(
config
).
eval
()
model
=
model
.
to
(
torch_device
)
torch
.
manual_seed
(
0
)
base_output
=
model
(
**
inputs_dict_class
)
model_size
=
compute_module_sizes
(
model
)[
""
]
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
model
.
cpu
().
save_pretrained
(
tmp_dir
)
max_size
=
int
(
self
.
model_split_percents
[
1
]
*
model_size
)
max_memory
=
{
0
:
max_size
,
"cpu"
:
max_size
}
# This doesn't error out as it's in safetensors and doesn't need an offload folder
new_model
=
model_class
.
from_pretrained
(
tmp_dir
,
device_map
=
"auto"
,
max_memory
=
max_memory
)
self
.
check_device_map_is_respected
(
new_model
,
new_model
.
hf_device_map
)
torch
.
manual_seed
(
0
)
new_output
=
new_model
(
**
inputs_dict_class
)
self
.
assertTrue
(
torch
.
allclose
(
base_output
[
0
],
new_output
[
0
],
atol
=
1e-5
))
@
require_accelerate
@
require_accelerate
@
mark
.
accelerate_tests
@
mark
.
accelerate_tests
@
require_torch_gpu
@
require_torch_gpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment