Unverified Commit 537deb78 authored by Hengwen Tong's avatar Hengwen Tong Committed by GitHub
Browse files

Remove redundant backend checks in training_args.py (#30999)



* Remove backend checks in training_args.py

* Expilicit initialize the device

---------
Co-authored-by: default avatartonghengwen <tonghengwen@cambricon.com>
parent dd4654ea
...@@ -67,7 +67,7 @@ if is_torch_available(): ...@@ -67,7 +67,7 @@ if is_torch_available():
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from .pytorch_utils import is_torch_greater_or_equal_than_2_0, is_torch_greater_or_equal_than_2_3 from .pytorch_utils import is_torch_greater_or_equal_than_2_0
if is_accelerate_available(): if is_accelerate_available():
from accelerate.state import AcceleratorState, PartialState from accelerate.state import AcceleratorState, PartialState
...@@ -1677,38 +1677,9 @@ class TrainingArguments: ...@@ -1677,38 +1677,9 @@ class TrainingArguments:
) )
self.accelerator_config.split_batches = self.split_batches self.accelerator_config.split_batches = self.split_batches
if ( # Initialize device before we proceed
self.framework == "pt" if self.framework == "pt" and is_torch_available():
and is_torch_available() self.device
and (self.device.type == "cpu" and not is_torch_greater_or_equal_than_2_3)
and (self.device.type != "cuda")
and (self.device.type != "mlu")
and (self.device.type != "npu")
and (self.device.type != "xpu")
and (get_xla_device_type(self.device) not in ["GPU", "CUDA"])
and (self.fp16 or self.fp16_full_eval)
):
raise ValueError(
"FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation"
" (`--fp16_full_eval`) can only be used on CUDA or MLU devices or NPU devices or certain XPU devices (with IPEX)."
)
if (
self.framework == "pt"
and is_torch_available()
and (self.device.type != "cuda")
and (self.device.type != "mlu")
and (self.device.type != "npu")
and (self.device.type != "xpu")
and (get_xla_device_type(self.device) not in ["GPU", "CUDA"])
and (get_xla_device_type(self.device) != "TPU")
and (self.device.type != "cpu")
and (self.bf16 or self.bf16_full_eval)
):
raise ValueError(
"BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation"
" (`--bf16_full_eval`) can only be used on CUDA, XPU (with IPEX), NPU, MLU or CPU/TPU/NeuronCore devices."
)
if self.torchdynamo is not None: if self.torchdynamo is not None:
warnings.warn( warnings.warn(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment