Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4d3cf0d6
Commit
4d3cf0d6
authored
Apr 07, 2019
by
Dhanajit Brahma
Browse files
removing some redundant lines
parent
4d3721f9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
18 deletions
+19
-18
examples/run_gpt2.py
examples/run_gpt2.py
+19
-18
No files found.
examples/run_gpt2.py
View file @
4d3cf0d6
...
@@ -83,29 +83,29 @@ def run_model():
...
@@ -83,29 +83,29 @@ def run_model():
elif
args
.
length
>
model
.
config
.
n_ctx
:
elif
args
.
length
>
model
.
config
.
n_ctx
:
raise
ValueError
(
"Can't get samples longer than window size: %s"
%
model
.
config
.
n_ctx
)
raise
ValueError
(
"Can't get samples longer than window size: %s"
%
model
.
config
.
n_ctx
)
while
not
args
.
unconditional
:
if
not
args
.
unconditional
:
if
not
args
.
unconditional
:
while
True
:
raw_text
=
input
(
"Model prompt >>> "
)
raw_text
=
input
(
"Model prompt >>> "
)
while
not
raw_text
:
while
not
raw_text
:
print
(
'Prompt should not be empty!'
)
print
(
'Prompt should not be empty!'
)
raw_text
=
input
(
"Model prompt >>> "
)
raw_text
=
input
(
"Model prompt >>> "
)
context_tokens
=
enc
.
encode
(
raw_text
)
context_tokens
=
enc
.
encode
(
raw_text
)
generated
=
0
generated
=
0
for
_
in
range
(
args
.
nsamples
//
args
.
batch_size
):
for
_
in
range
(
args
.
nsamples
//
args
.
batch_size
):
out
=
sample_sequence
(
out
=
sample_sequence
(
model
=
model
,
length
=
args
.
length
,
model
=
model
,
length
=
args
.
length
,
context
=
context_tokens
if
not
args
.
unconditional
else
None
,
context
=
context_tokens
,
start_token
=
enc
.
encoder
[
'<|endoftext|>'
]
if
args
.
unconditional
else
None
,
start_token
=
None
,
batch_size
=
args
.
batch_size
,
batch_size
=
args
.
batch_size
,
temperature
=
args
.
temperature
,
top_k
=
args
.
top_k
,
device
=
device
temperature
=
args
.
temperature
,
top_k
=
args
.
top_k
,
device
=
device
)
)
out
=
out
[:,
len
(
context_tokens
):].
tolist
()
out
=
out
[:,
len
(
context_tokens
):].
tolist
()
for
i
in
range
(
args
.
batch_size
):
for
i
in
range
(
args
.
batch_size
):
generated
+=
1
generated
+=
1
text
=
enc
.
decode
(
out
[
i
])
text
=
enc
.
decode
(
out
[
i
])
print
(
"="
*
40
+
" SAMPLE "
+
str
(
generated
)
+
" "
+
"="
*
40
)
print
(
"="
*
40
+
" SAMPLE "
+
str
(
generated
)
+
" "
+
"="
*
40
)
print
(
text
)
print
(
text
)
print
(
"="
*
80
)
print
(
"="
*
80
)
if
args
.
unconditional
:
if
args
.
unconditional
:
generated
=
0
generated
=
0
for
_
in
range
(
args
.
nsamples
//
args
.
batch_size
):
for
_
in
range
(
args
.
nsamples
//
args
.
batch_size
):
...
@@ -127,3 +127,4 @@ def run_model():
...
@@ -127,3 +127,4 @@ def run_model():
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
run_model
()
run_model
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment