Commit 9eddf44b authored by Julien Chaumond's avatar Julien Chaumond
Browse files

docstring + check

parent 8e11de0e
...@@ -313,6 +313,9 @@ GPT2_INPUTS_DOCSTRING = r""" Inputs: ...@@ -313,6 +313,9 @@ GPT2_INPUTS_DOCSTRING = r""" Inputs:
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want to input a probability distribution of tokens rather than actual tokens.
""" """
@add_start_docstrings("The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", @add_start_docstrings("The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
...@@ -371,7 +374,9 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -371,7 +374,9 @@ class GPT2Model(GPT2PreTrainedModel):
self.h[layer].attn.prune_heads(heads) self.h[layer].attn.prune_heads(heads)
def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None): def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None):
if input_ids is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1]) input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None: elif inputs_embeds is not None:
......
...@@ -35,7 +35,7 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -35,7 +35,7 @@ class TFPreTrainedModel(tf.keras.Model):
r""" Base class for all TF models. r""" Base class for all TF models.
:class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models :class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
as well as a few methods commons to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads. as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
Class attributes (overridden by derived classes): Class attributes (overridden by derived classes):
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture. - ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
......
...@@ -53,7 +53,7 @@ class PreTrainedModel(nn.Module): ...@@ -53,7 +53,7 @@ class PreTrainedModel(nn.Module):
r""" Base class for all models. r""" Base class for all models.
:class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models :class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
as well as a few methods commons to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads. as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
Class attributes (overridden by derived classes): Class attributes (overridden by derived classes):
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture. - ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment