Unverified Commit 700229f8 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fixes in the templates (#10951)



* Fixes in the templates

* Define in all cases

* Dimensionality -> Dimension
Co-authored-by: default avatarLysandre <lysandre.debut@reseau.eseo.fr>
parent 05c966f2
...@@ -44,16 +44,14 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig): ...@@ -44,16 +44,14 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
Vocabulary size of the {{cookiecutter.modelname}} model. Defines the number of different tokens that can be represented by the Vocabulary size of the {{cookiecutter.modelname}} model. Defines the number of different tokens that can be represented by the
:obj:`inputs_ids` passed when calling :class:`~transformers.{{cookiecutter.camelcase_modelname}}Model` or :obj:`inputs_ids` passed when calling :class:`~transformers.{{cookiecutter.camelcase_modelname}}Model` or
:class:`~transformers.TF{{cookiecutter.camelcase_modelname}}Model`. :class:`~transformers.TF{{cookiecutter.camelcase_modelname}}Model`.
Vocabulary size of the model. Defines the different tokens that
can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.{{cookiecutter.camelcase_modelname}}Model`.
hidden_size (:obj:`int`, `optional`, defaults to 768): hidden_size (:obj:`int`, `optional`, defaults to 768):
Dimensionality of the encoder layers and the pooler layer. Dimension of the encoder layers and the pooler layer.
num_hidden_layers (:obj:`int`, `optional`, defaults to 12): num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
Number of hidden layers in the Transformer encoder. Number of hidden layers in the Transformer encoder.
num_attention_heads (:obj:`int`, `optional`, defaults to 12): num_attention_heads (:obj:`int`, `optional`, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder. Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (:obj:`int`, `optional`, defaults to 3072): intermediate_size (:obj:`int`, `optional`, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`): hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. The non-linear activation function (function or string) in the encoder and pooler.
If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported. If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
...@@ -75,14 +73,14 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig): ...@@ -75,14 +73,14 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
Whether or not the model should return the last key/values attentions (not used by all models). Only Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if ``config.is_decoder=True``. relevant if ``config.is_decoder=True``.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass. If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass.
{% else -%} {% else -%}
vocab_size (:obj:`int`, `optional`, defaults to 50265): vocab_size (:obj:`int`, `optional`, defaults to 50265):
Vocabulary size of the {{cookiecutter.modelname}} model. Defines the number of different tokens that can be represented by the Vocabulary size of the {{cookiecutter.modelname}} model. Defines the number of different tokens that can be represented by the
:obj:`inputs_ids` passed when calling :class:`~transformers.{{cookiecutter.camelcase_modelname}}Model` or :obj:`inputs_ids` passed when calling :class:`~transformers.{{cookiecutter.camelcase_modelname}}Model` or
:class:`~transformers.TF{{cookiecutter.camelcase_modelname}}Model`. :class:`~transformers.TF{{cookiecutter.camelcase_modelname}}Model`.
d_model (:obj:`int`, `optional`, defaults to 1024): d_model (:obj:`int`, `optional`, defaults to 1024):
Dimensionality of the layers and the pooler layer. Dimension of the layers and the pooler layer.
encoder_layers (:obj:`int`, `optional`, defaults to 12): encoder_layers (:obj:`int`, `optional`, defaults to 12):
Number of encoder layers. Number of encoder layers.
decoder_layers (:obj:`int`, `optional`, defaults to 12): decoder_layers (:obj:`int`, `optional`, defaults to 12):
...@@ -92,9 +90,9 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig): ...@@ -92,9 +90,9 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
decoder_attention_heads (:obj:`int`, `optional`, defaults to 16): decoder_attention_heads (:obj:`int`, `optional`, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder. Number of attention heads for each attention layer in the Transformer decoder.
decoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096): decoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. Dimension of the "intermediate" (often named feed-forward) layer in decoder.
encoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096): encoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. Dimension of the "intermediate" (often named feed-forward) layer in decoder.
activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`): activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported. :obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
......
...@@ -60,6 +60,7 @@ from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.c ...@@ -60,6 +60,7 @@ from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.c
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "{{cookiecutter.checkpoint_identifier}}"
_CONFIG_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Config" _CONFIG_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Config"
_TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer" _TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
...@@ -730,7 +731,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod ...@@ -730,7 +731,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="{{cookiecutter.checkpoint_identifier}}", checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFBaseModelOutputWithPooling, output_type=TFBaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
...@@ -807,7 +808,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca ...@@ -807,7 +808,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="{{cookiecutter.checkpoint_identifier}}", checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFMaskedLMOutput, output_type=TFMaskedLMOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
...@@ -903,7 +904,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca ...@@ -903,7 +904,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="{{cookiecutter.checkpoint_identifier}}", checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFCausalLMOutput, output_type=TFCausalLMOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
...@@ -1031,7 +1032,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie ...@@ -1031,7 +1032,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="{{cookiecutter.checkpoint_identifier}}", checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFSequenceClassifierOutput, output_type=TFSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
...@@ -1137,7 +1138,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c ...@@ -1137,7 +1138,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="{{cookiecutter.checkpoint_identifier}}", checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFMultipleChoiceModelOutput, output_type=TFMultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
...@@ -1280,7 +1281,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut ...@@ -1280,7 +1281,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="{{cookiecutter.checkpoint_identifier}}", checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFTokenClassifierOutput, output_type=TFTokenClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
...@@ -1376,7 +1377,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte ...@@ -1376,7 +1377,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="{{cookiecutter.checkpoint_identifier}}", checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFQuestionAnsweringModelOutput, output_type=TFQuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
...@@ -1504,6 +1505,7 @@ from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.c ...@@ -1504,6 +1505,7 @@ from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.c
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "{{cookiecutter.checkpoint_identifier}}"
_CONFIG_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Config" _CONFIG_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Config"
_TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer" _TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
...@@ -2512,7 +2514,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod ...@@ -2512,7 +2514,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="{{cookiecutter.checkpoint_identifier}}", checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFSeq2SeqModelOutput, output_type=TFSeq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
......
...@@ -54,6 +54,7 @@ from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.c ...@@ -54,6 +54,7 @@ from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.c
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "{{cookiecutter.checkpoint_identifier}}"
_CONFIG_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Config" _CONFIG_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Config"
_TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer" _TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
...@@ -779,7 +780,7 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna ...@@ -779,7 +780,7 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="{{cookiecutter.checkpoint_identifier}}", checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPastAndCrossAttentions, output_type=BaseModelOutputWithPastAndCrossAttentions,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
...@@ -932,7 +933,7 @@ class {{cookiecutter.camelcase_modelname}}ForMaskedLM({{cookiecutter.camelcase_m ...@@ -932,7 +933,7 @@ class {{cookiecutter.camelcase_modelname}}ForMaskedLM({{cookiecutter.camelcase_m
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="{{cookiecutter.checkpoint_identifier}}", checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MaskedLMOutput, output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
...@@ -1190,7 +1191,7 @@ class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutt ...@@ -1190,7 +1191,7 @@ class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutt
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="{{cookiecutter.checkpoint_identifier}}", checkpoint=_CHECKPOINT_FOR_DOC,
output_type=SequenceClassifierOutput, output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
...@@ -1270,7 +1271,7 @@ class {{cookiecutter.camelcase_modelname}}ForMultipleChoice({{cookiecutter.camel ...@@ -1270,7 +1271,7 @@ class {{cookiecutter.camelcase_modelname}}ForMultipleChoice({{cookiecutter.camel
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="{{cookiecutter.checkpoint_identifier}}", checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MultipleChoiceModelOutput, output_type=MultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
...@@ -1360,7 +1361,7 @@ class {{cookiecutter.camelcase_modelname}}ForTokenClassification({{cookiecutter. ...@@ -1360,7 +1361,7 @@ class {{cookiecutter.camelcase_modelname}}ForTokenClassification({{cookiecutter.
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="{{cookiecutter.checkpoint_identifier}}", checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput, output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
...@@ -1447,7 +1448,7 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca ...@@ -1447,7 +1448,7 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="{{cookiecutter.checkpoint_identifier}}", checkpoint=_CHECKPOINT_FOR_DOC,
output_type=QuestionAnsweringModelOutput, output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
...@@ -1559,6 +1560,7 @@ from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.c ...@@ -1559,6 +1560,7 @@ from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.c
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "{{cookiecutter.checkpoint_identifier}}"
_CONFIG_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Config" _CONFIG_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Config"
_TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer" _TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
...@@ -2607,7 +2609,7 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna ...@@ -2607,7 +2609,7 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="{{cookiecutter.checkpoint_identifier}}", checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Seq2SeqModelOutput, output_type=Seq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
...@@ -2875,7 +2877,7 @@ class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutt ...@@ -2875,7 +2877,7 @@ class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutt
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="{{cookiecutter.checkpoint_identifier}}", checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Seq2SeqSequenceClassifierOutput, output_type=Seq2SeqSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
...@@ -2976,7 +2978,7 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca ...@@ -2976,7 +2978,7 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="{{cookiecutter.checkpoint_identifier}}", checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Seq2SeqQuestionAnsweringModelOutput, output_type=Seq2SeqQuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment