Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
07bc8efb
Commit
07bc8efb
authored
Nov 15, 2019
by
Rémi Louf
Browse files
add greedy decoding and sampling
parent
e57d00ee
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
785 additions
and
202 deletions
+785
-202
examples/run_generation.py
examples/run_generation.py
+140
-170
transformers/modeling_encoder_decoder.py
transformers/modeling_encoder_decoder.py
+132
-30
transformers/modeling_transfo_xl.py
transformers/modeling_transfo_xl.py
+9
-1
transformers/modeling_utils.py
transformers/modeling_utils.py
+229
-0
transformers/modeling_xlm.py
transformers/modeling_xlm.py
+28
-1
transformers/modeling_xlnet.py
transformers/modeling_xlnet.py
+34
-0
transformers/tests/sampling_test.py
transformers/tests/sampling_test.py
+213
-0
No files found.
examples/run_generation.py
View file @
07bc8efb
...
@@ -20,14 +20,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
...
@@ -20,14 +20,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
argparse
import
argparse
import
logging
import
logging
from
tqdm
import
trange
import
torch
import
torch
import
torch.nn.functional
as
F
import
numpy
as
np
import
numpy
as
np
from
transformers
import
GPT2Config
,
OpenAIGPTConfig
,
XLNetConfig
,
TransfoXLConfig
,
XLMConfig
,
CTRLConfig
from
transformers
import
GPT2LMHeadModel
,
GPT2Tokenizer
from
transformers
import
GPT2LMHeadModel
,
GPT2Tokenizer
from
transformers
import
OpenAIGPTLMHeadModel
,
OpenAIGPTTokenizer
from
transformers
import
OpenAIGPTLMHeadModel
,
OpenAIGPTTokenizer
from
transformers
import
XLNetLMHeadModel
,
XLNetTokenizer
from
transformers
import
XLNetLMHeadModel
,
XLNetTokenizer
...
@@ -36,22 +32,22 @@ from transformers import CTRLLMHeadModel, CTRLTokenizer
...
@@ -36,22 +32,22 @@ from transformers import CTRLLMHeadModel, CTRLTokenizer
from
transformers
import
XLMWithLMHeadModel
,
XLMTokenizer
from
transformers
import
XLMWithLMHeadModel
,
XLMTokenizer
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
logging
.
basicConfig
(
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
format
=
"%(asctime)s - %(levelname)s - %(name)s - %(message)s"
,
level
=
logging
.
INFO
)
datefmt
=
"%m/%d/%Y %H:%M:%S"
,
level
=
logging
.
INFO
,
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
MAX_LENGTH
=
int
(
10000
)
# Hardcoded max length to avoid infinite loop
MAX_LENGTH
=
int
(
10000
)
# Hardcoded max length to avoid infinite loop
ALL_MODELS
=
sum
((
tuple
(
conf
.
pretrained_config_archive_map
.
keys
())
for
conf
in
(
GPT2Config
,
OpenAIGPTConfig
,
XLNetConfig
,
TransfoXLConfig
,
XLMConfig
,
CTRLConfig
)),
())
MODEL_CLASSES
=
{
MODEL_CLASSES
=
{
'
gpt2
'
:
(
GPT2LMHeadModel
,
GPT2Tokenizer
),
"
gpt2
"
:
(
GPT2LMHeadModel
,
GPT2Tokenizer
),
'
ctrl
'
:
(
CTRLLMHeadModel
,
CTRLTokenizer
),
"
ctrl
"
:
(
CTRLLMHeadModel
,
CTRLTokenizer
),
'
openai-gpt
'
:
(
OpenAIGPTLMHeadModel
,
OpenAIGPTTokenizer
),
"
openai-gpt
"
:
(
OpenAIGPTLMHeadModel
,
OpenAIGPTTokenizer
),
'
xlnet
'
:
(
XLNetLMHeadModel
,
XLNetTokenizer
),
"
xlnet
"
:
(
XLNetLMHeadModel
,
XLNetTokenizer
),
'
transfo-xl
'
:
(
TransfoXLLMHeadModel
,
TransfoXLTokenizer
),
"
transfo-xl
"
:
(
TransfoXLLMHeadModel
,
TransfoXLTokenizer
),
'
xlm
'
:
(
XLMWithLMHeadModel
,
XLMTokenizer
),
"
xlm
"
:
(
XLMWithLMHeadModel
,
XLMTokenizer
),
}
}
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
...
@@ -75,81 +71,78 @@ def set_seed(args):
...
@@ -75,81 +71,78 @@ def set_seed(args):
if
args
.
n_gpu
>
0
:
if
args
.
n_gpu
>
0
:
torch
.
cuda
.
manual_seed_all
(
args
.
seed
)
torch
.
cuda
.
manual_seed_all
(
args
.
seed
)
#
# Functions to prepare models' input
#
def
prepare_ctrl_input
(
args
,
_
,
tokenizer
,
prompt_text
):
if
args
.
temperature
>
0.7
:
logger
.
info
(
"CTRL typically works better with lower temperatures (and lower top_k)."
)
encoded_prompt
=
tokenizer
.
encode
(
prompt_text
,
add_special_tokens
=
False
)
if
not
any
(
encoded_prompt
[
0
]
==
x
for
x
in
tokenizer
.
control_codes
.
values
()):
logger
.
info
(
"WARNING! You are not starting your generation from a control code so you won't get good results"
)
return
prompt_text
,
{}
def
prepare_xlm_input
(
args
,
model
,
tokenizer
,
prompt_text
):
kwargs
=
{
"language"
:
None
,
"mask_token"
:
None
}
# Set the language
use_lang_emb
=
hasattr
(
model
.
config
,
"use_lang_emb"
)
and
model
.
config
.
use_lang_emb
if
hasattr
(
model
.
config
,
"lang2id"
)
and
use_lang_emb
:
available_languages
=
model
.
config
.
lang2id
.
keys
()
if
args
.
xlm_language
in
available_languages
:
language
=
args
.
xlm_language
else
:
language
=
None
while
language
not
in
available_languages
:
language
=
input
(
"Using XLM. Select language in "
+
str
(
list
(
available_languages
))
+
" >>> "
)
kwargs
[
"language"
]
=
tokenizer
.
lang2id
[
language
]
# XLM masked-language modeling (MLM) models need masked token
is_xlm_mlm
=
"mlm"
in
args
.
model_name_or_path
if
is_xlm_mlm
:
kwargs
[
"mask_token"
]
=
tokenizer
.
mask_token_id
return
prompt_text
,
kwargs
def
top_k_top_p_filtering
(
logits
,
top_k
=
0
,
top_p
=
0.0
,
filter_value
=-
float
(
'Inf'
)):
def
prepare_xlnet_input
(
args
,
_
,
tokenizer
,
prompt_text
):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
prompt_text
=
(
args
.
padding_text
if
args
.
padding_text
else
PADDING_TEXT
)
+
prompt_text
Args:
return
prompt_text
,
{}
logits: logits distribution shape (batch size x vocabulary size)
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
def
prepare_transfoxl_input
(
args
,
_
,
tokenizer
,
prompt_text
):
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
prompt_text
=
(
args
.
padding_text
if
args
.
padding_text
else
PADDING_TEXT
)
+
prompt_text
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
return
prompt_text
,
{}
"""
top_k
=
min
(
top_k
,
logits
.
size
(
-
1
))
# Safety check
if
top_k
>
0
:
PREPROCESSING_FUNCTIONS
=
{
# Remove all tokens with a probability less than the last token of the top-k
"ctrl"
:
prepare_ctrl_input
,
indices_to_remove
=
logits
<
torch
.
topk
(
logits
,
top_k
)[
0
][...,
-
1
,
None
]
"xlm"
:
prepare_xlm_input
,
logits
[
indices_to_remove
]
=
filter_value
"xlnet"
:
prepare_xlnet_input
,
"transfo-xl"
:
prepare_transfoxl_input
,
if
top_p
>
0.0
:
}
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
True
)
cumulative_probs
=
torch
.
cumsum
(
F
.
softmax
(
sorted_logits
,
dim
=-
1
),
dim
=-
1
)
def
adjust_length_to_model
(
length
,
max_sequence_length
):
# Remove tokens with cumulative probability above the threshold
if
length
<
0
and
max_sequence_length
>
0
:
sorted_indices_to_remove
=
cumulative_probs
>
top_p
length
=
max_sequence_length
# Shift the indices to the right to keep also the first token above the threshold
elif
0
<
max_sequence_length
<
length
:
sorted_indices_to_remove
[...,
1
:]
=
sorted_indices_to_remove
[...,
:
-
1
].
clone
()
length
=
max_sequence_length
# No generation bigger than model size
sorted_indices_to_remove
[...,
0
]
=
0
elif
length
<
0
:
length
=
MAX_LENGTH
# avoid infinite loop
# scatter sorted tensors to original indexing
return
length
indices_to_remove
=
sorted_indices_to_remove
.
scatter
(
dim
=
1
,
index
=
sorted_indices
,
src
=
sorted_indices_to_remove
)
logits
[
indices_to_remove
]
=
filter_value
return
logits
def
sample_sequence
(
model
,
length
,
context
,
num_samples
=
1
,
temperature
=
1
,
top_k
=
0
,
top_p
=
0.0
,
repetition_penalty
=
1.0
,
is_xlnet
=
False
,
is_xlm_mlm
=
False
,
xlm_mask_token
=
None
,
xlm_lang
=
None
,
device
=
'cpu'
):
context
=
torch
.
tensor
(
context
,
dtype
=
torch
.
long
,
device
=
device
)
context
=
context
.
unsqueeze
(
0
).
repeat
(
num_samples
,
1
)
generated
=
context
with
torch
.
no_grad
():
for
_
in
trange
(
length
):
inputs
=
{
'input_ids'
:
generated
}
if
is_xlnet
:
# XLNet is a direct (predict same token, not next token) and bi-directional model by default
# => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring)
input_ids
=
torch
.
cat
((
generated
,
torch
.
zeros
((
1
,
1
),
dtype
=
torch
.
long
,
device
=
device
)),
dim
=
1
)
perm_mask
=
torch
.
zeros
((
1
,
input_ids
.
shape
[
1
],
input_ids
.
shape
[
1
]),
dtype
=
torch
.
float
,
device
=
device
)
perm_mask
[:,
:,
-
1
]
=
1.0
# Previous tokens don't see last token
target_mapping
=
torch
.
zeros
((
1
,
1
,
input_ids
.
shape
[
1
]),
dtype
=
torch
.
float
,
device
=
device
)
target_mapping
[
0
,
0
,
-
1
]
=
1.0
# predict last token
inputs
=
{
'input_ids'
:
input_ids
,
'perm_mask'
:
perm_mask
,
'target_mapping'
:
target_mapping
}
if
is_xlm_mlm
and
xlm_mask_token
:
# XLM MLM models are direct models (predict same token, not next token)
# => need one additional dummy token in the input (will be masked and guessed)
input_ids
=
torch
.
cat
((
generated
,
torch
.
full
((
1
,
1
),
xlm_mask_token
,
dtype
=
torch
.
long
,
device
=
device
)),
dim
=
1
)
inputs
=
{
'input_ids'
:
input_ids
}
if
xlm_lang
is
not
None
:
inputs
[
"langs"
]
=
torch
.
tensor
([
xlm_lang
]
*
inputs
[
"input_ids"
].
shape
[
1
],
device
=
device
).
view
(
1
,
-
1
)
outputs
=
model
(
**
inputs
)
# Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
next_token_logits
=
outputs
[
0
][:,
-
1
,
:]
/
(
temperature
if
temperature
>
0
else
1.
)
# repetition penalty from CTRL (https://arxiv.org/abs/1909.05858)
for
i
in
range
(
num_samples
):
for
_
in
set
(
generated
[
i
].
tolist
()):
next_token_logits
[
i
,
_
]
/=
repetition_penalty
filtered_logits
=
top_k_top_p_filtering
(
next_token_logits
,
top_k
=
top_k
,
top_p
=
top_p
)
if
temperature
==
0
:
# greedy sampling:
next_token
=
torch
.
argmax
(
filtered_logits
,
dim
=-
1
).
unsqueeze
(
-
1
)
else
:
next_token
=
torch
.
multinomial
(
F
.
softmax
(
filtered_logits
,
dim
=-
1
),
num_samples
=
1
)
generated
=
torch
.
cat
((
generated
,
next_token
),
dim
=
1
)
return
generated
def
main
():
def
main
():
...
@@ -157,104 +150,81 @@ def main():
...
@@ -157,104 +150,81 @@ def main():
parser
.
add_argument
(
"--model_type"
,
default
=
None
,
type
=
str
,
required
=
True
,
parser
.
add_argument
(
"--model_type"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Model type selected in the list: "
+
", "
.
join
(
MODEL_CLASSES
.
keys
()))
help
=
"Model type selected in the list: "
+
", "
.
join
(
MODEL_CLASSES
.
keys
()))
parser
.
add_argument
(
"--model_name_or_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
parser
.
add_argument
(
"--model_name_or_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to pre-trained model or shortcut name selected in the list: "
+
", "
.
join
(
ALL_MODELS
))
help
=
"Path to pre-trained model or shortcut name selected in the list: "
+
", "
.
join
(
MODEL_CLASSES
.
keys
()))
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--padding_text"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--xlm_lang"
,
type
=
str
,
default
=
""
,
help
=
"Optional language when used with the XLM model."
)
parser
.
add_argument
(
"--length"
,
type
=
int
,
default
=
20
)
parser
.
add_argument
(
"--length"
,
type
=
int
,
default
=
20
)
parser
.
add_argument
(
"--num_samples"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--stop_token"
,
type
=
str
,
default
=
None
,
help
=
"Token at which text generation is stopped"
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
,
help
=
"temperature of 0 implies greedy sampling"
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
,
help
=
"temperature of 0 implies greedy sampling"
)
parser
.
add_argument
(
"--repetition_penalty"
,
type
=
float
,
default
=
1.0
,
parser
.
add_argument
(
"--repetition_penalty"
,
type
=
float
,
default
=
1.0
,
help
=
"primarily useful for CTRL model; in that case, use 1.2"
)
help
=
"primarily useful for CTRL model; in that case, use 1.2"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--p"
,
type
=
float
,
default
=
0.9
)
parser
.
add_argument
(
"--top_p"
,
type
=
float
,
default
=
0.9
)
parser
.
add_argument
(
"--no_cuda"
,
action
=
'store_true'
,
parser
.
add_argument
(
"--padding_text"
,
type
=
str
,
default
=
""
,
help
=
"Padding text for Transfo-XL and XLNet."
)
help
=
"Avoid using CUDA when available"
)
parser
.
add_argument
(
"--xlm_language"
,
type
=
str
,
default
=
""
,
help
=
"Optional language when used with the XLM model."
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
42
,
help
=
"random seed for initialization"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
help
=
"random seed for initialization"
)
parser
.
add_argument
(
'--stop_token'
,
type
=
str
,
default
=
None
,
parser
.
add_argument
(
"--no_cuda"
,
action
=
"store_true"
,
help
=
"Avoid using CUDA when available"
)
help
=
"Token at which text generation is stopped"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
args
.
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
and
not
args
.
no_cuda
else
"cpu"
)
args
.
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
and
not
args
.
no_cuda
else
"cpu"
)
args
.
n_gpu
=
torch
.
cuda
.
device_count
()
args
.
n_gpu
=
torch
.
cuda
.
device_count
()
set_seed
(
args
)
set_seed
(
args
)
args
.
model_type
=
args
.
model_type
.
lower
()
# Initialize the model and tokenizer
model_class
,
tokenizer_class
=
MODEL_CLASSES
[
args
.
model_type
]
try
:
args
.
model_type
=
args
.
model_type
.
lower
()
model_class
,
tokenizer_class
=
MODEL_CLASSES
[
args
.
model_type
]
except
KeyError
as
ke
:
raise
ke
(
"the model {} you specified is not supported. You are welcome to add it and open a PR :)"
)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
model_name_or_path
)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
model_class
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
model_class
.
from_pretrained
(
args
.
model_name_or_path
)
model
.
to
(
args
.
device
)
model
.
to
(
args
.
device
)
model
.
eval
()
model
.
eval
()
if
args
.
length
<
0
and
model
.
config
.
max_position_embeddings
>
0
:
args
.
length
=
adjust_length_to_model
(
args
.
length
=
model
.
config
.
max_position_embeddings
args
.
length
,
max_sequence_length
=
model
.
config
.
max_position_embeddings
elif
0
<
model
.
config
.
max_position_embeddings
<
args
.
length
:
)
args
.
length
=
model
.
config
.
max_position_embeddings
# No generation bigger than model size
elif
args
.
length
<
0
:
args
.
length
=
MAX_LENGTH
# avoid infinite loop
logger
.
info
(
args
)
logger
.
info
(
args
)
if
args
.
model_type
in
[
"ctrl"
]:
if
args
.
temperature
>
0.7
:
logger
.
info
(
'CTRL typically works better with lower temperatures (and lower top_k).'
)
while
True
:
xlm_lang
=
None
# XLM Language usage detailed in the issues #1414
if
args
.
model_type
in
[
"xlm"
]
and
hasattr
(
tokenizer
,
'lang2id'
)
and
hasattr
(
model
.
config
,
'use_lang_emb'
)
\
and
model
.
config
.
use_lang_emb
:
if
args
.
xlm_lang
:
language
=
args
.
xlm_lang
else
:
language
=
None
while
language
not
in
tokenizer
.
lang2id
.
keys
():
language
=
input
(
"Using XLM. Select language in "
+
str
(
list
(
tokenizer
.
lang2id
.
keys
()))
+
" >>> "
)
xlm_lang
=
tokenizer
.
lang2id
[
language
]
# XLM masked-language modeling (MLM) models need masked token (see details in sample_sequence)
is_xlm_mlm
=
args
.
model_type
in
[
"xlm"
]
and
'mlm'
in
args
.
model_name_or_path
if
is_xlm_mlm
:
xlm_mask_token
=
tokenizer
.
mask_token_id
else
:
xlm_mask_token
=
None
raw_text
=
args
.
prompt
if
args
.
prompt
else
input
(
"Model prompt >>> "
)
if
args
.
model_type
in
[
"transfo-xl"
,
"xlnet"
]:
# Models with memory likes to have a long prompt for short inputs.
raw_text
=
(
args
.
padding_text
if
args
.
padding_text
else
PADDING_TEXT
)
+
raw_text
context_tokens
=
tokenizer
.
encode
(
raw_text
,
add_special_tokens
=
False
)
if
args
.
model_type
==
"ctrl"
:
if
not
any
(
context_tokens
[
0
]
==
x
for
x
in
tokenizer
.
control_codes
.
values
()):
logger
.
info
(
"WARNING! You are not starting your generation from a control code so you won't get good results"
)
out
=
sample_sequence
(
model
=
model
,
context
=
context_tokens
,
num_samples
=
args
.
num_samples
,
length
=
args
.
length
,
temperature
=
args
.
temperature
,
top_k
=
args
.
top_k
,
top_p
=
args
.
top_p
,
repetition_penalty
=
args
.
repetition_penalty
,
is_xlnet
=
bool
(
args
.
model_type
==
"xlnet"
),
is_xlm_mlm
=
is_xlm_mlm
,
xlm_mask_token
=
xlm_mask_token
,
xlm_lang
=
xlm_lang
,
device
=
args
.
device
,
)
out
=
out
[:,
len
(
context_tokens
):].
tolist
()
for
o
in
out
:
text
=
tokenizer
.
decode
(
o
,
clean_up_tokenization_spaces
=
True
)
text
=
text
[:
text
.
find
(
args
.
stop_token
)
if
args
.
stop_token
else
None
]
print
(
text
)
prompt_text
=
args
.
prompt
if
args
.
prompt
else
input
(
"Model prompt >>> "
)
# Different models need different input formatting and/or extra arguments
requires_preprocessing
=
args
.
model_type
in
PREPROCESSING_FUNCTIONS
.
keys
()
model_kwargs
=
{}
if
requires_preprocessing
:
prepare_input
=
PREPROCESSING_FUNCTIONS
.
get
(
args
.
model_type
)
prompt_text
,
model_kwargs
=
prepare_input
(
args
,
model
,
tokenizer
,
prompt_text
)
encoded_prompt
=
torch
.
tensor
(
tokenizer
.
encode
(
prompt_text
,
add_special_tokens
=
False
)).
unsqueeze
(
0
)
output_sequences
=
model
.
decode
(
prompt_ids
=
encoded_prompt
,
length
=
args
.
length
,
temperature
=
args
.
temperature
,
k
=
args
.
k
,
p
=
args
.
p
,
repetition_penalty
=
args
.
repetition_penalty
,
device
=
args
.
device
,
**
model_kwargs
,
)
generated_sequence
=
output_sequences
.
tolist
()[
encoded_prompt
.
size
(
1
)
:
]
# adapted to case where num_samples > 1
text
=
tokenizer
.
decode
(
generated_sequence
,
clean_up_tokenization_spaces
=
True
)
text
=
text
[:
text
.
find
(
args
.
stop_token
)
if
args
.
stop_token
else
None
]
print
(
text
)
if
args
.
prompt
:
break
return
text
return
text
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
main
()
main
()
transformers/modeling_encoder_decoder.py
View file @
07bc8efb
...
@@ -18,11 +18,14 @@ from __future__ import absolute_import, division, print_function, unicode_litera
...
@@ -18,11 +18,14 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
logging
import
logging
import
os
import
os
import
warnings
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
tqdm
import
trange
from
.modeling_auto
import
AutoModel
,
AutoModelWithLMHead
from
.modeling_auto
import
AutoModel
,
AutoModelWithLMHead
from
.modeling_utils
import
Sampler
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -117,8 +120,7 @@ class PreTrainedEncoderDecoder(nn.Module):
...
@@ -117,8 +120,7 @@ class PreTrainedEncoderDecoder(nn.Module):
kwargs_common
=
{
kwargs_common
=
{
argument
:
value
argument
:
value
for
argument
,
value
in
kwargs
.
items
()
for
argument
,
value
in
kwargs
.
items
()
if
not
argument
.
startswith
(
"encoder_"
)
if
not
argument
.
startswith
(
"encoder_"
)
and
not
argument
.
startswith
(
"decoder_"
)
and
not
argument
.
startswith
(
"decoder_"
)
}
}
kwargs_decoder
=
kwargs_common
.
copy
()
kwargs_decoder
=
kwargs_common
.
copy
()
kwargs_encoder
=
kwargs_common
.
copy
()
kwargs_encoder
=
kwargs_common
.
copy
()
...
@@ -186,51 +188,151 @@ class PreTrainedEncoderDecoder(nn.Module):
...
@@ -186,51 +188,151 @@ class PreTrainedEncoderDecoder(nn.Module):
Indices of decoder input sequence tokens in the vocabulary.
Indices of decoder input sequence tokens in the vocabulary.
kwargs: (`optional`) Remaining dictionary of keyword arguments.
kwargs: (`optional`) Remaining dictionary of keyword arguments.
"""
"""
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
kwargs_encoder
,
kwargs_decoder
=
self
.
prepare_model_kwargs
(
**
kwargs
)
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as whole.
# Encode if needed (training, first prediction pass)
# We let the specific kwargs override the common ones in case of conflict.
encoder_hidden_states
=
kwargs_encoder
.
pop
(
"hidden_states"
,
None
)
if
encoder_hidden_states
is
None
:
encoder_outputs
=
self
.
encoder
(
encoder_input_ids
,
**
kwargs_encoder
)
encoder_hidden_states
=
encoder_outputs
[
0
]
else
:
encoder_outputs
=
()
kwargs_decoder
[
"encoder_hidden_states"
]
=
encoder_hidden_states
decoder_outputs
=
self
.
decoder
(
decoder_input_ids
,
encoder_hidden_states
,
**
kwargs_decoder
)
return
decoder_outputs
+
encoder_outputs
def
decode
(
self
,
encoder_input_ids
,
decoder_prompt_ids
=
None
,
device
=
torch
.
device
(
"cpu"
),
length
=
10
,
do_sample
=
False
,
temperature
=
1.0
,
k
=
9
,
p
=
0.
,
repetition_penalty
=
1.
,
**
kwargs
):
""" Generic sequence generator for encoder-decoder models.
For encoder-decoders the generation consists in:
- Performing a forward pass through the encoder once;
- Pass the encoder's hidden states to a decoding mechanism that
repeatedly calls the decoder to generate sequences.
The method currently supports greedy decoding and sampling. See the
documentation of the `Sampler` class for more information about the
parameters related to sampling.
Params:
**encoder_input_ids**: `torch.LongTensor` of shape (1, sequence_length)
The sequence to encode.
**decoder_prompt_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length)
The sequence used as a prompt for the generation. If `None` the method initializes
it as an empty `torch.LongTensor` of shape (1,)
**device**: (`optional`) `torch.device`
The device on which the prompt_ids will be initialized if not provided.
**length**: (`optional`) int
The length of the sequence to be generated.
**do_sample**: (`optional`) bool
If set to `False` we use greedy decoding; otherwise sampling.
**temperature**: (`optional`) float
The value used to module the next token probabilities.
**k**: (`optional`) int
The parameter used for k-filtering.
**p**: (`optional`) float
The parameter for nucleus sampling. Must be between 0 and 1.
**repetition_penalty**: (`optional`) float
The parameter for repetition penalty.
"""
if
decoder_prompt_ids
is
None
:
decoder_prompt_ids
=
torch
.
tensor
([[]],
dtype
=
torch
.
long
,
device
=
device
)
# When the model does not have a LM head `get_output_embeddings`
# returns `None`. We use this mechanism to determine whether we
# should proceed with decoding or not.
if
self
.
decoder
.
get_output_embeddings
()
is
None
:
raise
AttributeError
(
"You tried do generated sequences with a decoder that does not have a LM Head."
)
# The followings checks that the decoder is on the same device as the one
# that is specified. It only works for models that fit on one GPU.
decoder_device
=
next
(
self
.
decoder
.
parameters
()).
device
if
decoder_device
!=
decoder_prompt_ids
.
device
:
warnings
.
warn
(
"The decoder is not on the same device as the prompt. Expected {}, got {}."
.
format
(
decoder_prompt_ids
.
device
,
decoder_device
)
)
kwargs_encoder
,
kwargs_decoder
=
self
.
prepare_model_kwargs
(
**
kwargs
)
with
torch
.
no_grad
():
encoder_outputs
=
self
.
encoder
(
encoder_input_ids
,
**
kwargs
)
encoder_hidden_states
=
encoder_outputs
[
0
]
kwargs_decoder
[
"encoder_hidden_states"
]
=
encoder_hidden_states
sampler_config
=
{
"k"
:
k
,
"p"
:
p
,
"do_sample"
:
do_sample
,
"temperature"
:
temperature
,
"repetition_penalty"
:
repetition_penalty
,
}
return
self
.
_greedy_decode_or_sample
(
decoder_prompt_ids
,
length
,
sampler_config
,
**
kwargs_decoder
)
def
_greedy_decode_or_sample
(
self
,
prompt_ids
,
length
,
sampler_config
,
**
kwargs_decoder
):
sampler
=
Sampler
(
**
sampler_config
)
with
torch
.
no_grad
():
generated_sequence
=
prompt_ids
for
_
in
trange
(
length
):
arguments
=
self
.
decoder
.
_prepare_inputs_for_decoding
(
generated_sequence
,
**
kwargs_decoder
)
outputs
=
self
.
decoder
(
**
arguments
)
next_tokens_logits
=
outputs
[
0
][:,
-
1
,
:]
next_tokens
=
sampler
.
get_one_token
(
next_tokens_logits
,
generated_sequence
)
generated_sequence
=
torch
.
cat
((
generated_sequence
,
next_tokens
),
dim
=
1
)
return
generated_sequence
.
squeeze
(
0
)
@
staticmethod
def
prepare_model_kwargs
(
**
kwargs
):
""" Prepare the encoder and decoder's keyword arguments.
Keyword arguments come in 3 flavors:
- encoder-specific (prefixed by `encoder_`)
- decoder-specific (prefixed by `decoder_`)
- those that apply to the model as whole.
We let the specific kwargs override the common ones in case of
conflict.
"""
kwargs_common
=
{
kwargs_common
=
{
argument
:
value
argument
:
value
for
argument
,
value
in
kwargs
.
items
()
for
argument
,
value
in
kwargs
.
items
()
if
not
argument
.
startswith
(
"encoder_"
)
if
not
argument
.
startswith
(
"encoder_"
)
and
not
argument
.
startswith
(
"decoder_"
)
and
not
argument
.
startswith
(
"decoder_"
)
}
}
kwargs_
decoder
=
kwargs_common
.
copy
()
decoder
_kwargs
=
kwargs_common
.
copy
()
kwargs_
encoder
=
kwargs_common
.
copy
()
encoder
_kwargs
=
kwargs_common
.
copy
()
kwargs_
encoder
.
update
(
encoder
_kwargs
.
update
(
{
{
argument
[
len
(
"encoder_"
)
:]:
value
argument
[
len
(
"encoder_"
)
:]:
value
for
argument
,
value
in
kwargs
.
items
()
for
argument
,
value
in
kwargs
.
items
()
if
argument
.
startswith
(
"encoder_"
)
if
argument
.
startswith
(
"encoder_"
)
}
}
)
)
kwargs_
decoder
.
update
(
decoder
_kwargs
.
update
(
{
{
argument
[
len
(
"decoder_"
)
:]:
value
argument
[
len
(
"decoder_"
)
:]:
value
for
argument
,
value
in
kwargs
.
items
()
for
argument
,
value
in
kwargs
.
items
()
if
argument
.
startswith
(
"decoder_"
)
if
argument
.
startswith
(
"decoder_"
)
}
}
)
)
decoder_kwargs
[
"encoder_attention_mask"
]
=
encoder_kwargs
.
get
(
"attention_mask"
,
None
)
# Encode if needed (training, first prediction pass)
return
encoder_kwargs
,
decoder_kwargs
encoder_hidden_states
=
kwargs_encoder
.
pop
(
"hidden_states"
,
None
)
if
encoder_hidden_states
is
None
:
encoder_outputs
=
self
.
encoder
(
encoder_input_ids
,
**
kwargs_encoder
)
encoder_hidden_states
=
encoder_outputs
[
0
]
# output the last layer hidden state
else
:
encoder_outputs
=
()
# Decode
kwargs_decoder
[
"encoder_hidden_states"
]
=
encoder_hidden_states
kwargs_decoder
[
"encoder_attention_mask"
]
=
kwargs_encoder
.
get
(
"attention_mask"
,
None
)
decoder_outputs
=
self
.
decoder
(
decoder_input_ids
,
**
kwargs_decoder
)
return
decoder_outputs
+
encoder_outputs
class
Model2Model
(
PreTrainedEncoderDecoder
):
class
Model2Model
(
PreTrainedEncoderDecoder
):
...
...
transformers/modeling_transfo_xl.py
View file @
07bc8efb
...
@@ -36,7 +36,7 @@ from torch.nn.parameter import Parameter
...
@@ -36,7 +36,7 @@ from torch.nn.parameter import Parameter
from
.modeling_utils
import
PreTrainedModel
,
Conv1D
,
prune_conv1d_layer
,
SequenceSummary
from
.modeling_utils
import
PreTrainedModel
,
Conv1D
,
prune_conv1d_layer
,
SequenceSummary
from
.configuration_transfo_xl
import
TransfoXLConfig
from
.configuration_transfo_xl
import
TransfoXLConfig
from
.modeling_transfo_xl_utilities
import
ProjectedAdaptiveLogSoftmax
,
sample_logits
from
.modeling_transfo_xl_utilities
import
ProjectedAdaptiveLogSoftmax
,
sample_logits
,
LogUniformSampler
from
.file_utils
import
add_start_docstrings
from
.file_utils
import
add_start_docstrings
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -908,3 +908,11 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
...
@@ -908,3 +908,11 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
outputs
=
[
softmax_output
,
None
]
+
outputs
outputs
=
[
softmax_output
,
None
]
+
outputs
return
outputs
# (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions)
return
outputs
# (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions)
def
get_output_embeddings
(
self
):
""" Double-check if you are using adaptive softmax.
"""
if
self
.
sample_softmax
>
0
:
return
self
.
out_layer
else
:
return
self
.
crit
.
out_layers
[
-
1
]
transformers/modeling_utils.py
View file @
07bc8efb
...
@@ -23,12 +23,14 @@ import json
...
@@ -23,12 +23,14 @@ import json
import
logging
import
logging
import
os
import
os
from
io
import
open
from
io
import
open
import
warnings
import
six
import
six
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torch.nn
import
CrossEntropyLoss
from
torch.nn
import
CrossEntropyLoss
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
tqdm
import
trange
from
.configuration_utils
import
PretrainedConfig
from
.configuration_utils
import
PretrainedConfig
from
.file_utils
import
cached_path
,
WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
from
.file_utils
import
cached_path
,
WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
...
@@ -87,6 +89,93 @@ class PreTrainedModel(nn.Module):
...
@@ -87,6 +89,93 @@ class PreTrainedModel(nn.Module):
def
base_model
(
self
):
def
base_model
(
self
):
return
getattr
(
self
,
self
.
base_model_prefix
,
self
)
return
getattr
(
self
,
self
.
base_model_prefix
,
self
)
def
decode
(
self
,
prompt_ids
=
None
,
device
=
torch
.
device
(
'cpu'
),
length
=
10
,
do_sample
=
False
,
temperature
=
1.
,
k
=
9
,
p
=
0
,
repetition_penalty
=
1
,
**
model_kwargs
):
""" Generic sequence generator for single-stack models with a LM head.
The method currently supports greedy decoding and sampling. See the
documentation of the `Sampler` class for more information about the
parameters related to sampling.
Params:
**encoder_input_ids**: `torch.LongTensor` of shape (1, sequence_length)
The sequence to encode.
**decoder_prompt_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length)
The sequence used as a prompt for the generation. If `None` the method initializes
it as an empty `torch.LongTensor` of shape (1,)
**device**: (`optional`) `torch.device`
The device on which the prompt_ids will be initialized if not provided.
**length**: (`optional`) int
The length of the sequence to be generated.
**do_sample**: (`optional`) bool
If set to `False` we use greedy decoding; otherwise sampling.
**temperature**: (`optional`) float
The value used to module the next token probabilities.
**k**: (`optional`) int
The parameter used for k-filtering.
**p**: (`optional`) float
The parameter for nucleus sampling. Must be between 0 and 1.
**repetition_penalty**: (`optional`) float
The parameter for repetition penalty.
"""
if
prompt_ids
is
None
:
prompt_ids
=
torch
.
tensor
([[]],
dtype
=
torch
.
long
,
device
=
device
)
# When the model does not have a LM head `get_output_embeddings`
# returns `None`. We use this mechanism to determine whether we
# should proceed with decoding or not.
if
self
.
get_output_embeddings
()
is
None
:
raise
AttributeError
(
"You tried do generated sequences with a model that does not have a LM Head."
)
# The followings checks that the model is on the same device as the one
# that is specified. It only works for models that fit on one GPU.
model_device
=
next
(
self
.
parameters
()).
device
if
model_device
!=
prompt_ids
.
device
:
warnings
.
warn
(
"The model is not on the same device as the prompts. Expected {}, got {}."
.
format
(
prompt_ids
.
device
,
model_device
)
)
sampler_config
=
{
"k"
:
k
,
"p"
:
p
,
"do_sample"
:
do_sample
,
"temperature"
:
temperature
,
"repetition_penalty"
:
repetition_penalty
,
}
return
self
.
_greedy_decode_or_sample
(
prompt_ids
,
length
,
sampler_config
,
**
model_kwargs
)
def
_greedy_decode_or_sample
(
self
,
prompt_ids
,
length
,
sampler_config
,
**
model_kwargs
):
""" Generate text using greedy decoding or by sampling tokens."""
sampler
=
Sampler
(
**
sampler_config
)
generated_sequence
=
prompt_ids
with
torch
.
no_grad
():
for
_
in
trange
(
length
):
arguments
=
self
.
_prepare_inputs_for_decoding
(
generated_sequence
,
**
model_kwargs
)
outputs
=
self
(
**
arguments
)
next_tokens_logits
=
outputs
[
0
][:,
-
1
,
:]
next_tokens
=
sampler
.
get_one_token
(
next_tokens_logits
,
generated_sequence
)
generated_sequence
=
torch
.
cat
((
generated_sequence
,
next_tokens
),
dim
=
1
)
return
generated_sequence
.
squeeze
(
0
)
def
_prepare_inputs_for_decoding
(
self
,
input_ids
,
**
kwargs
):
arguments
=
{
"input_ids"
:
input_ids
}
arguments
.
update
(
kwargs
)
return
arguments
def
get_input_embeddings
(
self
):
def
get_input_embeddings
(
self
):
""" Get model's input embeddings
""" Get model's input embeddings
"""
"""
...
@@ -859,3 +948,143 @@ def prune_layer(layer, index, dim=None):
...
@@ -859,3 +948,143 @@ def prune_layer(layer, index, dim=None):
return
prune_conv1d_layer
(
layer
,
index
,
dim
=
1
if
dim
is
None
else
dim
)
return
prune_conv1d_layer
(
layer
,
index
,
dim
=
1
if
dim
is
None
else
dim
)
else
:
else
:
raise
ValueError
(
"Can't prune layer of class {}"
.
format
(
layer
.
__class__
))
raise
ValueError
(
"Can't prune layer of class {}"
.
format
(
layer
.
__class__
))
class
Sampler
(
object
):
r
""" Sampler is used to generate sequences of ids from logit inputs.
Greedy decoding, which consists in chosing the most probable token at each
step, is the default behaviour. Sampling with varying temperature, top_k
and nucleus filtering is also implemented.
Attributes:
**device**: ``torch.device``
Device on which the computations will be run.
**do_sample**: bool
Whether to sample or do greedy decoding.
**k**: int between 0 and vocab_size
Parameter for the top-k filtering
**p**: float between 0 and 1
Parameter for the nucleus filtering
**temperature**: strictly positive float
Parameter used to modulate the distribution over ids. Low temperatures
put more emphasis on highly probably token while high temperatures tend
to smooth the probability distribution.
**repetition_penalty**: strictly postitive float
The penalty applied to repeating ids
"""
def
__init__
(
self
,
do_sample
=
False
,
k
=
9
,
p
=
0.0
,
temperature
=
1.0
,
repetition_penalty
=
1.0
):
self
.
k
=
k
self
.
p
=
p
self
.
do_sample
=
do_sample
self
.
temperature
=
temperature
self
.
repetition_penalty
=
repetition_penalty
self
.
do_apply_repetition_penalty
=
True
if
repetition_penalty
>
1
else
False
if
self
.
p
>
1
:
warnings
.
warn
(
"""You are trying to apply nucleus filtering with a value of p greater than 1 ({}).
However p is a probability and its value must lie between 0 and 1. In effect, no filtering
will be applied. If this is not the behavior you expect, change the value of p."""
.
format
(
self
.
p
)
)
def
get_one_token
(
self
,
next_token_logits
,
past_sequence
):
logits
=
self
.
apply_repetition_penalty
(
next_token_logits
,
past_sequence
)
if
self
.
do_sample
:
logits
=
self
.
apply_temperature
(
logits
)
logits
=
self
.
apply_top_k_filter
(
logits
)
logits
=
self
.
apply_nucleus_filter
(
logits
)
return
torch
.
multinomial
(
F
.
softmax
(
logits
,
dim
=-
1
),
num_samples
=
1
)
return
torch
.
argmax
(
logits
,
dim
=-
1
).
unsqueeze
(
-
1
)
def
apply_repetition_penalty
(
self
,
logits
,
past_sequence
):
""" Apply a penalty to tokens that appear more than once in the
generated sequence.
.. Keskar, Nitish Shirish, et al. "Ctrl: A conditional transformer
language model for controllable generation." arXiv preprint
arXiv:1909.05858 (2019).
"""
if
self
.
do_apply_repetition_penalty
:
generated_token_idx
=
set
(
past_sequence
[
0
].
tolist
())
for
token_idx
in
generated_token_idx
:
logits
[
0
,
token_idx
]
/=
self
.
repetition_penalty
return
logits
def
apply_temperature
(
self
,
logits
):
""" Shape the tokens' distribution through temperature. The higher the value
of the temperature, the more skewed towards high probability events the
distribution is.
.. Goodfellow, Ian, Yoshua Bengio, and Aaron Courville. Deep learning.
MIT press, 2016.
"""
# when dividing a float by 0, torch returns inf which in turns breaks the
# multinomial with an error message that is not very helpful. It is better
# for the user to break the execution and explain why.
if
self
.
temperature
==
0
:
raise
ZeroDivisionError
(
"""You are trying to sample with a temperature equal to 0.
If you wanted to do greedy sampling, set instead `do_sample` to False.
Otherwise set the temperature to a value different from 0."""
)
return
logits
/
self
.
temperature
def
apply_top_k_filter
(
self
,
logits
):
""" Use the probability distribution of the tokens to determine the set
to be sampled from. Specifically we select the set of size k such that
the sum of its items' probabilities is maximum.
.. Fan, Angela, Mike Lewis, and Yann Dauphin. "Hierarchical neural
story generation." arXiv preprint arXiv:1805.04833 (2018).
"""
if
self
.
k
>
0
:
vocabulary_size
=
logits
.
size
(
-
1
)
if
self
.
k
>
vocabulary_size
:
warnings
.
warn
(
"""You provided a value for k ({}) that is larger than the vocabulary size ({}).
We adjusted k's value to the vocabulary size; if that was what you intended to do
we recommend setting k to 0 instead. It this is not the behavior you expected,
choose a value of k that is smaller than the vocabulary size."""
.
format
(
self
.
k
,
vocabulary_size
)
)
self
.
k
=
vocabulary_size
indices_to_remove
=
logits
<
torch
.
topk
(
logits
,
self
.
k
)[
0
][...,
-
1
,
None
]
logits
[
indices_to_remove
]
=
-
float
(
"Inf"
)
return
logits
def
apply_nucleus_filter
(
self
,
logits
):
""" Use the probability distribution of the tokens to determine the set
to be sampled from. Specifically, choose the smallest set such that the
sum of its items' probabilities is greater than a number p in [0,1].
.. Holtzman, Ari, et al. "The curious case of neural text
degeneration." arXiv preprint arXiv:1904.09751 (2019).
"""
if
self
.
p
>
0
:
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
True
)
sorted_probabilities
=
F
.
softmax
(
sorted_logits
,
dim
=-
1
)
cumulative_probabilities
=
torch
.
cumsum
(
sorted_probabilities
,
dim
=-
1
)
# Remove tokens with cumulative probability above the threshold,
# but keep the first token above the threshold.
sorted_indices_to_remove
=
cumulative_probabilities
>
self
.
p
sorted_indices_to_remove
[...,
1
:]
=
sorted_indices_to_remove
[...,
:
-
1
].
clone
()
sorted_indices_to_remove
[...,
0
]
=
0
# scatter sorted tensors to original indexing
indices_to_remove
=
sorted_indices_to_remove
.
scatter
(
dim
=-
1
,
index
=
sorted_indices
,
src
=
sorted_indices_to_remove
)
logits
[
indices_to_remove
]
=
-
float
(
"Inf"
)
return
logits
transformers/modeling_xlm.py
View file @
07bc8efb
...
@@ -646,7 +646,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
...
@@ -646,7 +646,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
langs
=
langs
,
langs
=
langs
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
lengths
=
lengths
,
lengths
=
lengths
,
cache
=
cache
,
cache
=
cache
,
head_mask
=
head_mask
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
)
inputs_embeds
=
inputs_embeds
)
...
@@ -657,6 +657,33 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
...
@@ -657,6 +657,33 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
return
outputs
return
outputs
def
_prepare_inputs_for_decoding
(
self
,
input_ids
,
**
model_kwargs
):
mask_token
=
model_kwargs
.
pop
(
"mask_token"
,
None
)
language
=
model_kwargs
.
pop
(
"language"
,
None
)
input_ids
=
self
.
_append_mask_token
(
input_ids
,
mask_token
)
langs
=
self
.
_create_language_embeddings
(
input_ids
,
language
)
arguments
=
{
"input_ids"
:
input_ids
,
"langs"
:
langs
}
arguments
.
update
(
model_kwargs
)
return
arguments
@
staticmethod
def
_append_mask_token
(
sequence
,
mask_token_id
):
""" Append a [MASK] token at the end of the sequence that the MLM model
is going to try to predict.
"""
if
mask_token_id
is
not
None
:
tokens_to_append
=
torch
.
full
((
1
,
1
),
mask_token_id
,
dtype
=
torch
.
long
)
return
torch
.
cat
((
sequence
,
tokens_to_append
),
dim
=
1
)
return
sequence
@
staticmethod
def
_create_language_embeddings
(
sequence
,
language
):
if
language
is
not
None
:
return
torch
.
tensor
([
language
]
*
sequence
.
shape
[
1
]).
view
(
1
,
-
1
)
return
None
@
add_start_docstrings
(
"""XLM Model with a sequence classification/regression head on top (a linear layer on top of
@
add_start_docstrings
(
"""XLM Model with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """
,
the pooled output) e.g. for GLUE tasks. """
,
...
...
transformers/modeling_xlnet.py
View file @
07bc8efb
...
@@ -972,6 +972,40 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -972,6 +972,40 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
def
_prepare_inputs_for_decoding
(
self
,
input_ids
,
**
model_kwargs
):
input_ids
=
self
.
_add_dummy_token
(
input_ids
)
perm_mask
=
self
.
_create_perm_mask
(
input_ids
)
target_mapping
=
self
.
_create_target_mapping
(
input_ids
)
arguments
=
{
"input_ids"
:
input_ids
,
"perm_mask"
:
perm_mask
,
"target_mapping"
:
target_mapping
,
}
return
arguments
@
staticmethod
def
_add_dummy_token
(
sequence
):
dummy
=
torch
.
zeros
((
sequence
.
size
(
0
),
1
),
dtype
=
torch
.
long
)
return
torch
.
cat
((
sequence
,
dummy
),
dim
=
1
)
@
staticmethod
def
_create_perm_mask
(
sequence
):
mask
=
torch
.
zeros
(
(
sequence
.
shape
[
0
],
sequence
.
shape
[
1
],
sequence
.
shape
[
1
]),
dtype
=
torch
.
float
,
)
mask
[:,
:,
-
1
]
=
1.0
# Previous tokens don't see last token
return
mask
@
staticmethod
def
_create_target_mapping
(
sequence
):
target_mapping
=
torch
.
zeros
(
(
sequence
.
shape
[
0
],
1
,
sequence
.
shape
[
1
]),
dtype
=
torch
.
float
,
)
target_mapping
[
0
,
0
,
-
1
]
=
1.0
# predict last token
return
target_mapping
@
add_start_docstrings
(
"""XLNet Model with a sequence classification/regression head on top (a linear layer on top of
@
add_start_docstrings
(
"""XLNet Model with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """
,
the pooled output) e.g. for GLUE tasks. """
,
...
...
transformers/tests/sampling_test.py
0 → 100644
View file @
07bc8efb
# coding=utf-8
import
sys
import
unittest
import
numpy
as
np
import
pytest
from
transformers
import
is_torch_available
if
is_torch_available
():
import
torch
from
transformers
import
(
BertConfig
,
BertModel
,
GPT2Config
,
GPT2LMHeadModel
,
OpenAIGPTConfig
,
OpenAIGPTLMHeadModel
,
TransfoXLConfig
,
TransfoXLLMHeadModel
,
XLMConfig
,
XLMWithLMHeadModel
,
XLNetConfig
,
XLNetLMHeadModel
,
Model2Model
,
)
from
transformers.modeling_utils
import
Sampler
else
:
pytestmark
=
pytest
.
mark
.
skip
(
"Require Torch"
)
class
SamplerTest
(
unittest
.
TestCase
):
def
test_nucleus_sampling
(
self
):
inf
=
-
float
(
"Inf"
)
test_cases
=
(
{
"p"
:
0
,
"logits"
:
torch
.
tensor
([
0.3
,
0.1
,
0.2
]),
"expected"
:
torch
.
tensor
([
0.3
,
0.1
,
0.2
]),
},
{
"p"
:
0.01
,
"logits"
:
torch
.
tensor
([
0.3
,
0.1
,
0.2
]),
"expected"
:
torch
.
tensor
([
0.3
,
inf
,
inf
]),
},
{
"p"
:
1
,
"logits"
:
torch
.
tensor
([
0.3
,
0.1
,
0.2
]),
"expected"
:
torch
.
tensor
([
0.3
,
0.1
,
0.2
]),
},
{
"p"
:
0.2
,
"logits"
:
torch
.
tensor
([
0.7
,
0.1
,
0.2
]),
"expected"
:
torch
.
tensor
([
0.7
,
inf
,
inf
]),
},
{
"p"
:
0.71
,
"logits"
:
torch
.
tensor
([
0.7
,
0.1
,
0.2
]),
"expected"
:
torch
.
tensor
([
0.7
,
inf
,
0.2
]),
},
{
"p"
:
0.71
,
"logits"
:
torch
.
tensor
([
0.1
,
0.7
,
0.2
]),
"expected"
:
torch
.
tensor
([
inf
,
0.7
,
0.2
]),
},
{
"p"
:
0.71
,
"logits"
:
torch
.
tensor
([
0.7
,
0.2
,
0.1
]),
"expected"
:
torch
.
tensor
([
0.7
,
0.2
,
inf
]),
},
{
"p"
:
0.91
,
"logits"
:
torch
.
tensor
([
0.7
,
0.1
,
0.2
]),
"expected"
:
torch
.
tensor
([
0.7
,
0.1
,
0.2
]),
},
)
for
case
in
test_cases
:
config
=
{
"do_sample"
:
True
,
"temperature"
:
1.0
,
"k"
:
0
,
"p"
:
case
[
"p"
],
"repetition_penalty"
:
1.0
,
}
sampler
=
Sampler
(
**
config
)
filtered_logits
=
sampler
.
apply_nucleus_filter
(
case
[
"logits"
])
np
.
testing
.
assert_array_equal
(
case
[
"expected"
].
numpy
(),
filtered_logits
.
numpy
())
def
test_top_k_filter
(
self
):
inf
=
-
float
(
"Inf"
)
test_cases
=
(
{
"k"
:
0
,
"logits"
:
torch
.
tensor
([
0.7
,
0.1
,
0.2
]),
"expected"
:
torch
.
tensor
([
0.7
,
0.1
,
0.2
]),
},
{
"k"
:
1
,
"logits"
:
torch
.
tensor
([
0.7
,
0.1
,
0.2
]),
"expected"
:
torch
.
tensor
([
0.7
,
inf
,
inf
]),
},
{
"k"
:
2
,
"logits"
:
torch
.
tensor
([
0.7
,
0.1
,
0.2
]),
"expected"
:
torch
.
tensor
([
0.7
,
inf
,
0.2
]),
},
{
"k"
:
3
,
"logits"
:
torch
.
tensor
([
0.7
,
0.1
,
0.2
]),
"expected"
:
torch
.
tensor
([
0.7
,
0.1
,
0.2
]),
},
)
for
case
in
test_cases
:
config
=
{
"do_sample"
:
True
,
"temperature"
:
1.0
,
"k"
:
case
[
"k"
],
"p"
:
0
,
"repetition_penalty"
:
1.0
,
}
sampler
=
Sampler
(
**
config
)
filtered_logits
=
sampler
.
apply_top_k_filter
(
case
[
"logits"
])
np
.
testing
.
assert_array_equal
(
case
[
"expected"
].
numpy
(),
filtered_logits
.
numpy
())
@
pytest
.
mark
.
skipif
(
sys
.
version_info
<
(
3
,
2
),
reason
=
"assertWarns() requires Python >= 3.2"
)
def
test_wrong_k_value
(
self
):
case
=
{
"k"
:
10
,
"vocab_size"
:
5
}
config
=
{
"do_sample"
:
True
,
"temperature"
:
1.0
,
"k"
:
case
[
"k"
],
"p"
:
0
,
"repetition_penalty"
:
1.0
,
}
sampler
=
Sampler
(
**
config
)
next_token_logits
=
torch
.
rand
(
case
[
"vocab_size"
]).
unsqueeze
(
0
)
past_sequence
=
torch
.
tensor
([])
with
self
.
assertWarns
(
UserWarning
):
_
=
sampler
.
get_one_token
(
next_token_logits
,
past_sequence
)
def
test_zero_temperature
(
self
):
temperature
=
0
config
=
{
"do_sample"
:
True
,
"temperature"
:
temperature
,
"k"
:
0
,
"p"
:
0
,
"repetition_penalty"
:
1.0
,
}
sampler
=
Sampler
(
**
config
)
next_token_logits
=
torch
.
rand
(
10
).
unsqueeze
(
0
)
past_sequence
=
torch
.
tensor
([])
with
self
.
assertRaises
(
ZeroDivisionError
):
_
=
sampler
.
get_one_token
(
next_token_logits
,
past_sequence
)
class
SamplerSingleStackTest
(
unittest
.
TestCase
):
def
test_raises_exception_when_no_LM_head
(
self
):
models
=
[
BertModel
(
BertConfig
())]
for
model
in
models
:
with
self
.
assertRaises
(
AttributeError
):
model
.
decode
()
@
pytest
.
mark
.
slow
def
test_forward_pass_and_output_length
(
self
):
models
=
{
"XLNet"
:
XLNetLMHeadModel
(
XLNetConfig
()),
"XLM"
:
XLMWithLMHeadModel
(
XLMConfig
()),
"TransfoXL"
:
TransfoXLLMHeadModel
(
TransfoXLConfig
()),
"GPT2"
:
GPT2LMHeadModel
(
GPT2Config
()),
"GPT"
:
OpenAIGPTLMHeadModel
(
OpenAIGPTConfig
()),
}
kwargs
=
{
"XLNet"
:
{},
"XLM"
:
{
"mask_token"
:
0
},
"TransfoXL"
:
{},
"GPT2"
:
{},
"GPT"
:
{},
}
prompt
=
torch
.
tensor
([[
1
,
2
,
3
]],
dtype
=
torch
.
long
)
generated_length
=
5
expected_length
=
8
for
name
,
model
in
models
.
items
():
kwargs_model
=
kwargs
[
name
]
output
=
model
.
decode
(
prompt_ids
=
prompt
,
length
=
generated_length
,
**
kwargs_model
)
self
.
assertEqual
(
len
(
output
),
expected_length
)
class
SamplerEncoderDecoderTest
(
unittest
.
TestCase
):
@
pytest
.
mark
.
slow
def
test_forward_pass_and_output_length
(
self
):
model
=
Model2Model
.
from_pretrained
(
"bert-base-uncased"
)
encoder_input_ids
=
torch
.
tensor
([[
1
,
2
,
3
]],
dtype
=
torch
.
long
)
prompt
=
torch
.
tensor
([[
1
,
2
,
3
]],
dtype
=
torch
.
long
)
generated_length
=
5
expected_length
=
8
output
=
model
.
decode
(
encoder_input_ids
,
decoder_prompt_ids
=
prompt
,
k
=
2
,
p
=
0.5
,
repetition_penalty
=
2
,
length
=
generated_length
,
)
self
.
assertEqual
(
len
(
output
),
expected_length
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment