Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1c377468
Commit
1c377468
authored
Dec 21, 2019
by
thomwolf
Browse files
fixing run_generation
parent
3d2096f5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
9 additions
and
10 deletions
+9
-10
examples/run_generation.py
examples/run_generation.py
+4
-5
transformers/configuration_utils.py
transformers/configuration_utils.py
+0
-1
transformers/modeling_utils.py
transformers/modeling_utils.py
+5
-4
No files found.
examples/run_generation.py
View file @
1c377468
...
@@ -156,7 +156,7 @@ def main():
...
@@ -156,7 +156,7 @@ def main():
parser
.
add_argument
(
"--length"
,
type
=
int
,
default
=
20
)
parser
.
add_argument
(
"--length"
,
type
=
int
,
default
=
20
)
parser
.
add_argument
(
"--stop_token"
,
type
=
str
,
default
=
None
,
help
=
"Token at which text generation is stopped"
)
parser
.
add_argument
(
"--stop_token"
,
type
=
str
,
default
=
None
,
help
=
"Token at which text generation is stopped"
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
,
help
=
"temperature of
0 implies
greedy sampling"
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
,
help
=
"temperature of
1.0 has no effect, lower tend toward
greedy sampling"
)
parser
.
add_argument
(
"--repetition_penalty"
,
type
=
float
,
default
=
1.0
,
help
=
"primarily useful for CTRL model; in that case, use 1.2"
)
parser
.
add_argument
(
"--repetition_penalty"
,
type
=
float
,
default
=
1.0
,
help
=
"primarily useful for CTRL model; in that case, use 1.2"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--p"
,
type
=
float
,
default
=
0.9
)
parser
.
add_argument
(
"--p"
,
type
=
float
,
default
=
0.9
)
...
@@ -187,7 +187,6 @@ def main():
...
@@ -187,7 +187,6 @@ def main():
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
model_name_or_path
)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
model_class
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
model_class
.
from_pretrained
(
args
.
model_name_or_path
)
model
.
to
(
args
.
device
)
model
.
to
(
args
.
device
)
model
.
eval
()
args
.
length
=
adjust_length_to_model
(
args
.
length
=
adjust_length_to_model
(
args
.
length
,
max_sequence_length
=
model
.
config
.
max_position_embeddings
args
.
length
,
max_sequence_length
=
model
.
config
.
max_position_embeddings
...
@@ -202,11 +201,11 @@ def main():
...
@@ -202,11 +201,11 @@ def main():
if
requires_preprocessing
:
if
requires_preprocessing
:
prepare_input
=
PREPROCESSING_FUNCTIONS
.
get
(
args
.
model_type
)
prepare_input
=
PREPROCESSING_FUNCTIONS
.
get
(
args
.
model_type
)
prompt_text
,
model_kwargs
=
prepare_input
(
args
,
model
,
tokenizer
,
prompt_text
)
prompt_text
,
model_kwargs
=
prepare_input
(
args
,
model
,
tokenizer
,
prompt_text
)
encoded_prompt
=
torch
.
tensor
(
tokenizer
.
encode
(
prompt_text
,
add_special_tokens
=
False
)).
unsqueeze
(
0
)
encoded_prompt
=
tokenizer
.
encode
(
prompt_text
,
add_special_tokens
=
False
,
return_tensors
=
'pt'
)
output_sequences
=
model
.
generate
(
output_sequences
=
model
.
generate
(
in
t
put_ids
=
encoded_prompt
,
input_ids
=
encoded_prompt
,
length
=
args
.
length
,
max_
length
=
args
.
length
,
temperature
=
args
.
temperature
,
temperature
=
args
.
temperature
,
top_k
=
args
.
k
,
top_k
=
args
.
k
,
top_p
=
args
.
p
,
top_p
=
args
.
p
,
...
...
transformers/configuration_utils.py
View file @
1c377468
...
@@ -72,7 +72,6 @@ class PretrainedConfig(object):
...
@@ -72,7 +72,6 @@ class PretrainedConfig(object):
self
.
bos_token_id
=
kwargs
.
pop
(
'bos_token_id'
,
0
)
self
.
bos_token_id
=
kwargs
.
pop
(
'bos_token_id'
,
0
)
self
.
pad_token_id
=
kwargs
.
pop
(
'pad_token_id'
,
0
)
self
.
pad_token_id
=
kwargs
.
pop
(
'pad_token_id'
,
0
)
self
.
eos_token_ids
=
kwargs
.
pop
(
'eos_token_ids'
,
0
)
self
.
eos_token_ids
=
kwargs
.
pop
(
'eos_token_ids'
,
0
)
self
.
batch_size
=
kwargs
.
pop
(
'batch_size'
,
1
)
self
.
length_penalty
=
kwargs
.
pop
(
'length_penalty'
,
1.
)
self
.
length_penalty
=
kwargs
.
pop
(
'length_penalty'
,
1.
)
self
.
num_return_sequences
=
kwargs
.
pop
(
'num_return_sequences'
,
1
)
self
.
num_return_sequences
=
kwargs
.
pop
(
'num_return_sequences'
,
1
)
...
...
transformers/modeling_utils.py
View file @
1c377468
...
@@ -485,9 +485,10 @@ class PreTrainedModel(nn.Module):
...
@@ -485,9 +485,10 @@ class PreTrainedModel(nn.Module):
def
prepare_inputs_for_generation
(
self
,
input_ids
,
**
kwargs
):
def
prepare_inputs_for_generation
(
self
,
input_ids
,
**
kwargs
):
return
{
"input_ids"
:
input_ids
}
return
{
"input_ids"
:
input_ids
}
@
torch
.
no_grad
()
def
generate
(
self
,
input_ids
=
None
,
max_length
=
None
,
do_sample
=
None
,
num_beams
=
None
,
def
generate
(
self
,
input_ids
=
None
,
max_length
=
None
,
do_sample
=
None
,
num_beams
=
None
,
temperature
=
None
,
top_k
=
None
,
top_p
=
None
,
repetition_penalty
=
None
,
temperature
=
None
,
top_k
=
None
,
top_p
=
None
,
repetition_penalty
=
None
,
bos_token_id
=
None
,
pad_token_id
=
None
,
eos_token_ids
=
None
,
batch_size
=
None
,
bos_token_id
=
None
,
pad_token_id
=
None
,
eos_token_ids
=
None
,
length_penalty
=
None
,
num_return_sequences
=
None
,
**
model_kwargs
):
length_penalty
=
None
,
num_return_sequences
=
None
,
**
model_kwargs
):
""" Sequence generator for models with a LM head.
""" Sequence generator for models with a LM head.
...
@@ -530,19 +531,20 @@ class PreTrainedModel(nn.Module):
...
@@ -530,19 +531,20 @@ class PreTrainedModel(nn.Module):
bos_token_id
=
bos_token_id
if
bos_token_id
is
not
None
else
self
.
config
.
bos_token_id
bos_token_id
=
bos_token_id
if
bos_token_id
is
not
None
else
self
.
config
.
bos_token_id
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
config
.
pad_token_id
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
config
.
pad_token_id
eos_token_ids
=
eos_token_ids
if
eos_token_ids
is
not
None
else
self
.
config
.
eos_token_ids
eos_token_ids
=
eos_token_ids
if
eos_token_ids
is
not
None
else
self
.
config
.
eos_token_ids
batch_size
=
batch_size
if
batch_size
is
not
None
else
self
.
config
.
batch_size
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
self
.
config
.
length_penalty
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
self
.
config
.
length_penalty
num_return_sequences
=
num_return_sequences
if
num_return_sequences
is
not
None
else
self
.
config
.
num_return_sequences
num_return_sequences
=
num_return_sequences
if
num_return_sequences
is
not
None
else
self
.
config
.
num_return_sequences
if
input_ids
is
not
None
:
if
input_ids
is
not
None
:
batch_size
=
input_ids
.
shape
[
0
]
# overriden by the input batch_size
batch_size
=
input_ids
.
shape
[
0
]
# overriden by the input batch_size
else
:
batch_size
=
1
if
isinstance
(
eos_token_ids
,
int
):
if
isinstance
(
eos_token_ids
,
int
):
eos_token_ids
=
[
eos_token_ids
]
eos_token_ids
=
[
eos_token_ids
]
assert
isinstance
(
max_length
,
int
)
and
max_length
>
0
,
"`max_length` should be a strictely positive integer."
assert
isinstance
(
max_length
,
int
)
and
max_length
>
0
,
"`max_length` should be a strictely positive integer."
assert
isinstance
(
do_sample
,
bool
),
"`do_sample` should be a boolean."
assert
isinstance
(
do_sample
,
bool
),
"`do_sample` should be a boolean."
assert
isinstance
(
num_beams
,
int
)
and
num_beams
>
0
,
"`num_beams` should be a strictely positive integer."
assert
isinstance
(
num_beams
,
int
)
and
num_beams
>
0
,
"`num_beams` should be a strictely positive integer."
assert
temperature
>
0
,
"`temperature` should be strictely positive."
#
assert temperature > 0, "`temperature` should be strictely positive."
assert
isinstance
(
top_k
,
int
)
and
top_k
>=
0
,
"`top_k` should be a positive integer."
assert
isinstance
(
top_k
,
int
)
and
top_k
>=
0
,
"`top_k` should be a positive integer."
assert
0
<=
top_p
<=
1
,
"`top_p` should be between 0 and 1."
assert
0
<=
top_p
<=
1
,
"`top_p` should be between 0 and 1."
assert
repetition_penalty
>=
1.0
,
"`repetition_penalty` should be >= 1."
assert
repetition_penalty
>=
1.0
,
"`repetition_penalty` should be >= 1."
...
@@ -550,7 +552,6 @@ class PreTrainedModel(nn.Module):
...
@@ -550,7 +552,6 @@ class PreTrainedModel(nn.Module):
assert
isinstance
(
pad_token_id
,
int
)
and
pad_token_id
>=
0
,
"`pad_token_id` should be a positive integer."
assert
isinstance
(
pad_token_id
,
int
)
and
pad_token_id
>=
0
,
"`pad_token_id` should be a positive integer."
assert
isinstance
(
eos_token_ids
,
(
list
,
tuple
))
and
(
e
>=
0
for
e
in
eos_token_ids
),
\
assert
isinstance
(
eos_token_ids
,
(
list
,
tuple
))
and
(
e
>=
0
for
e
in
eos_token_ids
),
\
"`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
"`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
assert
isinstance
(
batch_size
,
int
)
and
batch_size
>
0
,
"`batch_size` should be a strictely positive integer."
assert
length_penalty
>
0
,
"`length_penalty` should be strictely positive."
assert
length_penalty
>
0
,
"`length_penalty` should be strictely positive."
assert
isinstance
(
num_return_sequences
,
int
)
and
num_return_sequences
>
0
,
"`num_return_sequences` should be a strictely positive integer."
assert
isinstance
(
num_return_sequences
,
int
)
and
num_return_sequences
>
0
,
"`num_return_sequences` should be a strictely positive integer."
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment