Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
300ec300
Commit
300ec300
authored
Dec 21, 2019
by
thomwolf
Browse files
fixing run_generation example - using torch.no_grad
parent
1c377468
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
32 additions
and
38 deletions
+32
-38
examples/run_generation.py
examples/run_generation.py
+14
-17
transformers/configuration_xlm.py
transformers/configuration_xlm.py
+2
-2
transformers/modeling_utils.py
transformers/modeling_utils.py
+13
-16
transformers/modeling_xlm.py
transformers/modeling_xlm.py
+3
-3
No files found.
examples/run_generation.py
View file @
300ec300
...
...
@@ -87,11 +87,11 @@ def prepare_ctrl_input(args, _, tokenizer, prompt_text):
logger
.
info
(
"WARNING! You are not starting your generation from a control code so you won't get good results"
)
return
prompt_text
,
{}
return
prompt_text
def
prepare_xlm_input
(
args
,
model
,
tokenizer
,
prompt_text
):
kwargs
=
{
"language"
:
None
,
"mask_token_id"
:
None
}
#
kwargs = {"language": None, "mask_token_id": None}
# Set the language
use_lang_emb
=
hasattr
(
model
.
config
,
"use_lang_emb"
)
and
model
.
config
.
use_lang_emb
...
...
@@ -107,14 +107,15 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text):
+
str
(
list
(
available_languages
))
+
" >>> "
)
kwargs
[
"language"
]
=
tokenizer
.
lang2id
[
language
]
#
kwargs["language"] = tokenizer.lang2id[language]
# TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers
# XLM masked-language modeling (MLM) models need masked token
is_xlm_mlm
=
"mlm"
in
args
.
model_name_or_path
if
is_xlm_mlm
:
kwargs
[
"mask_token_id"
]
=
tokenizer
.
mask_token_id
#
is_xlm_mlm = "mlm" in args.model_name_or_path
#
if is_xlm_mlm:
#
kwargs["mask_token_id"] = tokenizer.mask_token_id
return
prompt_text
,
kwargs
return
prompt_text
def
prepare_xlnet_input
(
args
,
_
,
tokenizer
,
prompt_text
):
...
...
@@ -179,8 +180,8 @@ def main():
try
:
args
.
model_type
=
args
.
model_type
.
lower
()
model_class
,
tokenizer_class
=
MODEL_CLASSES
[
args
.
model_type
]
except
KeyError
as
ke
:
raise
ke
(
except
KeyError
:
raise
KeyError
(
"the model {} you specified is not supported. You are welcome to add it and open a PR :)"
)
...
...
@@ -197,10 +198,9 @@ def main():
# Different models need different input formatting and/or extra arguments
requires_preprocessing
=
args
.
model_type
in
PREPROCESSING_FUNCTIONS
.
keys
()
model_kwargs
=
{}
if
requires_preprocessing
:
prepare_input
=
PREPROCESSING_FUNCTIONS
.
get
(
args
.
model_type
)
prompt_text
,
model_kwargs
=
prepare_input
(
args
,
model
,
tokenizer
,
prompt_text
)
prompt_text
=
prepare_input
(
args
,
model
,
tokenizer
,
prompt_text
)
encoded_prompt
=
tokenizer
.
encode
(
prompt_text
,
add_special_tokens
=
False
,
return_tensors
=
'pt'
)
output_sequences
=
model
.
generate
(
...
...
@@ -210,14 +210,11 @@ def main():
top_k
=
args
.
k
,
top_p
=
args
.
p
,
repetition_penalty
=
args
.
repetition_penalty
,
**
model_kwargs
,
)
generated_sequence
=
output_sequences
.
tolist
()[
encoded_prompt
.
size
(
1
)
:
]
# adapted to case where num_samples > 1
text
=
tokenizer
.
decode
(
generated_sequence
,
clean_up_tokenization_spaces
=
True
)
text
=
text
[:
text
.
find
(
args
.
stop_token
)
if
args
.
stop_token
else
None
]
generated_sequence
=
output_sequences
.
tolist
()
text
=
[
tokenizer
.
decode
(
seq
,
clean_up_tokenization_spaces
=
True
)
for
seq
in
generated_sequence
]
# text = text[: text.find(args.stop_token) if args.stop_token else None]
print
(
text
)
...
...
transformers/configuration_xlm.py
View file @
300ec300
...
...
@@ -113,8 +113,8 @@ class XLMConfig(PretrainedConfig):
summary_first_dropout
=
0.1
,
start_n_top
=
5
,
end_n_top
=
5
,
mask_token_id
=
0
,
lang_id
=
0
,
mask_token_id
=
0
,
lang_id
=
0
,
**
kwargs
):
"""Constructs XLMConfig.
"""
...
...
transformers/modeling_utils.py
View file @
300ec300
...
...
@@ -489,7 +489,7 @@ class PreTrainedModel(nn.Module):
def
generate
(
self
,
input_ids
=
None
,
max_length
=
None
,
do_sample
=
None
,
num_beams
=
None
,
temperature
=
None
,
top_k
=
None
,
top_p
=
None
,
repetition_penalty
=
None
,
bos_token_id
=
None
,
pad_token_id
=
None
,
eos_token_ids
=
None
,
length_penalty
=
None
,
num_return_sequences
=
None
,
**
model_kwargs
):
length_penalty
=
None
,
num_return_sequences
=
None
):
""" Sequence generator for models with a LM head.
The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
...
...
@@ -519,7 +519,8 @@ class PreTrainedModel(nn.Module):
# We cannot generate if the model does not have a LM head
if
self
.
get_output_embeddings
()
is
None
:
raise
AttributeError
(
"You tried do generated sequences with a model that does not have a LM Head."
)
raise
AttributeError
(
"You tried to generate sequences with a model that does not have a LM Head."
"Please use another model class (e.g. `OpenAIGPTLMHeadModel`)"
)
max_length
=
max_length
if
max_length
is
not
None
else
self
.
config
.
max_length
do_sample
=
do_sample
if
do_sample
is
not
None
else
self
.
config
.
do_sample
...
...
@@ -544,7 +545,7 @@ class PreTrainedModel(nn.Module):
assert
isinstance
(
max_length
,
int
)
and
max_length
>
0
,
"`max_length` should be a strictely positive integer."
assert
isinstance
(
do_sample
,
bool
),
"`do_sample` should be a boolean."
assert
isinstance
(
num_beams
,
int
)
and
num_beams
>
0
,
"`num_beams` should be a strictely positive integer."
# assert temperature > 0, "`temperature` should be
strictely
positive."
# assert temperature >
=
0, "`temperature` should be positive."
assert
isinstance
(
top_k
,
int
)
and
top_k
>=
0
,
"`top_k` should be a positive integer."
assert
0
<=
top_p
<=
1
,
"`top_p` should be between 0 and 1."
assert
repetition_penalty
>=
1.0
,
"`repetition_penalty` should be >= 1."
...
...
@@ -576,13 +577,11 @@ class PreTrainedModel(nn.Module):
output
=
self
.
_generate_beam_search
(
input_ids
,
cur_len
,
max_length
,
do_sample
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
pad_token_id
,
eos_token_ids
,
effective_batch_size
,
length_penalty
,
num_beams
,
vocab_size
,
**
model_kwargs
)
length_penalty
,
num_beams
,
vocab_size
)
else
:
output
=
self
.
_generate_no_beam_search
(
input_ids
,
cur_len
,
max_length
,
do_sample
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
pad_token_id
,
eos_token_ids
,
effective_batch_size
,
**
model_kwargs
)
pad_token_id
,
eos_token_ids
,
effective_batch_size
)
if
num_return_sequences
!=
1
:
output
=
output
.
view
(
batch_size
,
num_return_sequences
,
-
1
)
...
...
@@ -590,19 +589,18 @@ class PreTrainedModel(nn.Module):
def
_generate_no_beam_search
(
self
,
input_ids
,
cur_len
,
max_length
,
do_sample
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
pad_token_id
,
eos_token_ids
,
batch_size
,
**
model_kwargs
):
pad_token_id
,
eos_token_ids
,
batch_size
):
""" Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
"""
# current position / max lengths / length of generated sentences / unfinished sentences
unfinished_sents
=
input_ids
.
new
(
batch_size
).
fill_
(
1
)
# cache compute states
#
TODO: add
cache
d
compute states
pasts
=
None
while
cur_len
<
max_length
:
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
pasts
=
pasts
,
**
model_kwargs
)
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
pasts
=
pasts
)
outputs
=
self
(
**
model_inputs
)
next_token_logits
=
outputs
[
0
][:,
-
1
,
:]
...
...
@@ -614,7 +612,7 @@ class PreTrainedModel(nn.Module):
if
do_sample
:
# Temperature (higher temperature => more likely to sample low probability tokens)
if
temperature
!=
1.0
:
if
temperature
>
0
and
temperature
!=
1.0
:
next_token_logits
=
next_token_logits
/
temperature
# Top-p/top-k filtering
next_token_logits
=
top_k_top_p_filtering
(
next_token_logits
,
top_k
=
top_k
,
top_p
=
top_p
)
...
...
@@ -644,8 +642,7 @@ class PreTrainedModel(nn.Module):
def
_generate_beam_search
(
self
,
input_ids
,
cur_len
,
max_length
,
do_sample
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
pad_token_id
,
eos_token_ids
,
batch_size
,
length_penalty
,
num_beams
,
vocab_size
,
**
model_kwargs
):
length_penalty
,
num_beams
,
vocab_size
):
""" Generate sequences for each example with beam search.
"""
# Expand input to num beams
...
...
@@ -667,7 +664,7 @@ class PreTrainedModel(nn.Module):
done
=
[
False
for
_
in
range
(
batch_size
)]
while
cur_len
<
max_length
:
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
pasts
=
pasts
,
**
model_kwargs
)
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
pasts
=
pasts
)
scores
=
self
(
**
model_inputs
)[
0
]
# (batch_size * num_beams, cur_len, vocab_size)
scores
=
scores
[:,
-
1
,
:]
# (batch_size * num_beams, vocab_size)
...
...
@@ -679,7 +676,7 @@ class PreTrainedModel(nn.Module):
if
do_sample
:
# Temperature (higher temperature => more likely to sample low probability tokens)
if
temperature
!=
1.0
:
if
temperature
>
0
and
temperature
!=
1.0
:
scores
=
scores
/
temperature
# Top-p/top-k filtering
scores
=
top_k_top_p_filtering
(
scores
,
top_k
=
top_k
,
top_p
=
top_p
,
min_tokens_to_keep
=
2
)
# (batch_size * num_beams, vocab_size)
...
...
transformers/modeling_xlm.py
View file @
300ec300
...
...
@@ -639,9 +639,9 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
def
get_output_embeddings
(
self
):
return
self
.
pred_layer
.
proj
def
prepare_inputs_for_generation
(
self
,
input_ids
,
**
model_
kwargs
):
mask_token_id
=
model_kwargs
[
'mask_token_id'
]
if
'mask_token_id'
in
model_kwargs
else
self
.
config
.
mask_token_id
lang_id
=
model_kwargs
[
'lang_id'
]
if
'lang_id'
in
model_kwargs
else
self
.
config
.
lang_id
def
prepare_inputs_for_generation
(
self
,
input_ids
,
**
kwargs
):
mask_token_id
=
self
.
config
.
mask_token_id
lang_id
=
self
.
config
.
lang_id
mask_token
=
torch
.
full
((
1
,
1
),
mask_token_id
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
input_ids
=
torch
.
cat
([
input_ids
,
mask_token
],
dim
=
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment