Unverified Commit e92190c0 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

Fix Flax params dtype (#13098)



* fix inits

* fix embed dtype

* fix embed dtype

* add test to check default dtype

* quality

* add type conversion methods for flax models

* more robust casting

* cast sinusoidal positions

* update pegasus

* update albert

* update test

* make sure dtype is passed to every module

* style

* fix electra dense

* fix t5

* quality

* add more tests

* better name

* use the dtype for lm head computation

* fix albert

* style

* fix albert embed dtype

* more tests

* fix vision enc-dec

* cleanup

* fix embed dtype pegasus

* fix default param test

* doc

* update template

* fix final_logits_bias dtype

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix doc

* fix doc

* add detailed docstring for dtype parameter

* remove un-necessary import
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 1c76a516
...@@ -50,13 +50,13 @@ class FlaxHybridCLIPModule(nn.Module): ...@@ -50,13 +50,13 @@ class FlaxHybridCLIPModule(nn.Module):
self.visual_projection = nn.Dense( self.visual_projection = nn.Dense(
self.projection_dim, self.projection_dim,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype), kernel_init=jax.nn.initializers.normal(0.02),
use_bias=False, use_bias=False,
) )
self.text_projection = nn.Dense( self.text_projection = nn.Dense(
self.projection_dim, self.projection_dim,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype), kernel_init=jax.nn.initializers.normal(0.02),
use_bias=False, use_bias=False,
) )
self.logit_scale = self.param("logit_scale", jax.nn.initializers.ones, []) self.logit_scale = self.param("logit_scale", jax.nn.initializers.ones, [])
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import os import os
from functools import partial from functools import partial
from pickle import UnpicklingError from pickle import UnpicklingError
from typing import Dict, Set, Tuple, Union from typing import Any, Dict, Set, Tuple, Union
import flax.linen as nn import flax.linen as nn
import jax import jax
...@@ -154,6 +154,122 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -154,6 +154,122 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
) )
self._params = params self._params = params
def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
"""
Helper method to cast floating-point values of given parameter ``PyTree`` to given ``dtype``.
"""
# taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
def conditional_cast(param):
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
param = param.astype(dtype)
return param
if mask is None:
return jax.tree_map(conditional_cast, params)
flat_params = flatten_dict(params)
flat_mask, _ = jax.tree_flatten(mask)
for masked, key in zip(flat_mask, flat_params.keys()):
if masked:
param = flat_params[key]
flat_params[key] = conditional_cast(param)
return unflatten_dict(flat_params)
def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
r"""
Cast the floating-point ``params`` to ``jax.numpy.bfloat16``. This returns a new ``params`` tree and does not
cast the ``params`` in place.
This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full
half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
Arguments:
params (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree`` of model parameters.
mask (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree`` with same structure as the ``params`` tree. The leaves should be booleans, :obj:`True` for
params you want to cast, and should be :obj:`False` for those you want to skip.
Examples::
>>> from transformers import FlaxBertModel
>>> # load model
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
>>> model.params = model.to_bf16(model.params)
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
>>> # then pass the mask as follows
>>> from flax import traverse_util
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> flat_params = traverse_util.flatten_dict(model.params)
>>> mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
>>> mask = traverse_util.unflatten_dict(mask)
>>> model.params = model.to_bf16(model.params, mask)
"""
return self._cast_floating_to(params, jnp.bfloat16, mask)
def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
r"""
Cast the floating-point ``parmas`` to ``jax.numpy.float32``. This method can be used to explicitly convert the
model parameters to fp32 precision. This returns a new ``params`` tree and does not cast the ``params`` in
place.
Arguments:
params (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree`` of model parameters.
mask (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree`` with same structure as the ``params`` tree. The leaves should be booleans, :obj:`True` for
params you want to cast, and should be :obj:`False` for those you want to skip
Examples::
>>> from transformers import FlaxBertModel
>>> # Download model and configuration from huggingface.co
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> # By default, the model params will be in fp32, to illustrate the use of this method,
>>> # we'll first cast to fp16 and back to fp32
>>> model.params = model.to_f16(model.params)
>>> # now cast back to fp32
>>> model.params = model.to_fp32(model.params)
"""
return self._cast_floating_to(params, jnp.float32, mask)
def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
r"""
Cast the floating-point ``parmas`` to ``jax.numpy.float16``. This returns a new ``params`` tree and does not
cast the ``params`` in place.
This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full
half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
Arguments:
params (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree`` of model parameters.
mask (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree`` with same structure as the ``params`` tree. The leaves should be booleans, :obj:`True` for
params you want to cast, and should be :obj:`False` for those you want to skip
Examples::
>>> from transformers import FlaxBertModel
>>> # Download model and configuration from huggingface.co
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> # By default, the model params will be in fp32, to cast these to float16
>>> model.params = model.to_f16(model.params)
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
>>> # then pass the mask as follows
>>> from flax import traverse_util
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> flat_params = traverse_util.flatten_dict(model.params)
>>> mask = {path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
>>> mask = traverse_util.unflatten_dict(mask)
>>> model.params = model.to_f16(model.params, mask)
"""
return self._cast_floating_to(params, jnp.float16, mask)
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, cls,
...@@ -184,6 +300,19 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -184,6 +300,19 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
- A path or url to a `pt index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In this - A path or url to a `pt index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In this
case, ``from_pt`` should be set to :obj:`True`. case, ``from_pt`` should be set to :obj:`True`.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and
:meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
model_args (sequence of positional arguments, `optional`): model_args (sequence of positional arguments, `optional`):
All remaining positional arguments will be passed to the underlying model's ``__init__`` method. All remaining positional arguments will be passed to the underlying model's ``__init__`` method.
config (:obj:`Union[PretrainedConfig, str, os.PathLike]`, `optional`): config (:obj:`Union[PretrainedConfig, str, os.PathLike]`, `optional`):
......
...@@ -105,6 +105,18 @@ ALBERT_START_DOCSTRING = r""" ...@@ -105,6 +105,18 @@ ALBERT_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights. model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
""" """
ALBERT_INPUTS_DOCSTRING = r""" ALBERT_INPUTS_DOCSTRING = r"""
...@@ -152,19 +164,16 @@ class FlaxAlbertEmbeddings(nn.Module): ...@@ -152,19 +164,16 @@ class FlaxAlbertEmbeddings(nn.Module):
self.config.vocab_size, self.config.vocab_size,
self.config.embedding_size, self.config.embedding_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.position_embeddings = nn.Embed( self.position_embeddings = nn.Embed(
self.config.max_position_embeddings, self.config.max_position_embeddings,
self.config.embedding_size, self.config.embedding_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.token_type_embeddings = nn.Embed( self.token_type_embeddings = nn.Embed(
self.config.type_vocab_size, self.config.type_vocab_size,
self.config.embedding_size, self.config.embedding_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
...@@ -199,21 +208,21 @@ class FlaxAlbertSelfAttention(nn.Module): ...@@ -199,21 +208,21 @@ class FlaxAlbertSelfAttention(nn.Module):
self.query = nn.Dense( self.query = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
self.key = nn.Dense( self.key = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
self.value = nn.Dense( self.value = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
self.dense = nn.Dense( self.dense = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
...@@ -278,13 +287,13 @@ class FlaxAlbertLayer(nn.Module): ...@@ -278,13 +287,13 @@ class FlaxAlbertLayer(nn.Module):
self.attention = FlaxAlbertSelfAttention(self.config, dtype=self.dtype) self.attention = FlaxAlbertSelfAttention(self.config, dtype=self.dtype)
self.ffn = nn.Dense( self.ffn = nn.Dense(
self.config.intermediate_size, self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
self.activation = ACT2FN[self.config.hidden_act] self.activation = ACT2FN[self.config.hidden_act]
self.ffn_output = nn.Dense( self.ffn_output = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
self.full_layer_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.full_layer_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
...@@ -437,7 +446,7 @@ class FlaxAlbertEncoder(nn.Module): ...@@ -437,7 +446,7 @@ class FlaxAlbertEncoder(nn.Module):
def setup(self): def setup(self):
self.embedding_hidden_mapping_in = nn.Dense( self.embedding_hidden_mapping_in = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
self.albert_layer_groups = FlaxAlbertLayerGroups(self.config, dtype=self.dtype) self.albert_layer_groups = FlaxAlbertLayerGroups(self.config, dtype=self.dtype)
...@@ -596,7 +605,7 @@ class FlaxAlbertModule(nn.Module): ...@@ -596,7 +605,7 @@ class FlaxAlbertModule(nn.Module):
if self.add_pooling_layer: if self.add_pooling_layer:
self.pooler = nn.Dense( self.pooler = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
name="pooler", name="pooler",
) )
......
...@@ -79,6 +79,18 @@ BART_START_DOCSTRING = r""" ...@@ -79,6 +79,18 @@ BART_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights. model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
""" """
BART_INPUTS_DOCSTRING = r""" BART_INPUTS_DOCSTRING = r"""
...@@ -248,7 +260,7 @@ class FlaxBartAttention(nn.Module): ...@@ -248,7 +260,7 @@ class FlaxBartAttention(nn.Module):
self.embed_dim, self.embed_dim,
use_bias=self.bias, use_bias=self.bias,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
...@@ -404,6 +416,7 @@ class FlaxBartEncoderLayer(nn.Module): ...@@ -404,6 +416,7 @@ class FlaxBartEncoderLayer(nn.Module):
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=self.config.encoder_attention_heads, num_heads=self.config.encoder_attention_heads,
dropout=self.config.attention_dropout, dropout=self.config.attention_dropout,
dtype=self.dtype,
) )
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.dropout_layer = nn.Dropout(rate=self.config.dropout)
...@@ -412,10 +425,10 @@ class FlaxBartEncoderLayer(nn.Module): ...@@ -412,10 +425,10 @@ class FlaxBartEncoderLayer(nn.Module):
self.fc1 = nn.Dense( self.fc1 = nn.Dense(
self.config.encoder_ffn_dim, self.config.encoder_ffn_dim,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
self.fc2 = nn.Dense( self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
) )
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
...@@ -514,6 +527,7 @@ class FlaxBartDecoderLayer(nn.Module): ...@@ -514,6 +527,7 @@ class FlaxBartDecoderLayer(nn.Module):
num_heads=self.config.decoder_attention_heads, num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout, dropout=self.config.attention_dropout,
causal=True, causal=True,
dtype=self.dtype,
) )
self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function] self.activation_fn = ACT2FN[self.config.activation_function]
...@@ -525,15 +539,16 @@ class FlaxBartDecoderLayer(nn.Module): ...@@ -525,15 +539,16 @@ class FlaxBartDecoderLayer(nn.Module):
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads, num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout, dropout=self.config.attention_dropout,
dtype=self.dtype,
) )
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.fc1 = nn.Dense( self.fc1 = nn.Dense(
self.config.encoder_ffn_dim, self.config.encoder_ffn_dim,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
self.fc2 = nn.Dense( self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
) )
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
...@@ -668,13 +683,13 @@ class FlaxBartClassificationHead(nn.Module): ...@@ -668,13 +683,13 @@ class FlaxBartClassificationHead(nn.Module):
def setup(self): def setup(self):
self.dense = nn.Dense( self.dense = nn.Dense(
self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
) )
self.dropout = nn.Dropout(rate=self.pooler_dropout) self.dropout = nn.Dropout(rate=self.pooler_dropout)
self.out_proj = nn.Dense( self.out_proj = nn.Dense(
self.num_classes, self.num_classes,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
def __call__(self, hidden_states: jnp.ndarray, deterministic: bool): def __call__(self, hidden_states: jnp.ndarray, deterministic: bool):
...@@ -703,8 +718,7 @@ class FlaxBartEncoder(nn.Module): ...@@ -703,8 +718,7 @@ class FlaxBartEncoder(nn.Module):
self.embed_tokens = nn.Embed( self.embed_tokens = nn.Embed(
self.config.vocab_size, self.config.vocab_size,
embed_dim, embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
...@@ -713,8 +727,7 @@ class FlaxBartEncoder(nn.Module): ...@@ -713,8 +727,7 @@ class FlaxBartEncoder(nn.Module):
self.embed_positions = nn.Embed( self.embed_positions = nn.Embed(
self.config.max_position_embeddings + self.offset, self.config.max_position_embeddings + self.offset,
embed_dim, embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype) self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype) self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype)
...@@ -776,8 +789,7 @@ class FlaxBartDecoder(nn.Module): ...@@ -776,8 +789,7 @@ class FlaxBartDecoder(nn.Module):
self.embed_tokens = nn.Embed( self.embed_tokens = nn.Embed(
self.config.vocab_size, self.config.vocab_size,
embed_dim, embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
...@@ -786,8 +798,7 @@ class FlaxBartDecoder(nn.Module): ...@@ -786,8 +798,7 @@ class FlaxBartDecoder(nn.Module):
self.embed_positions = nn.Embed( self.embed_positions = nn.Embed(
self.config.max_position_embeddings + self.offset, self.config.max_position_embeddings + self.offset,
embed_dim, embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype) self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
...@@ -850,8 +861,7 @@ class FlaxBartModule(nn.Module): ...@@ -850,8 +861,7 @@ class FlaxBartModule(nn.Module):
self.shared = nn.Embed( self.shared = nn.Embed(
self.config.vocab_size, self.config.vocab_size,
self.config.d_model, self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
...@@ -1256,7 +1266,7 @@ class FlaxBartForConditionalGenerationModule(nn.Module): ...@@ -1256,7 +1266,7 @@ class FlaxBartForConditionalGenerationModule(nn.Module):
self.model.shared.num_embeddings, self.model.shared.num_embeddings,
use_bias=False, use_bias=False,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings)) self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))
...@@ -1300,7 +1310,7 @@ class FlaxBartForConditionalGenerationModule(nn.Module): ...@@ -1300,7 +1310,7 @@ class FlaxBartForConditionalGenerationModule(nn.Module):
else: else:
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
lm_logits += self.final_logits_bias lm_logits += self.final_logits_bias.astype(self.dtype)
if not return_dict: if not return_dict:
output = (lm_logits,) + outputs[1:] output = (lm_logits,) + outputs[1:]
...@@ -1416,7 +1426,7 @@ class FlaxBartForConditionalGeneration(FlaxBartPreTrainedModel): ...@@ -1416,7 +1426,7 @@ class FlaxBartForConditionalGeneration(FlaxBartPreTrainedModel):
else: else:
lm_logits = module.lm_head(hidden_states) lm_logits = module.lm_head(hidden_states)
lm_logits += module.final_logits_bias lm_logits += module.final_logits_bias.astype(self.dtype)
return lm_logits, outputs return lm_logits, outputs
outputs = self.module.apply( outputs = self.module.apply(
...@@ -1647,7 +1657,7 @@ class FlaxBartForQuestionAnsweringModule(nn.Module): ...@@ -1647,7 +1657,7 @@ class FlaxBartForQuestionAnsweringModule(nn.Module):
def setup(self): def setup(self):
self.model = FlaxBartModule(config=self.config, dtype=self.dtype) self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
self.qa_outputs = nn.Dense( self.qa_outputs = nn.Dense(
self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
) )
def _get_encoder_module(self): def _get_encoder_module(self):
......
...@@ -86,6 +86,18 @@ BEIT_START_DOCSTRING = r""" ...@@ -86,6 +86,18 @@ BEIT_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights. model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
""" """
BEIT_INPUTS_DOCSTRING = r""" BEIT_INPUTS_DOCSTRING = r"""
......
...@@ -106,6 +106,31 @@ BERT_START_DOCSTRING = r""" ...@@ -106,6 +106,31 @@ BERT_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights. model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
""" """
BERT_INPUTS_DOCSTRING = r""" BERT_INPUTS_DOCSTRING = r"""
...@@ -153,19 +178,16 @@ class FlaxBertEmbeddings(nn.Module): ...@@ -153,19 +178,16 @@ class FlaxBertEmbeddings(nn.Module):
self.config.vocab_size, self.config.vocab_size,
self.config.hidden_size, self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.position_embeddings = nn.Embed( self.position_embeddings = nn.Embed(
self.config.max_position_embeddings, self.config.max_position_embeddings,
self.config.hidden_size, self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.token_type_embeddings = nn.Embed( self.token_type_embeddings = nn.Embed(
self.config.type_vocab_size, self.config.type_vocab_size,
self.config.hidden_size, self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
...@@ -199,17 +221,17 @@ class FlaxBertSelfAttention(nn.Module): ...@@ -199,17 +221,17 @@ class FlaxBertSelfAttention(nn.Module):
self.query = nn.Dense( self.query = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
self.key = nn.Dense( self.key = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
self.value = nn.Dense( self.value = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False): def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
...@@ -267,7 +289,7 @@ class FlaxBertSelfOutput(nn.Module): ...@@ -267,7 +289,7 @@ class FlaxBertSelfOutput(nn.Module):
def setup(self): def setup(self):
self.dense = nn.Dense( self.dense = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
...@@ -313,7 +335,7 @@ class FlaxBertIntermediate(nn.Module): ...@@ -313,7 +335,7 @@ class FlaxBertIntermediate(nn.Module):
def setup(self): def setup(self):
self.dense = nn.Dense( self.dense = nn.Dense(
self.config.intermediate_size, self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
self.activation = ACT2FN[self.config.hidden_act] self.activation = ACT2FN[self.config.hidden_act]
...@@ -331,7 +353,7 @@ class FlaxBertOutput(nn.Module): ...@@ -331,7 +353,7 @@ class FlaxBertOutput(nn.Module):
def setup(self): def setup(self):
self.dense = nn.Dense( self.dense = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
...@@ -449,7 +471,7 @@ class FlaxBertPooler(nn.Module): ...@@ -449,7 +471,7 @@ class FlaxBertPooler(nn.Module):
def setup(self): def setup(self):
self.dense = nn.Dense( self.dense = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
...@@ -492,7 +514,8 @@ class FlaxBertLMPredictionHead(nn.Module): ...@@ -492,7 +514,8 @@ class FlaxBertLMPredictionHead(nn.Module):
else: else:
hidden_states = self.decoder(hidden_states) hidden_states = self.decoder(hidden_states)
hidden_states += self.bias bias = jnp.asarray(self.bias, self.dtype)
hidden_states += bias
return hidden_states return hidden_states
......
...@@ -136,6 +136,18 @@ BIG_BIRD_START_DOCSTRING = r""" ...@@ -136,6 +136,18 @@ BIG_BIRD_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights. model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
""" """
BIG_BIRD_INPUTS_DOCSTRING = r""" BIG_BIRD_INPUTS_DOCSTRING = r"""
...@@ -184,19 +196,16 @@ class FlaxBigBirdEmbeddings(nn.Module): ...@@ -184,19 +196,16 @@ class FlaxBigBirdEmbeddings(nn.Module):
self.config.vocab_size, self.config.vocab_size,
self.config.hidden_size, self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.position_embeddings = nn.Embed( self.position_embeddings = nn.Embed(
self.config.max_position_embeddings, self.config.max_position_embeddings,
self.config.hidden_size, self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.token_type_embeddings = nn.Embed( self.token_type_embeddings = nn.Embed(
self.config.type_vocab_size, self.config.type_vocab_size,
self.config.hidden_size, self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
...@@ -234,17 +243,17 @@ class FlaxBigBirdSelfAttention(nn.Module): ...@@ -234,17 +243,17 @@ class FlaxBigBirdSelfAttention(nn.Module):
self.query = nn.Dense( self.query = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
self.key = nn.Dense( self.key = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
self.value = nn.Dense( self.value = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False): def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
...@@ -305,19 +314,19 @@ class FlaxBigBirdBlockSparseAttention(nn.Module): ...@@ -305,19 +314,19 @@ class FlaxBigBirdBlockSparseAttention(nn.Module):
self.config.hidden_size, self.config.hidden_size,
dtype=self.dtype, dtype=self.dtype,
use_bias=self.config.use_bias, use_bias=self.config.use_bias,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
self.key = nn.Dense( self.key = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
dtype=self.dtype, dtype=self.dtype,
use_bias=self.config.use_bias, use_bias=self.config.use_bias,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
self.value = nn.Dense( self.value = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
dtype=self.dtype, dtype=self.dtype,
use_bias=self.config.use_bias, use_bias=self.config.use_bias,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
@staticmethod @staticmethod
...@@ -1074,7 +1083,7 @@ class FlaxBigBirdSelfOutput(nn.Module): ...@@ -1074,7 +1083,7 @@ class FlaxBigBirdSelfOutput(nn.Module):
def setup(self): def setup(self):
self.dense = nn.Dense( self.dense = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
...@@ -1131,7 +1140,7 @@ class FlaxBigBirdIntermediate(nn.Module): ...@@ -1131,7 +1140,7 @@ class FlaxBigBirdIntermediate(nn.Module):
def setup(self): def setup(self):
self.dense = nn.Dense( self.dense = nn.Dense(
self.config.intermediate_size, self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
self.activation = ACT2FN[self.config.hidden_act] self.activation = ACT2FN[self.config.hidden_act]
...@@ -1150,7 +1159,7 @@ class FlaxBigBirdOutput(nn.Module): ...@@ -1150,7 +1159,7 @@ class FlaxBigBirdOutput(nn.Module):
def setup(self): def setup(self):
self.dense = nn.Dense( self.dense = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
...@@ -1301,7 +1310,8 @@ class FlaxBigBirdLMPredictionHead(nn.Module): ...@@ -1301,7 +1310,8 @@ class FlaxBigBirdLMPredictionHead(nn.Module):
else: else:
hidden_states = self.decoder(hidden_states) hidden_states = self.decoder(hidden_states)
hidden_states += self.bias bias = jnp.asarray(self.bias, self.dtype)
hidden_states += bias
return hidden_states return hidden_states
...@@ -1431,7 +1441,7 @@ class FlaxBigBirdModule(nn.Module): ...@@ -1431,7 +1441,7 @@ class FlaxBigBirdModule(nn.Module):
self.encoder = FlaxBigBirdEncoder(self.config, dtype=self.dtype) self.encoder = FlaxBigBirdEncoder(self.config, dtype=self.dtype)
self.pooler = nn.Dense( self.pooler = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
......
...@@ -60,6 +60,18 @@ CLIP_START_DOCSTRING = r""" ...@@ -60,6 +60,18 @@ CLIP_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights. model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
""" """
CLIP_TEXT_INPUTS_DOCSTRING = r""" CLIP_TEXT_INPUTS_DOCSTRING = r"""
...@@ -262,18 +274,10 @@ class FlaxCLIPAttention(nn.Module): ...@@ -262,18 +274,10 @@ class FlaxCLIPAttention(nn.Module):
self.scale = self.head_dim ** -0.5 self.scale = self.head_dim ** -0.5
self.dropout = self.config.attention_dropout self.dropout = self.config.attention_dropout
self.k_proj = nn.Dense( self.k_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01))
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype) self.v_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01))
) self.q_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01))
self.v_proj = nn.Dense( self.out_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01))
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype)
)
self.q_proj = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype)
)
self.out_proj = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype)
)
self.causal = isinstance(self.config, CLIPTextConfig) self.causal = isinstance(self.config, CLIPTextConfig)
if self.causal: if self.causal:
...@@ -354,11 +358,9 @@ class FlaxCLIPMLP(nn.Module): ...@@ -354,11 +358,9 @@ class FlaxCLIPMLP(nn.Module):
self.fc1 = nn.Dense( self.fc1 = nn.Dense(
self.config.intermediate_size, self.config.intermediate_size,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype), kernel_init=jax.nn.initializers.normal(0.01),
)
self.fc2 = nn.Dense(
self.config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype)
) )
self.fc2 = nn.Dense(self.config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01))
def __call__(self, hidden_states): def __call__(self, hidden_states):
hidden_states = self.fc1(hidden_states) hidden_states = self.fc1(hidden_states)
...@@ -1032,18 +1034,18 @@ class FlaxCLIPModule(nn.Module): ...@@ -1032,18 +1034,18 @@ class FlaxCLIPModule(nn.Module):
self.visual_projection = nn.Dense( self.visual_projection = nn.Dense(
self.projection_dim, self.projection_dim,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype), kernel_init=jax.nn.initializers.normal(0.02),
use_bias=False, use_bias=False,
) )
self.text_projection = nn.Dense( self.text_projection = nn.Dense(
self.projection_dim, self.projection_dim,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype), kernel_init=jax.nn.initializers.normal(0.02),
use_bias=False, use_bias=False,
) )
self.logit_scale = self.param( self.logit_scale = self.param(
"logit_scale", lambda _, shape: jnp.ones(shape, dtype=self.dtype) * self.config.logit_scale_init_value, [] "logit_scale", lambda _, shape: jnp.ones(shape) * self.config.logit_scale_init_value, []
) )
def __call__( def __call__(
......
...@@ -102,7 +102,7 @@ def get_angles(pos, i, d_model): ...@@ -102,7 +102,7 @@ def get_angles(pos, i, d_model):
return pos * angle_rates return pos * angle_rates
def positional_encoding(position, d_model, dtype): def positional_encoding(position, d_model):
# create the sinusoidal pattern for the positional encoding # create the sinusoidal pattern for the positional encoding
angle_rads = get_angles(np.arange(position)[:, np.newaxis], np.arange(d_model)[np.newaxis, :], d_model) angle_rads = get_angles(np.arange(position)[:, np.newaxis], np.arange(d_model)[np.newaxis, :], d_model)
...@@ -114,8 +114,7 @@ def positional_encoding(position, d_model, dtype): ...@@ -114,8 +114,7 @@ def positional_encoding(position, d_model, dtype):
pos_encoding = angle_rads[np.newaxis, ...] pos_encoding = angle_rads[np.newaxis, ...]
# cast to dtype return jnp.array(pos_encoding)
return jnp.array(pos_encoding, dtype=dtype)
class FlaxEmbeddings(nn.Module): class FlaxEmbeddings(nn.Module):
...@@ -129,17 +128,15 @@ class FlaxEmbeddings(nn.Module): ...@@ -129,17 +128,15 @@ class FlaxEmbeddings(nn.Module):
self.config.vocab_size, self.config.vocab_size,
self.config.dim, self.config.dim,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
if not self.config.sinusoidal_pos_embds: if not self.config.sinusoidal_pos_embds:
self.position_embeddings = nn.Embed( self.position_embeddings = nn.Embed(
self.config.max_position_embeddings, self.config.max_position_embeddings,
self.config.dim, self.config.dim,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
else: else:
self.pos_encoding = positional_encoding(self.config.max_position_embeddings, self.config.dim, self.dtype) self.pos_encoding = positional_encoding(self.config.max_position_embeddings, self.config.dim)
self.LayerNorm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype) self.LayerNorm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.dropout) self.dropout = nn.Dropout(rate=self.config.dropout)
...@@ -153,6 +150,8 @@ class FlaxEmbeddings(nn.Module): ...@@ -153,6 +150,8 @@ class FlaxEmbeddings(nn.Module):
position_embeds = self.position_embeddings(position_ids.astype("i4")) position_embeds = self.position_embeddings(position_ids.astype("i4"))
else: else:
position_embeds = self.pos_encoding[:, :seq_length, :] position_embeds = self.pos_encoding[:, :seq_length, :]
# explictly cast the positions here, since self.embed_positions are not registered as parameters
position_embeds = position_embeds.astype(inputs_embeds.dtype)
# Sum all embeddings # Sum all embeddings
hidden_states = inputs_embeds + position_embeds hidden_states = inputs_embeds + position_embeds
...@@ -289,10 +288,10 @@ class FlaxTransformerBlock(nn.Module): ...@@ -289,10 +288,10 @@ class FlaxTransformerBlock(nn.Module):
), f"Hidden size {self.config.dim} not dividable by number of heads {self.config.n_heads}" ), f"Hidden size {self.config.dim} not dividable by number of heads {self.config.n_heads}"
self.attention = FlaxMultiHeadSelfAttention(self.config, dtype=self.dtype) self.attention = FlaxMultiHeadSelfAttention(self.config, dtype=self.dtype)
self.sa_layer_norm = nn.LayerNorm(epsilon=1e-12) self.sa_layer_norm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype)
self.ffn = FlaxFFN(self.config, dtype=self.dtype) self.ffn = FlaxFFN(self.config, dtype=self.dtype)
self.output_layer_norm = nn.LayerNorm(epsilon=1e-12) self.output_layer_norm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype)
def __call__( def __call__(
self, self,
...@@ -412,8 +411,11 @@ class FlaxDistilBertLMDecoder(nn.Module): ...@@ -412,8 +411,11 @@ class FlaxDistilBertLMDecoder(nn.Module):
self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,)) self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
def __call__(self, inputs, kernel): def __call__(self, inputs, kernel):
inputs = jnp.asarray(inputs, self.dtype)
kernel = jnp.asarray(kernel, self.dtype)
y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ()))) y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())))
y = y + self.bias bias = jnp.asarray(self.bias, self.dtype)
y = y + bias
return y return y
......
...@@ -148,19 +148,16 @@ class FlaxElectraEmbeddings(nn.Module): ...@@ -148,19 +148,16 @@ class FlaxElectraEmbeddings(nn.Module):
self.config.vocab_size, self.config.vocab_size,
self.config.embedding_size, self.config.embedding_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.position_embeddings = nn.Embed( self.position_embeddings = nn.Embed(
self.config.max_position_embeddings, self.config.max_position_embeddings,
self.config.embedding_size, self.config.embedding_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.token_type_embeddings = nn.Embed( self.token_type_embeddings = nn.Embed(
self.config.type_vocab_size, self.config.type_vocab_size,
self.config.embedding_size, self.config.embedding_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
...@@ -196,17 +193,17 @@ class FlaxElectraSelfAttention(nn.Module): ...@@ -196,17 +193,17 @@ class FlaxElectraSelfAttention(nn.Module):
self.query = nn.Dense( self.query = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
self.key = nn.Dense( self.key = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
self.value = nn.Dense( self.value = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False): def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
...@@ -265,7 +262,7 @@ class FlaxElectraSelfOutput(nn.Module): ...@@ -265,7 +262,7 @@ class FlaxElectraSelfOutput(nn.Module):
def setup(self): def setup(self):
self.dense = nn.Dense( self.dense = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
...@@ -313,7 +310,7 @@ class FlaxElectraIntermediate(nn.Module): ...@@ -313,7 +310,7 @@ class FlaxElectraIntermediate(nn.Module):
def setup(self): def setup(self):
self.dense = nn.Dense( self.dense = nn.Dense(
self.config.intermediate_size, self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
self.activation = ACT2FN[self.config.hidden_act] self.activation = ACT2FN[self.config.hidden_act]
...@@ -332,7 +329,7 @@ class FlaxElectraOutput(nn.Module): ...@@ -332,7 +329,7 @@ class FlaxElectraOutput(nn.Module):
def setup(self): def setup(self):
self.dense = nn.Dense( self.dense = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
...@@ -570,7 +567,7 @@ class FlaxElectraModule(nn.Module): ...@@ -570,7 +567,7 @@ class FlaxElectraModule(nn.Module):
def setup(self): def setup(self):
self.embeddings = FlaxElectraEmbeddings(self.config, dtype=self.dtype) self.embeddings = FlaxElectraEmbeddings(self.config, dtype=self.dtype)
if self.config.embedding_size != self.config.hidden_size: if self.config.embedding_size != self.config.hidden_size:
self.embeddings_project = nn.Dense(self.config.hidden_size) self.embeddings_project = nn.Dense(self.config.hidden_size, dtype=self.dtype)
self.encoder = FlaxElectraEncoder(self.config, dtype=self.dtype) self.encoder = FlaxElectraEncoder(self.config, dtype=self.dtype)
def __call__( def __call__(
...@@ -620,17 +617,19 @@ class FlaxElectraTiedDense(nn.Module): ...@@ -620,17 +617,19 @@ class FlaxElectraTiedDense(nn.Module):
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
def setup(self): def setup(self):
bias = self.param("bias", self.bias_init, (self.embedding_size,)) self.bias = self.param("bias", self.bias_init, (self.embedding_size,))
self.bias = jnp.asarray(bias, dtype=self.dtype)
def __call__(self, x, kernel): def __call__(self, x, kernel):
x = jnp.asarray(x, self.dtype)
kernel = jnp.asarray(kernel, self.dtype)
y = lax.dot_general( y = lax.dot_general(
x, x,
kernel, kernel,
(((x.ndim - 1,), (0,)), ((), ())), (((x.ndim - 1,), (0,)), ((), ())),
precision=self.precision, precision=self.precision,
) )
return y + self.bias bias = jnp.asarray(self.bias, self.dtype)
return y + bias
class FlaxElectraForMaskedLMModule(nn.Module): class FlaxElectraForMaskedLMModule(nn.Module):
...@@ -639,7 +638,7 @@ class FlaxElectraForMaskedLMModule(nn.Module): ...@@ -639,7 +638,7 @@ class FlaxElectraForMaskedLMModule(nn.Module):
def setup(self): def setup(self):
self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype)
self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config) self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype)
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype) self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype)
else: else:
...@@ -788,7 +787,7 @@ class FlaxElectraForTokenClassificationModule(nn.Module): ...@@ -788,7 +787,7 @@ class FlaxElectraForTokenClassificationModule(nn.Module):
else self.config.hidden_dropout_prob else self.config.hidden_dropout_prob
) )
self.dropout = nn.Dropout(classifier_dropout) self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Dense(self.config.num_labels) self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
def __call__( def __call__(
self, self,
......
...@@ -64,6 +64,18 @@ ENCODER_DECODER_START_DOCSTRING = r""" ...@@ -64,6 +64,18 @@ ENCODER_DECODER_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights. model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
""" """
ENCODER_DECODER_INPUTS_DOCSTRING = r""" ENCODER_DECODER_INPUTS_DOCSTRING = r"""
......
...@@ -62,6 +62,18 @@ GPT2_START_DOCSTRING = r""" ...@@ -62,6 +62,18 @@ GPT2_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights. model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
""" """
GPT2_INPUTS_DOCSTRING = r""" GPT2_INPUTS_DOCSTRING = r"""
...@@ -576,13 +588,11 @@ class FlaxGPT2Module(nn.Module): ...@@ -576,13 +588,11 @@ class FlaxGPT2Module(nn.Module):
self.config.vocab_size, self.config.vocab_size,
self.embed_dim, self.embed_dim,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.wpe = nn.Embed( self.wpe = nn.Embed(
self.config.max_position_embeddings, self.config.max_position_embeddings,
self.embed_dim, self.embed_dim,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.dropout = nn.Dropout(rate=self.config.embd_pdrop) self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
self.h = FlaxGPT2BlockCollection(self.config, dtype=self.dtype) self.h = FlaxGPT2BlockCollection(self.config, dtype=self.dtype)
...@@ -666,7 +676,7 @@ class FlaxGPT2LMHeadModule(nn.Module): ...@@ -666,7 +676,7 @@ class FlaxGPT2LMHeadModule(nn.Module):
self.config.vocab_size, self.config.vocab_size,
use_bias=False, use_bias=False,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range, dtype=self.dtype), kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
) )
def __call__( def __call__(
......
...@@ -60,6 +60,18 @@ GPT_NEO_START_DOCSTRING = r""" ...@@ -60,6 +60,18 @@ GPT_NEO_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights. model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
""" """
GPT_NEO_INPUTS_DOCSTRING = r""" GPT_NEO_INPUTS_DOCSTRING = r"""
...@@ -119,7 +131,7 @@ class FlaxGPTNeoSelfAttention(nn.Module): ...@@ -119,7 +131,7 @@ class FlaxGPTNeoSelfAttention(nn.Module):
nn.Dense, nn.Dense,
self.embed_dim, self.embed_dim,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
self.q_proj, self.k_proj, self.v_proj = dense(use_bias=False), dense(use_bias=False), dense(use_bias=False) self.q_proj, self.k_proj, self.v_proj = dense(use_bias=False), dense(use_bias=False), dense(use_bias=False)
...@@ -270,7 +282,7 @@ class FlaxGPTNeoMLP(nn.Module): ...@@ -270,7 +282,7 @@ class FlaxGPTNeoMLP(nn.Module):
def setup(self): def setup(self):
embed_dim = self.config.hidden_size embed_dim = self.config.hidden_size
kernel_init = jax.nn.initializers.normal(self.config.initializer_range, self.dtype) kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
self.c_fc = nn.Dense(self.intermediate_size, dtype=self.dtype, kernel_init=kernel_init) self.c_fc = nn.Dense(self.intermediate_size, dtype=self.dtype, kernel_init=kernel_init)
self.c_proj = nn.Dense(embed_dim, dtype=self.dtype, kernel_init=kernel_init) self.c_proj = nn.Dense(embed_dim, dtype=self.dtype, kernel_init=kernel_init)
self.act = ACT2FN[self.config.activation_function] self.act = ACT2FN[self.config.activation_function]
...@@ -505,13 +517,11 @@ class FlaxGPTNeoModule(nn.Module): ...@@ -505,13 +517,11 @@ class FlaxGPTNeoModule(nn.Module):
self.config.vocab_size, self.config.vocab_size,
self.embed_dim, self.embed_dim,
embedding_init=embedding_init, embedding_init=embedding_init,
dtype=self.dtype,
) )
self.wpe = nn.Embed( self.wpe = nn.Embed(
self.config.max_position_embeddings, self.config.max_position_embeddings,
self.embed_dim, self.embed_dim,
embedding_init=embedding_init, embedding_init=embedding_init,
dtype=self.dtype,
) )
self.dropout = nn.Dropout(rate=self.config.embed_dropout) self.dropout = nn.Dropout(rate=self.config.embed_dropout)
self.h = FlaxGPTNeoBlockCollection(self.config, dtype=self.dtype) self.h = FlaxGPTNeoBlockCollection(self.config, dtype=self.dtype)
...@@ -589,7 +599,7 @@ class FlaxGPTNeoForCausalLMModule(nn.Module): ...@@ -589,7 +599,7 @@ class FlaxGPTNeoForCausalLMModule(nn.Module):
self.config.vocab_size, self.config.vocab_size,
use_bias=False, use_bias=False,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range, dtype=self.dtype), kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
) )
def __call__( def __call__(
......
...@@ -71,6 +71,18 @@ MARIAN_START_DOCSTRING = r""" ...@@ -71,6 +71,18 @@ MARIAN_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights. model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
""" """
MARIAN_INPUTS_DOCSTRING = r""" MARIAN_INPUTS_DOCSTRING = r"""
...@@ -206,14 +218,14 @@ MARIAN_DECODE_INPUTS_DOCSTRING = r""" ...@@ -206,14 +218,14 @@ MARIAN_DECODE_INPUTS_DOCSTRING = r"""
""" """
def create_sinusoidal_positions(n_pos, dim, dtype): def create_sinusoidal_positions(n_pos, dim):
position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]) position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
sentinel = dim // 2 + dim % 2 sentinel = dim // 2 + dim % 2
out = np.zeros_like(position_enc) out = np.zeros_like(position_enc)
out[:, 0:sentinel] = np.sin(position_enc[:, 0::2]) out[:, 0:sentinel] = np.sin(position_enc[:, 0::2])
out[:, sentinel:] = np.cos(position_enc[:, 1::2]) out[:, sentinel:] = np.cos(position_enc[:, 1::2])
return jnp.array(out, dtype=dtype) return jnp.array(out)
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
...@@ -252,7 +264,7 @@ class FlaxMarianAttention(nn.Module): ...@@ -252,7 +264,7 @@ class FlaxMarianAttention(nn.Module):
self.embed_dim, self.embed_dim,
use_bias=self.bias, use_bias=self.bias,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
...@@ -409,6 +421,7 @@ class FlaxMarianEncoderLayer(nn.Module): ...@@ -409,6 +421,7 @@ class FlaxMarianEncoderLayer(nn.Module):
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=self.config.encoder_attention_heads, num_heads=self.config.encoder_attention_heads,
dropout=self.config.attention_dropout, dropout=self.config.attention_dropout,
dtype=self.dtype,
) )
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.dropout_layer = nn.Dropout(rate=self.config.dropout)
...@@ -417,10 +430,10 @@ class FlaxMarianEncoderLayer(nn.Module): ...@@ -417,10 +430,10 @@ class FlaxMarianEncoderLayer(nn.Module):
self.fc1 = nn.Dense( self.fc1 = nn.Dense(
self.config.encoder_ffn_dim, self.config.encoder_ffn_dim,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
self.fc2 = nn.Dense( self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
) )
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
...@@ -522,6 +535,7 @@ class FlaxMarianDecoderLayer(nn.Module): ...@@ -522,6 +535,7 @@ class FlaxMarianDecoderLayer(nn.Module):
num_heads=self.config.decoder_attention_heads, num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout, dropout=self.config.attention_dropout,
causal=True, causal=True,
dtype=self.dtype,
) )
self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function] self.activation_fn = ACT2FN[self.config.activation_function]
...@@ -533,15 +547,16 @@ class FlaxMarianDecoderLayer(nn.Module): ...@@ -533,15 +547,16 @@ class FlaxMarianDecoderLayer(nn.Module):
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads, num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout, dropout=self.config.attention_dropout,
dtype=self.dtype,
) )
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.fc1 = nn.Dense( self.fc1 = nn.Dense(
self.config.encoder_ffn_dim, self.config.encoder_ffn_dim,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
self.fc2 = nn.Dense( self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
) )
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
...@@ -683,13 +698,10 @@ class FlaxMarianEncoder(nn.Module): ...@@ -683,13 +698,10 @@ class FlaxMarianEncoder(nn.Module):
self.embed_tokens = nn.Embed( self.embed_tokens = nn.Embed(
self.config.vocab_size, self.config.vocab_size,
embed_dim, embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
self.embed_positions = create_sinusoidal_positions( self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim)
self.config.max_position_embeddings, embed_dim, dtype=self.dtype
)
self.layers = FlaxMarianEncoderLayerCollection(self.config, self.dtype) self.layers = FlaxMarianEncoderLayerCollection(self.config, self.dtype)
def __call__( def __call__(
...@@ -708,6 +720,8 @@ class FlaxMarianEncoder(nn.Module): ...@@ -708,6 +720,8 @@ class FlaxMarianEncoder(nn.Module):
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
positions = jnp.take(self.embed_positions, position_ids, axis=0) positions = jnp.take(self.embed_positions, position_ids, axis=0)
# explictly cast the positions here, since self.embed_positions are not registered as parameters
positions = positions.astype(inputs_embeds.dtype)
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
...@@ -747,13 +761,10 @@ class FlaxMarianDecoder(nn.Module): ...@@ -747,13 +761,10 @@ class FlaxMarianDecoder(nn.Module):
self.embed_tokens = nn.Embed( self.embed_tokens = nn.Embed(
self.config.vocab_size, self.config.vocab_size,
embed_dim, embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
self.embed_positions = create_sinusoidal_positions( self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim)
self.config.max_position_embeddings, embed_dim, dtype=self.dtype
)
self.layers = FlaxMarianDecoderLayerCollection(self.config, self.dtype) self.layers = FlaxMarianDecoderLayerCollection(self.config, self.dtype)
def __call__( def __call__(
...@@ -776,6 +787,8 @@ class FlaxMarianDecoder(nn.Module): ...@@ -776,6 +787,8 @@ class FlaxMarianDecoder(nn.Module):
# embed positions # embed positions
positions = jnp.take(self.embed_positions, position_ids, axis=0) positions = jnp.take(self.embed_positions, position_ids, axis=0)
# explictly cast the positions here, since self.embed_positions are not registered as parameters
positions = positions.astype(inputs_embeds.dtype)
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
...@@ -812,8 +825,7 @@ class FlaxMarianModule(nn.Module): ...@@ -812,8 +825,7 @@ class FlaxMarianModule(nn.Module):
self.shared = nn.Embed( self.shared = nn.Embed(
self.config.vocab_size, self.config.vocab_size,
self.config.d_model, self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
self.encoder = FlaxMarianEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) self.encoder = FlaxMarianEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
...@@ -1214,7 +1226,7 @@ class FlaxMarianMTModule(nn.Module): ...@@ -1214,7 +1226,7 @@ class FlaxMarianMTModule(nn.Module):
self.model.shared.num_embeddings, self.model.shared.num_embeddings,
use_bias=False, use_bias=False,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings)) self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))
...@@ -1258,7 +1270,7 @@ class FlaxMarianMTModule(nn.Module): ...@@ -1258,7 +1270,7 @@ class FlaxMarianMTModule(nn.Module):
else: else:
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
lm_logits += self.final_logits_bias lm_logits += self.final_logits_bias.astype(self.dtype)
if not return_dict: if not return_dict:
output = (lm_logits,) + outputs[1:] output = (lm_logits,) + outputs[1:]
...@@ -1373,7 +1385,7 @@ class FlaxMarianMTModel(FlaxMarianPreTrainedModel): ...@@ -1373,7 +1385,7 @@ class FlaxMarianMTModel(FlaxMarianPreTrainedModel):
lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
else: else:
lm_logits = module.lm_head(hidden_states) lm_logits = module.lm_head(hidden_states)
lm_logits += module.final_logits_bias lm_logits += module.final_logits_bias.astype(self.dtype)
return lm_logits, outputs return lm_logits, outputs
......
...@@ -79,6 +79,18 @@ MBART_START_DOCSTRING = r""" ...@@ -79,6 +79,18 @@ MBART_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights. model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
""" """
MBART_INPUTS_DOCSTRING = r""" MBART_INPUTS_DOCSTRING = r"""
...@@ -259,7 +271,7 @@ class FlaxMBartAttention(nn.Module): ...@@ -259,7 +271,7 @@ class FlaxMBartAttention(nn.Module):
self.embed_dim, self.embed_dim,
use_bias=self.bias, use_bias=self.bias,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
...@@ -415,6 +427,7 @@ class FlaxMBartEncoderLayer(nn.Module): ...@@ -415,6 +427,7 @@ class FlaxMBartEncoderLayer(nn.Module):
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=self.config.encoder_attention_heads, num_heads=self.config.encoder_attention_heads,
dropout=self.config.attention_dropout, dropout=self.config.attention_dropout,
dtype=self.dtype,
) )
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.dropout_layer = nn.Dropout(rate=self.config.dropout)
...@@ -423,10 +436,10 @@ class FlaxMBartEncoderLayer(nn.Module): ...@@ -423,10 +436,10 @@ class FlaxMBartEncoderLayer(nn.Module):
self.fc1 = nn.Dense( self.fc1 = nn.Dense(
self.config.encoder_ffn_dim, self.config.encoder_ffn_dim,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
self.fc2 = nn.Dense( self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
) )
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
...@@ -526,6 +539,7 @@ class FlaxMBartDecoderLayer(nn.Module): ...@@ -526,6 +539,7 @@ class FlaxMBartDecoderLayer(nn.Module):
num_heads=self.config.decoder_attention_heads, num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout, dropout=self.config.attention_dropout,
causal=True, causal=True,
dtype=self.dtype,
) )
self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function] self.activation_fn = ACT2FN[self.config.activation_function]
...@@ -537,15 +551,16 @@ class FlaxMBartDecoderLayer(nn.Module): ...@@ -537,15 +551,16 @@ class FlaxMBartDecoderLayer(nn.Module):
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads, num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout, dropout=self.config.attention_dropout,
dtype=self.dtype,
) )
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.fc1 = nn.Dense( self.fc1 = nn.Dense(
self.config.encoder_ffn_dim, self.config.encoder_ffn_dim,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
self.fc2 = nn.Dense( self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
) )
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
...@@ -683,13 +698,13 @@ class FlaxMBartClassificationHead(nn.Module): ...@@ -683,13 +698,13 @@ class FlaxMBartClassificationHead(nn.Module):
def setup(self): def setup(self):
self.dense = nn.Dense( self.dense = nn.Dense(
self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
) )
self.dropout = nn.Dropout(rate=self.pooler_dropout) self.dropout = nn.Dropout(rate=self.pooler_dropout)
self.out_proj = nn.Dense( self.out_proj = nn.Dense(
self.num_classes, self.num_classes,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
def __call__(self, hidden_states: jnp.ndarray, deterministic: bool): def __call__(self, hidden_states: jnp.ndarray, deterministic: bool):
...@@ -718,8 +733,7 @@ class FlaxMBartEncoder(nn.Module): ...@@ -718,8 +733,7 @@ class FlaxMBartEncoder(nn.Module):
self.embed_tokens = nn.Embed( self.embed_tokens = nn.Embed(
self.config.vocab_size, self.config.vocab_size,
embed_dim, embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
# MBart is set up so that if padding_idx is specified then offset the embedding ids by 2 # MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
...@@ -728,8 +742,7 @@ class FlaxMBartEncoder(nn.Module): ...@@ -728,8 +742,7 @@ class FlaxMBartEncoder(nn.Module):
self.embed_positions = nn.Embed( self.embed_positions = nn.Embed(
self.config.max_position_embeddings + self.offset, self.config.max_position_embeddings + self.offset,
embed_dim, embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
self.layers = FlaxMBartEncoderLayerCollection(self.config, self.dtype) self.layers = FlaxMBartEncoderLayerCollection(self.config, self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype) self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype)
...@@ -795,8 +808,7 @@ class FlaxMBartDecoder(nn.Module): ...@@ -795,8 +808,7 @@ class FlaxMBartDecoder(nn.Module):
self.embed_tokens = nn.Embed( self.embed_tokens = nn.Embed(
self.config.vocab_size, self.config.vocab_size,
embed_dim, embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
# MBart is set up so that if padding_idx is specified then offset the embedding ids by 2 # MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
...@@ -805,8 +817,7 @@ class FlaxMBartDecoder(nn.Module): ...@@ -805,8 +817,7 @@ class FlaxMBartDecoder(nn.Module):
self.embed_positions = nn.Embed( self.embed_positions = nn.Embed(
self.config.max_position_embeddings + self.offset, self.config.max_position_embeddings + self.offset,
embed_dim, embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
self.layers = FlaxMBartDecoderLayerCollection(self.config, self.dtype) self.layers = FlaxMBartDecoderLayerCollection(self.config, self.dtype)
...@@ -874,8 +885,7 @@ class FlaxMBartModule(nn.Module): ...@@ -874,8 +885,7 @@ class FlaxMBartModule(nn.Module):
self.shared = nn.Embed( self.shared = nn.Embed(
self.config.vocab_size, self.config.vocab_size,
self.config.d_model, self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
self.encoder = FlaxMBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) self.encoder = FlaxMBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
...@@ -1280,7 +1290,7 @@ class FlaxMBartForConditionalGenerationModule(nn.Module): ...@@ -1280,7 +1290,7 @@ class FlaxMBartForConditionalGenerationModule(nn.Module):
self.model.shared.num_embeddings, self.model.shared.num_embeddings,
use_bias=False, use_bias=False,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings)) self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))
...@@ -1324,7 +1334,7 @@ class FlaxMBartForConditionalGenerationModule(nn.Module): ...@@ -1324,7 +1334,7 @@ class FlaxMBartForConditionalGenerationModule(nn.Module):
else: else:
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
lm_logits += self.final_logits_bias lm_logits += self.final_logits_bias.astype(self.dtype)
if not return_dict: if not return_dict:
output = (lm_logits,) + outputs[1:] output = (lm_logits,) + outputs[1:]
...@@ -1440,7 +1450,7 @@ class FlaxMBartForConditionalGeneration(FlaxMBartPreTrainedModel): ...@@ -1440,7 +1450,7 @@ class FlaxMBartForConditionalGeneration(FlaxMBartPreTrainedModel):
else: else:
lm_logits = module.lm_head(hidden_states) lm_logits = module.lm_head(hidden_states)
lm_logits += module.final_logits_bias lm_logits += module.final_logits_bias.astype(self.dtype)
return lm_logits, outputs return lm_logits, outputs
outputs = self.module.apply( outputs = self.module.apply(
...@@ -1674,7 +1684,7 @@ class FlaxMBartForQuestionAnsweringModule(nn.Module): ...@@ -1674,7 +1684,7 @@ class FlaxMBartForQuestionAnsweringModule(nn.Module):
def setup(self): def setup(self):
self.model = FlaxMBartModule(config=self.config, dtype=self.dtype) self.model = FlaxMBartModule(config=self.config, dtype=self.dtype)
self.qa_outputs = nn.Dense( self.qa_outputs = nn.Dense(
self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
) )
def _get_encoder_module(self): def _get_encoder_module(self):
......
...@@ -78,6 +78,18 @@ PEGASUS_START_DOCSTRING = r""" ...@@ -78,6 +78,18 @@ PEGASUS_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights. model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
""" """
PEGASUS_INPUTS_DOCSTRING = r""" PEGASUS_INPUTS_DOCSTRING = r"""
...@@ -226,7 +238,7 @@ def create_sinusoidal_positions(n_pos, dim, dtype): ...@@ -226,7 +238,7 @@ def create_sinusoidal_positions(n_pos, dim, dtype):
out[:, 0:sentinel] = np.sin(position_enc[:, 0::2]) out[:, 0:sentinel] = np.sin(position_enc[:, 0::2])
out[:, sentinel:] = np.cos(position_enc[:, 1::2]) out[:, sentinel:] = np.cos(position_enc[:, 1::2])
return jnp.array(out, dtype=dtype) return jnp.array(out)
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention with Bart->Pegasus # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention with Bart->Pegasus
...@@ -252,7 +264,7 @@ class FlaxPegasusAttention(nn.Module): ...@@ -252,7 +264,7 @@ class FlaxPegasusAttention(nn.Module):
self.embed_dim, self.embed_dim,
use_bias=self.bias, use_bias=self.bias,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
...@@ -409,6 +421,7 @@ class FlaxPegasusEncoderLayer(nn.Module): ...@@ -409,6 +421,7 @@ class FlaxPegasusEncoderLayer(nn.Module):
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=self.config.encoder_attention_heads, num_heads=self.config.encoder_attention_heads,
dropout=self.config.attention_dropout, dropout=self.config.attention_dropout,
dtype=self.dtype,
) )
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.dropout_layer = nn.Dropout(rate=self.config.dropout)
...@@ -417,10 +430,10 @@ class FlaxPegasusEncoderLayer(nn.Module): ...@@ -417,10 +430,10 @@ class FlaxPegasusEncoderLayer(nn.Module):
self.fc1 = nn.Dense( self.fc1 = nn.Dense(
self.config.encoder_ffn_dim, self.config.encoder_ffn_dim,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
self.fc2 = nn.Dense( self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
) )
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
...@@ -521,6 +534,7 @@ class FlaxPegasusDecoderLayer(nn.Module): ...@@ -521,6 +534,7 @@ class FlaxPegasusDecoderLayer(nn.Module):
num_heads=self.config.decoder_attention_heads, num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout, dropout=self.config.attention_dropout,
causal=True, causal=True,
dtype=self.dtype,
) )
self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function] self.activation_fn = ACT2FN[self.config.activation_function]
...@@ -532,15 +546,16 @@ class FlaxPegasusDecoderLayer(nn.Module): ...@@ -532,15 +546,16 @@ class FlaxPegasusDecoderLayer(nn.Module):
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads, num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout, dropout=self.config.attention_dropout,
dtype=self.dtype,
) )
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.fc1 = nn.Dense( self.fc1 = nn.Dense(
self.config.encoder_ffn_dim, self.config.encoder_ffn_dim,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
self.fc2 = nn.Dense( self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
) )
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
...@@ -683,8 +698,7 @@ class FlaxPegasusEncoder(nn.Module): ...@@ -683,8 +698,7 @@ class FlaxPegasusEncoder(nn.Module):
self.embed_tokens = nn.Embed( self.embed_tokens = nn.Embed(
self.config.vocab_size, self.config.vocab_size,
embed_dim, embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
self.embed_positions = create_sinusoidal_positions( self.embed_positions = create_sinusoidal_positions(
...@@ -710,6 +724,8 @@ class FlaxPegasusEncoder(nn.Module): ...@@ -710,6 +724,8 @@ class FlaxPegasusEncoder(nn.Module):
# embed positions # embed positions
embed_pos = jnp.take(self.embed_positions, position_ids, axis=0) embed_pos = jnp.take(self.embed_positions, position_ids, axis=0)
# explictly cast the positions here, since self.embed_positions are not registered as parameters
embed_pos = embed_pos.astype(inputs_embeds.dtype)
hidden_states = inputs_embeds + embed_pos hidden_states = inputs_embeds + embed_pos
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
...@@ -751,8 +767,7 @@ class FlaxPegasusDecoder(nn.Module): ...@@ -751,8 +767,7 @@ class FlaxPegasusDecoder(nn.Module):
self.embed_tokens = nn.Embed( self.embed_tokens = nn.Embed(
self.config.vocab_size, self.config.vocab_size,
embed_dim, embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
self.embed_positions = create_sinusoidal_positions( self.embed_positions = create_sinusoidal_positions(
...@@ -782,6 +797,8 @@ class FlaxPegasusDecoder(nn.Module): ...@@ -782,6 +797,8 @@ class FlaxPegasusDecoder(nn.Module):
# embed positions # embed positions
positions = jnp.take(self.embed_positions, position_ids, axis=0) positions = jnp.take(self.embed_positions, position_ids, axis=0)
# explictly cast the positions here, since self.embed_positions are not registered as parameters
positions = positions.astype(inputs_embeds.dtype)
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
...@@ -819,8 +836,7 @@ class FlaxPegasusModule(nn.Module): ...@@ -819,8 +836,7 @@ class FlaxPegasusModule(nn.Module):
self.shared = nn.Embed( self.shared = nn.Embed(
self.config.vocab_size, self.config.vocab_size,
self.config.d_model, self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
self.encoder = FlaxPegasusEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) self.encoder = FlaxPegasusEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
...@@ -1224,7 +1240,7 @@ class FlaxPegasusForConditionalGenerationModule(nn.Module): ...@@ -1224,7 +1240,7 @@ class FlaxPegasusForConditionalGenerationModule(nn.Module):
self.model.shared.num_embeddings, self.model.shared.num_embeddings,
use_bias=False, use_bias=False,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings)) self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))
...@@ -1268,7 +1284,7 @@ class FlaxPegasusForConditionalGenerationModule(nn.Module): ...@@ -1268,7 +1284,7 @@ class FlaxPegasusForConditionalGenerationModule(nn.Module):
else: else:
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
lm_logits += self.final_logits_bias lm_logits += self.final_logits_bias.astype(self.dtype)
if not return_dict: if not return_dict:
output = (lm_logits,) + outputs[1:] output = (lm_logits,) + outputs[1:]
...@@ -1384,7 +1400,7 @@ class FlaxPegasusForConditionalGeneration(FlaxPegasusPreTrainedModel): ...@@ -1384,7 +1400,7 @@ class FlaxPegasusForConditionalGeneration(FlaxPegasusPreTrainedModel):
else: else:
lm_logits = module.lm_head(hidden_states) lm_logits = module.lm_head(hidden_states)
lm_logits += module.final_logits_bias lm_logits += module.final_logits_bias.astype(self.dtype)
return lm_logits, outputs return lm_logits, outputs
outputs = self.module.apply( outputs = self.module.apply(
......
...@@ -139,19 +139,16 @@ class FlaxRobertaEmbeddings(nn.Module): ...@@ -139,19 +139,16 @@ class FlaxRobertaEmbeddings(nn.Module):
self.config.vocab_size, self.config.vocab_size,
self.config.hidden_size, self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.position_embeddings = nn.Embed( self.position_embeddings = nn.Embed(
self.config.max_position_embeddings, self.config.max_position_embeddings,
self.config.hidden_size, self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.token_type_embeddings = nn.Embed( self.token_type_embeddings = nn.Embed(
self.config.type_vocab_size, self.config.type_vocab_size,
self.config.hidden_size, self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
...@@ -186,17 +183,17 @@ class FlaxRobertaSelfAttention(nn.Module): ...@@ -186,17 +183,17 @@ class FlaxRobertaSelfAttention(nn.Module):
self.query = nn.Dense( self.query = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
self.key = nn.Dense( self.key = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
self.value = nn.Dense( self.value = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False): def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
...@@ -255,7 +252,7 @@ class FlaxRobertaSelfOutput(nn.Module): ...@@ -255,7 +252,7 @@ class FlaxRobertaSelfOutput(nn.Module):
def setup(self): def setup(self):
self.dense = nn.Dense( self.dense = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
...@@ -303,7 +300,7 @@ class FlaxRobertaIntermediate(nn.Module): ...@@ -303,7 +300,7 @@ class FlaxRobertaIntermediate(nn.Module):
def setup(self): def setup(self):
self.dense = nn.Dense( self.dense = nn.Dense(
self.config.intermediate_size, self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
self.activation = ACT2FN[self.config.hidden_act] self.activation = ACT2FN[self.config.hidden_act]
...@@ -322,7 +319,7 @@ class FlaxRobertaOutput(nn.Module): ...@@ -322,7 +319,7 @@ class FlaxRobertaOutput(nn.Module):
def setup(self): def setup(self):
self.dense = nn.Dense( self.dense = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
...@@ -444,7 +441,7 @@ class FlaxRobertaPooler(nn.Module): ...@@ -444,7 +441,7 @@ class FlaxRobertaPooler(nn.Module):
def setup(self): def setup(self):
self.dense = nn.Dense( self.dense = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
...@@ -463,14 +460,14 @@ class FlaxRobertaLMHead(nn.Module): ...@@ -463,14 +460,14 @@ class FlaxRobertaLMHead(nn.Module):
self.dense = nn.Dense( self.dense = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.decoder = nn.Dense( self.decoder = nn.Dense(
self.config.vocab_size, self.config.vocab_size,
dtype=self.dtype, dtype=self.dtype,
use_bias=False, use_bias=False,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,)) self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
...@@ -484,7 +481,8 @@ class FlaxRobertaLMHead(nn.Module): ...@@ -484,7 +481,8 @@ class FlaxRobertaLMHead(nn.Module):
else: else:
hidden_states = self.decoder(hidden_states) hidden_states = self.decoder(hidden_states)
hidden_states += self.bias bias = jnp.asarray(self.bias, self.dtype)
hidden_states += bias
return hidden_states return hidden_states
...@@ -496,7 +494,7 @@ class FlaxRobertaClassificationHead(nn.Module): ...@@ -496,7 +494,7 @@ class FlaxRobertaClassificationHead(nn.Module):
self.dense = nn.Dense( self.dense = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
classifier_dropout = ( classifier_dropout = (
self.config.classifier_dropout self.config.classifier_dropout
...@@ -507,7 +505,7 @@ class FlaxRobertaClassificationHead(nn.Module): ...@@ -507,7 +505,7 @@ class FlaxRobertaClassificationHead(nn.Module):
self.out_proj = nn.Dense( self.out_proj = nn.Dense(
self.config.num_labels, self.config.num_labels,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
def __call__(self, hidden_states, deterministic=True): def __call__(self, hidden_states, deterministic=True):
......
...@@ -98,13 +98,13 @@ class FlaxT5DenseReluDense(nn.Module): ...@@ -98,13 +98,13 @@ class FlaxT5DenseReluDense(nn.Module):
self.wi = nn.Dense( self.wi = nn.Dense(
self.config.d_ff, self.config.d_ff,
use_bias=False, use_bias=False,
kernel_init=jax.nn.initializers.normal(wi_init_std, self.dtype), kernel_init=jax.nn.initializers.normal(wi_init_std),
dtype=self.dtype, dtype=self.dtype,
) )
self.wo = nn.Dense( self.wo = nn.Dense(
self.config.d_model, self.config.d_model,
use_bias=False, use_bias=False,
kernel_init=jax.nn.initializers.normal(wo_init_std, self.dtype), kernel_init=jax.nn.initializers.normal(wo_init_std),
dtype=self.dtype, dtype=self.dtype,
) )
self.dropout = nn.Dropout(self.config.dropout_rate) self.dropout = nn.Dropout(self.config.dropout_rate)
...@@ -128,19 +128,19 @@ class FlaxT5DenseGatedGeluDense(nn.Module): ...@@ -128,19 +128,19 @@ class FlaxT5DenseGatedGeluDense(nn.Module):
self.wi_0 = nn.Dense( self.wi_0 = nn.Dense(
self.config.d_ff, self.config.d_ff,
use_bias=False, use_bias=False,
kernel_init=jax.nn.initializers.normal(wi_init_std, self.dtype), kernel_init=jax.nn.initializers.normal(wi_init_std),
dtype=self.dtype, dtype=self.dtype,
) )
self.wi_1 = nn.Dense( self.wi_1 = nn.Dense(
self.config.d_ff, self.config.d_ff,
use_bias=False, use_bias=False,
kernel_init=jax.nn.initializers.normal(wi_init_std, self.dtype), kernel_init=jax.nn.initializers.normal(wi_init_std),
dtype=self.dtype, dtype=self.dtype,
) )
self.wo = nn.Dense( self.wo = nn.Dense(
self.config.d_model, self.config.d_model,
use_bias=False, use_bias=False,
kernel_init=jax.nn.initializers.normal(wo_init_std, self.dtype), kernel_init=jax.nn.initializers.normal(wo_init_std),
dtype=self.dtype, dtype=self.dtype,
) )
self.dropout = nn.Dropout(self.config.dropout_rate) self.dropout = nn.Dropout(self.config.dropout_rate)
...@@ -200,25 +200,25 @@ class FlaxT5Attention(nn.Module): ...@@ -200,25 +200,25 @@ class FlaxT5Attention(nn.Module):
self.q = nn.Dense( self.q = nn.Dense(
self.inner_dim, self.inner_dim,
use_bias=False, use_bias=False,
kernel_init=jax.nn.initializers.normal(q_init_std, self.dtype), kernel_init=jax.nn.initializers.normal(q_init_std),
dtype=self.dtype, dtype=self.dtype,
) )
self.k = nn.Dense( self.k = nn.Dense(
self.inner_dim, self.inner_dim,
use_bias=False, use_bias=False,
kernel_init=jax.nn.initializers.normal(kv_init_std, self.dtype), kernel_init=jax.nn.initializers.normal(kv_init_std),
dtype=self.dtype, dtype=self.dtype,
) )
self.v = nn.Dense( self.v = nn.Dense(
self.inner_dim, self.inner_dim,
use_bias=False, use_bias=False,
kernel_init=jax.nn.initializers.normal(kv_init_std, self.dtype), kernel_init=jax.nn.initializers.normal(kv_init_std),
dtype=self.dtype, dtype=self.dtype,
) )
self.o = nn.Dense( self.o = nn.Dense(
self.d_model, self.d_model,
use_bias=False, use_bias=False,
kernel_init=jax.nn.initializers.normal(o_init_std, self.dtype), kernel_init=jax.nn.initializers.normal(o_init_std),
dtype=self.dtype, dtype=self.dtype,
) )
...@@ -226,8 +226,7 @@ class FlaxT5Attention(nn.Module): ...@@ -226,8 +226,7 @@ class FlaxT5Attention(nn.Module):
self.relative_attention_bias = nn.Embed( self.relative_attention_bias = nn.Embed(
self.relative_attention_num_buckets, self.relative_attention_num_buckets,
self.n_heads, self.n_heads,
embedding_init=jax.nn.initializers.normal(kv_init_std, self.dtype), embedding_init=jax.nn.initializers.normal(kv_init_std),
dtype=self.dtype,
) )
@staticmethod @staticmethod
...@@ -500,10 +499,13 @@ class FlaxT5LayerSelfAttention(nn.Module): ...@@ -500,10 +499,13 @@ class FlaxT5LayerSelfAttention(nn.Module):
class FlaxT5LayerCrossAttention(nn.Module): class FlaxT5LayerCrossAttention(nn.Module):
config: T5Config config: T5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self): def setup(self):
self.EncDecAttention = FlaxT5Attention(self.config, has_relative_attention_bias=False, causal=False) self.EncDecAttention = FlaxT5Attention(
self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon) self.config, has_relative_attention_bias=False, causal=False, dtype=self.dtype
)
self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype)
self.dropout = nn.Dropout(self.config.dropout_rate) self.dropout = nn.Dropout(self.config.dropout_rate)
def __call__( def __call__(
...@@ -537,15 +539,18 @@ class FlaxT5Block(nn.Module): ...@@ -537,15 +539,18 @@ class FlaxT5Block(nn.Module):
self.causal = self.config.causal self.causal = self.config.causal
self.layer = ( self.layer = (
FlaxT5LayerSelfAttention( FlaxT5LayerSelfAttention(
self.config, has_relative_attention_bias=self.has_relative_attention_bias, name=str(0) self.config,
has_relative_attention_bias=self.has_relative_attention_bias,
name=str(0),
dtype=self.dtype,
), ),
) )
feed_forward_index = 1 feed_forward_index = 1
if self.causal: if self.causal:
self.layer += (FlaxT5LayerCrossAttention(self.config, name=str(1)),) self.layer += (FlaxT5LayerCrossAttention(self.config, name=str(1), dtype=self.dtype),)
feed_forward_index += 1 feed_forward_index += 1
self.layer += (FlaxT5LayerFF(self.config, name=str(feed_forward_index)),) self.layer += (FlaxT5LayerFF(self.config, name=str(feed_forward_index), dtype=self.dtype),)
def __call__( def __call__(
self, self,
...@@ -714,11 +719,10 @@ class FlaxT5Stack(nn.Module): ...@@ -714,11 +719,10 @@ class FlaxT5Stack(nn.Module):
self.embed_tokens = nn.Embed( self.embed_tokens = nn.Embed(
self.config.vocab_size, self.config.vocab_size,
self.config.d_model, self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
) )
self.block = FlaxT5BlockCollection(self.config) self.block = FlaxT5BlockCollection(self.config, dtype=self.dtype)
self.final_layer_norm = FlaxT5LayerNorm( self.final_layer_norm = FlaxT5LayerNorm(
self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
) )
...@@ -1225,6 +1229,18 @@ T5_START_DOCSTRING = r""" ...@@ -1225,6 +1229,18 @@ T5_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights. model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
""" """
...@@ -1246,8 +1262,7 @@ class FlaxT5Module(nn.Module): ...@@ -1246,8 +1262,7 @@ class FlaxT5Module(nn.Module):
self.shared = nn.Embed( self.shared = nn.Embed(
self.config.vocab_size, self.config.vocab_size,
self.config.d_model, self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0, self.dtype), embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0),
dtype=self.dtype,
) )
encoder_config = copy.deepcopy(self.config) encoder_config = copy.deepcopy(self.config)
...@@ -1358,25 +1373,25 @@ class FlaxT5ForConditionalGenerationModule(nn.Module): ...@@ -1358,25 +1373,25 @@ class FlaxT5ForConditionalGenerationModule(nn.Module):
self.shared = nn.Embed( self.shared = nn.Embed(
self.config.vocab_size, self.config.vocab_size,
self.config.d_model, self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor, self.dtype), embedding_init=jax.nn.initializers.normal(self.config.initializer_factor),
) )
encoder_config = copy.deepcopy(self.config) encoder_config = copy.deepcopy(self.config)
encoder_config.causal = False encoder_config.causal = False
encoder_config.use_cache = False encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False encoder_config.is_encoder_decoder = False
self.encoder = FlaxT5Stack(encoder_config, self.shared) self.encoder = FlaxT5Stack(encoder_config, self.shared, dtype=self.dtype)
decoder_config = copy.deepcopy(self.config) decoder_config = copy.deepcopy(self.config)
decoder_config.causal = True decoder_config.causal = True
decoder_config.is_encoder_decoder = False decoder_config.is_encoder_decoder = False
decoder_config.num_layers = self.config.num_decoder_layers decoder_config.num_layers = self.config.num_decoder_layers
self.decoder = FlaxT5Stack(decoder_config, self.shared) self.decoder = FlaxT5Stack(decoder_config, self.shared, dtype=self.dtype)
self.lm_head = nn.Dense( self.lm_head = nn.Dense(
self.config.vocab_size, self.config.vocab_size,
use_bias=False, use_bias=False,
kernel_init=jax.nn.initializers.normal(self.config.initializer_factor, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_factor),
dtype=self.dtype, dtype=self.dtype,
) )
......
...@@ -68,6 +68,18 @@ VISION_ENCODER_DECODER_START_DOCSTRING = r""" ...@@ -68,6 +68,18 @@ VISION_ENCODER_DECODER_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights. model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
""" """
VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r""" VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
...@@ -185,7 +197,7 @@ class FlaxVisionEncoderDecoderModule(nn.Module): ...@@ -185,7 +197,7 @@ class FlaxVisionEncoderDecoderModule(nn.Module):
): ):
self.enc_to_dec_proj = nn.Dense( self.enc_to_dec_proj = nn.Dense(
self.decoder.config.hidden_size, self.decoder.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
else: else:
......
...@@ -54,6 +54,18 @@ VIT_START_DOCSTRING = r""" ...@@ -54,6 +54,18 @@ VIT_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights. model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
""" """
VIT_INPUTS_DOCSTRING = r""" VIT_INPUTS_DOCSTRING = r"""
...@@ -89,7 +101,7 @@ class FlaxPatchEmbeddings(nn.Module): ...@@ -89,7 +101,7 @@ class FlaxPatchEmbeddings(nn.Module):
strides=(patch_size, patch_size), strides=(patch_size, patch_size),
padding="VALID", padding="VALID",
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
def __call__(self, pixel_values): def __call__(self, pixel_values):
...@@ -138,19 +150,19 @@ class FlaxViTSelfAttention(nn.Module): ...@@ -138,19 +150,19 @@ class FlaxViTSelfAttention(nn.Module):
self.query = nn.Dense( self.query = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
use_bias=self.config.qkv_bias, use_bias=self.config.qkv_bias,
) )
self.key = nn.Dense( self.key = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
use_bias=self.config.qkv_bias, use_bias=self.config.qkv_bias,
) )
self.value = nn.Dense( self.value = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
use_bias=self.config.qkv_bias, use_bias=self.config.qkv_bias,
) )
...@@ -196,7 +208,7 @@ class FlaxViTSelfOutput(nn.Module): ...@@ -196,7 +208,7 @@ class FlaxViTSelfOutput(nn.Module):
def setup(self): def setup(self):
self.dense = nn.Dense( self.dense = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
...@@ -235,7 +247,7 @@ class FlaxViTIntermediate(nn.Module): ...@@ -235,7 +247,7 @@ class FlaxViTIntermediate(nn.Module):
def setup(self): def setup(self):
self.dense = nn.Dense( self.dense = nn.Dense(
self.config.intermediate_size, self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
self.activation = ACT2FN[self.config.hidden_act] self.activation = ACT2FN[self.config.hidden_act]
...@@ -253,7 +265,7 @@ class FlaxViTOutput(nn.Module): ...@@ -253,7 +265,7 @@ class FlaxViTOutput(nn.Module):
def setup(self): def setup(self):
self.dense = nn.Dense( self.dense = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
...@@ -376,7 +388,7 @@ class FlaxViTPooler(nn.Module): ...@@ -376,7 +388,7 @@ class FlaxViTPooler(nn.Module):
def setup(self): def setup(self):
self.dense = nn.Dense( self.dense = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype, dtype=self.dtype,
) )
...@@ -533,7 +545,7 @@ class FlaxViTForImageClassificationModule(nn.Module): ...@@ -533,7 +545,7 @@ class FlaxViTForImageClassificationModule(nn.Module):
self.classifier = nn.Dense( self.classifier = nn.Dense(
self.config.num_labels, self.config.num_labels,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
def __call__( def __call__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment