Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
690a0dbf
Commit
690a0dbf
authored
Feb 18, 2019
by
thomwolf
Browse files
fix example - masking
parent
fbb248a2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
24 deletions
+24
-24
examples/run_gpt2.py
examples/run_gpt2.py
+16
-11
pytorch_pretrained_bert/modeling_gpt2.py
pytorch_pretrained_bert/modeling_gpt2.py
+8
-13
No files found.
examples/run_gpt2
_interactive_conditional_samples
.py
→
examples/run_gpt2.py
View file @
690a0dbf
...
@@ -22,7 +22,7 @@ def top_k_logits(logits, k):
...
@@ -22,7 +22,7 @@ def top_k_logits(logits, k):
min_values
=
values
[:,
-
1
]
min_values
=
values
[:,
-
1
]
return
torch
.
where
(
logits
<
min_values
,
torch
.
ones_like
(
logits
,
dtype
=
logits
.
dtype
)
*
-
1e10
,
logits
)
return
torch
.
where
(
logits
<
min_values
,
torch
.
ones_like
(
logits
,
dtype
=
logits
.
dtype
)
*
-
1e10
,
logits
)
def
sample_sequence
(
model
,
length
,
start_token
=
None
,
batch_size
=
None
,
context
=
None
,
temperature
=
1
,
top_k
=
0
,
device
=
'cuda'
):
def
sample_sequence
(
model
,
length
,
start_token
=
None
,
batch_size
=
None
,
context
=
None
,
temperature
=
1
,
top_k
=
0
,
device
=
'cuda'
,
sample
=
True
):
if
start_token
is
None
:
if
start_token
is
None
:
assert
context
is
not
None
,
'Specify exactly one of start_token and context!'
assert
context
is
not
None
,
'Specify exactly one of start_token and context!'
context
=
torch
.
tensor
(
context
,
device
=
device
,
dtype
=
torch
.
long
).
unsqueeze
(
0
).
repeat
(
batch_size
,
1
)
context
=
torch
.
tensor
(
context
,
device
=
device
,
dtype
=
torch
.
long
).
unsqueeze
(
0
).
repeat
(
batch_size
,
1
)
...
@@ -38,11 +38,14 @@ def sample_sequence(model, length, start_token=None, batch_size=None, context=No
...
@@ -38,11 +38,14 @@ def sample_sequence(model, length, start_token=None, batch_size=None, context=No
logits
=
logits
[:,
-
1
,
:]
/
temperature
logits
=
logits
[:,
-
1
,
:]
/
temperature
logits
=
top_k_logits
(
logits
,
k
=
top_k
)
logits
=
top_k_logits
(
logits
,
k
=
top_k
)
log_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
log_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
prev
=
torch
.
multinomial
(
log_probs
,
num_samples
=
1
)
if
sample
:
prev
=
torch
.
multinomial
(
log_probs
,
num_samples
=
1
)
else
:
_
,
prev
=
torch
.
topk
(
log_probs
,
k
=
1
,
dim
=-
1
)
output
=
torch
.
cat
((
output
,
prev
),
dim
=
1
)
output
=
torch
.
cat
((
output
,
prev
),
dim
=
1
)
return
output
return
output
def
interact
_model
():
def
run
_model
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--model_name_or_path'
,
type
=
str
,
default
=
'gpt2'
,
help
=
'pretrained model name or path to local checkpoint'
)
parser
.
add_argument
(
'--model_name_or_path'
,
type
=
str
,
default
=
'gpt2'
,
help
=
'pretrained model name or path to local checkpoint'
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
...
@@ -51,6 +54,7 @@ def interact_model():
...
@@ -51,6 +54,7 @@ def interact_model():
parser
.
add_argument
(
"--length"
,
type
=
int
,
default
=-
1
)
parser
.
add_argument
(
"--length"
,
type
=
int
,
default
=-
1
)
parser
.
add_argument
(
"--temperature"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--temperature"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'--unconditional'
,
action
=
'store_true'
,
help
=
'If true, unconditional generation.'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print
(
args
)
print
(
args
)
...
@@ -73,17 +77,19 @@ def interact_model():
...
@@ -73,17 +77,19 @@ def interact_model():
elif
args
.
length
>
model
.
config
.
n_ctx
:
elif
args
.
length
>
model
.
config
.
n_ctx
:
raise
ValueError
(
"Can't get samples longer than window size: %s"
%
model
.
config
.
n_ctx
)
raise
ValueError
(
"Can't get samples longer than window size: %s"
%
model
.
config
.
n_ctx
)
while
True
:
while
not
args
.
unconditional
:
raw_text
=
input
(
"Model prompt >>> "
)
if
not
args
.
unconditional
:
while
not
raw_text
:
print
(
'Prompt should not be empty!'
)
raw_text
=
input
(
"Model prompt >>> "
)
raw_text
=
input
(
"Model prompt >>> "
)
context_tokens
=
enc
.
encode
(
raw_text
)
while
not
raw_text
:
print
(
'Prompt should not be empty!'
)
raw_text
=
input
(
"Model prompt >>> "
)
context_tokens
=
enc
.
encode
(
raw_text
)
generated
=
0
generated
=
0
for
_
in
range
(
args
.
nsamples
//
args
.
batch_size
):
for
_
in
range
(
args
.
nsamples
//
args
.
batch_size
):
out
=
sample_sequence
(
out
=
sample_sequence
(
model
=
model
,
length
=
args
.
length
,
model
=
model
,
length
=
args
.
length
,
context
=
context_tokens
,
context
=
context_tokens
if
not
args
.
unconditional
else
None
,
start_token
=
enc
.
encoder
[
'<|endoftext|>'
]
if
args
.
unconditional
else
None
,
batch_size
=
args
.
batch_size
,
batch_size
=
args
.
batch_size
,
temperature
=
args
.
temperature
,
top_k
=
args
.
top_k
,
device
=
device
temperature
=
args
.
temperature
,
top_k
=
args
.
top_k
,
device
=
device
)
)
...
@@ -96,5 +102,4 @@ def interact_model():
...
@@ -96,5 +102,4 @@ def interact_model():
print
(
"="
*
80
)
print
(
"="
*
80
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
interact_model
()
run_model
()
pytorch_pretrained_bert/modeling_gpt2.py
View file @
690a0dbf
...
@@ -87,10 +87,6 @@ def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path):
...
@@ -87,10 +87,6 @@ def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path):
if
len
(
l
)
>=
2
:
if
len
(
l
)
>=
2
:
num
=
int
(
l
[
1
])
num
=
int
(
l
[
1
])
pointer
=
pointer
[
num
]
pointer
=
pointer
[
num
]
if
m_name
[
-
11
:]
==
'_embeddings'
:
pointer
=
getattr
(
pointer
,
'weight'
)
elif
m_name
==
'kernel'
:
array
=
np
.
transpose
(
array
)
try
:
try
:
assert
pointer
.
shape
==
array
.
shape
assert
pointer
.
shape
==
array
.
shape
except
AssertionError
as
e
:
except
AssertionError
as
e
:
...
@@ -216,10 +212,9 @@ class Attention(nn.Module):
...
@@ -216,10 +212,9 @@ class Attention(nn.Module):
w
=
torch
.
matmul
(
q
,
k
)
w
=
torch
.
matmul
(
q
,
k
)
if
self
.
scale
:
if
self
.
scale
:
w
=
w
/
math
.
sqrt
(
v
.
size
(
-
1
))
w
=
w
/
math
.
sqrt
(
v
.
size
(
-
1
))
# w = w * self.bias + -1e9 * (1 - self.bias) # TF implem method: mask_attn_weights
nd
,
ns
=
w
.
size
(
-
2
),
w
.
size
(
-
1
)
# XD: self.b may be larger than w, so we need to crop it
b
=
self
.
bias
[:,
:,
ns
-
nd
:
ns
,
:
ns
]
b
=
self
.
bias
[:,
:,
:
w
.
size
(
-
2
),
:
w
.
size
(
-
1
)]
w
=
w
*
b
-
1e10
*
(
1
-
b
)
w
=
w
*
b
+
-
1e10
*
(
1
-
b
)
w
=
nn
.
Softmax
(
dim
=-
1
)(
w
)
w
=
nn
.
Softmax
(
dim
=-
1
)(
w
)
return
torch
.
matmul
(
w
,
v
)
return
torch
.
matmul
(
w
,
v
)
...
@@ -233,9 +228,9 @@ class Attention(nn.Module):
...
@@ -233,9 +228,9 @@ class Attention(nn.Module):
new_x_shape
=
x
.
size
()[:
-
1
]
+
(
self
.
n_head
,
x
.
size
(
-
1
)
//
self
.
n_head
)
new_x_shape
=
x
.
size
()[:
-
1
]
+
(
self
.
n_head
,
x
.
size
(
-
1
)
//
self
.
n_head
)
x
=
x
.
view
(
*
new_x_shape
)
# in Tensorflow implem: fct split_states
x
=
x
.
view
(
*
new_x_shape
)
# in Tensorflow implem: fct split_states
if
k
:
if
k
:
return
x
.
permute
(
0
,
2
,
3
,
1
)
return
x
.
permute
(
0
,
2
,
3
,
1
)
# (batch, head, head_features, seq_length)
else
:
else
:
return
x
.
permute
(
0
,
2
,
1
,
3
)
return
x
.
permute
(
0
,
2
,
1
,
3
)
# (batch, head, seq_length, head_features)
def
forward
(
self
,
x
,
layer_past
=
None
):
def
forward
(
self
,
x
,
layer_past
=
None
):
x
=
self
.
c_attn
(
x
)
x
=
self
.
c_attn
(
x
)
...
@@ -244,10 +239,10 @@ class Attention(nn.Module):
...
@@ -244,10 +239,10 @@ class Attention(nn.Module):
key
=
self
.
split_heads
(
key
,
k
=
True
)
key
=
self
.
split_heads
(
key
,
k
=
True
)
value
=
self
.
split_heads
(
value
)
value
=
self
.
split_heads
(
value
)
if
layer_past
is
not
None
:
if
layer_past
is
not
None
:
past_key
,
past_value
=
layer_past
[
0
].
transpose
(
-
2
,
-
1
),
layer_past
[
1
]
# transpose
to have same shapes
past_key
,
past_value
=
layer_past
[
0
].
transpose
(
-
2
,
-
1
),
layer_past
[
1
]
# transpose
back cf below
key
=
torch
.
cat
((
past_key
,
key
),
dim
=-
1
)
key
=
torch
.
cat
((
past_key
,
key
),
dim
=-
1
)
value
=
torch
.
cat
((
past_value
,
value
),
dim
=-
2
)
value
=
torch
.
cat
((
past_value
,
value
),
dim
=-
2
)
present
=
torch
.
stack
((
key
.
transpose
(
-
2
,
-
1
),
value
))
present
=
torch
.
stack
((
key
.
transpose
(
-
2
,
-
1
),
value
))
# transpose to have same shapes for stacking
a
=
self
.
_attn
(
query
,
key
,
value
)
a
=
self
.
_attn
(
query
,
key
,
value
)
a
=
self
.
merge_heads
(
a
)
a
=
self
.
merge_heads
(
a
)
a
=
self
.
c_proj
(
a
)
a
=
self
.
c_proj
(
a
)
...
@@ -522,7 +517,7 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -522,7 +517,7 @@ class GPT2Model(GPT2PreTrainedModel):
self
.
wpe
=
nn
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
)
self
.
wpe
=
nn
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
)
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
)
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
)
self
.
h
=
nn
.
ModuleList
([
copy
.
deepcopy
(
block
)
for
_
in
range
(
config
.
n_layer
)])
self
.
h
=
nn
.
ModuleList
([
copy
.
deepcopy
(
block
)
for
_
in
range
(
config
.
n_layer
)])
self
.
ln_f
=
LayerNorm
(
config
.
n_embd
)
self
.
ln_f
=
LayerNorm
(
config
.
n_embd
,
eps
=
config
.
layer_norm_epsilon
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment