Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a5997dd8
Commit
a5997dd8
authored
Oct 10, 2019
by
thomwolf
Browse files
better error messages
parent
43a237f1
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
57 additions
and
40 deletions
+57
-40
examples/run_generation.py
examples/run_generation.py
+23
-6
transformers/configuration_utils.py
transformers/configuration_utils.py
+9
-10
transformers/modeling_utils.py
transformers/modeling_utils.py
+10
-10
transformers/tokenization_utils.py
transformers/tokenization_utils.py
+15
-14
No files found.
examples/run_generation.py
View file @
a5997dd8
...
...
@@ -107,7 +107,8 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
return
logits
def
sample_sequence
(
model
,
length
,
context
,
num_samples
=
1
,
temperature
=
1
,
top_k
=
0
,
top_p
=
0.0
,
repetition_penalty
=
1.0
,
is_xlnet
=
False
,
xlm_lang
=
None
,
device
=
'cpu'
):
def
sample_sequence
(
model
,
length
,
context
,
num_samples
=
1
,
temperature
=
1
,
top_k
=
0
,
top_p
=
0.0
,
repetition_penalty
=
1.0
,
is_xlnet
=
False
,
is_xlm_mlm
=
False
,
xlm_mask_token
=
None
,
xlm_lang
=
None
,
device
=
'cpu'
):
context
=
torch
.
tensor
(
context
,
dtype
=
torch
.
long
,
device
=
device
)
context
=
context
.
unsqueeze
(
0
).
repeat
(
num_samples
,
1
)
generated
=
context
...
...
@@ -125,10 +126,16 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
target_mapping
[
0
,
0
,
-
1
]
=
1.0
# predict last token
inputs
=
{
'input_ids'
:
input_ids
,
'perm_mask'
:
perm_mask
,
'target_mapping'
:
target_mapping
}
if
is_xlm_mlm
and
xlm_mask_token
:
# XLM MLM models are direct models (predict same token, not next token)
# => need one additional dummy token in the input (will be masked and guessed)
input_ids
=
torch
.
cat
((
generated
,
torch
.
full
((
1
,
1
),
xlm_mask_token
,
dtype
=
torch
.
long
,
device
=
device
)),
dim
=
1
)
inputs
=
{
'input_ids'
:
input_ids
}
if
xlm_lang
is
not
None
:
inputs
[
"langs"
]
=
torch
.
tensor
([
xlm_lang
]
*
inputs
[
"input_ids"
].
shape
[
1
],
device
=
device
).
view
(
1
,
-
1
)
outputs
=
model
(
**
inputs
)
# Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
outputs
=
model
(
**
inputs
)
# Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet
/CTRL
(cached hidden-states)
next_token_logits
=
outputs
[
0
][
0
,
-
1
,
:]
/
(
temperature
if
temperature
>
0
else
1.
)
# reptition penalty from CTRL (https://arxiv.org/abs/1909.05858)
...
...
@@ -167,10 +174,7 @@ def main():
parser
.
add_argument
(
'--stop_token'
,
type
=
str
,
default
=
None
,
help
=
"Token at which text generation is stopped"
)
args
=
parser
.
parse_args
()
if
args
.
model_type
in
[
"ctrl"
]:
if
args
.
temperature
>
0.7
:
print
(
'CTRL typically works better with lower temperatures (and lower top_k).'
)
args
.
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
and
not
args
.
no_cuda
else
"cpu"
)
args
.
n_gpu
=
torch
.
cuda
.
device_count
()
...
...
@@ -191,6 +195,10 @@ def main():
args
.
length
=
MAX_LENGTH
# avoid infinite loop
print
(
args
)
if
args
.
model_type
in
[
"ctrl"
]:
if
args
.
temperature
>
0.7
:
logger
.
info
(
'CTRL typically works better with lower temperatures (and lower top_k).'
)
while
True
:
xlm_lang
=
None
# XLM Language usage detailed in the issues #1414
...
...
@@ -204,6 +212,13 @@ def main():
language
=
input
(
"Using XLM. Select language in "
+
str
(
list
(
tokenizer
.
lang2id
.
keys
()))
+
" >>> "
)
xlm_lang
=
tokenizer
.
lang2id
[
language
]
# XLM masked-language modeling (MLM) models need masked token (see details in sample_sequence)
is_xlm_mlm
=
args
.
model_type
in
[
"xlm"
]
and
'mlm'
in
args
.
model_name_or_path
if
is_xlm_mlm
:
xlm_mask_token
=
tokenizer
.
mask_token_id
else
:
xlm_mask_token
=
None
raw_text
=
args
.
prompt
if
args
.
prompt
else
input
(
"Model prompt >>> "
)
if
args
.
model_type
in
[
"transfo-xl"
,
"xlnet"
]:
# Models with memory likes to have a long prompt for short inputs.
...
...
@@ -218,6 +233,8 @@ def main():
top_p
=
args
.
top_p
,
repetition_penalty
=
args
.
repetition_penalty
,
is_xlnet
=
bool
(
args
.
model_type
==
"xlnet"
),
is_xlm_mlm
=
is_xlm_mlm
,
xlm_mask_token
=
xlm_mask_token
,
xlm_lang
=
xlm_lang
,
device
=
args
.
device
,
)
...
...
transformers/configuration_utils.py
View file @
a5997dd8
...
...
@@ -130,20 +130,19 @@ class PretrainedConfig(object):
# redirect to the cache, if necessary
try
:
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
except
EnvironmentError
as
e
:
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_config_archive_map
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
config_file
))
msg
=
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
config_file
)
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
msg
=
"Model name '{}' was not found in model name list ({}). "
\
"We assumed '{}' was a path or url to a configuration file named {} or "
\
"a directory containing such a file but couldn't find any such file at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
cls
.
pretrained_config_archive_map
.
keys
()),
config_file
))
raise
e
config_file
,
CONFIG_NAME
)
raise
EnvironmentError
(
msg
)
if
resolved_config_file
==
config_file
:
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
else
:
...
...
transformers/modeling_utils.py
View file @
a5997dd8
...
...
@@ -316,20 +316,20 @@ class PreTrainedModel(nn.Module):
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
except
EnvironmentError
as
e
:
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
archive_file
))
msg
=
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
archive_file
)
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
msg
=
"Model name '{}' was not found in model name list ({}). "
\
"We assumed '{}' was a path or url to model weight files named one of {} but "
\
"couldn't find any such file at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
cls
.
pretrained_model_archive_map
.
keys
()),
archive_file
))
raise
e
archive_file
,
[
WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
,
TF_WEIGHTS_NAME
])
raise
EnvironmentError
(
msg
)
if
resolved_archive_file
==
archive_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
else
:
...
...
transformers/tokenization_utils.py
View file @
a5997dd8
...
...
@@ -337,13 +337,13 @@ class PreTrainedTokenizer(object):
vocab_files
[
file_id
]
=
full_file_name
if
all
(
full_file_name
is
None
for
full_file_name
in
vocab_files
.
values
()):
logger
.
e
rror
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url
but couldn't find tokenizer
files"
"at this path or url."
.
format
(
raise
EnvironmentE
rror
(
"Model name '{}' was not found in
tokenizers
model name list ({}). "
"We assumed '{}' was a path or url
to a directory containing vocabulary
files
"
"
named {} but couldn't find such vocabulary files
at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
s3_models
),
pretrained_model_name_or_path
,
))
return
None
pretrained_model_name_or_path
,
list
(
cls
.
vocab_files_names
.
values
())))
# Get files from url, cache, or disk depending on the case
try
:
...
...
@@ -353,17 +353,18 @@ class PreTrainedTokenizer(object):
resolved_vocab_files
[
file_id
]
=
None
else
:
resolved_vocab_files
[
file_id
]
=
cached_path
(
file_path
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
except
EnvironmentError
as
e
:
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
s3_models
:
logger
.
error
(
"Couldn't reach server to download vocabulary."
)
msg
=
"Couldn't reach server
at '{}'
to download vocabulary
files
."
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url."
.
format
(
msg
=
"Model name '{}' was not found in tokenizers model name list ({}). "
\
"We assumed '{}' was a path or url to a directory containing vocabulary files "
\
"named {}, but couldn't find such vocabulary files at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
s3_models
),
pretrained_model_name_or_path
,
str
(
vocab_files
.
keys
())))
raise
e
pretrained_model_name_or_path
,
list
(
cls
.
vocab_files_names
.
values
()))
raise
EnvironmentError
(
msg
)
for
file_id
,
file_path
in
vocab_files
.
items
():
if
file_path
==
resolved_vocab_files
[
file_id
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment