Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f06d0fad
Unverified
Commit
f06d0fad
authored
Dec 17, 2020
by
Stas Bekman
Committed by
GitHub
Dec 17, 2020
Browse files
[trainer] apex fixes and tests (#9180)
parent
467e9158
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
8 deletions
+22
-8
examples/seq2seq/test_finetune_trainer.py
examples/seq2seq/test_finetune_trainer.py
+17
-1
src/transformers/trainer.py
src/transformers/trainer.py
+5
-7
No files found.
examples/seq2seq/test_finetune_trainer.py
View file @
f06d0fad
...
@@ -18,7 +18,7 @@ import unittest
...
@@ -18,7 +18,7 @@ import unittest
from
unittest.mock
import
patch
from
unittest.mock
import
patch
from
transformers
import
BertTokenizer
,
EncoderDecoderModel
from
transformers
import
BertTokenizer
,
EncoderDecoderModel
from
transformers.file_utils
import
is_datasets_available
from
transformers.file_utils
import
is_apex_available
,
is_datasets_available
from
transformers.integrations
import
is_fairscale_available
from
transformers.integrations
import
is_fairscale_available
from
transformers.testing_utils
import
(
from
transformers.testing_utils
import
(
TestCasePlus
,
TestCasePlus
,
...
@@ -51,6 +51,17 @@ def require_fairscale(test_case):
...
@@ -51,6 +51,17 @@ def require_fairscale(test_case):
return
test_case
return
test_case
# a candidate for testing_utils
def
require_apex
(
test_case
):
"""
Decorator marking a test that requires apex
"""
if
not
is_apex_available
():
return
unittest
.
skip
(
"test requires apex"
)(
test_case
)
else
:
return
test_case
class
TestFinetuneTrainer
(
TestCasePlus
):
class
TestFinetuneTrainer
(
TestCasePlus
):
def
finetune_trainer_quick
(
self
,
distributed
=
None
,
extra_args_str
=
None
):
def
finetune_trainer_quick
(
self
,
distributed
=
None
,
extra_args_str
=
None
):
output_dir
=
self
.
run_trainer
(
1
,
"12"
,
MBART_TINY
,
1
,
distributed
,
extra_args_str
)
output_dir
=
self
.
run_trainer
(
1
,
"12"
,
MBART_TINY
,
1
,
distributed
,
extra_args_str
)
...
@@ -72,6 +83,7 @@ class TestFinetuneTrainer(TestCasePlus):
...
@@ -72,6 +83,7 @@ class TestFinetuneTrainer(TestCasePlus):
def
test_finetune_trainer_ddp
(
self
):
def
test_finetune_trainer_ddp
(
self
):
self
.
finetune_trainer_quick
(
distributed
=
True
)
self
.
finetune_trainer_quick
(
distributed
=
True
)
# it's crucial to test --sharded_ddp w/ and w/o --fp16
@
require_torch_multi_gpu
@
require_torch_multi_gpu
@
require_fairscale
@
require_fairscale
def
test_finetune_trainer_ddp_sharded_ddp
(
self
):
def
test_finetune_trainer_ddp_sharded_ddp
(
self
):
...
@@ -82,6 +94,10 @@ class TestFinetuneTrainer(TestCasePlus):
...
@@ -82,6 +94,10 @@ class TestFinetuneTrainer(TestCasePlus):
def
test_finetune_trainer_ddp_sharded_ddp_fp16
(
self
):
def
test_finetune_trainer_ddp_sharded_ddp_fp16
(
self
):
self
.
finetune_trainer_quick
(
distributed
=
True
,
extra_args_str
=
"--sharded_ddp --fp16"
)
self
.
finetune_trainer_quick
(
distributed
=
True
,
extra_args_str
=
"--sharded_ddp --fp16"
)
@
require_apex
def
test_finetune_trainer_apex
(
self
):
self
.
finetune_trainer_quick
(
extra_args_str
=
"--fp16 --fp16_backend=apex"
)
@
slow
@
slow
def
test_finetune_trainer_slow
(
self
):
def
test_finetune_trainer_slow
(
self
):
# There is a missing call to __init__process_group somewhere
# There is a missing call to __init__process_group somewhere
...
...
src/transformers/trainer.py
View file @
f06d0fad
...
@@ -53,7 +53,7 @@ from torch.utils.data.distributed import DistributedSampler
...
@@ -53,7 +53,7 @@ from torch.utils.data.distributed import DistributedSampler
from
torch.utils.data.sampler
import
RandomSampler
,
SequentialSampler
from
torch.utils.data.sampler
import
RandomSampler
,
SequentialSampler
from
.data.data_collator
import
DataCollator
,
DataCollatorWithPadding
,
default_data_collator
from
.data.data_collator
import
DataCollator
,
DataCollatorWithPadding
,
default_data_collator
from
.file_utils
import
WEIGHTS_NAME
,
is_datasets_available
,
is_in_notebook
,
is_torch_tpu_available
from
.file_utils
import
WEIGHTS_NAME
,
is_apex_available
,
is_datasets_available
,
is_in_notebook
,
is_torch_tpu_available
from
.modeling_utils
import
PreTrainedModel
from
.modeling_utils
import
PreTrainedModel
from
.models.auto.modeling_auto
import
MODEL_FOR_QUESTION_ANSWERING_MAPPING
from
.models.auto.modeling_auto
import
MODEL_FOR_QUESTION_ANSWERING_MAPPING
from
.optimization
import
AdamW
,
get_linear_schedule_with_warmup
from
.optimization
import
AdamW
,
get_linear_schedule_with_warmup
...
@@ -104,13 +104,10 @@ if is_in_notebook():
...
@@ -104,13 +104,10 @@ if is_in_notebook():
DEFAULT_PROGRESS_CALLBACK
=
NotebookProgressCallback
DEFAULT_PROGRESS_CALLBACK
=
NotebookProgressCallback
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
if
is_apex_available
():
if
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
"1.6"
):
from
apex
import
amp
from
.file_utils
import
is_apex_available
if
is_apex_available
():
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
"1.6"
):
from
apex
import
amp
else
:
_is_native_amp_available
=
True
_is_native_amp_available
=
True
from
torch.cuda.amp
import
autocast
from
torch.cuda.amp
import
autocast
...
@@ -309,6 +306,7 @@ class Trainer:
...
@@ -309,6 +306,7 @@ class Trainer:
backend
=
"amp"
if
_is_native_amp_available
else
"apex"
backend
=
"amp"
if
_is_native_amp_available
else
"apex"
else
:
else
:
backend
=
args
.
fp16_backend
backend
=
args
.
fp16_backend
logger
.
info
(
f
"Using
{
backend
}
fp16 backend"
)
if
backend
==
"amp"
:
if
backend
==
"amp"
:
self
.
use_amp
=
True
self
.
use_amp
=
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment