Unverified Commit 65ea2ddf authored by Aaron Pham's avatar Aaron Pham Committed by GitHub
Browse files

feat(config): support parsing torch.dtype (#1641)


Signed-off-by: default avatarAaron <29749331+aarnphm@users.noreply.github.com>
parent b514d3c4
from typing import Optional
from typing import Optional, Union
import torch
from transformers import PretrainedConfig
......@@ -58,7 +58,7 @@ class ModelConfig:
trust_remote_code: bool,
download_dir: Optional[str],
load_format: str,
dtype: str,
dtype: Union[str, torch.dtype],
seed: int,
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
......@@ -331,7 +331,7 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
def _get_and_verify_dtype(
config: PretrainedConfig,
dtype: str,
dtype: Union[str, torch.dtype],
) -> torch.dtype:
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
......@@ -339,10 +339,12 @@ def _get_and_verify_dtype(
if config_dtype is None:
config_dtype = torch.float32
if isinstance(dtype, str):
dtype = dtype.lower()
if dtype == "auto":
if config_dtype == torch.float32:
# Following the common practice, we use float16 for float32 models.
# Following the common practice, we use float16 for float32
# models.
torch_dtype = torch.float16
else:
torch_dtype = config_dtype
......@@ -350,6 +352,10 @@ def _get_and_verify_dtype(
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {dtype}")
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
elif isinstance(dtype, torch.dtype):
torch_dtype = dtype
else:
raise ValueError(f"Unknown dtype: {dtype}")
# Verify the dtype.
if torch_dtype != config_dtype:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment