Commit 269c9638 authored by Sylvain Gugger's avatar Sylvain Gugger
Browse files

Merge branch 'master' of github.com:huggingface/transformers

parents d31c7b10 c2e0fd52
......@@ -348,7 +348,7 @@ jobs:
- v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
- run: pip install --upgrade pip
- run: pip install ."[all, docs]"
- run: pip install ."[docs]"
- save_cache:
key: v0.4-build_doc-{{ checksum "setup.py" }}
paths:
......@@ -370,7 +370,7 @@ jobs:
keys:
- v0.4-deploy_doc-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: pip install ."[all,docs]"
- run: pip install ."[docs]"
- save_cache:
key: v0.4-deploy_doc-{{ checksum "setup.py" }}
paths:
......
......@@ -33,7 +33,7 @@ jobs:
run: |
apt -y update && apt install -y libsndfile1-dev
pip install --upgrade pip
pip install .[sklearn,testing,onnxruntime,sentencepiece,speech]
pip install .[sklearn,testing,onnxruntime,sentencepiece,speech,deepspeed]
- name: Are GPUs recognized by our DL frameworks
run: |
......@@ -155,7 +155,7 @@ jobs:
run: |
apt -y update && apt install -y libsndfile1-dev
pip install --upgrade pip
pip install .[sklearn,testing,onnxruntime,sentencepiece,speech]
pip install .[sklearn,testing,onnxruntime,sentencepiece,speech,deepspeed,fairscale]
- name: Are GPUs recognized by our DL frameworks
run: |
......
......@@ -274,6 +274,14 @@ Install the library via pypi:
pip install fairscale
or via ``transformers``' ``extras``:
.. code-block:: bash
pip install transformers[fairscale]
(will become available starting from ``transformers==4.6.0``)
or find more details on `the FairScale's GitHub page <https://github.com/facebookresearch/fairscale/#installation>`__.
If you're still struggling with the build, first make sure to read :ref:`zero-install-notes`.
......@@ -419,6 +427,14 @@ Install the library via pypi:
pip install deepspeed
or via ``transformers``' ``extras``:
.. code-block:: bash
pip install transformers[deepspeed]
(will become available starting from ``transformers==4.6.0``)
or find more details on `the DeepSpeed's GitHub page <https://github.com/microsoft/deepspeed#installation>`__ and
`advanced install <https://www.deepspeed.ai/tutorials/advanced-install/>`__.
......@@ -525,7 +541,7 @@ Here is an example of running ``run_translation.py`` under DeepSpeed deploying a
.. code-block:: bash
deepspeed examples/seq2seq/run_translation.py \
--deepspeed examples/tests/deepspeed/ds_config.json \
--deepspeed tests/deepspeed/ds_config.json \
--model_name_or_path t5-small --per_device_train_batch_size 1 \
--output_dir output_dir --overwrite_output_dir --fp16 \
--do_train --max_train_samples 500 --num_train_epochs 1 \
......@@ -550,7 +566,7 @@ To deploy DeepSpeed with one GPU adjust the :class:`~transformers.Trainer` comma
.. code-block:: bash
deepspeed --num_gpus=1 examples/seq2seq/run_translation.py \
--deepspeed examples/tests/deepspeed/ds_config.json \
--deepspeed tests/deepspeed/ds_config.json \
--model_name_or_path t5-small --per_device_train_batch_size 1 \
--output_dir output_dir --overwrite_output_dir --fp16 \
--do_train --max_train_samples 500 --num_train_epochs 1 \
......
......@@ -795,6 +795,23 @@ leave any data in there.
otherwise.
Temporary sys.path override
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
If you need to temporary override ``sys.path`` to import from another test for example, you can use the
``ExtendSysPath`` context manager. Example:
.. code-block:: python
import os
from transformers.testing_utils import ExtendSysPath
bindir = os.path.abspath(os.path.dirname(__file__))
with ExtendSysPath(f"{bindir}/.."):
from test_trainer import TrainerIntegrationCommon # noqa
Skipping tests
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -422,7 +422,12 @@ def main():
# Data collator
# This one will take care of randomly masking the tokens.
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
pad_to_multiple_of_8 = data_args.line_by_line and training_args.fp16 and not data_args.pad_to_max_length
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm_probability=data_args.mlm_probability,
pad_to_multiple_of=8 if pad_to_multiple_of_8 else None,
)
# Initialize our Trainer
trainer = Trainer(
......
......@@ -85,11 +85,14 @@ if stale_egg_info.exists():
# 1. all dependencies should be listed here with their version requirements if any
# 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py
_deps = [
"Pillow",
"black>=20.8b1",
"cookiecutter==1.7.2",
"dataclasses",
"datasets",
"deepspeed>0.3.13",
"docutils==0.16.0",
"fairscale>0.3",
"faiss-cpu",
"fastapi",
"filelock",
......@@ -102,13 +105,13 @@ _deps = [
"jax>=0.2.8",
"jaxlib>=0.1.59",
"keras2onnx",
"nltk",
"numpy>=1.17",
"onnxconverter-common",
"onnxruntime-tools>=1.4.2",
"onnxruntime>=1.4.0",
"packaging",
"parameterized",
"Pillow",
"protobuf",
"psutil",
"pydantic",
......@@ -119,15 +122,18 @@ _deps = [
"recommonmark",
"regex!=2019.12.17",
"requests",
"rouge-score",
"sacrebleu>=1.4.12",
"sacremoses",
"sagemaker>=2.31.0",
"scikit-learn",
"sentencepiece==0.1.91",
"soundfile",
"sphinx-copybutton",
"sphinx-markdown-tables",
"sphinx-rtd-theme==0.4.3", # sphinx-rtd-theme==0.5.0 introduced big changes in the style.
"sphinxext-opengraph==0.4.1",
"sphinx==3.2.1",
"sphinxext-opengraph==0.4.1",
"starlette",
"tensorflow-cpu>=2.3",
"tensorflow>=2.3",
......@@ -139,7 +145,6 @@ _deps = [
"unidic>=1.0.2",
"unidic_lite>=1.0.7",
"uvicorn",
"sagemaker>=2.31.0",
]
......@@ -230,6 +235,8 @@ extras["onnx"] = deps_list("onnxconverter-common", "keras2onnx") + extras["onnxr
extras["modelcreation"] = deps_list("cookiecutter")
extras["sagemaker"] = deps_list("sagemaker")
extras["deepspeed"] = deps_list("deepspeed")
extras["fairscale"] = deps_list("fairscale")
extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
extras["speech"] = deps_list("soundfile", "torchaudio")
......@@ -238,20 +245,12 @@ extras["vision"] = deps_list("Pillow")
extras["sentencepiece"] = deps_list("sentencepiece", "protobuf")
extras["testing"] = (
deps_list(
"pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets", "pytest-sugar", "black"
"pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets", "pytest-sugar", "black", "sacrebleu", "rouge-score", "nltk"
)
+ extras["retrieval"]
+ extras["modelcreation"]
)
extras["docs"] = deps_list(
"docutils",
"recommonmark",
"sphinx",
"sphinx-markdown-tables",
"sphinx-rtd-theme",
"sphinx-copybutton",
"sphinxext-opengraph",
)
extras["quality"] = deps_list("black", "isort", "flake8")
extras["all"] = (
......@@ -264,12 +263,24 @@ extras["all"] = (
+ extras["vision"]
)
extras["docs_specific"] = deps_list(
"docutils",
"recommonmark",
"sphinx",
"sphinx-markdown-tables",
"sphinx-rtd-theme",
"sphinx-copybutton",
"sphinxext-opengraph",
)
# "docs" needs "all" to resolve all the references
extras["docs"] = extras["all"] + extras["docs_specific"]
extras["dev"] = (
extras["all"]
+ extras["testing"]
+ extras["quality"]
+ extras["ja"]
+ extras["docs"]
+ extras["docs_specific"]
+ extras["sklearn"]
+ extras["modelcreation"]
)
......
......@@ -192,7 +192,7 @@ class DataCollatorForTokenClassification:
return batch
def _collate_batch(examples, tokenizer):
def _collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
# Tensorize if necessary.
if isinstance(examples[0], (list, tuple)):
......@@ -201,7 +201,7 @@ def _collate_batch(examples, tokenizer):
# Check if padding is necessary.
length_of_first = examples[0].size(0)
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
if are_tensors_same_length:
if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
return torch.stack(examples, dim=0)
# If yes, check if we have a `pad_token`.
......@@ -213,6 +213,8 @@ def _collate_batch(examples, tokenizer):
# Creating the full tensor and filling it with our data.
max_length = max(x.size(0) for x in examples)
if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
for i, example in enumerate(examples):
if tokenizer.padding_side == "right":
......@@ -311,6 +313,8 @@ class DataCollatorForLanguageModeling:
non-masked tokens and the value to predict for the masked token.
mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`.
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
.. note::
......@@ -323,6 +327,7 @@ class DataCollatorForLanguageModeling:
tokenizer: PreTrainedTokenizerBase
mlm: bool = True
mlm_probability: float = 0.15
pad_to_multiple_of: Optional[int] = None
def __post_init__(self):
if self.mlm and self.tokenizer.mask_token is None:
......@@ -336,9 +341,9 @@ class DataCollatorForLanguageModeling:
) -> Dict[str, torch.Tensor]:
# Handle dict or lists with proper padding and conversion to tensor.
if isinstance(examples[0], (dict, BatchEncoding)):
batch = self.tokenizer.pad(examples, return_tensors="pt")
batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
else:
batch = {"input_ids": _collate_batch(examples, self.tokenizer)}
batch = {"input_ids": _collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)}
# If special token mask has been preprocessed, pop it from the dict.
special_tokens_mask = batch.pop("special_tokens_mask", None)
......
......@@ -14,7 +14,7 @@
import sys
from .dependency_versions_table import deps
from .utils.versions import require_version_core
from .utils.versions import require_version, require_version_core
# define which module versions we always want to check at run time
......@@ -41,3 +41,7 @@ for pkg in pkgs_to_check_at_runtime:
require_version_core(deps[pkg])
else:
raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
def dep_version_check(pkg, hint=None):
require_version(deps[pkg], hint)
......@@ -2,11 +2,14 @@
# 1. modify the `_deps` dict in setup.py
# 2. run `make deps_table_update``
deps = {
"Pillow": "Pillow",
"black": "black>=20.8b1",
"cookiecutter": "cookiecutter==1.7.2",
"dataclasses": "dataclasses",
"datasets": "datasets",
"deepspeed": "deepspeed>0.3.13",
"docutils": "docutils==0.16.0",
"fairscale": "fairscale>0.3",
"faiss-cpu": "faiss-cpu",
"fastapi": "fastapi",
"filelock": "filelock",
......@@ -19,13 +22,13 @@ deps = {
"jax": "jax>=0.2.8",
"jaxlib": "jaxlib>=0.1.59",
"keras2onnx": "keras2onnx",
"nltk": "nltk",
"numpy": "numpy>=1.17",
"onnxconverter-common": "onnxconverter-common",
"onnxruntime-tools": "onnxruntime-tools>=1.4.2",
"onnxruntime": "onnxruntime>=1.4.0",
"packaging": "packaging",
"parameterized": "parameterized",
"Pillow": "Pillow",
"protobuf": "protobuf",
"psutil": "psutil",
"pydantic": "pydantic",
......@@ -36,15 +39,18 @@ deps = {
"recommonmark": "recommonmark",
"regex": "regex!=2019.12.17",
"requests": "requests",
"rouge-score": "rouge-score",
"sacrebleu": "sacrebleu>=1.4.12",
"sacremoses": "sacremoses",
"sagemaker": "sagemaker>=2.31.0",
"scikit-learn": "scikit-learn",
"sentencepiece": "sentencepiece==0.1.91",
"soundfile": "soundfile",
"sphinx-copybutton": "sphinx-copybutton",
"sphinx-markdown-tables": "sphinx-markdown-tables",
"sphinx-rtd-theme": "sphinx-rtd-theme==0.4.3",
"sphinxext-opengraph": "sphinxext-opengraph==0.4.1",
"sphinx": "sphinx==3.2.1",
"sphinxext-opengraph": "sphinxext-opengraph==0.4.1",
"starlette": "starlette",
"tensorflow-cpu": "tensorflow-cpu>=2.3",
"tensorflow": "tensorflow>=2.3",
......@@ -56,5 +62,4 @@ deps = {
"unidic": "unidic>=1.0.2",
"unidic_lite": "unidic_lite>=1.0.7",
"uvicorn": "uvicorn",
"sagemaker": "sagemaker>=2.31.0",
}
......@@ -24,8 +24,8 @@ import tempfile
from copy import deepcopy
from pathlib import Path
from .dependency_versions_check import dep_version_check
from .utils import logging
from .utils.versions import require_version
logger = logging.get_logger(__name__)
......@@ -324,7 +324,7 @@ def deepspeed_parse_config(ds_config):
If it's already a dict, return a copy of it, so that we can freely modify it.
"""
require_version("deepspeed>0.3.13")
dep_version_check("deepspeed")
if isinstance(ds_config, dict):
# Don't modify user's data should they want to reuse it (e.g. in tests), because once we
......@@ -604,7 +604,9 @@ class TensorBoardCallback(TrainerCallback):
self.tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})
def on_log(self, args, state, control, logs=None, **kwargs):
if state.is_world_process_zero:
if not state.is_world_process_zero:
return
if self.tb_writer is None:
self._init_summary_writer(args)
......
......@@ -387,6 +387,7 @@ class FlaxPreTrainedModel(ABC):
# get abs dir
save_directory = os.path.abspath(save_directory)
# save config as well
self.config.architectures = [self.__class__.__name__[4:]]
self.config.save_pretrained(save_directory)
# save model
......
......@@ -1037,6 +1037,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
logger.info(f"Saved model created in {saved_model_dir}")
# Save configuration file
self.config.architectures = [self.__class__.__name__[2:]]
self.config.save_pretrained(save_directory)
# If we save using the predefined names, we can load using `from_pretrained`
......
......@@ -22,6 +22,7 @@ from ...file_utils import _BaseLazyModule, is_flax_available, is_tf_available, i
_import_structure = {
"auto_factory": ["get_values"],
"configuration_auto": ["ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CONFIG_MAPPING", "MODEL_NAMES_MAPPING", "AutoConfig"],
"feature_extraction_auto": ["FEATURE_EXTRACTOR_MAPPING", "AutoFeatureExtractor"],
"tokenization_auto": ["TOKENIZER_MAPPING", "AutoTokenizer"],
......@@ -104,6 +105,7 @@ if is_flax_available():
if TYPE_CHECKING:
from .auto_factory import get_values
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, AutoConfig
from .feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
......
......@@ -328,6 +328,26 @@ FROM_PRETRAINED_FLAX_DOCSTRING = """
"""
def _get_model_class(config, model_mapping):
supported_models = model_mapping[type(config)]
if not isinstance(supported_models, (list, tuple)):
return supported_models
name_to_model = {model.__name__: model for model in supported_models}
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in name_to_model:
return name_to_model[arch]
elif f"TF{arch}" in name_to_model:
return name_to_model[f"TF{arch}"]
elif f"Flax{arch}" in name_to_model:
return name_to_model[f"Flax{arch}"]
# If not architecture is set in the config or match the supported models, the first element of the tuple is the
# defaults.
return supported_models[0]
class _BaseAutoModelClass:
# Base class for auto models.
_model_mapping = None
......@@ -341,7 +361,8 @@ class _BaseAutoModelClass:
def from_config(cls, config, **kwargs):
if type(config) in cls._model_mapping.keys():
return cls._model_mapping[type(config)](config, **kwargs)
model_class = _get_model_class(config, cls._model_mapping)
return model_class(config, **kwargs)
raise ValueError(
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
......@@ -356,9 +377,8 @@ class _BaseAutoModelClass:
)
if type(config) in cls._model_mapping.keys():
return cls._model_mapping[type(config)].from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
model_class = _get_model_class(config, cls._model_mapping)
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
raise ValueError(
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
......@@ -418,3 +438,14 @@ def auto_class_factory(name, model_mapping, checkpoint_for_example="bert-base-ca
from_pretrained = replace_list_option_in_docstrings(model_mapping)(from_pretrained)
new_class.from_pretrained = classmethod(from_pretrained)
return new_class
def get_values(model_mapping):
result = []
for model in model_mapping.values():
if isinstance(model, (list, tuple)):
result += list(model)
else:
result.append(model)
return result
......@@ -247,29 +247,38 @@ MODEL_NAMES_MAPPING = OrderedDict(
)
def _get_class_name(model_class):
if isinstance(model_class, (list, tuple)):
return " or ".join([f":class:`~transformers.{c.__name__}`" for c in model_class])
return f":class:`~transformers.{model_class.__name__}`"
def _list_model_options(indent, config_to_class=None, use_model_types=True):
if config_to_class is None and not use_model_types:
raise ValueError("Using `use_model_types=False` requires a `config_to_class` dictionary.")
if use_model_types:
if config_to_class is None:
model_type_to_name = {model_type: config.__name__ for model_type, config in CONFIG_MAPPING.items()}
model_type_to_name = {
model_type: f":class:`~transformers.{config.__name__}`"
for model_type, config in CONFIG_MAPPING.items()
}
else:
model_type_to_name = {
model_type: config_to_class[config].__name__
model_type: _get_class_name(config_to_class[config])
for model_type, config in CONFIG_MAPPING.items()
if config in config_to_class
}
lines = [
f"{indent}- **{model_type}** -- :class:`~transformers.{model_type_to_name[model_type]}` ({MODEL_NAMES_MAPPING[model_type]} model)"
f"{indent}- **{model_type}** -- {model_type_to_name[model_type]} ({MODEL_NAMES_MAPPING[model_type]} model)"
for model_type in sorted(model_type_to_name.keys())
]
else:
config_to_name = {config.__name__: clas.__name__ for config, clas in config_to_class.items()}
config_to_name = {config.__name__: _get_class_name(clas) for config, clas in config_to_class.items()}
config_to_model_name = {
config.__name__: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING.items()
}
lines = [
f"{indent}- :class:`~transformers.{config_name}` configuration class: :class:`~transformers.{config_to_name[config_name]}` ({config_to_model_name[config_name]} model)"
f"{indent}- :class:`~transformers.{config_name}` configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)"
for config_name in sorted(config_to_name.keys())
]
return "\n".join(lines)
......
......@@ -124,6 +124,7 @@ from ..flaubert.modeling_flaubert import (
)
from ..fsmt.modeling_fsmt import FSMTForConditionalGeneration, FSMTModel
from ..funnel.modeling_funnel import (
FunnelBaseModel,
FunnelForMaskedLM,
FunnelForMultipleChoice,
FunnelForPreTraining,
......@@ -377,7 +378,7 @@ MODEL_MAPPING = OrderedDict(
(CTRLConfig, CTRLModel),
(ElectraConfig, ElectraModel),
(ReformerConfig, ReformerModel),
(FunnelConfig, FunnelModel),
(FunnelConfig, (FunnelModel, FunnelBaseModel)),
(LxmertConfig, LxmertModel),
(BertGenerationConfig, BertGenerationEncoder),
(DebertaConfig, DebertaModel),
......
......@@ -91,6 +91,7 @@ from ..flaubert.modeling_tf_flaubert import (
TFFlaubertWithLMHeadModel,
)
from ..funnel.modeling_tf_funnel import (
TFFunnelBaseModel,
TFFunnelForMaskedLM,
TFFunnelForMultipleChoice,
TFFunnelForPreTraining,
......@@ -242,7 +243,7 @@ TF_MODEL_MAPPING = OrderedDict(
(XLMConfig, TFXLMModel),
(CTRLConfig, TFCTRLModel),
(ElectraConfig, TFElectraModel),
(FunnelConfig, TFFunnelModel),
(FunnelConfig, (TFFunnelModel, TFFunnelBaseModel)),
(DPRConfig, TFDPRQuestionEncoder),
(MPNetConfig, TFMPNetModel),
(BartConfig, TFBartModel),
......
......@@ -24,6 +24,7 @@ import unittest
from distutils.util import strtobool
from io import StringIO
from pathlib import Path
from typing import Iterator, Union
from .file_utils import (
is_datasets_available,
......@@ -621,6 +622,27 @@ class CaptureLogger:
return f"captured: {self.out}\n"
@contextlib.contextmanager
# adapted from https://stackoverflow.com/a/64789046/9201239
def ExtendSysPath(path: Union[str, os.PathLike]) -> Iterator[None]:
"""
Temporary add given path to `sys.path`.
Usage ::
with ExtendSysPath('/path/to/dir'):
mymodule = importlib.import_module('mymodule')
"""
path = os.fspath(path)
try:
sys.path.insert(0, path)
yield
finally:
sys.path.remove(path)
class TestCasePlus(unittest.TestCase):
"""
This class extends `unittest.TestCase` with additional features.
......
......@@ -54,6 +54,7 @@ from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .dependency_versions_check import dep_version_check
from .file_utils import (
WEIGHTS_NAME,
is_apex_available,
......@@ -139,17 +140,14 @@ if is_torch_tpu_available():
import torch_xla.distributed.parallel_loader as pl
if is_fairscale_available():
dep_version_check("fairscale")
import fairscale
from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.nn.wrap import auto_wrap
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
if version.parse(fairscale.__version__) >= version.parse("0.3"):
from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
from fairscale.nn.wrap import auto_wrap
else:
FullyShardedDDP = None
if is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.distributed as dist
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
......
......@@ -531,6 +531,12 @@ class TrainingArguments:
)
def __post_init__(self):
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).
# This needs to happen before any call to self.device or self.n_gpu.
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != self.local_rank:
self.local_rank = env_local_rank
# expand paths, if not os.makedirs("~/bar") will make directory
# in the current directory instead of the actual home
#  see https://github.com/huggingface/transformers/issues/10628
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment