Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a5997dd8
Commit
a5997dd8
authored
Oct 10, 2019
by
thomwolf
Browse files
better error messages
parent
43a237f1
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
57 additions
and
40 deletions
+57
-40
examples/run_generation.py
examples/run_generation.py
+23
-6
transformers/configuration_utils.py
transformers/configuration_utils.py
+9
-10
transformers/modeling_utils.py
transformers/modeling_utils.py
+10
-10
transformers/tokenization_utils.py
transformers/tokenization_utils.py
+15
-14
No files found.
examples/run_generation.py
View file @
a5997dd8
...
@@ -107,7 +107,8 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
...
@@ -107,7 +107,8 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
return
logits
return
logits
def
sample_sequence
(
model
,
length
,
context
,
num_samples
=
1
,
temperature
=
1
,
top_k
=
0
,
top_p
=
0.0
,
repetition_penalty
=
1.0
,
is_xlnet
=
False
,
xlm_lang
=
None
,
device
=
'cpu'
):
def
sample_sequence
(
model
,
length
,
context
,
num_samples
=
1
,
temperature
=
1
,
top_k
=
0
,
top_p
=
0.0
,
repetition_penalty
=
1.0
,
is_xlnet
=
False
,
is_xlm_mlm
=
False
,
xlm_mask_token
=
None
,
xlm_lang
=
None
,
device
=
'cpu'
):
context
=
torch
.
tensor
(
context
,
dtype
=
torch
.
long
,
device
=
device
)
context
=
torch
.
tensor
(
context
,
dtype
=
torch
.
long
,
device
=
device
)
context
=
context
.
unsqueeze
(
0
).
repeat
(
num_samples
,
1
)
context
=
context
.
unsqueeze
(
0
).
repeat
(
num_samples
,
1
)
generated
=
context
generated
=
context
...
@@ -125,10 +126,16 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
...
@@ -125,10 +126,16 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
target_mapping
[
0
,
0
,
-
1
]
=
1.0
# predict last token
target_mapping
[
0
,
0
,
-
1
]
=
1.0
# predict last token
inputs
=
{
'input_ids'
:
input_ids
,
'perm_mask'
:
perm_mask
,
'target_mapping'
:
target_mapping
}
inputs
=
{
'input_ids'
:
input_ids
,
'perm_mask'
:
perm_mask
,
'target_mapping'
:
target_mapping
}
if
is_xlm_mlm
and
xlm_mask_token
:
# XLM MLM models are direct models (predict same token, not next token)
# => need one additional dummy token in the input (will be masked and guessed)
input_ids
=
torch
.
cat
((
generated
,
torch
.
full
((
1
,
1
),
xlm_mask_token
,
dtype
=
torch
.
long
,
device
=
device
)),
dim
=
1
)
inputs
=
{
'input_ids'
:
input_ids
}
if
xlm_lang
is
not
None
:
if
xlm_lang
is
not
None
:
inputs
[
"langs"
]
=
torch
.
tensor
([
xlm_lang
]
*
inputs
[
"input_ids"
].
shape
[
1
],
device
=
device
).
view
(
1
,
-
1
)
inputs
[
"langs"
]
=
torch
.
tensor
([
xlm_lang
]
*
inputs
[
"input_ids"
].
shape
[
1
],
device
=
device
).
view
(
1
,
-
1
)
outputs
=
model
(
**
inputs
)
# Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
outputs
=
model
(
**
inputs
)
# Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet
/CTRL
(cached hidden-states)
next_token_logits
=
outputs
[
0
][
0
,
-
1
,
:]
/
(
temperature
if
temperature
>
0
else
1.
)
next_token_logits
=
outputs
[
0
][
0
,
-
1
,
:]
/
(
temperature
if
temperature
>
0
else
1.
)
# reptition penalty from CTRL (https://arxiv.org/abs/1909.05858)
# reptition penalty from CTRL (https://arxiv.org/abs/1909.05858)
...
@@ -167,9 +174,6 @@ def main():
...
@@ -167,9 +174,6 @@ def main():
parser
.
add_argument
(
'--stop_token'
,
type
=
str
,
default
=
None
,
parser
.
add_argument
(
'--stop_token'
,
type
=
str
,
default
=
None
,
help
=
"Token at which text generation is stopped"
)
help
=
"Token at which text generation is stopped"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
.
model_type
in
[
"ctrl"
]:
if
args
.
temperature
>
0.7
:
print
(
'CTRL typically works better with lower temperatures (and lower top_k).'
)
args
.
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
and
not
args
.
no_cuda
else
"cpu"
)
args
.
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
and
not
args
.
no_cuda
else
"cpu"
)
args
.
n_gpu
=
torch
.
cuda
.
device_count
()
args
.
n_gpu
=
torch
.
cuda
.
device_count
()
...
@@ -191,6 +195,10 @@ def main():
...
@@ -191,6 +195,10 @@ def main():
args
.
length
=
MAX_LENGTH
# avoid infinite loop
args
.
length
=
MAX_LENGTH
# avoid infinite loop
print
(
args
)
print
(
args
)
if
args
.
model_type
in
[
"ctrl"
]:
if
args
.
temperature
>
0.7
:
logger
.
info
(
'CTRL typically works better with lower temperatures (and lower top_k).'
)
while
True
:
while
True
:
xlm_lang
=
None
xlm_lang
=
None
# XLM Language usage detailed in the issues #1414
# XLM Language usage detailed in the issues #1414
...
@@ -204,6 +212,13 @@ def main():
...
@@ -204,6 +212,13 @@ def main():
language
=
input
(
"Using XLM. Select language in "
+
str
(
list
(
tokenizer
.
lang2id
.
keys
()))
+
" >>> "
)
language
=
input
(
"Using XLM. Select language in "
+
str
(
list
(
tokenizer
.
lang2id
.
keys
()))
+
" >>> "
)
xlm_lang
=
tokenizer
.
lang2id
[
language
]
xlm_lang
=
tokenizer
.
lang2id
[
language
]
# XLM masked-language modeling (MLM) models need masked token (see details in sample_sequence)
is_xlm_mlm
=
args
.
model_type
in
[
"xlm"
]
and
'mlm'
in
args
.
model_name_or_path
if
is_xlm_mlm
:
xlm_mask_token
=
tokenizer
.
mask_token_id
else
:
xlm_mask_token
=
None
raw_text
=
args
.
prompt
if
args
.
prompt
else
input
(
"Model prompt >>> "
)
raw_text
=
args
.
prompt
if
args
.
prompt
else
input
(
"Model prompt >>> "
)
if
args
.
model_type
in
[
"transfo-xl"
,
"xlnet"
]:
if
args
.
model_type
in
[
"transfo-xl"
,
"xlnet"
]:
# Models with memory likes to have a long prompt for short inputs.
# Models with memory likes to have a long prompt for short inputs.
...
@@ -218,6 +233,8 @@ def main():
...
@@ -218,6 +233,8 @@ def main():
top_p
=
args
.
top_p
,
top_p
=
args
.
top_p
,
repetition_penalty
=
args
.
repetition_penalty
,
repetition_penalty
=
args
.
repetition_penalty
,
is_xlnet
=
bool
(
args
.
model_type
==
"xlnet"
),
is_xlnet
=
bool
(
args
.
model_type
==
"xlnet"
),
is_xlm_mlm
=
is_xlm_mlm
,
xlm_mask_token
=
xlm_mask_token
,
xlm_lang
=
xlm_lang
,
xlm_lang
=
xlm_lang
,
device
=
args
.
device
,
device
=
args
.
device
,
)
)
...
...
transformers/configuration_utils.py
View file @
a5997dd8
...
@@ -130,20 +130,19 @@ class PretrainedConfig(object):
...
@@ -130,20 +130,19 @@ class PretrainedConfig(object):
# redirect to the cache, if necessary
# redirect to the cache, if necessary
try
:
try
:
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
except
EnvironmentError
as
e
:
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_config_archive_map
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_config_archive_map
:
logger
.
error
(
msg
=
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
config_file
)
config_file
))
else
:
else
:
logger
.
error
(
msg
=
"Model name '{}' was not found in model name list ({}). "
\
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url to a configuration file named {} or "
\
"We assumed '{}' was a path or url but couldn't find any file "
"a directory containing such a file but couldn't find any such file at this path or url."
.
format
(
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
pretrained_model_name_or_path
,
', '
.
join
(
cls
.
pretrained_config_archive_map
.
keys
()),
', '
.
join
(
cls
.
pretrained_config_archive_map
.
keys
()),
config_file
))
config_file
,
CONFIG_NAME
)
raise
e
raise
EnvironmentError
(
msg
)
if
resolved_config_file
==
config_file
:
if
resolved_config_file
==
config_file
:
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
else
:
else
:
...
...
transformers/modeling_utils.py
View file @
a5997dd8
...
@@ -316,20 +316,20 @@ class PreTrainedModel(nn.Module):
...
@@ -316,20 +316,20 @@ class PreTrainedModel(nn.Module):
# redirect to the cache, if necessary
# redirect to the cache, if necessary
try
:
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
except
EnvironmentError
as
e
:
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
logger
.
error
(
msg
=
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
archive_file
)
archive_file
))
else
:
else
:
logger
.
error
(
msg
=
"Model name '{}' was not found in model name list ({}). "
\
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url to model weight files named one of {} but "
\
"We assumed '{}' was a path or url but couldn't find any file "
"couldn't find any such file at this path or url."
.
format
(
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
pretrained_model_name_or_path
,
', '
.
join
(
cls
.
pretrained_model_archive_map
.
keys
()),
', '
.
join
(
cls
.
pretrained_model_archive_map
.
keys
()),
archive_file
))
archive_file
,
raise
e
[
WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
,
TF_WEIGHTS_NAME
])
raise
EnvironmentError
(
msg
)
if
resolved_archive_file
==
archive_file
:
if
resolved_archive_file
==
archive_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
else
:
else
:
...
...
transformers/tokenization_utils.py
View file @
a5997dd8
...
@@ -337,13 +337,13 @@ class PreTrainedTokenizer(object):
...
@@ -337,13 +337,13 @@ class PreTrainedTokenizer(object):
vocab_files
[
file_id
]
=
full_file_name
vocab_files
[
file_id
]
=
full_file_name
if
all
(
full_file_name
is
None
for
full_file_name
in
vocab_files
.
values
()):
if
all
(
full_file_name
is
None
for
full_file_name
in
vocab_files
.
values
()):
logger
.
e
rror
(
raise
EnvironmentE
rror
(
"Model name '{}' was not found in model name list ({}). "
"Model name '{}' was not found in
tokenizers
model name list ({}). "
"We assumed '{}' was a path or url
but couldn't find tokenizer
files"
"We assumed '{}' was a path or url
to a directory containing vocabulary
files
"
"at this path or url."
.
format
(
"
named {} but couldn't find such vocabulary files
at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
s3_models
),
pretrained_model_name_or_path
,
', '
.
join
(
s3_models
),
pretrained_model_name_or_path
,
))
pretrained_model_name_or_path
,
return
None
list
(
cls
.
vocab_files_names
.
values
())))
# Get files from url, cache, or disk depending on the case
# Get files from url, cache, or disk depending on the case
try
:
try
:
...
@@ -353,17 +353,18 @@ class PreTrainedTokenizer(object):
...
@@ -353,17 +353,18 @@ class PreTrainedTokenizer(object):
resolved_vocab_files
[
file_id
]
=
None
resolved_vocab_files
[
file_id
]
=
None
else
:
else
:
resolved_vocab_files
[
file_id
]
=
cached_path
(
file_path
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
resolved_vocab_files
[
file_id
]
=
cached_path
(
file_path
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
except
EnvironmentError
as
e
:
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
s3_models
:
if
pretrained_model_name_or_path
in
s3_models
:
logger
.
error
(
"Couldn't reach server to download vocabulary."
)
msg
=
"Couldn't reach server
at '{}'
to download vocabulary
files
."
else
:
else
:
logger
.
error
(
msg
=
"Model name '{}' was not found in tokenizers model name list ({}). "
\
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url to a directory containing vocabulary files "
\
"We assumed '{}' was a path or url but couldn't find files {} "
"named {}, but couldn't find such vocabulary files at this path or url."
.
format
(
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
s3_models
),
pretrained_model_name_or_path
,
', '
.
join
(
s3_models
),
pretrained_model_name_or_path
,
str
(
vocab_files
.
keys
())))
pretrained_model_name_or_path
,
raise
e
list
(
cls
.
vocab_files_names
.
values
()))
raise
EnvironmentError
(
msg
)
for
file_id
,
file_path
in
vocab_files
.
items
():
for
file_id
,
file_path
in
vocab_files
.
items
():
if
file_path
==
resolved_vocab_files
[
file_id
]:
if
file_path
==
resolved_vocab_files
[
file_id
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment