Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c0d9dd3b
Commit
c0d9dd3b
authored
Mar 05, 2020
by
Patrick von Platen
Browse files
refactored code a bit and made more generic
parent
d8e2b3c5
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
28 additions
and
16 deletions
+28
-16
src/transformers/configuration_utils.py
src/transformers/configuration_utils.py
+1
-0
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+25
-14
tests/test_modeling_bart.py
tests/test_modeling_bart.py
+1
-1
tests/test_modeling_common.py
tests/test_modeling_common.py
+1
-1
No files found.
src/transformers/configuration_utils.py
View file @
c0d9dd3b
...
...
@@ -69,6 +69,7 @@ class PretrainedConfig(object):
# Parameters for sequence generation
self
.
max_length
=
kwargs
.
pop
(
"max_length"
,
20
)
self
.
min_length
=
kwargs
.
pop
(
"max_length"
,
0
)
self
.
do_sample
=
kwargs
.
pop
(
"do_sample"
,
False
)
self
.
early_stopping
=
kwargs
.
pop
(
"early_stopping"
,
False
)
self
.
num_beams
=
kwargs
.
pop
(
"num_beams"
,
1
)
...
...
src/transformers/modeling_utils.py
View file @
c0d9dd3b
...
...
@@ -609,6 +609,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
self
,
input_ids
=
None
,
max_length
=
None
,
min_length
=
None
,
do_sample
=
True
,
num_beams
=
None
,
temperature
=
None
,
...
...
@@ -713,6 +714,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
)
max_length
=
max_length
if
max_length
is
not
None
else
self
.
config
.
max_length
min_length
=
min_length
if
min_length
is
not
None
else
self
.
config
.
min_length
do_sample
=
do_sample
if
do_sample
is
not
None
else
self
.
config
.
do_sample
num_beams
=
num_beams
if
num_beams
is
not
None
else
self
.
config
.
num_beams
temperature
=
temperature
if
temperature
is
not
None
else
self
.
config
.
temperature
...
...
@@ -735,6 +737,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
eos_token_ids
=
[
eos_token_ids
]
assert
isinstance
(
max_length
,
int
)
and
max_length
>
0
,
"`max_length` should be a strictly positive integer."
assert
isinstance
(
min_length
,
int
)
and
min_length
>=
0
,
"`min_length` should be a positive integer."
assert
isinstance
(
do_sample
,
bool
),
"`do_sample` should be a boolean."
assert
isinstance
(
num_beams
,
int
)
and
num_beams
>
0
,
"`num_beams` should be a strictly positive integer."
assert
temperature
>
0
,
"`temperature` should be strictly positive."
...
...
@@ -824,12 +827,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
encoder_inputs
=
input_ids
input_ids
=
torch
.
full
(
(
effective_batch_size
*
num_beams
,
1
),
# eos_token_id,
# eos_token_id,
# Why eos_token_id here? bos_token_id makes more sense no?
bos_token_id
,
dtype
=
torch
.
long
,
device
=
next
(
self
.
parameters
()).
device
,
)
cur_len
=
0
cur_len
=
1
self
.
model
.
decoder
.
generation_mode
=
True
else
:
encoder_inputs
=
None
...
...
@@ -840,6 +843,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
input_ids
,
cur_len
,
max_length
,
min_length
,
do_sample
,
temperature
,
top_k
,
...
...
@@ -859,6 +863,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
input_ids
,
cur_len
,
max_length
,
min_length
,
do_sample
,
temperature
,
top_k
,
...
...
@@ -877,6 +882,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
input_ids
,
cur_len
,
max_length
,
min_length
,
do_sample
,
temperature
,
top_k
,
...
...
@@ -911,6 +917,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if
repetition_penalty
!=
1.0
:
self
.
enforce_repetition_penalty_
(
next_token_logits
,
batch_size
,
1
,
input_ids
,
repetition_penalty
)
if
eos_token_ids
is
not
None
and
cur_len
<
min_length
:
for
eos_token_id
in
eos_token_ids
:
next_token_logits
[:,
eos_token_id
]
=
-
10000.0
# set eos token prob to 0 as is done for attention masks
if
do_sample
:
# Temperature (higher temperature => more likely to sample low probability tokens)
if
temperature
!=
1.0
:
...
...
@@ -965,6 +975,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
input_ids
,
cur_len
,
max_length
,
min_length
,
do_sample
,
temperature
,
top_k
,
...
...
@@ -1022,6 +1033,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_token_logits
,
batch_size
,
num_beams
,
input_ids
,
repetition_penalty
)
if
eos_token_ids
is
not
None
and
cur_len
<
min_length
:
for
eos_token_id
in
eos_token_ids
:
next_token_logits
[:,
eos_token_id
]
=
-
10000.0
# set eos token prob to 0 as is done for attention masks
if
do_sample
:
# Temperature (higher temperature => more likely to sample low probability tokens)
if
temperature
!=
1.0
:
...
...
@@ -1056,18 +1071,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
scores
=
F
.
log_softmax
(
next_token_logits
,
dim
=-
1
)
# (batch_size * num_beams, vocab_size)
if
is_encoder_decoder
:
# TODO(PVP) to be refactored later
import
math
# scores[scores != scores] = -math.inf # block nans
# scores[:, pad_token_id] = -math.inf
# scores[scores != scores] = -math.inf # block nans => seems very hacky here
# scores[:, pad_token_id] = -math.inf => seems very hacky here
# TODO(SS): fairseq also takes out <unk> every step, and has unk at slot 3
# if cur_len == 0: # Force BOS to be chosen
# scores[:, self.config.bos_token_id + 1 :] = -math.inf # TODO(PVP) should not use bos_token_id here
# elif cur_len < min_len: # Prevent EOS from being chosen TODO: for the moment don't think about min_len
# scores[:, eos_token_ids[0]] = -math.inf
# elif cur_len == max_length: # FORCE EOS to be chosen
if
cur_len
==
max_length
:
# FORCE EOS to be chosen
scores
[:,
:
eos_token_ids
[
0
]]
=
-
math
.
inf
scores
[:,
eos_token_ids
[
0
]
+
1
:]
=
-
math
.
inf
# if cur_len == 0: # Force BOS to be chosen => also very hacky ... seems also to work without this line
# scores[:, self.config.bos_token_id + 1 :] = -math.inf
if
cur_len
==
max_length
-
1
:
# FORCE EOS to be chosen
all_but_eos_mask
=
torch
.
tensor
([
x
for
x
in
range
(
vocab_size
)
if
x
not
in
eos_token_ids
],
dtype
=
torch
.
long
,
device
=
next
(
self
.
parameters
()).
device
)
scores
[:,
all_but_eos_mask
]
=
-
10000.0
assert
scores
.
size
()
==
(
batch_size
*
num_beams
,
vocab_size
)
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
...
...
@@ -1194,7 +1205,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# shorter batches are filled with pad_token
if
sent_lengths
.
min
().
item
()
!=
sent_lengths
.
max
().
item
():
assert
pad_token_id
is
not
None
,
"`Pad_token_id` has to be defined"
sent_max_len
=
min
(
sent_lengths
.
max
().
item
()
+
1
,
max_length
+
1
)
sent_max_len
=
min
(
sent_lengths
.
max
().
item
()
+
1
,
max_length
)
decoded
=
input_ids
.
new
(
output_batch_size
,
sent_max_len
).
fill_
(
pad_token_id
)
# fill with hypothesis and eos_token_id if necessary
...
...
tests/test_modeling_bart.py
View file @
c0d9dd3b
...
...
@@ -442,7 +442,7 @@ class BartModelIntegrationTest(unittest.TestCase):
tokens
=
tok
.
encode
(
text
,
return_tensors
=
"pt"
).
to
(
torch_device
)
extra_len
=
20
gen_tokens_1
=
hf
.
generate_1
(
tokens
,
num_beams
=
4
,
max_length
=
extra_len
,)
# repetition_penalty=10.,
gen_tokens
=
hf
.
generate
(
tokens
,
num_beams
=
4
,
max_length
=
extra_len
,
do_sample
=
False
)
# repetition_penalty=10.,
gen_tokens
=
hf
.
generate
(
tokens
,
num_beams
=
4
,
max_length
=
extra_len
+
2
,
do_sample
=
False
)
# repetition_penalty=10.,
print
(
"1: {}"
.
format
(
gen_tokens_1
))
print
(
"2: {}"
.
format
(
gen_tokens
))
ipdb
.
set_trace
()
...
...
tests/test_modeling_common.py
View file @
c0d9dd3b
...
...
@@ -621,7 +621,7 @@ class ModelTesterMixin:
with
torch
.
no_grad
():
model
(
**
inputs_dict
)
def
_A_
test_lm_head_model_random_generate
(
self
):
def
test_lm_head_model_random_generate
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
input_ids
=
inputs_dict
.
get
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment