Unverified Commit 75769744 authored by huismiling's avatar huismiling Committed by GitHub
Browse files

add Cambricon MLUs support (#29627)

* add Cambricon MLUs support

* fix mlu device rng state

* up for quality check

* up mlu to support fp16

* fix mlu device dependency error

* fix mlu device dependency error

* enable mlu device for bf16

* fix mlu device memory tracker
parent 0efcf323
...@@ -1109,6 +1109,7 @@ _import_structure = { ...@@ -1109,6 +1109,7 @@ _import_structure = {
"is_timm_available", "is_timm_available",
"is_tokenizers_available", "is_tokenizers_available",
"is_torch_available", "is_torch_available",
"is_torch_mlu_available",
"is_torch_neuroncore_available", "is_torch_neuroncore_available",
"is_torch_npu_available", "is_torch_npu_available",
"is_torch_tpu_available", "is_torch_tpu_available",
...@@ -5987,6 +5988,7 @@ if TYPE_CHECKING: ...@@ -5987,6 +5988,7 @@ if TYPE_CHECKING:
is_timm_available, is_timm_available,
is_tokenizers_available, is_tokenizers_available,
is_torch_available, is_torch_available,
is_torch_mlu_available,
is_torch_neuroncore_available, is_torch_neuroncore_available,
is_torch_npu_available, is_torch_npu_available,
is_torch_tpu_available, is_torch_tpu_available,
......
...@@ -21,7 +21,7 @@ import weakref ...@@ -21,7 +21,7 @@ import weakref
from functools import partialmethod from functools import partialmethod
from ..dependency_versions_check import dep_version_check from ..dependency_versions_check import dep_version_check
from ..utils import is_accelerate_available, is_torch_available, logging from ..utils import is_accelerate_available, is_torch_available, is_torch_mlu_available, logging
if is_torch_available(): if is_torch_available():
...@@ -38,6 +38,9 @@ def is_deepspeed_available(): ...@@ -38,6 +38,9 @@ def is_deepspeed_available():
# AND checking it has an author field in the metadata that is HuggingFace. # AND checking it has an author field in the metadata that is HuggingFace.
if package_exists: if package_exists:
try: try:
if is_torch_mlu_available():
_ = importlib_metadata.metadata("deepspeed-mlu")
return True
_ = importlib_metadata.metadata("deepspeed") _ = importlib_metadata.metadata("deepspeed")
return True return True
except importlib_metadata.PackageNotFoundError: except importlib_metadata.PackageNotFoundError:
......
...@@ -41,6 +41,7 @@ from ..utils import ( ...@@ -41,6 +41,7 @@ from ..utils import (
is_tf_available, is_tf_available,
is_torch_available, is_torch_available,
is_torch_cuda_available, is_torch_cuda_available,
is_torch_mlu_available,
is_torch_npu_available, is_torch_npu_available,
is_torch_xpu_available, is_torch_xpu_available,
logging, logging,
...@@ -851,6 +852,8 @@ class Pipeline(_ScikitCompat): ...@@ -851,6 +852,8 @@ class Pipeline(_ScikitCompat):
self.device = torch.device(device) self.device = torch.device(device)
elif device < 0: elif device < 0:
self.device = torch.device("cpu") self.device = torch.device("cpu")
elif is_torch_mlu_available():
self.device = torch.device(f"mlu:{device}")
elif is_torch_cuda_available(): elif is_torch_cuda_available():
self.device = torch.device(f"cuda:{device}") self.device = torch.device(f"cuda:{device}")
elif is_torch_npu_available(): elif is_torch_npu_available():
...@@ -995,6 +998,9 @@ class Pipeline(_ScikitCompat): ...@@ -995,6 +998,9 @@ class Pipeline(_ScikitCompat):
if self.device.type == "cuda": if self.device.type == "cuda":
with torch.cuda.device(self.device): with torch.cuda.device(self.device):
yield yield
elif self.device.type == "mlu":
with torch.mlu.device(self.device):
yield
else: else:
yield yield
......
...@@ -151,6 +151,7 @@ from .utils import ( ...@@ -151,6 +151,7 @@ from .utils import (
is_sagemaker_dp_enabled, is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled, is_sagemaker_mp_enabled,
is_torch_compile_available, is_torch_compile_available,
is_torch_mlu_available,
is_torch_neuroncore_available, is_torch_neuroncore_available,
is_torch_npu_available, is_torch_npu_available,
is_torch_xla_available, is_torch_xla_available,
...@@ -2671,6 +2672,17 @@ class Trainer: ...@@ -2671,6 +2672,17 @@ class Trainer:
f"Didn't manage to set back the RNG states of the NPU because of the following error:\n {e}" f"Didn't manage to set back the RNG states of the NPU because of the following error:\n {e}"
"\nThis won't yield the same results as if the training had not been interrupted." "\nThis won't yield the same results as if the training had not been interrupted."
) )
if is_torch_mlu_available():
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
torch.mlu.random.set_rng_state_all(checkpoint_rng_state["mlu"])
else:
try:
torch.mlu.random.set_rng_state(checkpoint_rng_state["mlu"])
except Exception as e:
logger.info(
f"Didn't manage to set back the RNG states of the MLU because of the following error:\n {e}"
"\nThis won't yield the same results as if the training had not been interrupted."
)
def _save_checkpoint(self, model, trial, metrics=None): def _save_checkpoint(self, model, trial, metrics=None):
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
...@@ -2745,6 +2757,12 @@ class Trainer: ...@@ -2745,6 +2757,12 @@ class Trainer:
else: else:
rng_states["npu"] = torch.npu.random.get_rng_state() rng_states["npu"] = torch.npu.random.get_rng_state()
if is_torch_mlu_available():
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
rng_states["mlu"] = torch.mlu.random.get_rng_state_all()
else:
rng_states["mlu"] = torch.mlu.random.get_rng_state()
# A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
# not yet exist. # not yet exist.
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
......
...@@ -35,6 +35,7 @@ from .utils import ( ...@@ -35,6 +35,7 @@ from .utils import (
is_tf_available, is_tf_available,
is_torch_available, is_torch_available,
is_torch_cuda_available, is_torch_cuda_available,
is_torch_mlu_available,
is_torch_mps_available, is_torch_mps_available,
is_torch_npu_available, is_torch_npu_available,
is_torch_xla_available, is_torch_xla_available,
...@@ -100,6 +101,8 @@ def set_seed(seed: int, deterministic: bool = False): ...@@ -100,6 +101,8 @@ def set_seed(seed: int, deterministic: bool = False):
# ^^ safe to call this function even if cuda is not available # ^^ safe to call this function even if cuda is not available
if deterministic: if deterministic:
torch.use_deterministic_algorithms(True) torch.use_deterministic_algorithms(True)
if is_torch_mlu_available():
torch.mlu.manual_seed_all(seed)
if is_torch_npu_available(): if is_torch_npu_available():
torch.npu.manual_seed_all(seed) torch.npu.manual_seed_all(seed)
if is_torch_xpu_available(): if is_torch_xpu_available():
...@@ -455,7 +458,7 @@ class TrainerMemoryTracker: ...@@ -455,7 +458,7 @@ class TrainerMemoryTracker:
import psutil # noqa import psutil # noqa
if is_torch_cuda_available(): if is_torch_cuda_available() or is_torch_mlu_available():
import torch import torch
self.torch = torch self.torch = torch
...@@ -528,6 +531,9 @@ class TrainerMemoryTracker: ...@@ -528,6 +531,9 @@ class TrainerMemoryTracker:
if torch.cuda.is_available(): if torch.cuda.is_available():
self.torch.cuda.reset_peak_memory_stats() self.torch.cuda.reset_peak_memory_stats()
self.torch.cuda.empty_cache() self.torch.cuda.empty_cache()
elif is_torch_mlu_available():
self.torch.mlu.reset_peak_memory_stats()
self.torch.mlu.empty_cache()
elif is_torch_xpu_available(): elif is_torch_xpu_available():
self.torch.xpu.reset_peak_memory_stats() self.torch.xpu.reset_peak_memory_stats()
self.torch.xpu.empty_cache() self.torch.xpu.empty_cache()
...@@ -541,6 +547,8 @@ class TrainerMemoryTracker: ...@@ -541,6 +547,8 @@ class TrainerMemoryTracker:
if self.torch is not None: if self.torch is not None:
if torch.cuda.is_available(): if torch.cuda.is_available():
self.gpu_mem_used_at_start = self.torch.cuda.memory_allocated() self.gpu_mem_used_at_start = self.torch.cuda.memory_allocated()
elif is_torch_mlu_available():
self.gpu_mem_used_at_start = self.torch.mlu.memory_allocated()
elif is_torch_xpu_available(): elif is_torch_xpu_available():
self.gpu_mem_used_at_start = self.torch.xpu.memory_allocated() self.gpu_mem_used_at_start = self.torch.xpu.memory_allocated()
elif is_torch_npu_available(): elif is_torch_npu_available():
...@@ -572,6 +580,8 @@ class TrainerMemoryTracker: ...@@ -572,6 +580,8 @@ class TrainerMemoryTracker:
if self.torch is not None: if self.torch is not None:
if torch.cuda.is_available(): if torch.cuda.is_available():
self.torch.cuda.empty_cache() self.torch.cuda.empty_cache()
elif is_torch_mlu_available():
self.torch.mlu.empty_cache()
elif is_torch_xpu_available(): elif is_torch_xpu_available():
self.torch.xpu.empty_cache() self.torch.xpu.empty_cache()
elif is_torch_npu_available(): elif is_torch_npu_available():
...@@ -589,6 +599,9 @@ class TrainerMemoryTracker: ...@@ -589,6 +599,9 @@ class TrainerMemoryTracker:
if torch.cuda.is_available(): if torch.cuda.is_available():
self.gpu_mem_used_now = self.torch.cuda.memory_allocated() self.gpu_mem_used_now = self.torch.cuda.memory_allocated()
self.gpu_mem_used_peak = self.torch.cuda.max_memory_allocated() self.gpu_mem_used_peak = self.torch.cuda.max_memory_allocated()
elif is_torch_mlu_available():
self.gpu_mem_used_now = self.torch.mlu.memory_allocated()
self.gpu_mem_used_peak = self.torch.mlu.max_memory_allocated()
elif is_torch_xpu_available(): elif is_torch_xpu_available():
self.gpu_mem_used_now = self.torch.xpu.memory_allocated() self.gpu_mem_used_now = self.torch.xpu.memory_allocated()
self.gpu_mem_used_peak = self.torch.xpu.max_memory_allocated() self.gpu_mem_used_peak = self.torch.xpu.max_memory_allocated()
......
...@@ -46,6 +46,7 @@ from .utils import ( ...@@ -46,6 +46,7 @@ from .utils import (
is_torch_available, is_torch_available,
is_torch_bf16_cpu_available, is_torch_bf16_cpu_available,
is_torch_bf16_gpu_available, is_torch_bf16_gpu_available,
is_torch_mlu_available,
is_torch_neuroncore_available, is_torch_neuroncore_available,
is_torch_npu_available, is_torch_npu_available,
is_torch_tf32_available, is_torch_tf32_available,
...@@ -993,7 +994,7 @@ class TrainingArguments: ...@@ -993,7 +994,7 @@ class TrainingArguments:
default=None, default=None,
metadata={ metadata={
"help": "The backend to be used for distributed training", "help": "The backend to be used for distributed training",
"choices": ["nccl", "gloo", "mpi", "ccl", "hccl"], "choices": ["nccl", "gloo", "mpi", "ccl", "hccl", "cncl"],
}, },
) )
tpu_num_cores: Optional[int] = field( tpu_num_cores: Optional[int] = field(
...@@ -1549,6 +1550,7 @@ class TrainingArguments: ...@@ -1549,6 +1550,7 @@ class TrainingArguments:
self.framework == "pt" self.framework == "pt"
and is_torch_available() and is_torch_available()
and (self.device.type != "cuda") and (self.device.type != "cuda")
and (self.device.type != "mlu")
and (self.device.type != "npu") and (self.device.type != "npu")
and (self.device.type != "xpu") and (self.device.type != "xpu")
and (get_xla_device_type(self.device) not in ["GPU", "CUDA"]) and (get_xla_device_type(self.device) not in ["GPU", "CUDA"])
...@@ -1556,13 +1558,14 @@ class TrainingArguments: ...@@ -1556,13 +1558,14 @@ class TrainingArguments:
): ):
raise ValueError( raise ValueError(
"FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation" "FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation"
" (`--fp16_full_eval`) can only be used on CUDA or NPU devices or certain XPU devices (with IPEX)." " (`--fp16_full_eval`) can only be used on CUDA or MLU devices or NPU devices or certain XPU devices (with IPEX)."
) )
if ( if (
self.framework == "pt" self.framework == "pt"
and is_torch_available() and is_torch_available()
and (self.device.type != "cuda") and (self.device.type != "cuda")
and (self.device.type != "mlu")
and (self.device.type != "npu") and (self.device.type != "npu")
and (self.device.type != "xpu") and (self.device.type != "xpu")
and (get_xla_device_type(self.device) not in ["GPU", "CUDA"]) and (get_xla_device_type(self.device) not in ["GPU", "CUDA"])
...@@ -1572,7 +1575,7 @@ class TrainingArguments: ...@@ -1572,7 +1575,7 @@ class TrainingArguments:
): ):
raise ValueError( raise ValueError(
"BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation" "BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation"
" (`--bf16_full_eval`) can only be used on CUDA, XPU (with IPEX), NPU or CPU/TPU/NeuronCore devices." " (`--bf16_full_eval`) can only be used on CUDA, XPU (with IPEX), NPU, MLU or CPU/TPU/NeuronCore devices."
) )
if self.torchdynamo is not None: if self.torchdynamo is not None:
...@@ -1999,6 +2002,10 @@ class TrainingArguments: ...@@ -1999,6 +2002,10 @@ class TrainingArguments:
device = torch.device("xpu:0") device = torch.device("xpu:0")
torch.xpu.set_device(device) torch.xpu.set_device(device)
self._n_gpu = 1 self._n_gpu = 1
elif is_torch_mlu_available():
device = torch.device("mlu:0")
torch.mlu.set_device(device)
self._n_gpu = 1
elif is_torch_npu_available(): elif is_torch_npu_available():
device = torch.device("npu:0") device = torch.device("npu:0")
torch.npu.set_device(device) torch.npu.set_device(device)
......
...@@ -185,6 +185,7 @@ from .import_utils import ( ...@@ -185,6 +185,7 @@ from .import_utils import (
is_torch_fp16_available_on_device, is_torch_fp16_available_on_device,
is_torch_fx_available, is_torch_fx_available,
is_torch_fx_proxy, is_torch_fx_proxy,
is_torch_mlu_available,
is_torch_mps_available, is_torch_mps_available,
is_torch_neuroncore_available, is_torch_neuroncore_available,
is_torch_npu_available, is_torch_npu_available,
......
...@@ -586,6 +586,29 @@ def is_torch_npu_available(check_device=False): ...@@ -586,6 +586,29 @@ def is_torch_npu_available(check_device=False):
return hasattr(torch, "npu") and torch.npu.is_available() return hasattr(torch, "npu") and torch.npu.is_available()
@lru_cache()
def is_torch_mlu_available(check_device=False):
"Checks if `torch_mlu` is installed and potentially if a MLU is in the environment"
if not _torch_available or importlib.util.find_spec("torch_mlu") is None:
return False
import torch
import torch_mlu # noqa: F401
from ..dependency_versions_table import deps
deps["deepspeed"] = "deepspeed-mlu>=0.10.1"
if check_device:
try:
# Will raise a RuntimeError if no MLU is found
_ = torch.mlu.device_count()
return torch.mlu.is_available()
except RuntimeError:
return False
return hasattr(torch, "mlu") and torch.mlu.is_available()
def is_torchdynamo_available(): def is_torchdynamo_available():
if not is_torch_available(): if not is_torch_available():
return False return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment