Unverified Commit dec18c86 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[Flax] dont warn for bf16 weights (#923)

dont warn for bf16 weights
parent 25dfd0f8
......@@ -482,29 +482,6 @@ class FlaxModelMixin:
" training."
)
# dictionary of key: dtypes for the model params
param_dtypes = jax.tree_map(lambda x: x.dtype, state)
# extract keys of parameters not in jnp.float32
fp16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.float16]
bf16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16]
# raise a warning if any of the parameters are not in jnp.float32
if len(fp16_params) > 0:
logger.warning(
f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from "
f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n"
"You should probably UPCAST the model weights to float32 if this was not intended. "
"See [`~ModelMixin.to_fp32`] for further information on how to do this."
)
if len(bf16_params) > 0:
logger.warning(
f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from "
f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n"
"You should probably UPCAST the model weights to float32 if this was not intended. "
"See [`~ModelMixin.to_fp32`] for further information on how to do this."
)
return model, unflatten_dict(state)
def save_pretrained(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment