Unverified Commit 81ac45f8 authored by Lai Wei's avatar Lai Wei Committed by GitHub
Browse files

update smddp api to v1.4.0 (#16371)



* update smddp api to v1.4.0

* Update src/transformers/trainer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/trainer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* address comments

* fix style

* remove unused import

* fix indent

* disable style check for import

* fix space
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent a73281e3
...@@ -90,10 +90,10 @@ class SageMakerTrainingArguments(TrainingArguments): ...@@ -90,10 +90,10 @@ class SageMakerTrainingArguments(TrainingArguments):
device = torch.device("cuda", local_rank) device = torch.device("cuda", local_rank)
self._n_gpu = 1 self._n_gpu = 1
elif is_sagemaker_dp_enabled(): elif is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.distributed as dist import smdistributed.dataparallel.torch.torch_smddp # noqa: F401
dist.init_process_group() torch.distributed.init_process_group(backend="smddp")
self.local_rank = dist.get_local_rank() self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))
device = torch.device("cuda", self.local_rank) device = torch.device("cuda", self.local_rank)
self._n_gpu = 1 self._n_gpu = 1
elif self.local_rank == -1: elif self.local_rank == -1:
......
...@@ -51,6 +51,7 @@ from .integrations import ( # isort: split ...@@ -51,6 +51,7 @@ from .integrations import ( # isort: split
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist
from packaging import version from packaging import version
from torch import nn from torch import nn
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
...@@ -170,11 +171,6 @@ if is_fairscale_available(): ...@@ -170,11 +171,6 @@ if is_fairscale_available():
from fairscale.optim import OSS from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler from fairscale.optim.grad_scaler import ShardedGradScaler
if is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.distributed as dist
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
else:
import torch.distributed as dist
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp import smdistributed.modelparallel.torch as smp
...@@ -1078,7 +1074,9 @@ class Trainer: ...@@ -1078,7 +1074,9 @@ class Trainer:
).to(self.args.device) ).to(self.args.device)
elif is_sagemaker_dp_enabled(): elif is_sagemaker_dp_enabled():
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False) model = nn.parallel.DistributedDataParallel(
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
)
elif self.args.local_rank != -1: elif self.args.local_rank != -1:
kwargs = {} kwargs = {}
if self.args.ddp_find_unused_parameters is not None: if self.args.ddp_find_unused_parameters is not None:
......
...@@ -29,25 +29,15 @@ from typing import Any, Dict, Iterator, List, Optional, Union ...@@ -29,25 +29,15 @@ from typing import Any, Dict, Iterator, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist
from packaging import version from packaging import version
from torch import nn from torch import nn
from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from .tokenization_utils_base import BatchEncoding from .tokenization_utils_base import BatchEncoding
from .utils import ( from .utils import is_sagemaker_mp_enabled, is_torch_tpu_available, is_training_run_on_sagemaker, logging
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_tpu_available,
is_training_run_on_sagemaker,
logging,
)
if is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.distributed as dist
else:
import torch.distributed as dist
if is_training_run_on_sagemaker(): if is_training_run_on_sagemaker():
logging.add_handler(StreamHandler(sys.stdout)) logging.add_handler(StreamHandler(sys.stdout))
......
...@@ -32,7 +32,6 @@ import numpy as np ...@@ -32,7 +32,6 @@ import numpy as np
from .utils import ( from .utils import (
ExplicitEnum, ExplicitEnum,
is_psutil_available, is_psutil_available,
is_sagemaker_dp_enabled,
is_tf_available, is_tf_available,
is_torch_available, is_torch_available,
is_torch_cuda_available, is_torch_cuda_available,
...@@ -263,10 +262,6 @@ def total_processes_number(local_rank): ...@@ -263,10 +262,6 @@ def total_processes_number(local_rank):
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
return xm.xrt_world_size() return xm.xrt_world_size()
elif is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.distributed as dist
return dist.get_world_size()
elif local_rank != -1 and is_torch_available(): elif local_rank != -1 and is_torch_available():
import torch import torch
......
...@@ -41,12 +41,11 @@ from .utils import ( ...@@ -41,12 +41,11 @@ from .utils import (
if is_torch_available(): if is_torch_available():
import torch import torch
import torch.distributed as dist
if is_torch_tpu_available(): if is_torch_tpu_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
if is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.distributed as sm_dist
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp import smdistributed.modelparallel.torch as smp
...@@ -1046,8 +1045,8 @@ class TrainingArguments: ...@@ -1046,8 +1045,8 @@ class TrainingArguments:
device = torch.device("cuda", local_rank) device = torch.device("cuda", local_rank)
self._n_gpu = 1 self._n_gpu = 1
elif is_sagemaker_dp_enabled(): elif is_sagemaker_dp_enabled():
sm_dist.init_process_group() dist.init_process_group(backend="smddp")
self.local_rank = sm_dist.get_local_rank() self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))
device = torch.device("cuda", self.local_rank) device = torch.device("cuda", self.local_rank)
self._n_gpu = 1 self._n_gpu = 1
elif self.deepspeed: elif self.deepspeed:
...@@ -1149,7 +1148,7 @@ class TrainingArguments: ...@@ -1149,7 +1148,7 @@ class TrainingArguments:
elif is_sagemaker_mp_enabled(): elif is_sagemaker_mp_enabled():
return smp.dp_size() if not smp.state.cfg.prescaled_batch else smp.rdp_size() return smp.dp_size() if not smp.state.cfg.prescaled_batch else smp.rdp_size()
elif is_sagemaker_dp_enabled(): elif is_sagemaker_dp_enabled():
return sm_dist.get_world_size() return dist.get_world_size()
elif self.local_rank != -1: elif self.local_rank != -1:
return torch.distributed.get_world_size() return torch.distributed.get_world_size()
return 1 return 1
...@@ -1165,7 +1164,7 @@ class TrainingArguments: ...@@ -1165,7 +1164,7 @@ class TrainingArguments:
elif is_sagemaker_mp_enabled(): elif is_sagemaker_mp_enabled():
return smp.dp_rank() if not smp.state.cfg.prescaled_batch else smp.rdp_rank() return smp.dp_rank() if not smp.state.cfg.prescaled_batch else smp.rdp_rank()
elif is_sagemaker_dp_enabled(): elif is_sagemaker_dp_enabled():
return sm_dist.get_rank() return dist.get_rank()
elif self.local_rank != -1: elif self.local_rank != -1:
return torch.distributed.get_rank() return torch.distributed.get_rank()
return 0 return 0
...@@ -1181,7 +1180,7 @@ class TrainingArguments: ...@@ -1181,7 +1180,7 @@ class TrainingArguments:
elif is_sagemaker_mp_enabled(): elif is_sagemaker_mp_enabled():
return smp.local_rank() return smp.local_rank()
elif is_sagemaker_dp_enabled(): elif is_sagemaker_dp_enabled():
return sm_dist.get_rank() return dist.get_rank()
elif self.local_rank != -1: elif self.local_rank != -1:
return self.local_rank return self.local_rank
return 0 return 0
...@@ -1281,7 +1280,7 @@ class TrainingArguments: ...@@ -1281,7 +1280,7 @@ class TrainingArguments:
if is_torch_tpu_available(): if is_torch_tpu_available():
xm.rendezvous(desc) xm.rendezvous(desc)
elif is_sagemaker_dp_enabled(): elif is_sagemaker_dp_enabled():
sm_dist.barrier() dist.barrier()
else: else:
torch.distributed.barrier() torch.distributed.barrier()
yield yield
...@@ -1292,7 +1291,7 @@ class TrainingArguments: ...@@ -1292,7 +1291,7 @@ class TrainingArguments:
if is_torch_tpu_available(): if is_torch_tpu_available():
xm.rendezvous(desc) xm.rendezvous(desc)
elif is_sagemaker_dp_enabled(): elif is_sagemaker_dp_enabled():
sm_dist.barrier() dist.barrier()
else: else:
torch.distributed.barrier() torch.distributed.barrier()
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment