Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fbb248a2
Commit
fbb248a2
authored
Feb 18, 2019
by
thomwolf
Browse files
examples testing
parent
5ff0c605
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
15 deletions
+28
-15
examples/run_gpt2_generate_unconditional_samples.py
examples/run_gpt2_generate_unconditional_samples.py
+11
-4
examples/run_gpt2_interactive_conditional_samples.py
examples/run_gpt2_interactive_conditional_samples.py
+12
-6
pytorch_pretrained_bert/modeling_gpt2.py
pytorch_pretrained_bert/modeling_gpt2.py
+5
-5
No files found.
examples/run_gpt2_generate_unconditional_samples.py
View file @
fbb248a2
...
@@ -4,7 +4,9 @@ import argparse
...
@@ -4,7 +4,9 @@ import argparse
import
logging
import
logging
import
torch
import
torch
import
torch.nn.functional
as
F
import
numpy
as
np
import
numpy
as
np
from
tqdm
import
trange
from
pytorch_pretrained_bert
import
GPT2LMHeadModel
,
GPT2Tokenizer
from
pytorch_pretrained_bert
import
GPT2LMHeadModel
,
GPT2Tokenizer
...
@@ -23,18 +25,20 @@ def top_k_logits(logits, k):
...
@@ -23,18 +25,20 @@ def top_k_logits(logits, k):
def
sample_sequence
(
model
,
length
,
start_token
=
None
,
batch_size
=
None
,
context
=
None
,
temperature
=
1
,
top_k
=
0
,
device
=
'cuda'
):
def
sample_sequence
(
model
,
length
,
start_token
=
None
,
batch_size
=
None
,
context
=
None
,
temperature
=
1
,
top_k
=
0
,
device
=
'cuda'
):
if
start_token
is
None
:
if
start_token
is
None
:
assert
context
is
not
None
,
'Specify exactly one of start_token and context!'
assert
context
is
not
None
,
'Specify exactly one of start_token and context!'
context
=
torch
.
tensor
(
context
,
device
=
device
)
context
=
torch
.
tensor
(
context
,
device
=
device
,
dtype
=
torch
.
long
)
else
:
else
:
assert
context
is
None
,
'Specify exactly one of start_token and context!'
assert
context
is
None
,
'Specify exactly one of start_token and context!'
context
=
torch
.
full
((
batch_size
,
1
),
start_token
,
device
=
device
)
context
=
torch
.
full
((
batch_size
,
1
),
start_token
,
device
=
device
,
dtype
=
torch
.
long
)
prev
=
context
prev
=
context
output
=
context
output
=
context
past
=
None
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
i
in
range
(
length
):
for
i
in
t
range
(
length
):
logits
,
past
=
model
(
prev
,
past
=
past
)
logits
,
past
=
model
(
prev
,
past
=
past
)
logits
=
logits
[:,
-
1
,
:]
/
temperature
logits
=
logits
[:,
-
1
,
:]
/
temperature
logits
=
top_k_logits
(
logits
,
k
=
top_k
)
logits
=
top_k_logits
(
logits
,
k
=
top_k
)
prev
=
torch
.
multinomial
(
logits
,
1
)
log_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
prev
=
torch
.
multinomial
(
log_probs
,
num_samples
=
1
)
output
=
torch
.
cat
((
output
,
prev
),
dim
=
1
)
output
=
torch
.
cat
((
output
,
prev
),
dim
=
1
)
return
output
return
output
...
@@ -57,6 +61,8 @@ def sample_model():
...
@@ -57,6 +61,8 @@ def sample_model():
enc
=
GPT2Tokenizer
.
from_pretrained
(
args
.
model_name_or_path
)
enc
=
GPT2Tokenizer
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
GPT2LMHeadModel
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
GPT2LMHeadModel
.
from_pretrained
(
args
.
model_name_or_path
)
model
.
to
(
device
)
model
.
eval
()
if
args
.
length
==
-
1
:
if
args
.
length
==
-
1
:
args
.
length
=
model
.
config
.
n_ctx
args
.
length
=
model
.
config
.
n_ctx
...
@@ -71,6 +77,7 @@ def sample_model():
...
@@ -71,6 +77,7 @@ def sample_model():
batch_size
=
args
.
batch_size
,
batch_size
=
args
.
batch_size
,
temperature
=
args
.
temperature
,
top_k
=
args
.
top_k
,
device
=
device
temperature
=
args
.
temperature
,
top_k
=
args
.
top_k
,
device
=
device
)
)
out
=
out
.
tolist
()
for
i
in
range
(
args
.
batch_size
):
for
i
in
range
(
args
.
batch_size
):
generated
+=
args
.
batch_size
generated
+=
args
.
batch_size
text
=
enc
.
decode
(
out
[
i
])
text
=
enc
.
decode
(
out
[
i
])
...
...
examples/run_gpt2_interactive_conditional_samples.py
View file @
fbb248a2
...
@@ -2,8 +2,10 @@
...
@@ -2,8 +2,10 @@
import
argparse
import
argparse
import
logging
import
logging
from
tqdm
import
trange
import
torch
import
torch
import
torch.nn.functional
as
F
import
numpy
as
np
import
numpy
as
np
from
pytorch_pretrained_bert
import
GPT2LMHeadModel
,
GPT2Tokenizer
from
pytorch_pretrained_bert
import
GPT2LMHeadModel
,
GPT2Tokenizer
...
@@ -23,18 +25,20 @@ def top_k_logits(logits, k):
...
@@ -23,18 +25,20 @@ def top_k_logits(logits, k):
def
sample_sequence
(
model
,
length
,
start_token
=
None
,
batch_size
=
None
,
context
=
None
,
temperature
=
1
,
top_k
=
0
,
device
=
'cuda'
):
def
sample_sequence
(
model
,
length
,
start_token
=
None
,
batch_size
=
None
,
context
=
None
,
temperature
=
1
,
top_k
=
0
,
device
=
'cuda'
):
if
start_token
is
None
:
if
start_token
is
None
:
assert
context
is
not
None
,
'Specify exactly one of start_token and context!'
assert
context
is
not
None
,
'Specify exactly one of start_token and context!'
context
=
torch
.
tensor
(
context
,
device
=
device
)
context
=
torch
.
tensor
(
context
,
device
=
device
,
dtype
=
torch
.
long
).
unsqueeze
(
0
).
repeat
(
batch_size
,
1
)
else
:
else
:
assert
context
is
None
,
'Specify exactly one of start_token and context!'
assert
context
is
None
,
'Specify exactly one of start_token and context!'
context
=
torch
.
full
((
batch_size
,
1
),
start_token
,
device
=
device
)
context
=
torch
.
full
((
batch_size
,
1
),
start_token
,
device
=
device
,
dtype
=
torch
.
long
)
prev
=
context
prev
=
context
output
=
context
output
=
context
past
=
None
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
i
in
range
(
length
):
for
i
in
t
range
(
length
):
logits
,
past
=
model
(
prev
,
past
=
past
)
logits
,
past
=
model
(
prev
,
past
=
past
)
logits
=
logits
[:,
-
1
,
:]
/
temperature
logits
=
logits
[:,
-
1
,
:]
/
temperature
logits
=
top_k_logits
(
logits
,
k
=
top_k
)
logits
=
top_k_logits
(
logits
,
k
=
top_k
)
prev
=
torch
.
multinomial
(
logits
,
1
)
log_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
prev
=
torch
.
multinomial
(
log_probs
,
num_samples
=
1
)
output
=
torch
.
cat
((
output
,
prev
),
dim
=
1
)
output
=
torch
.
cat
((
output
,
prev
),
dim
=
1
)
return
output
return
output
...
@@ -50,7 +54,7 @@ def interact_model():
...
@@ -50,7 +54,7 @@ def interact_model():
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print
(
args
)
print
(
args
)
if
args
.
batch_size
is
None
:
if
args
.
batch_size
==
-
1
:
args
.
batch_size
=
1
args
.
batch_size
=
1
assert
args
.
nsamples
%
args
.
batch_size
==
0
assert
args
.
nsamples
%
args
.
batch_size
==
0
...
@@ -61,6 +65,8 @@ def interact_model():
...
@@ -61,6 +65,8 @@ def interact_model():
enc
=
GPT2Tokenizer
.
from_pretrained
(
args
.
model_name_or_path
)
enc
=
GPT2Tokenizer
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
GPT2LMHeadModel
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
GPT2LMHeadModel
.
from_pretrained
(
args
.
model_name_or_path
)
model
.
to
(
device
)
model
.
eval
()
if
args
.
length
==
-
1
:
if
args
.
length
==
-
1
:
args
.
length
=
model
.
config
.
n_ctx
//
2
args
.
length
=
model
.
config
.
n_ctx
//
2
...
@@ -81,7 +87,7 @@ def interact_model():
...
@@ -81,7 +87,7 @@ def interact_model():
batch_size
=
args
.
batch_size
,
batch_size
=
args
.
batch_size
,
temperature
=
args
.
temperature
,
top_k
=
args
.
top_k
,
device
=
device
temperature
=
args
.
temperature
,
top_k
=
args
.
top_k
,
device
=
device
)
)
out
=
out
[:,
len
(
context_tokens
):]
out
=
out
[:,
len
(
context_tokens
):]
.
tolist
()
for
i
in
range
(
args
.
batch_size
):
for
i
in
range
(
args
.
batch_size
):
generated
+=
1
generated
+=
1
text
=
enc
.
decode
(
out
[
i
])
text
=
enc
.
decode
(
out
[
i
])
...
...
pytorch_pretrained_bert/modeling_gpt2.py
View file @
fbb248a2
...
@@ -244,10 +244,10 @@ class Attention(nn.Module):
...
@@ -244,10 +244,10 @@ class Attention(nn.Module):
key
=
self
.
split_heads
(
key
,
k
=
True
)
key
=
self
.
split_heads
(
key
,
k
=
True
)
value
=
self
.
split_heads
(
value
)
value
=
self
.
split_heads
(
value
)
if
layer_past
is
not
None
:
if
layer_past
is
not
None
:
past_key
,
past_value
=
layer_past
[
0
]
,
layer_past
[
1
]
past_key
,
past_value
=
layer_past
[
0
]
.
transpose
(
-
2
,
-
1
),
layer_past
[
1
]
# transpose to have same shapes
key
=
torch
.
cat
((
past_key
,
key
),
dim
=-
2
)
key
=
torch
.
cat
((
past_key
,
key
),
dim
=-
1
)
value
=
torch
.
cat
((
past_value
,
value
),
dim
=-
2
)
value
=
torch
.
cat
((
past_value
,
value
),
dim
=-
2
)
present
=
torch
.
stack
((
key
,
value
))
present
=
torch
.
stack
((
key
.
transpose
(
-
2
,
-
1
)
,
value
))
a
=
self
.
_attn
(
query
,
key
,
value
)
a
=
self
.
_attn
(
query
,
key
,
value
)
a
=
self
.
merge_heads
(
a
)
a
=
self
.
merge_heads
(
a
)
a
=
self
.
c_proj
(
a
)
a
=
self
.
c_proj
(
a
)
...
@@ -278,7 +278,7 @@ class Block(nn.Module):
...
@@ -278,7 +278,7 @@ class Block(nn.Module):
self
.
mlp
=
MLP
(
4
*
nx
,
config
)
self
.
mlp
=
MLP
(
4
*
nx
,
config
)
def
forward
(
self
,
x
,
layer_past
=
None
):
def
forward
(
self
,
x
,
layer_past
=
None
):
a
,
present
=
self
.
attn
(
self
.
ln_1
(
x
),
layer_past
=
past
)
a
,
present
=
self
.
attn
(
self
.
ln_1
(
x
),
layer_past
=
layer_
past
)
x
=
x
+
a
x
=
x
+
a
m
=
self
.
mlp
(
self
.
ln_2
(
x
))
m
=
self
.
mlp
(
self
.
ln_2
(
x
))
x
=
x
+
m
x
=
x
+
m
...
@@ -531,7 +531,7 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -531,7 +531,7 @@ class GPT2Model(GPT2PreTrainedModel):
past_length
=
0
past_length
=
0
past
=
[
None
]
*
len
(
self
.
h
)
past
=
[
None
]
*
len
(
self
.
h
)
else
:
else
:
past
[
0
][
0
].
size
(
-
2
)
past_length
=
past
[
0
][
0
].
size
(
-
2
)
if
position_ids
is
None
:
if
position_ids
is
None
:
position_ids
=
torch
.
arange
(
past_length
,
input_ids
.
size
(
-
1
)
+
past_length
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_ids
=
torch
.
arange
(
past_length
,
input_ids
.
size
(
-
1
)
+
past_length
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
input_ids
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
input_ids
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment