Unverified Commit e92190c0 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

Fix Flax params dtype (#13098)



* fix inits

* fix embed dtype

* fix embed dtype

* add test to check default dtype

* quality

* add type conversion methods for flax models

* more robust casting

* cast sinusoidal positions

* update pegasus

* update albert

* update test

* make sure dtype is passed to every module

* style

* fix electra dense

* fix t5

* quality

* add more tests

* better name

* use the dtype for lm head computation

* fix albert

* style

* fix albert embed dtype

* more tests

* fix vision enc-dec

* cleanup

* fix embed dtype pegasus

* fix default param test

* doc

* update template

* fix final_logits_bias dtype

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix doc

* fix doc

* add detailed docstring for dtype parameter

* remove un-necessary import
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 1c76a516
......@@ -236,6 +236,18 @@ WAV_2_VEC_2_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
"""
......@@ -289,7 +301,7 @@ class FlaxWav2Vec2LayerNormConvLayer(nn.Module):
kernel_size=(self.config.conv_kernel[self.layer_id],),
strides=(self.config.conv_stride[self.layer_id],),
use_bias=self.config.conv_bias,
kernel_init=jax.nn.initializers.he_normal(dtype=self.dtype),
kernel_init=jax.nn.initializers.he_normal(),
padding="VALID",
dtype=self.dtype,
)
......@@ -311,7 +323,7 @@ class FlaxConvWithWeightNorm(nn.Module):
self.conv = nn.Conv(
features=self.config.hidden_size,
kernel_size=(self.config.num_conv_pos_embeddings,),
kernel_init=jax.nn.initializers.he_normal(dtype=self.dtype),
kernel_init=jax.nn.initializers.he_normal(),
padding="VALID",
feature_group_count=self.config.num_conv_pos_embedding_groups,
dtype=self.dtype,
......@@ -321,7 +333,7 @@ class FlaxConvWithWeightNorm(nn.Module):
self.conv.features // self.conv.feature_group_count,
self.conv.kernel_size[0],
)
self.weight_v = self.param("weight_v", jax.nn.initializers.he_normal(dtype=self.dtype), weight_shape)
self.weight_v = self.param("weight_v", jax.nn.initializers.he_normal(), weight_shape)
self.weight_g = self.param("weight_g", lambda _: jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :])
self.bias = self.param("bias", jax.nn.initializers.zeros, (self.conv.features,))
self.prev_padding = self.conv.kernel_size[0] // 2
......@@ -407,7 +419,7 @@ class FlaxWav2Vec2FeatureProjection(nn.Module):
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.projection = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.feat_proj_dropout)
......@@ -439,7 +451,7 @@ class FlaxWav2Vec2Attention(nn.Module):
self.embed_dim,
use_bias=self.bias,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
......@@ -518,7 +530,7 @@ class FlaxWav2Vec2FeedForward(nn.Module):
self.intermediate_dense = nn.Dense(
self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
if isinstance(self.config.hidden_act, str):
......@@ -528,7 +540,7 @@ class FlaxWav2Vec2FeedForward(nn.Module):
self.output_dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.output_dropout = nn.Dropout(rate=self.config.hidden_dropout)
......@@ -704,7 +716,7 @@ class FlaxWav2Vec2GumbelVectorQuantizer(nn.Module):
)
self.weight_proj = nn.Dense(
self.num_groups * self.num_vars,
kernel_init=jax.nn.initializers.normal(1.0, self.dtype),
kernel_init=jax.nn.initializers.normal(1.0),
dtype=self.dtype,
)
......@@ -969,7 +981,7 @@ class FlaxWav2Vec2ForCTCModule(nn.Module):
self.dropout = nn.Dropout(rate=self.config.final_dropout)
self.lm_head = nn.Dense(
self.config.vocab_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
......@@ -1078,12 +1090,12 @@ class FlaxWav2Vec2ForPreTrainingModule(nn.Module):
self.quantizer = FlaxWav2Vec2GumbelVectorQuantizer(self.config, dtype=self.dtype)
self.project_q = nn.Dense(
self.config.proj_codevector_dim,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.project_hid = nn.Dense(
self.config.proj_codevector_dim,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
......
......@@ -75,6 +75,18 @@ _TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
"""
{{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING = r"""
Args:
......@@ -123,19 +135,16 @@ class Flax{{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
self.config.vocab_size,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.position_embeddings = nn.Embed(
self.config.max_position_embeddings,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.token_type_embeddings = nn.Embed(
self.config.type_vocab_size,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
......@@ -170,17 +179,17 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
self.query = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.key = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.value = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
......@@ -239,7 +248,7 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfOutput(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
......@@ -287,7 +296,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Intermediate(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.activation = ACT2FN[self.config.hidden_act]
......@@ -306,7 +315,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Output(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
......@@ -428,7 +437,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Pooler(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
......@@ -1105,6 +1114,18 @@ _TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
"""
{{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING = r"""
......@@ -1272,7 +1293,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module):
self.embed_dim,
use_bias=self.bias,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
......@@ -1428,6 +1449,7 @@ class Flax{{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module):
embed_dim=self.embed_dim,
num_heads=self.config.encoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype
)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
......@@ -1436,10 +1458,10 @@ class Flax{{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module):
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
......@@ -1538,6 +1560,7 @@ class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
causal=True,
dtype=self.dtype,
)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
......@@ -1549,15 +1572,16 @@ class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
......@@ -1692,13 +1716,13 @@ class Flax{{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.dropout = nn.Dropout(rate=self.pooler_dropout)
self.out_proj = nn.Dense(
self.num_classes,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
def __call__(self, hidden_states: jnp.ndarray, deterministic: bool):
......@@ -1727,8 +1751,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
# {{cookiecutter.camelcase_modelname}} is set up so that if padding_idx is specified then offset the embedding ids by 2
......@@ -1737,8 +1760,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
self.embed_positions = nn.Embed(
self.config.max_position_embeddings + self.offset,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.layers = Flax{{cookiecutter.camelcase_modelname}}EncoderLayerCollection(self.config, self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype)
......@@ -1800,8 +1822,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Decoder(nn.Module):
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
# {{cookiecutter.camelcase_modelname}} is set up so that if padding_idx is specified then offset the embedding ids by 2
......@@ -1810,8 +1831,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Decoder(nn.Module):
self.embed_positions = nn.Embed(
self.config.max_position_embeddings + self.offset,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.layers = Flax{{cookiecutter.camelcase_modelname}}DecoderLayerCollection(self.config, self.dtype)
......@@ -1874,8 +1894,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module):
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.encoder = Flax{{cookiecutter.camelcase_modelname}}Encoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
......@@ -2279,7 +2298,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGenerationModule(nn.
self.model.shared.num_embeddings,
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))
......@@ -2323,7 +2342,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGenerationModule(nn.
else:
lm_logits = self.lm_head(hidden_states)
lm_logits += self.final_logits_bias
lm_logits += self.final_logits_bias.astype(self.dtype)
if not return_dict:
output = (lm_logits,) + outputs[1:]
......@@ -2439,7 +2458,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(Flax{{coo
else:
lm_logits = module.lm_head(hidden_states)
lm_logits += module.final_logits_bias
lm_logits += module.final_logits_bias.astype(self.dtype)
return lm_logits, outputs
outputs = self.module.apply(
......@@ -2670,7 +2689,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnsweringModule(nn.Modu
def setup(self):
self.model = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype)
self.qa_outputs = nn.Dense(
self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
def _get_encoder_module(self):
......
......@@ -36,7 +36,7 @@ if is_flax_available():
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import unfreeze
from flax.traverse_util import flatten_dict
from flax.traverse_util import flatten_dict, unflatten_dict
from transformers import (
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
......@@ -613,6 +613,141 @@ class FlaxModelTesterMixin:
else:
new_model_without_prefix(input_ids)
def test_default_params_dtype(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
# check if all params are still in float32 when dtype of computation is half-precision
model = model_class(config, dtype=jnp.float16)
types = jax.tree_map(lambda x: x.dtype, model.params)
types = flatten_dict(types)
for name, type_ in types.items():
self.assertEquals(type_, jnp.float32, msg=f"param {name} is not initialized in fp32.")
def test_to_bf16(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
# cast all params to bf16
params = model.to_bf16(model.params)
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
# test if all params are in bf16
for name, type_ in types.items():
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
# test masking
flat_params = flatten_dict(params)
key = random.choice(list(flat_params.keys())) # choose a random param
mask = {path: path != key for path in flat_params} # don't cast the key
mask = unflatten_dict(mask)
params = model.to_bf16(model.params, mask)
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
# test if all params are in bf16 except key
for name, type_ in types.items():
if name == key:
self.assertEqual(type_, jnp.float32, msg=f"param {name} should be in fp32.")
else:
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
def test_to_fp16(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
# cast all params to fp16
params = model.to_fp16(model.params)
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
# test if all params are in fp16
for name, type_ in types.items():
self.assertEqual(type_, jnp.float16, msg=f"param {name} is not in fp16.")
# test masking
flat_params = flatten_dict(params)
key = random.choice(list(flat_params.keys())) # choose a random param
mask = {path: path != key for path in flat_params} # don't cast the key
mask = unflatten_dict(mask)
params = model.to_fp16(model.params, mask)
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
# test if all params are in fp16 except key
for name, type_ in types.items():
if name == key:
self.assertEqual(type_, jnp.float32, msg=f"param {name} should be in fp32.")
else:
self.assertEqual(type_, jnp.float16, msg=f"param {name} is not in fp16.")
def test_to_fp32(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
# cast all params to fp16 and back to fp32
params = model.to_fp16(model.params)
params = model.to_fp32(params)
# test if all params are in fp32
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
for name, type_ in types.items():
self.assertEqual(type_, jnp.float32, msg=f"param {name} is not in fp32.")
# test masking
flat_params = flatten_dict(params)
key = random.choice(list(flat_params.keys())) # choose a random param
mask = {path: path != key for path in flat_params} # don't cast the key
mask = unflatten_dict(mask)
# cast to fp16 and back to fp32 with mask
params = model.to_fp16(model.params)
params = model.to_fp32(params, mask)
# test if all params are in fp32 except key
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
for name, type_ in types.items():
if name == key:
self.assertEqual(type_, jnp.float16, msg=f"param {name} should be in fp16.")
else:
self.assertEqual(type_, jnp.float32, msg=f"param {name} is not in fp32.")
def test_save_load_in_fp16(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
# convert weights to fp16 and save
params = model.to_fp16(model.params)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, params=params)
# load the weights again and check if they are still in fp16
model = model_class.from_pretrained(tmpdirname)
types = flatten_dict(jax.tree_map(lambda x: x.dtype, model.params))
for name, type_ in types.items():
self.assertEqual(type_, jnp.float16, msg=f"param {name} is not in fp16.")
def test_save_load_in_bf16(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
# convert weights to bf16 and save
params = model.to_bf16(model.params)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, params=params)
# load the weights again and check if they are still in fp16
model = model_class.from_pretrained(tmpdirname)
types = flatten_dict(jax.tree_map(lambda x: x.dtype, model.params))
for name, type_ in types.items():
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
@require_flax
@is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment