Unverified Commit 5c57463b authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Enable fp16 on CPU (#30459)

* Check removing flag for torch

* LLM oops

* Getting there...

* More discoveries

* Change

* Clean up and prettify

* Logic check

* Not
parent d1d94d79
...@@ -28,6 +28,7 @@ logger = logging.get_logger(__name__) ...@@ -28,6 +28,7 @@ logger = logging.get_logger(__name__)
parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version) parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)
is_torch_greater_or_equal_than_2_3 = parsed_torch_version_base >= version.parse("2.3")
is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2") is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2")
is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1") is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1")
is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0") is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0")
......
...@@ -69,7 +69,11 @@ from .models.auto.modeling_auto import ( ...@@ -69,7 +69,11 @@ from .models.auto.modeling_auto import (
MODEL_MAPPING_NAMES, MODEL_MAPPING_NAMES,
) )
from .optimization import Adafactor, get_scheduler from .optimization import Adafactor, get_scheduler
from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13 from .pytorch_utils import (
ALL_LAYERNORM_LAYERS,
is_torch_greater_or_equal_than_1_13,
is_torch_greater_or_equal_than_2_3,
)
from .tokenization_utils_base import PreTrainedTokenizerBase from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import ( from .trainer_callback import (
CallbackHandler, CallbackHandler,
...@@ -620,7 +624,8 @@ class Trainer: ...@@ -620,7 +624,8 @@ class Trainer:
if (args.fp16 or args.bf16) and args.half_precision_backend == "auto": if (args.fp16 or args.bf16) and args.half_precision_backend == "auto":
if args.device == torch.device("cpu"): if args.device == torch.device("cpu"):
if args.fp16: if args.fp16:
raise ValueError("Tried to use `fp16` but it is not supported on cpu") if not is_torch_greater_or_equal_than_2_3:
raise ValueError("Tried to use `fp16` but it is not supported on cpu")
else: else:
args.half_precision_backend = "cpu_amp" args.half_precision_backend = "cpu_amp"
logger.info(f"Using {args.half_precision_backend} half precision backend") logger.info(f"Using {args.half_precision_backend} half precision backend")
......
...@@ -67,7 +67,7 @@ if is_torch_available(): ...@@ -67,7 +67,7 @@ if is_torch_available():
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from .pytorch_utils import is_torch_greater_or_equal_than_2_0 from .pytorch_utils import is_torch_greater_or_equal_than_2_0, is_torch_greater_or_equal_than_2_3
if is_accelerate_available(): if is_accelerate_available():
from accelerate.state import AcceleratorState, PartialState from accelerate.state import AcceleratorState, PartialState
...@@ -1618,6 +1618,7 @@ class TrainingArguments: ...@@ -1618,6 +1618,7 @@ class TrainingArguments:
if ( if (
self.framework == "pt" self.framework == "pt"
and is_torch_available() and is_torch_available()
and (self.device.type == "cpu" and not is_torch_greater_or_equal_than_2_3)
and (self.device.type != "cuda") and (self.device.type != "cuda")
and (self.device.type != "mlu") and (self.device.type != "mlu")
and (self.device.type != "npu") and (self.device.type != "npu")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment