Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3d2096f5
Commit
3d2096f5
authored
Dec 18, 2019
by
thomwolf
Browse files
further cleanup
parent
8e5587fb
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
58 additions
and
76 deletions
+58
-76
examples/run_generation.py
examples/run_generation.py
+6
-7
transformers/configuration_xlm.py
transformers/configuration_xlm.py
+4
-0
transformers/modeling_utils.py
transformers/modeling_utils.py
+11
-7
transformers/modeling_xlm.py
transformers/modeling_xlm.py
+12
-27
transformers/modeling_xlnet.py
transformers/modeling_xlnet.py
+24
-34
transformers/tokenization_utils.py
transformers/tokenization_utils.py
+1
-1
No files found.
examples/run_generation.py
View file @
3d2096f5
...
@@ -91,7 +91,7 @@ def prepare_ctrl_input(args, _, tokenizer, prompt_text):
...
@@ -91,7 +91,7 @@ def prepare_ctrl_input(args, _, tokenizer, prompt_text):
def
prepare_xlm_input
(
args
,
model
,
tokenizer
,
prompt_text
):
def
prepare_xlm_input
(
args
,
model
,
tokenizer
,
prompt_text
):
kwargs
=
{
"language"
:
None
,
"mask_token"
:
None
}
kwargs
=
{
"language"
:
None
,
"mask_token
_id
"
:
None
}
# Set the language
# Set the language
use_lang_emb
=
hasattr
(
model
.
config
,
"use_lang_emb"
)
and
model
.
config
.
use_lang_emb
use_lang_emb
=
hasattr
(
model
.
config
,
"use_lang_emb"
)
and
model
.
config
.
use_lang_emb
...
@@ -112,7 +112,7 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text):
...
@@ -112,7 +112,7 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text):
# XLM masked-language modeling (MLM) models need masked token
# XLM masked-language modeling (MLM) models need masked token
is_xlm_mlm
=
"mlm"
in
args
.
model_name_or_path
is_xlm_mlm
=
"mlm"
in
args
.
model_name_or_path
if
is_xlm_mlm
:
if
is_xlm_mlm
:
kwargs
[
"mask_token"
]
=
tokenizer
.
mask_token_id
kwargs
[
"mask_token
_id
"
]
=
tokenizer
.
mask_token_id
return
prompt_text
,
kwargs
return
prompt_text
,
kwargs
...
@@ -204,14 +204,13 @@ def main():
...
@@ -204,14 +204,13 @@ def main():
prompt_text
,
model_kwargs
=
prepare_input
(
args
,
model
,
tokenizer
,
prompt_text
)
prompt_text
,
model_kwargs
=
prepare_input
(
args
,
model
,
tokenizer
,
prompt_text
)
encoded_prompt
=
torch
.
tensor
(
tokenizer
.
encode
(
prompt_text
,
add_special_tokens
=
False
)).
unsqueeze
(
0
)
encoded_prompt
=
torch
.
tensor
(
tokenizer
.
encode
(
prompt_text
,
add_special_tokens
=
False
)).
unsqueeze
(
0
)
output_sequences
=
model
.
decod
e
(
output_sequences
=
model
.
generat
e
(
promp
t_ids
=
encoded_prompt
,
intpu
t_ids
=
encoded_prompt
,
length
=
args
.
length
,
length
=
args
.
length
,
temperature
=
args
.
temperature
,
temperature
=
args
.
temperature
,
k
=
args
.
k
,
top_
k
=
args
.
k
,
p
=
args
.
p
,
top_
p
=
args
.
p
,
repetition_penalty
=
args
.
repetition_penalty
,
repetition_penalty
=
args
.
repetition_penalty
,
device
=
args
.
device
,
**
model_kwargs
,
**
model_kwargs
,
)
)
...
...
transformers/configuration_xlm.py
View file @
3d2096f5
...
@@ -113,6 +113,8 @@ class XLMConfig(PretrainedConfig):
...
@@ -113,6 +113,8 @@ class XLMConfig(PretrainedConfig):
summary_first_dropout
=
0.1
,
summary_first_dropout
=
0.1
,
start_n_top
=
5
,
start_n_top
=
5
,
end_n_top
=
5
,
end_n_top
=
5
,
mask_token_id
=
0
,
lang_id
=
0
,
**
kwargs
):
**
kwargs
):
"""Constructs XLMConfig.
"""Constructs XLMConfig.
"""
"""
...
@@ -156,6 +158,8 @@ class XLMConfig(PretrainedConfig):
...
@@ -156,6 +158,8 @@ class XLMConfig(PretrainedConfig):
self
.
summary_first_dropout
=
summary_first_dropout
self
.
summary_first_dropout
=
summary_first_dropout
self
.
start_n_top
=
start_n_top
self
.
start_n_top
=
start_n_top
self
.
end_n_top
=
end_n_top
self
.
end_n_top
=
end_n_top
self
.
mask_token_id
=
mask_token_id
self
.
lang_id
=
lang_id
else
:
else
:
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
" or the path to a pretrained model config file (str)"
)
" or the path to a pretrained model config file (str)"
)
...
...
transformers/modeling_utils.py
View file @
3d2096f5
...
@@ -488,7 +488,7 @@ class PreTrainedModel(nn.Module):
...
@@ -488,7 +488,7 @@ class PreTrainedModel(nn.Module):
def
generate
(
self
,
input_ids
=
None
,
max_length
=
None
,
do_sample
=
None
,
num_beams
=
None
,
def
generate
(
self
,
input_ids
=
None
,
max_length
=
None
,
do_sample
=
None
,
num_beams
=
None
,
temperature
=
None
,
top_k
=
None
,
top_p
=
None
,
repetition_penalty
=
None
,
temperature
=
None
,
top_k
=
None
,
top_p
=
None
,
repetition_penalty
=
None
,
bos_token_id
=
None
,
pad_token_id
=
None
,
eos_token_ids
=
None
,
batch_size
=
None
,
bos_token_id
=
None
,
pad_token_id
=
None
,
eos_token_ids
=
None
,
batch_size
=
None
,
length_penalty
=
None
,
num_return_sequences
=
None
,
**
kwargs
):
length_penalty
=
None
,
num_return_sequences
=
None
,
**
model_
kwargs
):
""" Sequence generator for models with a LM head.
""" Sequence generator for models with a LM head.
The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
...
@@ -575,11 +575,13 @@ class PreTrainedModel(nn.Module):
...
@@ -575,11 +575,13 @@ class PreTrainedModel(nn.Module):
output
=
self
.
_generate_beam_search
(
input_ids
,
cur_len
,
max_length
,
do_sample
,
output
=
self
.
_generate_beam_search
(
input_ids
,
cur_len
,
max_length
,
do_sample
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
pad_token_id
,
eos_token_ids
,
effective_batch_size
,
pad_token_id
,
eos_token_ids
,
effective_batch_size
,
length_penalty
,
num_beams
,
vocab_size
)
length_penalty
,
num_beams
,
vocab_size
,
**
model_kwargs
)
else
:
else
:
output
=
self
.
_generate_no_beam_search
(
input_ids
,
cur_len
,
max_length
,
do_sample
,
output
=
self
.
_generate_no_beam_search
(
input_ids
,
cur_len
,
max_length
,
do_sample
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
pad_token_id
,
eos_token_ids
,
effective_batch_size
)
pad_token_id
,
eos_token_ids
,
effective_batch_size
,
**
model_kwargs
)
if
num_return_sequences
!=
1
:
if
num_return_sequences
!=
1
:
output
=
output
.
view
(
batch_size
,
num_return_sequences
,
-
1
)
output
=
output
.
view
(
batch_size
,
num_return_sequences
,
-
1
)
...
@@ -587,7 +589,8 @@ class PreTrainedModel(nn.Module):
...
@@ -587,7 +589,8 @@ class PreTrainedModel(nn.Module):
def
_generate_no_beam_search
(
self
,
input_ids
,
cur_len
,
max_length
,
do_sample
,
def
_generate_no_beam_search
(
self
,
input_ids
,
cur_len
,
max_length
,
do_sample
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
pad_token_id
,
eos_token_ids
,
batch_size
):
pad_token_id
,
eos_token_ids
,
batch_size
,
**
model_kwargs
):
""" Generate sequences for each example without beam search (num_beams == 1).
""" Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
All returned sequence are generated independantly.
"""
"""
...
@@ -598,7 +601,7 @@ class PreTrainedModel(nn.Module):
...
@@ -598,7 +601,7 @@ class PreTrainedModel(nn.Module):
pasts
=
None
pasts
=
None
while
cur_len
<
max_length
:
while
cur_len
<
max_length
:
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
pasts
=
pasts
)
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
pasts
=
pasts
,
**
model_kwargs
)
outputs
=
self
(
**
model_inputs
)
outputs
=
self
(
**
model_inputs
)
next_token_logits
=
outputs
[
0
][:,
-
1
,
:]
next_token_logits
=
outputs
[
0
][:,
-
1
,
:]
...
@@ -640,7 +643,8 @@ class PreTrainedModel(nn.Module):
...
@@ -640,7 +643,8 @@ class PreTrainedModel(nn.Module):
def
_generate_beam_search
(
self
,
input_ids
,
cur_len
,
max_length
,
do_sample
,
def
_generate_beam_search
(
self
,
input_ids
,
cur_len
,
max_length
,
do_sample
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
pad_token_id
,
eos_token_ids
,
batch_size
,
pad_token_id
,
eos_token_ids
,
batch_size
,
length_penalty
,
num_beams
,
vocab_size
):
length_penalty
,
num_beams
,
vocab_size
,
**
model_kwargs
):
""" Generate sequences for each example with beam search.
""" Generate sequences for each example with beam search.
"""
"""
# Expand input to num beams
# Expand input to num beams
...
@@ -662,7 +666,7 @@ class PreTrainedModel(nn.Module):
...
@@ -662,7 +666,7 @@ class PreTrainedModel(nn.Module):
done
=
[
False
for
_
in
range
(
batch_size
)]
done
=
[
False
for
_
in
range
(
batch_size
)]
while
cur_len
<
max_length
:
while
cur_len
<
max_length
:
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
pasts
=
pasts
)
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
pasts
=
pasts
,
**
model_kwargs
)
scores
=
self
(
**
model_inputs
)[
0
]
# (batch_size * num_beams, cur_len, vocab_size)
scores
=
self
(
**
model_inputs
)[
0
]
# (batch_size * num_beams, cur_len, vocab_size)
scores
=
scores
[:,
-
1
,
:]
# (batch_size * num_beams, vocab_size)
scores
=
scores
[:,
-
1
,
:]
# (batch_size * num_beams, vocab_size)
...
...
transformers/modeling_xlm.py
View file @
3d2096f5
...
@@ -639,6 +639,18 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
...
@@ -639,6 +639,18 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
def
get_output_embeddings
(
self
):
def
get_output_embeddings
(
self
):
return
self
.
pred_layer
.
proj
return
self
.
pred_layer
.
proj
def
prepare_inputs_for_generation
(
self
,
input_ids
,
**
model_kwargs
):
mask_token_id
=
model_kwargs
[
'mask_token_id'
]
if
'mask_token_id'
in
model_kwargs
else
self
.
config
.
mask_token_id
lang_id
=
model_kwargs
[
'lang_id'
]
if
'lang_id'
in
model_kwargs
else
self
.
config
.
lang_id
mask_token
=
torch
.
full
((
1
,
1
),
mask_token_id
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
input_ids
=
torch
.
cat
([
input_ids
,
mask_token
],
dim
=
1
)
if
lang_id
is
not
None
:
langs
=
torch
.
full_like
(
input_ids
,
lang_id
)
else
:
langs
=
None
return
{
"input_ids"
:
input_ids
,
"langs"
:
langs
}
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
langs
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
langs
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
lengths
=
None
,
cache
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
labels
=
None
):
lengths
=
None
,
cache
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
labels
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
transformer_outputs
=
self
.
transformer
(
input_ids
,
...
@@ -657,33 +669,6 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
...
@@ -657,33 +669,6 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
return
outputs
return
outputs
def
_prepare_inputs_for_decoding
(
self
,
input_ids
,
**
model_kwargs
):
mask_token
=
model_kwargs
.
pop
(
"mask_token"
,
None
)
language
=
model_kwargs
.
pop
(
"language"
,
None
)
input_ids
=
self
.
_append_mask_token
(
input_ids
,
mask_token
)
langs
=
self
.
_create_language_embeddings
(
input_ids
,
language
)
arguments
=
{
"input_ids"
:
input_ids
,
"langs"
:
langs
}
arguments
.
update
(
model_kwargs
)
return
arguments
@
staticmethod
def
_append_mask_token
(
sequence
,
mask_token_id
):
""" Append a [MASK] token at the end of the sequence that the MLM model
is going to try to predict.
"""
if
mask_token_id
is
not
None
:
tokens_to_append
=
torch
.
full
((
1
,
1
),
mask_token_id
,
dtype
=
torch
.
long
)
return
torch
.
cat
((
sequence
,
tokens_to_append
),
dim
=
1
)
return
sequence
@
staticmethod
def
_create_language_embeddings
(
sequence
,
language
):
if
language
is
not
None
:
return
torch
.
tensor
([
language
]
*
sequence
.
shape
[
1
]).
view
(
1
,
-
1
)
return
None
@
add_start_docstrings
(
"""XLM Model with a sequence classification/regression head on top (a linear layer on top of
@
add_start_docstrings
(
"""XLM Model with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """
,
the pooled output) e.g. for GLUE tasks. """
,
...
...
transformers/modeling_xlnet.py
View file @
3d2096f5
...
@@ -947,6 +947,30 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -947,6 +947,30 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def
get_output_embeddings
(
self
):
def
get_output_embeddings
(
self
):
return
self
.
lm_loss
return
self
.
lm_loss
def
prepare_inputs_for_generation
(
self
,
input_ids
,
**
model_kwargs
):
# Add dummy token at the end (no attention on this one)
dummy_token
=
torch
.
zeros
((
1
,
1
),
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
input_ids
=
torch
.
cat
([
input_ids
,
dummy_token
],
dim
=
1
)
# Build permutation mask so that previous tokens don't see last token
perm_mask
=
torch
.
zeros
(
(
input_ids
.
shape
[
0
],
input_ids
.
shape
[
1
],
input_ids
.
shape
[
1
]),
dtype
=
torch
.
float
,
device
=
input_ids
.
device
)
perm_mask
[:,
:,
-
1
]
=
1.0
# We'll only predict the last token
target_mapping
=
torch
.
zeros
(
(
input_ids
.
shape
[
0
],
1
,
input_ids
.
shape
[
1
]),
dtype
=
torch
.
float
,
device
=
input_ids
.
device
)
target_mapping
[
0
,
0
,
-
1
]
=
1.0
return
{
"input_ids"
:
input_ids
,
"perm_mask"
:
perm_mask
,
"target_mapping"
:
target_mapping
}
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
token_type_ids
=
None
,
input_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
labels
=
None
):
token_type_ids
=
None
,
input_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
labels
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
transformer_outputs
=
self
.
transformer
(
input_ids
,
...
@@ -972,40 +996,6 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -972,40 +996,6 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
def
_prepare_inputs_for_decoding
(
self
,
input_ids
,
**
model_kwargs
):
input_ids
=
self
.
_add_dummy_token
(
input_ids
)
perm_mask
=
self
.
_create_perm_mask
(
input_ids
)
target_mapping
=
self
.
_create_target_mapping
(
input_ids
)
arguments
=
{
"input_ids"
:
input_ids
,
"perm_mask"
:
perm_mask
,
"target_mapping"
:
target_mapping
,
}
return
arguments
@
staticmethod
def
_add_dummy_token
(
sequence
):
dummy
=
torch
.
zeros
((
sequence
.
size
(
0
),
1
),
dtype
=
torch
.
long
)
return
torch
.
cat
((
sequence
,
dummy
),
dim
=
1
)
@
staticmethod
def
_create_perm_mask
(
sequence
):
mask
=
torch
.
zeros
(
(
sequence
.
shape
[
0
],
sequence
.
shape
[
1
],
sequence
.
shape
[
1
]),
dtype
=
torch
.
float
,
)
mask
[:,
:,
-
1
]
=
1.0
# Previous tokens don't see last token
return
mask
@
staticmethod
def
_create_target_mapping
(
sequence
):
target_mapping
=
torch
.
zeros
(
(
sequence
.
shape
[
0
],
1
,
sequence
.
shape
[
1
]),
dtype
=
torch
.
float
,
)
target_mapping
[
0
,
0
,
-
1
]
=
1.0
# predict last token
return
target_mapping
@
add_start_docstrings
(
"""XLNet Model with a sequence classification/regression head on top (a linear layer on top of
@
add_start_docstrings
(
"""XLNet Model with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """
,
the pooled output) e.g. for GLUE tasks. """
,
...
...
transformers/tokenization_utils.py
View file @
3d2096f5
...
@@ -761,7 +761,7 @@ class PreTrainedTokenizer(object):
...
@@ -761,7 +761,7 @@ class PreTrainedTokenizer(object):
padding index, up to their max length. If no max length is specified, the padding is done up to the model's max length.
padding index, up to their max length. If no max length is specified, the padding is done up to the model's max length.
The tokenizer padding sides are handled by the following strings:
The tokenizer padding sides are handled by the following strings:
- 'left': pads on the left of the sequences
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
- 'right': pads on the right of the sequences
Defaults to False: no padding.
Defaults to False: no padding.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
or PyTorch torch.Tensor instead of a list of python integers.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment