Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
62b8eb43
Commit
62b8eb43
authored
Jul 15, 2019
by
thomwolf
Browse files
fix add_start_docstrings on python 2 (removed)
parent
5bc3d0cc
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
32 additions
and
21 deletions
+32
-21
pytorch_transformers/modeling_bert.py
pytorch_transformers/modeling_bert.py
+8
-8
pytorch_transformers/modeling_gpt2.py
pytorch_transformers/modeling_gpt2.py
+5
-4
pytorch_transformers/modeling_utils.py
pytorch_transformers/modeling_utils.py
+19
-9
No files found.
pytorch_transformers/modeling_bert.py
View file @
62b8eb43
...
...
@@ -646,7 +646,7 @@ BERT_INPUTS_DOCSTRING = r"""
@
add_start_docstrings
(
"The bare Bert Model transformer outputing raw hidden-states without any specific head on top."
,
BERT_START_DOCSTRING
,
BERT_INPUTS_DOCSTRING
)
class
BertModel
(
BertPreTrainedModel
):
__doc__
=
r
"""
r
"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the last layer of the model.
...
...
@@ -742,7 +742,7 @@ class BertModel(BertPreTrainedModel):
a `masked language modeling` head and a `next sentence prediction (classification)` head. """
,
BERT_START_DOCSTRING
,
BERT_INPUTS_DOCSTRING
)
class
BertForPreTraining
(
BertPreTrainedModel
):
__doc__
=
r
"""
r
"""
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the masked language modeling loss.
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
...
...
@@ -818,7 +818,7 @@ class BertForPreTraining(BertPreTrainedModel):
@
add_start_docstrings
(
"""Bert Model transformer BERT model with a `language modeling` head on top. """
,
BERT_START_DOCSTRING
,
BERT_INPUTS_DOCSTRING
)
class
BertForMaskedLM
(
BertPreTrainedModel
):
__doc__
=
r
"""
r
"""
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the masked language modeling loss.
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
...
...
@@ -883,7 +883,7 @@ class BertForMaskedLM(BertPreTrainedModel):
@
add_start_docstrings
(
"""Bert Model transformer BERT model with a `next sentence prediction (classification)` head on top. """
,
BERT_START_DOCSTRING
,
BERT_INPUTS_DOCSTRING
)
class
BertForNextSentencePrediction
(
BertPreTrainedModel
):
__doc__
=
r
"""
r
"""
**next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
Indices should be in ``[0, 1]``.
...
...
@@ -941,7 +941,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
the pooled output) e.g. for GLUE tasks. """
,
BERT_START_DOCSTRING
,
BERT_INPUTS_DOCSTRING
)
class
BertForSequenceClassification
(
BertPreTrainedModel
):
__doc__
=
r
"""
r
"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the sequence classification/regression loss.
Indices should be in ``[0, ..., config.num_labels]``.
...
...
@@ -1009,7 +1009,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """
,
BERT_START_DOCSTRING
)
class
BertForMultipleChoice
(
BertPreTrainedModel
):
__doc__
=
r
"""
r
"""
Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
...
...
@@ -1115,7 +1115,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """
,
BERT_START_DOCSTRING
,
BERT_INPUTS_DOCSTRING
)
class
BertForTokenClassification
(
BertPreTrainedModel
):
__doc__
=
r
"""
r
"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels]``.
...
...
@@ -1182,7 +1182,7 @@ class BertForTokenClassification(BertPreTrainedModel):
the hidden-states output to compute `span start logits` and `span end logits`). """
,
BERT_START_DOCSTRING
,
BERT_INPUTS_DOCSTRING
)
class
BertForQuestionAnswering
(
BertPreTrainedModel
):
__doc__
=
r
"""
r
"""
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
...
...
pytorch_transformers/modeling_gpt2.py
View file @
62b8eb43
...
...
@@ -31,7 +31,8 @@ from torch.nn import CrossEntropyLoss
from
torch.nn.parameter
import
Parameter
from
.modeling_utils
import
(
Conv1D
,
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
,
prune_conv1d_layer
,
SequenceSummary
)
PreTrainedModel
,
prune_conv1d_layer
,
SequenceSummary
,
add_start_docstrings
)
from
.modeling_bert
import
BertLayerNorm
as
LayerNorm
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -414,7 +415,7 @@ GPT2_INPUTS_DOCTRING = r""" Inputs:
@
add_start_docstrings
(
"The bare GPT2 Model transformer outputing raw hidden-states without any specific head on top."
,
GPT2_START_DOCSTRING
,
GPT2_INPUTS_DOCTRING
)
class
GPT2Model
(
GPT2PreTrainedModel
):
__doc__
=
r
"""
r
"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the last layer of the model.
...
...
@@ -539,7 +540,7 @@ class GPT2Model(GPT2PreTrainedModel):
@
add_start_docstrings
(
"""The GPT2 Model transformer with a language modeling head on top
(linear layer with weights tied to the input embeddings). """
,
GPT2_START_DOCSTRING
,
GPT2_INPUTS_DOCTRING
)
class
GPT2LMHeadModel
(
GPT2PreTrainedModel
):
__doc__
=
r
"""
r
"""
**lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for language modeling.
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
...
...
@@ -615,7 +616,7 @@ The language modeling head has its weights tied to the input embeddings,
the classification head takes as input the input of a specified classification token index in the intput sequence).
"""
,
GPT2_START_DOCSTRING
)
class
GPT2DoubleHeadsModel
(
GPT2PreTrainedModel
):
__doc__
=
r
""" Inputs:
r
""" Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
The second dimension of the input (`num_choices`) indicates the number of choices to score.
...
...
pytorch_transformers/modeling_utils.py
View file @
62b8eb43
...
...
@@ -15,17 +15,20 @@
# limitations under the License.
"""PyTorch BERT model."""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
from
__future__
import
(
absolute_import
,
division
,
print_function
,
unicode_literals
)
import
copy
import
json
import
logging
import
os
import
json
import
copy
from
io
import
open
import
six
import
torch
from
torch
import
nn
from
torch.nn
import
CrossEntropyLoss
,
functional
as
F
from
torch.nn
import
CrossEntropyLoss
from
torch.nn
import
functional
as
F
from
.file_utils
import
cached_path
...
...
@@ -36,11 +39,18 @@ WEIGHTS_NAME = "pytorch_model.bin"
TF_WEIGHTS_NAME
=
'model.ckpt'
def
add_start_docstrings
(
*
docstr
):
if
not
six
.
PY2
:
def
add_start_docstrings
(
*
docstr
):
def
docstring_decorator
(
fn
):
fn
.
__doc__
=
''
.
join
(
docstr
)
+
fn
.
__doc__
return
fn
return
docstring_decorator
else
:
# Not possible to update class docstrings on python2
def
add_start_docstrings
(
*
docstr
):
def
docstring_decorator
(
fn
):
return
fn
return
docstring_decorator
class
PretrainedConfig
(
object
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment