Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1320e4ec
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "2c9e83f7b85edc677d4a654c63eceaefc3441c9f"
Commit
1320e4ec
authored
Feb 09, 2019
by
thomwolf
Browse files
mc_token_mask => mc_token_ids
parent
f4a07a39
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
40 deletions
+35
-40
examples/run_openai_gpt.py
examples/run_openai_gpt.py
+9
-9
pytorch_pretrained_bert/modeling_openai.py
pytorch_pretrained_bert/modeling_openai.py
+18
-23
tests/modeling_openai_test.py
tests/modeling_openai_test.py
+8
-8
No files found.
examples/run_openai_gpt.py
View file @
1320e4ec
...
@@ -64,7 +64,7 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d
...
@@ -64,7 +64,7 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d
for
dataset
in
encoded_datasets
:
for
dataset
in
encoded_datasets
:
n_batch
=
len
(
dataset
)
n_batch
=
len
(
dataset
)
input_ids
=
np
.
zeros
((
n_batch
,
2
,
input_len
),
dtype
=
np
.
int64
)
input_ids
=
np
.
zeros
((
n_batch
,
2
,
input_len
),
dtype
=
np
.
int64
)
mc_token_
mask
=
np
.
zeros
((
n_batch
,
2
,
input_len
),
dtype
=
np
.
int64
)
mc_token_
ids
=
np
.
zeros
((
n_batch
,
2
),
dtype
=
np
.
int64
)
lm_labels
=
np
.
full
((
n_batch
,
2
,
input_len
),
fill_value
=-
1
,
dtype
=
np
.
int64
)
lm_labels
=
np
.
full
((
n_batch
,
2
,
input_len
),
fill_value
=-
1
,
dtype
=
np
.
int64
)
mc_labels
=
np
.
zeros
((
n_batch
,),
dtype
=
np
.
int64
)
mc_labels
=
np
.
zeros
((
n_batch
,),
dtype
=
np
.
int64
)
for
i
,
(
story
,
cont1
,
cont2
,
mc_label
),
in
enumerate
(
dataset
):
for
i
,
(
story
,
cont1
,
cont2
,
mc_label
),
in
enumerate
(
dataset
):
...
@@ -72,12 +72,12 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d
...
@@ -72,12 +72,12 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d
with_cont2
=
[
start_token
]
+
story
[:
cap_length
]
+
[
delimiter_token
]
+
cont2
[:
cap_length
]
+
[
clf_token
]
with_cont2
=
[
start_token
]
+
story
[:
cap_length
]
+
[
delimiter_token
]
+
cont2
[:
cap_length
]
+
[
clf_token
]
input_ids
[
i
,
0
,
:
len
(
with_cont1
)]
=
with_cont1
input_ids
[
i
,
0
,
:
len
(
with_cont1
)]
=
with_cont1
input_ids
[
i
,
1
,
:
len
(
with_cont2
)]
=
with_cont2
input_ids
[
i
,
1
,
:
len
(
with_cont2
)]
=
with_cont2
mc_token_
mask
[
i
,
0
,
len
(
with_cont1
)
-
1
]
=
1
mc_token_
ids
[
i
,
0
]
=
len
(
with_cont1
)
-
1
mc_token_
mask
[
i
,
1
,
len
(
with_cont2
)
-
1
]
=
1
mc_token_
ids
[
i
,
1
]
=
len
(
with_cont2
)
-
1
lm_labels
[
i
,
0
,
:
len
(
with_cont1
)
-
1
]
=
with_cont1
[
1
:]
lm_labels
[
i
,
0
,
:
len
(
with_cont1
)
-
1
]
=
with_cont1
[
1
:]
lm_labels
[
i
,
1
,
:
len
(
with_cont2
)
-
1
]
=
with_cont2
[
1
:]
lm_labels
[
i
,
1
,
:
len
(
with_cont2
)
-
1
]
=
with_cont2
[
1
:]
mc_labels
[
i
]
=
mc_label
mc_labels
[
i
]
=
mc_label
all_inputs
=
(
input_ids
,
mc_token_
mask
,
lm_labels
,
mc_labels
)
all_inputs
=
(
input_ids
,
mc_token_
ids
,
lm_labels
,
mc_labels
)
tensor_datasets
.
append
(
tuple
(
torch
.
tensor
(
t
)
for
t
in
all_inputs
))
tensor_datasets
.
append
(
tuple
(
torch
.
tensor
(
t
)
for
t
in
all_inputs
))
return
tensor_datasets
return
tensor_datasets
...
@@ -197,8 +197,8 @@ def main():
...
@@ -197,8 +197,8 @@ def main():
tqdm_bar
=
tqdm
(
train_dataloader
,
desc
=
"Training"
)
tqdm_bar
=
tqdm
(
train_dataloader
,
desc
=
"Training"
)
for
step
,
batch
in
enumerate
(
tqdm_bar
):
for
step
,
batch
in
enumerate
(
tqdm_bar
):
batch
=
tuple
(
t
.
to
(
device
)
for
t
in
batch
)
batch
=
tuple
(
t
.
to
(
device
)
for
t
in
batch
)
input_ids
,
mc_token_
mask
,
lm_labels
,
mc_labels
=
batch
input_ids
,
mc_token_
ids
,
lm_labels
,
mc_labels
=
batch
losses
=
model
(
input_ids
,
mc_token_
mask
,
lm_labels
,
mc_labels
)
losses
=
model
(
input_ids
,
mc_token_
ids
,
lm_labels
,
mc_labels
)
loss
=
args
.
lm_coef
*
losses
[
0
]
+
losses
[
1
]
loss
=
args
.
lm_coef
*
losses
[
0
]
+
losses
[
1
]
loss
.
backward
()
loss
.
backward
()
optimizer
.
step
()
optimizer
.
step
()
...
@@ -226,10 +226,10 @@ def main():
...
@@ -226,10 +226,10 @@ def main():
nb_eval_steps
,
nb_eval_examples
=
0
,
0
nb_eval_steps
,
nb_eval_examples
=
0
,
0
for
batch
in
tqdm
(
eval_dataloader
,
desc
=
"Evaluating"
):
for
batch
in
tqdm
(
eval_dataloader
,
desc
=
"Evaluating"
):
batch
=
tuple
(
t
.
to
(
device
)
for
t
in
batch
)
batch
=
tuple
(
t
.
to
(
device
)
for
t
in
batch
)
input_ids
,
mc_token_
mask
,
lm_labels
,
mc_labels
=
batch
input_ids
,
mc_token_
ids
,
lm_labels
,
mc_labels
=
batch
with
torch
.
no_grad
():
with
torch
.
no_grad
():
_
,
mc_loss
=
model
(
input_ids
,
mc_token_
mask
,
lm_labels
,
mc_labels
)
_
,
mc_loss
=
model
(
input_ids
,
mc_token_
ids
,
lm_labels
,
mc_labels
)
_
,
mc_logits
=
model
(
input_ids
,
mc_token_
mask
)
_
,
mc_logits
=
model
(
input_ids
,
mc_token_
ids
)
mc_logits
=
mc_logits
.
detach
().
cpu
().
numpy
()
mc_logits
=
mc_logits
.
detach
().
cpu
().
numpy
()
mc_labels
=
mc_labels
.
to
(
'cpu'
).
numpy
()
mc_labels
=
mc_labels
.
to
(
'cpu'
).
numpy
()
...
...
pytorch_pretrained_bert/modeling_openai.py
View file @
1320e4ec
...
@@ -366,23 +366,16 @@ class OpenAIGPTMultipleChoiceHead(nn.Module):
...
@@ -366,23 +366,16 @@ class OpenAIGPTMultipleChoiceHead(nn.Module):
nn
.
init
.
normal_
(
self
.
linear
.
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
linear
.
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
linear
.
bias
,
0
)
nn
.
init
.
normal_
(
self
.
linear
.
bias
,
0
)
def
forward
(
self
,
hidden_states
,
mc_token_
mask
):
def
forward
(
self
,
hidden_states
,
mc_token_
ids
):
# Classification logits
# Classification logits
# hidden_states = hidden_states.view(-1, self.n_embd)
# hidden_state (bsz, num_choices, seq_length, hidden_size)
# mc_token_mask = mc_token_mask.view(-1, 1).expand_as(hidden_states)
# mc_token_ids (bsz, num_choices)
mc_token_mask
=
mc_token_mask
.
float
()
mc_token_ids
=
mc_token_ids
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
-
1
,
hidden_states
.
size
(
-
1
))
multiple_choice_h
=
hidden_states
*
mc_token_mask
.
unsqueeze
(
-
1
)
# (bsz, num_choices, 1, hidden_size)
multiple_choice_h
=
multiple_choice_h
.
sum
(
dim
=-
2
)
multiple_choice_h
=
hidden_states
.
gather
(
2
,
mc_token_ids
).
squeeze
(
2
)
# flat = x[..., 0].contiguous().view(-1)
# (bsz, num_choices, hidden_size)
# multiple_choice_h = multiple_choice_h[flat == self.multiple_choice_token, :]
# multiple_choice_h = multiple_choice_h.view(-1, x.size(1), self.n_embd, 1)
# # This double transposition is there to replicate the behavior
# # of the noise_shape argument in the tensorflow
# # implementation. For more details, see
# # https://github.com/huggingface/pytorch-openai-transformer-lm/issues/11
# multiple_choice_h = self.dropout(multiple_choice_h.transpose(1, 2)).transpose(1, 2)
# multiple_choice_h = multiple_choice_h.contiguous().view(-1, self.n_embd)
multiple_choice_logits
=
self
.
linear
(
multiple_choice_h
).
squeeze
(
-
1
)
multiple_choice_logits
=
self
.
linear
(
multiple_choice_h
).
squeeze
(
-
1
)
# (bsz, num_choices)
return
multiple_choice_logits
return
multiple_choice_logits
...
@@ -727,7 +720,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
...
@@ -727,7 +720,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
class
OpenAIGPTDoubleHeadsModel
(
OpenAIGPTPreTrainedModel
):
class
OpenAIGPTDoubleHeadsModel
(
OpenAIGPTPreTrainedModel
):
"""OpenAI GPT model with a Language Modeling and a Multiple Choice head
s
("Improving Language Understanding by Generative Pre-Training").
"""OpenAI GPT model with a Language Modeling and a Multiple Choice head ("Improving Language Understanding by Generative Pre-Training").
OpenAI GPT use a single embedding matrix to store the word and special embeddings.
OpenAI GPT use a single embedding matrix to store the word and special embeddings.
Special tokens embeddings are additional tokens that are not pre-trained: [SEP], [CLS]...
Special tokens embeddings are additional tokens that are not pre-trained: [SEP], [CLS]...
...
@@ -750,8 +743,10 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
...
@@ -750,8 +743,10 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
config: a OpenAIGPTConfig class instance with the configuration to build a new model
config: a OpenAIGPTConfig class instance with the configuration to build a new model
Inputs:
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] with the BPE token
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, total_tokens_embeddings[
indices selected in the range [0, total_tokens_embeddings[
`mc_token_ids`: a torch.LongTensor of shape [batch_size, num_choices] with the index of the token from
which we should take the hidden state to feed the multiple choice classifier (usually last token of the sequence)
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
with the position indices (selected in the range [0, config.n_positions - 1[.
with the position indices (selected in the range [0, config.n_positions - 1[.
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
...
@@ -775,13 +770,13 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
...
@@ -775,13 +770,13 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
Example usage:
Example usage:
```python
```python
# Already been converted into BPE token ids
# Already been converted into BPE token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_ids = torch.LongTensor([[
[
31, 51, 99], [15, 5, 0]]
]) # (bsz, number of choice, seq length
)
mc_token_
mask
= torch.LongTensor([[
0, 0, 1], [0, 1, 0]]
)
mc_token_
ids
= torch.LongTensor([[
2], [1]]) # (bsz, number of choice
)
config = modeling_openai.OpenAIGPTConfig()
config = modeling_openai.OpenAIGPTConfig()
model = modeling_openai.OpenAIGPTLMHeadModel(config)
model = modeling_openai.OpenAIGPTLMHeadModel(config)
lm_logits, multiple_choice_logits = model(input_ids, mc_token_
mask
)
lm_logits, multiple_choice_logits = model(input_ids, mc_token_
ids
)
```
```
"""
"""
...
@@ -799,10 +794,10 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
...
@@ -799,10 +794,10 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
tokens_embed
.
weight
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
tokens_embed
.
weight
)
def
forward
(
self
,
input_ids
,
mc_token_
mask
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
position_ids
=
None
):
def
forward
(
self
,
input_ids
,
mc_token_
ids
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
position_ids
=
None
):
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
)
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_
mask
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_
ids
)
losses
=
[]
losses
=
[]
if
lm_labels
is
not
None
:
if
lm_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
...
...
tests/modeling_openai_test.py
View file @
1320e4ec
...
@@ -89,11 +89,11 @@ class OpenAIGPTModelTest(unittest.TestCase):
...
@@ -89,11 +89,11 @@ class OpenAIGPTModelTest(unittest.TestCase):
mc_labels
=
None
mc_labels
=
None
lm_labels
=
None
lm_labels
=
None
mc_token_
mask
=
None
mc_token_
ids
=
None
if
self
.
use_labels
:
if
self
.
use_labels
:
mc_labels
=
OpenAIGPTModelTest
.
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
mc_labels
=
OpenAIGPTModelTest
.
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
lm_labels
=
OpenAIGPTModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
],
self
.
num_labels
)
lm_labels
=
OpenAIGPTModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
],
self
.
num_labels
)
mc_token_
mask
=
OpenAIGPTModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
],
2
).
float
()
mc_token_
ids
=
OpenAIGPTModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
]
,
self
.
seq_length
).
float
()
config
=
OpenAIGPTConfig
(
config
=
OpenAIGPTConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
vocab_size_or_config_json_file
=
self
.
vocab_size
,
...
@@ -109,10 +109,10 @@ class OpenAIGPTModelTest(unittest.TestCase):
...
@@ -109,10 +109,10 @@ class OpenAIGPTModelTest(unittest.TestCase):
initializer_range
=
self
.
initializer_range
)
initializer_range
=
self
.
initializer_range
)
return
(
config
,
input_ids
,
token_type_ids
,
position_ids
,
return
(
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_
mask
)
mc_labels
,
lm_labels
,
mc_token_
ids
)
def
create_openai_model
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
def
create_openai_model
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_
mask
):
mc_labels
,
lm_labels
,
mc_token_
ids
):
model
=
OpenAIGPTModel
(
config
)
model
=
OpenAIGPTModel
(
config
)
model
.
eval
()
model
.
eval
()
hidden_states
=
model
(
input_ids
,
position_ids
,
token_type_ids
)
hidden_states
=
model
(
input_ids
,
position_ids
,
token_type_ids
)
...
@@ -128,7 +128,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
...
@@ -128,7 +128,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
def
create_openai_lm_head
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
def
create_openai_lm_head
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_
mask
):
mc_labels
,
lm_labels
,
mc_token_
ids
):
model
=
OpenAIGPTLMHeadModel
(
config
)
model
=
OpenAIGPTLMHeadModel
(
config
)
model
.
eval
()
model
.
eval
()
loss
=
model
(
input_ids
,
position_ids
,
token_type_ids
,
lm_labels
)
loss
=
model
(
input_ids
,
position_ids
,
token_type_ids
,
lm_labels
)
...
@@ -151,13 +151,13 @@ class OpenAIGPTModelTest(unittest.TestCase):
...
@@ -151,13 +151,13 @@ class OpenAIGPTModelTest(unittest.TestCase):
[])
[])
def
create_openai_double_heads
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
def
create_openai_double_heads
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_
mask
):
mc_labels
,
lm_labels
,
mc_token_
ids
):
model
=
OpenAIGPTDoubleHeadsModel
(
config
)
model
=
OpenAIGPTDoubleHeadsModel
(
config
)
model
.
eval
()
model
.
eval
()
loss
=
model
(
input_ids
,
mc_token_
mask
,
loss
=
model
(
input_ids
,
mc_token_
ids
,
lm_labels
=
lm_labels
,
mc_labels
=
mc_labels
,
lm_labels
=
lm_labels
,
mc_labels
=
mc_labels
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
)
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
)
lm_logits
,
mc_logits
=
model
(
input_ids
,
mc_token_
mask
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
)
lm_logits
,
mc_logits
=
model
(
input_ids
,
mc_token_
ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
)
outputs
=
{
outputs
=
{
"loss"
:
loss
,
"loss"
:
loss
,
"lm_logits"
:
lm_logits
,
"lm_logits"
:
lm_logits
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment