Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ed9b8481
Unverified
Commit
ed9b8481
authored
Dec 21, 2019
by
Thomas Wolf
Committed by
GitHub
Dec 21, 2019
Browse files
Merge pull request #1840 from huggingface/generation_sampler
[WIP] Sampling sequence generator for transformers
parents
7e17f09f
f86ed231
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
639 additions
and
208 deletions
+639
-208
examples/run_generation.py
examples/run_generation.py
+139
-177
transformers/configuration_utils.py
transformers/configuration_utils.py
+16
-0
transformers/configuration_xlm.py
transformers/configuration_xlm.py
+4
-0
transformers/modeling_encoder_decoder.py
transformers/modeling_encoder_decoder.py
+37
-29
transformers/modeling_transfo_xl.py
transformers/modeling_transfo_xl.py
+9
-1
transformers/modeling_utils.py
transformers/modeling_utils.py
+398
-1
transformers/modeling_xlm.py
transformers/modeling_xlm.py
+12
-0
transformers/modeling_xlnet.py
transformers/modeling_xlnet.py
+24
-0
No files found.
examples/run_generation.py
View file @
ed9b8481
...
@@ -20,14 +20,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
...
@@ -20,14 +20,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
argparse
import
argparse
import
logging
import
logging
from
tqdm
import
trange
import
torch
import
torch
import
torch.nn.functional
as
F
import
numpy
as
np
import
numpy
as
np
from
transformers
import
GPT2Config
,
OpenAIGPTConfig
,
XLNetConfig
,
TransfoXLConfig
,
XLMConfig
,
CTRLConfig
from
transformers
import
GPT2LMHeadModel
,
GPT2Tokenizer
from
transformers
import
GPT2LMHeadModel
,
GPT2Tokenizer
from
transformers
import
OpenAIGPTLMHeadModel
,
OpenAIGPTTokenizer
from
transformers
import
OpenAIGPTLMHeadModel
,
OpenAIGPTTokenizer
from
transformers
import
XLNetLMHeadModel
,
XLNetTokenizer
from
transformers
import
XLNetLMHeadModel
,
XLNetTokenizer
...
@@ -36,22 +32,22 @@ from transformers import CTRLLMHeadModel, CTRLTokenizer
...
@@ -36,22 +32,22 @@ from transformers import CTRLLMHeadModel, CTRLTokenizer
from
transformers
import
XLMWithLMHeadModel
,
XLMTokenizer
from
transformers
import
XLMWithLMHeadModel
,
XLMTokenizer
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
logging
.
basicConfig
(
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
format
=
"%(asctime)s - %(levelname)s - %(name)s - %(message)s"
,
level
=
logging
.
INFO
)
datefmt
=
"%m/%d/%Y %H:%M:%S"
,
level
=
logging
.
INFO
,
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
MAX_LENGTH
=
int
(
10000
)
# Hardcoded max length to avoid infinite loop
MAX_LENGTH
=
int
(
10000
)
# Hardcoded max length to avoid infinite loop
ALL_MODELS
=
sum
((
tuple
(
conf
.
pretrained_config_archive_map
.
keys
())
for
conf
in
(
GPT2Config
,
OpenAIGPTConfig
,
XLNetConfig
,
TransfoXLConfig
,
XLMConfig
,
CTRLConfig
)),
())
MODEL_CLASSES
=
{
MODEL_CLASSES
=
{
'
gpt2
'
:
(
GPT2LMHeadModel
,
GPT2Tokenizer
),
"
gpt2
"
:
(
GPT2LMHeadModel
,
GPT2Tokenizer
),
'
ctrl
'
:
(
CTRLLMHeadModel
,
CTRLTokenizer
),
"
ctrl
"
:
(
CTRLLMHeadModel
,
CTRLTokenizer
),
'
openai-gpt
'
:
(
OpenAIGPTLMHeadModel
,
OpenAIGPTTokenizer
),
"
openai-gpt
"
:
(
OpenAIGPTLMHeadModel
,
OpenAIGPTTokenizer
),
'
xlnet
'
:
(
XLNetLMHeadModel
,
XLNetTokenizer
),
"
xlnet
"
:
(
XLNetLMHeadModel
,
XLNetTokenizer
),
'
transfo-xl
'
:
(
TransfoXLLMHeadModel
,
TransfoXLTokenizer
),
"
transfo-xl
"
:
(
TransfoXLLMHeadModel
,
TransfoXLTokenizer
),
'
xlm
'
:
(
XLMWithLMHeadModel
,
XLMTokenizer
),
"
xlm
"
:
(
XLMWithLMHeadModel
,
XLMTokenizer
),
}
}
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
...
@@ -75,81 +71,79 @@ def set_seed(args):
...
@@ -75,81 +71,79 @@ def set_seed(args):
if
args
.
n_gpu
>
0
:
if
args
.
n_gpu
>
0
:
torch
.
cuda
.
manual_seed_all
(
args
.
seed
)
torch
.
cuda
.
manual_seed_all
(
args
.
seed
)
#
# Functions to prepare models' input
#
def
prepare_ctrl_input
(
args
,
_
,
tokenizer
,
prompt_text
):
if
args
.
temperature
>
0.7
:
logger
.
info
(
"CTRL typically works better with lower temperatures (and lower top_k)."
)
encoded_prompt
=
tokenizer
.
encode
(
prompt_text
,
add_special_tokens
=
False
)
if
not
any
(
encoded_prompt
[
0
]
==
x
for
x
in
tokenizer
.
control_codes
.
values
()):
logger
.
info
(
"WARNING! You are not starting your generation from a control code so you won't get good results"
)
return
prompt_text
def
top_k_top_p_filtering
(
logits
,
top_k
=
0
,
top_p
=
0.0
,
filter_value
=-
float
(
'Inf'
)):
def
prepare_xlm_input
(
args
,
model
,
tokenizer
,
prompt_text
):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
# kwargs = {"language": None, "mask_token_id": None}
Args:
logits: logits distribution shape (batch size x vocabulary size)
# Set the language
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
use_lang_emb
=
hasattr
(
model
.
config
,
"use_lang_emb"
)
and
model
.
config
.
use_lang_emb
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
if
hasattr
(
model
.
config
,
"lang2id"
)
and
use_lang_emb
:
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
available_languages
=
model
.
config
.
lang2id
.
keys
()
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
if
args
.
xlm_language
in
available_languages
:
"""
language
=
args
.
xlm_language
top_k
=
min
(
top_k
,
logits
.
size
(
-
1
))
# Safety check
if
top_k
>
0
:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove
=
logits
<
torch
.
topk
(
logits
,
top_k
)[
0
][...,
-
1
,
None
]
logits
[
indices_to_remove
]
=
filter_value
if
top_p
>
0.0
:
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
True
)
cumulative_probs
=
torch
.
cumsum
(
F
.
softmax
(
sorted_logits
,
dim
=-
1
),
dim
=-
1
)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove
=
cumulative_probs
>
top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove
[...,
1
:]
=
sorted_indices_to_remove
[...,
:
-
1
].
clone
()
sorted_indices_to_remove
[...,
0
]
=
0
# scatter sorted tensors to original indexing
indices_to_remove
=
sorted_indices_to_remove
.
scatter
(
dim
=
1
,
index
=
sorted_indices
,
src
=
sorted_indices_to_remove
)
logits
[
indices_to_remove
]
=
filter_value
return
logits
def
sample_sequence
(
model
,
length
,
context
,
num_samples
=
1
,
temperature
=
1
,
top_k
=
0
,
top_p
=
0.0
,
repetition_penalty
=
1.0
,
is_xlnet
=
False
,
is_xlm_mlm
=
False
,
xlm_mask_token
=
None
,
xlm_lang
=
None
,
device
=
'cpu'
):
context
=
torch
.
tensor
(
context
,
dtype
=
torch
.
long
,
device
=
device
)
context
=
context
.
unsqueeze
(
0
).
repeat
(
num_samples
,
1
)
generated
=
context
with
torch
.
no_grad
():
for
_
in
trange
(
length
):
inputs
=
{
'input_ids'
:
generated
}
if
is_xlnet
:
# XLNet is a direct (predict same token, not next token) and bi-directional model by default
# => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring)
input_ids
=
torch
.
cat
((
generated
,
torch
.
zeros
((
1
,
1
),
dtype
=
torch
.
long
,
device
=
device
)),
dim
=
1
)
perm_mask
=
torch
.
zeros
((
1
,
input_ids
.
shape
[
1
],
input_ids
.
shape
[
1
]),
dtype
=
torch
.
float
,
device
=
device
)
perm_mask
[:,
:,
-
1
]
=
1.0
# Previous tokens don't see last token
target_mapping
=
torch
.
zeros
((
1
,
1
,
input_ids
.
shape
[
1
]),
dtype
=
torch
.
float
,
device
=
device
)
target_mapping
[
0
,
0
,
-
1
]
=
1.0
# predict last token
inputs
=
{
'input_ids'
:
input_ids
,
'perm_mask'
:
perm_mask
,
'target_mapping'
:
target_mapping
}
if
is_xlm_mlm
and
xlm_mask_token
:
# XLM MLM models are direct models (predict same token, not next token)
# => need one additional dummy token in the input (will be masked and guessed)
input_ids
=
torch
.
cat
((
generated
,
torch
.
full
((
1
,
1
),
xlm_mask_token
,
dtype
=
torch
.
long
,
device
=
device
)),
dim
=
1
)
inputs
=
{
'input_ids'
:
input_ids
}
if
xlm_lang
is
not
None
:
inputs
[
"langs"
]
=
torch
.
tensor
([
xlm_lang
]
*
inputs
[
"input_ids"
].
shape
[
1
],
device
=
device
).
view
(
1
,
-
1
)
outputs
=
model
(
**
inputs
)
# Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
next_token_logits
=
outputs
[
0
][:,
-
1
,
:]
/
(
temperature
if
temperature
>
0
else
1.
)
# repetition penalty from CTRL (https://arxiv.org/abs/1909.05858)
for
i
in
range
(
num_samples
):
for
_
in
set
(
generated
[
i
].
tolist
()):
next_token_logits
[
i
,
_
]
/=
repetition_penalty
filtered_logits
=
top_k_top_p_filtering
(
next_token_logits
,
top_k
=
top_k
,
top_p
=
top_p
)
if
temperature
==
0
:
# greedy sampling:
next_token
=
torch
.
argmax
(
filtered_logits
,
dim
=-
1
).
unsqueeze
(
-
1
)
else
:
else
:
next_token
=
torch
.
multinomial
(
F
.
softmax
(
filtered_logits
,
dim
=-
1
),
num_samples
=
1
)
language
=
None
generated
=
torch
.
cat
((
generated
,
next_token
),
dim
=
1
)
while
language
not
in
available_languages
:
return
generated
language
=
input
(
"Using XLM. Select language in "
+
str
(
list
(
available_languages
))
+
" >>> "
)
# kwargs["language"] = tokenizer.lang2id[language]
# TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers
# XLM masked-language modeling (MLM) models need masked token
# is_xlm_mlm = "mlm" in args.model_name_or_path
# if is_xlm_mlm:
# kwargs["mask_token_id"] = tokenizer.mask_token_id
return
prompt_text
def
prepare_xlnet_input
(
args
,
_
,
tokenizer
,
prompt_text
):
prompt_text
=
(
args
.
padding_text
if
args
.
padding_text
else
PADDING_TEXT
)
+
prompt_text
return
prompt_text
,
{}
def
prepare_transfoxl_input
(
args
,
_
,
tokenizer
,
prompt_text
):
prompt_text
=
(
args
.
padding_text
if
args
.
padding_text
else
PADDING_TEXT
)
+
prompt_text
return
prompt_text
,
{}
PREPROCESSING_FUNCTIONS
=
{
"ctrl"
:
prepare_ctrl_input
,
"xlm"
:
prepare_xlm_input
,
"xlnet"
:
prepare_xlnet_input
,
"transfo-xl"
:
prepare_transfoxl_input
,
}
def
adjust_length_to_model
(
length
,
max_sequence_length
):
if
length
<
0
and
max_sequence_length
>
0
:
length
=
max_sequence_length
elif
0
<
max_sequence_length
<
length
:
length
=
max_sequence_length
# No generation bigger than model size
elif
length
<
0
:
length
=
MAX_LENGTH
# avoid infinite loop
return
length
def
main
():
def
main
():
...
@@ -157,108 +151,76 @@ def main():
...
@@ -157,108 +151,76 @@ def main():
parser
.
add_argument
(
"--model_type"
,
default
=
None
,
type
=
str
,
required
=
True
,
parser
.
add_argument
(
"--model_type"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Model type selected in the list: "
+
", "
.
join
(
MODEL_CLASSES
.
keys
()))
help
=
"Model type selected in the list: "
+
", "
.
join
(
MODEL_CLASSES
.
keys
()))
parser
.
add_argument
(
"--model_name_or_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
parser
.
add_argument
(
"--model_name_or_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to pre-trained model or shortcut name selected in the list: "
+
", "
.
join
(
ALL_MODELS
))
help
=
"Path to pre-trained model or shortcut name selected in the list: "
+
", "
.
join
(
MODEL_CLASSES
.
keys
()))
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--padding_text"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--xlm_lang"
,
type
=
str
,
default
=
""
,
help
=
"Optional language when used with the XLM model."
)
parser
.
add_argument
(
"--length"
,
type
=
int
,
default
=
20
)
parser
.
add_argument
(
"--length"
,
type
=
int
,
default
=
20
)
parser
.
add_argument
(
"--num_samples"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--stop_token"
,
type
=
str
,
default
=
None
,
help
=
"Token at which text generation is stopped"
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
,
help
=
"temperature of 0 implies greedy sampling"
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
,
help
=
"temperature of 1.0 has no effect, lower tend toward greedy sampling"
)
parser
.
add_argument
(
"--repetition_penalty"
,
type
=
float
,
default
=
1.0
,
parser
.
add_argument
(
"--repetition_penalty"
,
type
=
float
,
default
=
1.0
,
help
=
"primarily useful for CTRL model; in that case, use 1.2"
)
help
=
"primarily useful for CTRL model; in that case, use 1.2"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--p"
,
type
=
float
,
default
=
0.9
)
parser
.
add_argument
(
"--top_p"
,
type
=
float
,
default
=
0.9
)
parser
.
add_argument
(
"--no_cuda"
,
action
=
'store_true'
,
parser
.
add_argument
(
"--padding_text"
,
type
=
str
,
default
=
""
,
help
=
"Padding text for Transfo-XL and XLNet."
)
help
=
"Avoid using CUDA when available"
)
parser
.
add_argument
(
"--xlm_language"
,
type
=
str
,
default
=
""
,
help
=
"Optional language when used with the XLM model."
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
42
,
help
=
"random seed for initialization"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
help
=
"random seed for initialization"
)
parser
.
add_argument
(
'--stop_token'
,
type
=
str
,
default
=
None
,
parser
.
add_argument
(
"--no_cuda"
,
action
=
"store_true"
,
help
=
"Avoid using CUDA when available"
)
help
=
"Token at which text generation is stopped"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
args
.
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
and
not
args
.
no_cuda
else
"cpu"
)
args
.
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
and
not
args
.
no_cuda
else
"cpu"
)
args
.
n_gpu
=
torch
.
cuda
.
device_count
()
args
.
n_gpu
=
torch
.
cuda
.
device_count
()
set_seed
(
args
)
set_seed
(
args
)
# Initialize the model and tokenizer
try
:
args
.
model_type
=
args
.
model_type
.
lower
()
args
.
model_type
=
args
.
model_type
.
lower
()
model_class
,
tokenizer_class
=
MODEL_CLASSES
[
args
.
model_type
]
model_class
,
tokenizer_class
=
MODEL_CLASSES
[
args
.
model_type
]
except
KeyError
:
raise
KeyError
(
"the model {} you specified is not supported. You are welcome to add it and open a PR :)"
)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
model_name_or_path
)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
model_class
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
model_class
.
from_pretrained
(
args
.
model_name_or_path
)
model
.
to
(
args
.
device
)
model
.
to
(
args
.
device
)
model
.
eval
()
if
args
.
length
<
0
and
model
.
config
.
max_position_embeddings
>
0
:
args
.
length
=
model
.
config
.
max_position_embeddings
elif
0
<
model
.
config
.
max_position_embeddings
<
args
.
length
:
args
.
length
=
model
.
config
.
max_position_embeddings
# No generation bigger than model size
elif
args
.
length
<
0
:
args
.
length
=
MAX_LENGTH
# avoid infinite loop
args
.
length
=
adjust_length_to_model
(
args
.
length
,
max_sequence_length
=
model
.
config
.
max_position_embeddings
)
logger
.
info
(
args
)
logger
.
info
(
args
)
if
args
.
model_type
in
[
"ctrl"
]:
if
args
.
temperature
>
0.7
:
prompt_text
=
args
.
prompt
if
args
.
prompt
else
input
(
"Model prompt >>> "
)
logger
.
info
(
'CTRL typically works better with lower temperatures (and lower top_k).'
)
# Different models need different input formatting and/or extra arguments
while
True
:
requires_preprocessing
=
args
.
model_type
in
PREPROCESSING_FUNCTIONS
.
keys
()
xlm_lang
=
None
if
requires_preprocessing
:
# XLM Language usage detailed in the issues #1414
prepare_input
=
PREPROCESSING_FUNCTIONS
.
get
(
args
.
model_type
)
if
args
.
model_type
in
[
"xlm"
]
and
hasattr
(
tokenizer
,
'lang2id'
)
and
hasattr
(
model
.
config
,
'use_lang_emb'
)
\
prompt_text
=
prepare_input
(
args
,
model
,
tokenizer
,
prompt_text
)
and
model
.
config
.
use_lang_emb
:
encoded_prompt
=
tokenizer
.
encode
(
prompt_text
,
add_special_tokens
=
False
,
return_tensors
=
'pt'
)
if
args
.
xlm_lang
:
language
=
args
.
xlm_lang
output_sequences
=
model
.
generate
(
else
:
input_ids
=
encoded_prompt
,
language
=
None
max_length
=
args
.
length
,
while
language
not
in
tokenizer
.
lang2id
.
keys
():
language
=
input
(
"Using XLM. Select language in "
+
str
(
list
(
tokenizer
.
lang2id
.
keys
()))
+
" >>> "
)
xlm_lang
=
tokenizer
.
lang2id
[
language
]
# XLM masked-language modeling (MLM) models need masked token (see details in sample_sequence)
is_xlm_mlm
=
args
.
model_type
in
[
"xlm"
]
and
'mlm'
in
args
.
model_name_or_path
if
is_xlm_mlm
:
xlm_mask_token
=
tokenizer
.
mask_token_id
else
:
xlm_mask_token
=
None
raw_text
=
args
.
prompt
if
args
.
prompt
else
input
(
"Model prompt >>> "
)
if
args
.
model_type
in
[
"transfo-xl"
,
"xlnet"
]:
# Models with memory likes to have a long prompt for short inputs.
raw_text
=
(
args
.
padding_text
if
args
.
padding_text
else
PADDING_TEXT
)
+
raw_text
context_tokens
=
tokenizer
.
encode
(
raw_text
,
add_special_tokens
=
False
)
if
args
.
model_type
==
"ctrl"
:
if
not
any
(
context_tokens
[
0
]
==
x
for
x
in
tokenizer
.
control_codes
.
values
()):
logger
.
info
(
"WARNING! You are not starting your generation from a control code so you won't get good results"
)
out
=
sample_sequence
(
model
=
model
,
context
=
context_tokens
,
num_samples
=
args
.
num_samples
,
length
=
args
.
length
,
temperature
=
args
.
temperature
,
temperature
=
args
.
temperature
,
top_k
=
args
.
top_
k
,
top_k
=
args
.
k
,
top_p
=
args
.
top_
p
,
top_p
=
args
.
p
,
repetition_penalty
=
args
.
repetition_penalty
,
repetition_penalty
=
args
.
repetition_penalty
,
is_xlnet
=
bool
(
args
.
model_type
==
"xlnet"
),
is_xlm_mlm
=
is_xlm_mlm
,
xlm_mask_token
=
xlm_mask_token
,
xlm_lang
=
xlm_lang
,
device
=
args
.
device
,
)
)
out
=
out
[:,
len
(
context_tokens
):].
tolist
()
for
o
in
out
:
# Batch size == 1. to add more examples please use num_return_sequences > 1
text
=
tokenizer
.
decode
(
o
,
clean_up_tokenization_spaces
=
True
)
generated_sequence
=
output_sequences
[
0
].
tolist
()
if
args
.
stop_token
:
text
=
tokenizer
.
decode
(
generated_sequence
,
clean_up_tokenization_spaces
=
True
)
index
=
text
.
find
(
args
.
stop_token
)
text
=
text
[:
t
.
find
(
args
.
stop_token
)
if
args
.
stop_token
else
None
]
if
index
==
-
1
:
index
=
None
text
=
text
[:
index
]
print
(
text
)
print
(
text
)
if
args
.
prompt
:
break
return
text
return
text
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
main
()
main
()
transformers/configuration_utils.py
View file @
ed9b8481
...
@@ -56,8 +56,24 @@ class PretrainedConfig(object):
...
@@ -56,8 +56,24 @@ class PretrainedConfig(object):
self
.
torchscript
=
kwargs
.
pop
(
'torchscript'
,
False
)
# Only used by PyTorch models
self
.
torchscript
=
kwargs
.
pop
(
'torchscript'
,
False
)
# Only used by PyTorch models
self
.
use_bfloat16
=
kwargs
.
pop
(
'use_bfloat16'
,
False
)
self
.
use_bfloat16
=
kwargs
.
pop
(
'use_bfloat16'
,
False
)
self
.
pruned_heads
=
kwargs
.
pop
(
'pruned_heads'
,
{})
self
.
pruned_heads
=
kwargs
.
pop
(
'pruned_heads'
,
{})
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
self
.
is_decoder
=
kwargs
.
pop
(
'is_decoder'
,
False
)
self
.
is_decoder
=
kwargs
.
pop
(
'is_decoder'
,
False
)
# Parameters for sequence generation
self
.
max_length
=
kwargs
.
pop
(
'max_length'
,
20
)
self
.
do_sample
=
kwargs
.
pop
(
'do_sample'
,
False
)
self
.
num_beams
=
kwargs
.
pop
(
'num_beams'
,
1
)
self
.
temperature
=
kwargs
.
pop
(
'temperature'
,
1.0
)
self
.
top_k
=
kwargs
.
pop
(
'top_k'
,
50
)
self
.
top_p
=
kwargs
.
pop
(
'top_p'
,
1.0
)
self
.
repetition_penalty
=
kwargs
.
pop
(
'repetition_penalty'
,
1.0
)
self
.
bos_token_id
=
kwargs
.
pop
(
'bos_token_id'
,
0
)
self
.
pad_token_id
=
kwargs
.
pop
(
'pad_token_id'
,
0
)
self
.
eos_token_ids
=
kwargs
.
pop
(
'eos_token_ids'
,
0
)
self
.
length_penalty
=
kwargs
.
pop
(
'length_penalty'
,
1.
)
self
.
num_return_sequences
=
kwargs
.
pop
(
'num_return_sequences'
,
1
)
# Fine-tuning task arguments
# Fine-tuning task arguments
self
.
finetuning_task
=
kwargs
.
pop
(
'finetuning_task'
,
None
)
self
.
finetuning_task
=
kwargs
.
pop
(
'finetuning_task'
,
None
)
self
.
num_labels
=
kwargs
.
pop
(
'num_labels'
,
2
)
self
.
num_labels
=
kwargs
.
pop
(
'num_labels'
,
2
)
...
...
transformers/configuration_xlm.py
View file @
ed9b8481
...
@@ -110,6 +110,8 @@ class XLMConfig(PretrainedConfig):
...
@@ -110,6 +110,8 @@ class XLMConfig(PretrainedConfig):
summary_first_dropout
=
0.1
,
summary_first_dropout
=
0.1
,
start_n_top
=
5
,
start_n_top
=
5
,
end_n_top
=
5
,
end_n_top
=
5
,
mask_token_id
=
0
,
lang_id
=
0
,
**
kwargs
):
**
kwargs
):
"""Constructs XLMConfig.
"""Constructs XLMConfig.
"""
"""
...
@@ -143,6 +145,8 @@ class XLMConfig(PretrainedConfig):
...
@@ -143,6 +145,8 @@ class XLMConfig(PretrainedConfig):
self
.
summary_first_dropout
=
summary_first_dropout
self
.
summary_first_dropout
=
summary_first_dropout
self
.
start_n_top
=
start_n_top
self
.
start_n_top
=
start_n_top
self
.
end_n_top
=
end_n_top
self
.
end_n_top
=
end_n_top
self
.
mask_token_id
=
mask_token_id
self
.
lang_id
=
lang_id
if
"n_words"
in
kwargs
:
if
"n_words"
in
kwargs
:
self
.
n_words
=
kwargs
[
"n_words"
]
self
.
n_words
=
kwargs
[
"n_words"
]
...
...
transformers/modeling_encoder_decoder.py
View file @
ed9b8481
...
@@ -18,9 +18,11 @@ from __future__ import absolute_import, division, print_function, unicode_litera
...
@@ -18,9 +18,11 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
logging
import
logging
import
os
import
os
import
warnings
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
tqdm
import
trange
from
.modeling_auto
import
AutoModel
,
AutoModelWithLMHead
from
.modeling_auto
import
AutoModel
,
AutoModelWithLMHead
...
@@ -119,8 +121,7 @@ class PreTrainedEncoderDecoder(nn.Module):
...
@@ -119,8 +121,7 @@ class PreTrainedEncoderDecoder(nn.Module):
kwargs_common
=
{
kwargs_common
=
{
argument
:
value
argument
:
value
for
argument
,
value
in
kwargs
.
items
()
for
argument
,
value
in
kwargs
.
items
()
if
not
argument
.
startswith
(
"encoder_"
)
if
not
argument
.
startswith
(
"encoder_"
)
and
not
argument
.
startswith
(
"decoder_"
)
and
not
argument
.
startswith
(
"decoder_"
)
}
}
kwargs_decoder
=
kwargs_common
.
copy
()
kwargs_decoder
=
kwargs_common
.
copy
()
kwargs_encoder
=
kwargs_common
.
copy
()
kwargs_encoder
=
kwargs_common
.
copy
()
...
@@ -220,49 +221,56 @@ class PreTrainedEncoderDecoder(nn.Module):
...
@@ -220,49 +221,56 @@ class PreTrainedEncoderDecoder(nn.Module):
Indices of decoder input sequence tokens in the vocabulary.
Indices of decoder input sequence tokens in the vocabulary.
kwargs: (`optional`) Remaining dictionary of keyword arguments.
kwargs: (`optional`) Remaining dictionary of keyword arguments.
"""
"""
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
kwargs_encoder
,
kwargs_decoder
=
self
.
prepare_model_kwargs
(
**
kwargs
)
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as whole.
# Encode if needed (training, first prediction pass)
# We let the specific kwargs override the common ones in case of conflict.
encoder_hidden_states
=
kwargs_encoder
.
pop
(
"hidden_states"
,
None
)
if
encoder_hidden_states
is
None
:
encoder_outputs
=
self
.
encoder
(
encoder_input_ids
,
**
kwargs_encoder
)
encoder_hidden_states
=
encoder_outputs
[
0
]
else
:
encoder_outputs
=
()
kwargs_decoder
[
"encoder_hidden_states"
]
=
encoder_hidden_states
decoder_outputs
=
self
.
decoder
(
decoder_input_ids
,
encoder_hidden_states
,
**
kwargs_decoder
)
return
decoder_outputs
+
encoder_outputs
@
staticmethod
def
prepare_model_kwargs
(
**
kwargs
):
""" Prepare the encoder and decoder's keyword arguments.
Keyword arguments come in 3 flavors:
- encoder-specific (prefixed by `encoder_`)
- decoder-specific (prefixed by `decoder_`)
- those that apply to the model as whole.
We let the specific kwargs override the common ones in case of
conflict.
"""
kwargs_common
=
{
kwargs_common
=
{
argument
:
value
argument
:
value
for
argument
,
value
in
kwargs
.
items
()
for
argument
,
value
in
kwargs
.
items
()
if
not
argument
.
startswith
(
"encoder_"
)
if
not
argument
.
startswith
(
"encoder_"
)
and
not
argument
.
startswith
(
"decoder_"
)
and
not
argument
.
startswith
(
"decoder_"
)
}
}
kwargs_
decoder
=
kwargs_common
.
copy
()
decoder
_kwargs
=
kwargs_common
.
copy
()
kwargs_
encoder
=
kwargs_common
.
copy
()
encoder
_kwargs
=
kwargs_common
.
copy
()
kwargs_
encoder
.
update
(
encoder
_kwargs
.
update
(
{
{
argument
[
len
(
"encoder_"
)
:]:
value
argument
[
len
(
"encoder_"
)
:]:
value
for
argument
,
value
in
kwargs
.
items
()
for
argument
,
value
in
kwargs
.
items
()
if
argument
.
startswith
(
"encoder_"
)
if
argument
.
startswith
(
"encoder_"
)
}
}
)
)
kwargs_
decoder
.
update
(
decoder
_kwargs
.
update
(
{
{
argument
[
len
(
"decoder_"
)
:]:
value
argument
[
len
(
"decoder_"
)
:]:
value
for
argument
,
value
in
kwargs
.
items
()
for
argument
,
value
in
kwargs
.
items
()
if
argument
.
startswith
(
"decoder_"
)
if
argument
.
startswith
(
"decoder_"
)
}
}
)
)
decoder_kwargs
[
"encoder_attention_mask"
]
=
encoder_kwargs
.
get
(
"attention_mask"
,
None
)
# Encode if needed (training, first prediction pass)
return
encoder_kwargs
,
decoder_kwargs
encoder_hidden_states
=
kwargs_encoder
.
pop
(
"hidden_states"
,
None
)
if
encoder_hidden_states
is
None
:
encoder_outputs
=
self
.
encoder
(
encoder_input_ids
,
**
kwargs_encoder
)
encoder_hidden_states
=
encoder_outputs
[
0
]
else
:
encoder_outputs
=
()
# Decode
kwargs_decoder
[
"encoder_hidden_states"
]
=
encoder_hidden_states
kwargs_decoder
[
"encoder_attention_mask"
]
=
kwargs_encoder
.
get
(
"attention_mask"
,
None
)
decoder_outputs
=
self
.
decoder
(
decoder_input_ids
,
**
kwargs_decoder
)
return
decoder_outputs
+
encoder_outputs
class
Model2Model
(
PreTrainedEncoderDecoder
):
class
Model2Model
(
PreTrainedEncoderDecoder
):
...
...
transformers/modeling_transfo_xl.py
View file @
ed9b8481
...
@@ -36,7 +36,7 @@ from torch.nn.parameter import Parameter
...
@@ -36,7 +36,7 @@ from torch.nn.parameter import Parameter
from
.modeling_utils
import
PreTrainedModel
,
Conv1D
,
prune_conv1d_layer
,
SequenceSummary
from
.modeling_utils
import
PreTrainedModel
,
Conv1D
,
prune_conv1d_layer
,
SequenceSummary
from
.configuration_transfo_xl
import
TransfoXLConfig
from
.configuration_transfo_xl
import
TransfoXLConfig
from
.modeling_transfo_xl_utilities
import
ProjectedAdaptiveLogSoftmax
,
sample_logits
from
.modeling_transfo_xl_utilities
import
ProjectedAdaptiveLogSoftmax
,
sample_logits
,
LogUniformSampler
from
.file_utils
import
add_start_docstrings
from
.file_utils
import
add_start_docstrings
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -908,3 +908,11 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
...
@@ -908,3 +908,11 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
outputs
=
[
softmax_output
,
None
]
+
outputs
outputs
=
[
softmax_output
,
None
]
+
outputs
return
outputs
# (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions)
return
outputs
# (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions)
def
get_output_embeddings
(
self
):
""" Double-check if you are using adaptive softmax.
"""
if
self
.
sample_softmax
>
0
:
return
self
.
out_layer
else
:
return
self
.
crit
.
out_layers
[
-
1
]
transformers/modeling_utils.py
View file @
ed9b8481
# coding=utf-8
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright 2018 The Google AI Language Team Authors
, Facebook AI Research authors
and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -496,6 +496,403 @@ class PreTrainedModel(nn.Module):
...
@@ -496,6 +496,403 @@ class PreTrainedModel(nn.Module):
return
model
return
model
def
prepare_inputs_for_generation
(
self
,
input_ids
,
**
kwargs
):
return
{
"input_ids"
:
input_ids
}
@
torch
.
no_grad
()
def
generate
(
self
,
input_ids
=
None
,
max_length
=
None
,
do_sample
=
None
,
num_beams
=
None
,
temperature
=
None
,
top_k
=
None
,
top_p
=
None
,
repetition_penalty
=
None
,
bos_token_id
=
None
,
pad_token_id
=
None
,
eos_token_ids
=
None
,
length_penalty
=
None
,
num_return_sequences
=
None
):
""" Sequence generator for models with a LM head.
The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
and beam-search.
Adapted in part from Facebook's XLM beam search code: https://github.com/facebookresearch/XLM
Params:
**input_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length)
The sequence used as a prompt for the generation. If `None` the method initializes
it as an empty `torch.LongTensor` of shape (1,)
**max_length**: (`optional`) int
The max length of the sequence to be generated. Between 1 and infinity. Default to 20.
**do_sample**: (`optional`) bool
If set to `False` we use greedy decoding; otherwise sampling. Default to greedy sampling.
**num_beams**: (`optional`) int
Number of beams for beam search. 1 means no beam serach. Default to 1.
**temperature**: (`optional`) float
The value used to module the next token probabilities.
**top_k**: (`optional`) int
The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
**top_p**: (`optional`) float
The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
**repetition_penalty**: (`optional`) float
The parameter for repetition penalty. Between 1.0 and + infinity. 1.0 means no penalty. Default to 1.
**bos_token_id**: (`optional`) int
Beginning of sentence token if no prompt is provided. Default to 0.
**eos_token_ids**: (`optional`) int or list of int
End of sequence token or list of tokens to stop the generation. Default to 0.
**length_penalty**: (`optional`) int
Exponential penalty to the length. Default to 0.
**length_penalty**: (`optional`) float
Exponential penalty to the length. Default to 1.
**num_return_sequences**: (`optional`) int
The number of independantly computed returned sequences for each element in the batch. Default to 1.
"""
# We cannot generate if the model does not have a LM head
if
self
.
get_output_embeddings
()
is
None
:
raise
AttributeError
(
"You tried to generate sequences with a model that does not have a LM Head."
"Please use another model class (e.g. `OpenAIGPTLMHeadModel`)"
)
max_length
=
max_length
if
max_length
is
not
None
else
self
.
config
.
max_length
do_sample
=
do_sample
if
do_sample
is
not
None
else
self
.
config
.
do_sample
num_beams
=
num_beams
if
num_beams
is
not
None
else
self
.
config
.
num_beams
temperature
=
temperature
if
temperature
is
not
None
else
self
.
config
.
temperature
top_k
=
top_k
if
top_k
is
not
None
else
self
.
config
.
top_k
top_p
=
top_p
if
top_p
is
not
None
else
self
.
config
.
top_p
repetition_penalty
=
repetition_penalty
if
repetition_penalty
is
not
None
else
self
.
config
.
repetition_penalty
bos_token_id
=
bos_token_id
if
bos_token_id
is
not
None
else
self
.
config
.
bos_token_id
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
config
.
pad_token_id
eos_token_ids
=
eos_token_ids
if
eos_token_ids
is
not
None
else
self
.
config
.
eos_token_ids
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
self
.
config
.
length_penalty
num_return_sequences
=
num_return_sequences
if
num_return_sequences
is
not
None
else
self
.
config
.
num_return_sequences
if
input_ids
is
not
None
:
batch_size
=
input_ids
.
shape
[
0
]
# overriden by the input batch_size
else
:
batch_size
=
1
if
isinstance
(
eos_token_ids
,
int
):
eos_token_ids
=
[
eos_token_ids
]
assert
isinstance
(
max_length
,
int
)
and
max_length
>
0
,
"`max_length` should be a strictely positive integer."
assert
isinstance
(
do_sample
,
bool
),
"`do_sample` should be a boolean."
assert
isinstance
(
num_beams
,
int
)
and
num_beams
>
0
,
"`num_beams` should be a strictely positive integer."
# assert temperature >= 0, "`temperature` should be positive."
assert
isinstance
(
top_k
,
int
)
and
top_k
>=
0
,
"`top_k` should be a positive integer."
assert
0
<=
top_p
<=
1
,
"`top_p` should be between 0 and 1."
assert
repetition_penalty
>=
1.0
,
"`repetition_penalty` should be >= 1."
assert
isinstance
(
bos_token_id
,
int
)
and
bos_token_id
>=
0
,
"`bos_token_id` should be a positive integer."
assert
isinstance
(
pad_token_id
,
int
)
and
pad_token_id
>=
0
,
"`pad_token_id` should be a positive integer."
assert
isinstance
(
eos_token_ids
,
(
list
,
tuple
))
and
(
e
>=
0
for
e
in
eos_token_ids
),
\
"`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
assert
length_penalty
>
0
,
"`length_penalty` should be strictely positive."
assert
isinstance
(
num_return_sequences
,
int
)
and
num_return_sequences
>
0
,
"`num_return_sequences` should be a strictely positive integer."
if
input_ids
is
None
:
input_ids
=
torch
.
full
((
batch_size
,
1
),
bos_token_id
,
dtype
=
torch
.
long
,
device
=
next
(
self
.
parameters
()).
device
)
else
:
assert
input_ids
.
dim
()
==
2
,
"Input prompt should be of shape (batch_size, sequence length)."
# current position and vocab size
cur_len
=
input_ids
.
shape
[
1
]
vocab_size
=
self
.
config
.
vocab_size
if
num_return_sequences
!=
1
:
# Expand input to num return sequences
input_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
batch_size
,
num_return_sequences
,
cur_len
)
input_ids
=
input_ids
.
contiguous
().
view
(
batch_size
*
num_return_sequences
,
cur_len
)
# (batch_size * num_return_sequences, cur_len)
effective_batch_size
=
batch_size
*
num_return_sequences
else
:
effective_batch_size
=
batch_size
if
num_beams
>
1
:
output
=
self
.
_generate_beam_search
(
input_ids
,
cur_len
,
max_length
,
do_sample
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
pad_token_id
,
eos_token_ids
,
effective_batch_size
,
length_penalty
,
num_beams
,
vocab_size
)
else
:
output
=
self
.
_generate_no_beam_search
(
input_ids
,
cur_len
,
max_length
,
do_sample
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
pad_token_id
,
eos_token_ids
,
effective_batch_size
)
if
num_return_sequences
!=
1
:
output
=
output
.
view
(
batch_size
,
num_return_sequences
,
-
1
)
return
output
def
_generate_no_beam_search
(
self
,
input_ids
,
cur_len
,
max_length
,
do_sample
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
pad_token_id
,
eos_token_ids
,
batch_size
):
""" Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
"""
# current position / max lengths / length of generated sentences / unfinished sentences
unfinished_sents
=
input_ids
.
new
(
batch_size
).
fill_
(
1
)
# TODO: add cached compute states
pasts
=
None
while
cur_len
<
max_length
:
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
pasts
=
pasts
)
outputs
=
self
(
**
model_inputs
)
next_token_logits
=
outputs
[
0
][:,
-
1
,
:]
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if
repetition_penalty
!=
1.0
:
for
i
in
range
(
batch_size
):
for
previous_tokens
in
set
(
input_ids
[
i
].
tolist
()):
next_token_logits
[
i
,
previous_tokens
]
/=
repetition_penalty
if
do_sample
:
# Temperature (higher temperature => more likely to sample low probability tokens)
if
temperature
>
0
and
temperature
!=
1.0
:
next_token_logits
=
next_token_logits
/
temperature
# Top-p/top-k filtering
next_token_logits
=
top_k_top_p_filtering
(
next_token_logits
,
top_k
=
top_k
,
top_p
=
top_p
)
# Sample
next_token
=
torch
.
multinomial
(
F
.
softmax
(
next_token_logits
,
dim
=-
1
),
num_samples
=
1
).
squeeze
(
1
)
else
:
# Greedy decoding
next_token
=
torch
.
argmax
(
next_token_logits
,
dim
=-
1
)
# update generations and finished sentences
tokens_to_add
=
next_token
*
unfinished_sents
+
pad_token_id
*
(
1
-
unfinished_sents
)
input_ids
=
torch
.
cat
([
input_ids
,
tokens_to_add
.
unsqueeze
(
-
1
)],
dim
=-
1
)
for
eos_token_id
in
eos_token_ids
:
unfinished_sents
.
mul_
(
tokens_to_add
.
ne
(
eos_token_id
).
long
())
cur_len
=
cur_len
+
1
# stop when there is a </s> in each sentence, or if we exceed the maximul length
if
unfinished_sents
.
max
()
==
0
:
break
# add eos_token_ids to unfinished sentences
if
cur_len
==
max_length
:
input_ids
[:,
-
1
].
masked_fill_
(
unfinished_sents
.
to
(
dtype
=
torch
.
bool
),
eos_token_ids
[
0
])
return
input_ids
def
_generate_beam_search
(
self
,
input_ids
,
cur_len
,
max_length
,
do_sample
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
pad_token_id
,
eos_token_ids
,
batch_size
,
length_penalty
,
num_beams
,
vocab_size
):
""" Generate sequences for each example with beam search.
"""
# Expand input to num beams
input_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
batch_size
,
num_beams
,
cur_len
)
input_ids
=
input_ids
.
contiguous
().
view
(
batch_size
*
num_beams
,
cur_len
)
# (batch_size * num_beams, cur_len)
# generated hypotheses
generated_hyps
=
[
BeamHypotheses
(
num_beams
,
max_length
,
length_penalty
,
early_stopping
=
False
)
for
_
in
range
(
batch_size
)]
# scores for each sentence in the beam
beam_scores
=
torch
.
zeros
((
batch_size
,
num_beams
),
dtype
=
torch
.
float
,
device
=
input_ids
.
device
)
beam_scores
[:,
1
:]
=
-
1e9
beam_scores
=
beam_scores
.
view
(
-
1
)
# shape (batch_size * num_beams,)
# cache compute states
pasts
=
None
# self.prepare_pasts()
# done sentences
done
=
[
False
for
_
in
range
(
batch_size
)]
while
cur_len
<
max_length
:
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
pasts
=
pasts
)
scores
=
self
(
**
model_inputs
)[
0
]
# (batch_size * num_beams, cur_len, vocab_size)
scores
=
scores
[:,
-
1
,
:]
# (batch_size * num_beams, vocab_size)
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if
repetition_penalty
!=
1.0
:
for
i
in
range
(
batch_size
*
num_beams
):
for
previous_tokens
in
set
(
input_ids
[
i
].
tolist
()):
scores
[
i
,
previous_tokens
]
/=
repetition_penalty
if
do_sample
:
# Temperature (higher temperature => more likely to sample low probability tokens)
if
temperature
>
0
and
temperature
!=
1.0
:
scores
=
scores
/
temperature
# Top-p/top-k filtering
scores
=
top_k_top_p_filtering
(
scores
,
top_k
=
top_k
,
top_p
=
top_p
,
min_tokens_to_keep
=
2
)
# (batch_size * num_beams, vocab_size)
# Sample 2 next words for each beam (so we have some spare tokens and match output of greedy beam search)
next_words
=
torch
.
multinomial
(
F
.
softmax
(
scores
,
dim
=-
1
),
num_samples
=
2
)
# (batch_size * num_beams, 2)
# Compute next scores
_scores
=
F
.
log_softmax
(
scores
,
dim
=-
1
)
# (batch_size * num_beams, vocab_size)
_scores
=
torch
.
gather
(
_scores
,
-
1
,
next_words
)
# (batch_size * num_beams, 2)
next_scores
=
_scores
+
beam_scores
[:,
None
].
expand_as
(
_scores
)
# (batch_size * num_beams, 2)
# Match shape of greedy beam search
next_words
=
next_words
.
view
(
batch_size
,
2
*
num_beams
)
# (batch_size, 2 * num_beams)
next_scores
=
next_scores
.
view
(
batch_size
,
2
*
num_beams
)
# (batch_size, 2 * num_beams)
else
:
# do greedy beam search
scores
=
F
.
log_softmax
(
scores
,
dim
=-
1
)
# (batch_size * num_beams, vocab_size)
assert
scores
.
size
()
==
(
batch_size
*
num_beams
,
vocab_size
)
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
_scores
=
scores
+
beam_scores
[:,
None
].
expand_as
(
scores
)
# (batch_size * num_beams, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
_scores
=
_scores
.
view
(
batch_size
,
num_beams
*
vocab_size
)
# (batch_size, num_beams * vocab_size)
next_scores
,
next_words
=
torch
.
topk
(
_scores
,
2
*
num_beams
,
dim
=
1
,
largest
=
True
,
sorted
=
True
)
assert
next_scores
.
size
()
==
next_words
.
size
()
==
(
batch_size
,
2
*
num_beams
)
# next batch beam content
# list of (batch_size * num_beams) tuple(next hypothesis score, next word, current position in the batch)
next_batch_beam
=
[]
# for each sentence
for
batch_ex
in
range
(
batch_size
):
# if we are done with this sentence
done
[
batch_ex
]
=
done
[
batch_ex
]
or
generated_hyps
[
batch_ex
].
is_done
(
next_scores
[
batch_ex
].
max
().
item
())
if
done
[
batch_ex
]:
next_batch_beam
.
extend
([(
0
,
pad_token_id
,
0
)]
*
num_beams
)
# pad the batch
continue
# next sentence beam content
next_sent_beam
=
[]
# next words for this sentence
for
idx
,
score
in
zip
(
next_words
[
batch_ex
],
next_scores
[
batch_ex
]):
# get beam and word IDs
beam_id
=
idx
//
vocab_size
word_id
=
idx
%
vocab_size
# end of sentence, or next word
if
word_id
.
item
()
in
eos_token_ids
or
cur_len
+
1
==
max_length
:
generated_hyps
[
batch_ex
].
add
(
input_ids
[
batch_ex
*
num_beams
+
beam_id
,
:
cur_len
].
clone
(),
score
.
item
())
else
:
next_sent_beam
.
append
((
score
,
word_id
,
batch_ex
*
num_beams
+
beam_id
))
# the beam for next step is full
if
len
(
next_sent_beam
)
==
num_beams
:
break
# update next beam content
assert
len
(
next_sent_beam
)
==
0
if
cur_len
+
1
==
max_length
else
num_beams
if
len
(
next_sent_beam
)
==
0
:
next_sent_beam
=
[(
0
,
pad_token_id
,
0
)]
*
num_beams
# pad the batch
next_batch_beam
.
extend
(
next_sent_beam
)
assert
len
(
next_batch_beam
)
==
num_beams
*
(
batch_ex
+
1
)
# sanity check / prepare next batch
assert
len
(
next_batch_beam
)
==
batch_size
*
num_beams
beam_scores
=
beam_scores
.
new
([
x
[
0
]
for
x
in
next_batch_beam
])
beam_words
=
input_ids
.
new
([
x
[
1
]
for
x
in
next_batch_beam
])
beam_idx
=
input_ids
.
new
([
x
[
2
]
for
x
in
next_batch_beam
])
# re-order batch and internal states
input_ids
=
input_ids
[
beam_idx
,
:]
input_ids
=
torch
.
cat
([
input_ids
,
beam_words
.
unsqueeze
(
1
)],
dim
=-
1
)
# TODO: Activate cache
# for k in cache.keys():
# if k != 'slen':
# cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx])
# update current length
cur_len
=
cur_len
+
1
# stop when we are done with each sentence
if
all
(
done
):
break
# visualize hypotheses
# print([len(x) for x in generated_hyps], cur_len)
# globals().update( locals() );
# !import code; code.interact(local=vars())
# for ii in range(batch_size):
# for ss, ww in sorted(generated_hyps[ii].hyp, key=lambda x: x[0], reverse=True):
# print("%.3f " % ss + " ".join(self.dico[x] for x in ww.tolist()))
# print("")
# select the best hypotheses
tgt_len
=
input_ids
.
new
(
batch_size
)
best
=
[]
for
i
,
hypotheses
in
enumerate
(
generated_hyps
):
best_hyp
=
max
(
hypotheses
.
hyp
,
key
=
lambda
x
:
x
[
0
])[
1
]
tgt_len
[
i
]
=
len
(
best_hyp
)
+
1
# +1 for the <EOS> symbol
best
.
append
(
best_hyp
)
# generate target batch
decoded
=
input_ids
.
new
(
batch_size
,
tgt_len
.
max
().
item
()).
fill_
(
pad_token_id
)
for
i
,
hypo
in
enumerate
(
best
):
decoded
[
i
,
:
tgt_len
[
i
]
-
1
]
=
hypo
decoded
[
i
,
tgt_len
[
i
]
-
1
]
=
eos_token_ids
[
0
]
return
decoded
def
top_k_top_p_filtering
(
logits
,
top_k
=
0
,
top_p
=
1.0
,
filter_value
=-
float
(
'Inf'
),
min_tokens_to_keep
=
1
):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens_to_keep per batch example in the output
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
if
top_k
>
0
:
top_k
=
min
(
max
(
top_k
,
min_tokens_to_keep
),
logits
.
size
(
-
1
))
# Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove
=
logits
<
torch
.
topk
(
logits
,
top_k
)[
0
][...,
-
1
,
None
]
logits
[
indices_to_remove
]
=
filter_value
if
top_p
<
1.0
:
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
True
)
cumulative_probs
=
torch
.
cumsum
(
F
.
softmax
(
sorted_logits
,
dim
=-
1
),
dim
=-
1
)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove
=
cumulative_probs
>
top_p
if
min_tokens_to_keep
>
1
:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove
[...,
:
min_tokens_to_keep
]
=
0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove
[...,
1
:]
=
sorted_indices_to_remove
[...,
:
-
1
].
clone
()
sorted_indices_to_remove
[...,
0
]
=
0
# scatter sorted tensors to original indexing
indices_to_remove
=
sorted_indices_to_remove
.
scatter
(
dim
=
1
,
index
=
sorted_indices
,
src
=
sorted_indices_to_remove
)
logits
[
indices_to_remove
]
=
filter_value
return
logits
class
BeamHypotheses
(
object
):
def
__init__
(
self
,
n_hyp
,
max_length
,
length_penalty
,
early_stopping
):
"""
Initialize n-best list of hypotheses.
"""
self
.
max_length
=
max_length
-
1
# ignoring bos_token
self
.
length_penalty
=
length_penalty
self
.
early_stopping
=
early_stopping
self
.
n_hyp
=
n_hyp
self
.
hyp
=
[]
self
.
worst_score
=
1e9
def
__len__
(
self
):
"""
Number of hypotheses in the list.
"""
return
len
(
self
.
hyp
)
def
add
(
self
,
hyp
,
sum_logprobs
):
"""
Add a new hypothesis to the list.
"""
score
=
sum_logprobs
/
len
(
hyp
)
**
self
.
length_penalty
if
len
(
self
)
<
self
.
n_hyp
or
score
>
self
.
worst_score
:
self
.
hyp
.
append
((
score
,
hyp
))
if
len
(
self
)
>
self
.
n_hyp
:
sorted_scores
=
sorted
([(
s
,
idx
)
for
idx
,
(
s
,
_
)
in
enumerate
(
self
.
hyp
)])
del
self
.
hyp
[
sorted_scores
[
0
][
1
]]
self
.
worst_score
=
sorted_scores
[
1
][
0
]
else
:
self
.
worst_score
=
min
(
score
,
self
.
worst_score
)
def
is_done
(
self
,
best_sum_logprobs
):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if
len
(
self
)
<
self
.
n_hyp
:
return
False
elif
self
.
early_stopping
:
return
True
else
:
return
self
.
worst_score
>=
best_sum_logprobs
/
self
.
max_length
**
self
.
length_penalty
class
Conv1D
(
nn
.
Module
):
class
Conv1D
(
nn
.
Module
):
def
__init__
(
self
,
nf
,
nx
):
def
__init__
(
self
,
nf
,
nx
):
...
...
transformers/modeling_xlm.py
View file @
ed9b8481
...
@@ -649,6 +649,18 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
...
@@ -649,6 +649,18 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
def
get_output_embeddings
(
self
):
def
get_output_embeddings
(
self
):
return
self
.
pred_layer
.
proj
return
self
.
pred_layer
.
proj
def
prepare_inputs_for_generation
(
self
,
input_ids
,
**
kwargs
):
mask_token_id
=
self
.
config
.
mask_token_id
lang_id
=
self
.
config
.
lang_id
mask_token
=
torch
.
full
((
1
,
1
),
mask_token_id
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
input_ids
=
torch
.
cat
([
input_ids
,
mask_token
],
dim
=
1
)
if
lang_id
is
not
None
:
langs
=
torch
.
full_like
(
input_ids
,
lang_id
)
else
:
langs
=
None
return
{
"input_ids"
:
input_ids
,
"langs"
:
langs
}
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
langs
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
langs
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
lengths
=
None
,
cache
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
labels
=
None
):
lengths
=
None
,
cache
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
labels
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
transformer_outputs
=
self
.
transformer
(
input_ids
,
...
...
transformers/modeling_xlnet.py
View file @
ed9b8481
...
@@ -947,6 +947,30 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -947,6 +947,30 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def
get_output_embeddings
(
self
):
def
get_output_embeddings
(
self
):
return
self
.
lm_loss
return
self
.
lm_loss
def
prepare_inputs_for_generation
(
self
,
input_ids
,
**
model_kwargs
):
# Add dummy token at the end (no attention on this one)
dummy_token
=
torch
.
zeros
((
1
,
1
),
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
input_ids
=
torch
.
cat
([
input_ids
,
dummy_token
],
dim
=
1
)
# Build permutation mask so that previous tokens don't see last token
perm_mask
=
torch
.
zeros
(
(
input_ids
.
shape
[
0
],
input_ids
.
shape
[
1
],
input_ids
.
shape
[
1
]),
dtype
=
torch
.
float
,
device
=
input_ids
.
device
)
perm_mask
[:,
:,
-
1
]
=
1.0
# We'll only predict the last token
target_mapping
=
torch
.
zeros
(
(
input_ids
.
shape
[
0
],
1
,
input_ids
.
shape
[
1
]),
dtype
=
torch
.
float
,
device
=
input_ids
.
device
)
target_mapping
[
0
,
0
,
-
1
]
=
1.0
return
{
"input_ids"
:
input_ids
,
"perm_mask"
:
perm_mask
,
"target_mapping"
:
target_mapping
}
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
token_type_ids
=
None
,
input_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
labels
=
None
):
token_type_ids
=
None
,
input_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
labels
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
transformer_outputs
=
self
.
transformer
(
input_ids
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment