"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "842e99f1b9ee2a0fa239997ef695c5ed0bd77195"
Unverified Commit af0e4b7b authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

Fix float8_e4m3fn in modeling_utils (#32193)

* Fix float8_e4m3fn in modeling_utils

* style

* fix

* comment
parent 1392a686
...@@ -855,6 +855,8 @@ def _load_state_dict_into_meta_model( ...@@ -855,6 +855,8 @@ def _load_state_dict_into_meta_model(
for old_key, new_key in zip(old_keys, new_keys): for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key) state_dict[new_key] = state_dict.pop(old_key)
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
for param_name, param in state_dict.items(): for param_name, param in state_dict.items():
# First part of the test is always true as load_state_dict_keys always contains state_dict keys. # First part of the test is always true as load_state_dict_keys always contains state_dict keys.
if param_name not in loaded_state_dict_keys or param_name not in expected_keys: if param_name not in loaded_state_dict_keys or param_name not in expected_keys:
...@@ -866,9 +868,10 @@ def _load_state_dict_into_meta_model( ...@@ -866,9 +868,10 @@ def _load_state_dict_into_meta_model(
module_name = param_name module_name = param_name
set_module_kwargs = {} set_module_kwargs = {}
# We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
# in int/uint/bool and not cast them. # in int/uint/bool and not cast them.
if dtype is not None and torch.is_floating_point(param) and param.dtype != torch.float8_e4m3fn: is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn
if dtype is not None and torch.is_floating_point(param) and not is_param_float8_e4m3fn:
if ( if (
keep_in_fp32_modules is not None keep_in_fp32_modules is not None
and any( and any(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment