Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
333affcb
Commit
333affcb
authored
Mar 06, 2020
by
Patrick von Platen
Browse files
add current changes
parent
42121699
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
5 deletions
+9
-5
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+9
-5
No files found.
src/transformers/modeling_utils.py
View file @
333affcb
...
...
@@ -614,6 +614,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
max_length
=
None
,
min_length
=
None
,
do_sample
=
True
,
early_stopping
=
False
,
num_beams
=
None
,
temperature
=
None
,
top_k
=
None
,
...
...
@@ -720,7 +721,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
max_length
=
max_length
if
max_length
is
not
None
else
self
.
config
.
max_length
min_length
=
min_length
if
min_length
is
not
None
else
self
.
config
.
min_length
do_sample
=
do_sample
if
do_sample
is
not
None
else
self
.
config
.
do_sample
early_stopping
=
early_stopping
if
early_stopping
is
not
None
else
self
.
config
.
early_stopping
num_beams
=
num_beams
if
num_beams
is
not
None
else
self
.
config
.
num_beams
temperature
=
temperature
if
temperature
is
not
None
else
self
.
config
.
temperature
top_k
=
top_k
if
top_k
is
not
None
else
self
.
config
.
top_k
...
...
@@ -747,6 +748,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
assert
isinstance
(
max_length
,
int
)
and
max_length
>
0
,
"`max_length` should be a strictly positive integer."
assert
isinstance
(
min_length
,
int
)
and
min_length
>=
0
,
"`min_length` should be a positive integer."
assert
isinstance
(
do_sample
,
bool
),
"`do_sample` should be a boolean."
assert
isinstance
(
early_stopping
,
bool
),
"`early_stopping` should be a boolean."
assert
isinstance
(
num_beams
,
int
)
and
num_beams
>
0
,
"`num_beams` should be a strictly positive integer."
assert
temperature
>
0
,
"`temperature` should be strictly positive."
assert
isinstance
(
top_k
,
int
)
and
top_k
>=
0
,
"`top_k` should be a positive integer."
...
...
@@ -841,8 +843,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
encoder_inputs
=
input_ids
input_ids
=
torch
.
full
(
(
effective_batch_size
*
num_beams
,
1
),
#
eos_token_id,
bos_token_id
,
eos_token_id
,
#
bos_token_id,
# eos_token_id, # Why eos_token_id here? bos_token_id seems to work as well ... to see if it works as well with hard summarization case
dtype
=
torch
.
long
,
device
=
next
(
self
.
parameters
()).
device
,
...
...
@@ -860,6 +862,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
max_length
,
min_length
,
do_sample
,
early_stopping
,
temperature
,
top_k
,
top_p
,
...
...
@@ -1012,6 +1015,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
max_length
,
min_length
,
do_sample
,
early_stopping
,
temperature
,
top_k
,
top_p
,
...
...
@@ -1033,7 +1037,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# generated hypotheses
generated_hyps
=
[
BeamHypotheses
(
num_beams
,
max_length
,
length_penalty
,
early_stopping
=
False
)
for
_
in
range
(
batch_size
)
BeamHypotheses
(
num_beams
,
max_length
-
1
,
length_penalty
,
early_stopping
=
early_stopping
)
for
_
in
range
(
batch_size
)
]
# scores for each sentence in the beam
...
...
@@ -1080,11 +1084,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# force eos to be chosen at end of generation for encoder-decoder models
# TODO (PVP): both these things are very hacky see whether it might be possible to solve this differently
if
self
.
config
.
is_encoder_decoder
:
# self.prepare_logits_for_softmax(next_token_logits, cur_len, max_length)
if
cur_len
==
1
:
self
.
_force_token_ids_generation
(
next_token_logits
,
bos_token_id
)
if
cur_len
==
max_length
-
1
:
self
.
_force_token_ids_generation
(
next_token_logits
,
eos_token_ids
)
# self.prepare_logits_for_softmax(next_token_logits, cur_len, max_length)
if
do_sample
:
# Temperature (higher temperature => more likely to sample low probability tokens)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment