Unverified Commit e92190c0 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

Fix Flax params dtype (#13098)



* fix inits

* fix embed dtype

* fix embed dtype

* add test to check default dtype

* quality

* add type conversion methods for flax models

* more robust casting

* cast sinusoidal positions

* update pegasus

* update albert

* update test

* make sure dtype is passed to every module

* style

* fix electra dense

* fix t5

* quality

* add more tests

* better name

* use the dtype for lm head computation

* fix albert

* style

* fix albert embed dtype

* more tests

* fix vision enc-dec

* cleanup

* fix embed dtype pegasus

* fix default param test

* doc

* update template

* fix final_logits_bias dtype

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix doc

* fix doc

* add detailed docstring for dtype parameter

* remove un-necessary import
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 1c76a516
......@@ -50,13 +50,13 @@ class FlaxHybridCLIPModule(nn.Module):
self.visual_projection = nn.Dense(
self.projection_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
kernel_init=jax.nn.initializers.normal(0.02),
use_bias=False,
)
self.text_projection = nn.Dense(
self.projection_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
kernel_init=jax.nn.initializers.normal(0.02),
use_bias=False,
)
self.logit_scale = self.param("logit_scale", jax.nn.initializers.ones, [])
......
......@@ -16,7 +16,7 @@
import os
from functools import partial
from pickle import UnpicklingError
from typing import Dict, Set, Tuple, Union
from typing import Any, Dict, Set, Tuple, Union
import flax.linen as nn
import jax
......@@ -154,6 +154,122 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
)
self._params = params
def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
"""
Helper method to cast floating-point values of given parameter ``PyTree`` to given ``dtype``.
"""
# taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
def conditional_cast(param):
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
param = param.astype(dtype)
return param
if mask is None:
return jax.tree_map(conditional_cast, params)
flat_params = flatten_dict(params)
flat_mask, _ = jax.tree_flatten(mask)
for masked, key in zip(flat_mask, flat_params.keys()):
if masked:
param = flat_params[key]
flat_params[key] = conditional_cast(param)
return unflatten_dict(flat_params)
def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
r"""
Cast the floating-point ``params`` to ``jax.numpy.bfloat16``. This returns a new ``params`` tree and does not
cast the ``params`` in place.
This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full
half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
Arguments:
params (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree`` of model parameters.
mask (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree`` with same structure as the ``params`` tree. The leaves should be booleans, :obj:`True` for
params you want to cast, and should be :obj:`False` for those you want to skip.
Examples::
>>> from transformers import FlaxBertModel
>>> # load model
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
>>> model.params = model.to_bf16(model.params)
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
>>> # then pass the mask as follows
>>> from flax import traverse_util
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> flat_params = traverse_util.flatten_dict(model.params)
>>> mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
>>> mask = traverse_util.unflatten_dict(mask)
>>> model.params = model.to_bf16(model.params, mask)
"""
return self._cast_floating_to(params, jnp.bfloat16, mask)
def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
r"""
Cast the floating-point ``parmas`` to ``jax.numpy.float32``. This method can be used to explicitly convert the
model parameters to fp32 precision. This returns a new ``params`` tree and does not cast the ``params`` in
place.
Arguments:
params (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree`` of model parameters.
mask (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree`` with same structure as the ``params`` tree. The leaves should be booleans, :obj:`True` for
params you want to cast, and should be :obj:`False` for those you want to skip
Examples::
>>> from transformers import FlaxBertModel
>>> # Download model and configuration from huggingface.co
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> # By default, the model params will be in fp32, to illustrate the use of this method,
>>> # we'll first cast to fp16 and back to fp32
>>> model.params = model.to_f16(model.params)
>>> # now cast back to fp32
>>> model.params = model.to_fp32(model.params)
"""
return self._cast_floating_to(params, jnp.float32, mask)
def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
r"""
Cast the floating-point ``parmas`` to ``jax.numpy.float16``. This returns a new ``params`` tree and does not
cast the ``params`` in place.
This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full
half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
Arguments:
params (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree`` of model parameters.
mask (:obj:`Union[Dict, FrozenDict]`):
A ``PyTree`` with same structure as the ``params`` tree. The leaves should be booleans, :obj:`True` for
params you want to cast, and should be :obj:`False` for those you want to skip
Examples::
>>> from transformers import FlaxBertModel
>>> # Download model and configuration from huggingface.co
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> # By default, the model params will be in fp32, to cast these to float16
>>> model.params = model.to_f16(model.params)
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
>>> # then pass the mask as follows
>>> from flax import traverse_util
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> flat_params = traverse_util.flatten_dict(model.params)
>>> mask = {path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
>>> mask = traverse_util.unflatten_dict(mask)
>>> model.params = model.to_f16(model.params, mask)
"""
return self._cast_floating_to(params, jnp.float16, mask)
@classmethod
def from_pretrained(
cls,
......@@ -184,6 +300,19 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
- A path or url to a `pt index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In this
case, ``from_pt`` should be set to :obj:`True`.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and
:meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
model_args (sequence of positional arguments, `optional`):
All remaining positional arguments will be passed to the underlying model's ``__init__`` method.
config (:obj:`Union[PretrainedConfig, str, os.PathLike]`, `optional`):
......
......@@ -105,6 +105,18 @@ ALBERT_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
"""
ALBERT_INPUTS_DOCSTRING = r"""
......@@ -152,19 +164,16 @@ class FlaxAlbertEmbeddings(nn.Module):
self.config.vocab_size,
self.config.embedding_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.position_embeddings = nn.Embed(
self.config.max_position_embeddings,
self.config.embedding_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.token_type_embeddings = nn.Embed(
self.config.type_vocab_size,
self.config.embedding_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
......@@ -199,21 +208,21 @@ class FlaxAlbertSelfAttention(nn.Module):
self.query = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.key = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.value = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
......@@ -278,13 +287,13 @@ class FlaxAlbertLayer(nn.Module):
self.attention = FlaxAlbertSelfAttention(self.config, dtype=self.dtype)
self.ffn = nn.Dense(
self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.activation = ACT2FN[self.config.hidden_act]
self.ffn_output = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.full_layer_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
......@@ -437,7 +446,7 @@ class FlaxAlbertEncoder(nn.Module):
def setup(self):
self.embedding_hidden_mapping_in = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.albert_layer_groups = FlaxAlbertLayerGroups(self.config, dtype=self.dtype)
......@@ -596,7 +605,7 @@ class FlaxAlbertModule(nn.Module):
if self.add_pooling_layer:
self.pooler = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
name="pooler",
)
......
......@@ -79,6 +79,18 @@ BART_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
"""
BART_INPUTS_DOCSTRING = r"""
......@@ -248,7 +260,7 @@ class FlaxBartAttention(nn.Module):
self.embed_dim,
use_bias=self.bias,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
......@@ -404,6 +416,7 @@ class FlaxBartEncoderLayer(nn.Module):
embed_dim=self.embed_dim,
num_heads=self.config.encoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
......@@ -412,10 +425,10 @@ class FlaxBartEncoderLayer(nn.Module):
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
......@@ -514,6 +527,7 @@ class FlaxBartDecoderLayer(nn.Module):
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
causal=True,
dtype=self.dtype,
)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
......@@ -525,15 +539,16 @@ class FlaxBartDecoderLayer(nn.Module):
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
......@@ -668,13 +683,13 @@ class FlaxBartClassificationHead(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.dropout = nn.Dropout(rate=self.pooler_dropout)
self.out_proj = nn.Dense(
self.num_classes,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
def __call__(self, hidden_states: jnp.ndarray, deterministic: bool):
......@@ -703,8 +718,7 @@ class FlaxBartEncoder(nn.Module):
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
......@@ -713,8 +727,7 @@ class FlaxBartEncoder(nn.Module):
self.embed_positions = nn.Embed(
self.config.max_position_embeddings + self.offset,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype)
......@@ -776,8 +789,7 @@ class FlaxBartDecoder(nn.Module):
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
......@@ -786,8 +798,7 @@ class FlaxBartDecoder(nn.Module):
self.embed_positions = nn.Embed(
self.config.max_position_embeddings + self.offset,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
......@@ -850,8 +861,7 @@ class FlaxBartModule(nn.Module):
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
......@@ -1256,7 +1266,7 @@ class FlaxBartForConditionalGenerationModule(nn.Module):
self.model.shared.num_embeddings,
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))
......@@ -1300,7 +1310,7 @@ class FlaxBartForConditionalGenerationModule(nn.Module):
else:
lm_logits = self.lm_head(hidden_states)
lm_logits += self.final_logits_bias
lm_logits += self.final_logits_bias.astype(self.dtype)
if not return_dict:
output = (lm_logits,) + outputs[1:]
......@@ -1416,7 +1426,7 @@ class FlaxBartForConditionalGeneration(FlaxBartPreTrainedModel):
else:
lm_logits = module.lm_head(hidden_states)
lm_logits += module.final_logits_bias
lm_logits += module.final_logits_bias.astype(self.dtype)
return lm_logits, outputs
outputs = self.module.apply(
......@@ -1647,7 +1657,7 @@ class FlaxBartForQuestionAnsweringModule(nn.Module):
def setup(self):
self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
self.qa_outputs = nn.Dense(
self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
def _get_encoder_module(self):
......
......@@ -86,6 +86,18 @@ BEIT_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
"""
BEIT_INPUTS_DOCSTRING = r"""
......
......@@ -106,6 +106,31 @@ BERT_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
"""
BERT_INPUTS_DOCSTRING = r"""
......@@ -153,19 +178,16 @@ class FlaxBertEmbeddings(nn.Module):
self.config.vocab_size,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.position_embeddings = nn.Embed(
self.config.max_position_embeddings,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.token_type_embeddings = nn.Embed(
self.config.type_vocab_size,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
......@@ -199,17 +221,17 @@ class FlaxBertSelfAttention(nn.Module):
self.query = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.key = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.value = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
......@@ -267,7 +289,7 @@ class FlaxBertSelfOutput(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
......@@ -313,7 +335,7 @@ class FlaxBertIntermediate(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.activation = ACT2FN[self.config.hidden_act]
......@@ -331,7 +353,7 @@ class FlaxBertOutput(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
......@@ -449,7 +471,7 @@ class FlaxBertPooler(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
......@@ -492,7 +514,8 @@ class FlaxBertLMPredictionHead(nn.Module):
else:
hidden_states = self.decoder(hidden_states)
hidden_states += self.bias
bias = jnp.asarray(self.bias, self.dtype)
hidden_states += bias
return hidden_states
......
......@@ -136,6 +136,18 @@ BIG_BIRD_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
"""
BIG_BIRD_INPUTS_DOCSTRING = r"""
......@@ -184,19 +196,16 @@ class FlaxBigBirdEmbeddings(nn.Module):
self.config.vocab_size,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.position_embeddings = nn.Embed(
self.config.max_position_embeddings,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.token_type_embeddings = nn.Embed(
self.config.type_vocab_size,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
......@@ -234,17 +243,17 @@ class FlaxBigBirdSelfAttention(nn.Module):
self.query = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.key = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.value = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
......@@ -305,19 +314,19 @@ class FlaxBigBirdBlockSparseAttention(nn.Module):
self.config.hidden_size,
dtype=self.dtype,
use_bias=self.config.use_bias,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.key = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
use_bias=self.config.use_bias,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.value = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
use_bias=self.config.use_bias,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
@staticmethod
......@@ -1074,7 +1083,7 @@ class FlaxBigBirdSelfOutput(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
......@@ -1131,7 +1140,7 @@ class FlaxBigBirdIntermediate(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.activation = ACT2FN[self.config.hidden_act]
......@@ -1150,7 +1159,7 @@ class FlaxBigBirdOutput(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
......@@ -1301,7 +1310,8 @@ class FlaxBigBirdLMPredictionHead(nn.Module):
else:
hidden_states = self.decoder(hidden_states)
hidden_states += self.bias
bias = jnp.asarray(self.bias, self.dtype)
hidden_states += bias
return hidden_states
......@@ -1431,7 +1441,7 @@ class FlaxBigBirdModule(nn.Module):
self.encoder = FlaxBigBirdEncoder(self.config, dtype=self.dtype)
self.pooler = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
......
......@@ -60,6 +60,18 @@ CLIP_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
"""
CLIP_TEXT_INPUTS_DOCSTRING = r"""
......@@ -262,18 +274,10 @@ class FlaxCLIPAttention(nn.Module):
self.scale = self.head_dim ** -0.5
self.dropout = self.config.attention_dropout
self.k_proj = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype)
)
self.v_proj = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype)
)
self.q_proj = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype)
)
self.out_proj = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype)
)
self.k_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01))
self.v_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01))
self.q_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01))
self.out_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01))
self.causal = isinstance(self.config, CLIPTextConfig)
if self.causal:
......@@ -354,11 +358,9 @@ class FlaxCLIPMLP(nn.Module):
self.fc1 = nn.Dense(
self.config.intermediate_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype),
)
self.fc2 = nn.Dense(
self.config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype)
kernel_init=jax.nn.initializers.normal(0.01),
)
self.fc2 = nn.Dense(self.config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01))
def __call__(self, hidden_states):
hidden_states = self.fc1(hidden_states)
......@@ -1032,18 +1034,18 @@ class FlaxCLIPModule(nn.Module):
self.visual_projection = nn.Dense(
self.projection_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
kernel_init=jax.nn.initializers.normal(0.02),
use_bias=False,
)
self.text_projection = nn.Dense(
self.projection_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
kernel_init=jax.nn.initializers.normal(0.02),
use_bias=False,
)
self.logit_scale = self.param(
"logit_scale", lambda _, shape: jnp.ones(shape, dtype=self.dtype) * self.config.logit_scale_init_value, []
"logit_scale", lambda _, shape: jnp.ones(shape) * self.config.logit_scale_init_value, []
)
def __call__(
......
......@@ -102,7 +102,7 @@ def get_angles(pos, i, d_model):
return pos * angle_rates
def positional_encoding(position, d_model, dtype):
def positional_encoding(position, d_model):
# create the sinusoidal pattern for the positional encoding
angle_rads = get_angles(np.arange(position)[:, np.newaxis], np.arange(d_model)[np.newaxis, :], d_model)
......@@ -114,8 +114,7 @@ def positional_encoding(position, d_model, dtype):
pos_encoding = angle_rads[np.newaxis, ...]
# cast to dtype
return jnp.array(pos_encoding, dtype=dtype)
return jnp.array(pos_encoding)
class FlaxEmbeddings(nn.Module):
......@@ -129,17 +128,15 @@ class FlaxEmbeddings(nn.Module):
self.config.vocab_size,
self.config.dim,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
if not self.config.sinusoidal_pos_embds:
self.position_embeddings = nn.Embed(
self.config.max_position_embeddings,
self.config.dim,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
else:
self.pos_encoding = positional_encoding(self.config.max_position_embeddings, self.config.dim, self.dtype)
self.pos_encoding = positional_encoding(self.config.max_position_embeddings, self.config.dim)
self.LayerNorm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.dropout)
......@@ -153,6 +150,8 @@ class FlaxEmbeddings(nn.Module):
position_embeds = self.position_embeddings(position_ids.astype("i4"))
else:
position_embeds = self.pos_encoding[:, :seq_length, :]
# explictly cast the positions here, since self.embed_positions are not registered as parameters
position_embeds = position_embeds.astype(inputs_embeds.dtype)
# Sum all embeddings
hidden_states = inputs_embeds + position_embeds
......@@ -289,10 +288,10 @@ class FlaxTransformerBlock(nn.Module):
), f"Hidden size {self.config.dim} not dividable by number of heads {self.config.n_heads}"
self.attention = FlaxMultiHeadSelfAttention(self.config, dtype=self.dtype)
self.sa_layer_norm = nn.LayerNorm(epsilon=1e-12)
self.sa_layer_norm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype)
self.ffn = FlaxFFN(self.config, dtype=self.dtype)
self.output_layer_norm = nn.LayerNorm(epsilon=1e-12)
self.output_layer_norm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype)
def __call__(
self,
......@@ -412,8 +411,11 @@ class FlaxDistilBertLMDecoder(nn.Module):
self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
def __call__(self, inputs, kernel):
inputs = jnp.asarray(inputs, self.dtype)
kernel = jnp.asarray(kernel, self.dtype)
y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())))
y = y + self.bias
bias = jnp.asarray(self.bias, self.dtype)
y = y + bias
return y
......
......@@ -148,19 +148,16 @@ class FlaxElectraEmbeddings(nn.Module):
self.config.vocab_size,
self.config.embedding_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.position_embeddings = nn.Embed(
self.config.max_position_embeddings,
self.config.embedding_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.token_type_embeddings = nn.Embed(
self.config.type_vocab_size,
self.config.embedding_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
......@@ -196,17 +193,17 @@ class FlaxElectraSelfAttention(nn.Module):
self.query = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.key = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.value = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
......@@ -265,7 +262,7 @@ class FlaxElectraSelfOutput(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
......@@ -313,7 +310,7 @@ class FlaxElectraIntermediate(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.activation = ACT2FN[self.config.hidden_act]
......@@ -332,7 +329,7 @@ class FlaxElectraOutput(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
......@@ -570,7 +567,7 @@ class FlaxElectraModule(nn.Module):
def setup(self):
self.embeddings = FlaxElectraEmbeddings(self.config, dtype=self.dtype)
if self.config.embedding_size != self.config.hidden_size:
self.embeddings_project = nn.Dense(self.config.hidden_size)
self.embeddings_project = nn.Dense(self.config.hidden_size, dtype=self.dtype)
self.encoder = FlaxElectraEncoder(self.config, dtype=self.dtype)
def __call__(
......@@ -620,17 +617,19 @@ class FlaxElectraTiedDense(nn.Module):
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
def setup(self):
bias = self.param("bias", self.bias_init, (self.embedding_size,))
self.bias = jnp.asarray(bias, dtype=self.dtype)
self.bias = self.param("bias", self.bias_init, (self.embedding_size,))
def __call__(self, x, kernel):
x = jnp.asarray(x, self.dtype)
kernel = jnp.asarray(kernel, self.dtype)
y = lax.dot_general(
x,
kernel,
(((x.ndim - 1,), (0,)), ((), ())),
precision=self.precision,
)
return y + self.bias
bias = jnp.asarray(self.bias, self.dtype)
return y + bias
class FlaxElectraForMaskedLMModule(nn.Module):
......@@ -639,7 +638,7 @@ class FlaxElectraForMaskedLMModule(nn.Module):
def setup(self):
self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype)
self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config)
self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype)
if self.config.tie_word_embeddings:
self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype)
else:
......@@ -788,7 +787,7 @@ class FlaxElectraForTokenClassificationModule(nn.Module):
else self.config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Dense(self.config.num_labels)
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
def __call__(
self,
......
......@@ -64,6 +64,18 @@ ENCODER_DECODER_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
"""
ENCODER_DECODER_INPUTS_DOCSTRING = r"""
......
......@@ -62,6 +62,18 @@ GPT2_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
"""
GPT2_INPUTS_DOCSTRING = r"""
......@@ -576,13 +588,11 @@ class FlaxGPT2Module(nn.Module):
self.config.vocab_size,
self.embed_dim,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.wpe = nn.Embed(
self.config.max_position_embeddings,
self.embed_dim,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
self.h = FlaxGPT2BlockCollection(self.config, dtype=self.dtype)
......@@ -666,7 +676,7 @@ class FlaxGPT2LMHeadModule(nn.Module):
self.config.vocab_size,
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range, dtype=self.dtype),
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
)
def __call__(
......
......@@ -60,6 +60,18 @@ GPT_NEO_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
"""
GPT_NEO_INPUTS_DOCSTRING = r"""
......@@ -119,7 +131,7 @@ class FlaxGPTNeoSelfAttention(nn.Module):
nn.Dense,
self.embed_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.q_proj, self.k_proj, self.v_proj = dense(use_bias=False), dense(use_bias=False), dense(use_bias=False)
......@@ -270,7 +282,7 @@ class FlaxGPTNeoMLP(nn.Module):
def setup(self):
embed_dim = self.config.hidden_size
kernel_init = jax.nn.initializers.normal(self.config.initializer_range, self.dtype)
kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
self.c_fc = nn.Dense(self.intermediate_size, dtype=self.dtype, kernel_init=kernel_init)
self.c_proj = nn.Dense(embed_dim, dtype=self.dtype, kernel_init=kernel_init)
self.act = ACT2FN[self.config.activation_function]
......@@ -505,13 +517,11 @@ class FlaxGPTNeoModule(nn.Module):
self.config.vocab_size,
self.embed_dim,
embedding_init=embedding_init,
dtype=self.dtype,
)
self.wpe = nn.Embed(
self.config.max_position_embeddings,
self.embed_dim,
embedding_init=embedding_init,
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.embed_dropout)
self.h = FlaxGPTNeoBlockCollection(self.config, dtype=self.dtype)
......@@ -589,7 +599,7 @@ class FlaxGPTNeoForCausalLMModule(nn.Module):
self.config.vocab_size,
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range, dtype=self.dtype),
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
)
def __call__(
......
......@@ -71,6 +71,18 @@ MARIAN_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
"""
MARIAN_INPUTS_DOCSTRING = r"""
......@@ -206,14 +218,14 @@ MARIAN_DECODE_INPUTS_DOCSTRING = r"""
"""
def create_sinusoidal_positions(n_pos, dim, dtype):
def create_sinusoidal_positions(n_pos, dim):
position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
sentinel = dim // 2 + dim % 2
out = np.zeros_like(position_enc)
out[:, 0:sentinel] = np.sin(position_enc[:, 0::2])
out[:, sentinel:] = np.cos(position_enc[:, 1::2])
return jnp.array(out, dtype=dtype)
return jnp.array(out)
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
......@@ -252,7 +264,7 @@ class FlaxMarianAttention(nn.Module):
self.embed_dim,
use_bias=self.bias,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
......@@ -409,6 +421,7 @@ class FlaxMarianEncoderLayer(nn.Module):
embed_dim=self.embed_dim,
num_heads=self.config.encoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
......@@ -417,10 +430,10 @@ class FlaxMarianEncoderLayer(nn.Module):
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
......@@ -522,6 +535,7 @@ class FlaxMarianDecoderLayer(nn.Module):
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
causal=True,
dtype=self.dtype,
)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
......@@ -533,15 +547,16 @@ class FlaxMarianDecoderLayer(nn.Module):
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
......@@ -683,13 +698,10 @@ class FlaxMarianEncoder(nn.Module):
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.embed_positions = create_sinusoidal_positions(
self.config.max_position_embeddings, embed_dim, dtype=self.dtype
)
self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim)
self.layers = FlaxMarianEncoderLayerCollection(self.config, self.dtype)
def __call__(
......@@ -708,6 +720,8 @@ class FlaxMarianEncoder(nn.Module):
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
positions = jnp.take(self.embed_positions, position_ids, axis=0)
# explictly cast the positions here, since self.embed_positions are not registered as parameters
positions = positions.astype(inputs_embeds.dtype)
hidden_states = inputs_embeds + positions
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
......@@ -747,13 +761,10 @@ class FlaxMarianDecoder(nn.Module):
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.embed_positions = create_sinusoidal_positions(
self.config.max_position_embeddings, embed_dim, dtype=self.dtype
)
self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim)
self.layers = FlaxMarianDecoderLayerCollection(self.config, self.dtype)
def __call__(
......@@ -776,6 +787,8 @@ class FlaxMarianDecoder(nn.Module):
# embed positions
positions = jnp.take(self.embed_positions, position_ids, axis=0)
# explictly cast the positions here, since self.embed_positions are not registered as parameters
positions = positions.astype(inputs_embeds.dtype)
hidden_states = inputs_embeds + positions
......@@ -812,8 +825,7 @@ class FlaxMarianModule(nn.Module):
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.encoder = FlaxMarianEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
......@@ -1214,7 +1226,7 @@ class FlaxMarianMTModule(nn.Module):
self.model.shared.num_embeddings,
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))
......@@ -1258,7 +1270,7 @@ class FlaxMarianMTModule(nn.Module):
else:
lm_logits = self.lm_head(hidden_states)
lm_logits += self.final_logits_bias
lm_logits += self.final_logits_bias.astype(self.dtype)
if not return_dict:
output = (lm_logits,) + outputs[1:]
......@@ -1373,7 +1385,7 @@ class FlaxMarianMTModel(FlaxMarianPreTrainedModel):
lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
else:
lm_logits = module.lm_head(hidden_states)
lm_logits += module.final_logits_bias
lm_logits += module.final_logits_bias.astype(self.dtype)
return lm_logits, outputs
......
......@@ -79,6 +79,18 @@ MBART_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
"""
MBART_INPUTS_DOCSTRING = r"""
......@@ -259,7 +271,7 @@ class FlaxMBartAttention(nn.Module):
self.embed_dim,
use_bias=self.bias,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
......@@ -415,6 +427,7 @@ class FlaxMBartEncoderLayer(nn.Module):
embed_dim=self.embed_dim,
num_heads=self.config.encoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
......@@ -423,10 +436,10 @@ class FlaxMBartEncoderLayer(nn.Module):
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
......@@ -526,6 +539,7 @@ class FlaxMBartDecoderLayer(nn.Module):
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
causal=True,
dtype=self.dtype,
)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
......@@ -537,15 +551,16 @@ class FlaxMBartDecoderLayer(nn.Module):
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
......@@ -683,13 +698,13 @@ class FlaxMBartClassificationHead(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.dropout = nn.Dropout(rate=self.pooler_dropout)
self.out_proj = nn.Dense(
self.num_classes,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
def __call__(self, hidden_states: jnp.ndarray, deterministic: bool):
......@@ -718,8 +733,7 @@ class FlaxMBartEncoder(nn.Module):
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
# MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
......@@ -728,8 +742,7 @@ class FlaxMBartEncoder(nn.Module):
self.embed_positions = nn.Embed(
self.config.max_position_embeddings + self.offset,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.layers = FlaxMBartEncoderLayerCollection(self.config, self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype)
......@@ -795,8 +808,7 @@ class FlaxMBartDecoder(nn.Module):
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
# MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
......@@ -805,8 +817,7 @@ class FlaxMBartDecoder(nn.Module):
self.embed_positions = nn.Embed(
self.config.max_position_embeddings + self.offset,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.layers = FlaxMBartDecoderLayerCollection(self.config, self.dtype)
......@@ -874,8 +885,7 @@ class FlaxMBartModule(nn.Module):
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.encoder = FlaxMBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
......@@ -1280,7 +1290,7 @@ class FlaxMBartForConditionalGenerationModule(nn.Module):
self.model.shared.num_embeddings,
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))
......@@ -1324,7 +1334,7 @@ class FlaxMBartForConditionalGenerationModule(nn.Module):
else:
lm_logits = self.lm_head(hidden_states)
lm_logits += self.final_logits_bias
lm_logits += self.final_logits_bias.astype(self.dtype)
if not return_dict:
output = (lm_logits,) + outputs[1:]
......@@ -1440,7 +1450,7 @@ class FlaxMBartForConditionalGeneration(FlaxMBartPreTrainedModel):
else:
lm_logits = module.lm_head(hidden_states)
lm_logits += module.final_logits_bias
lm_logits += module.final_logits_bias.astype(self.dtype)
return lm_logits, outputs
outputs = self.module.apply(
......@@ -1674,7 +1684,7 @@ class FlaxMBartForQuestionAnsweringModule(nn.Module):
def setup(self):
self.model = FlaxMBartModule(config=self.config, dtype=self.dtype)
self.qa_outputs = nn.Dense(
self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
def _get_encoder_module(self):
......
......@@ -78,6 +78,18 @@ PEGASUS_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
"""
PEGASUS_INPUTS_DOCSTRING = r"""
......@@ -226,7 +238,7 @@ def create_sinusoidal_positions(n_pos, dim, dtype):
out[:, 0:sentinel] = np.sin(position_enc[:, 0::2])
out[:, sentinel:] = np.cos(position_enc[:, 1::2])
return jnp.array(out, dtype=dtype)
return jnp.array(out)
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention with Bart->Pegasus
......@@ -252,7 +264,7 @@ class FlaxPegasusAttention(nn.Module):
self.embed_dim,
use_bias=self.bias,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
......@@ -409,6 +421,7 @@ class FlaxPegasusEncoderLayer(nn.Module):
embed_dim=self.embed_dim,
num_heads=self.config.encoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
......@@ -417,10 +430,10 @@ class FlaxPegasusEncoderLayer(nn.Module):
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
......@@ -521,6 +534,7 @@ class FlaxPegasusDecoderLayer(nn.Module):
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
causal=True,
dtype=self.dtype,
)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
......@@ -532,15 +546,16 @@ class FlaxPegasusDecoderLayer(nn.Module):
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
......@@ -683,8 +698,7 @@ class FlaxPegasusEncoder(nn.Module):
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.embed_positions = create_sinusoidal_positions(
......@@ -710,6 +724,8 @@ class FlaxPegasusEncoder(nn.Module):
# embed positions
embed_pos = jnp.take(self.embed_positions, position_ids, axis=0)
# explictly cast the positions here, since self.embed_positions are not registered as parameters
embed_pos = embed_pos.astype(inputs_embeds.dtype)
hidden_states = inputs_embeds + embed_pos
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
......@@ -751,8 +767,7 @@ class FlaxPegasusDecoder(nn.Module):
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.embed_positions = create_sinusoidal_positions(
......@@ -782,6 +797,8 @@ class FlaxPegasusDecoder(nn.Module):
# embed positions
positions = jnp.take(self.embed_positions, position_ids, axis=0)
# explictly cast the positions here, since self.embed_positions are not registered as parameters
positions = positions.astype(inputs_embeds.dtype)
hidden_states = inputs_embeds + positions
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
......@@ -819,8 +836,7 @@ class FlaxPegasusModule(nn.Module):
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.encoder = FlaxPegasusEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
......@@ -1224,7 +1240,7 @@ class FlaxPegasusForConditionalGenerationModule(nn.Module):
self.model.shared.num_embeddings,
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))
......@@ -1268,7 +1284,7 @@ class FlaxPegasusForConditionalGenerationModule(nn.Module):
else:
lm_logits = self.lm_head(hidden_states)
lm_logits += self.final_logits_bias
lm_logits += self.final_logits_bias.astype(self.dtype)
if not return_dict:
output = (lm_logits,) + outputs[1:]
......@@ -1384,7 +1400,7 @@ class FlaxPegasusForConditionalGeneration(FlaxPegasusPreTrainedModel):
else:
lm_logits = module.lm_head(hidden_states)
lm_logits += module.final_logits_bias
lm_logits += module.final_logits_bias.astype(self.dtype)
return lm_logits, outputs
outputs = self.module.apply(
......
......@@ -139,19 +139,16 @@ class FlaxRobertaEmbeddings(nn.Module):
self.config.vocab_size,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.position_embeddings = nn.Embed(
self.config.max_position_embeddings,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.token_type_embeddings = nn.Embed(
self.config.type_vocab_size,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
......@@ -186,17 +183,17 @@ class FlaxRobertaSelfAttention(nn.Module):
self.query = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.key = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.value = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
......@@ -255,7 +252,7 @@ class FlaxRobertaSelfOutput(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
......@@ -303,7 +300,7 @@ class FlaxRobertaIntermediate(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.activation = ACT2FN[self.config.hidden_act]
......@@ -322,7 +319,7 @@ class FlaxRobertaOutput(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
......@@ -444,7 +441,7 @@ class FlaxRobertaPooler(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
......@@ -463,14 +460,14 @@ class FlaxRobertaLMHead(nn.Module):
self.dense = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.decoder = nn.Dense(
self.config.vocab_size,
dtype=self.dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
......@@ -484,7 +481,8 @@ class FlaxRobertaLMHead(nn.Module):
else:
hidden_states = self.decoder(hidden_states)
hidden_states += self.bias
bias = jnp.asarray(self.bias, self.dtype)
hidden_states += bias
return hidden_states
......@@ -496,7 +494,7 @@ class FlaxRobertaClassificationHead(nn.Module):
self.dense = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
classifier_dropout = (
self.config.classifier_dropout
......@@ -507,7 +505,7 @@ class FlaxRobertaClassificationHead(nn.Module):
self.out_proj = nn.Dense(
self.config.num_labels,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
def __call__(self, hidden_states, deterministic=True):
......
......@@ -98,13 +98,13 @@ class FlaxT5DenseReluDense(nn.Module):
self.wi = nn.Dense(
self.config.d_ff,
use_bias=False,
kernel_init=jax.nn.initializers.normal(wi_init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(wi_init_std),
dtype=self.dtype,
)
self.wo = nn.Dense(
self.config.d_model,
use_bias=False,
kernel_init=jax.nn.initializers.normal(wo_init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(wo_init_std),
dtype=self.dtype,
)
self.dropout = nn.Dropout(self.config.dropout_rate)
......@@ -128,19 +128,19 @@ class FlaxT5DenseGatedGeluDense(nn.Module):
self.wi_0 = nn.Dense(
self.config.d_ff,
use_bias=False,
kernel_init=jax.nn.initializers.normal(wi_init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(wi_init_std),
dtype=self.dtype,
)
self.wi_1 = nn.Dense(
self.config.d_ff,
use_bias=False,
kernel_init=jax.nn.initializers.normal(wi_init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(wi_init_std),
dtype=self.dtype,
)
self.wo = nn.Dense(
self.config.d_model,
use_bias=False,
kernel_init=jax.nn.initializers.normal(wo_init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(wo_init_std),
dtype=self.dtype,
)
self.dropout = nn.Dropout(self.config.dropout_rate)
......@@ -200,25 +200,25 @@ class FlaxT5Attention(nn.Module):
self.q = nn.Dense(
self.inner_dim,
use_bias=False,
kernel_init=jax.nn.initializers.normal(q_init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(q_init_std),
dtype=self.dtype,
)
self.k = nn.Dense(
self.inner_dim,
use_bias=False,
kernel_init=jax.nn.initializers.normal(kv_init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(kv_init_std),
dtype=self.dtype,
)
self.v = nn.Dense(
self.inner_dim,
use_bias=False,
kernel_init=jax.nn.initializers.normal(kv_init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(kv_init_std),
dtype=self.dtype,
)
self.o = nn.Dense(
self.d_model,
use_bias=False,
kernel_init=jax.nn.initializers.normal(o_init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(o_init_std),
dtype=self.dtype,
)
......@@ -226,8 +226,7 @@ class FlaxT5Attention(nn.Module):
self.relative_attention_bias = nn.Embed(
self.relative_attention_num_buckets,
self.n_heads,
embedding_init=jax.nn.initializers.normal(kv_init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(kv_init_std),
)
@staticmethod
......@@ -500,10 +499,13 @@ class FlaxT5LayerSelfAttention(nn.Module):
class FlaxT5LayerCrossAttention(nn.Module):
config: T5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.EncDecAttention = FlaxT5Attention(self.config, has_relative_attention_bias=False, causal=False)
self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon)
self.EncDecAttention = FlaxT5Attention(
self.config, has_relative_attention_bias=False, causal=False, dtype=self.dtype
)
self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype)
self.dropout = nn.Dropout(self.config.dropout_rate)
def __call__(
......@@ -537,15 +539,18 @@ class FlaxT5Block(nn.Module):
self.causal = self.config.causal
self.layer = (
FlaxT5LayerSelfAttention(
self.config, has_relative_attention_bias=self.has_relative_attention_bias, name=str(0)
self.config,
has_relative_attention_bias=self.has_relative_attention_bias,
name=str(0),
dtype=self.dtype,
),
)
feed_forward_index = 1
if self.causal:
self.layer += (FlaxT5LayerCrossAttention(self.config, name=str(1)),)
self.layer += (FlaxT5LayerCrossAttention(self.config, name=str(1), dtype=self.dtype),)
feed_forward_index += 1
self.layer += (FlaxT5LayerFF(self.config, name=str(feed_forward_index)),)
self.layer += (FlaxT5LayerFF(self.config, name=str(feed_forward_index), dtype=self.dtype),)
def __call__(
self,
......@@ -714,11 +719,10 @@ class FlaxT5Stack(nn.Module):
self.embed_tokens = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.block = FlaxT5BlockCollection(self.config)
self.block = FlaxT5BlockCollection(self.config, dtype=self.dtype)
self.final_layer_norm = FlaxT5LayerNorm(
self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
)
......@@ -1225,6 +1229,18 @@ T5_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
"""
......@@ -1246,8 +1262,7 @@ class FlaxT5Module(nn.Module):
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0),
)
encoder_config = copy.deepcopy(self.config)
......@@ -1358,25 +1373,25 @@ class FlaxT5ForConditionalGenerationModule(nn.Module):
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor, self.dtype),
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor),
)
encoder_config = copy.deepcopy(self.config)
encoder_config.causal = False
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
self.encoder = FlaxT5Stack(encoder_config, self.shared)
self.encoder = FlaxT5Stack(encoder_config, self.shared, dtype=self.dtype)
decoder_config = copy.deepcopy(self.config)
decoder_config.causal = True
decoder_config.is_encoder_decoder = False
decoder_config.num_layers = self.config.num_decoder_layers
self.decoder = FlaxT5Stack(decoder_config, self.shared)
self.decoder = FlaxT5Stack(decoder_config, self.shared, dtype=self.dtype)
self.lm_head = nn.Dense(
self.config.vocab_size,
use_bias=False,
kernel_init=jax.nn.initializers.normal(self.config.initializer_factor, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_factor),
dtype=self.dtype,
)
......
......@@ -68,6 +68,18 @@ VISION_ENCODER_DECODER_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
"""
VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
......@@ -185,7 +197,7 @@ class FlaxVisionEncoderDecoderModule(nn.Module):
):
self.enc_to_dec_proj = nn.Dense(
self.decoder.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range),
dtype=self.dtype,
)
else:
......
......@@ -54,6 +54,18 @@ VIT_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
"""
VIT_INPUTS_DOCSTRING = r"""
......@@ -89,7 +101,7 @@ class FlaxPatchEmbeddings(nn.Module):
strides=(patch_size, patch_size),
padding="VALID",
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
def __call__(self, pixel_values):
......@@ -138,19 +150,19 @@ class FlaxViTSelfAttention(nn.Module):
self.query = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
use_bias=self.config.qkv_bias,
)
self.key = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
use_bias=self.config.qkv_bias,
)
self.value = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
use_bias=self.config.qkv_bias,
)
......@@ -196,7 +208,7 @@ class FlaxViTSelfOutput(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
......@@ -235,7 +247,7 @@ class FlaxViTIntermediate(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.activation = ACT2FN[self.config.hidden_act]
......@@ -253,7 +265,7 @@ class FlaxViTOutput(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
......@@ -376,7 +388,7 @@ class FlaxViTPooler(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
......@@ -533,7 +545,7 @@ class FlaxViTForImageClassificationModule(nn.Module):
self.classifier = nn.Dense(
self.config.num_labels,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
def __call__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment