Unverified Commit a5b68232 authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

Flax/Jax documentation (#8331)



* First addition of Flax/Jax documentation
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* make style

* Ensure input order match between Bert & Roberta
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Install dependencies "all" when building doc
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* wraps build_doc deps with ""
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Addressing @sgugger comments.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Use list to highlight JAX features.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Make style.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Let's not look to much into the future for now.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Style
Co-authored-by: default avatarLysandre <lysandre.debut@reseau.eseo.fr>
parent c7b6bbec
...@@ -281,7 +281,7 @@ jobs: ...@@ -281,7 +281,7 @@ jobs:
- v0.4-build_doc-{{ checksum "setup.py" }} - v0.4-build_doc-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }} - v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip - run: pip install --upgrade pip
- run: pip install .[tf,torch,sentencepiece,docs] - run: pip install ."[all, docs]"
- save_cache: - save_cache:
key: v0.4-build_doc-{{ checksum "setup.py" }} key: v0.4-build_doc-{{ checksum "setup.py" }}
paths: paths:
......
...@@ -188,3 +188,10 @@ TFBertForQuestionAnswering ...@@ -188,3 +188,10 @@ TFBertForQuestionAnswering
.. autoclass:: transformers.TFBertForQuestionAnswering .. autoclass:: transformers.TFBertForQuestionAnswering
:members: call :members: call
FlaxBertModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBertModel
:members: __call__
...@@ -146,3 +146,10 @@ TFRobertaForQuestionAnswering ...@@ -146,3 +146,10 @@ TFRobertaForQuestionAnswering
.. autoclass:: transformers.TFRobertaForQuestionAnswering .. autoclass:: transformers.TFRobertaForQuestionAnswering
:members: call :members: call
FlaxRobertaModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxRobertaModel
:members: __call__
...@@ -22,7 +22,7 @@ import jax ...@@ -22,7 +22,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from .configuration_bert import BertConfig from .configuration_bert import BertConfig
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from .modeling_flax_utils import FlaxPreTrainedModel, gelu from .modeling_flax_utils import FlaxPreTrainedModel, gelu
from .utils import logging from .utils import logging
...@@ -35,13 +35,20 @@ _TOKENIZER_FOR_DOC = "BertTokenizer" ...@@ -35,13 +35,20 @@ _TOKENIZER_FOR_DOC = "BertTokenizer"
BERT_START_DOCSTRING = r""" BERT_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, generic methods the library implements for all its model (such as downloading, saving and converting weights from
pruning heads etc.) PyTorch models)
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ This model is also a Flax Linen `flax.nn.Module
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to <https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax
general usage and behavior. Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- `Just-In-Time (JIT) compilation <https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__
- `Automatic Differentiation <https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__
- `Vectorization <https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__
- `Parallelization <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__
Parameters: Parameters:
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model. config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
...@@ -52,50 +59,32 @@ BERT_START_DOCSTRING = r""" ...@@ -52,50 +59,32 @@ BERT_START_DOCSTRING = r"""
BERT_INPUTS_DOCSTRING = r""" BERT_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`): input_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`):
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.BertTokenizer`. See Indices can be obtained using :class:`~transformers.BertTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for :meth:`transformers.PreTrainedTokenizer.encode` and :func:`transformers.PreTrainedTokenizer.__call__` for
details. details.
`What are input IDs? <../glossary.html#input-ids>`__ `What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): attention_mask (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**, - 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): token_type_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
1]``: 1]``:
- 0 corresponds to a `sentence A` token, - 0 corresponds to a `sentence A` token,
- 1 corresponds to a `sentence B` token. - 1 corresponds to a `sentence B` token.
`What are token type IDs? <../glossary.html#token-type-ids>`_ `What are token type IDs? <../glossary.html#token-type-ids>`__
position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): position_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
config.max_position_embeddings - 1]``. config.max_position_embeddings - 1]``.
`What are position IDs? <../glossary.html#position-ids>`_
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
vectors than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`): return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
""" """
...@@ -291,7 +280,7 @@ class FlaxBertModule(nn.Module): ...@@ -291,7 +280,7 @@ class FlaxBertModule(nn.Module):
intermediate_size: int intermediate_size: int
@nn.compact @nn.compact
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask): def __call__(self, input_ids, attention_mask, token_type_ids, position_ids):
# Embedding # Embedding
embeddings = FlaxBertEmbeddings( embeddings = FlaxBertEmbeddings(
...@@ -410,7 +399,8 @@ class FlaxBertModel(FlaxPreTrainedModel): ...@@ -410,7 +399,8 @@ class FlaxBertModel(FlaxPreTrainedModel):
def module(self) -> nn.Module: def module(self) -> nn.Module:
return self._module return self._module
def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None): @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None):
if token_type_ids is None: if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids) token_type_ids = jnp.ones_like(input_ids)
...@@ -423,7 +413,7 @@ class FlaxBertModel(FlaxPreTrainedModel): ...@@ -423,7 +413,7 @@ class FlaxBertModel(FlaxPreTrainedModel):
return self.model.apply( return self.model.apply(
{"params": self.params}, {"params": self.params},
jnp.array(input_ids, dtype="i4"), jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"), jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"), jnp.array(position_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
) )
...@@ -21,7 +21,7 @@ import jax ...@@ -21,7 +21,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from .configuration_roberta import RobertaConfig from .configuration_roberta import RobertaConfig
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from .modeling_flax_utils import FlaxPreTrainedModel, gelu from .modeling_flax_utils import FlaxPreTrainedModel, gelu
from .utils import logging from .utils import logging
...@@ -34,13 +34,20 @@ _TOKENIZER_FOR_DOC = "RobertaTokenizer" ...@@ -34,13 +34,20 @@ _TOKENIZER_FOR_DOC = "RobertaTokenizer"
ROBERTA_START_DOCSTRING = r""" ROBERTA_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, generic methods the library implements for all its model (such as downloading, saving and converting weights from
pruning heads etc.) PyTorch models)
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ This model is also a Flax Linen `flax.nn.Module
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to <https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax
general usage and behavior. Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- `Just-In-Time (JIT) compilation <https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__
- `Automatic Differentiation <https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__
- `Vectorization <https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__
- `Parallelization <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__
Parameters: Parameters:
config (:class:`~transformers.RobertaConfig`): Model configuration class with all the parameters of the config (:class:`~transformers.RobertaConfig`): Model configuration class with all the parameters of the
...@@ -51,50 +58,32 @@ ROBERTA_START_DOCSTRING = r""" ...@@ -51,50 +58,32 @@ ROBERTA_START_DOCSTRING = r"""
ROBERTA_INPUTS_DOCSTRING = r""" ROBERTA_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`): input_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`):
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.RobertaTokenizer`. See Indices can be obtained using :class:`~transformers.BertTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for :func:`transformers.PreTrainedTokenizer.encode` and :func:`transformers.PreTrainedTokenizer.__call__` for
details. details.
`What are input IDs? <../glossary.html#input-ids>`__ `What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): attention_mask (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**, - 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): token_type_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
1]``: 1]``:
- 0 corresponds to a `sentence A` token, - 0 corresponds to a `sentence A` token,
- 1 corresponds to a `sentence B` token. - 1 corresponds to a `sentence B` token.
`What are token type IDs? <../glossary.html#token-type-ids>`_ `What are token type IDs? <../glossary.html#token-type-ids>`__
position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): position_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
config.max_position_embeddings - 1]``. config.max_position_embeddings - 1]``.
`What are position IDs? <../glossary.html#position-ids>`_
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
vectors than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`): return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
""" """
...@@ -302,7 +291,7 @@ class FlaxRobertaModule(nn.Module): ...@@ -302,7 +291,7 @@ class FlaxRobertaModule(nn.Module):
intermediate_size: int intermediate_size: int
@nn.compact @nn.compact
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask): def __call__(self, input_ids, attention_mask, token_type_ids, position_ids):
# Embedding # Embedding
embeddings = FlaxRobertaEmbeddings( embeddings = FlaxRobertaEmbeddings(
...@@ -421,7 +410,8 @@ class FlaxRobertaModel(FlaxPreTrainedModel): ...@@ -421,7 +410,8 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
def module(self) -> nn.Module: def module(self) -> nn.Module:
return self._module return self._module
def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None): @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None):
if token_type_ids is None: if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids) token_type_ids = jnp.ones_like(input_ids)
...@@ -436,7 +426,7 @@ class FlaxRobertaModel(FlaxPreTrainedModel): ...@@ -436,7 +426,7 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
return self.model.apply( return self.model.apply(
{"params": self.params}, {"params": self.params},
jnp.array(input_ids, dtype="i4"), jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"), jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"), jnp.array(position_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment