Unverified Commit 0b98ca36 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Adapt Flax models to new structure (#9484)



* Create modeling_flax_eletra with code copied from modeling_flax_bert

* Add ElectraForMaskedLM and ElectraForPretraining

* Add modeling test for Flax electra and fix naming and arg in Flax Electra model

* Add documentation

* Fix code style

* Create modeling_flax_eletra with code copied from modeling_flax_bert

* Add ElectraForMaskedLM and ElectraForPretraining

* Add modeling test for Flax electra and fix naming and arg in Flax Electra model

* Add documentation

* Fix code style

* Fix code quality

* Adjust tol in assert_almost_equal due to very small difference between model output, ranging 0.0010 - 0.0016

* Remove redundant ElectraPooler

* save intermediate

* adapt

* correct bert flax design

* adapt roberta as well

* finish roberta flax

* finish

* apply suggestions

* apply suggestions
Co-authored-by: default avatarChris Nguyen <anhtu2687@gmail.com>
parent 5c0bf397
...@@ -97,6 +97,7 @@ class FlaxBertLayerNorm(nn.Module): ...@@ -97,6 +97,7 @@ class FlaxBertLayerNorm(nn.Module):
Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data. Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data.
""" """
hidden_size: int
epsilon: float = 1e-6 epsilon: float = 1e-6
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
bias: bool = True # If True, bias (beta) is added. bias: bool = True # If True, bias (beta) is added.
...@@ -106,7 +107,10 @@ class FlaxBertLayerNorm(nn.Module): ...@@ -106,7 +107,10 @@ class FlaxBertLayerNorm(nn.Module):
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
@nn.compact def setup(self):
self.gamma = self.param("gamma", self.scale_init, (self.hidden_size,))
self.beta = self.param("beta", self.scale_init, (self.hidden_size,))
def __call__(self, x): def __call__(self, x):
""" """
Applies layer normalization on the input. It normalizes the activations of the layer for each given example in Applies layer normalization on the input. It normalizes the activations of the layer for each given example in
...@@ -119,18 +123,17 @@ class FlaxBertLayerNorm(nn.Module): ...@@ -119,18 +123,17 @@ class FlaxBertLayerNorm(nn.Module):
Returns: Returns:
Normalized inputs (the same shape as inputs). Normalized inputs (the same shape as inputs).
""" """
features = x.shape[-1]
mean = jnp.mean(x, axis=-1, keepdims=True) mean = jnp.mean(x, axis=-1, keepdims=True)
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True) mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
var = mean2 - jax.lax.square(mean) var = mean2 - jax.lax.square(mean)
mul = jax.lax.rsqrt(var + self.epsilon) mul = jax.lax.rsqrt(var + self.epsilon)
if self.scale: if self.scale:
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,))) mul = mul * jnp.asarray(self.gamma)
y = (x - mean) * mul y = (x - mean) * mul
if self.bias: if self.bias:
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,))) y = y + jnp.asarray(self.beta)
return y return y
...@@ -142,278 +145,232 @@ class FlaxBertEmbedding(nn.Module): ...@@ -142,278 +145,232 @@ class FlaxBertEmbedding(nn.Module):
vocab_size: int vocab_size: int
hidden_size: int hidden_size: int
kernel_init_scale: float = 0.2 initializer_range: float
emb_init: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=kernel_init_scale)
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact def setup(self):
def __call__(self, inputs): init_fn: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=self.initializer_range)
embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size)) self.embeddings = self.param("weight", init_fn, (self.vocab_size, self.hidden_size))
return jnp.take(embedding, inputs, axis=0)
def __call__(self, input_ids):
return jnp.take(self.embeddings, input_ids, axis=0)
class FlaxBertEmbeddings(nn.Module): class FlaxBertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.""" """Construct the embeddings from word, position and token_type embeddings."""
vocab_size: int config: BertConfig
hidden_size: int
type_vocab_size: int
max_length: int
kernel_init_scale: float = 0.2
dropout_rate: float = 0.0
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact def setup(self):
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): self.word_embeddings = FlaxBertEmbedding(
self.config.vocab_size,
# Embed self.config.hidden_size,
w_emb = FlaxBertEmbedding( initializer_range=self.config.initializer_range,
self.vocab_size,
self.hidden_size,
kernel_init_scale=self.kernel_init_scale,
name="word_embeddings", name="word_embeddings",
dtype=self.dtype, dtype=self.dtype,
)(jnp.atleast_2d(input_ids.astype("i4"))) )
p_emb = FlaxBertEmbedding( self.position_embeddings = FlaxBertEmbedding(
self.max_length, self.config.max_position_embeddings,
self.hidden_size, self.config.hidden_size,
kernel_init_scale=self.kernel_init_scale, initializer_range=self.config.initializer_range,
name="position_embeddings", name="position_embeddings",
dtype=self.dtype, dtype=self.dtype,
)(jnp.atleast_2d(position_ids.astype("i4"))) )
t_emb = FlaxBertEmbedding( self.token_type_embeddings = FlaxBertEmbedding(
self.type_vocab_size, self.config.type_vocab_size,
self.hidden_size, self.config.hidden_size,
kernel_init_scale=self.kernel_init_scale, initializer_range=self.config.initializer_range,
name="token_type_embeddings", name="token_type_embeddings",
dtype=self.dtype, dtype=self.dtype,
)(jnp.atleast_2d(token_type_ids.astype("i4"))) )
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
# Embed
inputs_embeds = self.word_embeddings(jnp.atleast_2d(input_ids.astype("i4")))
position_embeds = self.position_embeddings(jnp.atleast_2d(position_ids.astype("i4")))
token_type_embeddings = self.token_type_embeddings(jnp.atleast_2d(token_type_ids.astype("i4")))
# Sum all embeddings # Sum all embeddings
summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings
# Layer Norm # Layer Norm
layer_norm = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(summed_emb) hidden_states = self.layer_norm(hidden_states)
embeddings = nn.Dropout(rate=self.dropout_rate)(layer_norm, deterministic=deterministic) hidden_states = self.dropout(hidden_states, deterministic=deterministic)
return embeddings return hidden_states
class FlaxBertAttention(nn.Module): class FlaxBertAttention(nn.Module):
num_heads: int config: BertConfig
head_size: int
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact def setup(self):
def __call__(self, hidden_states, attention_mask, deterministic: bool = True): self.self_attention = nn.attention.SelfAttention(
num_heads=self.config.num_attention_heads,
qkv_features=self.config.hidden_size,
dropout_rate=self.config.attention_probs_dropout_prob,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
bias_init=jax.nn.initializers.zeros,
name="self",
dtype=self.dtype,
)
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic=True):
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
self_att = nn.attention.SelfAttention( self_attn_output = self.self_attention(hidden_states, attention_mask, deterministic=deterministic)
num_heads=self.num_heads,
qkv_features=self.head_size,
dropout_rate=self.dropout_rate,
deterministic=deterministic,
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
bias_init=jax.nn.initializers.zeros,
name="self",
dtype=self.dtype,
)(hidden_states, attention_mask)
layer_norm = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_states) hidden_states = self.layer_norm(self_attn_output + hidden_states)
return layer_norm return hidden_states
class FlaxBertIntermediate(nn.Module): class FlaxBertIntermediate(nn.Module):
output_size: int config: BertConfig
hidden_act: str = "gelu"
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact def setup(self):
def __call__(self, hidden_states): self.dense = nn.Dense(
hidden_states = nn.Dense( self.config.intermediate_size,
features=self.output_size, kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
name="dense", name="dense",
dtype=self.dtype, dtype=self.dtype,
)(hidden_states) )
hidden_states = ACT2FN[self.hidden_act](hidden_states) self.activation = ACT2FN[self.config.hidden_act]
def __call__(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states return hidden_states
class FlaxBertOutput(nn.Module): class FlaxBertOutput(nn.Module):
dropout_rate: float = 0.0 config: BertConfig
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact def setup(self):
def __call__(self, intermediate_output, attention_output, deterministic: bool = True): self.dense = nn.Dense(
hidden_states = nn.Dense( self.config.hidden_size,
attention_output.shape[-1], kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
name="dense", name="dense",
dtype=self.dtype, dtype=self.dtype,
)(intermediate_output) )
hidden_states = nn.Dropout(rate=self.dropout_rate)(hidden_states, deterministic=deterministic) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
hidden_states = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states + attention_output) self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.layer_norm(hidden_states + attention_output)
return hidden_states return hidden_states
class FlaxBertLayer(nn.Module): class FlaxBertLayer(nn.Module):
num_heads: int config: BertConfig
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact def setup(self):
def __call__(self, hidden_states, attention_mask, deterministic: bool = True): self.attention = FlaxBertAttention(self.config, name="attention", dtype=self.dtype)
attention = FlaxBertAttention( self.intermediate = FlaxBertIntermediate(self.config, name="intermediate", dtype=self.dtype)
self.num_heads, self.output = FlaxBertOutput(self.config, name="output", dtype=self.dtype)
self.head_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="attention",
dtype=self.dtype,
)(hidden_states, attention_mask, deterministic=deterministic)
intermediate = FlaxBertIntermediate(
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
hidden_act=self.hidden_act,
name="intermediate",
dtype=self.dtype,
)(attention)
output = FlaxBertOutput(
kernel_init_scale=self.kernel_init_scale, dropout_rate=self.dropout_rate, name="output", dtype=self.dtype
)(intermediate, attention, deterministic=deterministic)
return output def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic)
hidden_states = self.intermediate(attention_output)
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
return hidden_states
class FlaxBertLayerCollection(nn.Module): class FlaxBertLayerCollection(nn.Module):
""" config: BertConfig
Stores N BertLayer(s)
"""
num_layers: int
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact def setup(self):
def __call__(self, inputs, attention_mask, deterministic: bool = True): self.layers = [
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})" FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
]
# Initialize input / output
input_i = inputs def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
for layer in self.layers:
# Forward over all encoders hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic)
for i in range(self.num_layers): return hidden_states
layer = FlaxBertLayer(
self.num_heads,
self.head_size,
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
hidden_act=self.hidden_act,
name=f"{i}",
dtype=self.dtype,
)
input_i = layer(input_i, attention_mask, deterministic=deterministic)
return input_i
class FlaxBertEncoder(nn.Module): class FlaxBertEncoder(nn.Module):
num_layers: int config: BertConfig
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact def setup(self):
self.layers = FlaxBertLayerCollection(self.config, name="layer", dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic: bool = True): def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
layer = FlaxBertLayerCollection( return self.layers(hidden_states, attention_mask, deterministic=deterministic)
self.num_layers,
self.num_heads,
self.head_size,
self.intermediate_size,
hidden_act=self.hidden_act,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="layer",
dtype=self.dtype,
)(hidden_states, attention_mask, deterministic=deterministic)
return layer
class FlaxBertPooler(nn.Module): class FlaxBertPooler(nn.Module):
kernel_init_scale: float = 0.2 config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact def setup(self):
def __call__(self, hidden_states): self.dense = nn.Dense(
cls_token = hidden_states[:, 0] self.config.hidden_size,
out = nn.Dense( kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
hidden_states.shape[-1],
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
name="dense", name="dense",
dtype=self.dtype, dtype=self.dtype,
)(cls_token) )
return nn.tanh(out)
def __call__(self, hidden_states):
cls_hidden_state = hidden_states[:, 0]
cls_hidden_state = self.dense(cls_hidden_state)
return nn.tanh(cls_hidden_state)
class FlaxBertPredictionHeadTransform(nn.Module): class FlaxBertPredictionHeadTransform(nn.Module):
hidden_act: str = "gelu" config: BertConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
@nn.compact def setup(self):
self.dense = nn.Dense(self.config.hidden_size, name="dense", dtype=self.dtype)
self.activation = ACT2FN[self.config.hidden_act]
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
def __call__(self, hidden_states): def __call__(self, hidden_states):
hidden_states = nn.Dense(hidden_states.shape[-1], name="dense", dtype=self.dtype)(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = ACT2FN[self.hidden_act](hidden_states) hidden_states = self.activation(hidden_states)
return FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states) return self.layer_norm(hidden_states)
class FlaxBertLMPredictionHead(nn.Module): class FlaxBertLMPredictionHead(nn.Module):
vocab_size: int config: BertConfig
hidden_act: str = "gelu"
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
@nn.compact def setup(self):
self.transform = FlaxBertPredictionHeadTransform(self.config, name="transform", dtype=self.dtype)
self.decoder = nn.Dense(self.config.vocab_size, name="decoder", dtype=self.dtype)
def __call__(self, hidden_states): def __call__(self, hidden_states):
# TODO: The output weights are the same as the input embeddings, but there is # TODO: The output weights are the same as the input embeddings, but there is
# an output-only bias for each token. # an output-only bias for each token.
# Need a link between the two variables so that the bias is correctly # Need a link between the two variables so that the bias is correctly
# resized with `resize_token_embeddings` # resized with `resize_token_embeddings`
hidden_states = self.transform(hidden_states)
hidden_states = FlaxBertPredictionHeadTransform( hidden_states = self.decoder(hidden_states)
name="transform", hidden_act=self.hidden_act, dtype=self.dtype
)(hidden_states)
hidden_states = nn.Dense(self.vocab_size, name="decoder", dtype=self.dtype)(hidden_states)
return hidden_states return hidden_states
class FlaxBertOnlyMLMHead(nn.Module): class FlaxBertOnlyMLMHead(nn.Module):
vocab_size: int config: BertConfig
hidden_act: str = "gelu"
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
@nn.compact def setup(self):
self.mlm_head = FlaxBertLMPredictionHead(self.config, name="predictions", dtype=self.dtype)
def __call__(self, hidden_states): def __call__(self, hidden_states):
hidden_states = FlaxBertLMPredictionHead( hidden_states = self.mlm_head(hidden_states)
vocab_size=self.vocab_size, hidden_act=self.hidden_act, name="predictions", dtype=self.dtype
)(hidden_states)
return hidden_states return hidden_states
...@@ -543,20 +500,7 @@ class FlaxBertModel(FlaxBertPreTrainedModel): ...@@ -543,20 +500,7 @@ class FlaxBertModel(FlaxBertPreTrainedModel):
def __init__( def __init__(
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
): ):
module = FlaxBertModule( module = FlaxBertModule(config=config, dtype=dtype, **kwargs)
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
type_vocab_size=config.type_vocab_size,
max_length=config.max_position_embeddings,
num_encoder_layers=config.num_hidden_layers,
num_heads=config.num_attention_heads,
head_size=config.hidden_size,
intermediate_size=config.intermediate_size,
dropout_rate=config.hidden_dropout_prob,
hidden_act=config.hidden_act,
dtype=dtype,
**kwargs,
)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
...@@ -592,71 +536,34 @@ class FlaxBertModel(FlaxBertPreTrainedModel): ...@@ -592,71 +536,34 @@ class FlaxBertModel(FlaxBertPreTrainedModel):
class FlaxBertModule(nn.Module): class FlaxBertModule(nn.Module):
vocab_size: int config: BertConfig
hidden_size: int
type_vocab_size: int
max_length: int
num_encoder_layers: int
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
add_pooling_layer: bool = True add_pooling_layer: bool = True
@nn.compact def setup(self):
self.embeddings = FlaxBertEmbeddings(self.config, name="embeddings", dtype=self.dtype)
self.encoder = FlaxBertEncoder(self.config, name="encoder", dtype=self.dtype)
self.pooler = FlaxBertPooler(self.config, name="pooler", dtype=self.dtype)
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True): def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
# Embedding hidden_states = self.embeddings(
embeddings = FlaxBertEmbeddings( input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
self.vocab_size, )
self.hidden_size, hidden_states = self.encoder(hidden_states, attention_mask, deterministic=deterministic)
self.type_vocab_size,
self.max_length,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="embeddings",
dtype=self.dtype,
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
# N stacked encoding layers
encoder = FlaxBertEncoder(
self.num_encoder_layers,
self.num_heads,
self.head_size,
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
hidden_act=self.hidden_act,
name="encoder",
dtype=self.dtype,
)(embeddings, attention_mask, deterministic=deterministic)
if not self.add_pooling_layer: if not self.add_pooling_layer:
return encoder return hidden_states
pooled = FlaxBertPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder) pooled = self.pooler(hidden_states)
return encoder, pooled return hidden_states, pooled
class FlaxBertForMaskedLM(FlaxBertPreTrainedModel): class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
def __init__( def __init__(
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
): ):
module = FlaxBertForMaskedLMModule( module = FlaxBertForMaskedLMModule(config, **kwargs)
vocab_size=config.vocab_size,
type_vocab_size=config.type_vocab_size,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
head_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_encoder_layers=config.num_hidden_layers,
max_length=config.max_position_embeddings,
hidden_act=config.hidden_act,
**kwargs,
)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
...@@ -691,43 +598,32 @@ class FlaxBertForMaskedLM(FlaxBertPreTrainedModel): ...@@ -691,43 +598,32 @@ class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
class FlaxBertForMaskedLMModule(nn.Module): class FlaxBertForMaskedLMModule(nn.Module):
vocab_size: int config: BertConfig
hidden_size: int
intermediate_size: int
head_size: int
num_heads: int
num_encoder_layers: int
type_vocab_size: int
max_length: int
hidden_act: str
dropout_rate: float = 0.0
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
@nn.compact def setup(self):
self.encoder = FlaxBertModule(
config=self.config,
add_pooling_layer=False,
name="bert",
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.mlm_head = FlaxBertOnlyMLMHead(
config=self.config,
name="cls",
dtype=self.dtype,
)
def __call__( def __call__(
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
): ):
# Model # Model
encoder = FlaxBertModule( hidden_states = self.encoder(
vocab_size=self.vocab_size, input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
type_vocab_size=self.type_vocab_size, )
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
head_size=self.hidden_size,
num_heads=self.num_heads,
num_encoder_layers=self.num_encoder_layers,
max_length=self.max_length,
dropout_rate=self.dropout_rate,
hidden_act=self.hidden_act,
dtype=self.dtype,
add_pooling_layer=False,
name="bert",
)(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
# Compute the prediction scores # Compute the prediction scores
encoder = nn.Dropout(rate=self.dropout_rate)(encoder, deterministic=deterministic) hidden_states = self.dropout(hidden_states, deterministic=deterministic)
logits = FlaxBertOnlyMLMHead( logits = self.mlm_head(hidden_states)
vocab_size=self.vocab_size, hidden_act=self.hidden_act, name="cls", dtype=self.dtype
)(encoder)
return (logits,) return (logits,)
...@@ -114,6 +114,7 @@ class FlaxRobertaLayerNorm(nn.Module): ...@@ -114,6 +114,7 @@ class FlaxRobertaLayerNorm(nn.Module):
Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data. Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data.
""" """
hidden_size: int
epsilon: float = 1e-6 epsilon: float = 1e-6
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
bias: bool = True # If True, bias (beta) is added. bias: bool = True # If True, bias (beta) is added.
...@@ -123,7 +124,10 @@ class FlaxRobertaLayerNorm(nn.Module): ...@@ -123,7 +124,10 @@ class FlaxRobertaLayerNorm(nn.Module):
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
@nn.compact def setup(self):
self.gamma = self.param("gamma", self.scale_init, (self.hidden_size,))
self.beta = self.param("beta", self.scale_init, (self.hidden_size,))
def __call__(self, x): def __call__(self, x):
""" """
Applies layer normalization on the input. It normalizes the activations of the layer for each given example in Applies layer normalization on the input. It normalizes the activations of the layer for each given example in
...@@ -136,18 +140,17 @@ class FlaxRobertaLayerNorm(nn.Module): ...@@ -136,18 +140,17 @@ class FlaxRobertaLayerNorm(nn.Module):
Returns: Returns:
Normalized inputs (the same shape as inputs). Normalized inputs (the same shape as inputs).
""" """
features = x.shape[-1]
mean = jnp.mean(x, axis=-1, keepdims=True) mean = jnp.mean(x, axis=-1, keepdims=True)
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True) mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
var = mean2 - jax.lax.square(mean) var = mean2 - jax.lax.square(mean)
mul = jax.lax.rsqrt(var + self.epsilon) mul = jax.lax.rsqrt(var + self.epsilon)
if self.scale: if self.scale:
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,))) mul = mul * jnp.asarray(self.gamma)
y = (x - mean) * mul y = (x - mean) * mul
if self.bias: if self.bias:
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,))) y = y + jnp.asarray(self.beta)
return y return y
...@@ -160,243 +163,202 @@ class FlaxRobertaEmbedding(nn.Module): ...@@ -160,243 +163,202 @@ class FlaxRobertaEmbedding(nn.Module):
vocab_size: int vocab_size: int
hidden_size: int hidden_size: int
kernel_init_scale: float = 0.2 initializer_range: float
emb_init: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=kernel_init_scale)
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact def setup(self):
def __call__(self, inputs): init_fn: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=self.initializer_range)
embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size)) self.embeddings = self.param("weight", init_fn, (self.vocab_size, self.hidden_size))
return jnp.take(embedding, inputs, axis=0)
def __call__(self, input_ids):
return jnp.take(self.embeddings, input_ids, axis=0)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->Roberta
class FlaxRobertaEmbeddings(nn.Module): class FlaxRobertaEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.""" """Construct the embeddings from word, position and token_type embeddings."""
vocab_size: int config: RobertaConfig
hidden_size: int
type_vocab_size: int
max_length: int
kernel_init_scale: float = 0.2
dropout_rate: float = 0.0
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact def setup(self):
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): self.word_embeddings = FlaxRobertaEmbedding(
self.config.vocab_size,
# Embed self.config.hidden_size,
w_emb = FlaxRobertaEmbedding( initializer_range=self.config.initializer_range,
self.vocab_size,
self.hidden_size,
kernel_init_scale=self.kernel_init_scale,
name="word_embeddings", name="word_embeddings",
dtype=self.dtype, dtype=self.dtype,
)(jnp.atleast_2d(input_ids.astype("i4"))) )
p_emb = FlaxRobertaEmbedding( self.position_embeddings = FlaxRobertaEmbedding(
self.max_length, self.config.max_position_embeddings,
self.hidden_size, self.config.hidden_size,
kernel_init_scale=self.kernel_init_scale, initializer_range=self.config.initializer_range,
name="position_embeddings", name="position_embeddings",
dtype=self.dtype, dtype=self.dtype,
)(jnp.atleast_2d(position_ids.astype("i4"))) )
t_emb = FlaxRobertaEmbedding( self.token_type_embeddings = FlaxRobertaEmbedding(
self.type_vocab_size, self.config.type_vocab_size,
self.hidden_size, self.config.hidden_size,
kernel_init_scale=self.kernel_init_scale, initializer_range=self.config.initializer_range,
name="token_type_embeddings", name="token_type_embeddings",
dtype=self.dtype, dtype=self.dtype,
)(jnp.atleast_2d(token_type_ids.astype("i4"))) )
self.layer_norm = FlaxRobertaLayerNorm(
hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
# Embed
inputs_embeds = self.word_embeddings(jnp.atleast_2d(input_ids.astype("i4")))
position_embeds = self.position_embeddings(jnp.atleast_2d(position_ids.astype("i4")))
token_type_embeddings = self.token_type_embeddings(jnp.atleast_2d(token_type_ids.astype("i4")))
# Sum all embeddings # Sum all embeddings
summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings
# Layer Norm # Layer Norm
layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(summed_emb) hidden_states = self.layer_norm(hidden_states)
embeddings = nn.Dropout(rate=self.dropout_rate)(layer_norm, deterministic=deterministic) hidden_states = self.dropout(hidden_states, deterministic=deterministic)
return embeddings return hidden_states
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta
class FlaxRobertaAttention(nn.Module): class FlaxRobertaAttention(nn.Module):
num_heads: int config: RobertaConfig
head_size: int
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact def setup(self):
def __call__(self, hidden_states, attention_mask, deterministic: bool = True): self.self_attention = nn.attention.SelfAttention(
num_heads=self.config.num_attention_heads,
qkv_features=self.config.hidden_size,
dropout_rate=self.config.attention_probs_dropout_prob,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
bias_init=jax.nn.initializers.zeros,
name="self",
dtype=self.dtype,
)
self.layer_norm = FlaxRobertaLayerNorm(
hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype
)
def __call__(self, hidden_states, attention_mask, deterministic=True):
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
self_att = nn.attention.SelfAttention( self_attn_output = self.self_attention(hidden_states, attention_mask, deterministic=deterministic)
num_heads=self.num_heads,
qkv_features=self.head_size,
dropout_rate=self.dropout_rate,
deterministic=deterministic,
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
bias_init=jax.nn.initializers.zeros,
name="self",
dtype=self.dtype,
)(hidden_states, attention_mask)
layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_states) hidden_states = self.layer_norm(self_attn_output + hidden_states)
return layer_norm return hidden_states
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
class FlaxRobertaIntermediate(nn.Module): class FlaxRobertaIntermediate(nn.Module):
output_size: int config: RobertaConfig
hidden_act: str = "gelu"
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact def setup(self):
def __call__(self, hidden_states): self.dense = nn.Dense(
hidden_states = nn.Dense( self.config.intermediate_size,
features=self.output_size, kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
name="dense", name="dense",
dtype=self.dtype, dtype=self.dtype,
)(hidden_states) )
hidden_states = ACT2FN[self.hidden_act](hidden_states) self.activation = ACT2FN[self.config.hidden_act]
def __call__(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states return hidden_states
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta
class FlaxRobertaOutput(nn.Module): class FlaxRobertaOutput(nn.Module):
dropout_rate: float = 0.0 config: RobertaConfig
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact def setup(self):
def __call__(self, intermediate_output, attention_output, deterministic: bool = True): self.dense = nn.Dense(
hidden_states = nn.Dense( self.config.hidden_size,
attention_output.shape[-1], kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
name="dense", name="dense",
dtype=self.dtype, dtype=self.dtype,
)(intermediate_output) )
hidden_states = nn.Dropout(rate=self.dropout_rate)(hidden_states, deterministic=deterministic) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
hidden_states = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states + attention_output) self.layer_norm = FlaxRobertaLayerNorm(
hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype
)
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.layer_norm(hidden_states + attention_output)
return hidden_states return hidden_states
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer with Bert->Roberta
class FlaxRobertaLayer(nn.Module): class FlaxRobertaLayer(nn.Module):
num_heads: int config: RobertaConfig
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact def setup(self):
def __call__(self, hidden_states, attention_mask, deterministic: bool = True): self.attention = FlaxRobertaAttention(self.config, name="attention", dtype=self.dtype)
attention = FlaxRobertaAttention( self.intermediate = FlaxRobertaIntermediate(self.config, name="intermediate", dtype=self.dtype)
self.num_heads, self.output = FlaxRobertaOutput(self.config, name="output", dtype=self.dtype)
self.head_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="attention",
dtype=self.dtype,
)(hidden_states, attention_mask, deterministic=deterministic)
intermediate = FlaxRobertaIntermediate(
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
hidden_act=self.hidden_act,
name="intermediate",
dtype=self.dtype,
)(attention)
output = FlaxRobertaOutput(
kernel_init_scale=self.kernel_init_scale, dropout_rate=self.dropout_rate, name="output", dtype=self.dtype
)(intermediate, attention, deterministic=deterministic)
return output def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic)
hidden_states = self.intermediate(attention_output)
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
return hidden_states
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta
class FlaxRobertaLayerCollection(nn.Module): class FlaxRobertaLayerCollection(nn.Module):
""" config: RobertaConfig
Stores N RobertaLayer(s)
"""
num_layers: int
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact def setup(self):
def __call__(self, inputs, attention_mask, deterministic: bool = True): self.layers = [
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})" FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
]
# Initialize input / output
input_i = inputs def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
for layer in self.layers:
# Forward over all encoders hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic)
for i in range(self.num_layers): return hidden_states
layer = FlaxRobertaLayer(
self.num_heads,
self.head_size,
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
hidden_act=self.hidden_act,
name=f"{i}",
dtype=self.dtype,
)
input_i = layer(input_i, attention_mask, deterministic=deterministic)
return input_i
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta
class FlaxRobertaEncoder(nn.Module): class FlaxRobertaEncoder(nn.Module):
num_layers: int config: RobertaConfig
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact def setup(self):
self.layers = FlaxRobertaLayerCollection(self.config, name="layer", dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic: bool = True): def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
layer = FlaxRobertaLayerCollection( return self.layers(hidden_states, attention_mask, deterministic=deterministic)
self.num_layers,
self.num_heads,
self.head_size,
self.intermediate_size,
hidden_act=self.hidden_act,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="layer",
dtype=self.dtype,
)(hidden_states, attention_mask, deterministic=deterministic)
return layer
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
class FlaxRobertaPooler(nn.Module): class FlaxRobertaPooler(nn.Module):
kernel_init_scale: float = 0.2 config: RobertaConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact def setup(self):
def __call__(self, hidden_states): self.dense = nn.Dense(
cls_token = hidden_states[:, 0] self.config.hidden_size,
out = nn.Dense( kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
hidden_states.shape[-1],
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
name="dense", name="dense",
dtype=self.dtype, dtype=self.dtype,
)(cls_token) )
return nn.tanh(out)
def __call__(self, hidden_states):
cls_hidden_state = hidden_states[:, 0]
cls_hidden_state = self.dense(cls_hidden_state)
return nn.tanh(cls_hidden_state)
class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel): class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
...@@ -520,21 +482,7 @@ class FlaxRobertaModel(FlaxRobertaPreTrainedModel): ...@@ -520,21 +482,7 @@ class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
**kwargs **kwargs
): ):
module = FlaxRobertaModule( module = FlaxRobertaModule(config, dtype=dtype, **kwargs)
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
type_vocab_size=config.type_vocab_size,
max_length=config.max_position_embeddings,
num_encoder_layers=config.num_hidden_layers,
num_heads=config.num_attention_heads,
head_size=config.hidden_size,
hidden_act=config.hidden_act,
intermediate_size=config.intermediate_size,
dropout_rate=config.hidden_dropout_prob,
dtype=dtype,
**kwargs,
)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
...@@ -570,50 +518,24 @@ class FlaxRobertaModel(FlaxRobertaPreTrainedModel): ...@@ -570,50 +518,24 @@ class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta
class FlaxRobertaModule(nn.Module): class FlaxRobertaModule(nn.Module):
vocab_size: int config: RobertaConfig
hidden_size: int
type_vocab_size: int
max_length: int
num_encoder_layers: int
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
add_pooling_layer: bool = True add_pooling_layer: bool = True
@nn.compact def setup(self):
self.embeddings = FlaxRobertaEmbeddings(self.config, name="embeddings", dtype=self.dtype)
self.encoder = FlaxRobertaEncoder(self.config, name="encoder", dtype=self.dtype)
self.pooler = FlaxRobertaPooler(self.config, name="pooler", dtype=self.dtype)
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True): def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
# Embedding hidden_states = self.embeddings(
embeddings = FlaxRobertaEmbeddings( input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
self.vocab_size, )
self.hidden_size, hidden_states = self.encoder(hidden_states, attention_mask, deterministic=deterministic)
self.type_vocab_size,
self.max_length,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="embeddings",
dtype=self.dtype,
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
# N stacked encoding layers
encoder = FlaxRobertaEncoder(
self.num_encoder_layers,
self.num_heads,
self.head_size,
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
hidden_act=self.hidden_act,
name="encoder",
dtype=self.dtype,
)(embeddings, attention_mask, deterministic=deterministic)
if not self.add_pooling_layer: if not self.add_pooling_layer:
return encoder return hidden_states
pooled = FlaxRobertaPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder) pooled = self.pooler(hidden_states)
return encoder, pooled return hidden_states, pooled
...@@ -60,6 +60,7 @@ def random_attention_mask(shape, rng=None): ...@@ -60,6 +60,7 @@ def random_attention_mask(shape, rng=None):
return attn_mask return attn_mask
@require_flax
class FlaxModelTesterMixin: class FlaxModelTesterMixin:
model_tester = None model_tester = None
all_model_classes = () all_model_classes = ()
...@@ -90,7 +91,7 @@ class FlaxModelTesterMixin: ...@@ -90,7 +91,7 @@ class FlaxModelTesterMixin:
fx_outputs = fx_model(**inputs_dict) fx_outputs = fx_model(**inputs_dict)
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs): for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3) self.assert_almost_equals(fx_output, pt_output.numpy(), 2e-3)
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname) pt_model.save_pretrained(tmpdirname)
...@@ -103,7 +104,6 @@ class FlaxModelTesterMixin: ...@@ -103,7 +104,6 @@ class FlaxModelTesterMixin:
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3) self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3)
@require_flax
def test_from_pretrained_save_pretrained(self): def test_from_pretrained_save_pretrained(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -121,7 +121,6 @@ class FlaxModelTesterMixin: ...@@ -121,7 +121,6 @@ class FlaxModelTesterMixin:
for output_loaded, output in zip(outputs_loaded, outputs): for output_loaded, output in zip(outputs_loaded, outputs):
self.assert_almost_equals(output_loaded, output, 5e-3) self.assert_almost_equals(output_loaded, output, 5e-3)
@require_flax
def test_jit_compilation(self): def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -144,7 +143,6 @@ class FlaxModelTesterMixin: ...@@ -144,7 +143,6 @@ class FlaxModelTesterMixin:
for jitted_output, output in zip(jitted_outputs, outputs): for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape) self.assertEqual(jitted_output.shape, output.shape)
@require_flax
def test_naming_convention(self): def test_naming_convention(self):
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model_class_name = model_class.__name__ model_class_name = model_class.__name__
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment