Unverified Commit d8269eb4 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Flax `.from_pretrained`] Raise a warning if model weights are not in float32 (#16762)

* [Flax] Raise a warning if model weights are not in float32

* apply suggestions and few small changes

* reorder wording for better readability
parent 195fbbb6
......@@ -657,6 +657,29 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
# dictionary of key: dtypes for the model params
param_dtypes = jax.tree_map(lambda x: x.dtype, state)
# extract keys of parameters not in jnp.float32
fp16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.float16]
bf16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16]
# raise a warning if any of the parameters are not in jnp.float32
if len(fp16_params) > 0:
logger.warning(
f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from "
f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n"
"You should probably UPCAST the model weights to float32 if this was not intended. "
"See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
)
if len(bf16_params) > 0:
logger.warning(
f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from "
f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n"
"You should probably UPCAST the model weights to float32 if this was not intended. "
"See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
)
# set correct parameters
model.params = unflatten_dict(state)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment