Unverified Commit 10fab90f authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

`torch.distributed` group initialization for `torch_neuron` disabled when...

`torch.distributed` group initialization for `torch_neuron` disabled when `optimum-neuron` is installed (#22728)

* Make the process group initialization not happen if optimum_neuron is installed

* Add warning

* Remove list and added warning
parent 1306b7d3
...@@ -54,8 +54,13 @@ from .utils import ( ...@@ -54,8 +54,13 @@ from .utils import (
logging, logging,
requires_backends, requires_backends,
) )
from .utils.import_utils import is_optimum_neuron_available
logger = logging.get_logger(__name__)
log_levels = logging.get_log_levels_dict().copy()
trainer_log_levels = dict(**log_levels, passive=-1)
if is_torch_available(): if is_torch_available():
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -67,12 +72,23 @@ if is_torch_neuroncore_available(check_device=False): ...@@ -67,12 +72,23 @@ if is_torch_neuroncore_available(check_device=False):
# torchrun support # torchrun support
# https://github.com/pytorch/xla/pull/3609 # https://github.com/pytorch/xla/pull/3609
if os.environ.get("TORCHELASTIC_RUN_ID"): if os.environ.get("TORCHELASTIC_RUN_ID"):
import torch_xla.distributed.xla_backend as xbn if is_optimum_neuron_available():
logger.info(
"Make sure that you are performing the training with the TrainiumTrainer from optimum[neuron], this "
"will fail otherwise."
)
else:
logger.warning(
"Please use the TrainiumTrainer from optimum[neuron] instead of the Transformers library to perform "
"training on AWS Trainium instances. More information here: "
"https://github.com/huggingface/optimum-neuron"
)
import torch_xla.distributed.xla_backend as xbn
if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla):
torch.distributed.init_process_group(backend="xla")
if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla): if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla):
raise AssertionError("Failed to initialize torch.distributed process group using XLA backend.") torch.distributed.init_process_group(backend="xla")
if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla):
raise AssertionError("Failed to initialize torch.distributed process group using XLA backend.")
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
...@@ -81,11 +97,6 @@ if is_sagemaker_mp_enabled(): ...@@ -81,11 +97,6 @@ if is_sagemaker_mp_enabled():
smp.init() smp.init()
logger = logging.get_logger(__name__)
log_levels = logging.get_log_levels_dict().copy()
trainer_log_levels = dict(**log_levels, passive=-1)
def default_logdir() -> str: def default_logdir() -> str:
""" """
Same default as PyTorch Same default as PyTorch
......
...@@ -583,6 +583,10 @@ def is_optimum_available(): ...@@ -583,6 +583,10 @@ def is_optimum_available():
return importlib.util.find_spec("optimum") is not None return importlib.util.find_spec("optimum") is not None
def is_optimum_neuron_available():
return importlib.util.find_spec("optimum.neuron") is not None
def is_safetensors_available(): def is_safetensors_available():
if is_torch_available(): if is_torch_available():
if version.parse(_torch_version) >= version.parse("1.10"): if version.parse(_torch_version) >= version.parse("1.10"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment