Unverified Commit f06d0fad authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[trainer] apex fixes and tests (#9180)

parent 467e9158
...@@ -18,7 +18,7 @@ import unittest ...@@ -18,7 +18,7 @@ import unittest
from unittest.mock import patch from unittest.mock import patch
from transformers import BertTokenizer, EncoderDecoderModel from transformers import BertTokenizer, EncoderDecoderModel
from transformers.file_utils import is_datasets_available from transformers.file_utils import is_apex_available, is_datasets_available
from transformers.integrations import is_fairscale_available from transformers.integrations import is_fairscale_available
from transformers.testing_utils import ( from transformers.testing_utils import (
TestCasePlus, TestCasePlus,
...@@ -51,6 +51,17 @@ def require_fairscale(test_case): ...@@ -51,6 +51,17 @@ def require_fairscale(test_case):
return test_case return test_case
# a candidate for testing_utils
def require_apex(test_case):
"""
Decorator marking a test that requires apex
"""
if not is_apex_available():
return unittest.skip("test requires apex")(test_case)
else:
return test_case
class TestFinetuneTrainer(TestCasePlus): class TestFinetuneTrainer(TestCasePlus):
def finetune_trainer_quick(self, distributed=None, extra_args_str=None): def finetune_trainer_quick(self, distributed=None, extra_args_str=None):
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str) output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str)
...@@ -72,6 +83,7 @@ class TestFinetuneTrainer(TestCasePlus): ...@@ -72,6 +83,7 @@ class TestFinetuneTrainer(TestCasePlus):
def test_finetune_trainer_ddp(self): def test_finetune_trainer_ddp(self):
self.finetune_trainer_quick(distributed=True) self.finetune_trainer_quick(distributed=True)
# it's crucial to test --sharded_ddp w/ and w/o --fp16
@require_torch_multi_gpu @require_torch_multi_gpu
@require_fairscale @require_fairscale
def test_finetune_trainer_ddp_sharded_ddp(self): def test_finetune_trainer_ddp_sharded_ddp(self):
...@@ -82,6 +94,10 @@ class TestFinetuneTrainer(TestCasePlus): ...@@ -82,6 +94,10 @@ class TestFinetuneTrainer(TestCasePlus):
def test_finetune_trainer_ddp_sharded_ddp_fp16(self): def test_finetune_trainer_ddp_sharded_ddp_fp16(self):
self.finetune_trainer_quick(distributed=True, extra_args_str="--sharded_ddp --fp16") self.finetune_trainer_quick(distributed=True, extra_args_str="--sharded_ddp --fp16")
@require_apex
def test_finetune_trainer_apex(self):
self.finetune_trainer_quick(extra_args_str="--fp16 --fp16_backend=apex")
@slow @slow
def test_finetune_trainer_slow(self): def test_finetune_trainer_slow(self):
# There is a missing call to __init__process_group somewhere # There is a missing call to __init__process_group somewhere
......
...@@ -53,7 +53,7 @@ from torch.utils.data.distributed import DistributedSampler ...@@ -53,7 +53,7 @@ from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler from torch.utils.data.sampler import RandomSampler, SequentialSampler
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_in_notebook, is_torch_tpu_available from .file_utils import WEIGHTS_NAME, is_apex_available, is_datasets_available, is_in_notebook, is_torch_tpu_available
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
from .optimization import AdamW, get_linear_schedule_with_warmup from .optimization import AdamW, get_linear_schedule_with_warmup
...@@ -104,13 +104,10 @@ if is_in_notebook(): ...@@ -104,13 +104,10 @@ if is_in_notebook():
DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex if is_apex_available():
if version.parse(torch.__version__) < version.parse("1.6"): from apex import amp
from .file_utils import is_apex_available
if is_apex_available(): if version.parse(torch.__version__) >= version.parse("1.6"):
from apex import amp
else:
_is_native_amp_available = True _is_native_amp_available = True
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
...@@ -309,6 +306,7 @@ class Trainer: ...@@ -309,6 +306,7 @@ class Trainer:
backend = "amp" if _is_native_amp_available else "apex" backend = "amp" if _is_native_amp_available else "apex"
else: else:
backend = args.fp16_backend backend = args.fp16_backend
logger.info(f"Using {backend} fp16 backend")
if backend == "amp": if backend == "amp":
self.use_amp = True self.use_amp = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment