Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ab90d4cd
Commit
ab90d4cd
authored
Jan 09, 2019
by
thomwolf
Browse files
adding docs and example for OpenAI GPT
parent
dc5df92f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
511 additions
and
44 deletions
+511
-44
examples/run_openai_gpt.py
examples/run_openai_gpt.py
+304
-0
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+2
-2
pytorch_pretrained_bert/modeling_openai.py
pytorch_pretrained_bert/modeling_openai.py
+205
-42
No files found.
examples/run_openai_gpt.py
0 → 100644
View file @
ab90d4cd
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
" Run OpenAI GPT on RocStories"
import
argparse
import
os
import
random
import
logging
from
sklearn.metrics
import
accuracy_score
from
sklearn.utils
import
shuffle
# from analysis import rocstories as rocstories_analysis
# from datasets import rocstories
# from model_pytorch import DoubleHeadModel, load_openai_pretrained_model
# from opt import OpenAIAdam
# from text_utils import TextEncoder
# from utils import (encode_dataset, iter_data,
# ResultLogger, make_path)
# from loss import MultipleChoiceLossCompute
import
numpy
as
np
import
torch
from
torch.utils.data
import
TensorDataset
,
DataLoader
,
RandomSampler
,
SequentialSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
pytorch_pretrained_bert.tokenization_openai
import
OpenAIGPTTokenizer
from
pytorch_pretrained_bert.modeling_openai
import
OpenAIGPTDoubleHeadsModel
from
pytorch_pretrained_bert.optimization_openai
import
OpenAIAdam
from
pytorch_pretrained_bert.file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
level
=
logging
.
INFO
)
logger
=
logging
.
getLogger
(
__name__
)
def
transform_roc
(
X1
,
X2
,
X3
):
n_batch
=
len
(
X1
)
xmb
=
np
.
zeros
((
n_batch
,
2
,
n_ctx
,
2
),
dtype
=
np
.
int32
)
mmb
=
np
.
zeros
((
n_batch
,
2
,
n_ctx
),
dtype
=
np
.
float32
)
start
=
encoder
[
'_start_'
]
delimiter
=
encoder
[
'_delimiter_'
]
for
i
,
(
x1
,
x2
,
x3
),
in
enumerate
(
zip
(
X1
,
X2
,
X3
)):
x12
=
[
start
]
+
x1
[:
max_len
]
+
[
delimiter
]
+
x2
[:
max_len
]
+
[
clf_token
]
x13
=
[
start
]
+
x1
[:
max_len
]
+
[
delimiter
]
+
x3
[:
max_len
]
+
[
clf_token
]
l12
=
len
(
x12
)
l13
=
len
(
x13
)
xmb
[
i
,
0
,
:
l12
,
0
]
=
x12
xmb
[
i
,
1
,
:
l13
,
0
]
=
x13
mmb
[
i
,
0
,
:
l12
]
=
1
mmb
[
i
,
1
,
:
l13
]
=
1
# Position information that is added to the input embeddings in the TransformerModel
xmb
[:,
:,
:,
1
]
=
np
.
arange
(
n_vocab
+
n_special
,
n_vocab
+
n_special
+
n_ctx
)
return
xmb
,
mmb
def
iter_apply
(
Xs
,
Ms
,
Ys
):
# fns = [lambda x: np.concatenate(x, 0), lambda x: float(np.sum(x))]
logits
=
[]
cost
=
0
with
torch
.
no_grad
():
dh_model
.
eval
()
for
xmb
,
mmb
,
ymb
in
iter_data
(
Xs
,
Ms
,
Ys
,
n_batch
=
n_batch_train
,
truncate
=
False
,
verbose
=
True
):
n
=
len
(
xmb
)
XMB
=
torch
.
tensor
(
xmb
,
dtype
=
torch
.
long
).
to
(
device
)
YMB
=
torch
.
tensor
(
ymb
,
dtype
=
torch
.
long
).
to
(
device
)
MMB
=
torch
.
tensor
(
mmb
).
to
(
device
)
_
,
clf_logits
=
dh_model
(
XMB
)
clf_logits
*=
n
clf_losses
=
compute_loss_fct
(
XMB
,
YMB
,
MMB
,
clf_logits
,
only_return_losses
=
True
)
clf_losses
*=
n
logits
.
append
(
clf_logits
.
to
(
"cpu"
).
numpy
())
cost
+=
clf_losses
.
sum
().
item
()
logits
=
np
.
concatenate
(
logits
,
0
)
return
logits
,
cost
def
iter_predict
(
Xs
,
Ms
):
logits
=
[]
with
torch
.
no_grad
():
dh_model
.
eval
()
for
xmb
,
mmb
in
iter_data
(
Xs
,
Ms
,
n_batch
=
n_batch_train
,
truncate
=
False
,
verbose
=
True
):
n
=
len
(
xmb
)
XMB
=
torch
.
tensor
(
xmb
,
dtype
=
torch
.
long
).
to
(
device
)
MMB
=
torch
.
tensor
(
mmb
).
to
(
device
)
_
,
clf_logits
=
dh_model
(
XMB
)
logits
.
append
(
clf_logits
.
to
(
"cpu"
).
numpy
())
logits
=
np
.
concatenate
(
logits
,
0
)
return
logits
def
log
(
save_dir
,
desc
):
global
best_score
print
(
"Logging"
)
tr_logits
,
tr_cost
=
iter_apply
(
trX
[:
n_valid
],
trM
[:
n_valid
],
trY
[:
n_valid
])
va_logits
,
va_cost
=
iter_apply
(
vaX
,
vaM
,
vaY
)
tr_cost
=
tr_cost
/
len
(
trY
[:
n_valid
])
va_cost
=
va_cost
/
n_valid
tr_acc
=
accuracy_score
(
trY
[:
n_valid
],
np
.
argmax
(
tr_logits
,
1
))
*
100.
va_acc
=
accuracy_score
(
vaY
,
np
.
argmax
(
va_logits
,
1
))
*
100.
logger
.
log
(
n_epochs
=
n_epochs
,
n_updates
=
n_updates
,
tr_cost
=
tr_cost
,
va_cost
=
va_cost
,
tr_acc
=
tr_acc
,
va_acc
=
va_acc
)
print
(
'%d %d %.3f %.3f %.2f %.2f'
%
(
n_epochs
,
n_updates
,
tr_cost
,
va_cost
,
tr_acc
,
va_acc
))
if
submit
:
score
=
va_acc
if
score
>
best_score
:
best_score
=
score
path
=
os
.
path
.
join
(
save_dir
,
desc
,
'best_params'
)
torch
.
save
(
dh_model
.
state_dict
(),
make_path
(
path
))
def
predict
(
dataset
,
submission_dir
):
filename
=
filenames
[
dataset
]
pred_fn
=
pred_fns
[
dataset
]
label_decoder
=
label_decoders
[
dataset
]
predictions
=
pred_fn
(
iter_predict
(
teX
,
teM
))
if
label_decoder
is
not
None
:
predictions
=
[
label_decoder
[
prediction
]
for
prediction
in
predictions
]
path
=
os
.
path
.
join
(
submission_dir
,
filename
)
os
.
makedirs
(
os
.
path
.
dirname
(
path
),
exist_ok
=
True
)
with
open
(
path
,
'w'
)
as
f
:
f
.
write
(
'{}
\t
{}
\n
'
.
format
(
'index'
,
'prediction'
))
for
i
,
prediction
in
enumerate
(
predictions
):
f
.
write
(
'{}
\t
{}
\n
'
.
format
(
i
,
prediction
))
def
run_epoch
():
for
xmb
,
mmb
,
ymb
in
iter_data
(
*
shuffle
(
trX
,
trM
,
trYt
,
random_state
=
np
.
random
),
n_batch
=
n_batch_train
,
truncate
=
True
,
verbose
=
True
):
global
n_updates
dh_model
.
train
()
XMB
=
torch
.
tensor
(
xmb
,
dtype
=
torch
.
long
).
to
(
device
)
YMB
=
torch
.
tensor
(
ymb
,
dtype
=
torch
.
long
).
to
(
device
)
MMB
=
torch
.
tensor
(
mmb
).
to
(
device
)
lm_logits
,
clf_logits
=
dh_model
(
XMB
)
compute_loss_fct
(
XMB
,
YMB
,
MMB
,
clf_logits
,
lm_logits
)
n_updates
+=
1
if
n_updates
in
[
1000
,
2000
,
4000
,
8000
,
16000
,
32000
]
and
n_epochs
==
0
:
log
(
save_dir
,
desc
)
argmax
=
lambda
x
:
np
.
argmax
(
x
,
1
)
pred_fns
=
{
'rocstories'
:
argmax
,
}
filenames
=
{
'rocstories'
:
'ROCStories.tsv'
,
}
label_decoders
=
{
'rocstories'
:
None
,
}
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--desc'
,
type
=
str
,
help
=
"Description"
)
parser
.
add_argument
(
'--dataset'
,
type
=
str
)
parser
.
add_argument
(
'--log_dir'
,
type
=
str
,
default
=
'log/'
)
parser
.
add_argument
(
'--save_dir'
,
type
=
str
,
default
=
'save/'
)
parser
.
add_argument
(
'--data_dir'
,
type
=
str
,
default
=
'data/'
)
parser
.
add_argument
(
'--submission_dir'
,
type
=
str
,
default
=
'submission/'
)
parser
.
add_argument
(
'--submit'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--analysis'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
42
)
parser
.
add_argument
(
'--n_iter'
,
type
=
int
,
default
=
3
)
parser
.
add_argument
(
'--n_batch'
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
'--max_grad_norm'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'--lr'
,
type
=
float
,
default
=
6.25e-5
)
parser
.
add_argument
(
'--lr_warmup'
,
type
=
float
,
default
=
0.002
)
parser
.
add_argument
(
'--n_ctx'
,
type
=
int
,
default
=
512
)
parser
.
add_argument
(
'--n_embd'
,
type
=
int
,
default
=
768
)
parser
.
add_argument
(
'--n_head'
,
type
=
int
,
default
=
12
)
parser
.
add_argument
(
'--n_layer'
,
type
=
int
,
default
=
12
)
parser
.
add_argument
(
'--embd_pdrop'
,
type
=
float
,
default
=
0.1
)
parser
.
add_argument
(
'--attn_pdrop'
,
type
=
float
,
default
=
0.1
)
parser
.
add_argument
(
'--resid_pdrop'
,
type
=
float
,
default
=
0.1
)
parser
.
add_argument
(
'--clf_pdrop'
,
type
=
float
,
default
=
0.1
)
parser
.
add_argument
(
'--l2'
,
type
=
float
,
default
=
0.01
)
parser
.
add_argument
(
'--vector_l2'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--opt'
,
type
=
str
,
default
=
'adam'
)
parser
.
add_argument
(
'--afn'
,
type
=
str
,
default
=
'gelu'
)
parser
.
add_argument
(
'--lr_schedule'
,
type
=
str
,
default
=
'warmup_linear'
)
parser
.
add_argument
(
'--encoder_path'
,
type
=
str
,
default
=
'model/encoder_bpe_40000.json'
)
parser
.
add_argument
(
'--bpe_path'
,
type
=
str
,
default
=
'model/vocab_40000.bpe'
)
parser
.
add_argument
(
'--n_transfer'
,
type
=
int
,
default
=
12
)
parser
.
add_argument
(
'--lm_coef'
,
type
=
float
,
default
=
0.5
)
parser
.
add_argument
(
'--b1'
,
type
=
float
,
default
=
0.9
)
parser
.
add_argument
(
'--b2'
,
type
=
float
,
default
=
0.999
)
parser
.
add_argument
(
'--e'
,
type
=
float
,
default
=
1e-8
)
parser
.
add_argument
(
'--n_valid'
,
type
=
int
,
default
=
374
)
args
=
parser
.
parse_args
()
print
(
args
)
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
torch
.
manual_seed
(
args
.
seed
)
torch
.
cuda
.
manual_seed_all
(
args
.
seed
)
# Constants
submit
=
args
.
submit
dataset
=
args
.
dataset
n_ctx
=
args
.
n_ctx
save_dir
=
args
.
save_dir
desc
=
args
.
desc
data_dir
=
args
.
data_dir
log_dir
=
args
.
log_dir
submission_dir
=
args
.
submission_dir
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
n_gpu
=
torch
.
cuda
.
device_count
()
print
(
"device"
,
device
,
"n_gpu"
,
n_gpu
)
logger
=
ResultLogger
(
path
=
os
.
path
.
join
(
log_dir
,
'{}.jsonl'
.
format
(
desc
)),
**
args
.
__dict__
)
text_encoder
=
TextEncoder
(
args
.
encoder_path
,
args
.
bpe_path
)
encoder
=
text_encoder
.
encoder
n_vocab
=
len
(
text_encoder
.
encoder
)
print
(
"Encoding dataset..."
)
((
trX1
,
trX2
,
trX3
,
trY
),
(
vaX1
,
vaX2
,
vaX3
,
vaY
),
(
teX1
,
teX2
,
teX3
))
=
encode_dataset
(
*
rocstories
(
data_dir
,
n_valid
=
args
.
n_valid
),
encoder
=
text_encoder
)
encoder
[
'_start_'
]
=
len
(
encoder
)
encoder
[
'_delimiter_'
]
=
len
(
encoder
)
encoder
[
'_classify_'
]
=
len
(
encoder
)
clf_token
=
encoder
[
'_classify_'
]
n_special
=
3
max_len
=
n_ctx
//
2
-
2
n_ctx
=
min
(
max
(
[
len
(
x1
[:
max_len
])
+
max
(
len
(
x2
[:
max_len
]),
len
(
x3
[:
max_len
]))
for
x1
,
x2
,
x3
in
zip
(
trX1
,
trX2
,
trX3
)]
+
[
len
(
x1
[:
max_len
])
+
max
(
len
(
x2
[:
max_len
]),
len
(
x3
[:
max_len
]))
for
x1
,
x2
,
x3
in
zip
(
vaX1
,
vaX2
,
vaX3
)]
+
[
len
(
x1
[:
max_len
])
+
max
(
len
(
x2
[:
max_len
]),
len
(
x3
[:
max_len
]))
for
x1
,
x2
,
x3
in
zip
(
teX1
,
teX2
,
teX3
)]
)
+
3
,
n_ctx
)
vocab
=
n_vocab
+
n_special
+
n_ctx
trX
,
trM
=
transform_roc
(
trX1
,
trX2
,
trX3
)
vaX
,
vaM
=
transform_roc
(
vaX1
,
vaX2
,
vaX3
)
if
submit
:
teX
,
teM
=
transform_roc
(
teX1
,
teX2
,
teX3
)
n_train
=
len
(
trY
)
n_valid
=
len
(
vaY
)
n_batch_train
=
args
.
n_batch
*
max
(
n_gpu
,
1
)
n_updates_total
=
(
n_train
//
n_batch_train
)
*
args
.
n_iter
dh_model
=
DoubleHeadModel
(
args
,
clf_token
,
'multiple_choice'
,
vocab
,
n_ctx
)
criterion
=
nn
.
CrossEntropyLoss
(
reduce
=
False
)
model_opt
=
OpenAIAdam
(
dh_model
.
parameters
(),
lr
=
args
.
lr
,
schedule
=
args
.
lr_schedule
,
warmup
=
args
.
lr_warmup
,
t_total
=
n_updates_total
,
b1
=
args
.
b1
,
b2
=
args
.
b2
,
e
=
args
.
e
,
l2
=
args
.
l2
,
vector_l2
=
args
.
vector_l2
,
max_grad_norm
=
args
.
max_grad_norm
)
compute_loss_fct
=
MultipleChoiceLossCompute
(
criterion
,
criterion
,
args
.
lm_coef
,
model_opt
)
load_openai_pretrained_model
(
dh_model
.
transformer
,
n_ctx
=
n_ctx
,
n_special
=
n_special
)
dh_model
.
to
(
device
)
dh_model
=
nn
.
DataParallel
(
dh_model
)
n_updates
=
0
n_epochs
=
0
if
dataset
!=
'stsb'
:
trYt
=
trY
if
submit
:
path
=
os
.
path
.
join
(
save_dir
,
desc
,
'best_params'
)
torch
.
save
(
dh_model
.
state_dict
(),
make_path
(
path
))
best_score
=
0
for
i
in
range
(
args
.
n_iter
):
print
(
"running epoch"
,
i
)
run_epoch
()
n_epochs
+=
1
log
(
save_dir
,
desc
)
if
submit
:
path
=
os
.
path
.
join
(
save_dir
,
desc
,
'best_params'
)
dh_model
.
load_state_dict
(
torch
.
load
(
path
))
predict
(
dataset
,
args
.
submission_dir
)
if
args
.
analysis
:
rocstories_analysis
(
data_dir
,
os
.
path
.
join
(
args
.
submission_dir
,
'ROCStories.tsv'
),
os
.
path
.
join
(
log_dir
,
'rocstories.jsonl'
))
pytorch_pretrained_bert/modeling.py
View file @
ab90d4cd
...
@@ -659,10 +659,10 @@ class BertForPreTraining(BertPreTrainedModel):
...
@@ -659,10 +659,10 @@ class BertForPreTraining(BertPreTrainedModel):
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
a batch has varying length sentences.
`masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
`masked_lm_labels`:
optional
masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., vocab_size]
is only computed for the labels set in [0, ..., vocab_size]
`next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size]
`next_sentence_label`:
optional
next sentence classification loss: torch.LongTensor of shape [batch_size]
with indices selected in [0, 1].
with indices selected in [0, 1].
0 => next sentence is the continuation, 1 => next sentence is a random sentence.
0 => next sentence is the continuation, 1 => next sentence is a random sentence.
...
...
pytorch_pretrained_bert/modeling_openai.py
View file @
ab90d4cd
...
@@ -149,19 +149,19 @@ class Conv1D(nn.Module):
...
@@ -149,19 +149,19 @@ class Conv1D(nn.Module):
class
Attention
(
nn
.
Module
):
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
nx
,
n_ctx
,
c
f
g
,
scale
=
False
):
def
__init__
(
self
,
nx
,
n_ctx
,
c
onfi
g
,
scale
=
False
):
super
(
Attention
,
self
).
__init__
()
super
(
Attention
,
self
).
__init__
()
n_state
=
nx
# in Attention: n_state=768 (nx=n_embd)
n_state
=
nx
# in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
assert
n_state
%
c
f
g
.
n_head
==
0
assert
n_state
%
c
onfi
g
.
n_head
==
0
self
.
register_buffer
(
'b'
,
torch
.
tril
(
torch
.
ones
(
n_ctx
,
n_ctx
)).
view
(
1
,
1
,
n_ctx
,
n_ctx
))
self
.
register_buffer
(
'b'
,
torch
.
tril
(
torch
.
ones
(
n_ctx
,
n_ctx
)).
view
(
1
,
1
,
n_ctx
,
n_ctx
))
self
.
n_head
=
c
f
g
.
n_head
self
.
n_head
=
c
onfi
g
.
n_head
self
.
split_size
=
n_state
self
.
split_size
=
n_state
self
.
scale
=
scale
self
.
scale
=
scale
self
.
c_attn
=
Conv1D
(
n_state
*
3
,
1
,
nx
)
self
.
c_attn
=
Conv1D
(
n_state
*
3
,
1
,
nx
)
self
.
c_proj
=
Conv1D
(
n_state
,
1
,
nx
)
self
.
c_proj
=
Conv1D
(
n_state
,
1
,
nx
)
self
.
attn_dropout
=
nn
.
Dropout
(
c
f
g
.
attn_pdrop
)
self
.
attn_dropout
=
nn
.
Dropout
(
c
onfi
g
.
attn_pdrop
)
self
.
resid_dropout
=
nn
.
Dropout
(
c
f
g
.
resid_pdrop
)
self
.
resid_dropout
=
nn
.
Dropout
(
c
onfi
g
.
resid_pdrop
)
def
_attn
(
self
,
q
,
k
,
v
):
def
_attn
(
self
,
q
,
k
,
v
):
w
=
torch
.
matmul
(
q
,
k
)
w
=
torch
.
matmul
(
q
,
k
)
...
@@ -203,13 +203,13 @@ class Attention(nn.Module):
...
@@ -203,13 +203,13 @@ class Attention(nn.Module):
class
MLP
(
nn
.
Module
):
class
MLP
(
nn
.
Module
):
def
__init__
(
self
,
n_state
,
c
f
g
):
# in MLP: n_state=3072 (4 * n_embd)
def
__init__
(
self
,
n_state
,
c
onfi
g
):
# in MLP: n_state=3072 (4 * n_embd)
super
(
MLP
,
self
).
__init__
()
super
(
MLP
,
self
).
__init__
()
nx
=
c
f
g
.
n_embd
nx
=
c
onfi
g
.
n_embd
self
.
c_fc
=
Conv1D
(
n_state
,
1
,
nx
)
self
.
c_fc
=
Conv1D
(
n_state
,
1
,
nx
)
self
.
c_proj
=
Conv1D
(
nx
,
1
,
n_state
)
self
.
c_proj
=
Conv1D
(
nx
,
1
,
n_state
)
self
.
act
=
ACT_FNS
[
c
f
g
.
afn
]
self
.
act
=
ACT_FNS
[
c
onfi
g
.
afn
]
self
.
dropout
=
nn
.
Dropout
(
c
f
g
.
resid_pdrop
)
self
.
dropout
=
nn
.
Dropout
(
c
onfi
g
.
resid_pdrop
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
h
=
self
.
act
(
self
.
c_fc
(
x
))
h
=
self
.
act
(
self
.
c_fc
(
x
))
...
@@ -218,12 +218,12 @@ class MLP(nn.Module):
...
@@ -218,12 +218,12 @@ class MLP(nn.Module):
class
Block
(
nn
.
Module
):
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
n_ctx
,
c
f
g
,
scale
=
False
):
def
__init__
(
self
,
n_ctx
,
c
onfi
g
,
scale
=
False
):
super
(
Block
,
self
).
__init__
()
super
(
Block
,
self
).
__init__
()
nx
=
c
f
g
.
n_embd
nx
=
c
onfi
g
.
n_embd
self
.
attn
=
Attention
(
nx
,
n_ctx
,
c
f
g
,
scale
)
self
.
attn
=
Attention
(
nx
,
n_ctx
,
c
onfi
g
,
scale
)
self
.
ln_1
=
LayerNorm
(
nx
)
self
.
ln_1
=
LayerNorm
(
nx
)
self
.
mlp
=
MLP
(
4
*
nx
,
c
f
g
)
self
.
mlp
=
MLP
(
4
*
nx
,
c
onfi
g
)
self
.
ln_2
=
LayerNorm
(
nx
)
self
.
ln_2
=
LayerNorm
(
nx
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -237,9 +237,9 @@ class Block(nn.Module):
...
@@ -237,9 +237,9 @@ class Block(nn.Module):
class
OpenAIGPTLMHead
(
nn
.
Module
):
class
OpenAIGPTLMHead
(
nn
.
Module
):
""" Language Model Head for the transformer """
""" Language Model Head for the transformer """
def
__init__
(
self
,
model_embeddings_weights
,
c
f
g
):
def
__init__
(
self
,
model_embeddings_weights
,
c
onfi
g
):
super
(
OpenAIGPTLMHead
,
self
).
__init__
()
super
(
OpenAIGPTLMHead
,
self
).
__init__
()
self
.
n_embd
=
c
f
g
.
n_embd
self
.
n_embd
=
c
onfi
g
.
n_embd
self
.
set_embeddings_weights
(
model_embeddings_weights
)
self
.
set_embeddings_weights
(
model_embeddings_weights
)
def
set_embeddings_weights
(
self
,
model_embeddings_weights
):
def
set_embeddings_weights
(
self
,
model_embeddings_weights
):
...
@@ -257,12 +257,12 @@ class OpenAIGPTLMHead(nn.Module):
...
@@ -257,12 +257,12 @@ class OpenAIGPTLMHead(nn.Module):
class
OpenAIGPTMultipleChoiceHead
(
nn
.
Module
):
class
OpenAIGPTMultipleChoiceHead
(
nn
.
Module
):
""" Classifier Head for the transformer """
""" Classifier Head for the transformer """
def
__init__
(
self
,
c
f
g
):
def
__init__
(
self
,
c
onfi
g
):
super
(
OpenAIGPTMultipleChoiceHead
,
self
).
__init__
()
super
(
OpenAIGPTMultipleChoiceHead
,
self
).
__init__
()
self
.
n_embd
=
c
f
g
.
n_embd
self
.
n_embd
=
c
onfi
g
.
n_embd
# self.multiple_choice_token = multiple_choice_token
# self.multiple_choice_token = multiple_choice_token
self
.
dropout
=
nn
.
Dropout2d
(
c
f
g
.
resid_pdrop
)
# To reproduce the noise_shape parameter of TF implementation
self
.
dropout
=
nn
.
Dropout2d
(
c
onfi
g
.
resid_pdrop
)
# To reproduce the noise_shape parameter of TF implementation
self
.
linear
=
nn
.
Linear
(
c
f
g
.
n_embd
,
1
)
self
.
linear
=
nn
.
Linear
(
c
onfi
g
.
n_embd
,
1
)
nn
.
init
.
normal_
(
self
.
linear
.
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
linear
.
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
linear
.
bias
,
0
)
nn
.
init
.
normal_
(
self
.
linear
.
bias
,
0
)
...
@@ -428,15 +428,63 @@ class OpenAIGPTPreTrainedModel(nn.Module):
...
@@ -428,15 +428,63 @@ class OpenAIGPTPreTrainedModel(nn.Module):
class
OpenAIGPTModel
(
OpenAIGPTPreTrainedModel
):
class
OpenAIGPTModel
(
OpenAIGPTPreTrainedModel
):
""" OpenAI GPT model """
"""OpenAI GPT model ("Improving Language Understanding by Generative Pre-Training").
The main implementation difference between BERT and the OpenAI is the use, in OpenAI GPT, of a single embedding matrix
to store the word, special ([SEP], [CLS]...) and position embeddings.
The embeddings are ordered as follow in the word embeddings matrice:
[0, ----------------------
... -> word embeddings
config.vocab_size - 1, ______________________
config.vocab_size,
... -> special embeddings
config.vocab_size + config.n_special - 1, ______________________
config.vocab_size + config.n_special,
... -> position embeddings
total_num_embeddings - 1] ______________________
where total_num_embeddings can be obtained as config.total_num_embeddings and is:
total_num_embeddings = config.vocab_size + config.n_special + config.n_ctx
You should use the associate indices to index the embeddings.
The special embeddings ([SEP], [CLS]...) are not pre-trained and need to be trained during the fine-tuning if you use them.
The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.
def
__init__
(
self
,
cfg
):
Params:
super
(
OpenAIGPTModel
,
self
).
__init__
(
cfg
)
config: a OpenAIGPTConfig class instance with the configuration to build a new model
total_embeddings_size
=
cfg
.
vocab_size
+
cfg
.
n_special
+
cfg
.
n_ctx
self
.
embed
=
nn
.
Embedding
(
total_embeddings_size
,
cfg
.
n_embd
)
Inputs:
self
.
drop
=
nn
.
Dropout
(
cfg
.
embd_pdrop
)
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
block
=
Block
(
cfg
.
n_ctx
,
cfg
,
scale
=
True
)
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
self
.
h
=
nn
.
ModuleList
([
copy
.
deepcopy
(
block
)
for
_
in
range
(
cfg
.
n_layer
)])
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
with the position indices (selected in the range [config.vocab_size + config.n_special, config.vocab_size + config.n_special + config.n_ctx - 1[.
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
You can use it to add a third embedding (the previous two being the word and position embeddings)
to each token in the sentence.
Outputs:
`hidden_states`: the encoded-hidden-states at the top of the model
as a torch.FloatTensor of size [batch_size, sequence_length, hidden_size]
(or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
Example usage:
```python
# Already been converted into BPE token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
config = modeling_openai.OpenAIGPTConfig()
model = modeling_openai.OpenAIGPTModel(config)
hidden_states = model(input_ids)
```
"""
def
__init__
(
self
,
config
):
super
(
OpenAIGPTModel
,
self
).
__init__
(
config
)
total_embeddings_size
=
config
.
vocab_size
+
config
.
n_special
+
config
.
n_ctx
self
.
embed
=
nn
.
Embedding
(
total_embeddings_size
,
config
.
n_embd
)
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
)
self
.
h
=
nn
.
ModuleList
([
copy
.
deepcopy
(
block
)
for
_
in
range
(
config
.
n_layer
)])
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
# nn.init.normal_(self.embed.weight, std=0.02)
# nn.init.normal_(self.embed.weight, std=0.02)
...
@@ -480,11 +528,67 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
...
@@ -480,11 +528,67 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
return
hidden_states
.
view
(
*
input_shape
,
hidden_states
.
size
(
-
1
))
return
hidden_states
.
view
(
*
input_shape
,
hidden_states
.
size
(
-
1
))
class
OpenAIGPTLMHeadModel
(
OpenAIGPTPreTrainedModel
):
class
OpenAIGPTLMHeadModel
(
OpenAIGPTPreTrainedModel
):
""" OpenAI GPT model with language model and classification heads """
"""OpenAI GPT model with a Language Modeling head ("Improving Language Understanding by Generative Pre-Training").
def
__init__
(
self
,
cfg
):
super
(
OpenAIGPTLMHeadModel
,
self
).
__init__
(
cfg
)
There are two main implementation differences between BERT and the OpenAI GPT:
self
.
transformer
=
OpenAIGPTModel
(
cfg
)
- the use of an LM loss in OpenAI GPT which means the Transformer is trained to predict the NEXT token for each input token
self
.
lm_head
=
OpenAIGPTLMHead
(
self
.
transformer
.
embed
.
weight
,
cfg
)
vs. predict the SAME token for BERT (i.e. you need to shift your labels to the right)
- the use, in OpenAI GPT, of a single embedding matrix to store the word, special ([SEP], [CLS]...) and position embeddings.
The embeddings are ordered as follow in the word embeddings matrice:
[0, ----------------------
... -> word embeddings
config.vocab_size - 1, ______________________
config.vocab_size,
... -> special embeddings
config.vocab_size + config.n_special - 1, ______________________
config.vocab_size + config.n_special,
... -> position embeddings
total_num_embeddings - 1] ______________________
where total_num_embeddings can be obtained as config.total_num_embeddings and is:
total_num_embeddings = config.vocab_size + config.n_special + config.n_ctx
You should use these indices to index the word, special and position embeddings.
The special embeddings ([SEP], [CLS]...) are not pre-trained and need to be trained during the fine-tuning if you use them.
The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.
Params:
config: a OpenAIGPTConfig class instance with the configuration to build a new model
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
with the position indices (selected in the range [config.vocab_size + config.n_special, config.vocab_size + config.n_special + config.n_ctx - 1[.
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
You can use it to add a third embedding (the previous two being the word and position embeddings)
to each token in the sentence.
`lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., vocab_size]
Outputs:
if `lm_labels` is not `None`:
Outputs the language modeling loss.
else:
`lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, sequence_length, total_num_embeddings]
(or more generally [d_1, ..., d_n, total_num_embeddings] were d_1 ... d_n are the dimension of input_ids)
Example usage:
```python
# Already been converted into BPE token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
config = modeling_openai.OpenAIGPTConfig()
model = modeling_openai.OpenAIGPTLMHeadModel(config)
lm_logits = model(input_ids)
```
"""
def
__init__
(
self
,
config
):
super
(
OpenAIGPTLMHeadModel
,
self
).
__init__
(
config
)
self
.
transformer
=
OpenAIGPTModel
(
config
)
self
.
lm_head
=
OpenAIGPTLMHead
(
self
.
transformer
.
embed
.
weight
,
config
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
def
set_num_special_tokens
(
self
,
num_special_tokens
):
def
set_num_special_tokens
(
self
,
num_special_tokens
):
...
@@ -502,12 +606,74 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
...
@@ -502,12 +606,74 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
return
lm_logits
return
lm_logits
class
OpenAIGPTDoubleHeadsModel
(
OpenAIGPTPreTrainedModel
):
class
OpenAIGPTDoubleHeadsModel
(
OpenAIGPTPreTrainedModel
):
""" OpenAI GPT model with language model and classification heads """
"""OpenAI GPT model with a Language Modeling and a Multiple Choice heads ("Improving Language Understanding by Generative Pre-Training").
def
__init__
(
self
,
cfg
):
super
(
OpenAIGPTDoubleHeadsModel
,
self
).
__init__
(
cfg
)
There are two main implementation differences between BERT and the OpenAI GPT:
self
.
transformer
=
OpenAIGPTModel
(
cfg
)
- the use of an LM loss in OpenAI GPT which means the Transformer is trained to predict the NEXT token for each input token
self
.
lm_head
=
OpenAIGPTLMHead
(
self
.
transformer
.
embed
.
weight
,
cfg
)
vs. predict the SAME token for BERT (i.e. you need to shift your labels to the right)
self
.
multiple_choice_head
=
OpenAIGPTMultipleChoiceHead
(
cfg
)
- the use, in OpenAI GPT, of a single embedding matrix to store the word, special ([SEP], [CLS]...) and position embeddings.
The embeddings are ordered as follow in the word embeddings matrice:
[0, ----------------------
... -> word embeddings
config.vocab_size - 1, ______________________
config.vocab_size,
... -> special embeddings
config.vocab_size + config.n_special - 1, ______________________
config.vocab_size + config.n_special,
... -> position embeddings
total_num_embeddings - 1] ______________________
where total_num_embeddings can be obtained as config.total_num_embeddings and is:
total_num_embeddings = config.vocab_size + config.n_special + config.n_ctx
You should use these indices to index the word, special and position embeddings.
The special embeddings ([SEP], [CLS]...) are not pre-trained and need to be trained during the fine-tuning if you use them.
The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.
Params:
config: a OpenAIGPTConfig class instance with the configuration to build a new model
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
with the word BPE token indices selected in the range [0, config.vocab_size[
`multiple_choice_token_mask`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
with a value of 1 were the last hidden state is (usually the [CLS] token) and 0 otherwise.
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
with the position indices (selected in the range [config.vocab_size + config.n_special,
config.vocab_size + config.n_special + config.n_ctx - 1[.
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
You can use it to add a third embedding (the previous two being the word and position embeddings)
to each token in the sentence.
`lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, num_choices, sequence_length]
with indices selected in [-1, 0, ..., total_num_embeddings]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., total_num_embeddings]
`multiple_choice_labels`: optional multiple choice labels: torch.LongTensor of shape [batch_size]
with indices selected in [0, ..., num_choices].
Outputs:
if `lm_labels` and `multiple_choice_labels` are not `None`:
Outputs a tuple of losses with the language modeling loss and the multiple choice loss.
else: a tuple with
`lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, num_choices, sequence_length, total_num_embeddings]
`multiple_choice_logits`: the multiple choice logits as a torch.FloatTensor of size [batch_size, num_choices]
Example usage:
```python
# Already been converted into BPE token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
multiple_choice_token_mask = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = modeling_openai.OpenAIGPTConfig()
model = modeling_openai.OpenAIGPTLMHeadModel(config)
lm_logits, multiple_choice_logits = model(input_ids, multiple_choice_token_mask)
```
"""
def
__init__
(
self
,
config
):
super
(
OpenAIGPTDoubleHeadsModel
,
self
).
__init__
(
config
)
self
.
transformer
=
OpenAIGPTModel
(
config
)
self
.
lm_head
=
OpenAIGPTLMHead
(
self
.
transformer
.
embed
.
weight
,
config
)
self
.
multiple_choice_head
=
OpenAIGPTMultipleChoiceHead
(
config
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
def
set_num_special_tokens
(
self
,
num_special_tokens
):
def
set_num_special_tokens
(
self
,
num_special_tokens
):
...
@@ -517,9 +683,6 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
...
@@ -517,9 +683,6 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
def
forward
(
self
,
input_ids
,
multiple_choice_token_mask
,
position_ids
=
None
,
token_type_ids
=
None
,
def
forward
(
self
,
input_ids
,
multiple_choice_token_mask
,
position_ids
=
None
,
token_type_ids
=
None
,
lm_labels
=
None
,
multiple_choice_labels
=
None
):
lm_labels
=
None
,
multiple_choice_labels
=
None
):
""" input_ids should be of shape B x C x S
lm_labels can be masked using the -1 value
"""
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
)
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
multiple_choice_logits
=
self
.
multiple_choice_head
(
hidden_states
,
multiple_choice_token_mask
)
multiple_choice_logits
=
self
.
multiple_choice_head
(
hidden_states
,
multiple_choice_token_mask
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment