Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ef1b8b2a
Commit
ef1b8b2a
authored
Oct 22, 2019
by
Julien Chaumond
Browse files
[CTRL] warn if generation prompt does not start with a control code
see also
https://github.com/salesforce/ctrl/pull/50
parent
e16d4684
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
65 additions
and
3 deletions
+65
-3
README.md
README.md
+1
-1
examples/README.md
examples/README.md
+1
-1
examples/run_generation.py
examples/run_generation.py
+4
-1
transformers/tokenization_ctrl.py
transformers/tokenization_ctrl.py
+59
-0
No files found.
README.md
View file @
ef1b8b2a
...
...
@@ -413,7 +413,7 @@ and from the Salesforce CTRL model:
python ./examples/run_generation.py
\
--model_type
=
ctrl
\
--length
=
20
\
--model_name_or_path
=
gpt2
\
--model_name_or_path
=
ctrl
\
--temperature
=
0
\
--repetition_penalty
=
1.2
\
```
...
...
examples/README.md
View file @
ef1b8b2a
...
...
@@ -101,7 +101,7 @@ python run_lm_finetuning.py \
Based on the script
[
`run_generation.py`
](
https://github.com/huggingface/transformers/blob/master/examples/run_generation.py
)
.
Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, Transformer-XL
and
XLNet.
Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, Transformer-XL
,
XLNet
, CTRL
.
A similar script is used for our official demo
[
Write With Transfomer
](
https://transformer.huggingface.co
)
, where you
can try out the different models available in the library.
...
...
examples/run_generation.py
View file @
ef1b8b2a
...
...
@@ -196,7 +196,7 @@ def main():
logger
.
info
(
args
)
if
args
.
model_type
in
[
"ctrl"
]:
if
args
.
temperature
>
0.7
:
if
args
.
temperature
>
0.7
:
logger
.
info
(
'CTRL typically works better with lower temperatures (and lower top_k).'
)
while
True
:
...
...
@@ -224,6 +224,9 @@ def main():
# Models with memory likes to have a long prompt for short inputs.
raw_text
=
(
args
.
padding_text
if
args
.
padding_text
else
PADDING_TEXT
)
+
raw_text
context_tokens
=
tokenizer
.
encode
(
raw_text
)
if
args
.
model_type
==
"ctrl"
:
if
not
any
(
context_tokens
[
0
]
==
x
for
x
in
tokenizer
.
control_codes
.
values
()):
logger
.
info
(
"WARNING! You are not starting your generation from a control code so you won't get good results"
)
out
=
sample_sequence
(
model
=
model
,
context
=
context_tokens
,
...
...
transformers/tokenization_ctrl.py
View file @
ef1b8b2a
...
...
@@ -46,6 +46,64 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'ctrl'
:
256
,
}
CONTROL_CODES
=
{
"Pregnancy"
:
168629
,
"Christianity"
:
7675
,
"Explain"
:
106423
,
"Fitness"
:
63440
,
"Saving"
:
63163
,
"Ask"
:
27171
,
"Ass"
:
95985
,
"Joke"
:
163509
,
"Questions"
:
45622
,
"Thoughts"
:
49605
,
"Retail"
:
52342
,
"Feminism"
:
164338
,
"Writing"
:
11992
,
"Atheism"
:
192263
,
"Netflix"
:
48616
,
"Computing"
:
39639
,
"Opinion"
:
43213
,
"Alone"
:
44967
,
"Funny"
:
58917
,
"Gaming"
:
40358
,
"Human"
:
4088
,
"India"
:
1331
,
"Joker"
:
77138
,
"Diet"
:
36206
,
"Legal"
:
11859
,
"Norman"
:
4939
,
"Tip"
:
72689
,
"Weight"
:
52343
,
"Movies"
:
46273
,
"Running"
:
23425
,
"Science"
:
2090
,
"Horror"
:
37793
,
"Confession"
:
60572
,
"Finance"
:
12250
,
"Politics"
:
16360
,
"Scary"
:
191985
,
"Support"
:
12654
,
"Technologies"
:
32516
,
"Teenage"
:
66160
,
"Event"
:
32769
,
"Learned"
:
67460
,
"Notion"
:
182770
,
"Wikipedia"
:
37583
,
"Books"
:
6665
,
"Extract"
:
76050
,
"Confessions"
:
102701
,
"Conspiracy"
:
75932
,
"Links"
:
63674
,
"Narcissus"
:
150425
,
"Relationship"
:
54766
,
"Relationships"
:
134796
,
"Reviews"
:
41671
,
"News"
:
4256
,
"Translation"
:
26820
,
"multilingual"
:
128406
,
}
def
get_pairs
(
word
):
"""Return set of symbol pairs in a word.
...
...
@@ -68,6 +126,7 @@ class CTRLTokenizer(PreTrainedTokenizer):
vocab_files_names
=
VOCAB_FILES_NAMES
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
control_codes
=
CONTROL_CODES
def
__init__
(
self
,
vocab_file
,
merges_file
,
unk_token
=
"<unk>"
,
**
kwargs
):
super
(
CTRLTokenizer
,
self
).
__init__
(
unk_token
=
unk_token
,
**
kwargs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment