Unverified Commit f689743e authored by Philipp Schmid's avatar Philipp Schmid Committed by GitHub
Browse files

SageMaker: Fix sagemaker DDP & metric logs (#13181)



* Barrier -> barrier

* added logger for metrics

* removed stream handler in trainer

* moved handler

* removed streamhandler from trainer

* updated test image and instance type added datasets version to test

* Update tests/sagemaker/scripts/pytorch/requirements.txt
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
parent 8679bd71
...@@ -26,7 +26,6 @@ import shutil ...@@ -26,7 +26,6 @@ import shutil
import sys import sys
import time import time
import warnings import warnings
from logging import StreamHandler
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
...@@ -68,7 +67,6 @@ from .file_utils import ( ...@@ -68,7 +67,6 @@ from .file_utils import (
is_sagemaker_dp_enabled, is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled, is_sagemaker_mp_enabled,
is_torch_tpu_available, is_torch_tpu_available,
is_training_run_on_sagemaker,
) )
from .modelcard import TrainingSummary from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, unwrap_model from .modeling_utils import PreTrainedModel, unwrap_model
...@@ -173,9 +171,6 @@ if is_sagemaker_mp_enabled(): ...@@ -173,9 +171,6 @@ if is_sagemaker_mp_enabled():
from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
if is_training_run_on_sagemaker():
logging.add_handler(StreamHandler(sys.stdout))
if TYPE_CHECKING: if TYPE_CHECKING:
import optuna import optuna
......
...@@ -20,9 +20,11 @@ import datetime ...@@ -20,9 +20,11 @@ import datetime
import json import json
import math import math
import os import os
import sys
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from logging import StreamHandler
from typing import Dict, Iterator, List, Optional, Union from typing import Dict, Iterator, List, Optional, Union
import numpy as np import numpy as np
...@@ -32,7 +34,12 @@ from torch import nn ...@@ -32,7 +34,12 @@ from torch import nn
from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from .file_utils import is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_tpu_available from .file_utils import (
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_tpu_available,
is_training_run_on_sagemaker,
)
from .tokenization_utils_base import BatchEncoding from .tokenization_utils_base import BatchEncoding
from .utils import logging from .utils import logging
...@@ -42,6 +49,8 @@ if is_sagemaker_dp_enabled(): ...@@ -42,6 +49,8 @@ if is_sagemaker_dp_enabled():
else: else:
import torch.distributed as dist import torch.distributed as dist
if is_training_run_on_sagemaker():
logging.add_handler(StreamHandler(sys.stdout))
if is_torch_tpu_available(): if is_torch_tpu_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
......
...@@ -1053,7 +1053,7 @@ class TrainingArguments: ...@@ -1053,7 +1053,7 @@ class TrainingArguments:
if is_torch_tpu_available(): if is_torch_tpu_available():
xm.rendezvous(desc) xm.rendezvous(desc)
elif is_sagemaker_dp_enabled(): elif is_sagemaker_dp_enabled():
sm_dist.Barrier() sm_dist.barrier()
else: else:
torch.distributed.barrier() torch.distributed.barrier()
yield yield
...@@ -1064,7 +1064,7 @@ class TrainingArguments: ...@@ -1064,7 +1064,7 @@ class TrainingArguments:
if is_torch_tpu_available(): if is_torch_tpu_available():
xm.rendezvous(desc) xm.rendezvous(desc)
elif is_sagemaker_dp_enabled(): elif is_sagemaker_dp_enabled():
sm_dist.Barrier() sm_dist.barrier()
else: else:
torch.distributed.barrier() torch.distributed.barrier()
else: else:
......
...@@ -17,8 +17,8 @@ class SageMakerTestEnvironment: ...@@ -17,8 +17,8 @@ class SageMakerTestEnvironment:
role = "arn:aws:iam::558105141721:role/sagemaker_execution_role" role = "arn:aws:iam::558105141721:role/sagemaker_execution_role"
hyperparameters = { hyperparameters = {
"task_name": "mnli", "task_name": "mnli",
"per_device_train_batch_size": 32, "per_device_train_batch_size": 16,
"per_device_eval_batch_size": 32, "per_device_eval_batch_size": 16,
"do_train": True, "do_train": True,
"do_eval": True, "do_eval": True,
"do_predict": True, "do_predict": True,
...@@ -55,9 +55,9 @@ class SageMakerTestEnvironment: ...@@ -55,9 +55,9 @@ class SageMakerTestEnvironment:
@property @property
def image_uri(self) -> str: def image_uri(self) -> str:
if self.framework == "pytorch": if self.framework == "pytorch":
return "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" return "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04"
else: else:
return "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-tensorflow-training:2.4.1-transformers4.4.2-gpu-py37-cu110-ubuntu18.04" return "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-tensorflow-training:2.4.1-transformers4.6.1-gpu-py37-cu110-ubuntu18.04"
@pytest.fixture(scope="class") @pytest.fixture(scope="class")
......
git+https://github.com/huggingface/transformers.git@master # install master or adjust ist with vX.X.X for installing version specific transforms git+https://github.com/huggingface/transformers.git@master # install master or adjust it with vX.X.X for installing version specific transforms
\ No newline at end of file datasets==1.8.0
\ No newline at end of file
...@@ -27,21 +27,21 @@ if is_sagemaker_available(): ...@@ -27,21 +27,21 @@ if is_sagemaker_available():
"framework": "pytorch", "framework": "pytorch",
"script": "run_glue.py", "script": "run_glue.py",
"model_name_or_path": "distilbert-base-cased", "model_name_or_path": "distilbert-base-cased",
"instance_type": "ml.p3dn.24xlarge", "instance_type": "ml.p3.16xlarge",
"results": {"train_runtime": 650, "eval_accuracy": 0.7, "eval_loss": 0.6}, "results": {"train_runtime": 650, "eval_accuracy": 0.7, "eval_loss": 0.6},
}, },
{ {
"framework": "pytorch", "framework": "pytorch",
"script": "run_ddp.py", "script": "run_ddp.py",
"model_name_or_path": "distilbert-base-cased", "model_name_or_path": "distilbert-base-cased",
"instance_type": "ml.p3dn.24xlarge", "instance_type": "ml.p3.16xlarge",
"results": {"train_runtime": 600, "eval_accuracy": 0.7, "eval_loss": 0.6}, "results": {"train_runtime": 600, "eval_accuracy": 0.7, "eval_loss": 0.6},
}, },
{ {
"framework": "tensorflow", "framework": "tensorflow",
"script": "run_tf_dist.py", "script": "run_tf_dist.py",
"model_name_or_path": "distilbert-base-cased", "model_name_or_path": "distilbert-base-cased",
"instance_type": "ml.p3dn.24xlarge", "instance_type": "ml.p3.16xlarge",
"results": {"train_runtime": 600, "eval_accuracy": 0.6, "eval_loss": 0.7}, "results": {"train_runtime": 600, "eval_accuracy": 0.6, "eval_loss": 0.7},
}, },
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment