Unverified Commit e92190c0 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

Fix Flax params dtype (#13098)



* fix inits

* fix embed dtype

* fix embed dtype

* add test to check default dtype

* quality

* add type conversion methods for flax models

* more robust casting

* cast sinusoidal positions

* update pegasus

* update albert

* update test

* make sure dtype is passed to every module

* style

* fix electra dense

* fix t5

* quality

* add more tests

* better name

* use the dtype for lm head computation

* fix albert

* style

* fix albert embed dtype

* more tests

* fix vision enc-dec

* cleanup

* fix embed dtype pegasus

* fix default param test

* doc

* update template

* fix final_logits_bias dtype

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix doc

* fix doc

* add detailed docstring for dtype parameter

* remove un-necessary import
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 1c76a516
...@@ -236,6 +236,18 @@ WAV_2_VEC_2_START_DOCSTRING = r""" ...@@ -236,6 +236,18 @@ WAV_2_VEC_2_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights. model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
""" """
...@@ -289,7 +301,7 @@ class FlaxWav2Vec2LayerNormConvLayer(nn.Module): ...@@ -289,7 +301,7 @@ class FlaxWav2Vec2LayerNormConvLayer(nn.Module):
kernel_size=(self.config.conv_kernel[self.layer_id],), kernel_size=(self.config.conv_kernel[self.layer_id],),
strides=(self.config.conv_stride[self.layer_id],), strides=(self.config.conv_stride[self.layer_id],),
use_bias=self.config.conv_bias, use_bias=self.config.conv_bias,
kernel_init=jax.nn.initializers.he_normal(dtype=self.dtype), kernel_init=jax.nn.initializers.he_normal(),
padding="VALID", padding="VALID",
dtype=self.dtype, dtype=self.dtype,
) )
...@@ -311,7 +323,7 @@ class FlaxConvWithWeightNorm(nn.Module): ...@@ -311,7 +323,7 @@ class FlaxConvWithWeightNorm(nn.Module):
self.conv = nn.Conv( self.conv = nn.Conv(
features=self.config.hidden_size, features=self.config.hidden_size,
kernel_size=(self.config.num_conv_pos_embeddings,), kernel_size=(self.config.num_conv_pos_embeddings,),
kernel_init=jax.nn.initializers.he_normal(dtype=self.dtype), kernel_init=jax.nn.initializers.he_normal(),
padding="VALID", padding="VALID",
feature_group_count=self.config.num_conv_pos_embedding_groups, feature_group_count=self.config.num_conv_pos_embedding_groups,
dtype=self.dtype, dtype=self.dtype,
...@@ -321,7 +333,7 @@ class FlaxConvWithWeightNorm(nn.Module): ...@@ -321,7 +333,7 @@ class FlaxConvWithWeightNorm(nn.Module):
self.conv.features // self.conv.feature_group_count, self.conv.features // self.conv.feature_group_count,
self.conv.kernel_size[0], self.conv.kernel_size[0],
) )
self.weight_v = self.param("weight_v", jax.nn.initializers.he_normal(dtype=self.dtype), weight_shape) self.weight_v = self.param("weight_v", jax.nn.initializers.he_normal(), weight_shape)
self.weight_g = self.param("weight_g", lambda _: jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :]) self.weight_g = self.param("weight_g", lambda _: jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :])
self.bias = self.param("bias", jax.nn.initializers.zeros, (self.conv.features,)) self.bias = self.param("bias", jax.nn.initializers.zeros, (self.conv.features,))
self.prev_padding = self.conv.kernel_size[0] // 2 self.prev_padding = self.conv.kernel_size[0] // 2
...@@ -407,7 +419,7 @@ class FlaxWav2Vec2FeatureProjection(nn.Module): ...@@ -407,7 +419,7 @@ class FlaxWav2Vec2FeatureProjection(nn.Module):
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.projection = nn.Dense( self.projection = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
self.dropout = nn.Dropout(rate=self.config.feat_proj_dropout) self.dropout = nn.Dropout(rate=self.config.feat_proj_dropout)
...@@ -439,7 +451,7 @@ class FlaxWav2Vec2Attention(nn.Module): ...@@ -439,7 +451,7 @@ class FlaxWav2Vec2Attention(nn.Module):
self.embed_dim, self.embed_dim,
use_bias=self.bias, use_bias=self.bias,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
...@@ -518,7 +530,7 @@ class FlaxWav2Vec2FeedForward(nn.Module): ...@@ -518,7 +530,7 @@ class FlaxWav2Vec2FeedForward(nn.Module):
self.intermediate_dense = nn.Dense( self.intermediate_dense = nn.Dense(
self.config.intermediate_size, self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
if isinstance(self.config.hidden_act, str): if isinstance(self.config.hidden_act, str):
...@@ -528,7 +540,7 @@ class FlaxWav2Vec2FeedForward(nn.Module): ...@@ -528,7 +540,7 @@ class FlaxWav2Vec2FeedForward(nn.Module):
self.output_dense = nn.Dense( self.output_dense = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
self.output_dropout = nn.Dropout(rate=self.config.hidden_dropout) self.output_dropout = nn.Dropout(rate=self.config.hidden_dropout)
...@@ -704,7 +716,7 @@ class FlaxWav2Vec2GumbelVectorQuantizer(nn.Module): ...@@ -704,7 +716,7 @@ class FlaxWav2Vec2GumbelVectorQuantizer(nn.Module):
) )
self.weight_proj = nn.Dense( self.weight_proj = nn.Dense(
self.num_groups * self.num_vars, self.num_groups * self.num_vars,
kernel_init=jax.nn.initializers.normal(1.0, self.dtype), kernel_init=jax.nn.initializers.normal(1.0),
dtype=self.dtype, dtype=self.dtype,
) )
...@@ -969,7 +981,7 @@ class FlaxWav2Vec2ForCTCModule(nn.Module): ...@@ -969,7 +981,7 @@ class FlaxWav2Vec2ForCTCModule(nn.Module):
self.dropout = nn.Dropout(rate=self.config.final_dropout) self.dropout = nn.Dropout(rate=self.config.final_dropout)
self.lm_head = nn.Dense( self.lm_head = nn.Dense(
self.config.vocab_size, self.config.vocab_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
...@@ -1078,12 +1090,12 @@ class FlaxWav2Vec2ForPreTrainingModule(nn.Module): ...@@ -1078,12 +1090,12 @@ class FlaxWav2Vec2ForPreTrainingModule(nn.Module):
self.quantizer = FlaxWav2Vec2GumbelVectorQuantizer(self.config, dtype=self.dtype) self.quantizer = FlaxWav2Vec2GumbelVectorQuantizer(self.config, dtype=self.dtype)
self.project_q = nn.Dense( self.project_q = nn.Dense(
self.config.proj_codevector_dim, self.config.proj_codevector_dim,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
self.project_hid = nn.Dense( self.project_hid = nn.Dense(
self.config.proj_codevector_dim, self.config.proj_codevector_dim,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
......
...@@ -75,6 +75,18 @@ _TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer" ...@@ -75,6 +75,18 @@ _TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
Initializing with a config file does not load the weights associated with the model, only the Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights. model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
""" """
{{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING = r""" {{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING = r"""
Args: Args:
...@@ -123,19 +135,16 @@ class Flax{{cookiecutter.camelcase_modelname}}Embeddings(nn.Module): ...@@ -123,19 +135,16 @@ class Flax{{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
self.config.vocab_size, self.config.vocab_size,
self.config.hidden_size, self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.position_embeddings = nn.Embed( self.position_embeddings = nn.Embed(
self.config.max_position_embeddings, self.config.max_position_embeddings,
self.config.hidden_size, self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.token_type_embeddings = nn.Embed( self.token_type_embeddings = nn.Embed(
self.config.type_vocab_size, self.config.type_vocab_size,
self.config.hidden_size, self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
...@@ -170,17 +179,17 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module): ...@@ -170,17 +179,17 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
self.query = nn.Dense( self.query = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
self.key = nn.Dense( self.key = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
self.value = nn.Dense( self.value = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False): def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
...@@ -239,7 +248,7 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfOutput(nn.Module): ...@@ -239,7 +248,7 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfOutput(nn.Module):
def setup(self): def setup(self):
self.dense = nn.Dense( self.dense = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
...@@ -287,7 +296,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Intermediate(nn.Module): ...@@ -287,7 +296,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Intermediate(nn.Module):
def setup(self): def setup(self):
self.dense = nn.Dense( self.dense = nn.Dense(
self.config.intermediate_size, self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
self.activation = ACT2FN[self.config.hidden_act] self.activation = ACT2FN[self.config.hidden_act]
...@@ -306,7 +315,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Output(nn.Module): ...@@ -306,7 +315,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Output(nn.Module):
def setup(self): def setup(self):
self.dense = nn.Dense( self.dense = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
...@@ -428,7 +437,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Pooler(nn.Module): ...@@ -428,7 +437,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Pooler(nn.Module):
def setup(self): def setup(self):
self.dense = nn.Dense( self.dense = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
...@@ -1105,6 +1114,18 @@ _TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer" ...@@ -1105,6 +1114,18 @@ _TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
Initializing with a config file does not load the weights associated with the model, only the Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights. model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
""" """
{{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING = r""" {{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING = r"""
...@@ -1272,7 +1293,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module): ...@@ -1272,7 +1293,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module):
self.embed_dim, self.embed_dim,
use_bias=self.bias, use_bias=self.bias,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
...@@ -1428,6 +1449,7 @@ class Flax{{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module): ...@@ -1428,6 +1449,7 @@ class Flax{{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module):
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=self.config.encoder_attention_heads, num_heads=self.config.encoder_attention_heads,
dropout=self.config.attention_dropout, dropout=self.config.attention_dropout,
dtype=self.dtype
) )
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.dropout_layer = nn.Dropout(rate=self.config.dropout)
...@@ -1436,10 +1458,10 @@ class Flax{{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module): ...@@ -1436,10 +1458,10 @@ class Flax{{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module):
self.fc1 = nn.Dense( self.fc1 = nn.Dense(
self.config.encoder_ffn_dim, self.config.encoder_ffn_dim,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
self.fc2 = nn.Dense( self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
) )
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
...@@ -1538,6 +1560,7 @@ class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module): ...@@ -1538,6 +1560,7 @@ class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
num_heads=self.config.decoder_attention_heads, num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout, dropout=self.config.attention_dropout,
causal=True, causal=True,
dtype=self.dtype,
) )
self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function] self.activation_fn = ACT2FN[self.config.activation_function]
...@@ -1549,15 +1572,16 @@ class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module): ...@@ -1549,15 +1572,16 @@ class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads, num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout, dropout=self.config.attention_dropout,
dtype=self.dtype,
) )
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.fc1 = nn.Dense( self.fc1 = nn.Dense(
self.config.encoder_ffn_dim, self.config.encoder_ffn_dim,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
self.fc2 = nn.Dense( self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
) )
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
...@@ -1692,13 +1716,13 @@ class Flax{{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module): ...@@ -1692,13 +1716,13 @@ class Flax{{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module):
def setup(self): def setup(self):
self.dense = nn.Dense( self.dense = nn.Dense(
self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
) )
self.dropout = nn.Dropout(rate=self.pooler_dropout) self.dropout = nn.Dropout(rate=self.pooler_dropout)
self.out_proj = nn.Dense( self.out_proj = nn.Dense(
self.num_classes, self.num_classes,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
def __call__(self, hidden_states: jnp.ndarray, deterministic: bool): def __call__(self, hidden_states: jnp.ndarray, deterministic: bool):
...@@ -1727,8 +1751,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module): ...@@ -1727,8 +1751,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
self.embed_tokens = nn.Embed( self.embed_tokens = nn.Embed(
self.config.vocab_size, self.config.vocab_size,
embed_dim, embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
# {{cookiecutter.camelcase_modelname}} is set up so that if padding_idx is specified then offset the embedding ids by 2 # {{cookiecutter.camelcase_modelname}} is set up so that if padding_idx is specified then offset the embedding ids by 2
...@@ -1737,8 +1760,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module): ...@@ -1737,8 +1760,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
self.embed_positions = nn.Embed( self.embed_positions = nn.Embed(
self.config.max_position_embeddings + self.offset, self.config.max_position_embeddings + self.offset,
embed_dim, embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
self.layers = Flax{{cookiecutter.camelcase_modelname}}EncoderLayerCollection(self.config, self.dtype) self.layers = Flax{{cookiecutter.camelcase_modelname}}EncoderLayerCollection(self.config, self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype) self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype)
...@@ -1800,8 +1822,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Decoder(nn.Module): ...@@ -1800,8 +1822,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Decoder(nn.Module):
self.embed_tokens = nn.Embed( self.embed_tokens = nn.Embed(
self.config.vocab_size, self.config.vocab_size,
embed_dim, embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
# {{cookiecutter.camelcase_modelname}} is set up so that if padding_idx is specified then offset the embedding ids by 2 # {{cookiecutter.camelcase_modelname}} is set up so that if padding_idx is specified then offset the embedding ids by 2
...@@ -1810,8 +1831,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Decoder(nn.Module): ...@@ -1810,8 +1831,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Decoder(nn.Module):
self.embed_positions = nn.Embed( self.embed_positions = nn.Embed(
self.config.max_position_embeddings + self.offset, self.config.max_position_embeddings + self.offset,
embed_dim, embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
self.layers = Flax{{cookiecutter.camelcase_modelname}}DecoderLayerCollection(self.config, self.dtype) self.layers = Flax{{cookiecutter.camelcase_modelname}}DecoderLayerCollection(self.config, self.dtype)
...@@ -1874,8 +1894,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module): ...@@ -1874,8 +1894,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module):
self.shared = nn.Embed( self.shared = nn.Embed(
self.config.vocab_size, self.config.vocab_size,
self.config.d_model, self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
self.encoder = Flax{{cookiecutter.camelcase_modelname}}Encoder(self.config, dtype=self.dtype, embed_tokens=self.shared) self.encoder = Flax{{cookiecutter.camelcase_modelname}}Encoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
...@@ -2279,7 +2298,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGenerationModule(nn. ...@@ -2279,7 +2298,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGenerationModule(nn.
self.model.shared.num_embeddings, self.model.shared.num_embeddings,
use_bias=False, use_bias=False,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings)) self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))
...@@ -2323,7 +2342,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGenerationModule(nn. ...@@ -2323,7 +2342,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGenerationModule(nn.
else: else:
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
lm_logits += self.final_logits_bias lm_logits += self.final_logits_bias.astype(self.dtype)
if not return_dict: if not return_dict:
output = (lm_logits,) + outputs[1:] output = (lm_logits,) + outputs[1:]
...@@ -2439,7 +2458,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(Flax{{coo ...@@ -2439,7 +2458,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(Flax{{coo
else: else:
lm_logits = module.lm_head(hidden_states) lm_logits = module.lm_head(hidden_states)
lm_logits += module.final_logits_bias lm_logits += module.final_logits_bias.astype(self.dtype)
return lm_logits, outputs return lm_logits, outputs
outputs = self.module.apply( outputs = self.module.apply(
...@@ -2670,7 +2689,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnsweringModule(nn.Modu ...@@ -2670,7 +2689,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnsweringModule(nn.Modu
def setup(self): def setup(self):
self.model = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype) self.model = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype)
self.qa_outputs = nn.Dense( self.qa_outputs = nn.Dense(
self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
) )
def _get_encoder_module(self): def _get_encoder_module(self):
......
...@@ -36,7 +36,7 @@ if is_flax_available(): ...@@ -36,7 +36,7 @@ if is_flax_available():
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import unfreeze from flax.core.frozen_dict import unfreeze
from flax.traverse_util import flatten_dict from flax.traverse_util import flatten_dict, unflatten_dict
from transformers import ( from transformers import (
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
...@@ -613,6 +613,141 @@ class FlaxModelTesterMixin: ...@@ -613,6 +613,141 @@ class FlaxModelTesterMixin:
else: else:
new_model_without_prefix(input_ids) new_model_without_prefix(input_ids)
def test_default_params_dtype(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
# check if all params are still in float32 when dtype of computation is half-precision
model = model_class(config, dtype=jnp.float16)
types = jax.tree_map(lambda x: x.dtype, model.params)
types = flatten_dict(types)
for name, type_ in types.items():
self.assertEquals(type_, jnp.float32, msg=f"param {name} is not initialized in fp32.")
def test_to_bf16(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
# cast all params to bf16
params = model.to_bf16(model.params)
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
# test if all params are in bf16
for name, type_ in types.items():
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
# test masking
flat_params = flatten_dict(params)
key = random.choice(list(flat_params.keys())) # choose a random param
mask = {path: path != key for path in flat_params} # don't cast the key
mask = unflatten_dict(mask)
params = model.to_bf16(model.params, mask)
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
# test if all params are in bf16 except key
for name, type_ in types.items():
if name == key:
self.assertEqual(type_, jnp.float32, msg=f"param {name} should be in fp32.")
else:
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
def test_to_fp16(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
# cast all params to fp16
params = model.to_fp16(model.params)
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
# test if all params are in fp16
for name, type_ in types.items():
self.assertEqual(type_, jnp.float16, msg=f"param {name} is not in fp16.")
# test masking
flat_params = flatten_dict(params)
key = random.choice(list(flat_params.keys())) # choose a random param
mask = {path: path != key for path in flat_params} # don't cast the key
mask = unflatten_dict(mask)
params = model.to_fp16(model.params, mask)
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
# test if all params are in fp16 except key
for name, type_ in types.items():
if name == key:
self.assertEqual(type_, jnp.float32, msg=f"param {name} should be in fp32.")
else:
self.assertEqual(type_, jnp.float16, msg=f"param {name} is not in fp16.")
def test_to_fp32(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
# cast all params to fp16 and back to fp32
params = model.to_fp16(model.params)
params = model.to_fp32(params)
# test if all params are in fp32
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
for name, type_ in types.items():
self.assertEqual(type_, jnp.float32, msg=f"param {name} is not in fp32.")
# test masking
flat_params = flatten_dict(params)
key = random.choice(list(flat_params.keys())) # choose a random param
mask = {path: path != key for path in flat_params} # don't cast the key
mask = unflatten_dict(mask)
# cast to fp16 and back to fp32 with mask
params = model.to_fp16(model.params)
params = model.to_fp32(params, mask)
# test if all params are in fp32 except key
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
for name, type_ in types.items():
if name == key:
self.assertEqual(type_, jnp.float16, msg=f"param {name} should be in fp16.")
else:
self.assertEqual(type_, jnp.float32, msg=f"param {name} is not in fp32.")
def test_save_load_in_fp16(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
# convert weights to fp16 and save
params = model.to_fp16(model.params)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, params=params)
# load the weights again and check if they are still in fp16
model = model_class.from_pretrained(tmpdirname)
types = flatten_dict(jax.tree_map(lambda x: x.dtype, model.params))
for name, type_ in types.items():
self.assertEqual(type_, jnp.float16, msg=f"param {name} is not in fp16.")
def test_save_load_in_bf16(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
# convert weights to bf16 and save
params = model.to_bf16(model.params)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, params=params)
# load the weights again and check if they are still in fp16
model = model_class.from_pretrained(tmpdirname)
types = flatten_dict(jax.tree_map(lambda x: x.dtype, model.params))
for name, type_ in types.items():
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
@require_flax @require_flax
@is_staging_test @is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment