Unverified Commit f689743e authored by Philipp Schmid's avatar Philipp Schmid Committed by GitHub
Browse files

SageMaker: Fix sagemaker DDP & metric logs (#13181)



* Barrier -> barrier

* added logger for metrics

* removed stream handler in trainer

* moved handler

* removed streamhandler from trainer

* updated test image and instance type added datasets version to test

* Update tests/sagemaker/scripts/pytorch/requirements.txt
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
parent 8679bd71
......@@ -26,7 +26,6 @@ import shutil
import sys
import time
import warnings
from logging import StreamHandler
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
......@@ -68,7 +67,6 @@ from .file_utils import (
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_tpu_available,
is_training_run_on_sagemaker,
)
from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, unwrap_model
......@@ -173,9 +171,6 @@ if is_sagemaker_mp_enabled():
from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
if is_training_run_on_sagemaker():
logging.add_handler(StreamHandler(sys.stdout))
if TYPE_CHECKING:
import optuna
......
......@@ -20,9 +20,11 @@ import datetime
import json
import math
import os
import sys
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from logging import StreamHandler
from typing import Dict, Iterator, List, Optional, Union
import numpy as np
......@@ -32,7 +34,12 @@ from torch import nn
from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler
from torch.utils.data.distributed import DistributedSampler
from .file_utils import is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_tpu_available
from .file_utils import (
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_tpu_available,
is_training_run_on_sagemaker,
)
from .tokenization_utils_base import BatchEncoding
from .utils import logging
......@@ -42,6 +49,8 @@ if is_sagemaker_dp_enabled():
else:
import torch.distributed as dist
if is_training_run_on_sagemaker():
logging.add_handler(StreamHandler(sys.stdout))
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
......
......@@ -1053,7 +1053,7 @@ class TrainingArguments:
if is_torch_tpu_available():
xm.rendezvous(desc)
elif is_sagemaker_dp_enabled():
sm_dist.Barrier()
sm_dist.barrier()
else:
torch.distributed.barrier()
yield
......@@ -1064,7 +1064,7 @@ class TrainingArguments:
if is_torch_tpu_available():
xm.rendezvous(desc)
elif is_sagemaker_dp_enabled():
sm_dist.Barrier()
sm_dist.barrier()
else:
torch.distributed.barrier()
else:
......
......@@ -17,8 +17,8 @@ class SageMakerTestEnvironment:
role = "arn:aws:iam::558105141721:role/sagemaker_execution_role"
hyperparameters = {
"task_name": "mnli",
"per_device_train_batch_size": 32,
"per_device_eval_batch_size": 32,
"per_device_train_batch_size": 16,
"per_device_eval_batch_size": 16,
"do_train": True,
"do_eval": True,
"do_predict": True,
......@@ -55,9 +55,9 @@ class SageMakerTestEnvironment:
@property
def image_uri(self) -> str:
if self.framework == "pytorch":
return "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04"
return "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04"
else:
return "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-tensorflow-training:2.4.1-transformers4.4.2-gpu-py37-cu110-ubuntu18.04"
return "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-tensorflow-training:2.4.1-transformers4.6.1-gpu-py37-cu110-ubuntu18.04"
@pytest.fixture(scope="class")
......
git+https://github.com/huggingface/transformers.git@master # install master or adjust ist with vX.X.X for installing version specific transforms
\ No newline at end of file
git+https://github.com/huggingface/transformers.git@master # install master or adjust it with vX.X.X for installing version specific transforms
datasets==1.8.0
\ No newline at end of file
......@@ -27,21 +27,21 @@ if is_sagemaker_available():
"framework": "pytorch",
"script": "run_glue.py",
"model_name_or_path": "distilbert-base-cased",
"instance_type": "ml.p3dn.24xlarge",
"instance_type": "ml.p3.16xlarge",
"results": {"train_runtime": 650, "eval_accuracy": 0.7, "eval_loss": 0.6},
},
{
"framework": "pytorch",
"script": "run_ddp.py",
"model_name_or_path": "distilbert-base-cased",
"instance_type": "ml.p3dn.24xlarge",
"instance_type": "ml.p3.16xlarge",
"results": {"train_runtime": 600, "eval_accuracy": 0.7, "eval_loss": 0.6},
},
{
"framework": "tensorflow",
"script": "run_tf_dist.py",
"model_name_or_path": "distilbert-base-cased",
"instance_type": "ml.p3dn.24xlarge",
"instance_type": "ml.p3.16xlarge",
"results": {"train_runtime": 600, "eval_accuracy": 0.6, "eval_loss": 0.7},
},
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment