Commit 572dcfd1 authored by LysandreJik's avatar LysandreJik
Browse files

Doc

parent c4ef1034
...@@ -47,3 +47,4 @@ The library currently contains PyTorch implementations, pre-trained model weight ...@@ -47,3 +47,4 @@ The library currently contains PyTorch implementations, pre-trained model weight
model_doc/gpt2 model_doc/gpt2
model_doc/xlm model_doc/xlm
model_doc/xlnet model_doc/xlnet
model_doc/roberta
RoBERTa
----------------------------------------------------
``RobertaConfig``
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: pytorch_transformers.RobertaConfig
:members:
``RobertaTokenizer``
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: pytorch_transformers.RobertaTokenizer
:members:
``RobertaModel``
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: pytorch_transformers.RobertaModel
:members:
``RobertaForMaskedLM``
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: pytorch_transformers.RobertaForMaskedLM
:members:
``RobertaForSequenceClassification``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: pytorch_transformers.RobertaForSequenceClassification
:members:
This diff is collapsed.
...@@ -29,6 +29,8 @@ from pytorch_transformers.modeling_bert import (BertConfig, BertEmbeddings, ...@@ -29,6 +29,8 @@ from pytorch_transformers.modeling_bert import (BertConfig, BertEmbeddings,
BertLayerNorm, BertModel, BertLayerNorm, BertModel,
BertPreTrainedModel, gelu) BertPreTrainedModel, gelu)
from pytorch_transformers.modeling_utils import add_start_docstrings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = { ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
...@@ -65,11 +67,93 @@ class RobertaEmbeddings(BertEmbeddings): ...@@ -65,11 +67,93 @@ class RobertaEmbeddings(BertEmbeddings):
class RobertaConfig(BertConfig): class RobertaConfig(BertConfig):
pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
ROBERTA_START_DOCSTRING = r""" The RoBERTa model was proposed in
`RoBERTa: A Robustly Optimized BERT Pretraining Approach`_
by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer,
Veselin Stoyanov. It is based on Google's BERT model released in 2018.
It builds on BERT and modifies key hyperparameters, removing the next-sentence pretraining
objective and training with much larger mini-batches and learning rates.
This implementation is the same as BertModel with a tiny embeddings tweak as well as a setup for Roberta pretrained
models.
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
refer to the PyTorch documentation for all matter related to general usage and behavior.
.. _`RoBERTa: A Robustly Optimized BERT Pretraining Approach`:
https://arxiv.org/abs/1907.11692
.. _`torch.nn.Module`:
https://pytorch.org/docs/stable/nn.html#module
Parameters:
config (:class:`~pytorch_transformers.RobertaConfig`): Model configuration class with all the parameters of the
model.
"""
ROBERTA_INPUTS_DOCSTRING = r"""
Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
To match pre-training, RoBERTa input sequence should be formatted with [CLS] and [SEP] tokens as follows:
(a) For sequence pairs:
``tokens: [CLS] is this jack ##son ##ville ? [SEP][SEP] no it is not . [SEP]``
(b) For single sequences:
``tokens: [CLS] the dog is hairy . [SEP]``
Fully encoded sequences or sequence pairs can be obtained using the RobertaTokenizer.encode function with
the ``add_special_tokens`` parameter set to ``True``.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1[``.
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""
@add_start_docstrings("The bare RoBERTa Model transformer outputing raw hidden-states without any specific head on top.",
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
class RobertaModel(BertModel): class RobertaModel(BertModel):
""" r"""
Same as BertModel with: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
- a tiny embeddings tweak. **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
- setup for Roberta pretrained models Sequence of hidden-states at the output of the last layer of the model.
**pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer and a Tanh activation function. The Linear
layer weights are trained from the next sentence prediction (classification)
objective during Bert pretraining. This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaModel.from_pretrained('roberta-base')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
""" """
config_class = RobertaConfig config_class = RobertaConfig
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
...@@ -82,9 +166,37 @@ class RobertaModel(BertModel): ...@@ -82,9 +166,37 @@ class RobertaModel(BertModel):
self.apply(self.init_weights) self.apply(self.init_weights)
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """,
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
class RobertaForMaskedLM(BertPreTrainedModel): class RobertaForMaskedLM(BertPreTrainedModel):
""" r"""
Roberta Model with a `language modeling` head on top. **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the masked language modeling loss.
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss.
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaForMaskedLM.from_pretrained('roberta-base')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids, masked_lm_labels=input_ids)
loss, prediction_scores = outputs[:2]
""" """
config_class = RobertaConfig config_class = RobertaConfig
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
...@@ -112,14 +224,14 @@ class RobertaForMaskedLM(BertPreTrainedModel): ...@@ -112,14 +224,14 @@ class RobertaForMaskedLM(BertPreTrainedModel):
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output) prediction_scores = self.lm_head(sequence_output)
outputs = (prediction_scores,) + outputs[2:] outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
if masked_lm_labels is not None: if masked_lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
outputs = (masked_lm_loss,) + outputs outputs = (masked_lm_loss,) + outputs
return outputs return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
class RobertaLMHead(nn.Module): class RobertaLMHead(nn.Module):
...@@ -144,9 +256,39 @@ class RobertaLMHead(nn.Module): ...@@ -144,9 +256,39 @@ class RobertaLMHead(nn.Module):
return x return x
@add_start_docstrings("""RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer
on top of the pooled output) e.g. for GLUE tasks. """,
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
class RobertaForSequenceClassification(BertPreTrainedModel): class RobertaForSequenceClassification(BertPreTrainedModel):
""" r"""
Roberta Model with a classifier head on top. **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the sequence classification/regression loss.
Indices should be in ``[0, ..., config.num_labels]``.
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification (or regression if config.num_labels==1) loss.
**logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
Classification (or regression if config.num_labels==1) scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = RoertaTokenizer.from_pretrained('roberta-base')
model = RobertaForSequenceClassification.from_pretrained('roberta-base')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, logits = outputs[:2]
""" """
config_class = RobertaConfig config_class = RobertaConfig
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
......
...@@ -65,8 +65,7 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { ...@@ -65,8 +65,7 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
class RobertaTokenizer(PreTrainedTokenizer): class RobertaTokenizer(PreTrainedTokenizer):
""" """
RoBERTa BPE tokenizer, derived from the GPT-2 tokenizer. Peculiarities: RoBERTa BPE tokenizer, derived from the GPT-2 tokenizer. Peculiarities: Byte-level BPE
- Byte-level BPE
""" """
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
......
...@@ -180,9 +180,10 @@ class PreTrainedTokenizer(object): ...@@ -180,9 +180,10 @@ class PreTrainedTokenizer(object):
@classmethod @classmethod
def from_pretrained(cls, *inputs, **kwargs): def from_pretrained(cls, *inputs, **kwargs):
r""" Instantiate a :class:`~pytorch_transformers.PreTrainedTokenizer` (or a derived class) from a predefined tokenizer. r"""
Instantiate a :class:`~pytorch_transformers.PreTrainedTokenizer` (or a derived class) from a predefined tokenizer.
Parameters: Args:
pretrained_model_name_or_path: either: pretrained_model_name_or_path: either:
- a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``. - a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``.
...@@ -383,14 +384,15 @@ class PreTrainedTokenizer(object): ...@@ -383,14 +384,15 @@ class PreTrainedTokenizer(object):
def add_tokens(self, new_tokens): def add_tokens(self, new_tokens):
""" Add a list of new tokens to the tokenizer class. If the new tokens are not in the """
Add a list of new tokens to the tokenizer class. If the new tokens are not in the
vocabulary, they are added to it with indices starting from length of the current vocabulary. vocabulary, they are added to it with indices starting from length of the current vocabulary.
Parameters: Args:
new_tokens: list of string. Each string is a token to add. Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them). new_tokens: list of string. Each string is a token to add. Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
Returns: Returns:
Number of tokens added to the vocabulary. Number of tokens added to the vocabulary.
Examples:: Examples::
...@@ -422,17 +424,20 @@ class PreTrainedTokenizer(object): ...@@ -422,17 +424,20 @@ class PreTrainedTokenizer(object):
def add_special_tokens(self, special_tokens_dict): def add_special_tokens(self, special_tokens_dict):
""" Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them """
to class attributes. If special tokens are NOT in the vocabulary, they are added Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them
to it (indexed starting from the last index of the current vocabulary). to class attributes. If special tokens are NOT in the vocabulary, they are added
to it (indexed starting from the last index of the current vocabulary).
Parameters: Args:
special_tokens_dict: dict of string. Keys should be in the list of predefined special attributes: [``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``]. special_tokens_dict: dict of string. Keys should be in the list of predefined special attributes:
[``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``,
Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them). ``additional_special_tokens``].
Returns: Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
Number of tokens added to the vocabulary.
Returns:
Number of tokens added to the vocabulary.
Examples:: Examples::
...@@ -520,9 +525,16 @@ class PreTrainedTokenizer(object): ...@@ -520,9 +525,16 @@ class PreTrainedTokenizer(object):
raise NotImplementedError raise NotImplementedError
def encode(self, text, text_pair=None, add_special_tokens=False): def encode(self, text, text_pair=None, add_special_tokens=False):
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. """
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
Same doing ``self.convert_tokens_to_ids(self.tokenize(text))``. Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
Args:
text: The first sequence to be encoded.
text_pair: Optional second sequence to be encoded.
add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
to their model.
""" """
if text_pair is None: if text_pair is None:
if add_special_tokens: if add_special_tokens:
...@@ -577,9 +589,9 @@ class PreTrainedTokenizer(object): ...@@ -577,9 +589,9 @@ class PreTrainedTokenizer(object):
return ' '.join(self.convert_ids_to_tokens(tokens)) return ' '.join(self.convert_ids_to_tokens(tokens))
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary """
with options to remove special tokens and clean up tokenization spaces. Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
with options to remove special tokens and clean up tokenization spaces.
Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``. Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
""" """
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment