Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8e5587fb
"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "b73dd1a0e441e2fc215c013dc4ea2fed65db6c5f"
Commit
8e5587fb
authored
Dec 18, 2019
by
thomwolf
Browse files
few fixes on sampling
parent
641a8dec
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
42 additions
and
53 deletions
+42
-53
transformers/modeling_utils.py
transformers/modeling_utils.py
+42
-53
No files found.
transformers/modeling_utils.py
View file @
8e5587fb
...
@@ -23,14 +23,12 @@ import json
...
@@ -23,14 +23,12 @@ import json
import
logging
import
logging
import
os
import
os
from
io
import
open
from
io
import
open
import
warnings
import
six
import
six
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torch.nn
import
CrossEntropyLoss
from
torch.nn
import
CrossEntropyLoss
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
tqdm
import
trange
from
.configuration_utils
import
PretrainedConfig
from
.configuration_utils
import
PretrainedConfig
from
.file_utils
import
cached_path
,
WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
from
.file_utils
import
cached_path
,
WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
...
@@ -82,7 +80,6 @@ class PreTrainedModel(nn.Module):
...
@@ -82,7 +80,6 @@ class PreTrainedModel(nn.Module):
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
))
))
# Save config in model
# Save config in model
self
.
config
=
config
self
.
config
=
config
...
@@ -220,9 +217,6 @@ class PreTrainedModel(nn.Module):
...
@@ -220,9 +217,6 @@ class PreTrainedModel(nn.Module):
# Tie weights if needed
# Tie weights if needed
self
.
tie_weights
()
self
.
tie_weights
()
# Initialize decoding head if we have output embeddings
def
prune_heads
(
self
,
heads_to_prune
):
def
prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the base model.
""" Prunes heads of the base model.
...
@@ -569,30 +563,36 @@ class PreTrainedModel(nn.Module):
...
@@ -569,30 +563,36 @@ class PreTrainedModel(nn.Module):
cur_len
=
input_ids
.
shape
[
1
]
cur_len
=
input_ids
.
shape
[
1
]
vocab_size
=
self
.
config
.
vocab_size
vocab_size
=
self
.
config
.
vocab_size
if
num_return_sequences
!=
1
:
# Expand input to num return sequences
input_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
batch_size
,
num_return_sequences
,
cur_len
)
input_ids
=
input_ids
.
contiguous
().
view
(
batch_size
*
num_return_sequences
,
cur_len
)
# (batch_size * num_return_sequences, cur_len)
effective_batch_size
=
batch_size
*
num_return_sequences
else
:
effective_batch_size
=
batch_size
if
num_beams
>
1
:
if
num_beams
>
1
:
return
self
.
_generate_beam_search
(
input_ids
,
cur_len
,
max_length
,
do_sample
,
output
=
self
.
_generate_beam_search
(
input_ids
,
cur_len
,
max_length
,
do_sample
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
pad_token_id
,
eos_token_ids
,
batch_size
,
pad_token_id
,
eos_token_ids
,
effective_
batch_size
,
num_return_sequences
,
length_penalty
,
num_beams
,
vocab_size
)
length_penalty
,
num_beams
,
vocab_size
)
else
:
return
self
.
_generate_no_beam_search
(
input_ids
,
cur_len
,
max_length
,
do_sample
,
output
=
self
.
_generate_no_beam_search
(
input_ids
,
cur_len
,
max_length
,
do_sample
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
pad_token_id
,
eos_token_ids
,
batch_size
,
pad_token_id
,
eos_token_ids
,
effective_batch_size
)
num_return_sequences
)
if
num_return_sequences
!=
1
:
output
=
output
.
view
(
batch_size
,
num_return_sequences
,
-
1
)
return
output
def
_generate_no_beam_search
(
self
,
input_ids
,
cur_len
,
max_length
,
do_sample
,
def
_generate_no_beam_search
(
self
,
input_ids
,
cur_len
,
max_length
,
do_sample
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
pad_token_id
,
eos_token_ids
,
batch_size
,
pad_token_id
,
eos_token_ids
,
batch_size
):
num_return_sequences
):
""" Generate sequences for each example without beam search (num_beams == 1).
""" Generate `num_return_sequences` sequences per batch example without beam search (num_beams == 1).
All returned sequence are generated independantly.
All returned sequence are generated independantly.
"""
"""
# Expand input to num return sequences
input_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
batch_size
,
num_return_sequences
,
cur_len
)
input_ids
=
input_ids
.
contiguous
().
view
(
batch_size
*
num_return_sequences
,
cur_len
)
# (batch_size*num_return_sequences, cur_len)
# current position / max lengths / length of generated sentences / unfinished sentences
# current position / max lengths / length of generated sentences / unfinished sentences
unfinished_sents
=
input_ids
.
new
(
batch_size
*
num_return_sequences
).
fill_
(
1
)
unfinished_sents
=
input_ids
.
new
(
batch_size
).
fill_
(
1
)
# cache compute states
# cache compute states
pasts
=
None
pasts
=
None
...
@@ -604,7 +604,7 @@ class PreTrainedModel(nn.Module):
...
@@ -604,7 +604,7 @@ class PreTrainedModel(nn.Module):
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if
repetition_penalty
!=
1.0
:
if
repetition_penalty
!=
1.0
:
for
i
in
range
(
batch_size
*
num_return_sequences
):
for
i
in
range
(
batch_size
):
for
previous_tokens
in
set
(
input_ids
[
i
].
tolist
()):
for
previous_tokens
in
set
(
input_ids
[
i
].
tolist
()):
next_token_logits
[
i
,
previous_tokens
]
/=
repetition_penalty
next_token_logits
[
i
,
previous_tokens
]
/=
repetition_penalty
...
@@ -635,22 +635,14 @@ class PreTrainedModel(nn.Module):
...
@@ -635,22 +635,14 @@ class PreTrainedModel(nn.Module):
if
cur_len
==
max_length
:
if
cur_len
==
max_length
:
input_ids
[:,
-
1
].
masked_fill_
(
unfinished_sents
.
to
(
dtype
=
torch
.
bool
),
eos_token_ids
[
0
])
input_ids
[:,
-
1
].
masked_fill_
(
unfinished_sents
.
to
(
dtype
=
torch
.
bool
),
eos_token_ids
[
0
])
if
num_return_sequences
!=
1
:
input_ids
=
input_ids
.
view
(
batch_size
,
num_return_sequences
,
-
1
)
return
input_ids
return
input_ids
def
_generate_beam_search
(
self
,
input_ids
,
cur_len
,
max_length
,
do_sample
,
def
_generate_beam_search
(
self
,
input_ids
,
cur_len
,
max_length
,
do_sample
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
pad_token_id
,
eos_token_ids
,
batch_size
,
pad_token_id
,
eos_token_ids
,
batch_size
,
num_return_sequences
,
length_penalty
,
num_beams
,
vocab_size
):
length_penalty
,
num_beams
,
vocab_size
):
""" Generate `num_return_sequences` sequences per batch example with beam search.
""" Generate sequences for each example with beam search.
We return the top-`num_return_sequences` beams.
`num_return_sequences` should be bigger than `num_beams` (we default to the min of both)
"""
"""
num_return_sequences
=
min
(
num_return_sequences
,
num_beams
)
# Expand input to num beams
# Expand input to num beams
input_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
batch_size
,
num_beams
,
cur_len
)
input_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
batch_size
,
num_beams
,
cur_len
)
input_ids
=
input_ids
.
contiguous
().
view
(
batch_size
*
num_beams
,
cur_len
)
# (batch_size * num_beams, cur_len)
input_ids
=
input_ids
.
contiguous
().
view
(
batch_size
*
num_beams
,
cur_len
)
# (batch_size * num_beams, cur_len)
...
@@ -685,7 +677,7 @@ class PreTrainedModel(nn.Module):
...
@@ -685,7 +677,7 @@ class PreTrainedModel(nn.Module):
if
temperature
!=
1.0
:
if
temperature
!=
1.0
:
scores
=
scores
/
temperature
scores
=
scores
/
temperature
# Top-p/top-k filtering
# Top-p/top-k filtering
scores
=
top_k_top_p_filtering
(
scores
,
top_k
=
top_k
,
top_p
=
top_p
)
# (batch_size * num_beams, vocab_size)
scores
=
top_k_top_p_filtering
(
scores
,
top_k
=
top_k
,
top_p
=
top_p
,
min_tokens_to_keep
=
2
)
# (batch_size * num_beams, vocab_size)
# Sample 2 next words for each beam (so we have some spare tokens and match output of greedy beam search)
# Sample 2 next words for each beam (so we have some spare tokens and match output of greedy beam search)
next_words
=
torch
.
multinomial
(
F
.
softmax
(
scores
,
dim
=-
1
),
num_samples
=
2
)
# (batch_size * num_beams, 2)
next_words
=
torch
.
multinomial
(
F
.
softmax
(
scores
,
dim
=-
1
),
num_samples
=
2
)
# (batch_size * num_beams, 2)
# Compute next scores
# Compute next scores
...
@@ -778,41 +770,35 @@ class PreTrainedModel(nn.Module):
...
@@ -778,41 +770,35 @@ class PreTrainedModel(nn.Module):
# print("")
# print("")
# select the best hypotheses
# select the best hypotheses
tgt_len
=
input_ids
.
new
(
batch_size
,
num_return_sequences
)
tgt_len
=
input_ids
.
new
(
batch_size
)
best
s
=
[]
best
=
[]
for
i
,
hypotheses
in
enumerate
(
generated_hyps
):
for
i
,
hypotheses
in
enumerate
(
generated_hyps
):
best_hyps
=
[
hyp
[
1
]
for
hyp
in
sorted
(
hypotheses
.
hyp
,
key
=
lambda
hyp
:
hyp
[
0
])[
-
num_return_sequences
:]]
best_hyp
=
max
(
hypotheses
.
hyp
,
key
=
lambda
x
:
x
[
0
])[
1
]
for
j
,
hyp
in
enumerate
(
best_hyps
):
tgt_len
[
i
]
=
len
(
best_hyp
)
+
1
# +1 for the <EOS> symbol
tgt_len
[
i
,
j
]
=
len
(
hyp
)
+
1
# +1 for the <EOS> symbol
best
.
append
(
best_hyp
)
bests
.
append
(
best_hyps
)
# generate target batch
# generate target batch
decoded
=
input_ids
.
new
(
batch_size
,
num_return_sequences
,
tgt_len
.
max
().
item
()).
fill_
(
pad_token_id
)
decoded
=
input_ids
.
new
(
batch_size
,
tgt_len
.
max
().
item
()).
fill_
(
pad_token_id
)
for
i
,
hyps
in
enumerate
(
bests
):
for
i
,
hypo
in
enumerate
(
best
):
for
j
,
hypo
in
enumerate
(
hyps
):
decoded
[
i
,
:
tgt_len
[
i
]
-
1
]
=
hypo
decoded
[
i
,
j
,
:
tgt_len
[
i
,
j
]
-
1
]
=
hypo
decoded
[
i
,
tgt_len
[
i
]
-
1
]
=
eos_token_ids
[
0
]
decoded
[
i
,
j
,
tgt_len
[
i
,
j
]
-
1
]
=
eos_token_ids
[
0
]
if
num_return_sequences
==
1
:
decoded
=
decoded
.
squeeze
(
1
)
# # sanity check
# assert (decoded == eos_token_ids[0]).sum() == 2 * batch_size
return
decoded
return
decoded
def
top_k_top_p_filtering
(
logits
,
top_k
=
0
,
top_p
=
1.0
,
filter_value
=-
float
(
'Inf'
)):
def
top_k_top_p_filtering
(
logits
,
top_k
=
0
,
top_p
=
1.0
,
filter_value
=-
float
(
'Inf'
)
,
min_tokens_to_keep
=
1
):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
Args:
logits: logits distribution shape (batch size
x
vocabulary size)
logits: logits distribution shape (batch size
,
vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens_to_keep per batch example in the output
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
"""
top_k
=
min
(
top_k
,
logits
.
size
(
-
1
))
# Safety check
if
top_k
>
0
:
if
top_k
>
0
:
top_k
=
min
(
max
(
top_k
,
min_tokens_to_keep
),
logits
.
size
(
-
1
))
# Safety check
# Remove all tokens with a probability less than the last token of the top-k
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove
=
logits
<
torch
.
topk
(
logits
,
top_k
)[
0
][...,
-
1
,
None
]
indices_to_remove
=
logits
<
torch
.
topk
(
logits
,
top_k
)[
0
][...,
-
1
,
None
]
logits
[
indices_to_remove
]
=
filter_value
logits
[
indices_to_remove
]
=
filter_value
...
@@ -821,8 +807,11 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')
...
@@ -821,8 +807,11 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
True
)
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
True
)
cumulative_probs
=
torch
.
cumsum
(
F
.
softmax
(
sorted_logits
,
dim
=-
1
),
dim
=-
1
)
cumulative_probs
=
torch
.
cumsum
(
F
.
softmax
(
sorted_logits
,
dim
=-
1
),
dim
=-
1
)
# Remove tokens with cumulative probability above the threshold
# Remove tokens with cumulative probability above the threshold
(token with 0 are kept)
sorted_indices_to_remove
=
cumulative_probs
>
top_p
sorted_indices_to_remove
=
cumulative_probs
>
top_p
if
min_tokens_to_keep
>
1
:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove
[...,
:
min_tokens_to_keep
]
=
0
# Shift the indices to the right to keep also the first token above the threshold
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove
[...,
1
:]
=
sorted_indices_to_remove
[...,
:
-
1
].
clone
()
sorted_indices_to_remove
[...,
1
:]
=
sorted_indices_to_remove
[...,
:
-
1
].
clone
()
sorted_indices_to_remove
[...,
0
]
=
0
sorted_indices_to_remove
[...,
0
]
=
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment