Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c2c2ca0f
Commit
c2c2ca0f
authored
Oct 03, 2019
by
LysandreJik
Browse files
Added XLM to run_generation, with prompt language selection.
parent
1569610f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
5 deletions
+24
-5
examples/run_generation.py
examples/run_generation.py
+24
-5
No files found.
examples/run_generation.py
View file @
c2c2ca0f
...
@@ -26,12 +26,13 @@ import torch
...
@@ -26,12 +26,13 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
numpy
as
np
import
numpy
as
np
from
transformers
import
GPT2Config
,
OpenAIGPTConfig
,
XLNetConfig
,
TransfoXLConfig
from
transformers
import
GPT2Config
,
OpenAIGPTConfig
,
XLNetConfig
,
TransfoXLConfig
,
XLMConfig
from
transformers
import
GPT2LMHeadModel
,
GPT2Tokenizer
from
transformers
import
GPT2LMHeadModel
,
GPT2Tokenizer
from
transformers
import
OpenAIGPTLMHeadModel
,
OpenAIGPTTokenizer
from
transformers
import
OpenAIGPTLMHeadModel
,
OpenAIGPTTokenizer
from
transformers
import
XLNetLMHeadModel
,
XLNetTokenizer
from
transformers
import
XLNetLMHeadModel
,
XLNetTokenizer
from
transformers
import
TransfoXLLMHeadModel
,
TransfoXLTokenizer
from
transformers
import
TransfoXLLMHeadModel
,
TransfoXLTokenizer
from
transformers
import
XLMWithLMHeadModel
,
XLMTokenizer
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
...
@@ -41,13 +42,14 @@ logger = logging.getLogger(__name__)
...
@@ -41,13 +42,14 @@ logger = logging.getLogger(__name__)
MAX_LENGTH
=
int
(
10000
)
# Hardcoded max length to avoid infinite loop
MAX_LENGTH
=
int
(
10000
)
# Hardcoded max length to avoid infinite loop
ALL_MODELS
=
sum
((
tuple
(
conf
.
pretrained_config_archive_map
.
keys
())
for
conf
in
(
GPT2Config
,
OpenAIGPTConfig
,
XLNetConfig
,
TransfoXLConfig
)),
())
ALL_MODELS
=
sum
((
tuple
(
conf
.
pretrained_config_archive_map
.
keys
())
for
conf
in
(
GPT2Config
,
OpenAIGPTConfig
,
XLNetConfig
,
TransfoXLConfig
,
XLMConfig
)),
())
MODEL_CLASSES
=
{
MODEL_CLASSES
=
{
'gpt2'
:
(
GPT2LMHeadModel
,
GPT2Tokenizer
),
'gpt2'
:
(
GPT2LMHeadModel
,
GPT2Tokenizer
),
'openai-gpt'
:
(
OpenAIGPTLMHeadModel
,
OpenAIGPTTokenizer
),
'openai-gpt'
:
(
OpenAIGPTLMHeadModel
,
OpenAIGPTTokenizer
),
'xlnet'
:
(
XLNetLMHeadModel
,
XLNetTokenizer
),
'xlnet'
:
(
XLNetLMHeadModel
,
XLNetTokenizer
),
'transfo-xl'
:
(
TransfoXLLMHeadModel
,
TransfoXLTokenizer
),
'transfo-xl'
:
(
TransfoXLLMHeadModel
,
TransfoXLTokenizer
),
'xlm'
:
(
XLMWithLMHeadModel
,
XLMTokenizer
),
}
}
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
...
@@ -103,7 +105,8 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
...
@@ -103,7 +105,8 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
return
logits
return
logits
def
sample_sequence
(
model
,
length
,
context
,
num_samples
=
1
,
temperature
=
1
,
top_k
=
0
,
top_p
=
0.0
,
is_xlnet
=
False
,
device
=
'cpu'
):
def
sample_sequence
(
model
,
length
,
context
,
num_samples
=
1
,
temperature
=
1
,
top_k
=
0
,
top_p
=
0.0
,
is_xlnet
=
False
,
xlm_lang
=
None
,
device
=
'cpu'
):
context
=
torch
.
tensor
(
context
,
dtype
=
torch
.
long
,
device
=
device
)
context
=
torch
.
tensor
(
context
,
dtype
=
torch
.
long
,
device
=
device
)
context
=
context
.
unsqueeze
(
0
).
repeat
(
num_samples
,
1
)
context
=
context
.
unsqueeze
(
0
).
repeat
(
num_samples
,
1
)
generated
=
context
generated
=
context
...
@@ -121,6 +124,9 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
...
@@ -121,6 +124,9 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
target_mapping
[
0
,
0
,
-
1
]
=
1.0
# predict last token
target_mapping
[
0
,
0
,
-
1
]
=
1.0
# predict last token
inputs
=
{
'input_ids'
:
input_ids
,
'perm_mask'
:
perm_mask
,
'target_mapping'
:
target_mapping
}
inputs
=
{
'input_ids'
:
input_ids
,
'perm_mask'
:
perm_mask
,
'target_mapping'
:
target_mapping
}
if
xlm_lang
is
not
None
:
inputs
[
"langs"
]
=
torch
.
tensor
([
xlm_lang
]
*
inputs
[
"input_ids"
].
shape
[
1
]).
view
(
1
,
-
1
)
outputs
=
model
(
**
inputs
)
# Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
outputs
=
model
(
**
inputs
)
# Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
next_token_logits
=
outputs
[
0
][
0
,
-
1
,
:]
/
temperature
next_token_logits
=
outputs
[
0
][
0
,
-
1
,
:]
/
temperature
filtered_logits
=
top_k_top_p_filtering
(
next_token_logits
,
top_k
=
top_k
,
top_p
=
top_p
)
filtered_logits
=
top_k_top_p_filtering
(
next_token_logits
,
top_k
=
top_k
,
top_p
=
top_p
)
...
@@ -137,6 +143,7 @@ def main():
...
@@ -137,6 +143,7 @@ def main():
help
=
"Path to pre-trained model or shortcut name selected in the list: "
+
", "
.
join
(
ALL_MODELS
))
help
=
"Path to pre-trained model or shortcut name selected in the list: "
+
", "
.
join
(
ALL_MODELS
))
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--padding_text"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--padding_text"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--xlm_lang"
,
type
=
str
,
default
=
""
,
help
=
"Optional language when used with the XLM model."
)
parser
.
add_argument
(
"--length"
,
type
=
int
,
default
=
20
)
parser
.
add_argument
(
"--length"
,
type
=
int
,
default
=
20
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
0
)
...
@@ -168,6 +175,17 @@ def main():
...
@@ -168,6 +175,17 @@ def main():
print
(
args
)
print
(
args
)
while
True
:
while
True
:
xlm_lang
=
None
# XLM Language usage detailed in the issues #1414
if
args
.
model_type
in
[
"xlm"
]
and
hasattr
(
tokenizer
,
'lang2id'
):
if
args
.
xlm_lang
:
language
=
args
.
xlm_lang
else
:
language
=
None
while
language
not
in
tokenizer
.
lang2id
.
keys
():
language
=
input
(
"Using XLM. Select language in "
+
str
(
list
(
tokenizer
.
lang2id
.
keys
()))
+
" >>> "
)
xlm_lang
=
tokenizer
.
lang2id
[
language
]
raw_text
=
args
.
prompt
if
args
.
prompt
else
input
(
"Model prompt >>> "
)
raw_text
=
args
.
prompt
if
args
.
prompt
else
input
(
"Model prompt >>> "
)
if
args
.
model_type
in
[
"transfo-xl"
,
"xlnet"
]:
if
args
.
model_type
in
[
"transfo-xl"
,
"xlnet"
]:
# Models with memory likes to have a long prompt for short inputs.
# Models with memory likes to have a long prompt for short inputs.
...
@@ -180,11 +198,12 @@ def main():
...
@@ -180,11 +198,12 @@ def main():
temperature
=
args
.
temperature
,
temperature
=
args
.
temperature
,
top_k
=
args
.
top_k
,
top_k
=
args
.
top_k
,
top_p
=
args
.
top_p
,
top_p
=
args
.
top_p
,
device
=
args
.
device
,
is_xlnet
=
bool
(
args
.
model_type
==
"xlnet"
),
is_xlnet
=
bool
(
args
.
model_type
==
"xlnet"
),
xlm_lang
=
xlm_lang
,
device
=
args
.
device
,
)
)
out
=
out
[
0
,
len
(
context_tokens
):].
tolist
()
out
=
out
[
0
,
len
(
context_tokens
):].
tolist
()
text
=
tokenizer
.
decode
(
out
,
clean_up_tokenization_spaces
=
True
)
text
=
tokenizer
.
decode
(
out
,
clean_up_tokenization_spaces
=
True
,
skip_special_tokens
=
True
)
print
(
text
)
print
(
text
)
if
args
.
prompt
:
if
args
.
prompt
:
break
break
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment