Unverified Commit 81ac45f8 authored by Lai Wei's avatar Lai Wei Committed by GitHub
Browse files

update smddp api to v1.4.0 (#16371)



* update smddp api to v1.4.0

* Update src/transformers/trainer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/trainer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* address comments

* fix style

* remove unused import

* fix indent

* disable style check for import

* fix space
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent a73281e3
......@@ -90,10 +90,10 @@ class SageMakerTrainingArguments(TrainingArguments):
device = torch.device("cuda", local_rank)
self._n_gpu = 1
elif is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.distributed as dist
import smdistributed.dataparallel.torch.torch_smddp # noqa: F401
dist.init_process_group()
self.local_rank = dist.get_local_rank()
torch.distributed.init_process_group(backend="smddp")
self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1
elif self.local_rank == -1:
......
......@@ -51,6 +51,7 @@ from .integrations import ( # isort: split
import numpy as np
import torch
import torch.distributed as dist
from packaging import version
from torch import nn
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
......@@ -170,11 +171,6 @@ if is_fairscale_available():
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
if is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.distributed as dist
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
else:
import torch.distributed as dist
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
......@@ -1078,7 +1074,9 @@ class Trainer:
).to(self.args.device)
elif is_sagemaker_dp_enabled():
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
model = nn.parallel.DistributedDataParallel(
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
)
elif self.args.local_rank != -1:
kwargs = {}
if self.args.ddp_find_unused_parameters is not None:
......
......@@ -29,25 +29,15 @@ from typing import Any, Dict, Iterator, List, Optional, Union
import numpy as np
import torch
import torch.distributed as dist
from packaging import version
from torch import nn
from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler
from torch.utils.data.distributed import DistributedSampler
from .tokenization_utils_base import BatchEncoding
from .utils import (
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_tpu_available,
is_training_run_on_sagemaker,
logging,
)
if is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.distributed as dist
else:
import torch.distributed as dist
from .utils import is_sagemaker_mp_enabled, is_torch_tpu_available, is_training_run_on_sagemaker, logging
if is_training_run_on_sagemaker():
logging.add_handler(StreamHandler(sys.stdout))
......
......@@ -32,7 +32,6 @@ import numpy as np
from .utils import (
ExplicitEnum,
is_psutil_available,
is_sagemaker_dp_enabled,
is_tf_available,
is_torch_available,
is_torch_cuda_available,
......@@ -263,10 +262,6 @@ def total_processes_number(local_rank):
import torch_xla.core.xla_model as xm
return xm.xrt_world_size()
elif is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.distributed as dist
return dist.get_world_size()
elif local_rank != -1 and is_torch_available():
import torch
......
......@@ -41,12 +41,11 @@ from .utils import (
if is_torch_available():
import torch
import torch.distributed as dist
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
if is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.distributed as sm_dist
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
......@@ -1046,8 +1045,8 @@ class TrainingArguments:
device = torch.device("cuda", local_rank)
self._n_gpu = 1
elif is_sagemaker_dp_enabled():
sm_dist.init_process_group()
self.local_rank = sm_dist.get_local_rank()
dist.init_process_group(backend="smddp")
self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1
elif self.deepspeed:
......@@ -1149,7 +1148,7 @@ class TrainingArguments:
elif is_sagemaker_mp_enabled():
return smp.dp_size() if not smp.state.cfg.prescaled_batch else smp.rdp_size()
elif is_sagemaker_dp_enabled():
return sm_dist.get_world_size()
return dist.get_world_size()
elif self.local_rank != -1:
return torch.distributed.get_world_size()
return 1
......@@ -1165,7 +1164,7 @@ class TrainingArguments:
elif is_sagemaker_mp_enabled():
return smp.dp_rank() if not smp.state.cfg.prescaled_batch else smp.rdp_rank()
elif is_sagemaker_dp_enabled():
return sm_dist.get_rank()
return dist.get_rank()
elif self.local_rank != -1:
return torch.distributed.get_rank()
return 0
......@@ -1181,7 +1180,7 @@ class TrainingArguments:
elif is_sagemaker_mp_enabled():
return smp.local_rank()
elif is_sagemaker_dp_enabled():
return sm_dist.get_rank()
return dist.get_rank()
elif self.local_rank != -1:
return self.local_rank
return 0
......@@ -1281,7 +1280,7 @@ class TrainingArguments:
if is_torch_tpu_available():
xm.rendezvous(desc)
elif is_sagemaker_dp_enabled():
sm_dist.barrier()
dist.barrier()
else:
torch.distributed.barrier()
yield
......@@ -1292,7 +1291,7 @@ class TrainingArguments:
if is_torch_tpu_available():
xm.rendezvous(desc)
elif is_sagemaker_dp_enabled():
sm_dist.barrier()
dist.barrier()
else:
torch.distributed.barrier()
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment