Unverified Commit 8c297cdb authored by Philipp Schmid's avatar Philipp Schmid Committed by GitHub
Browse files

Sm trainer smp init fix (#10870)

* rewrote is_sagemaker_model_parallel_available

* added is_sagemaker_model_parallel_available to SageMakerTrainer

* removed unnecessary mp_parameters as TrainingArguments

* make style happy

* added mp_parameters again to parse mp-specific args.
parent d4d4447d
...@@ -34,13 +34,13 @@ from ..trainer_pt_utils import ( ...@@ -34,13 +34,13 @@ from ..trainer_pt_utils import (
) )
from ..trainer_utils import PREFIX_CHECKPOINT_DIR from ..trainer_utils import PREFIX_CHECKPOINT_DIR
from ..utils import logging from ..utils import logging
from .training_args_sm import is_smdistributed_available from .training_args_sm import is_sagemaker_model_parallel_available
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
if is_smdistributed_available(): if is_sagemaker_model_parallel_available():
import smdistributed.modelparallel.torch as smp import smdistributed.modelparallel.torch as smp
@smp.step() @smp.step()
...@@ -79,7 +79,7 @@ if is_smdistributed_available(): ...@@ -79,7 +79,7 @@ if is_smdistributed_available():
class SageMakerTrainer(Trainer): class SageMakerTrainer(Trainer):
def __init__(self, args=None, **kwargs): def __init__(self, args=None, **kwargs):
self.is_model_parallel_enabled = is_smdistributed_available() and args.mp_parameters != "" self.is_model_parallel_enabled = is_sagemaker_model_parallel_available()
super().__init__(args=args, **kwargs) super().__init__(args=args, **kwargs)
def is_world_process_zero(self) -> bool: def is_world_process_zero(self) -> bool:
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
import importlib.util import importlib.util
import json
import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
import torch import torch
...@@ -24,33 +26,53 @@ from transformers.utils import logging ...@@ -24,33 +26,53 @@ from transformers.utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
# TODO: should be moved to `file_utils` after refactoring of SageMakerTrainer
def is_smdistributed_available():
def is_sagemaker_model_parallel_available():
# Get the sagemaker specific mp parameters from smp_options variable.
smp_options = os.getenv("SM_HP_MP_PARAMETERS", "{}")
try:
# Parse it and check the field "partitions" is included, it is required for model parallel.
smp_options = json.loads(smp_options)
if "partitions" not in smp_options:
return False
except json.JSONDecodeError:
return False
# Get the sagemaker specific framework parameters from mpi_options variable.
mpi_options = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
try:
# Parse it and check the field "sagemaker_distributed_dataparallel_enabled".
mpi_options = json.loads(mpi_options)
if not mpi_options.get("sagemaker_mpi_enabled", False):
return False
except json.JSONDecodeError:
return False
# Lastly, check if the `smdistributed` module is present.
return importlib.util.find_spec("smdistributed") is not None return importlib.util.find_spec("smdistributed") is not None
if is_smdistributed_available(): if is_sagemaker_model_parallel_available():
import smdistributed.modelparallel.torch as smp import smdistributed.modelparallel.torch as smp
smp.init()
@dataclass @dataclass
class SageMakerTrainingArguments(TrainingArguments): class SageMakerTrainingArguments(TrainingArguments):
mp_parameters: str = field( mp_parameters: str = field(
default="", metadata={"help": "Used by the SageMaker launcher to send mp-specific args."} default="",
metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in SageMakerTrainer"},
) )
def __post_init__(self):
super().__post_init__()
if is_smdistributed_available() and self.mp_parameters != "":
smp.init()
@cached_property @cached_property
def _setup_devices(self) -> "torch.device": def _setup_devices(self) -> "torch.device":
logger.info("PyTorch: setting up devices") logger.info("PyTorch: setting up devices")
if self.no_cuda: if self.no_cuda:
device = torch.device("cpu") device = torch.device("cpu")
self._n_gpu = 0 self._n_gpu = 0
elif is_smdistributed_available() and self.mp_parameters != "": elif is_sagemaker_model_parallel_available():
local_rank = smp.local_rank() local_rank = smp.local_rank()
device = torch.device("cuda", local_rank) device = torch.device("cuda", local_rank)
self._n_gpu = 1 self._n_gpu = 1
...@@ -86,14 +108,14 @@ class SageMakerTrainingArguments(TrainingArguments): ...@@ -86,14 +108,14 @@ class SageMakerTrainingArguments(TrainingArguments):
@property @property
def world_size(self): def world_size(self):
if is_smdistributed_available() and self.mp_parameters != "": if is_sagemaker_model_parallel_available():
return smp.dp_size() return smp.dp_size()
return super().world_size return super().world_size
@property @property
def place_model_on_device(self): def place_model_on_device(self):
return not (is_smdistributed_available() and self.mp_parameters != "") return not is_sagemaker_model_parallel_available()
@property @property
def _no_sync_in_gradient_accumulation(self): def _no_sync_in_gradient_accumulation(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment