Unverified Commit 75627148 authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

Flax Masked Language Modeling training example (#8728)



* Remove "Model" suffix from Flax models to look more :hugs:
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Initial working (forward + backward) for Flax MLM training example.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Simply code
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Addressing comments, using module and moving to LM task.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Restore parameter name "module" wrongly renamed model.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Restore correct output ordering...
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Actually commit the example 😅

Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Add FlaxBertModelForMaskedLM after rebasing.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Make it possible to initialize the training from scratch
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Reuse flax linen example of cross entropy loss
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Added specific data collator for flax
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Remove todo for data collator
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Added evaluation step
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Added ability to provide dtype to support bfloat16 on TPU
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Enable flax tensorboard output
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Enable jax.pmap support.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Ensure batches are correctly sized to be dispatched with jax.pmap
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Enable bfloat16 with --fp16 cmdline args
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Correctly export metrics to tensorboard
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Added dropout and ability to use it.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Effectively enable & disable during training and evaluation steps.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Oops.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Enable specifying kernel initializer scale
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Style.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Added warmup step to the learning rate scheduler.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Fix typo.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Print training loss
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Make style
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* fix linter issue (flake8)
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Fix model matching
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Fix dummies
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Fix non default dtype on Flax models
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Use the same create_position_ids_from_input_ids for FlaxRoberta
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Make Roberta attention as Bert
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* fix copy
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Wording.
Co-authored-by: default avatarMarc van Zee <marcvanzee@gmail.com>
Co-authored-by: default avatarMarc van Zee <marcvanzee@gmail.com>
parent df2af6d8
This diff is collapsed.
......@@ -936,7 +936,7 @@ else:
if is_flax_available():
from .models.auto import FLAX_MODEL_MAPPING, FlaxAutoModel
from .models.bert import FlaxBertModel
from .models.bert import FlaxBertForMaskedLM, FlaxBertModel
from .models.roberta import FlaxRobertaModel
else:
# Import the same objects as dummies to get them in the namespace.
......
......@@ -65,13 +65,12 @@ class FlaxPreTrainedModel(ABC):
base_model_prefix = ""
model_class = None
def __init__(self, config: PretrainedConfig, module: nn.Module, params: Dict, seed: int = 0):
def __init__(
self, config: PretrainedConfig, module: nn.Module, params: Dict, seed: int = 0, dtype: jnp.dtype = jnp.float32
):
if config is None:
raise ValueError("config cannot be None")
if module is None:
raise ValueError("module cannot be None")
if params is None:
raise ValueError("state cannot be None")
......@@ -82,19 +81,23 @@ class FlaxPreTrainedModel(ABC):
# Those are public as their type is generic to every derived classes.
self.key = PRNGKey(seed)
self.params = params
self.model = module
self.dtype = dtype
@property
def config(self) -> PretrainedConfig:
return self._config
@property
def module(self) -> nn.Module:
return self._module
@staticmethod
@abstractmethod
def convert_from_pytorch(pt_state: Dict, config: PretrainedConfig) -> Dict:
raise NotImplementedError()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
def from_pretrained(cls, pretrained_model_name_or_path, dtype: jnp.dtype = jnp.float32, *model_args, **kwargs):
r"""
Instantiate a pretrained Flax model from a pre-trained model configuration.
"""
......@@ -127,6 +130,9 @@ class FlaxPreTrainedModel(ABC):
else:
model_kwargs = kwargs
# Add the dtype to model_kwargs
model_kwargs["dtype"] = dtype
# Load model
if pretrained_model_name_or_path is not None:
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
......
......@@ -59,4 +59,4 @@ if is_tf_available():
)
if is_flax_available():
from .modeling_flax_bert import FlaxBertModel
from .modeling_flax_bert import FlaxBertForMaskedLM, FlaxBertModel
......@@ -12,13 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Dict
from typing import Callable, Dict, Tuple
import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from jax.random import PRNGKey
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_utils import FlaxPreTrainedModel, gelu
......@@ -101,8 +102,8 @@ class FlaxRobertaLayerNorm(nn.Module):
scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear
# (also e.g. nn.relu), this can be disabled since the scaling will be
# done by the next layer.
bias_init: jnp.ndarray = nn.initializers.zeros
scale_init: jnp.ndarray = nn.initializers.ones
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
@nn.compact
def __call__(self, x):
......@@ -122,11 +123,13 @@ class FlaxRobertaLayerNorm(nn.Module):
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
var = mean2 - jax.lax.square(mean)
mul = jax.lax.rsqrt(var + self.epsilon)
if self.scale:
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)), self.dtype)
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)))
y = (x - mean) * mul
if self.bias:
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)), self.dtype)
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)))
return y
......@@ -139,7 +142,9 @@ class FlaxRobertaEmbedding(nn.Module):
vocab_size: int
hidden_size: int
emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1)
kernel_init_scale: float = 0.2
emb_init: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=kernel_init_scale)
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, inputs):
......@@ -155,66 +160,108 @@ class FlaxRobertaEmbeddings(nn.Module):
hidden_size: int
type_vocab_size: int
max_length: int
kernel_init_scale: float = 0.2
dropout_rate: float = 0.0
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
# Embed
w_emb = FlaxRobertaEmbedding(self.vocab_size, self.hidden_size, name="word_embeddings")(
jnp.atleast_2d(input_ids.astype("i4"))
)
p_emb = FlaxRobertaEmbedding(self.max_length, self.hidden_size, name="position_embeddings")(
jnp.atleast_2d(position_ids.astype("i4"))
)
t_emb = FlaxRobertaEmbedding(self.type_vocab_size, self.hidden_size, name="token_type_embeddings")(
jnp.atleast_2d(token_type_ids.astype("i4"))
)
w_emb = FlaxRobertaEmbedding(
self.vocab_size,
self.hidden_size,
kernel_init_scale=self.kernel_init_scale,
name="word_embeddings",
dtype=self.dtype,
)(jnp.atleast_2d(input_ids.astype("i4")))
p_emb = FlaxRobertaEmbedding(
self.max_length,
self.hidden_size,
kernel_init_scale=self.kernel_init_scale,
name="position_embeddings",
dtype=self.dtype,
)(jnp.atleast_2d(position_ids.astype("i4")))
t_emb = FlaxRobertaEmbedding(
self.type_vocab_size,
self.hidden_size,
kernel_init_scale=self.kernel_init_scale,
name="token_type_embeddings",
dtype=self.dtype,
)(jnp.atleast_2d(token_type_ids.astype("i4")))
# Sum all embeddings
summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb
# Layer Norm
layer_norm = FlaxRobertaLayerNorm(name="layer_norm")(summed_emb)
return layer_norm
layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(summed_emb)
embeddings = nn.Dropout(rate=self.dropout_rate)(layer_norm, deterministic=deterministic)
return embeddings
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta
class FlaxRobertaAttention(nn.Module):
num_heads: int
head_size: int
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state, attention_mask):
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")(
hidden_state, attention_mask
)
self_att = nn.attention.SelfAttention(
num_heads=self.num_heads,
qkv_features=self.head_size,
dropout_rate=self.dropout_rate,
deterministic=deterministic,
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
bias_init=jax.nn.initializers.zeros,
name="self",
dtype=self.dtype,
)(hidden_state, attention_mask)
layer_norm = FlaxRobertaLayerNorm(name="layer_norm")(self_att + hidden_state)
layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_state)
return layer_norm
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
class FlaxRobertaIntermediate(nn.Module):
output_size: int
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state):
# TODO: Add ACT2FN reference to change activation function
dense = nn.Dense(features=self.output_size, name="dense")(hidden_state)
dense = nn.Dense(
features=self.output_size,
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
name="dense",
dtype=self.dtype,
)(hidden_state)
return gelu(dense)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta
class FlaxRobertaOutput(nn.Module):
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, intermediate_output, attention_output):
hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output)
hidden_state = FlaxRobertaLayerNorm(name="layer_norm")(hidden_state + attention_output)
def __call__(self, intermediate_output, attention_output, deterministic: bool = True):
hidden_state = nn.Dense(
attention_output.shape[-1],
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
name="dense",
dtype=self.dtype,
)(intermediate_output)
hidden_state = nn.Dropout(rate=self.dropout_rate)(hidden_state, deterministic=deterministic)
hidden_state = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_state + attention_output)
return hidden_state
......@@ -222,14 +269,29 @@ class FlaxRobertaLayer(nn.Module):
num_heads: int
head_size: int
intermediate_size: int
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state, attention_mask):
attention = FlaxRobertaAttention(self.num_heads, self.head_size, name="attention")(
hidden_state, attention_mask
)
intermediate = FlaxRobertaIntermediate(self.intermediate_size, name="intermediate")(attention)
output = FlaxRobertaOutput(name="output")(intermediate, attention)
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
attention = FlaxRobertaAttention(
self.num_heads,
self.head_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="attention",
dtype=self.dtype,
)(hidden_state, attention_mask, deterministic=deterministic)
intermediate = FlaxRobertaIntermediate(
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
name="intermediate",
dtype=self.dtype,
)(attention)
output = FlaxRobertaOutput(
kernel_init_scale=self.kernel_init_scale, dropout_rate=self.dropout_rate, name="output", dtype=self.dtype
)(intermediate, attention, deterministic=deterministic)
return output
......@@ -244,9 +306,12 @@ class FlaxRobertaLayerCollection(nn.Module):
num_heads: int
head_size: int
intermediate_size: int
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, inputs, attention_mask):
def __call__(self, inputs, attention_mask, deterministic: bool = True):
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})"
# Initialize input / output
......@@ -254,8 +319,16 @@ class FlaxRobertaLayerCollection(nn.Module):
# Forward over all encoders
for i in range(self.num_layers):
layer = FlaxRobertaLayer(self.num_heads, self.head_size, self.intermediate_size, name=f"{i}")
input_i = layer(input_i, attention_mask)
layer = FlaxRobertaLayer(
self.num_heads,
self.head_size,
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name=f"{i}",
dtype=self.dtype,
)
input_i = layer(input_i, attention_mask, deterministic=deterministic)
return input_i
......@@ -265,22 +338,40 @@ class FlaxRobertaEncoder(nn.Module):
num_heads: int
head_size: int
intermediate_size: int
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state, attention_mask):
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
layer = FlaxRobertaLayerCollection(
self.num_layers, self.num_heads, self.head_size, self.intermediate_size, name="layer"
)(hidden_state, attention_mask)
self.num_layers,
self.num_heads,
self.head_size,
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="layer",
dtype=self.dtype,
)(hidden_state, attention_mask, deterministic=deterministic)
return layer
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
class FlaxRobertaPooler(nn.Module):
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state):
cls_token = hidden_state[:, 0]
out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token)
return jax.lax.tanh(out)
out = nn.Dense(
hidden_state.shape[-1],
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
name="dense",
dtype=self.dtype,
)(cls_token)
return nn.tanh(out)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta
......@@ -293,21 +384,38 @@ class FlaxRobertaModule(nn.Module):
num_heads: int
head_size: int
intermediate_size: int
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids):
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
# Embedding
embeddings = FlaxRobertaEmbeddings(
self.vocab_size, self.hidden_size, self.type_vocab_size, self.max_length, name="embeddings"
)(input_ids, token_type_ids, position_ids, attention_mask)
self.vocab_size,
self.hidden_size,
self.type_vocab_size,
self.max_length,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="embeddings",
dtype=self.dtype,
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
# N stacked encoding layers
encoder = FlaxRobertaEncoder(
self.num_encoder_layers, self.num_heads, self.head_size, self.intermediate_size, name="encoder"
)(embeddings, attention_mask)
pooled = FlaxRobertaPooler(name="pooler")(encoder)
self.num_encoder_layers,
self.num_heads,
self.head_size,
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="encoder",
dtype=self.dtype,
)(embeddings, attention_mask, deterministic=deterministic)
pooled = FlaxRobertaPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder)
return encoder, pooled
......@@ -396,8 +504,8 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
return jax_state
def __init__(self, config: RobertaConfig, state: dict, seed: int = 0, **kwargs):
model = FlaxRobertaModule(
def __init__(self, config: RobertaConfig, state: dict, seed: int = 0, dtype: jnp.dtype = jnp.float32):
module = FlaxRobertaModule(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
type_vocab_size=config.type_vocab_size,
......@@ -406,31 +514,78 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
num_heads=config.num_attention_heads,
head_size=config.hidden_size,
intermediate_size=config.intermediate_size,
dropout_rate=config.hidden_dropout_prob,
dtype=dtype,
)
super().__init__(config, model, state, seed)
@property
def module(self) -> nn.Module:
return self._module
super().__init__(config, module, state, seed)
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None):
def __call__(
self,
input_ids,
token_type_ids=None,
attention_mask=None,
position_ids=None,
params: dict = None,
dropout_rng: PRNGKey = None,
train: bool = False,
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
rngs=rngs,
)
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
jnp.zeros(input_shape, dtype="i4"), None, None, None
)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
self.params = self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids):
if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids)
if position_ids is None:
position_ids = jnp.arange(
self.config.pad_token_id + 1, jnp.atleast_2d(input_ids).shape[-1] + self.config.pad_token_id + 1
)
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
return self.model.apply(
{"params": self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
)
return input_ids, attention_mask, token_type_ids, position_ids
def create_position_ids_from_input_ids(input_ids, padding_idx):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.
Args:
input_ids: jnp.ndarray
padding_idx: int
Returns: jnp.ndarray
"""
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = (input_ids != padding_idx).astype("i4")
incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
return incremental_indices.astype("i4") + padding_idx
......@@ -14,6 +14,15 @@ class FlaxAutoModel:
requires_flax(self)
class FlaxBertForMaskedLM:
def __init__(self, *args, **kwargs):
requires_flax(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_flax(self)
class FlaxBertModel:
def __init__(self, *args, **kwargs):
requires_flax(self)
......
......@@ -57,7 +57,7 @@ class FlaxRobertaModelTest(unittest.TestCase):
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs.to_tuple()):
self.assert_almost_equals(fx_output, pt_output.numpy(), 6e-4)
self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3)
def test_multiple_sequences(self):
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment