Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f06d0fad
Unverified
Commit
f06d0fad
authored
Dec 17, 2020
by
Stas Bekman
Committed by
GitHub
Dec 17, 2020
Browse files
[trainer] apex fixes and tests (#9180)
parent
467e9158
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
8 deletions
+22
-8
examples/seq2seq/test_finetune_trainer.py
examples/seq2seq/test_finetune_trainer.py
+17
-1
src/transformers/trainer.py
src/transformers/trainer.py
+5
-7
No files found.
examples/seq2seq/test_finetune_trainer.py
View file @
f06d0fad
...
@@ -18,7 +18,7 @@ import unittest
...
@@ -18,7 +18,7 @@ import unittest
from
unittest.mock
import
patch
from
unittest.mock
import
patch
from
transformers
import
BertTokenizer
,
EncoderDecoderModel
from
transformers
import
BertTokenizer
,
EncoderDecoderModel
from
transformers.file_utils
import
is_datasets_available
from
transformers.file_utils
import
is_apex_available
,
is_datasets_available
from
transformers.integrations
import
is_fairscale_available
from
transformers.integrations
import
is_fairscale_available
from
transformers.testing_utils
import
(
from
transformers.testing_utils
import
(
TestCasePlus
,
TestCasePlus
,
...
@@ -51,6 +51,17 @@ def require_fairscale(test_case):
...
@@ -51,6 +51,17 @@ def require_fairscale(test_case):
return
test_case
return
test_case
# a candidate for testing_utils
def
require_apex
(
test_case
):
"""
Decorator marking a test that requires apex
"""
if
not
is_apex_available
():
return
unittest
.
skip
(
"test requires apex"
)(
test_case
)
else
:
return
test_case
class
TestFinetuneTrainer
(
TestCasePlus
):
class
TestFinetuneTrainer
(
TestCasePlus
):
def
finetune_trainer_quick
(
self
,
distributed
=
None
,
extra_args_str
=
None
):
def
finetune_trainer_quick
(
self
,
distributed
=
None
,
extra_args_str
=
None
):
output_dir
=
self
.
run_trainer
(
1
,
"12"
,
MBART_TINY
,
1
,
distributed
,
extra_args_str
)
output_dir
=
self
.
run_trainer
(
1
,
"12"
,
MBART_TINY
,
1
,
distributed
,
extra_args_str
)
...
@@ -72,6 +83,7 @@ class TestFinetuneTrainer(TestCasePlus):
...
@@ -72,6 +83,7 @@ class TestFinetuneTrainer(TestCasePlus):
def
test_finetune_trainer_ddp
(
self
):
def
test_finetune_trainer_ddp
(
self
):
self
.
finetune_trainer_quick
(
distributed
=
True
)
self
.
finetune_trainer_quick
(
distributed
=
True
)
# it's crucial to test --sharded_ddp w/ and w/o --fp16
@
require_torch_multi_gpu
@
require_torch_multi_gpu
@
require_fairscale
@
require_fairscale
def
test_finetune_trainer_ddp_sharded_ddp
(
self
):
def
test_finetune_trainer_ddp_sharded_ddp
(
self
):
...
@@ -82,6 +94,10 @@ class TestFinetuneTrainer(TestCasePlus):
...
@@ -82,6 +94,10 @@ class TestFinetuneTrainer(TestCasePlus):
def
test_finetune_trainer_ddp_sharded_ddp_fp16
(
self
):
def
test_finetune_trainer_ddp_sharded_ddp_fp16
(
self
):
self
.
finetune_trainer_quick
(
distributed
=
True
,
extra_args_str
=
"--sharded_ddp --fp16"
)
self
.
finetune_trainer_quick
(
distributed
=
True
,
extra_args_str
=
"--sharded_ddp --fp16"
)
@
require_apex
def
test_finetune_trainer_apex
(
self
):
self
.
finetune_trainer_quick
(
extra_args_str
=
"--fp16 --fp16_backend=apex"
)
@
slow
@
slow
def
test_finetune_trainer_slow
(
self
):
def
test_finetune_trainer_slow
(
self
):
# There is a missing call to __init__process_group somewhere
# There is a missing call to __init__process_group somewhere
...
...
src/transformers/trainer.py
View file @
f06d0fad
...
@@ -53,7 +53,7 @@ from torch.utils.data.distributed import DistributedSampler
...
@@ -53,7 +53,7 @@ from torch.utils.data.distributed import DistributedSampler
from
torch.utils.data.sampler
import
RandomSampler
,
SequentialSampler
from
torch.utils.data.sampler
import
RandomSampler
,
SequentialSampler
from
.data.data_collator
import
DataCollator
,
DataCollatorWithPadding
,
default_data_collator
from
.data.data_collator
import
DataCollator
,
DataCollatorWithPadding
,
default_data_collator
from
.file_utils
import
WEIGHTS_NAME
,
is_datasets_available
,
is_in_notebook
,
is_torch_tpu_available
from
.file_utils
import
WEIGHTS_NAME
,
is_apex_available
,
is_datasets_available
,
is_in_notebook
,
is_torch_tpu_available
from
.modeling_utils
import
PreTrainedModel
from
.modeling_utils
import
PreTrainedModel
from
.models.auto.modeling_auto
import
MODEL_FOR_QUESTION_ANSWERING_MAPPING
from
.models.auto.modeling_auto
import
MODEL_FOR_QUESTION_ANSWERING_MAPPING
from
.optimization
import
AdamW
,
get_linear_schedule_with_warmup
from
.optimization
import
AdamW
,
get_linear_schedule_with_warmup
...
@@ -104,13 +104,10 @@ if is_in_notebook():
...
@@ -104,13 +104,10 @@ if is_in_notebook():
DEFAULT_PROGRESS_CALLBACK
=
NotebookProgressCallback
DEFAULT_PROGRESS_CALLBACK
=
NotebookProgressCallback
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
if
is_apex_available
():
if
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
"1.6"
):
from
apex
import
amp
from
.file_utils
import
is_apex_available
if
is_apex_available
():
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
"1.6"
):
from
apex
import
amp
else
:
_is_native_amp_available
=
True
_is_native_amp_available
=
True
from
torch.cuda.amp
import
autocast
from
torch.cuda.amp
import
autocast
...
@@ -309,6 +306,7 @@ class Trainer:
...
@@ -309,6 +306,7 @@ class Trainer:
backend
=
"amp"
if
_is_native_amp_available
else
"apex"
backend
=
"amp"
if
_is_native_amp_available
else
"apex"
else
:
else
:
backend
=
args
.
fp16_backend
backend
=
args
.
fp16_backend
logger
.
info
(
f
"Using
{
backend
}
fp16 backend"
)
if
backend
==
"amp"
:
if
backend
==
"amp"
:
self
.
use_amp
=
True
self
.
use_amp
=
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment