Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5bc3d0cc
"...git@developer.sourcefind.cn:modelzoo/yolox_mmcv.git" did not exist on "baf20b93f0143e05b4acf6c85d5398abb24753cc"
Commit
5bc3d0cc
authored
Jul 15, 2019
by
thomwolf
Browse files
added gpt2 doc
parent
183fedfe
Changes
2
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
214 additions
and
215 deletions
+214
-215
pytorch_transformers/modeling_bert.py
pytorch_transformers/modeling_bert.py
+30
-25
pytorch_transformers/modeling_gpt2.py
pytorch_transformers/modeling_gpt2.py
+184
-190
No files found.
pytorch_transformers/modeling_bert.py
View file @
5bc3d0cc
...
...
@@ -277,8 +277,9 @@ class BertEmbeddings(nn.Module):
self
.
LayerNorm
=
BertLayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
):
seq_length
=
input_ids
.
size
(
1
)
if
position_ids
is
None
:
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
input_ids
)
if
token_type_ids
is
None
:
...
...
@@ -624,6 +625,9 @@ BERT_INPUTS_DOCSTRING = r"""
Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1[``.
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
...
...
@@ -687,7 +691,7 @@ class BertModel(BertPreTrainedModel):
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
encoder
.
layer
[
layer
].
attention
.
prune_heads
(
heads
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
head_mask
=
None
):
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones_like
(
input_ids
)
if
token_type_ids
is
None
:
...
...
@@ -723,7 +727,7 @@ class BertModel(BertPreTrainedModel):
else
:
head_mask
=
[
None
]
*
self
.
config
.
num_hidden_layers
embedding_output
=
self
.
embeddings
(
input_ids
,
token_type_ids
)
embedding_output
=
self
.
embeddings
(
input_ids
,
position_ids
,
token_type_ids
)
encoder_outputs
=
self
.
encoder
(
embedding_output
,
extended_attention_mask
,
head_mask
=
head_mask
)
...
...
@@ -773,7 +777,7 @@ class BertForPreTraining(BertPreTrainedModel):
>>> model = BertForPreTraining(config)
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
>>> outputs = model(input_ids)
>>> prediction_scores, seq_relationship_scores = outputs[:
1
]
>>> prediction_scores, seq_relationship_scores = outputs[:
2
]
"""
def
__init__
(
self
,
config
):
...
...
@@ -792,9 +796,9 @@ class BertForPreTraining(BertPreTrainedModel):
self
.
_tie_or_clone_weights
(
self
.
cls
.
predictions
.
decoder
,
self
.
bert
.
embeddings
.
word_embeddings
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
masked_lm_labels
=
None
,
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
masked_lm_labels
=
None
,
next_sentence_label
=
None
,
head_mask
=
None
):
outputs
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
head_mask
=
head_mask
)
outputs
=
self
.
bert
(
input_ids
,
position_ids
,
token_type_ids
,
attention_mask
,
head_mask
=
head_mask
)
sequence_output
,
pooled_output
=
outputs
[:
2
]
prediction_scores
,
seq_relationship_score
=
self
.
cls
(
sequence_output
,
pooled_output
)
...
...
@@ -842,7 +846,7 @@ class BertForMaskedLM(BertPreTrainedModel):
>>> model = BertForMaskedLM(config)
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
>>> outputs = model(input_ids, masked_lm_labels=input_ids)
>>> loss, prediction_scores = outputs[:
1
]
>>> loss, prediction_scores = outputs[:
2
]
"""
def
__init__
(
self
,
config
):
...
...
@@ -861,8 +865,8 @@ class BertForMaskedLM(BertPreTrainedModel):
self
.
_tie_or_clone_weights
(
self
.
cls
.
predictions
.
decoder
,
self
.
bert
.
embeddings
.
word_embeddings
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
masked_lm_labels
=
None
,
head_mask
=
None
):
outputs
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
head_mask
=
head_mask
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
masked_lm_labels
=
None
,
head_mask
=
None
):
outputs
=
self
.
bert
(
input_ids
,
position_ids
,
token_type_ids
,
attention_mask
,
head_mask
=
head_mask
)
sequence_output
=
outputs
[
0
]
prediction_scores
=
self
.
cls
(
sequence_output
)
...
...
@@ -918,8 +922,8 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
next_sentence_label
=
None
,
head_mask
=
None
):
outputs
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
head_mask
=
head_mask
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
next_sentence_label
=
None
,
head_mask
=
None
):
outputs
=
self
.
bert
(
input_ids
,
position_ids
,
token_type_ids
,
attention_mask
,
head_mask
=
head_mask
)
pooled_output
=
outputs
[
1
]
seq_relationship_score
=
self
.
cls
(
pooled_output
)
...
...
@@ -966,7 +970,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
>>> labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
>>> outputs = model(input_ids, labels=labels)
>>> loss, logits = outputs[:
1
]
>>> loss, logits = outputs[:
2
]
"""
def
__init__
(
self
,
config
):
...
...
@@ -979,8 +983,8 @@ class BertForSequenceClassification(BertPreTrainedModel):
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
,
head_mask
=
None
):
outputs
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
head_mask
=
head_mask
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
,
head_mask
=
None
):
outputs
=
self
.
bert
(
input_ids
,
position_ids
,
token_type_ids
,
attention_mask
,
head_mask
=
head_mask
)
pooled_output
=
outputs
[
1
]
pooled_output
=
self
.
dropout
(
pooled_output
)
...
...
@@ -1071,7 +1075,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
>>> input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
>>> labels = torch.tensor(1).unsqueeze(0) # Batch size 1
>>> outputs = model(input_ids, labels=labels)
>>> loss, classification_scores = outputs[:
1
]
>>> loss, classification_scores = outputs[:
2
]
"""
def
__init__
(
self
,
config
):
...
...
@@ -1083,13 +1087,14 @@ class BertForMultipleChoice(BertPreTrainedModel):
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
,
head_mask
=
None
):
num_choices
=
input_ids
.
shape
[
1
]
flat_input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
flat_position_ids
=
position_ids
.
view
(
-
1
,
position_ids
.
size
(
-
1
))
if
position_ids
is
not
None
else
None
flat_token_type_ids
=
token_type_ids
.
view
(
-
1
,
token_type_ids
.
size
(
-
1
))
if
token_type_ids
is
not
None
else
None
flat_attention_mask
=
attention_mask
.
view
(
-
1
,
attention_mask
.
size
(
-
1
))
if
attention_mask
is
not
None
else
None
outputs
=
self
.
bert
(
flat_input_ids
,
flat_token_type_ids
,
flat_attention_mask
,
head_mask
=
head_mask
)
outputs
=
self
.
bert
(
flat_input_ids
,
flat_position_ids
,
flat_token_type_ids
,
flat_attention_mask
,
head_mask
=
head_mask
)
pooled_output
=
outputs
[
1
]
pooled_output
=
self
.
dropout
(
pooled_output
)
...
...
@@ -1137,7 +1142,7 @@ class BertForTokenClassification(BertPreTrainedModel):
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
>>> labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
>>> outputs = model(input_ids, labels=labels)
>>> loss, scores = outputs[:
1
]
>>> loss, scores = outputs[:
2
]
"""
def
__init__
(
self
,
config
):
...
...
@@ -1150,8 +1155,8 @@ class BertForTokenClassification(BertPreTrainedModel):
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
,
head_mask
=
None
):
outputs
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
head_mask
=
head_mask
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
,
head_mask
=
None
):
outputs
=
self
.
bert
(
input_ids
,
position_ids
,
token_type_ids
,
attention_mask
,
head_mask
=
head_mask
)
sequence_output
=
outputs
[
0
]
sequence_output
=
self
.
dropout
(
sequence_output
)
...
...
@@ -1177,7 +1182,7 @@ class BertForTokenClassification(BertPreTrainedModel):
the hidden-states output to compute `span start logits` and `span end logits`). """
,
BERT_START_DOCSTRING
,
BERT_INPUTS_DOCSTRING
)
class
BertForQuestionAnswering
(
BertPreTrainedModel
):
r
"""
__doc__
=
r
"""
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
...
...
@@ -1224,9 +1229,9 @@ class BertForQuestionAnswering(BertPreTrainedModel):
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
start_positions
=
None
,
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
start_positions
=
None
,
end_positions
=
None
,
head_mask
=
None
):
outputs
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
head_mask
=
head_mask
)
outputs
=
self
.
bert
(
input_ids
,
position_ids
,
token_type_ids
,
attention_mask
,
head_mask
=
head_mask
)
sequence_output
=
outputs
[
0
]
logits
=
self
.
qa_outputs
(
sequence_output
)
...
...
pytorch_transformers/modeling_gpt2.py
View file @
5bc3d0cc
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment