"tests/models/gptj/test_modeling_gptj.py" did not exist on "a573777901e662ec2e565be312ffaeedef6effec"
Unverified Commit 21aecc09 authored by Daniel Stancl's avatar Daniel Stancl Committed by GitHub
Browse files

Add Flax RoFormer (#15005)



* Add FlaxRoFormer

* Clean code + make quality

* Fix output pooling for FlaxRoFormerForMultipleChoiceModule

* Apply suggestions from code review

* add flax model to repos
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 9e1775dd
......@@ -246,7 +246,7 @@ Flax), PyTorch, and/or TensorFlow.
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
| RoFormer | ✅ | ✅ | ✅ | ✅ | |
| RoFormer | ✅ | ✅ | ✅ | ✅ | |
| SegFormer | ❌ | ❌ | ✅ | ❌ | ❌ |
| SEW | ❌ | ❌ | ✅ | ❌ | ❌ |
| SEW-D | ❌ | ❌ | ✅ | ❌ | ❌ |
......
......@@ -123,3 +123,33 @@ This model was contributed by [junnyu](https://huggingface.co/junnyu). The origi
[[autodoc]] TFRoFormerForQuestionAnswering
- call
## FlaxRoFormerModel
[[autodoc]] FlaxRoFormerModel
- __call__
## FlaxRoFormerForMaskedLM
[[autodoc]] FlaxRoFormerForMaskedLM
- __call__
## FlaxRoFormerForSequenceClassification
[[autodoc]] FlaxRoFormerForSequenceClassification
- __call__
## FlaxRoFormerForMultipleChoice
[[autodoc]] FlaxRoFormerForMultipleChoice
- __call__
## FlaxRoFormerForTokenClassification
[[autodoc]] FlaxRoFormerForTokenClassification
- __call__
## FlaxRoFormerForQuestionAnswering
[[autodoc]] FlaxRoFormerForQuestionAnswering
- __call__
......@@ -2084,6 +2084,17 @@ if is_flax_available():
"FlaxRobertaPreTrainedModel",
]
)
_import_structure["models.roformer"].extend(
[
"FlaxRoFormerForMaskedLM",
"FlaxRoFormerForMultipleChoice",
"FlaxRoFormerForQuestionAnswering",
"FlaxRoFormerForSequenceClassification",
"FlaxRoFormerForTokenClassification",
"FlaxRoFormerModel",
"FlaxRoFormerPreTrainedModel",
]
)
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"])
_import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel")
_import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"])
......@@ -3819,6 +3830,15 @@ if TYPE_CHECKING:
FlaxRobertaModel,
FlaxRobertaPreTrainedModel,
)
from .models.roformer import (
FlaxRoFormerForMaskedLM,
FlaxRoFormerForMultipleChoice,
FlaxRoFormerForQuestionAnswering,
FlaxRoFormerForSequenceClassification,
FlaxRoFormerForTokenClassification,
FlaxRoFormerModel,
FlaxRoFormerPreTrainedModel,
)
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
from .models.vision_encoder_decoder import FlaxVisionEncoderDecoderModel
from .models.vision_text_dual_encoder import FlaxVisionTextDualEncoderModel
......
......@@ -50,6 +50,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
("wav2vec2", "FlaxWav2Vec2Model"),
("marian", "FlaxMarianModel"),
("blenderbot", "FlaxBlenderbotModel"),
("roformer", "FlaxRoFormerModel"),
]
)
......@@ -66,6 +67,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("t5", "FlaxT5ForConditionalGeneration"),
("mt5", "FlaxMT5ForConditionalGeneration"),
("wav2vec2", "FlaxWav2Vec2ForPreTraining"),
("roformer", "FlaxRoFormerForMaskedLM"),
]
)
......@@ -80,6 +82,7 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
("bart", "FlaxBartForConditionalGeneration"),
("electra", "FlaxElectraForMaskedLM"),
("mbart", "FlaxMBartForConditionalGeneration"),
("roformer", "FlaxRoFormerForMaskedLM"),
]
)
......@@ -132,6 +135,7 @@ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("bart", "FlaxBartForSequenceClassification"),
("electra", "FlaxElectraForSequenceClassification"),
("mbart", "FlaxMBartForSequenceClassification"),
("roformer", "FlaxRoFormerForSequenceClassification"),
]
)
......@@ -146,6 +150,7 @@ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
("bart", "FlaxBartForQuestionAnswering"),
("electra", "FlaxElectraForQuestionAnswering"),
("mbart", "FlaxMBartForQuestionAnswering"),
("roformer", "FlaxRoFormerForQuestionAnswering"),
]
)
......@@ -158,6 +163,7 @@ FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("bert", "FlaxBertForTokenClassification"),
("big_bird", "FlaxBigBirdForTokenClassification"),
("electra", "FlaxElectraForTokenClassification"),
("roformer", "FlaxRoFormerForTokenClassification"),
]
)
......@@ -170,6 +176,7 @@ FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
("bert", "FlaxBertForMultipleChoice"),
("big_bird", "FlaxBigBirdForMultipleChoice"),
("electra", "FlaxElectraForMultipleChoice"),
("roformer", "FlaxRoFormerForMultipleChoice"),
]
)
......
......@@ -17,7 +17,7 @@
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available
from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available
_import_structure = {
......@@ -59,6 +59,19 @@ if is_tf_available():
]
if is_flax_available():
_import_structure["modeling_flax_roformer"] = [
"FLAX_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"FlaxRoFormerForMaskedLM",
"FlaxRoFormerForMultipleChoice",
"FlaxRoFormerForQuestionAnswering",
"FlaxRoFormerForSequenceClassification",
"FlaxRoFormerForTokenClassification",
"FlaxRoFormerModel",
"FlaxRoFormerPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig
from .tokenization_roformer import RoFormerTokenizer
......@@ -95,6 +108,18 @@ if TYPE_CHECKING:
TFRoFormerPreTrainedModel,
)
if is_flax_available():
from .modeling_flax_roformer import (
FLAX_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
FlaxRoFormerForMaskedLM,
FlaxRoFormerForMultipleChoice,
FlaxRoFormerForQuestionAnswering,
FlaxRoFormerForSequenceClassification,
FlaxRoFormerForTokenClassification,
FlaxRoFormerModel,
FlaxRoFormerPreTrainedModel,
)
else:
import sys
......
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Flax RoFormer model."""
from typing import Callable, Optional, Tuple
import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.linen.attention import dot_product_attention_weights
from jax import lax
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_outputs import (
FlaxBaseModelOutput,
FlaxMaskedLMOutput,
FlaxMultipleChoiceModelOutput,
FlaxQuestionAnsweringModelOutput,
FlaxSequenceClassifierOutput,
FlaxTokenClassifierOutput,
)
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, overwrite_call_docstring
from ...utils import logging
from .configuration_roformer import RoFormerConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "junnyu/roformer_chinese_base"
_CONFIG_FOR_DOC = "RoFormerConfig"
_TOKENIZER_FOR_DOC = "RoFormerTokenizer"
FLAX_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"junnyu/roformer_chinese_small",
"junnyu/roformer_chinese_base",
"junnyu/roformer_chinese_char_small",
"junnyu/roformer_chinese_char_base",
"junnyu/roformer_small_discriminator",
"junnyu/roformer_small_generator"
# See all RoFormer models at https://huggingface.co/models?filter=roformer
]
ROFORMER_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters:
config ([`RoFormerConfig`]): Model configuration class with all the parameters of the
model. Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given `dtype`.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
[`~FlaxPreTrainedModel.to_bf16`].
"""
ROFORMER_INPUTS_DOCSTRING = r"""
Args:
input_ids (`numpy.ndarray` of shape `({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`RoFormerTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`numpy.ndarray` of shape `({0})`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
1]`:
- 0 corresponds to a *sentence A* token,
- 1 corresponds to a *sentence B* token.
[What are token type IDs?](../glossary#token-type-ids)
position_ids (`numpy.ndarray` of shape `({0})`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
head_mask (`numpy.ndarray` of shape `({0})`, `optional):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
"""
# Copied from transformers.models.marian.modeling_flax_marian.create_sinusoidal_positions
def create_sinusoidal_positions(n_pos, dim):
position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
sentinel = dim // 2 + dim % 2
out = np.zeros_like(position_enc)
out[:, 0:sentinel] = np.sin(position_enc[:, 0::2])
out[:, sentinel:] = np.cos(position_enc[:, 1::2])
return jnp.array(out)
class FlaxRoFormerEmbeddings(nn.Module):
"""Construct the embeddings from word and token_type embeddings."""
config: RoFormerConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.word_embeddings = nn.Embed(
self.config.vocab_size,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
)
self.token_type_embeddings = nn.Embed(
self.config.type_vocab_size,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, input_ids, token_type_ids, attention_mask, deterministic: bool = True):
# Embed
inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
# Sum all embeddings
hidden_states = inputs_embeds + token_type_embeddings
# Layer Norm
hidden_states = self.LayerNorm(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
return hidden_states
class FlaxRoFormerSelfAttention(nn.Module):
config: RoFormerConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self) -> None:
if self.config.hidden_size % self.config.num_attention_heads != 0:
raise ValueError(
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\
: {self.config.num_attention_heads}"
)
self.query = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.key = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.value = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.rotary_value = self.config.rotary_value
def __call__(
self,
hidden_states,
attention_mask,
sinusoidal_pos,
layer_head_mask,
deterministic=True,
output_attentions: bool = False,
):
head_dim = self.config.hidden_size // self.config.num_attention_heads
query_states = self.query(hidden_states).reshape(
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
)
value_states = self.value(hidden_states).reshape(
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
)
key_states = self.key(hidden_states).reshape(
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
)
if sinusoidal_pos is not None:
if self.rotary_value:
query_states, key_states, value_states = self.apply_rotary_position_embeddings(
sinusoidal_pos, query_states, key_states, value_states
)
else:
query_states, key_states = self.apply_rotary_position_embeddings(
sinusoidal_pos, query_states, key_states
)
# Convert the boolean attention mask to an attention bias.
if attention_mask is not None:
# attention mask in the form of attention bias
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
attention_bias = lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e10).astype(self.dtype),
)
else:
attention_bias = None
dropout_rng = None
if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
dropout_rng = self.make_rng("dropout")
attn_weights = dot_product_attention_weights(
query_states,
key_states,
bias=attention_bias,
dropout_rng=dropout_rng,
dropout_rate=self.config.attention_probs_dropout_prob,
broadcast_dropout=True,
deterministic=deterministic,
dtype=self.dtype,
precision=None,
)
# Mask heads if we want to
if layer_head_mask is not None:
attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask)
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
return outputs
@staticmethod
def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None):
sin, cos = sinusoidal_pos.split(2, axis=-1)
sin_pos = jnp.stack([sin, sin], axis=-1).reshape(sinusoidal_pos.shape)
cos_pos = jnp.stack([cos, cos], axis=-1).reshape(sinusoidal_pos.shape)
def rotate_layer(layer, sin_pos, cos_pos):
rotate_half_layer = jnp.stack([-layer[..., 1::2], layer[..., ::2]], axis=-1).reshape(layer.shape)
rotary_matrix_cos = jnp.einsum("bslh,...sh->bslh", layer, cos_pos)
rotary_matrix_sin = jnp.einsum("bslh,...sh->bslh", rotate_half_layer, sin_pos)
return rotary_matrix_cos + rotary_matrix_sin
query_layer = rotate_layer(query_layer, sin_pos, cos_pos)
key_layer = rotate_layer(key_layer, sin_pos, cos_pos)
if value_layer is not None:
value_layer = rotate_layer(value_layer, sin_pos, cos_pos)
return query_layer, key_layer, value_layer
return query_layer, key_layer
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->RoFormer
class FlaxRoFormerSelfOutput(nn.Module):
config: RoFormerConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class FlaxRoFormerAttention(nn.Module):
config: RoFormerConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.self = FlaxRoFormerSelfAttention(self.config, dtype=self.dtype)
self.output = FlaxRoFormerSelfOutput(self.config, dtype=self.dtype)
def __call__(
self,
hidden_states,
attention_mask,
sinusoidal_pos,
layer_head_mask,
deterministic=True,
output_attentions: bool = False,
):
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
attn_outputs = self.self(
hidden_states,
attention_mask,
sinusoidal_pos,
layer_head_mask=layer_head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
)
attn_output = attn_outputs[0]
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_outputs[1],)
return outputs
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->RoFormer
class FlaxRoFormerIntermediate(nn.Module):
config: RoFormerConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.dense = nn.Dense(
self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.activation = ACT2FN[self.config.hidden_act]
def __call__(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->RoFormer
class FlaxRoFormerOutput(nn.Module):
config: RoFormerConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.LayerNorm(hidden_states + attention_output)
return hidden_states
class FlaxRoFormerLayer(nn.Module):
config: RoFormerConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.attention = FlaxRoFormerAttention(self.config, dtype=self.dtype)
self.intermediate = FlaxRoFormerIntermediate(self.config, dtype=self.dtype)
self.output = FlaxRoFormerOutput(self.config, dtype=self.dtype)
def __call__(
self,
hidden_states,
attention_mask,
sinusiodal_pos,
layer_head_mask,
deterministic: bool = True,
output_attentions: bool = False,
):
attention_outputs = self.attention(
hidden_states,
attention_mask,
sinusiodal_pos,
layer_head_mask=layer_head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
)
attention_output = attention_outputs[0]
hidden_states = self.intermediate(attention_output)
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
outputs = (hidden_states,)
if output_attentions:
outputs += (attention_outputs[1],)
return outputs
class FlaxRoFormerLayerCollection(nn.Module):
config: RoFormerConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.layers = [
FlaxRoFormerLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
]
def __call__(
self,
hidden_states,
attention_mask,
sinusoidal_pos,
head_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
# Check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
if head_mask.shape[0] != (len(self.layers)):
raise ValueError(
f"The head_mask should be specified for {len(self.layers)} layers, but it is for \
{head_mask.shape[0]}."
)
for i, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = layer(
hidden_states,
attention_mask,
sinusoidal_pos,
layer_head_mask=head_mask[i] if head_mask is not None else None,
deterministic=deterministic,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions += (layer_outputs[1],)
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = (hidden_states,)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
class FlaxRoFormerEncoder(nn.Module):
config: RoFormerConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.embed_positions = create_sinusoidal_positions(
self.config.max_position_embeddings, self.config.hidden_size // self.config.num_attention_heads
)
self.layer = FlaxRoFormerLayerCollection(self.config, dtype=self.dtype)
def __call__(
self,
hidden_states,
attention_mask,
head_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
sinusoidal_pos = self.embed_positions[: hidden_states.shape[1], :]
return self.layer(
hidden_states,
attention_mask,
sinusoidal_pos,
head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPredictionHeadTransform with Bert->RoFormer
class FlaxRoFormerPredictionHeadTransform(nn.Module):
config: RoFormerConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
self.activation = ACT2FN[self.config.hidden_act]
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
def __call__(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.activation(hidden_states)
return self.LayerNorm(hidden_states)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLMPredictionHead with Bert->RoFormer
class FlaxRoFormerLMPredictionHead(nn.Module):
config: RoFormerConfig
dtype: jnp.dtype = jnp.float32
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
def setup(self):
self.transform = FlaxRoFormerPredictionHeadTransform(self.config, dtype=self.dtype)
self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False)
self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
def __call__(self, hidden_states, shared_embedding=None):
hidden_states = self.transform(hidden_states)
if shared_embedding is not None:
hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
else:
hidden_states = self.decoder(hidden_states)
bias = jnp.asarray(self.bias, self.dtype)
hidden_states += bias
return hidden_states
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOnlyMLMHead with Bert->RoFormer
class FlaxRoFormerOnlyMLMHead(nn.Module):
config: RoFormerConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.predictions = FlaxRoFormerLMPredictionHead(self.config, dtype=self.dtype)
def __call__(self, hidden_states, shared_embedding=None):
hidden_states = self.predictions(hidden_states, shared_embedding=shared_embedding)
return hidden_states
class FlaxRoFormerClassificationHead(nn.Module):
config: RoFormerConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.out_proj = nn.Dense(
self.config.num_labels,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
def __call__(self, hidden_states, deterministic=True):
hidden_states = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS])
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.dense(hidden_states)
hidden_states = nn.tanh(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.out_proj(hidden_states)
return hidden_states
class FlaxRoFormerPreTrainedModel(FlaxPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = RoFormerConfig
base_model_prefix = "roformer"
module_class: nn.Module = None
def __init__(
self,
config: RoFormerConfig,
input_shape: Tuple = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
token_type_ids = jnp.zeros_like(input_ids)
attention_mask = jnp.ones_like(input_ids)
head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, head_mask, return_dict=False)[
"params"
]
@add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
head_mask=None,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
# init input tensors if not passed
if token_type_ids is None:
token_type_ids = jnp.zeros_like(input_ids)
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
if head_mask is None:
head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(head_mask, dtype="i4"),
not train,
output_attentions,
output_hidden_states,
return_dict,
rngs=rngs,
)
class FlaxRoFormerModule(nn.Module):
config: RoFormerConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.embeddings = FlaxRoFormerEmbeddings(self.config, dtype=self.dtype)
self.encoder = FlaxRoFormerEncoder(self.config, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
token_type_ids,
head_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
hidden_states = self.embeddings(input_ids, token_type_ids, attention_mask, deterministic=deterministic)
outputs = self.encoder(
hidden_states,
attention_mask,
head_mask=head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if not return_dict:
return (hidden_states,) + outputs[1:]
return FlaxBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"The bare RoFormer Model transformer outputting raw hidden-states without any specific head on top.",
ROFORMER_START_DOCSTRING,
)
class FlaxRoFormerModel(FlaxRoFormerPreTrainedModel):
module_class = FlaxRoFormerModule
append_call_sample_docstring(
FlaxRoFormerModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC
)
class FlaxRoFormerForMaskedLMModule(nn.Module):
config: RoFormerConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.roformer = FlaxRoFormerModule(config=self.config, dtype=self.dtype)
self.cls = FlaxRoFormerOnlyMLMHead(config=self.config, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
token_type_ids,
head_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# Model
outputs = self.roformer(
input_ids,
attention_mask,
token_type_ids,
head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.config.tie_word_embeddings:
shared_embedding = self.roformer.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
else:
shared_embedding = None
# Compute the prediction scores
logits = self.cls(hidden_states, shared_embedding=shared_embedding)
if not return_dict:
return (logits,) + outputs[1:]
return FlaxMaskedLMOutput(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings("""RoFormer Model with a `language modeling` head on top.""", ROFORMER_START_DOCSTRING)
class FlaxRoFormerForMaskedLM(FlaxRoFormerPreTrainedModel):
module_class = FlaxRoFormerForMaskedLMModule
append_call_sample_docstring(
FlaxRoFormerForMaskedLM,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxMaskedLMOutput,
_CONFIG_FOR_DOC,
mask="<mask>",
)
class FlaxRoFormerForSequenceClassificationModule(nn.Module):
config: RoFormerConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.roformer = FlaxRoFormerModule(config=self.config, dtype=self.dtype)
self.classifier = FlaxRoFormerClassificationHead(config=self.config, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
token_type_ids,
head_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# Model
outputs = self.roformer(
input_ids,
attention_mask,
token_type_ids,
head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.classifier(sequence_output, deterministic=deterministic)
if not return_dict:
return (logits,) + outputs[1:]
return FlaxSequenceClassifierOutput(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
RoFormer Model transformer with a sequence classification/regression head on top (a linear layer on top of the
pooled output) e.g. for GLUE tasks.
""",
ROFORMER_START_DOCSTRING,
)
class FlaxRoFormerForSequenceClassification(FlaxRoFormerPreTrainedModel):
module_class = FlaxRoFormerForSequenceClassificationModule
append_call_sample_docstring(
FlaxRoFormerForSequenceClassification,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxSequenceClassifierOutput,
_CONFIG_FOR_DOC,
)
class FlaxRoFormerForMultipleChoiceModule(nn.Module):
config: RoFormerConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.roformer = FlaxRoFormerModule(config=self.config, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.classifier = nn.Dense(1, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
token_type_ids,
head_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
num_choices = input_ids.shape[1]
input_ids = input_ids.reshape(-1, input_ids.shape[-1])
attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1])
token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1])
# Model
outputs = self.roformer(
input_ids,
attention_mask,
token_type_ids,
head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Equivalent to sequence_summary call in the PyTorch implementation
hidden_states = outputs[0]
pooled_output = hidden_states[:, -1]
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
logits = self.classifier(pooled_output)
reshaped_logits = logits.reshape(-1, num_choices)
if not return_dict:
return (reshaped_logits,) + outputs[2:]
return FlaxMultipleChoiceModelOutput(
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
RoFormer Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
softmax) e.g. for RocStories/SWAG tasks.
""",
ROFORMER_START_DOCSTRING,
)
class FlaxRoFormerForMultipleChoice(FlaxRoFormerPreTrainedModel):
module_class = FlaxRoFormerForMultipleChoiceModule
overwrite_call_docstring(
FlaxRoFormerForMultipleChoice, ROFORMER_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
)
append_call_sample_docstring(
FlaxRoFormerForMultipleChoice,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxMultipleChoiceModelOutput,
_CONFIG_FOR_DOC,
)
class FlaxRoFormerForTokenClassificationModule(nn.Module):
config: RoFormerConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.roformer = FlaxRoFormerModule(config=self.config, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
token_type_ids,
head_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# Model
outputs = self.roformer(
input_ids,
attention_mask,
token_type_ids,
head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
logits = self.classifier(hidden_states)
if not return_dict:
return (logits,) + outputs[1:]
return FlaxTokenClassifierOutput(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
RoFormer Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
""",
ROFORMER_START_DOCSTRING,
)
class FlaxRoFormerForTokenClassification(FlaxRoFormerPreTrainedModel):
module_class = FlaxRoFormerForTokenClassificationModule
append_call_sample_docstring(
FlaxRoFormerForTokenClassification,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxTokenClassifierOutput,
_CONFIG_FOR_DOC,
)
class FlaxRoFormerForQuestionAnsweringModule(nn.Module):
config: RoFormerConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.roformer = FlaxRoFormerModule(config=self.config, dtype=self.dtype)
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
token_type_ids,
head_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# Model
outputs = self.roformer(
input_ids,
attention_mask,
token_type_ids,
head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.qa_outputs(hidden_states)
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
if not return_dict:
return (start_logits, end_logits) + outputs[1:]
return FlaxQuestionAnsweringModelOutput(
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
RoFormer Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
ROFORMER_START_DOCSTRING,
)
class FlaxRoFormerForQuestionAnswering(FlaxRoFormerPreTrainedModel):
module_class = FlaxRoFormerForQuestionAnsweringModule
append_call_sample_docstring(
FlaxRoFormerForQuestionAnswering,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxQuestionAnsweringModelOutput,
_CONFIG_FOR_DOC,
)
......@@ -1316,6 +1316,90 @@ class FlaxRobertaPreTrainedModel:
requires_backends(self, ["flax"])
class FlaxRoFormerForMaskedLM:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
def __call__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxRoFormerForMultipleChoice:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
def __call__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxRoFormerForQuestionAnswering:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
def __call__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxRoFormerForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
def __call__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxRoFormerForTokenClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
def __call__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxRoFormerModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
def __call__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxRoFormerPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
def __call__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxT5ForConditionalGeneration:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
......
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers import RoFormerConfig, is_flax_available
from transformers.testing_utils import require_flax, slow
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
if is_flax_available():
import jax.numpy as jnp
from transformers.models.roformer.modeling_flax_roformer import (
FlaxRoFormerForMaskedLM,
FlaxRoFormerForMultipleChoice,
FlaxRoFormerForQuestionAnswering,
FlaxRoFormerForSequenceClassification,
FlaxRoFormerForTokenClassification,
FlaxRoFormerModel,
)
class FlaxRoFormerModelTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_attention_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_choices=4,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_attention_mask = use_attention_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_choices = num_choices
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = None
if self.use_attention_mask:
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
config = RoFormerConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range,
)
return config, input_ids, token_type_ids, attention_mask
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, token_type_ids, attention_mask = config_and_inputs
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask}
return config, inputs_dict
@require_flax
class FlaxRoFormerModelTest(FlaxModelTesterMixin, unittest.TestCase):
test_head_masking = True
all_model_classes = (
(
FlaxRoFormerModel,
FlaxRoFormerForMaskedLM,
FlaxRoFormerForSequenceClassification,
FlaxRoFormerForTokenClassification,
FlaxRoFormerForMultipleChoice,
FlaxRoFormerForQuestionAnswering,
)
if is_flax_available()
else ()
)
def setUp(self):
self.model_tester = FlaxRoFormerModelTester(self)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("junnyu/roformer_chinese_small", from_pt=True)
outputs = model(np.ones((1, 1)))
self.assertIsNotNone(outputs)
@require_flax
class FlaxRoFormerModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_masked_lm(self):
model = FlaxRoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_base")
input_ids = jnp.array([[0, 1, 2, 3, 4, 5]])
output = model(input_ids)[0]
vocab_size = 50000
expected_shape = (1, 6, vocab_size)
self.assertEqual(output.shape, expected_shape)
expected_slice = jnp.array(
[[[-0.1205, -1.0265, 0.2922], [-1.5134, 0.1974, 0.1519], [-5.0135, -3.9003, -0.8404]]]
)
self.assertTrue(jnp.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment