Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
5235ed87
Commit
5235ed87
authored
Apr 21, 2020
by
Neel Kant
Browse files
Simplify batch and forward for ICT dataset and model
parent
aae93362
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
101 additions
and
71 deletions
+101
-71
megatron/data/ict_dataset.py
megatron/data/ict_dataset.py
+29
-21
megatron/model/bert_model.py
megatron/model/bert_model.py
+58
-32
pretrain_bert_ict.py
pretrain_bert_ict.py
+14
-18
No files found.
megatron/data/ict_dataset.py
View file @
5235ed87
...
@@ -28,13 +28,13 @@ class InverseClozeDataset(Dataset):
...
@@ -28,13 +28,13 @@ class InverseClozeDataset(Dataset):
self
.
samples_mapping
=
self
.
get_samples_mapping
(
self
.
samples_mapping
=
self
.
get_samples_mapping
(
data_prefix
,
num_epochs
,
max_num_samples
)
data_prefix
,
num_epochs
,
max_num_samples
)
tokenizer
=
get_tokenizer
()
self
.
tokenizer
=
get_tokenizer
()
self
.
vocab_id_list
=
list
(
tokenizer
.
inv_vocab
.
keys
())
self
.
vocab_id_list
=
list
(
self
.
tokenizer
.
inv_vocab
.
keys
())
self
.
vocab_id_to_token_list
=
tokenizer
.
inv_vocab
self
.
vocab_id_to_token_list
=
self
.
tokenizer
.
inv_vocab
self
.
cls_id
=
tokenizer
.
cls
self
.
cls_id
=
self
.
tokenizer
.
cls
self
.
sep_id
=
tokenizer
.
sep
self
.
sep_id
=
self
.
tokenizer
.
sep
self
.
mask_id
=
tokenizer
.
mask
self
.
mask_id
=
self
.
tokenizer
.
mask
self
.
pad_id
=
tokenizer
.
pad
self
.
pad_id
=
self
.
tokenizer
.
pad
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
samples_mapping
.
shape
[
0
]
return
self
.
samples_mapping
.
shape
[
0
]
...
@@ -62,21 +62,36 @@ class InverseClozeDataset(Dataset):
...
@@ -62,21 +62,36 @@ class InverseClozeDataset(Dataset):
query
=
query
[:
self
.
max_seq_length
-
2
]
query
=
query
[:
self
.
max_seq_length
-
2
]
block
=
list
(
itertools
.
chain
(
*
block
))[:
self
.
max_seq_length
-
(
3
+
len
(
title
))]
block
=
list
(
itertools
.
chain
(
*
block
))[:
self
.
max_seq_length
-
(
3
+
len
(
title
))]
query_tokens
,
query_token_types
,
query_pad_mask
=
self
.
concat_and_pad_tokens
(
query
)
query_tokens
,
query_pad_mask
=
self
.
concat_and_pad_tokens
(
query
)
block_tokens
,
block_token_types
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
sample
=
{
sample
=
{
'query_tokens'
:
np
.
array
(
query_tokens
),
'query_tokens'
:
np
.
array
(
query_tokens
),
'query_types'
:
np
.
array
(
query_token_types
),
'query_pad_mask'
:
np
.
array
(
query_pad_mask
),
'query_pad_mask'
:
np
.
array
(
query_pad_mask
),
'block_tokens'
:
np
.
array
(
block_tokens
),
'block_tokens'
:
np
.
array
(
block_tokens
),
'block_types'
:
np
.
array
(
block_token_types
),
'block_pad_mask'
:
np
.
array
(
block_pad_mask
),
'block_pad_mask'
:
np
.
array
(
block_pad_mask
),
'block_
indices
'
:
np
.
array
([
start_idx
,
end_idx
,
doc_idx
,
block_idx
]).
astype
(
np
.
int64
)
'block_
data
'
:
np
.
array
([
start_idx
,
end_idx
,
doc_idx
,
block_idx
]).
astype
(
np
.
int64
)
}
}
return
sample
return
sample
def
encode_text
(
self
,
text
):
return
self
.
tokenizer
.
tokenize
(
text
)
def
decode_tokens
(
self
,
token_ids
):
tokens
=
self
.
tokenizer
.
tokenizer
.
convert_ids_to_tokens
(
token_ids
)
return
' '
.
join
(
tokens
)
def
get_block
(
self
,
start_idx
,
end_idx
,
doc_idx
):
"""Get the IDs for an evidence block plus the title of the corresponding document"""
block
=
[
self
.
context_dataset
[
i
]
for
i
in
range
(
start_idx
,
end_idx
)]
title
=
list
(
self
.
titles_dataset
[
int
(
doc_idx
)])
block
=
list
(
itertools
.
chain
(
*
block
))[
self
.
max_seq_length
-
(
3
+
len
(
title
))]
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
return
block_tokens
,
block_pad_mask
def
concat_and_pad_tokens
(
self
,
tokens
,
title
=
None
):
def
concat_and_pad_tokens
(
self
,
tokens
,
title
=
None
):
"""concat with special tokens and pad sequence to self.max_seq_length"""
"""concat with special tokens and pad sequence to self.max_seq_length"""
tokens
=
[
self
.
cls_id
]
+
tokens
+
[
self
.
sep_id
]
tokens
=
[
self
.
cls_id
]
+
tokens
+
[
self
.
sep_id
]
...
@@ -85,16 +100,9 @@ class InverseClozeDataset(Dataset):
...
@@ -85,16 +100,9 @@ class InverseClozeDataset(Dataset):
assert
len
(
tokens
)
<=
self
.
max_seq_length
,
len
(
tokens
)
assert
len
(
tokens
)
<=
self
.
max_seq_length
,
len
(
tokens
)
num_pad
=
self
.
max_seq_length
-
len
(
tokens
)
num_pad
=
self
.
max_seq_length
-
len
(
tokens
)
pad_mask
=
[
0
]
*
len
(
tokens
)
+
[
1
]
*
num_pad
pad_mask
=
[
1
]
*
len
(
tokens
)
+
[
0
]
*
num_pad
tokens
+=
[
self
.
pad_id
]
*
num_pad
tokens
+=
[
self
.
pad_id
]
*
num_pad
token_types
=
[
0
]
*
self
.
max_seq_length
return
tokens
,
pad_mask
return
tokens
,
token_types
,
pad_mask
def
get_block
(
self
,
start_idx
,
end_idx
,
doc_idx
,
block_idx
):
block
=
[
self
.
context_dataset
[
i
]
for
i
in
range
(
start_idx
,
end_idx
)]
title
=
list
(
self
.
titles_dataset
[
int
(
doc_idx
)])
block
=
list
(
itertools
.
chain
(
*
block
))[
self
.
max_seq_length
-
(
3
+
len
(
title
))]
def
get_samples_mapping
(
self
,
data_prefix
,
num_epochs
,
max_num_samples
):
def
get_samples_mapping
(
self
,
data_prefix
,
num_epochs
,
max_num_samples
):
if
not
num_epochs
:
if
not
num_epochs
:
...
...
megatron/model/bert_model.py
View file @
5235ed87
...
@@ -273,8 +273,10 @@ class ICTBertModel(MegatronModule):
...
@@ -273,8 +273,10 @@ class ICTBertModel(MegatronModule):
"""Bert-based module for Inverse Cloze task."""
"""Bert-based module for Inverse Cloze task."""
def
__init__
(
self
,
def
__init__
(
self
,
ict_head_size
,
ict_head_size
,
num_tokentypes
=
2
,
num_tokentypes
=
1
,
parallel_output
=
True
):
parallel_output
=
True
,
only_query_model
=
False
,
only_block_model
=
False
):
super
(
ICTBertModel
,
self
).
__init__
()
super
(
ICTBertModel
,
self
).
__init__
()
bert_args
=
dict
(
bert_args
=
dict
(
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
...
@@ -282,44 +284,68 @@ class ICTBertModel(MegatronModule):
...
@@ -282,44 +284,68 @@ class ICTBertModel(MegatronModule):
ict_head_size
=
ict_head_size
,
ict_head_size
=
ict_head_size
,
parallel_output
=
parallel_output
parallel_output
=
parallel_output
)
)
assert
not
only_block_model
and
only_query_model
self
.
use_block_model
=
not
only_query_model
self
.
use_query_model
=
not
only_block_model
# this model embeds (pseudo-)queries - Embed_input in the paper
if
self
.
use_query_model
:
self
.
query_model
=
BertModel
(
**
bert_args
)
# this model embeds (pseudo-)queries - Embed_input in the paper
self
.
_query_key
=
'question_model'
self
.
query_model
=
BertModel
(
**
bert_args
)
self
.
_query_key
=
'question_model'
# this model embeds evidence blocks - Embed_doc in the paper
if
self
.
use_block_model
:
self
.
block_model
=
BertModel
(
**
bert_args
)
# this model embeds evidence blocks - Embed_doc in the paper
self
.
_block_key
=
'context_model'
self
.
block_model
=
BertModel
(
**
bert_args
)
self
.
_block_key
=
'context_model'
def
forward
(
self
,
query_tokens
,
query_attention_mask
,
query_types
,
def
forward
(
self
,
query_tokens
,
query_attention_mask
,
block_tokens
,
block_attention_mask
):
block_tokens
,
block_attention_mask
,
block_types
):
"""Run a forward pass for each of the models and compute the similarity scores."""
"""Run a forward pass for each of the models and compute the similarity scores."""
query_logits
=
self
.
embed_query
(
query_tokens
,
query_attention_mask
)
block_logits
=
self
.
embed_block
(
block_tokens
,
block_attention_mask
)
# [batch x embed] * [embed x batch]
retrieval_scores
=
query_logits
.
matmul
(
torch
.
transpose
(
block_logits
,
0
,
1
))
return
retrieval_scores
def
embed_query
(
self
,
query_tokens
,
query_attention_mask
):
"""Embed a batch of tokens using the query model"""
if
self
.
use_query_model
:
query_types
=
torch
.
zeros
(
query_tokens
.
shape
).
type
(
torch
.
float16
).
cuda
()
query_ict_logits
,
_
=
self
.
query_model
.
forward
(
query_tokens
,
query_attention_mask
,
query_types
)
return
query_ict_logits
else
:
raise
ValueError
(
"Cannot embed query without query model."
)
def
embed_block
(
self
,
block_tokens
,
block_attention_mask
):
"""Embed a batch of tokens using the block model"""
if
self
.
use_block_model
:
block_types
=
torch
.
zeros
(
block_tokens
.
shape
).
type
(
torch
.
float16
).
cuda
()
block_ict_logits
,
_
=
self
.
block_model
.
forward
(
block_tokens
,
block_attention_mask
,
block_types
)
return
block_ict_logits
else
:
raise
ValueError
(
"Cannot embed block without block model."
)
query_logits
,
_
=
self
.
query_model
.
forward
(
query_tokens
,
1
-
query_attention_mask
,
query_types
)
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
block_logits
,
_
=
self
.
block_model
.
forward
(
block_tokens
,
1
-
block_attention_mask
,
block_types
)
return
query_logits
,
block_logits
def
embed_query
(
self
,
query_tokens
,
query_attention_mask
,
query_types
):
query_ict_logits
,
_
=
self
.
question_model
.
forward
(
query_tokens
,
1
-
query_attention_mask
,
query_types
)
return
query_ict_logits
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
"""Save dict with state dicts of each of the models."""
"""Save dict with state dicts of each of the models."""
state_dict_
=
{}
state_dict_
=
{}
state_dict_
[
self
.
_query_key
]
\
if
self
.
use_query_model
:
=
self
.
query_model
.
state_dict_for_save_checkpoint
(
state_dict_
[
self
.
_query_key
]
\
destination
,
prefix
,
keep_vars
)
=
self
.
query_model
.
state_dict_for_save_checkpoint
(
state_dict_
[
self
.
_block_key
]
\
destination
,
prefix
,
keep_vars
)
=
self
.
block_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
self
.
use_block_model
:
state_dict_
[
self
.
_block_key
]
\
=
self
.
block_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Load the state dicts of each of the models"""
"""Load the state dicts of each of the models"""
self
.
query_model
.
load_state_dict
(
if
self
.
use_query_model
:
state_dict
[
self
.
_query_key
],
strict
=
strict
)
self
.
query_model
.
load_state_dict
(
self
.
block_model
.
load_state_dict
(
state_dict
[
self
.
_query_key
],
strict
=
strict
)
state_dict
[
self
.
_block_key
],
strict
=
strict
)
if
self
.
use_block_model
:
self
.
block_model
.
load_state_dict
(
state_dict
[
self
.
_block_key
],
strict
=
strict
)
pretrain_bert_ict.py
View file @
5235ed87
...
@@ -43,10 +43,9 @@ def model_provider():
...
@@ -43,10 +43,9 @@ def model_provider():
def
get_batch
(
data_iterator
):
def
get_batch
(
data_iterator
):
# Items and their type.
# Items and their type.
keys
=
[
'query_tokens'
,
'query_types'
,
'query_pad_mask'
,
keys
=
[
'query_tokens'
,
'query_pad_mask'
,
'block_tokens'
,
'block_types'
,
'block_pad_mask'
,
'block_
indices
'
]
'block_tokens'
,
'block_pad_mask'
,
'block_
data
'
]
datatype
=
torch
.
int64
datatype
=
torch
.
int64
# Broadcast data.
# Broadcast data.
...
@@ -58,15 +57,13 @@ def get_batch(data_iterator):
...
@@ -58,15 +57,13 @@ def get_batch(data_iterator):
# Unpack.
# Unpack.
query_tokens
=
data_b
[
'query_tokens'
].
long
()
query_tokens
=
data_b
[
'query_tokens'
].
long
()
query_types
=
data_b
[
'query_types'
].
long
()
query_pad_mask
=
data_b
[
'query_pad_mask'
].
long
()
query_pad_mask
=
data_b
[
'query_pad_mask'
].
long
()
block_tokens
=
data_b
[
'block_tokens'
].
long
()
block_tokens
=
data_b
[
'block_tokens'
].
long
()
block_types
=
data_b
[
'block_types'
].
long
()
block_pad_mask
=
data_b
[
'block_pad_mask'
].
long
()
block_pad_mask
=
data_b
[
'block_pad_mask'
].
long
()
block_indices
=
data_b
[
'block_
indices
'
].
long
()
block_indices
=
data_b
[
'block_
data
'
].
long
()
return
query_tokens
,
query_types
,
query_pad_mask
,
\
return
query_tokens
,
query_pad_mask
,
\
block_tokens
,
block_types
,
block_pad_mask
,
block_indices
block_tokens
,
block_pad_mask
,
block_indices
def
forward_step
(
data_iterator
,
model
):
def
forward_step
(
data_iterator
,
model
):
...
@@ -75,16 +72,12 @@ def forward_step(data_iterator, model):
...
@@ -75,16 +72,12 @@ def forward_step(data_iterator, model):
# Get the batch.
# Get the batch.
timers
(
'batch generator'
).
start
()
timers
(
'batch generator'
).
start
()
query_tokens
,
query_types
,
query_pad_mask
,
\
query_tokens
,
query_pad_mask
,
\
block_tokens
,
block_types
,
block_pad_mask
,
block_indices
=
get_batch
(
data_iterator
)
block_tokens
,
block_pad_mask
,
block_indices
=
get_batch
(
data_iterator
)
timers
(
'batch generator'
).
stop
()
timers
(
'batch generator'
).
stop
()
# Forward model.
# Forward model.
query_logits
,
block_logits
=
model
(
query_tokens
,
query_pad_mask
,
query_types
,
retrieval_scores
=
model
(
query_tokens
,
query_pad_mask
,
block_tokens
,
block_pad_mask
).
float
()
block_tokens
,
block_pad_mask
,
block_types
).
float
()
# [batch x h] * [h x batch]
retrieval_scores
=
query_logits
.
matmul
(
torch
.
transpose
(
block_logits
,
0
,
1
))
softmaxed
=
F
.
softmax
(
retrieval_scores
,
dim
=
1
)
softmaxed
=
F
.
softmax
(
retrieval_scores
,
dim
=
1
)
top5_vals
,
top5_indices
=
torch
.
topk
(
softmaxed
,
k
=
5
,
sorted
=
True
)
top5_vals
,
top5_indices
=
torch
.
topk
(
softmaxed
,
k
=
5
,
sorted
=
True
)
...
@@ -95,10 +88,13 @@ def forward_step(data_iterator, model):
...
@@ -95,10 +88,13 @@ def forward_step(data_iterator, model):
retrieval_loss
=
F
.
cross_entropy
(
softmaxed
,
torch
.
arange
(
batch_size
).
cuda
())
retrieval_loss
=
F
.
cross_entropy
(
softmaxed
,
torch
.
arange
(
batch_size
).
cuda
())
reduced_losses
=
reduce_losses
([
retrieval_loss
,
top1_acc
,
top5_acc
])
reduced_losses
=
reduce_losses
([
retrieval_loss
,
top1_acc
,
top5_acc
])
stats_dict
=
{
'retrieval loss'
:
reduced_losses
[
0
],
'top1_acc'
:
reduced_losses
[
1
],
'top5_acc'
:
reduced_losses
[
2
]
}
return
retrieval_loss
,
{
'retrieval loss'
:
reduced_losses
[
0
],
return
retrieval_loss
,
stats_dict
'top1_acc'
:
reduced_losses
[
1
],
'top5_acc'
:
reduced_losses
[
2
]}
def
train_valid_test_datasets_provider
(
train_val_test_num_samples
):
def
train_valid_test_datasets_provider
(
train_val_test_num_samples
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment