Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7daacf00
Unverified
Commit
7daacf00
authored
Nov 05, 2019
by
Julien Chaumond
Committed by
GitHub
Nov 05, 2019
Browse files
Merge pull request #1695 from huggingface/models_inputs_embeds
model forwards can take an inputs_embeds param
parents
a44f112f
00337e96
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
98 additions
and
41 deletions
+98
-41
transformers/modeling_xlm.py
transformers/modeling_xlm.py
+40
-21
transformers/modeling_xlnet.py
transformers/modeling_xlnet.py
+45
-20
transformers/tests/modeling_common_test.py
transformers/tests/modeling_common_test.py
+13
-0
No files found.
transformers/modeling_xlm.py
View file @
7daacf00
...
@@ -311,6 +311,10 @@ XLM_INPUTS_DOCSTRING = r"""
...
@@ -311,6 +311,10 @@ XLM_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules.
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""
"""
@
add_start_docstrings
(
"The bare XLM Model transformer outputting raw hidden-states without any specific head on top."
,
@
add_start_docstrings
(
"The bare XLM Model transformer outputting raw hidden-states without any specific head on top."
,
...
@@ -421,14 +425,21 @@ class XLMModel(XLMPreTrainedModel):
...
@@ -421,14 +425,21 @@ class XLMModel(XLMPreTrainedModel):
for
layer
,
heads
in
heads_to_prune
.
items
():
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
attentions
[
layer
].
prune_heads
(
heads
)
self
.
attentions
[
layer
].
prune_heads
(
heads
)
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
langs
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
langs
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
lengths
=
None
,
cache
=
None
,
head_mask
=
None
):
# removed: src_enc=None, src_len=None
lengths
=
None
,
cache
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
):
# removed: src_enc=None, src_len=None
if
input_ids
is
not
None
:
bs
,
slen
=
input_ids
.
size
()
else
:
bs
,
slen
=
inputs_embeds
.
size
()[:
-
1
]
if
lengths
is
None
:
if
lengths
is
None
:
lengths
=
(
input_ids
!=
self
.
pad_index
).
sum
(
dim
=
1
).
long
()
if
input_ids
is
not
None
:
lengths
=
(
input_ids
!=
self
.
pad_index
).
sum
(
dim
=
1
).
long
()
else
:
lengths
=
torch
.
LongTensor
([
slen
]
*
bs
)
# mask = input_ids != self.pad_index
# mask = input_ids != self.pad_index
# check inputs
# check inputs
bs
,
slen
=
input_ids
.
size
()
assert
lengths
.
size
(
0
)
==
bs
assert
lengths
.
size
(
0
)
==
bs
assert
lengths
.
max
().
item
()
<=
slen
assert
lengths
.
max
().
item
()
<=
slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
...
@@ -442,10 +453,12 @@ class XLMModel(XLMPreTrainedModel):
...
@@ -442,10 +453,12 @@ class XLMModel(XLMPreTrainedModel):
# if self.is_decoder and src_enc is not None:
# if self.is_decoder and src_enc is not None:
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
# position_ids
# position_ids
if
position_ids
is
None
:
if
position_ids
is
None
:
position_ids
=
input_ids
.
new
((
slen
,)).
long
(
)
position_ids
=
torch
.
arange
(
slen
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
torch
.
arange
(
slen
,
out
=
position_ids
)
.
unsqueeze
(
0
)
position_ids
=
position_ids
.
unsqueeze
(
0
)
.
expand
((
bs
,
slen
))
else
:
else
:
assert
position_ids
.
size
()
==
(
bs
,
slen
)
# (slen, bs)
assert
position_ids
.
size
()
==
(
bs
,
slen
)
# (slen, bs)
# position_ids = position_ids.transpose(0, 1)
# position_ids = position_ids.transpose(0, 1)
...
@@ -471,7 +484,7 @@ class XLMModel(XLMPreTrainedModel):
...
@@ -471,7 +484,7 @@ class XLMModel(XLMPreTrainedModel):
head_mask
=
[
None
]
*
self
.
n_layers
head_mask
=
[
None
]
*
self
.
n_layers
# do not recompute cached elements
# do not recompute cached elements
if
cache
is
not
None
:
if
cache
is
not
None
and
input_ids
is
not
None
:
_slen
=
slen
-
cache
[
'slen'
]
_slen
=
slen
-
cache
[
'slen'
]
input_ids
=
input_ids
[:,
-
_slen
:]
input_ids
=
input_ids
[:,
-
_slen
:]
position_ids
=
position_ids
[:,
-
_slen
:]
position_ids
=
position_ids
[:,
-
_slen
:]
...
@@ -481,8 +494,10 @@ class XLMModel(XLMPreTrainedModel):
...
@@ -481,8 +494,10 @@ class XLMModel(XLMPreTrainedModel):
attn_mask
=
attn_mask
[:,
-
_slen
:]
attn_mask
=
attn_mask
[:,
-
_slen
:]
# embeddings
# embeddings
tensor
=
self
.
embeddings
(
input_ids
)
if
inputs_embeds
is
None
:
tensor
=
tensor
+
self
.
position_embeddings
(
position_ids
).
expand_as
(
tensor
)
inputs_embeds
=
self
.
embeddings
(
input_ids
)
tensor
=
inputs_embeds
+
self
.
position_embeddings
(
position_ids
).
expand_as
(
inputs_embeds
)
if
langs
is
not
None
and
self
.
use_lang_emb
:
if
langs
is
not
None
and
self
.
use_lang_emb
:
tensor
=
tensor
+
self
.
lang_embeddings
(
langs
)
tensor
=
tensor
+
self
.
lang_embeddings
(
langs
)
if
token_type_ids
is
not
None
:
if
token_type_ids
is
not
None
:
...
@@ -624,8 +639,8 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
...
@@ -624,8 +639,8 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
def
get_output_embeddings
(
self
):
def
get_output_embeddings
(
self
):
return
self
.
pred_layer
.
proj
return
self
.
pred_layer
.
proj
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
langs
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
langs
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
lengths
=
None
,
cache
=
None
,
head_mask
=
None
,
labels
=
None
):
lengths
=
None
,
cache
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
labels
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
transformer_outputs
=
self
.
transformer
(
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
langs
=
langs
,
langs
=
langs
,
...
@@ -633,7 +648,8 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
...
@@ -633,7 +648,8 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
position_ids
=
position_ids
,
position_ids
=
position_ids
,
lengths
=
lengths
,
lengths
=
lengths
,
cache
=
cache
,
cache
=
cache
,
head_mask
=
head_mask
)
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
)
output
=
transformer_outputs
[
0
]
output
=
transformer_outputs
[
0
]
outputs
=
self
.
pred_layer
(
output
,
labels
)
outputs
=
self
.
pred_layer
(
output
,
labels
)
...
@@ -685,8 +701,8 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
...
@@ -685,8 +701,8 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
langs
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
langs
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
lengths
=
None
,
cache
=
None
,
head_mask
=
None
,
labels
=
None
):
lengths
=
None
,
cache
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
labels
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
transformer_outputs
=
self
.
transformer
(
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
langs
=
langs
,
langs
=
langs
,
...
@@ -694,7 +710,8 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
...
@@ -694,7 +710,8 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
position_ids
=
position_ids
,
position_ids
=
position_ids
,
lengths
=
lengths
,
lengths
=
lengths
,
cache
=
cache
,
cache
=
cache
,
head_mask
=
head_mask
)
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
)
output
=
transformer_outputs
[
0
]
output
=
transformer_outputs
[
0
]
logits
=
self
.
sequence_summary
(
output
)
logits
=
self
.
sequence_summary
(
output
)
...
@@ -768,8 +785,8 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
...
@@ -768,8 +785,8 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
langs
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
langs
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
lengths
=
None
,
cache
=
None
,
head_mask
=
None
,
start_positions
=
None
,
end_positions
=
None
):
lengths
=
None
,
cache
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
start_positions
=
None
,
end_positions
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
transformer_outputs
=
self
.
transformer
(
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
langs
=
langs
,
langs
=
langs
,
...
@@ -777,7 +794,8 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
...
@@ -777,7 +794,8 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
position_ids
=
position_ids
,
position_ids
=
position_ids
,
lengths
=
lengths
,
lengths
=
lengths
,
cache
=
cache
,
cache
=
cache
,
head_mask
=
head_mask
)
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
)
sequence_output
=
transformer_outputs
[
0
]
sequence_output
=
transformer_outputs
[
0
]
...
@@ -863,8 +881,8 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
...
@@ -863,8 +881,8 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
langs
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
langs
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
lengths
=
None
,
cache
=
None
,
head_mask
=
None
,
start_positions
=
None
,
end_positions
=
None
,
lengths
=
None
,
cache
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
start_positions
=
None
,
end_positions
=
None
,
is_impossible
=
None
,
cls_index
=
None
,
p_mask
=
None
):
is_impossible
=
None
,
cls_index
=
None
,
p_mask
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
transformer_outputs
=
self
.
transformer
(
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
...
@@ -873,7 +891,8 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
...
@@ -873,7 +891,8 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
position_ids
=
position_ids
,
position_ids
=
position_ids
,
lengths
=
lengths
,
lengths
=
lengths
,
cache
=
cache
,
cache
=
cache
,
head_mask
=
head_mask
)
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
)
output
=
transformer_outputs
[
0
]
output
=
transformer_outputs
[
0
]
...
...
transformers/modeling_xlnet.py
View file @
7daacf00
...
@@ -558,6 +558,10 @@ XLNET_INPUTS_DOCSTRING = r"""
...
@@ -558,6 +558,10 @@ XLNET_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules.
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""
"""
@
add_start_docstrings
(
"The bare XLNet Model transformer outputting raw hidden-states without any specific head on top."
,
@
add_start_docstrings
(
"The bare XLNet Model transformer outputting raw hidden-states without any specific head on top."
,
...
@@ -712,19 +716,29 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -712,19 +716,29 @@ class XLNetModel(XLNetPreTrainedModel):
pos_emb
=
pos_emb
.
to
(
next
(
self
.
parameters
()))
pos_emb
=
pos_emb
.
to
(
next
(
self
.
parameters
()))
return
pos_emb
return
pos_emb
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
token_type_ids
=
None
,
input_mask
=
None
,
head_mask
=
None
):
token_type_ids
=
None
,
input_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
):
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension
# but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end
# so we move here the first dimension (batch) to the end
input_ids
=
input_ids
.
transpose
(
0
,
1
).
contiguous
()
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif
input_ids
is
not
None
:
input_ids
=
input_ids
.
transpose
(
0
,
1
).
contiguous
()
qlen
,
bsz
=
input_ids
.
shape
[
0
],
input_ids
.
shape
[
1
]
elif
inputs_embeds
is
not
None
:
inputs_embeds
.
transpose
(
0
,
1
).
contiguous
()
qlen
,
bsz
=
inputs_embeds
.
shape
[
0
],
inputs_embeds
.
shape
[
1
]
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
token_type_ids
=
token_type_ids
.
transpose
(
0
,
1
).
contiguous
()
if
token_type_ids
is
not
None
else
None
token_type_ids
=
token_type_ids
.
transpose
(
0
,
1
).
contiguous
()
if
token_type_ids
is
not
None
else
None
input_mask
=
input_mask
.
transpose
(
0
,
1
).
contiguous
()
if
input_mask
is
not
None
else
None
input_mask
=
input_mask
.
transpose
(
0
,
1
).
contiguous
()
if
input_mask
is
not
None
else
None
attention_mask
=
attention_mask
.
transpose
(
0
,
1
).
contiguous
()
if
attention_mask
is
not
None
else
None
attention_mask
=
attention_mask
.
transpose
(
0
,
1
).
contiguous
()
if
attention_mask
is
not
None
else
None
perm_mask
=
perm_mask
.
permute
(
1
,
2
,
0
).
contiguous
()
if
perm_mask
is
not
None
else
None
perm_mask
=
perm_mask
.
permute
(
1
,
2
,
0
).
contiguous
()
if
perm_mask
is
not
None
else
None
target_mapping
=
target_mapping
.
permute
(
1
,
2
,
0
).
contiguous
()
if
target_mapping
is
not
None
else
None
target_mapping
=
target_mapping
.
permute
(
1
,
2
,
0
).
contiguous
()
if
target_mapping
is
not
None
else
None
qlen
,
bsz
=
input_ids
.
shape
[
0
],
input_ids
.
shape
[
1
]
mlen
=
mems
[
0
].
shape
[
0
]
if
mems
is
not
None
and
mems
[
0
]
is
not
None
else
0
mlen
=
mems
[
0
].
shape
[
0
]
if
mems
is
not
None
and
mems
[
0
]
is
not
None
else
0
klen
=
mlen
+
qlen
klen
=
mlen
+
qlen
...
@@ -777,7 +791,10 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -777,7 +791,10 @@ class XLNetModel(XLNetPreTrainedModel):
non_tgt_mask
=
None
non_tgt_mask
=
None
##### Word embeddings and prepare h & g hidden states
##### Word embeddings and prepare h & g hidden states
word_emb_k
=
self
.
word_embedding
(
input_ids
)
if
inputs_embeds
is
not
None
:
word_emb_k
=
inputs_embeds
else
:
word_emb_k
=
self
.
word_embedding
(
input_ids
)
output_h
=
self
.
dropout
(
word_emb_k
)
output_h
=
self
.
dropout
(
word_emb_k
)
if
target_mapping
is
not
None
:
if
target_mapping
is
not
None
:
word_emb_q
=
self
.
mask_emb
.
expand
(
target_mapping
.
shape
[
0
],
bsz
,
-
1
)
word_emb_q
=
self
.
mask_emb
.
expand
(
target_mapping
.
shape
[
0
],
bsz
,
-
1
)
...
@@ -924,8 +941,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -924,8 +941,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def
get_output_embeddings
(
self
):
def
get_output_embeddings
(
self
):
return
self
.
lm_loss
return
self
.
lm_loss
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
token_type_ids
=
None
,
input_mask
=
None
,
head_mask
=
None
,
labels
=
None
):
token_type_ids
=
None
,
input_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
labels
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
transformer_outputs
=
self
.
transformer
(
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
mems
=
mems
,
mems
=
mems
,
...
@@ -933,7 +950,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -933,7 +950,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
target_mapping
=
target_mapping
,
target_mapping
=
target_mapping
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
input_mask
=
input_mask
,
input_mask
=
input_mask
,
head_mask
=
head_mask
)
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
)
logits
=
self
.
lm_loss
(
transformer_outputs
[
0
])
logits
=
self
.
lm_loss
(
transformer_outputs
[
0
])
...
@@ -998,8 +1016,8 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
...
@@ -998,8 +1016,8 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
token_type_ids
=
None
,
input_mask
=
None
,
head_mask
=
None
,
labels
=
None
):
token_type_ids
=
None
,
input_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
labels
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
transformer_outputs
=
self
.
transformer
(
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
mems
=
mems
,
mems
=
mems
,
...
@@ -1007,7 +1025,8 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
...
@@ -1007,7 +1025,8 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
target_mapping
=
target_mapping
,
target_mapping
=
target_mapping
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
input_mask
=
input_mask
,
input_mask
=
input_mask
,
head_mask
=
head_mask
)
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
)
output
=
transformer_outputs
[
0
]
output
=
transformer_outputs
[
0
]
output
=
self
.
sequence_summary
(
output
)
output
=
self
.
sequence_summary
(
output
)
...
@@ -1049,6 +1068,10 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
...
@@ -1049,6 +1068,10 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
Mask to nullify selected heads of the self-attention modules.
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the multiple choice classification loss.
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
...
@@ -1093,9 +1116,9 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
...
@@ -1093,9 +1116,9 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
def
forward
(
self
,
input_ids
=
None
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
labels
=
None
,
head_mask
=
None
):
labels
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
):
num_choices
=
input_ids
.
shape
[
1
]
num_choices
=
input_ids
.
shape
[
1
]
flat_input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
flat_input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
...
@@ -1106,7 +1129,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
...
@@ -1106,7 +1129,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
transformer_outputs
=
self
.
transformer
(
flat_input_ids
,
token_type_ids
=
flat_token_type_ids
,
transformer_outputs
=
self
.
transformer
(
flat_input_ids
,
token_type_ids
=
flat_token_type_ids
,
input_mask
=
flat_input_mask
,
attention_mask
=
flat_attention_mask
,
input_mask
=
flat_input_mask
,
attention_mask
=
flat_attention_mask
,
mems
=
mems
,
perm_mask
=
perm_mask
,
target_mapping
=
target_mapping
,
mems
=
mems
,
perm_mask
=
perm_mask
,
target_mapping
=
target_mapping
,
head_mask
=
head_mask
)
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
)
output
=
transformer_outputs
[
0
]
output
=
transformer_outputs
[
0
]
...
@@ -1178,8 +1201,8 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
...
@@ -1178,8 +1201,8 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
token_type_ids
=
None
,
input_mask
=
None
,
head_mask
=
None
,
token_type_ids
=
None
,
input_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
start_positions
=
None
,
end_positions
=
None
):
start_positions
=
None
,
end_positions
=
None
):
outputs
=
self
.
transformer
(
input_ids
,
outputs
=
self
.
transformer
(
input_ids
,
...
@@ -1189,7 +1212,8 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
...
@@ -1189,7 +1212,8 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
target_mapping
=
target_mapping
,
target_mapping
=
target_mapping
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
input_mask
=
input_mask
,
input_mask
=
input_mask
,
head_mask
=
head_mask
)
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
)
sequence_output
=
outputs
[
0
]
sequence_output
=
outputs
[
0
]
...
@@ -1294,8 +1318,8 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
...
@@ -1294,8 +1318,8 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
token_type_ids
=
None
,
input_mask
=
None
,
head_mask
=
None
,
token_type_ids
=
None
,
input_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
start_positions
=
None
,
end_positions
=
None
,
is_impossible
=
None
,
cls_index
=
None
,
p_mask
=
None
,):
start_positions
=
None
,
end_positions
=
None
,
is_impossible
=
None
,
cls_index
=
None
,
p_mask
=
None
,):
transformer_outputs
=
self
.
transformer
(
input_ids
,
transformer_outputs
=
self
.
transformer
(
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
...
@@ -1304,7 +1328,8 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
...
@@ -1304,7 +1328,8 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
target_mapping
=
target_mapping
,
target_mapping
=
target_mapping
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
input_mask
=
input_mask
,
input_mask
=
input_mask
,
head_mask
=
head_mask
)
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
)
hidden_states
=
transformer_outputs
[
0
]
hidden_states
=
transformer_outputs
[
0
]
start_logits
=
self
.
start_logits
(
hidden_states
,
p_mask
=
p_mask
)
start_logits
=
self
.
start_logits
(
hidden_states
,
p_mask
=
p_mask
)
...
...
transformers/tests/modeling_common_test.py
View file @
7daacf00
...
@@ -525,6 +525,19 @@ class CommonTestCases:
...
@@ -525,6 +525,19 @@ class CommonTestCases:
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
def
test_inputs_embeds
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
input_ids
=
inputs_dict
[
"input_ids"
]
del
inputs_dict
[
"input_ids"
]
for
model_class
in
self
.
all_model_classes
:
model
=
model_class
(
config
)
model
.
eval
()
wte
=
model
.
get_input_embeddings
()
inputs_dict
[
"inputs_embeds"
]
=
wte
(
input_ids
)
outputs
=
model
(
**
inputs_dict
)
class
GPTModelTester
(
CommonModelTester
):
class
GPTModelTester
(
CommonModelTester
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment