"docs/source/en/generation_strategies.md" did not exist on "3f96e0b4e483c4c7d4ec9dcdc24b0b0cdf31ea5c"
Unverified Commit 20081c74 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

Update `dtype_byte_size` to handle torch.float8_e4m3fn/float8_e5m2 types (#30488)

* Update modeling_utils/dtype_byte_size to handle float8 types

* Add a test for dtype_byte_size

* Format

* Fix bool
parent 59e715f7
...@@ -324,7 +324,7 @@ def dtype_byte_size(dtype): ...@@ -324,7 +324,7 @@ def dtype_byte_size(dtype):
""" """
if dtype == torch.bool: if dtype == torch.bool:
return 1 / 8 return 1 / 8
bit_search = re.search(r"[^\d](\d+)$", str(dtype)) bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
if bit_search is None: if bit_search is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0]) bit_size = int(bit_search.groups()[0])
......
...@@ -101,7 +101,12 @@ if is_torch_available(): ...@@ -101,7 +101,12 @@ if is_torch_available():
_prepare_4d_attention_mask, _prepare_4d_attention_mask,
_prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask,
) )
from transformers.modeling_utils import _find_disjoint, _find_identical, shard_checkpoint from transformers.modeling_utils import (
_find_disjoint,
_find_identical,
dtype_byte_size,
shard_checkpoint,
)
# Fake pretrained models for tests # Fake pretrained models for tests
class BaseModel(PreTrainedModel): class BaseModel(PreTrainedModel):
...@@ -465,6 +470,31 @@ class ModelUtilsTest(TestCasePlus): ...@@ -465,6 +470,31 @@ class ModelUtilsTest(TestCasePlus):
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation] module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
) )
def test_torch_dtype_byte_sizes(self):
torch_dtypes_and_bytes = [
(torch.double, 8),
(torch.float64, 8),
(torch.float, 4),
(torch.float32, 4),
(torch.half, 2),
(torch.float16, 2),
(torch.bfloat16, 2),
(torch.long, 8),
(torch.int64, 8),
(torch.int, 4),
(torch.int32, 4),
(torch.short, 2),
(torch.int16, 2),
(torch.uint8, 1),
(torch.int8, 1),
(torch.float8_e4m3fn, 1),
(torch.float8_e5m2, 1),
(torch.bool, 0.125),
]
for torch_dtype, bytes_per_element in torch_dtypes_and_bytes:
self.assertEqual(dtype_byte_size(torch_dtype), bytes_per_element)
def test_no_super_init_config_and_model(self): def test_no_super_init_config_and_model(self):
config = NoSuperInitConfig(attribute=32) config = NoSuperInitConfig(attribute=32)
model = NoSuperInitModel(config) model = NoSuperInitModel(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment