Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a26f4d62
Unverified
Commit
a26f4d62
authored
May 21, 2021
by
Stas Bekman
Committed by
GitHub
May 21, 2021
Browse files
[Deepspeed] support `zero.Init` in `from_config` (#11805)
* support zero.Init in from_config * no need for eval test
parent
82335185
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
2 deletions
+48
-2
src/transformers/models/auto/auto_factory.py
src/transformers/models/auto/auto_factory.py
+15
-1
tests/deepspeed/test_deepspeed.py
tests/deepspeed/test_deepspeed.py
+33
-1
No files found.
src/transformers/models/auto/auto_factory.py
View file @
a26f4d62
...
@@ -18,9 +18,14 @@ import types
...
@@ -18,9 +18,14 @@ import types
from
...configuration_utils
import
PretrainedConfig
from
...configuration_utils
import
PretrainedConfig
from
...file_utils
import
copy_func
from
...file_utils
import
copy_func
from
...integrations
import
deepspeed_config
,
is_deepspeed_zero3_enabled
from
...utils
import
logging
from
.configuration_auto
import
AutoConfig
,
replace_list_option_in_docstrings
from
.configuration_auto
import
AutoConfig
,
replace_list_option_in_docstrings
logger
=
logging
.
get_logger
(
__name__
)
CLASS_DOCSTRING
=
"""
CLASS_DOCSTRING
=
"""
This is a generic model class that will be instantiated as one of the model classes of the library when created
This is a generic model class that will be instantiated as one of the model classes of the library when created
with the :meth:`~transformers.BaseAutoModelClass.from_pretrained` class method or the
with the :meth:`~transformers.BaseAutoModelClass.from_pretrained` class method or the
...
@@ -362,7 +367,16 @@ class _BaseAutoModelClass:
...
@@ -362,7 +367,16 @@ class _BaseAutoModelClass:
def
from_config
(
cls
,
config
,
**
kwargs
):
def
from_config
(
cls
,
config
,
**
kwargs
):
if
type
(
config
)
in
cls
.
_model_mapping
.
keys
():
if
type
(
config
)
in
cls
.
_model_mapping
.
keys
():
model_class
=
_get_model_class
(
config
,
cls
.
_model_mapping
)
model_class
=
_get_model_class
(
config
,
cls
.
_model_mapping
)
return
model_class
(
config
,
**
kwargs
)
if
is_deepspeed_zero3_enabled
():
import
deepspeed
logger
.
info
(
"Detected DeepSpeed ZeRO-3: activating zero.init() for this model"
)
# this immediately partitions the model across all gpus, to avoid the overhead in time
# and memory copying it on CPU or each GPU first
with
deepspeed
.
zero
.
Init
(
config
=
deepspeed_config
()):
return
model_class
(
config
,
**
kwargs
)
else
:
return
model_class
(
config
,
**
kwargs
)
raise
ValueError
(
raise
ValueError
(
f
"Unrecognized configuration class
{
config
.
__class__
}
for this kind of AutoModel:
{
cls
.
__name__
}
.
\n
"
f
"Unrecognized configuration class
{
config
.
__class__
}
for this kind of AutoModel:
{
cls
.
__name__
}
.
\n
"
f
"Model type should be one of
{
', '
.
join
(
c
.
__name__
for
c
in
cls
.
_model_mapping
.
keys
())
}
."
f
"Model type should be one of
{
', '
.
join
(
c
.
__name__
for
c
in
cls
.
_model_mapping
.
keys
())
}
."
...
...
tests/deepspeed/test_deepspeed.py
View file @
a26f4d62
...
@@ -25,6 +25,7 @@ from transformers.file_utils import WEIGHTS_NAME
...
@@ -25,6 +25,7 @@ from transformers.file_utils import WEIGHTS_NAME
from
transformers.integrations
import
is_deepspeed_available
from
transformers.integrations
import
is_deepspeed_available
from
transformers.testing_utils
import
(
from
transformers.testing_utils
import
(
CaptureLogger
,
CaptureLogger
,
CaptureStderr
,
ExtendSysPath
,
ExtendSysPath
,
TestCasePlus
,
TestCasePlus
,
execute_subprocess_async
,
execute_subprocess_async
,
...
@@ -741,7 +742,38 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
...
@@ -741,7 +742,38 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
execute_subprocess_async
(
cmd
,
env
=
self
.
get_env
())
execute_subprocess_async
(
cmd
,
env
=
self
.
get_env
())
return
output_dir
def
test_clm_from_config_zero3
(
self
):
# this test exercises AutoModel.from_config(config) - to ensure zero.Init is called
data_dir
=
self
.
tests_dir
/
"fixtures"
output_dir
=
self
.
get_auto_remove_tmp_dir
()
args
=
f
"""
--model_type gpt2
--tokenizer_name sshleifer/tiny-gpt2
--train_file
{
data_dir
}
/sample_text.txt
--validation_file
{
data_dir
}
/sample_text.txt
--output_dir
{
output_dir
}
--overwrite_output_dir
--do_train
--max_train_samples 4
--per_device_train_batch_size 2
--num_train_epochs 1
--warmup_steps 8
--block_size 8
--fp16
--report_to none
"""
.
split
()
ds_args
=
f
"--deepspeed
{
self
.
test_file_dir_str
}
/ds_config_zero3.json"
.
split
()
script
=
[
f
"
{
self
.
examples_dir_str
}
/pytorch/language-modeling/run_clm.py"
]
launcher
=
self
.
get_launcher
(
distributed
=
True
)
cmd
=
launcher
+
script
+
args
+
ds_args
# keep for quick debug
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
with
CaptureStderr
()
as
cs
:
execute_subprocess_async
(
cmd
,
env
=
self
.
get_env
())
assert
"Detected DeepSpeed ZeRO-3"
in
cs
.
err
def
get_launcher
(
self
,
distributed
=
False
):
def
get_launcher
(
self
,
distributed
=
False
):
# 1. explicitly set --num_nodes=1 just in case these tests end up run on a multi-node setup
# 1. explicitly set --num_nodes=1 just in case these tests end up run on a multi-node setup
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment