Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a26f4d62
Unverified
Commit
a26f4d62
authored
May 21, 2021
by
Stas Bekman
Committed by
GitHub
May 21, 2021
Browse files
[Deepspeed] support `zero.Init` in `from_config` (#11805)
* support zero.Init in from_config * no need for eval test
parent
82335185
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
2 deletions
+48
-2
src/transformers/models/auto/auto_factory.py
src/transformers/models/auto/auto_factory.py
+15
-1
tests/deepspeed/test_deepspeed.py
tests/deepspeed/test_deepspeed.py
+33
-1
No files found.
src/transformers/models/auto/auto_factory.py
View file @
a26f4d62
...
...
@@ -18,9 +18,14 @@ import types
from
...configuration_utils
import
PretrainedConfig
from
...file_utils
import
copy_func
from
...integrations
import
deepspeed_config
,
is_deepspeed_zero3_enabled
from
...utils
import
logging
from
.configuration_auto
import
AutoConfig
,
replace_list_option_in_docstrings
logger
=
logging
.
get_logger
(
__name__
)
CLASS_DOCSTRING
=
"""
This is a generic model class that will be instantiated as one of the model classes of the library when created
with the :meth:`~transformers.BaseAutoModelClass.from_pretrained` class method or the
...
...
@@ -362,7 +367,16 @@ class _BaseAutoModelClass:
def
from_config
(
cls
,
config
,
**
kwargs
):
if
type
(
config
)
in
cls
.
_model_mapping
.
keys
():
model_class
=
_get_model_class
(
config
,
cls
.
_model_mapping
)
return
model_class
(
config
,
**
kwargs
)
if
is_deepspeed_zero3_enabled
():
import
deepspeed
logger
.
info
(
"Detected DeepSpeed ZeRO-3: activating zero.init() for this model"
)
# this immediately partitions the model across all gpus, to avoid the overhead in time
# and memory copying it on CPU or each GPU first
with
deepspeed
.
zero
.
Init
(
config
=
deepspeed_config
()):
return
model_class
(
config
,
**
kwargs
)
else
:
return
model_class
(
config
,
**
kwargs
)
raise
ValueError
(
f
"Unrecognized configuration class
{
config
.
__class__
}
for this kind of AutoModel:
{
cls
.
__name__
}
.
\n
"
f
"Model type should be one of
{
', '
.
join
(
c
.
__name__
for
c
in
cls
.
_model_mapping
.
keys
())
}
."
...
...
tests/deepspeed/test_deepspeed.py
View file @
a26f4d62
...
...
@@ -25,6 +25,7 @@ from transformers.file_utils import WEIGHTS_NAME
from
transformers.integrations
import
is_deepspeed_available
from
transformers.testing_utils
import
(
CaptureLogger
,
CaptureStderr
,
ExtendSysPath
,
TestCasePlus
,
execute_subprocess_async
,
...
...
@@ -741,7 +742,38 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
execute_subprocess_async
(
cmd
,
env
=
self
.
get_env
())
return
output_dir
def
test_clm_from_config_zero3
(
self
):
# this test exercises AutoModel.from_config(config) - to ensure zero.Init is called
data_dir
=
self
.
tests_dir
/
"fixtures"
output_dir
=
self
.
get_auto_remove_tmp_dir
()
args
=
f
"""
--model_type gpt2
--tokenizer_name sshleifer/tiny-gpt2
--train_file
{
data_dir
}
/sample_text.txt
--validation_file
{
data_dir
}
/sample_text.txt
--output_dir
{
output_dir
}
--overwrite_output_dir
--do_train
--max_train_samples 4
--per_device_train_batch_size 2
--num_train_epochs 1
--warmup_steps 8
--block_size 8
--fp16
--report_to none
"""
.
split
()
ds_args
=
f
"--deepspeed
{
self
.
test_file_dir_str
}
/ds_config_zero3.json"
.
split
()
script
=
[
f
"
{
self
.
examples_dir_str
}
/pytorch/language-modeling/run_clm.py"
]
launcher
=
self
.
get_launcher
(
distributed
=
True
)
cmd
=
launcher
+
script
+
args
+
ds_args
# keep for quick debug
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
with
CaptureStderr
()
as
cs
:
execute_subprocess_async
(
cmd
,
env
=
self
.
get_env
())
assert
"Detected DeepSpeed ZeRO-3"
in
cs
.
err
def
get_launcher
(
self
,
distributed
=
False
):
# 1. explicitly set --num_nodes=1 just in case these tests end up run on a multi-node setup
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment