Unverified Commit 774760e6 authored by Kamal Raj's avatar Kamal Raj Committed by GitHub
Browse files

distilbert-flax (#13324)

* distilbert-flax

* added missing self

* docs fix

* removed tied kernal extra init

* updated docs

* x -> hidden states

* removed head_mask

* removed from_pt, +FLAX

* updated year
parent 01977466
...@@ -357,7 +357,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -357,7 +357,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| DETR | ❌ | ❌ | ✅ | ❌ | ❌ | | DETR | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| DistilBERT | ✅ | ✅ | ✅ | ✅ | | | DistilBERT | ✅ | ✅ | ✅ | ✅ | |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| DPR | ✅ | ✅ | ✅ | ✅ | ❌ | | DPR | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
......
...@@ -44,8 +44,9 @@ Tips: ...@@ -44,8 +44,9 @@ Tips:
- DistilBERT doesn't have options to select the input positions (:obj:`position_ids` input). This could be added if - DistilBERT doesn't have options to select the input positions (:obj:`position_ids` input). This could be added if
necessary though, just let us know if you need this option. necessary though, just let us know if you need this option.
This model was contributed by `victorsanh <https://huggingface.co/victorsanh>`__. The original code can be found This model was contributed by `victorsanh <https://huggingface.co/victorsanh>`__. This model jax version was
:prefix_link:`here <examples/research-projects/distillation>`. contributed by `kamalkraj <https://huggingface.co/kamalkraj>`__. The original code can be found :prefix_link:`here
<examples/research-projects/distillation>`.
DistilBertConfig DistilBertConfig
...@@ -152,3 +153,45 @@ TFDistilBertForQuestionAnswering ...@@ -152,3 +153,45 @@ TFDistilBertForQuestionAnswering
.. autoclass:: transformers.TFDistilBertForQuestionAnswering .. autoclass:: transformers.TFDistilBertForQuestionAnswering
:members: call :members: call
FlaxDistilBertModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxDistilBertModel
:members: __call__
FlaxDistilBertForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxDistilBertForMaskedLM
:members: __call__
FlaxDistilBertForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxDistilBertForSequenceClassification
:members: __call__
FlaxDistilBertForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxDistilBertForMultipleChoice
:members: __call__
FlaxDistilBertForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxDistilBertForTokenClassification
:members: __call__
FlaxDistilBertForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxDistilBertForQuestionAnswering
:members: __call__
...@@ -1712,6 +1712,17 @@ if is_flax_available(): ...@@ -1712,6 +1712,17 @@ if is_flax_available():
"FlaxCLIPVisionPreTrainedModel", "FlaxCLIPVisionPreTrainedModel",
] ]
) )
_import_structure["models.distilbert"].extend(
[
"FlaxDistilBertForMaskedLM",
"FlaxDistilBertForMultipleChoice",
"FlaxDistilBertForQuestionAnswering",
"FlaxDistilBertForSequenceClassification",
"FlaxDistilBertForTokenClassification",
"FlaxDistilBertModel",
"FlaxDistilBertPreTrainedModel",
]
)
_import_structure["models.electra"].extend( _import_structure["models.electra"].extend(
[ [
"FlaxElectraForMaskedLM", "FlaxElectraForMaskedLM",
...@@ -3201,6 +3212,15 @@ if TYPE_CHECKING: ...@@ -3201,6 +3212,15 @@ if TYPE_CHECKING:
FlaxCLIPVisionModel, FlaxCLIPVisionModel,
FlaxCLIPVisionPreTrainedModel, FlaxCLIPVisionPreTrainedModel,
) )
from .models.distilbert import (
FlaxDistilBertForMaskedLM,
FlaxDistilBertForMultipleChoice,
FlaxDistilBertForQuestionAnswering,
FlaxDistilBertForSequenceClassification,
FlaxDistilBertForTokenClassification,
FlaxDistilBertModel,
FlaxDistilBertPreTrainedModel,
)
from .models.electra import ( from .models.electra import (
FlaxElectraForMaskedLM, FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice, FlaxElectraForMultipleChoice,
......
...@@ -28,6 +28,7 @@ logger = logging.get_logger(__name__) ...@@ -28,6 +28,7 @@ logger = logging.get_logger(__name__)
FLAX_MODEL_MAPPING_NAMES = OrderedDict( FLAX_MODEL_MAPPING_NAMES = OrderedDict(
[ [
# Base model mapping # Base model mapping
("distilbert", "FlaxDistilBertModel"),
("roberta", "FlaxRobertaModel"), ("roberta", "FlaxRobertaModel"),
("bert", "FlaxBertModel"), ("bert", "FlaxBertModel"),
("big_bird", "FlaxBigBirdModel"), ("big_bird", "FlaxBigBirdModel"),
...@@ -63,6 +64,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( ...@@ -63,6 +64,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[ [
# Model for Masked LM mapping # Model for Masked LM mapping
("distilbert", "FlaxDistilBertForMaskedLM"),
("roberta", "FlaxRobertaForMaskedLM"), ("roberta", "FlaxRobertaForMaskedLM"),
("bert", "FlaxBertForMaskedLM"), ("bert", "FlaxBertForMaskedLM"),
("big_bird", "FlaxBigBirdForMaskedLM"), ("big_bird", "FlaxBigBirdForMaskedLM"),
...@@ -101,6 +103,7 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ...@@ -101,6 +103,7 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[ [
# Model for Sequence Classification mapping # Model for Sequence Classification mapping
("distilbert", "FlaxDistilBertForSequenceClassification"),
("roberta", "FlaxRobertaForSequenceClassification"), ("roberta", "FlaxRobertaForSequenceClassification"),
("bert", "FlaxBertForSequenceClassification"), ("bert", "FlaxBertForSequenceClassification"),
("big_bird", "FlaxBigBirdForSequenceClassification"), ("big_bird", "FlaxBigBirdForSequenceClassification"),
...@@ -113,6 +116,7 @@ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -113,6 +116,7 @@ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[ [
# Model for Question Answering mapping # Model for Question Answering mapping
("distilbert", "FlaxDistilBertForQuestionAnswering"),
("roberta", "FlaxRobertaForQuestionAnswering"), ("roberta", "FlaxRobertaForQuestionAnswering"),
("bert", "FlaxBertForQuestionAnswering"), ("bert", "FlaxBertForQuestionAnswering"),
("big_bird", "FlaxBigBirdForQuestionAnswering"), ("big_bird", "FlaxBigBirdForQuestionAnswering"),
...@@ -125,6 +129,7 @@ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( ...@@ -125,6 +129,7 @@ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[ [
# Model for Token Classification mapping # Model for Token Classification mapping
("distilbert", "FlaxDistilBertForTokenClassification"),
("roberta", "FlaxRobertaForTokenClassification"), ("roberta", "FlaxRobertaForTokenClassification"),
("bert", "FlaxBertForTokenClassification"), ("bert", "FlaxBertForTokenClassification"),
("big_bird", "FlaxBigBirdForTokenClassification"), ("big_bird", "FlaxBigBirdForTokenClassification"),
...@@ -135,6 +140,7 @@ FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -135,6 +140,7 @@ FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[ [
# Model for Multiple Choice mapping # Model for Multiple Choice mapping
("distilbert", "FlaxDistilBertForMultipleChoice"),
("roberta", "FlaxRobertaForMultipleChoice"), ("roberta", "FlaxRobertaForMultipleChoice"),
("bert", "FlaxBertForMultipleChoice"), ("bert", "FlaxBertForMultipleChoice"),
("big_bird", "FlaxBigBirdForMultipleChoice"), ("big_bird", "FlaxBigBirdForMultipleChoice"),
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available
_import_structure = { _import_structure = {
...@@ -58,6 +58,17 @@ if is_tf_available(): ...@@ -58,6 +58,17 @@ if is_tf_available():
"TFDistilBertPreTrainedModel", "TFDistilBertPreTrainedModel",
] ]
if is_flax_available():
_import_structure["modeling_flax_distilbert"] = [
"FlaxDistilBertForMaskedLM",
"FlaxDistilBertForMultipleChoice",
"FlaxDistilBertForQuestionAnswering",
"FlaxDistilBertForSequenceClassification",
"FlaxDistilBertForTokenClassification",
"FlaxDistilBertModel",
"FlaxDistilBertPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_distilbert import ( from .configuration_distilbert import (
...@@ -95,6 +106,17 @@ if TYPE_CHECKING: ...@@ -95,6 +106,17 @@ if TYPE_CHECKING:
TFDistilBertPreTrainedModel, TFDistilBertPreTrainedModel,
) )
if is_flax_available():
from .modeling_flax_distilbert import (
FlaxDistilBertForMaskedLM,
FlaxDistilBertForMultipleChoice,
FlaxDistilBertForQuestionAnswering,
FlaxDistilBertForSequenceClassification,
FlaxDistilBertForTokenClassification,
FlaxDistilBertModel,
FlaxDistilBertPreTrainedModel,
)
else: else:
import sys import sys
......
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Callable, Optional, Tuple
import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from jax import lax
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_outputs import (
FlaxBaseModelOutput,
FlaxMaskedLMOutput,
FlaxMultipleChoiceModelOutput,
FlaxQuestionAnsweringModelOutput,
FlaxSequenceClassifierOutput,
FlaxTokenClassifierOutput,
)
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, overwrite_call_docstring
from ...utils import logging
from .configuration_distilbert import DistilBertConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "distilbert-base-uncased"
_CONFIG_FOR_DOC = "DistilBertConfig"
_TOKENIZER_FOR_DOC = "DistilBertTokenizer"
FLAX_DISTILBERT_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
generic methods the library implements for all its model (such as downloading, saving and converting weights from
PyTorch models)
This model is also a Flax Linen `flax.linen.Module
<https://flax.readthedocs.io/en/latest/flax.linen.html#module>`__ subclass. Use it as a regular Flax linen Module
and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- `Just-In-Time (JIT) compilation <https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__
- `Automatic Differentiation <https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__
- `Vectorization <https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__
- `Parallelization <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__
Parameters:
config (:class:`~transformers.DistilBertConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
weights.
"""
DISTILBERT_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.BertTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :func:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
def get_angles(pos, i, d_model):
angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))
return pos * angle_rates
def positional_encoding(position, d_model, dtype):
# create the sinusoidal pattern for the positional encoding
angle_rads = get_angles(np.arange(position)[:, np.newaxis], np.arange(d_model)[np.newaxis, :], d_model)
# apply sin to even indices in the array; 2i
angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
# apply cos to odd indices in the array; 2i+1
angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
pos_encoding = angle_rads[np.newaxis, ...]
# cast to dtype
return jnp.array(pos_encoding, dtype=dtype)
class FlaxEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
config: DistilBertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.word_embeddings = nn.Embed(
self.config.vocab_size,
self.config.dim,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
if not self.config.sinusoidal_pos_embds:
self.position_embeddings = nn.Embed(
self.config.max_position_embeddings,
self.config.dim,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
else:
self.pos_encoding = positional_encoding(self.config.max_position_embeddings, self.config.dim, self.dtype)
self.LayerNorm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.dropout)
def __call__(self, input_ids, deterministic: bool = True):
# Embed
batch_size, seq_length = input_ids.shape
inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
if not self.config.sinusoidal_pos_embds:
position_ids = jnp.arange(seq_length).astype("i4")
position_ids = jnp.broadcast_to(position_ids, shape=(batch_size, seq_length))
position_embeds = self.position_embeddings(position_ids.astype("i4"))
else:
position_embeds = self.pos_encoding[:, :seq_length, :]
# Sum all embeddings
hidden_states = inputs_embeds + position_embeds
# Layer Norm
hidden_states = self.LayerNorm(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
return hidden_states
class FlaxMultiHeadSelfAttention(nn.Module):
config: DistilBertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.n_heads = self.config.n_heads
self.dim = self.config.dim
self.dropout = nn.Dropout(rate=self.config.attention_dropout)
assert self.dim % self.n_heads == 0, f"Hidden size {self.dim} not dividable by number of heads {self.n_heads}"
self.q_lin = nn.Dense(
self.dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
)
self.k_lin = nn.Dense(
self.dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
)
self.v_lin = nn.Dense(
self.dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
)
self.out_lin = nn.Dense(
self.dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
)
def __call__(
self,
query,
key,
value,
mask,
deterministic: bool = True,
output_attentions: bool = False,
):
bs, q_len, dim = query.shape
k_len = key.shape[1]
# assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured'
# assert key.size() == value.size()
dim_per_head = self.dim // self.n_heads
mask_reshp = (bs, 1, 1, k_len)
def shape(x):
"""separate heads"""
return x.reshape(bs, -1, self.n_heads, dim_per_head).transpose(0, 2, 1, 3)
def unshape(x):
"""group heads"""
return x.transpose(0, 2, 1, 3).reshape(bs, -1, self.n_heads * dim_per_head)
q = shape(self.q_lin(query)) # (bs, n_heads, q_len, dim_per_head)
k = shape(self.k_lin(key)) # (bs, n_heads, k_len, dim_per_head)
v = shape(self.v_lin(value)) # (bs, n_heads, k_len, dim_per_head)
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_len, dim_per_head)
scores = jnp.matmul(q, k.transpose(0, 1, 3, 2)) # (bs, n_heads, q_len, k_len)
mask = jnp.reshape(mask, mask_reshp)
mask = mask.astype(scores.dtype)
scores = scores - 1e30 * (1.0 - mask)
weights = nn.softmax(scores, axis=-1) # (bs, n_heads, q_len, k_len)
weights = self.dropout(weights, deterministic=deterministic)
context = jnp.matmul(weights, v) # (bs, n_heads, q_len, dim_per_head)
context = unshape(context) # (bs, q_len, dim)
context = self.out_lin(context) # (bs, q_len, dim)
if output_attentions:
return (context, weights)
else:
return (context,)
class FlaxFFN(nn.Module):
config: DistilBertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.dropout = nn.Dropout(rate=self.config.dropout)
self.chunk_size_feed_forward = self.config.chunk_size_feed_forward
self.seq_len_dim = 1
self.lin1 = nn.Dense(
self.config.hidden_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
)
self.lin2 = nn.Dense(
self.config.dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
)
assert self.config.activation in [
"relu",
"gelu",
], f"activation ({self.config.activation}) must be in ['relu', 'gelu']"
self.activation = ACT2FN[self.config.activation]
def __call__(self, hidden_states, deterministic: bool = True):
hidden_states = self.lin1(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.lin2(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
return hidden_states
class FlaxTransformerBlock(nn.Module):
config: DistilBertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
assert (
self.config.dim % self.config.n_heads == 0
), f"Hidden size {self.config.dim} not dividable by number of heads {self.config.n_heads}"
self.attention = FlaxMultiHeadSelfAttention(self.config, dtype=self.dtype)
self.sa_layer_norm = nn.LayerNorm(epsilon=1e-12)
self.ffn = FlaxFFN(self.config, dtype=self.dtype)
self.output_layer_norm = nn.LayerNorm(epsilon=1e-12)
def __call__(
self,
hidden_states,
attn_mask,
output_attentions: bool = False,
deterministic: bool = True,
):
# Self-Attention
sa_output = self.attention(
query=hidden_states,
key=hidden_states,
value=hidden_states,
mask=attn_mask,
output_attentions=output_attentions,
deterministic=deterministic,
)
if output_attentions:
sa_output, sa_weights = sa_output
else:
assert type(sa_output) == tuple
sa_output = sa_output[0]
sa_output = self.sa_layer_norm(sa_output + hidden_states)
# Feed Forward Network
ffn_output = self.ffn(sa_output, deterministic=deterministic)
ffn_output = self.output_layer_norm(ffn_output + sa_output)
output = (ffn_output,)
if output_attentions:
output = (sa_weights,) + output
return output
class FlaxTransformer(nn.Module):
config: DistilBertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.layers = [
FlaxTransformerBlock(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.n_layers)
]
def __call__(
self,
hidden_states,
attention_mask,
output_attentions: bool = False,
output_hidden_states: bool = False,
deterministic: bool = True,
return_dict: bool = False,
):
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for layer_module in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
hidden_states=hidden_states,
attn_mask=attention_mask,
output_attentions=output_attentions,
deterministic=deterministic,
)
hidden_states = layer_outputs[-1]
if output_attentions:
assert len(layer_outputs) == 2
attentions = layer_outputs[0]
all_attentions = all_attentions + (attentions,)
else:
assert len(layer_outputs) == 1
# Add last layer
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_attentions, all_hidden_states] if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
class FlaxTransformerEncoder(nn.Module):
config: DistilBertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.layer = FlaxTransformer(self.config, dtype=self.dtype)
def __call__(
self,
hidden_states,
attention_mask,
output_attentions: bool = False,
output_hidden_states: bool = False,
deterministic: bool = True,
return_dict: bool = False,
):
return self.layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
deterministic=deterministic,
return_dict=return_dict,
)
class FlaxDistilBertLMDecoder(nn.Module):
config: DistilBertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
def setup(self):
self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
def __call__(self, inputs, kernel):
y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())))
y = y + self.bias
return y
class FlaxDistilBertPreTrainedModel(FlaxPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = DistilBertConfig
base_model_prefix = "distilbert"
module_class: nn.Module = None
def __init__(
self,
config: DistilBertConfig,
input_shape: Tuple = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, return_dict=False)["params"]
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,
input_ids,
attention_mask=None,
head_mask=None,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
not train,
output_attentions,
output_hidden_states,
return_dict,
rngs=rngs,
)
class FlaxDistilBertModule(nn.Module):
config: DistilBertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.embeddings = FlaxEmbeddings(self.config, dtype=self.dtype)
self.transformer = FlaxTransformerEncoder(self.config, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
input_embeds = self.embeddings(input_ids, deterministic=deterministic)
return self.transformer(
hidden_states=input_embeds,
attention_mask=attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
@add_start_docstrings(
"The bare DistilBert Model transformer outputting raw hidden-states without any specific head on top.",
FLAX_DISTILBERT_START_DOCSTRING,
)
class FlaxDistilBertModel(FlaxDistilBertPreTrainedModel):
module_class = FlaxDistilBertModule
append_call_sample_docstring(FlaxDistilBertModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, None, _CONFIG_FOR_DOC)
class FlaxDistilBertForMaskedLMModule(nn.Module):
config: DistilBertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.distilbert = FlaxDistilBertModule(self.config, dtype=self.dtype)
self.vocab_transform = nn.Dense(
self.config.dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
)
self.vocab_layer_norm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype)
if self.config.tie_word_embeddings:
self.vocab_projector = FlaxDistilBertLMDecoder(
self.config,
dtype=self.dtype,
)
else:
self.vocab_projector = nn.Dense(
self.config.vocab_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
)
def __call__(
self,
input_ids,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
dlbrt_output = self.distilbert(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
deterministic=deterministic,
return_dict=return_dict,
)
hidden_states = dlbrt_output[0]
prediction_logits = self.vocab_transform(hidden_states)
prediction_logits = ACT2FN["gelu"](prediction_logits)
prediction_logits = self.vocab_layer_norm(prediction_logits)
if self.config.tie_word_embeddings:
shared_embedding = self.distilbert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
prediction_logits = self.vocab_projector(prediction_logits, shared_embedding.T)
else:
prediction_logits = self.vocab_projector(prediction_logits)
if not return_dict:
output = (prediction_logits,) + dlbrt_output[1:]
return output
return FlaxMaskedLMOutput(
logits=prediction_logits,
hidden_states=dlbrt_output.hidden_states,
attentions=dlbrt_output.attentions,
)
@add_start_docstrings("""DistilBert Model with a `language modeling` head on top. """, FLAX_DISTILBERT_START_DOCSTRING)
class FlaxDistilBertForMaskedLM(FlaxDistilBertPreTrainedModel):
module_class = FlaxDistilBertForMaskedLMModule
append_call_sample_docstring(
FlaxDistilBertForMaskedLM, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC
)
class FlaxDistilBertForSequenceClassificationModule(nn.Module):
config: DistilBertConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.distilbert = FlaxDistilBertModule(config=self.config, dtype=self.dtype)
self.pre_classifier = nn.Dense(
self.config.dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
)
self.dropout = nn.Dropout(rate=self.config.seq_classif_dropout)
self.classifier = nn.Dense(
self.config.num_labels,
dtype=self.dtype,
)
def __call__(
self,
input_ids,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Model
distilbert_output = self.distilbert(
input_ids,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
pooled_output = hidden_state[:, 0] # (bs, dim)
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
pooled_output = ACT2FN["relu"](pooled_output)
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
logits = self.classifier(pooled_output) # (bs, dim)
if not return_dict:
return (logits,) + distilbert_output[1:]
return FlaxSequenceClassifierOutput(
logits=logits,
hidden_states=distilbert_output.hidden_states,
attentions=distilbert_output.attentions,
)
@add_start_docstrings(
"""
DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the
pooled output) e.g. for GLUE tasks.
""",
FLAX_DISTILBERT_START_DOCSTRING,
)
class FlaxDistilBertForSequenceClassification(FlaxDistilBertPreTrainedModel):
module_class = FlaxDistilBertForSequenceClassificationModule
append_call_sample_docstring(
FlaxDistilBertForSequenceClassification,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxSequenceClassifierOutput,
_CONFIG_FOR_DOC,
)
class FlaxDistilBertForMultipleChoiceModule(nn.Module):
config: DistilBertConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.distilbert = FlaxDistilBertModule(config=self.config, dtype=self.dtype)
self.pre_classifier = nn.Dense(
self.config.dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
)
self.dropout = nn.Dropout(rate=self.config.seq_classif_dropout)
self.classifier = nn.Dense(
1,
dtype=self.dtype,
)
def __call__(
self,
input_ids,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
num_choices = input_ids.shape[1]
input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
# Model
outputs = self.distilbert(
input_ids,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_state = outputs[0]
pooled_output = hidden_state[:, 0]
pooled_output = self.pre_classifier(pooled_output)
pooled_output = ACT2FN["relu"](pooled_output)
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
logits = self.classifier(pooled_output)
reshaped_logits = logits.reshape(-1, num_choices)
if not return_dict:
return (reshaped_logits,) + outputs[2:]
return FlaxMultipleChoiceModelOutput(
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
DistilBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and
a softmax) e.g. for RocStories/SWAG tasks.
""",
FLAX_DISTILBERT_START_DOCSTRING,
)
class FlaxDistilBertForMultipleChoice(FlaxDistilBertPreTrainedModel):
module_class = FlaxDistilBertForMultipleChoiceModule
overwrite_call_docstring(
FlaxDistilBertForMultipleChoice, DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
)
append_call_sample_docstring(
FlaxDistilBertForMultipleChoice,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxMultipleChoiceModelOutput,
_CONFIG_FOR_DOC,
)
class FlaxDistilBertForTokenClassificationModule(nn.Module):
config: DistilBertConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.distilbert = FlaxDistilBertModule(config=self.config, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.dropout)
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Model
outputs = self.distilbert(
input_ids,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
logits = self.classifier(hidden_states)
if not return_dict:
return (logits,) + outputs[1:]
return FlaxTokenClassifierOutput(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
DistilBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
for Named-Entity-Recognition (NER) tasks.
""",
FLAX_DISTILBERT_START_DOCSTRING,
)
class FlaxDistilBertForTokenClassification(FlaxDistilBertPreTrainedModel):
module_class = FlaxDistilBertForTokenClassificationModule
append_call_sample_docstring(
FlaxDistilBertForTokenClassification,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxTokenClassifierOutput,
_CONFIG_FOR_DOC,
)
class FlaxDistilBertForQuestionAnsweringModule(nn.Module):
config: DistilBertConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.distilbert = FlaxDistilBertModule(config=self.config, dtype=self.dtype)
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
assert self.config.num_labels == 2
self.dropout = nn.Dropout(rate=self.config.qa_dropout)
def __call__(
self,
input_ids,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Model
distilbert_output = self.distilbert(
input_ids,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = distilbert_output[0]
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
logits = self.qa_outputs(hidden_states)
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
if not return_dict:
return (start_logits, end_logits) + distilbert_output[1:]
return FlaxQuestionAnsweringModelOutput(
start_logits=start_logits,
end_logits=end_logits,
hidden_states=distilbert_output.hidden_states,
attentions=distilbert_output.attentions,
)
@add_start_docstrings(
"""
DistilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a
linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
FLAX_DISTILBERT_START_DOCSTRING,
)
class FlaxDistilBertForQuestionAnswering(FlaxDistilBertPreTrainedModel):
module_class = FlaxDistilBertForQuestionAnsweringModule
append_call_sample_docstring(
FlaxDistilBertForQuestionAnswering,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxQuestionAnsweringModelOutput,
_CONFIG_FOR_DOC,
)
...@@ -448,6 +448,69 @@ class FlaxCLIPVisionPreTrainedModel: ...@@ -448,6 +448,69 @@ class FlaxCLIPVisionPreTrainedModel:
requires_backends(cls, ["flax"]) requires_backends(cls, ["flax"])
class FlaxDistilBertForMaskedLM:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxDistilBertForMultipleChoice:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxDistilBertForQuestionAnswering:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxDistilBertForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxDistilBertForTokenClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxDistilBertModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxDistilBertPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxElectraForMaskedLM: class FlaxElectraForMaskedLM:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
......
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers import DistilBertConfig, is_flax_available
from transformers.testing_utils import require_flax, slow
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
if is_flax_available():
import jax.numpy as jnp
from transformers.models.distilbert.modeling_flax_distilbert import (
FlaxDistilBertForMaskedLM,
FlaxDistilBertForMultipleChoice,
FlaxDistilBertForQuestionAnswering,
FlaxDistilBertForSequenceClassification,
FlaxDistilBertForTokenClassification,
FlaxDistilBertModel,
)
class FlaxDistilBertModelTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_attention_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_choices=4,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_attention_mask = use_attention_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_choices = num_choices
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = None
if self.use_attention_mask:
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
config = DistilBertConfig(
vocab_size=self.vocab_size,
dim=self.hidden_size,
n_layers=self.num_hidden_layers,
n_heads=self.num_attention_heads,
hidden_dim=self.intermediate_size,
hidden_act=self.hidden_act,
dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
initializer_range=self.initializer_range,
tie_weights_=True,
)
return config, input_ids, attention_mask
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, attention_mask = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
return config, inputs_dict
@require_flax
class FlaxDistilBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
FlaxDistilBertModel,
FlaxDistilBertForMaskedLM,
FlaxDistilBertForMultipleChoice,
FlaxDistilBertForQuestionAnswering,
FlaxDistilBertForSequenceClassification,
FlaxDistilBertForTokenClassification,
FlaxDistilBertForQuestionAnswering,
)
if is_flax_available()
else ()
)
def setUp(self):
self.model_tester = FlaxDistilBertModelTester(self)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("distilbert-base-uncased")
outputs = model(np.ones((1, 1)))
self.assertIsNotNone(outputs)
@require_flax
class FlaxDistilBertModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_no_head_absolute_embedding(self):
model = FlaxDistilBertModel.from_pretrained("distilbert-base-uncased")
input_ids = np.array([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
attention_mask = np.array([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
output = model(input_ids, attention_mask=attention_mask)[0]
expected_shape = (1, 11, 768)
self.assertEqual(output.shape, expected_shape)
expected_slice = np.array([[[-0.1639, 0.3299, 0.1648], [-0.1746, 0.3289, 0.1710], [-0.1884, 0.3357, 0.1810]]])
self.assertTrue(jnp.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment