"examples/flax/vscode:/vscode.git/clone" did not exist on "279ce5b705a0b8689f2a8e5d5258dbb5421c9e6c"
Unverified Commit e89c959a authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Fix model templates (#9999)

parent 804cd185
...@@ -1656,7 +1656,7 @@ class BartForCausalLM(BartPretrainedModel): ...@@ -1656,7 +1656,7 @@ class BartForCausalLM(BartPretrainedModel):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it. provide it.
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See Indices can be obtained using :class:`~transformers.BartTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details. for details.
......
...@@ -1425,7 +1425,7 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel): ...@@ -1425,7 +1425,7 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it. provide it.
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See Indices can be obtained using :class:`~transformers.BlenderbotTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details. for details.
......
...@@ -1400,7 +1400,7 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel): ...@@ -1400,7 +1400,7 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it. provide it.
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See Indices can be obtained using :class:`~transformers.BlenderbotSmallTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details. for details.
......
...@@ -1411,7 +1411,7 @@ class MarianForCausalLM(MarianPreTrainedModel): ...@@ -1411,7 +1411,7 @@ class MarianForCausalLM(MarianPreTrainedModel):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it. provide it.
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See Indices can be obtained using :class:`~transformers.MarianTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details. for details.
......
...@@ -1658,7 +1658,7 @@ class MBartForCausalLM(MBartPreTrainedModel): ...@@ -1658,7 +1658,7 @@ class MBartForCausalLM(MBartPreTrainedModel):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it. provide it.
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See Indices can be obtained using :class:`~transformers.MBartTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details. for details.
......
...@@ -1414,7 +1414,7 @@ class PegasusForCausalLM(PegasusPreTrainedModel): ...@@ -1414,7 +1414,7 @@ class PegasusForCausalLM(PegasusPreTrainedModel):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it. provide it.
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See Indices can be obtained using :class:`~transformers.PegasusTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details. for details.
......
...@@ -55,6 +55,7 @@ if is_torch_available(): ...@@ -55,6 +55,7 @@ if is_torch_available():
"{{cookiecutter.camelcase_modelname}}ForConditionalGeneration", "{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering", "{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification", "{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
"{{cookiecutter.camelcase_modelname}}ForCausalLM",
"{{cookiecutter.camelcase_modelname}}Model", "{{cookiecutter.camelcase_modelname}}Model",
"{{cookiecutter.camelcase_modelname}}PreTrainedModel", "{{cookiecutter.camelcase_modelname}}PreTrainedModel",
] ]
...@@ -114,6 +115,7 @@ if TYPE_CHECKING: ...@@ -114,6 +115,7 @@ if TYPE_CHECKING:
from .modeling_{{cookiecutter.lowercase_modelname}} import ( from .modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST, {{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration, {{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
{{cookiecutter.camelcase_modelname}}ForCausalLM,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering, {{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification, {{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}Model, {{cookiecutter.camelcase_modelname}}Model,
......
...@@ -1546,6 +1546,7 @@ from ...modeling_outputs import ( ...@@ -1546,6 +1546,7 @@ from ...modeling_outputs import (
Seq2SeqModelOutput, Seq2SeqModelOutput,
Seq2SeqQuestionAnsweringModelOutput, Seq2SeqQuestionAnsweringModelOutput,
Seq2SeqSequenceClassifierOutput, Seq2SeqSequenceClassifierOutput,
CausalLMOutputWithCrossAttentions
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import logging from ...utils import logging
...@@ -1952,7 +1953,7 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module): ...@@ -1952,7 +1953,7 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
return outputs return outputs
# Copied from transformers.models.bart.modeling_bart.{{cookiecutter.camelcase_modelname}}ClassificationHead with {{cookiecutter.camelcase_modelname}}->{{cookiecutter.camelcase_modelname}} # Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module): class {{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module):
"""Head for sentence-level classification tasks.""" """Head for sentence-level classification tasks."""
...@@ -3066,8 +3067,8 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca ...@@ -3066,8 +3067,8 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca
encoder_attentions=outputs.encoder_attentions, encoder_attentions=outputs.encoder_attentions,
) )
# Copied from transformers.models.bart.modeling_bart.{{cookiecutter.camelcase_modelname}}DecoderWrapper with {{cookiecutter.camelcase_modelname}}->{{cookiecutter.camelcase_modelname}} # Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}DecoderWrapper({{cookiecutter.camelcase_modelname}}PretrainedModel): class {{cookiecutter.camelcase_modelname}}DecoderWrapper({{cookiecutter.camelcase_modelname}}PreTrainedModel):
""" """
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
used in combination with the :class:`~transformers.EncoderDecoderModel` framework. used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
...@@ -3081,8 +3082,8 @@ class {{cookiecutter.camelcase_modelname}}DecoderWrapper({{cookiecutter.camelcas ...@@ -3081,8 +3082,8 @@ class {{cookiecutter.camelcase_modelname}}DecoderWrapper({{cookiecutter.camelcas
return self.decoder(*args, **kwargs) return self.decoder(*args, **kwargs)
# Copied from transformers.models.bart.modeling_bart.{{cookiecutter.camelcase_modelname}}ForCausalLM with {{cookiecutter.camelcase_modelname}}->{{cookiecutter.camelcase_modelname}} # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_modelname}}PretrainedModel): class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_modelname}}PreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
config = copy.deepcopy(config) config = copy.deepcopy(config)
...@@ -3199,8 +3200,8 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m ...@@ -3199,8 +3200,8 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
>>> from transformers import {{cookiecutter.camelcase_modelname}}Tokenizer, {{cookiecutter.camelcase_modelname}}ForCausalLM >>> from transformers import {{cookiecutter.camelcase_modelname}}Tokenizer, {{cookiecutter.camelcase_modelname}}ForCausalLM
>>> tokenizer = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained('{{cookiecutter.checkpoint_identifier}}') >>> tokenizer = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained('facebook/bart-large')
>>> model = {{cookiecutter.camelcase_modelname}}ForCausalLM.from_pretrained('{{cookiecutter.checkpoint_identifier}}', add_cross_attention=False) >>> model = {{cookiecutter.camelcase_modelname}}ForCausalLM.from_pretrained('facebook/bart-large', add_cross_attention=False)
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs) >>> outputs = model(**inputs)
......
...@@ -488,7 +488,7 @@ from transformers.testing_utils import require_sentencepiece, require_tokenizers ...@@ -488,7 +488,7 @@ from transformers.testing_utils import require_sentencepiece, require_tokenizers
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, ids_tensor, floats_tensor from .test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available(): if is_torch_available():
...@@ -498,6 +498,7 @@ if is_torch_available(): ...@@ -498,6 +498,7 @@ if is_torch_available():
{{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}Config,
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration, {{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering, {{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForCausalLM,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification, {{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}Model, {{cookiecutter.camelcase_modelname}}Model,
{{cookiecutter.camelcase_modelname}}Tokenizer, {{cookiecutter.camelcase_modelname}}Tokenizer,
......
...@@ -47,6 +47,7 @@ ...@@ -47,6 +47,7 @@
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].extend( _import_structure["models.{{cookiecutter.lowercase_modelname}}"].extend(
[ [
"{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST", "{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
"{{cookiecutter.camelcase_modelname}}ForCausalLM",
"{{cookiecutter.camelcase_modelname}}ForConditionalGeneration", "{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering", "{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification", "{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
...@@ -115,6 +116,7 @@ ...@@ -115,6 +116,7 @@
from .models.{{cookiecutter.lowercase_modelname}} import ( from .models.{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST, {{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration, {{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
{{cookiecutter.camelcase_modelname}}ForCausalLM,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering, {{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification, {{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}Model, {{cookiecutter.camelcase_modelname}}Model,
...@@ -209,6 +211,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_mo ...@@ -209,6 +211,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_mo
{% else -%} {% else -%}
from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import ( from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration, {{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
{{cookiecutter.camelcase_modelname}}ForCausalLM,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering, {{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification, {{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}Model, {{cookiecutter.camelcase_modelname}}Model,
...@@ -232,10 +235,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_mo ...@@ -232,10 +235,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_mo
# Below: "# Model for Causal LM mapping" # Below: "# Model for Causal LM mapping"
# Replace with: # Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForCausalLM), ({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForCausalLM),
{% else -%}
{% endif -%}
# End. # End.
# Below: "# Model for Masked LM mapping" # Below: "# Model for Masked LM mapping"
...@@ -384,6 +384,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase ...@@ -384,6 +384,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
{% else -%} {% else -%}
"{{cookiecutter.camelcase_modelname}}Encoder", "{{cookiecutter.camelcase_modelname}}Encoder",
"{{cookiecutter.camelcase_modelname}}Decoder", "{{cookiecutter.camelcase_modelname}}Decoder",
"{{cookiecutter.camelcase_modelname}}DecoderWrapper",
{% endif -%} {% endif -%}
# End. # End.
...@@ -393,5 +394,6 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase ...@@ -393,5 +394,6 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
{% else -%} {% else -%}
"{{cookiecutter.camelcase_modelname}}Encoder", # Building part of bigger (tested) model. "{{cookiecutter.camelcase_modelname}}Encoder", # Building part of bigger (tested) model.
"{{cookiecutter.camelcase_modelname}}Decoder", # Building part of bigger (tested) model. "{{cookiecutter.camelcase_modelname}}Decoder", # Building part of bigger (tested) model.
"{{cookiecutter.camelcase_modelname}}DecoderWrapper", # Building part of bigger (tested) model.
{% endif -%} {% endif -%}
# End. # End.
...@@ -121,6 +121,13 @@ Tips: ...@@ -121,6 +121,13 @@ Tips:
:members: forward :members: forward
{{cookiecutter.camelcase_modelname}}ForCausalLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.{{cookiecutter.camelcase_modelname}}ForCausalLM
:members: forward
{% endif -%} {% endif -%}
{% endif -%} {% endif -%}
{% if "TensorFlow" in cookiecutter.generate_tensorflow_and_pytorch -%} {% if "TensorFlow" in cookiecutter.generate_tensorflow_and_pytorch -%}
...@@ -180,5 +187,7 @@ TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration ...@@ -180,5 +187,7 @@ TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration
.. autoclass:: transformers.TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration .. autoclass:: transformers.TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration
:members: call :members: call
{% endif -%} {% endif -%}
{% endif -%} {% endif -%}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment