Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a5997dd8
Commit
a5997dd8
authored
Oct 10, 2019
by
thomwolf
Browse files
better error messages
parent
43a237f1
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
57 additions
and
40 deletions
+57
-40
examples/run_generation.py
examples/run_generation.py
+23
-6
transformers/configuration_utils.py
transformers/configuration_utils.py
+9
-10
transformers/modeling_utils.py
transformers/modeling_utils.py
+10
-10
transformers/tokenization_utils.py
transformers/tokenization_utils.py
+15
-14
No files found.
examples/run_generation.py
View file @
a5997dd8
...
@@ -107,7 +107,8 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
...
@@ -107,7 +107,8 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
return
logits
return
logits
def
sample_sequence
(
model
,
length
,
context
,
num_samples
=
1
,
temperature
=
1
,
top_k
=
0
,
top_p
=
0.0
,
repetition_penalty
=
1.0
,
is_xlnet
=
False
,
xlm_lang
=
None
,
device
=
'cpu'
):
def
sample_sequence
(
model
,
length
,
context
,
num_samples
=
1
,
temperature
=
1
,
top_k
=
0
,
top_p
=
0.0
,
repetition_penalty
=
1.0
,
is_xlnet
=
False
,
is_xlm_mlm
=
False
,
xlm_mask_token
=
None
,
xlm_lang
=
None
,
device
=
'cpu'
):
context
=
torch
.
tensor
(
context
,
dtype
=
torch
.
long
,
device
=
device
)
context
=
torch
.
tensor
(
context
,
dtype
=
torch
.
long
,
device
=
device
)
context
=
context
.
unsqueeze
(
0
).
repeat
(
num_samples
,
1
)
context
=
context
.
unsqueeze
(
0
).
repeat
(
num_samples
,
1
)
generated
=
context
generated
=
context
...
@@ -125,10 +126,16 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
...
@@ -125,10 +126,16 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
target_mapping
[
0
,
0
,
-
1
]
=
1.0
# predict last token
target_mapping
[
0
,
0
,
-
1
]
=
1.0
# predict last token
inputs
=
{
'input_ids'
:
input_ids
,
'perm_mask'
:
perm_mask
,
'target_mapping'
:
target_mapping
}
inputs
=
{
'input_ids'
:
input_ids
,
'perm_mask'
:
perm_mask
,
'target_mapping'
:
target_mapping
}
if
is_xlm_mlm
and
xlm_mask_token
:
# XLM MLM models are direct models (predict same token, not next token)
# => need one additional dummy token in the input (will be masked and guessed)
input_ids
=
torch
.
cat
((
generated
,
torch
.
full
((
1
,
1
),
xlm_mask_token
,
dtype
=
torch
.
long
,
device
=
device
)),
dim
=
1
)
inputs
=
{
'input_ids'
:
input_ids
}
if
xlm_lang
is
not
None
:
if
xlm_lang
is
not
None
:
inputs
[
"langs"
]
=
torch
.
tensor
([
xlm_lang
]
*
inputs
[
"input_ids"
].
shape
[
1
],
device
=
device
).
view
(
1
,
-
1
)
inputs
[
"langs"
]
=
torch
.
tensor
([
xlm_lang
]
*
inputs
[
"input_ids"
].
shape
[
1
],
device
=
device
).
view
(
1
,
-
1
)
outputs
=
model
(
**
inputs
)
# Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
outputs
=
model
(
**
inputs
)
# Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet
/CTRL
(cached hidden-states)
next_token_logits
=
outputs
[
0
][
0
,
-
1
,
:]
/
(
temperature
if
temperature
>
0
else
1.
)
next_token_logits
=
outputs
[
0
][
0
,
-
1
,
:]
/
(
temperature
if
temperature
>
0
else
1.
)
# reptition penalty from CTRL (https://arxiv.org/abs/1909.05858)
# reptition penalty from CTRL (https://arxiv.org/abs/1909.05858)
...
@@ -167,10 +174,7 @@ def main():
...
@@ -167,10 +174,7 @@ def main():
parser
.
add_argument
(
'--stop_token'
,
type
=
str
,
default
=
None
,
parser
.
add_argument
(
'--stop_token'
,
type
=
str
,
default
=
None
,
help
=
"Token at which text generation is stopped"
)
help
=
"Token at which text generation is stopped"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
.
model_type
in
[
"ctrl"
]:
if
args
.
temperature
>
0.7
:
print
(
'CTRL typically works better with lower temperatures (and lower top_k).'
)
args
.
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
and
not
args
.
no_cuda
else
"cpu"
)
args
.
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
and
not
args
.
no_cuda
else
"cpu"
)
args
.
n_gpu
=
torch
.
cuda
.
device_count
()
args
.
n_gpu
=
torch
.
cuda
.
device_count
()
...
@@ -191,6 +195,10 @@ def main():
...
@@ -191,6 +195,10 @@ def main():
args
.
length
=
MAX_LENGTH
# avoid infinite loop
args
.
length
=
MAX_LENGTH
# avoid infinite loop
print
(
args
)
print
(
args
)
if
args
.
model_type
in
[
"ctrl"
]:
if
args
.
temperature
>
0.7
:
logger
.
info
(
'CTRL typically works better with lower temperatures (and lower top_k).'
)
while
True
:
while
True
:
xlm_lang
=
None
xlm_lang
=
None
# XLM Language usage detailed in the issues #1414
# XLM Language usage detailed in the issues #1414
...
@@ -204,6 +212,13 @@ def main():
...
@@ -204,6 +212,13 @@ def main():
language
=
input
(
"Using XLM. Select language in "
+
str
(
list
(
tokenizer
.
lang2id
.
keys
()))
+
" >>> "
)
language
=
input
(
"Using XLM. Select language in "
+
str
(
list
(
tokenizer
.
lang2id
.
keys
()))
+
" >>> "
)
xlm_lang
=
tokenizer
.
lang2id
[
language
]
xlm_lang
=
tokenizer
.
lang2id
[
language
]
# XLM masked-language modeling (MLM) models need masked token (see details in sample_sequence)
is_xlm_mlm
=
args
.
model_type
in
[
"xlm"
]
and
'mlm'
in
args
.
model_name_or_path
if
is_xlm_mlm
:
xlm_mask_token
=
tokenizer
.
mask_token_id
else
:
xlm_mask_token
=
None
raw_text
=
args
.
prompt
if
args
.
prompt
else
input
(
"Model prompt >>> "
)
raw_text
=
args
.
prompt
if
args
.
prompt
else
input
(
"Model prompt >>> "
)
if
args
.
model_type
in
[
"transfo-xl"
,
"xlnet"
]:
if
args
.
model_type
in
[
"transfo-xl"
,
"xlnet"
]:
# Models with memory likes to have a long prompt for short inputs.
# Models with memory likes to have a long prompt for short inputs.
...
@@ -218,6 +233,8 @@ def main():
...
@@ -218,6 +233,8 @@ def main():
top_p
=
args
.
top_p
,
top_p
=
args
.
top_p
,
repetition_penalty
=
args
.
repetition_penalty
,
repetition_penalty
=
args
.
repetition_penalty
,
is_xlnet
=
bool
(
args
.
model_type
==
"xlnet"
),
is_xlnet
=
bool
(
args
.
model_type
==
"xlnet"
),
is_xlm_mlm
=
is_xlm_mlm
,
xlm_mask_token
=
xlm_mask_token
,
xlm_lang
=
xlm_lang
,
xlm_lang
=
xlm_lang
,
device
=
args
.
device
,
device
=
args
.
device
,
)
)
...
...
transformers/configuration_utils.py
View file @
a5997dd8
...
@@ -130,20 +130,19 @@ class PretrainedConfig(object):
...
@@ -130,20 +130,19 @@ class PretrainedConfig(object):
# redirect to the cache, if necessary
# redirect to the cache, if necessary
try
:
try
:
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
except
EnvironmentError
as
e
:
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_config_archive_map
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_config_archive_map
:
logger
.
error
(
msg
=
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
config_file
)
config_file
))
else
:
else
:
logger
.
error
(
msg
=
"Model name '{}' was not found in model name list ({}). "
\
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url to a configuration file named {} or "
\
"We assumed '{}' was a path or url but couldn't find any file "
"a directory containing such a file but couldn't find any such file at this path or url."
.
format
(
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
pretrained_model_name_or_path
,
', '
.
join
(
cls
.
pretrained_config_archive_map
.
keys
()),
', '
.
join
(
cls
.
pretrained_config_archive_map
.
keys
()),
config_file
))
config_file
,
CONFIG_NAME
)
raise
e
raise
EnvironmentError
(
msg
)
if
resolved_config_file
==
config_file
:
if
resolved_config_file
==
config_file
:
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
else
:
else
:
...
...
transformers/modeling_utils.py
View file @
a5997dd8
...
@@ -316,20 +316,20 @@ class PreTrainedModel(nn.Module):
...
@@ -316,20 +316,20 @@ class PreTrainedModel(nn.Module):
# redirect to the cache, if necessary
# redirect to the cache, if necessary
try
:
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
except
EnvironmentError
as
e
:
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
logger
.
error
(
msg
=
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
archive_file
)
archive_file
))
else
:
else
:
logger
.
error
(
msg
=
"Model name '{}' was not found in model name list ({}). "
\
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url to model weight files named one of {} but "
\
"We assumed '{}' was a path or url but couldn't find any file "
"couldn't find any such file at this path or url."
.
format
(
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
pretrained_model_name_or_path
,
', '
.
join
(
cls
.
pretrained_model_archive_map
.
keys
()),
', '
.
join
(
cls
.
pretrained_model_archive_map
.
keys
()),
archive_file
))
archive_file
,
raise
e
[
WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
,
TF_WEIGHTS_NAME
])
raise
EnvironmentError
(
msg
)
if
resolved_archive_file
==
archive_file
:
if
resolved_archive_file
==
archive_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
else
:
else
:
...
...
transformers/tokenization_utils.py
View file @
a5997dd8
...
@@ -337,13 +337,13 @@ class PreTrainedTokenizer(object):
...
@@ -337,13 +337,13 @@ class PreTrainedTokenizer(object):
vocab_files
[
file_id
]
=
full_file_name
vocab_files
[
file_id
]
=
full_file_name
if
all
(
full_file_name
is
None
for
full_file_name
in
vocab_files
.
values
()):
if
all
(
full_file_name
is
None
for
full_file_name
in
vocab_files
.
values
()):
logger
.
e
rror
(
raise
EnvironmentE
rror
(
"Model name '{}' was not found in model name list ({}). "
"Model name '{}' was not found in
tokenizers
model name list ({}). "
"We assumed '{}' was a path or url
but couldn't find tokenizer
files"
"We assumed '{}' was a path or url
to a directory containing vocabulary
files
"
"at this path or url."
.
format
(
"
named {} but couldn't find such vocabulary files
at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
s3_models
),
pretrained_model_name_or_path
,
', '
.
join
(
s3_models
),
pretrained_model_name_or_path
,
))
pretrained_model_name_or_path
,
return
None
list
(
cls
.
vocab_files_names
.
values
())))
# Get files from url, cache, or disk depending on the case
# Get files from url, cache, or disk depending on the case
try
:
try
:
...
@@ -353,17 +353,18 @@ class PreTrainedTokenizer(object):
...
@@ -353,17 +353,18 @@ class PreTrainedTokenizer(object):
resolved_vocab_files
[
file_id
]
=
None
resolved_vocab_files
[
file_id
]
=
None
else
:
else
:
resolved_vocab_files
[
file_id
]
=
cached_path
(
file_path
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
resolved_vocab_files
[
file_id
]
=
cached_path
(
file_path
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
except
EnvironmentError
as
e
:
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
s3_models
:
if
pretrained_model_name_or_path
in
s3_models
:
logger
.
error
(
"Couldn't reach server to download vocabulary."
)
msg
=
"Couldn't reach server
at '{}'
to download vocabulary
files
."
else
:
else
:
logger
.
error
(
msg
=
"Model name '{}' was not found in tokenizers model name list ({}). "
\
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url to a directory containing vocabulary files "
\
"We assumed '{}' was a path or url but couldn't find files {} "
"named {}, but couldn't find such vocabulary files at this path or url."
.
format
(
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
s3_models
),
pretrained_model_name_or_path
,
', '
.
join
(
s3_models
),
pretrained_model_name_or_path
,
str
(
vocab_files
.
keys
())))
pretrained_model_name_or_path
,
raise
e
list
(
cls
.
vocab_files_names
.
values
()))
raise
EnvironmentError
(
msg
)
for
file_id
,
file_path
in
vocab_files
.
items
():
for
file_id
,
file_path
in
vocab_files
.
items
():
if
file_path
==
resolved_vocab_files
[
file_id
]:
if
file_path
==
resolved_vocab_files
[
file_id
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment