Unverified Commit 0d0efd3a authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Smdistributed trainer (#9798)

* Add a debug print

* Adapt Trainer to use smdistributed if available

* Forgotten parenthesis

* Real check for sagemaker

* Donforget to define device...

* Woopsie, local)rank is defined differently

* Update since local_rank has the proper value

* Remove debug statement

* More robust check for smdistributed

* Quality

* Deal with key not present error
parent 897a24c8
...@@ -297,6 +297,20 @@ def is_pandas_available(): ...@@ -297,6 +297,20 @@ def is_pandas_available():
return importlib.util.find_spec("pandas") is not None return importlib.util.find_spec("pandas") is not None
def is_sagemaker_distributed_available():
# Get the sagemaker specific env variable.
sagemaker_params = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
try:
# Parse it and check the field "sagemaker_distributed_dataparallel_enabled".
sagemaker_params = json.loads(sagemaker_params)
if not sagemaker_params.get("sagemaker_distributed_dataparallel_enabled", False):
return False
except json.JSONDecodeError:
return False
# Lastly, check if the `smdistributed` module is present.
return importlib.util.find_spec("smdistributed") is not None
def torch_only_method(fn): def torch_only_method(fn):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if not _torch_available: if not _torch_available:
......
...@@ -51,7 +51,14 @@ from torch.utils.data.distributed import DistributedSampler ...@@ -51,7 +51,14 @@ from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler from torch.utils.data.sampler import RandomSampler, SequentialSampler
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .file_utils import WEIGHTS_NAME, is_apex_available, is_datasets_available, is_in_notebook, is_torch_tpu_available from .file_utils import (
WEIGHTS_NAME,
is_apex_available,
is_datasets_available,
is_in_notebook,
is_sagemaker_distributed_available,
is_torch_tpu_available,
)
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
from .optimization import Adafactor, AdamW, get_scheduler from .optimization import Adafactor, AdamW, get_scheduler
...@@ -125,6 +132,11 @@ if is_fairscale_available(): ...@@ -125,6 +132,11 @@ if is_fairscale_available():
from fairscale.optim import OSS from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler from fairscale.optim.grad_scaler import ShardedGradScaler
if is_sagemaker_distributed_available():
import smdistributed.dataparallel.torch.distributed as dist
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
else:
import torch.distributed as dist
if TYPE_CHECKING: if TYPE_CHECKING:
import optuna import optuna
...@@ -428,9 +440,12 @@ class Trainer: ...@@ -428,9 +440,12 @@ class Trainer:
if self.args.parallel_mode == ParallelMode.TPU: if self.args.parallel_mode == ParallelMode.TPU:
num_processes = xm.xrt_world_size() num_processes = xm.xrt_world_size()
process_index = xm.get_ordinal() process_index = xm.get_ordinal()
elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: elif (
num_processes = torch.distributed.get_world_size() self.args.parallel_mode == ParallelMode.DISTRIBUTED
process_index = torch.distributed.get_rank() or self.args.parallel_mode == ParallelMode.SAGEMAKER_DISTRIBUTED
):
num_processes = dist.get_world_size()
process_index = dist.get_rank()
else: else:
num_processes = 1 num_processes = 1
process_index = 0 process_index = 0
...@@ -743,6 +758,8 @@ class Trainer: ...@@ -743,6 +758,8 @@ class Trainer:
# Distributed training (should be after apex fp16 initialization) # Distributed training (should be after apex fp16 initialization)
if self.sharded_dpp: if self.sharded_dpp:
model = ShardedDDP(model, self.optimizer) model = ShardedDDP(model, self.optimizer)
elif is_sagemaker_distributed_available():
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
elif self.args.local_rank != -1: elif self.args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel( model = torch.nn.parallel.DistributedDataParallel(
model, model,
...@@ -767,14 +784,13 @@ class Trainer: ...@@ -767,14 +784,13 @@ class Trainer:
# Train! # Train!
if is_torch_tpu_available(): if is_torch_tpu_available():
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size() world_size = xm.xrt_world_size()
elif self.args.local_rank != -1:
world_size = dist.get_world_size()
else: else:
total_train_batch_size = ( world_size = 1
self.args.train_batch_size
* self.args.gradient_accumulation_steps
* (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
)
total_train_batch_size = self.args.train_batch_size * self.args.gradient_accumulation_steps * world_size
num_examples = ( num_examples = (
self.num_examples(train_dataloader) self.num_examples(train_dataloader)
if train_dataset_is_sized if train_dataset_is_sized
...@@ -1302,7 +1318,7 @@ class Trainer: ...@@ -1302,7 +1318,7 @@ class Trainer:
if is_torch_tpu_available(): if is_torch_tpu_available():
return xm.is_master_ordinal(local=False) return xm.is_master_ordinal(local=False)
else: else:
return self.args.local_rank == -1 or torch.distributed.get_rank() == 0 return self.args.local_rank == -1 or dist.get_rank() == 0
def save_model(self, output_dir: Optional[str] = None): def save_model(self, output_dir: Optional[str] = None):
""" """
...@@ -1542,7 +1558,7 @@ class Trainer: ...@@ -1542,7 +1558,7 @@ class Trainer:
if is_torch_tpu_available(): if is_torch_tpu_available():
world_size = xm.xrt_world_size() world_size = xm.xrt_world_size()
elif self.args.local_rank != -1: elif self.args.local_rank != -1:
world_size = torch.distributed.get_world_size() world_size = dist.get_world_size()
world_size = max(1, world_size) world_size = max(1, world_size)
eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size) eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
......
...@@ -28,10 +28,16 @@ from torch.utils.data.dataset import Dataset ...@@ -28,10 +28,16 @@ from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler from torch.utils.data.sampler import RandomSampler, Sampler
from .file_utils import is_torch_tpu_available from .file_utils import is_sagemaker_distributed_available, is_torch_tpu_available
from .utils import logging from .utils import logging
if is_sagemaker_distributed_available():
import smdistributed.dataparallel.torch.distributed as dist
else:
import torch.distributed as dist
if is_torch_tpu_available(): if is_torch_tpu_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
...@@ -121,8 +127,8 @@ def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int] ...@@ -121,8 +127,8 @@ def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int]
try: try:
if isinstance(tensor, (tuple, list)): if isinstance(tensor, (tuple, list)):
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor) return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())] output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor) dist.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0) concat = torch.cat(output_tensors, dim=0)
# truncate the dummy elements added by SequentialDistributedSampler # truncate the dummy elements added by SequentialDistributedSampler
...@@ -138,8 +144,8 @@ def distributed_broadcast_scalars( ...@@ -138,8 +144,8 @@ def distributed_broadcast_scalars(
) -> torch.Tensor: ) -> torch.Tensor:
try: try:
tensorized_scalar = torch.tensor(scalars).cuda() tensorized_scalar = torch.tensor(scalars).cuda()
output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())] output_tensors = [tensorized_scalar.clone() for _ in range(dist.get_world_size())]
torch.distributed.all_gather(output_tensors, tensorized_scalar) dist.all_gather(output_tensors, tensorized_scalar)
concat = torch.cat(output_tensors, dim=0) concat = torch.cat(output_tensors, dim=0)
# truncate the dummy elements added by SequentialDistributedSampler # truncate the dummy elements added by SequentialDistributedSampler
...@@ -167,10 +173,10 @@ def torch_distributed_zero_first(local_rank: int): ...@@ -167,10 +173,10 @@ def torch_distributed_zero_first(local_rank: int):
local_rank (:obj:`int`): The rank of the local process. local_rank (:obj:`int`): The rank of the local process.
""" """
if local_rank not in [-1, 0]: if local_rank not in [-1, 0]:
torch.distributed.barrier() dist.barrier()
yield yield
if local_rank == 0: if local_rank == 0:
torch.distributed.barrier() dist.barrier()
class SequentialDistributedSampler(Sampler): class SequentialDistributedSampler(Sampler):
...@@ -185,13 +191,13 @@ class SequentialDistributedSampler(Sampler): ...@@ -185,13 +191,13 @@ class SequentialDistributedSampler(Sampler):
def __init__(self, dataset, num_replicas=None, rank=None): def __init__(self, dataset, num_replicas=None, rank=None):
if num_replicas is None: if num_replicas is None:
if not torch.distributed.is_available(): if not dist.is_available():
raise RuntimeError("Requires distributed package to be available") raise RuntimeError("Requires distributed package to be available")
num_replicas = torch.distributed.get_world_size() num_replicas = dist.get_world_size()
if rank is None: if rank is None:
if not torch.distributed.is_available(): if not dist.is_available():
raise RuntimeError("Requires distributed package to be available") raise RuntimeError("Requires distributed package to be available")
rank = torch.distributed.get_rank() rank = dist.get_rank()
self.dataset = dataset self.dataset = dataset
self.num_replicas = num_replicas self.num_replicas = num_replicas
self.rank = rank self.rank = rank
...@@ -480,13 +486,13 @@ class DistributedLengthGroupedSampler(DistributedSampler): ...@@ -480,13 +486,13 @@ class DistributedLengthGroupedSampler(DistributedSampler):
lengths: Optional[List[int]] = None, lengths: Optional[List[int]] = None,
): ):
if num_replicas is None: if num_replicas is None:
if not torch.distributed.is_available(): if not dist.is_available():
raise RuntimeError("Requires distributed package to be available") raise RuntimeError("Requires distributed package to be available")
num_replicas = torch.distributed.get_world_size() num_replicas = dist.get_world_size()
if rank is None: if rank is None:
if not torch.distributed.is_available(): if not dist.is_available():
raise RuntimeError("Requires distributed package to be available") raise RuntimeError("Requires distributed package to be available")
rank = torch.distributed.get_rank() rank = dist.get_rank()
self.dataset = dataset self.dataset = dataset
self.batch_size = batch_size self.batch_size = batch_size
self.num_replicas = num_replicas self.num_replicas = num_replicas
......
...@@ -25,7 +25,7 @@ from typing import Any, Dict, NamedTuple, Optional, Tuple, Union ...@@ -25,7 +25,7 @@ from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
import numpy as np import numpy as np
from .file_utils import is_tf_available, is_torch_available, is_torch_tpu_available from .file_utils import is_sagemaker_distributed_available, is_tf_available, is_torch_available, is_torch_tpu_available
from .tokenization_utils_base import ExplicitEnum from .tokenization_utils_base import ExplicitEnum
...@@ -187,6 +187,10 @@ def total_processes_number(local_rank): ...@@ -187,6 +187,10 @@ def total_processes_number(local_rank):
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
return xm.xrt_world_size() return xm.xrt_world_size()
elif is_sagemaker_distributed_available():
import smdistributed.dataparallel.torch.distributed as dist
return dist.get_world_size()
elif local_rank != -1 and is_torch_available(): elif local_rank != -1 and is_torch_available():
import torch import torch
......
...@@ -18,7 +18,13 @@ from dataclasses import asdict, dataclass, field ...@@ -18,7 +18,13 @@ from dataclasses import asdict, dataclass, field
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required from .file_utils import (
cached_property,
is_sagemaker_distributed_available,
is_torch_available,
is_torch_tpu_available,
torch_required,
)
from .trainer_utils import EvaluationStrategy, SchedulerType from .trainer_utils import EvaluationStrategy, SchedulerType
from .utils import logging from .utils import logging
...@@ -493,6 +499,13 @@ class TrainingArguments: ...@@ -493,6 +499,13 @@ class TrainingArguments:
elif is_torch_tpu_available(): elif is_torch_tpu_available():
device = xm.xla_device() device = xm.xla_device()
self._n_gpu = 0 self._n_gpu = 0
elif is_sagemaker_distributed_available():
import smdistributed.dataparallel.torch.distributed as dist
dist.init_process_group()
self.local_rank = dist.get_local_rank()
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1
elif self.local_rank == -1: elif self.local_rank == -1:
# if n_gpu is > 1 we'll use nn.DataParallel. # if n_gpu is > 1 we'll use nn.DataParallel.
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
...@@ -566,6 +579,8 @@ class TrainingArguments: ...@@ -566,6 +579,8 @@ class TrainingArguments:
""" """
if is_torch_tpu_available(): if is_torch_tpu_available():
return ParallelMode.TPU return ParallelMode.TPU
elif is_sagemaker_distributed_available():
return ParallelMode.SAGEMAKER_DISTRIBUTED
elif self.local_rank != -1: elif self.local_rank != -1:
return ParallelMode.DISTRIBUTED return ParallelMode.DISTRIBUTED
elif self.n_gpu > 1: elif self.n_gpu > 1:
...@@ -607,4 +622,5 @@ class ParallelMode(Enum): ...@@ -607,4 +622,5 @@ class ParallelMode(Enum):
NOT_PARALLEL = "not_parallel" NOT_PARALLEL = "not_parallel"
NOT_DISTRIBUTED = "not_distributed" NOT_DISTRIBUTED = "not_distributed"
DISTRIBUTED = "distributed" DISTRIBUTED = "distributed"
SAGEMAKER_DISTRIBUTED = "sm_distributed"
TPU = "tpu" TPU = "tpu"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment