Unverified Commit 75627148 authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

Flax Masked Language Modeling training example (#8728)



* Remove "Model" suffix from Flax models to look more :hugs:
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Initial working (forward + backward) for Flax MLM training example.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Simply code
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Addressing comments, using module and moving to LM task.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Restore parameter name "module" wrongly renamed model.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Restore correct output ordering...
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Actually commit the example 😅

Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Add FlaxBertModelForMaskedLM after rebasing.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Make it possible to initialize the training from scratch
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Reuse flax linen example of cross entropy loss
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Added specific data collator for flax
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Remove todo for data collator
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Added evaluation step
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Added ability to provide dtype to support bfloat16 on TPU
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Enable flax tensorboard output
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Enable jax.pmap support.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Ensure batches are correctly sized to be dispatched with jax.pmap
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Enable bfloat16 with --fp16 cmdline args
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Correctly export metrics to tensorboard
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Added dropout and ability to use it.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Effectively enable & disable during training and evaluation steps.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Oops.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Enable specifying kernel initializer scale
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Style.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Added warmup step to the learning rate scheduler.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Fix typo.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Print training loss
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Make style
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* fix linter issue (flake8)
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Fix model matching
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Fix dummies
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Fix non default dtype on Flax models
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Use the same create_position_ids_from_input_ids for FlaxRoberta
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Make Roberta attention as Bert
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* fix copy
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Wording.
Co-authored-by: default avatarMarc van Zee <marcvanzee@gmail.com>
Co-authored-by: default avatarMarc van Zee <marcvanzee@gmail.com>
parent df2af6d8
This diff is collapsed.
...@@ -936,7 +936,7 @@ else: ...@@ -936,7 +936,7 @@ else:
if is_flax_available(): if is_flax_available():
from .models.auto import FLAX_MODEL_MAPPING, FlaxAutoModel from .models.auto import FLAX_MODEL_MAPPING, FlaxAutoModel
from .models.bert import FlaxBertModel from .models.bert import FlaxBertForMaskedLM, FlaxBertModel
from .models.roberta import FlaxRobertaModel from .models.roberta import FlaxRobertaModel
else: else:
# Import the same objects as dummies to get them in the namespace. # Import the same objects as dummies to get them in the namespace.
......
...@@ -65,13 +65,12 @@ class FlaxPreTrainedModel(ABC): ...@@ -65,13 +65,12 @@ class FlaxPreTrainedModel(ABC):
base_model_prefix = "" base_model_prefix = ""
model_class = None model_class = None
def __init__(self, config: PretrainedConfig, module: nn.Module, params: Dict, seed: int = 0): def __init__(
self, config: PretrainedConfig, module: nn.Module, params: Dict, seed: int = 0, dtype: jnp.dtype = jnp.float32
):
if config is None: if config is None:
raise ValueError("config cannot be None") raise ValueError("config cannot be None")
if module is None:
raise ValueError("module cannot be None")
if params is None: if params is None:
raise ValueError("state cannot be None") raise ValueError("state cannot be None")
...@@ -82,19 +81,23 @@ class FlaxPreTrainedModel(ABC): ...@@ -82,19 +81,23 @@ class FlaxPreTrainedModel(ABC):
# Those are public as their type is generic to every derived classes. # Those are public as their type is generic to every derived classes.
self.key = PRNGKey(seed) self.key = PRNGKey(seed)
self.params = params self.params = params
self.model = module self.dtype = dtype
@property @property
def config(self) -> PretrainedConfig: def config(self) -> PretrainedConfig:
return self._config return self._config
@property
def module(self) -> nn.Module:
return self._module
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def convert_from_pytorch(pt_state: Dict, config: PretrainedConfig) -> Dict: def convert_from_pytorch(pt_state: Dict, config: PretrainedConfig) -> Dict:
raise NotImplementedError() raise NotImplementedError()
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, dtype: jnp.dtype = jnp.float32, *model_args, **kwargs):
r""" r"""
Instantiate a pretrained Flax model from a pre-trained model configuration. Instantiate a pretrained Flax model from a pre-trained model configuration.
""" """
...@@ -127,6 +130,9 @@ class FlaxPreTrainedModel(ABC): ...@@ -127,6 +130,9 @@ class FlaxPreTrainedModel(ABC):
else: else:
model_kwargs = kwargs model_kwargs = kwargs
# Add the dtype to model_kwargs
model_kwargs["dtype"] = dtype
# Load model # Load model
if pretrained_model_name_or_path is not None: if pretrained_model_name_or_path is not None:
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
......
...@@ -59,4 +59,4 @@ if is_tf_available(): ...@@ -59,4 +59,4 @@ if is_tf_available():
) )
if is_flax_available(): if is_flax_available():
from .modeling_flax_bert import FlaxBertModel from .modeling_flax_bert import FlaxBertForMaskedLM, FlaxBertModel
...@@ -12,13 +12,14 @@ ...@@ -12,13 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Callable, Dict from typing import Callable, Dict, Tuple
import numpy as np import numpy as np
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax.random import PRNGKey
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_utils import FlaxPreTrainedModel, gelu from ...modeling_flax_utils import FlaxPreTrainedModel, gelu
...@@ -101,8 +102,8 @@ class FlaxRobertaLayerNorm(nn.Module): ...@@ -101,8 +102,8 @@ class FlaxRobertaLayerNorm(nn.Module):
scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear
# (also e.g. nn.relu), this can be disabled since the scaling will be # (also e.g. nn.relu), this can be disabled since the scaling will be
# done by the next layer. # done by the next layer.
bias_init: jnp.ndarray = nn.initializers.zeros scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
scale_init: jnp.ndarray = nn.initializers.ones bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
@nn.compact @nn.compact
def __call__(self, x): def __call__(self, x):
...@@ -122,11 +123,13 @@ class FlaxRobertaLayerNorm(nn.Module): ...@@ -122,11 +123,13 @@ class FlaxRobertaLayerNorm(nn.Module):
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True) mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
var = mean2 - jax.lax.square(mean) var = mean2 - jax.lax.square(mean)
mul = jax.lax.rsqrt(var + self.epsilon) mul = jax.lax.rsqrt(var + self.epsilon)
if self.scale: if self.scale:
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)), self.dtype) mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)))
y = (x - mean) * mul y = (x - mean) * mul
if self.bias: if self.bias:
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)), self.dtype) y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)))
return y return y
...@@ -139,7 +142,9 @@ class FlaxRobertaEmbedding(nn.Module): ...@@ -139,7 +142,9 @@ class FlaxRobertaEmbedding(nn.Module):
vocab_size: int vocab_size: int
hidden_size: int hidden_size: int
emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1) kernel_init_scale: float = 0.2
emb_init: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=kernel_init_scale)
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact @nn.compact
def __call__(self, inputs): def __call__(self, inputs):
...@@ -155,66 +160,108 @@ class FlaxRobertaEmbeddings(nn.Module): ...@@ -155,66 +160,108 @@ class FlaxRobertaEmbeddings(nn.Module):
hidden_size: int hidden_size: int
type_vocab_size: int type_vocab_size: int
max_length: int max_length: int
kernel_init_scale: float = 0.2
dropout_rate: float = 0.0
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact @nn.compact
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask): def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
# Embed # Embed
w_emb = FlaxRobertaEmbedding(self.vocab_size, self.hidden_size, name="word_embeddings")( w_emb = FlaxRobertaEmbedding(
jnp.atleast_2d(input_ids.astype("i4")) self.vocab_size,
) self.hidden_size,
p_emb = FlaxRobertaEmbedding(self.max_length, self.hidden_size, name="position_embeddings")( kernel_init_scale=self.kernel_init_scale,
jnp.atleast_2d(position_ids.astype("i4")) name="word_embeddings",
) dtype=self.dtype,
t_emb = FlaxRobertaEmbedding(self.type_vocab_size, self.hidden_size, name="token_type_embeddings")( )(jnp.atleast_2d(input_ids.astype("i4")))
jnp.atleast_2d(token_type_ids.astype("i4")) p_emb = FlaxRobertaEmbedding(
) self.max_length,
self.hidden_size,
kernel_init_scale=self.kernel_init_scale,
name="position_embeddings",
dtype=self.dtype,
)(jnp.atleast_2d(position_ids.astype("i4")))
t_emb = FlaxRobertaEmbedding(
self.type_vocab_size,
self.hidden_size,
kernel_init_scale=self.kernel_init_scale,
name="token_type_embeddings",
dtype=self.dtype,
)(jnp.atleast_2d(token_type_ids.astype("i4")))
# Sum all embeddings # Sum all embeddings
summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb
# Layer Norm # Layer Norm
layer_norm = FlaxRobertaLayerNorm(name="layer_norm")(summed_emb) layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(summed_emb)
embeddings = nn.Dropout(rate=self.dropout_rate)(layer_norm, deterministic=deterministic)
return layer_norm return embeddings
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta
class FlaxRobertaAttention(nn.Module): class FlaxRobertaAttention(nn.Module):
num_heads: int num_heads: int
head_size: int head_size: int
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact @nn.compact
def __call__(self, hidden_state, attention_mask): def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")( self_att = nn.attention.SelfAttention(
hidden_state, attention_mask num_heads=self.num_heads,
) qkv_features=self.head_size,
dropout_rate=self.dropout_rate,
deterministic=deterministic,
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
bias_init=jax.nn.initializers.zeros,
name="self",
dtype=self.dtype,
)(hidden_state, attention_mask)
layer_norm = FlaxRobertaLayerNorm(name="layer_norm")(self_att + hidden_state) layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_state)
return layer_norm return layer_norm
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
class FlaxRobertaIntermediate(nn.Module): class FlaxRobertaIntermediate(nn.Module):
output_size: int output_size: int
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact @nn.compact
def __call__(self, hidden_state): def __call__(self, hidden_state):
# TODO: Add ACT2FN reference to change activation function # TODO: Add ACT2FN reference to change activation function
dense = nn.Dense(features=self.output_size, name="dense")(hidden_state) dense = nn.Dense(
features=self.output_size,
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
name="dense",
dtype=self.dtype,
)(hidden_state)
return gelu(dense) return gelu(dense)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta
class FlaxRobertaOutput(nn.Module): class FlaxRobertaOutput(nn.Module):
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact @nn.compact
def __call__(self, intermediate_output, attention_output): def __call__(self, intermediate_output, attention_output, deterministic: bool = True):
hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output) hidden_state = nn.Dense(
hidden_state = FlaxRobertaLayerNorm(name="layer_norm")(hidden_state + attention_output) attention_output.shape[-1],
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
name="dense",
dtype=self.dtype,
)(intermediate_output)
hidden_state = nn.Dropout(rate=self.dropout_rate)(hidden_state, deterministic=deterministic)
hidden_state = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_state + attention_output)
return hidden_state return hidden_state
...@@ -222,14 +269,29 @@ class FlaxRobertaLayer(nn.Module): ...@@ -222,14 +269,29 @@ class FlaxRobertaLayer(nn.Module):
num_heads: int num_heads: int
head_size: int head_size: int
intermediate_size: int intermediate_size: int
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact @nn.compact
def __call__(self, hidden_state, attention_mask): def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
attention = FlaxRobertaAttention(self.num_heads, self.head_size, name="attention")( attention = FlaxRobertaAttention(
hidden_state, attention_mask self.num_heads,
) self.head_size,
intermediate = FlaxRobertaIntermediate(self.intermediate_size, name="intermediate")(attention) kernel_init_scale=self.kernel_init_scale,
output = FlaxRobertaOutput(name="output")(intermediate, attention) dropout_rate=self.dropout_rate,
name="attention",
dtype=self.dtype,
)(hidden_state, attention_mask, deterministic=deterministic)
intermediate = FlaxRobertaIntermediate(
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
name="intermediate",
dtype=self.dtype,
)(attention)
output = FlaxRobertaOutput(
kernel_init_scale=self.kernel_init_scale, dropout_rate=self.dropout_rate, name="output", dtype=self.dtype
)(intermediate, attention, deterministic=deterministic)
return output return output
...@@ -244,9 +306,12 @@ class FlaxRobertaLayerCollection(nn.Module): ...@@ -244,9 +306,12 @@ class FlaxRobertaLayerCollection(nn.Module):
num_heads: int num_heads: int
head_size: int head_size: int
intermediate_size: int intermediate_size: int
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact @nn.compact
def __call__(self, inputs, attention_mask): def __call__(self, inputs, attention_mask, deterministic: bool = True):
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})" assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})"
# Initialize input / output # Initialize input / output
...@@ -254,8 +319,16 @@ class FlaxRobertaLayerCollection(nn.Module): ...@@ -254,8 +319,16 @@ class FlaxRobertaLayerCollection(nn.Module):
# Forward over all encoders # Forward over all encoders
for i in range(self.num_layers): for i in range(self.num_layers):
layer = FlaxRobertaLayer(self.num_heads, self.head_size, self.intermediate_size, name=f"{i}") layer = FlaxRobertaLayer(
input_i = layer(input_i, attention_mask) self.num_heads,
self.head_size,
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name=f"{i}",
dtype=self.dtype,
)
input_i = layer(input_i, attention_mask, deterministic=deterministic)
return input_i return input_i
...@@ -265,22 +338,40 @@ class FlaxRobertaEncoder(nn.Module): ...@@ -265,22 +338,40 @@ class FlaxRobertaEncoder(nn.Module):
num_heads: int num_heads: int
head_size: int head_size: int
intermediate_size: int intermediate_size: int
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact @nn.compact
def __call__(self, hidden_state, attention_mask): def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
layer = FlaxRobertaLayerCollection( layer = FlaxRobertaLayerCollection(
self.num_layers, self.num_heads, self.head_size, self.intermediate_size, name="layer" self.num_layers,
)(hidden_state, attention_mask) self.num_heads,
self.head_size,
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="layer",
dtype=self.dtype,
)(hidden_state, attention_mask, deterministic=deterministic)
return layer return layer
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
class FlaxRobertaPooler(nn.Module): class FlaxRobertaPooler(nn.Module):
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact @nn.compact
def __call__(self, hidden_state): def __call__(self, hidden_state):
cls_token = hidden_state[:, 0] cls_token = hidden_state[:, 0]
out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token) out = nn.Dense(
return jax.lax.tanh(out) hidden_state.shape[-1],
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
name="dense",
dtype=self.dtype,
)(cls_token)
return nn.tanh(out)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta
...@@ -293,21 +384,38 @@ class FlaxRobertaModule(nn.Module): ...@@ -293,21 +384,38 @@ class FlaxRobertaModule(nn.Module):
num_heads: int num_heads: int
head_size: int head_size: int
intermediate_size: int intermediate_size: int
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact @nn.compact
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids): def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
# Embedding # Embedding
embeddings = FlaxRobertaEmbeddings( embeddings = FlaxRobertaEmbeddings(
self.vocab_size, self.hidden_size, self.type_vocab_size, self.max_length, name="embeddings" self.vocab_size,
)(input_ids, token_type_ids, position_ids, attention_mask) self.hidden_size,
self.type_vocab_size,
self.max_length,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="embeddings",
dtype=self.dtype,
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
# N stacked encoding layers # N stacked encoding layers
encoder = FlaxRobertaEncoder( encoder = FlaxRobertaEncoder(
self.num_encoder_layers, self.num_heads, self.head_size, self.intermediate_size, name="encoder" self.num_encoder_layers,
)(embeddings, attention_mask) self.num_heads,
self.head_size,
pooled = FlaxRobertaPooler(name="pooler")(encoder) self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="encoder",
dtype=self.dtype,
)(embeddings, attention_mask, deterministic=deterministic)
pooled = FlaxRobertaPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder)
return encoder, pooled return encoder, pooled
...@@ -396,8 +504,8 @@ class FlaxRobertaModel(FlaxPreTrainedModel): ...@@ -396,8 +504,8 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
return jax_state return jax_state
def __init__(self, config: RobertaConfig, state: dict, seed: int = 0, **kwargs): def __init__(self, config: RobertaConfig, state: dict, seed: int = 0, dtype: jnp.dtype = jnp.float32):
model = FlaxRobertaModule( module = FlaxRobertaModule(
vocab_size=config.vocab_size, vocab_size=config.vocab_size,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
type_vocab_size=config.type_vocab_size, type_vocab_size=config.type_vocab_size,
...@@ -406,31 +514,78 @@ class FlaxRobertaModel(FlaxPreTrainedModel): ...@@ -406,31 +514,78 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
head_size=config.hidden_size, head_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
dropout_rate=config.hidden_dropout_prob,
dtype=dtype,
) )
super().__init__(config, model, state, seed) super().__init__(config, module, state, seed)
@property
def module(self) -> nn.Module:
return self._module
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None): def __call__(
self,
input_ids,
token_type_ids=None,
attention_mask=None,
position_ids=None,
params: dict = None,
dropout_rng: PRNGKey = None,
train: bool = False,
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
rngs=rngs,
)
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
jnp.zeros(input_shape, dtype="i4"), None, None, None
)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
self.params = self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids):
if token_type_ids is None: if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids) token_type_ids = jnp.ones_like(input_ids)
if position_ids is None: if position_ids is None:
position_ids = jnp.arange( position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
self.config.pad_token_id + 1, jnp.atleast_2d(input_ids).shape[-1] + self.config.pad_token_id + 1
)
if attention_mask is None: if attention_mask is None:
attention_mask = jnp.ones_like(input_ids) attention_mask = jnp.ones_like(input_ids)
return self.model.apply( return input_ids, attention_mask, token_type_ids, position_ids
{"params": self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"), def create_position_ids_from_input_ids(input_ids, padding_idx):
jnp.array(token_type_ids, dtype="i4"), """
jnp.array(position_ids, dtype="i4"), Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
) are ignored. This is modified from fairseq's `utils.make_positions`.
Args:
input_ids: jnp.ndarray
padding_idx: int
Returns: jnp.ndarray
"""
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = (input_ids != padding_idx).astype("i4")
incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
return incremental_indices.astype("i4") + padding_idx
...@@ -14,6 +14,15 @@ class FlaxAutoModel: ...@@ -14,6 +14,15 @@ class FlaxAutoModel:
requires_flax(self) requires_flax(self)
class FlaxBertForMaskedLM:
def __init__(self, *args, **kwargs):
requires_flax(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_flax(self)
class FlaxBertModel: class FlaxBertModel:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_flax(self) requires_flax(self)
......
...@@ -57,7 +57,7 @@ class FlaxRobertaModelTest(unittest.TestCase): ...@@ -57,7 +57,7 @@ class FlaxRobertaModelTest(unittest.TestCase):
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs.to_tuple()): for fx_output, pt_output in zip(fx_outputs, pt_outputs.to_tuple()):
self.assert_almost_equals(fx_output, pt_output.numpy(), 6e-4) self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3)
def test_multiple_sequences(self): def test_multiple_sequences(self):
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base") tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment