Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
0295bb89
Commit
0295bb89
authored
Jan 29, 2021
by
Mostofa Patwary
Browse files
WIP: main_retriver_merge
parent
17d897e0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
493 additions
and
0 deletions
+493
-0
megatron/data/biencoder_dataset_utils.py
megatron/data/biencoder_dataset_utils.py
+202
-0
megatron/model/biencoder_model.py
megatron/model/biencoder_model.py
+291
-0
No files found.
megatron/data/biencoder_dataset_utils.py
0 → 100644
View file @
0295bb89
import
os
import
time
import
numpy
as
np
import
torch
from
megatron
import
mpu
,
print_rank_0
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
pad_and_convert_to_numpy
from
megatron
import
get_args
,
get_tokenizer
,
print_rank_0
,
mpu
def
get_one_epoch_dataloader
(
dataset
,
micro_batch_size
=
None
):
"""Specifically one epoch to be used in an indexing job."""
args
=
get_args
()
world_size
=
mpu
.
get_data_parallel_world_size
()
rank
=
mpu
.
get_data_parallel_rank
()
if
micro_batch_size
is
None
:
micro_batch_size
=
args
.
micro_batch_size
global_batch_size
=
micro_batch_size
*
world_size
num_workers
=
args
.
num_workers
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
# importantly, drop_last must be False to get all the data.
assert
False
,
'DistributedBatchSampler deprecated, change the implementation'
from
megatron.data.samplers
import
DistributedBatchSampler
batch_sampler
=
DistributedBatchSampler
(
sampler
,
batch_size
=
global_batch_size
,
drop_last
=
False
,
rank
=
rank
,
world_size
=
world_size
)
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
num_workers
=
num_workers
,
pin_memory
=
True
)
def
get_ict_batch
(
data_iterator
):
# Items and their type.
keys
=
[
'query_tokens'
,
'query_mask'
,
'context_tokens'
,
'context_mask'
,
'block_data'
]
datatype
=
torch
.
int64
# Broadcast data.
if
data_iterator
is
None
:
data
=
None
else
:
data
=
next
(
data_iterator
)
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
query_tokens
=
data_b
[
'query_tokens'
].
long
()
query_mask
=
data_b
[
'query_mask'
]
<
0.5
context_tokens
=
data_b
[
'context_tokens'
].
long
()
context_mask
=
data_b
[
'context_mask'
]
<
0.5
block_indices
=
data_b
[
'block_data'
].
long
()
return
query_tokens
,
query_mask
,
\
context_tokens
,
context_mask
,
block_indices
def
join_str_list
(
str_list
):
"""Join a list of strings, handling spaces appropriately"""
result
=
""
for
s
in
str_list
:
if
s
.
startswith
(
"##"
):
result
+=
s
[
2
:]
else
:
result
+=
" "
+
s
return
result
class
BlockSampleData
(
object
):
"""A struct for fully describing a fixed-size block of data as used in REALM
:param start_idx: for first sentence of the block
:param end_idx: for last sentence of the block (may be partially truncated in sample construction)
:param doc_idx: the index of the document from which the block comes in the original indexed dataset
:param block_idx: a unique integer identifier given to every block.
"""
def
__init__
(
self
,
start_idx
,
end_idx
,
doc_idx
,
block_idx
):
self
.
start_idx
=
start_idx
self
.
end_idx
=
end_idx
self
.
doc_idx
=
doc_idx
self
.
block_idx
=
block_idx
def
as_array
(
self
):
return
np
.
array
([
self
.
start_idx
,
self
.
end_idx
,
self
.
doc_idx
,
self
.
block_idx
]).
astype
(
np
.
int64
)
def
as_tuple
(
self
):
return
self
.
start_idx
,
self
.
end_idx
,
self
.
doc_idx
,
self
.
block_idx
class
BlockSamplesMapping
(
object
):
def
__init__
(
self
,
mapping_array
):
# make sure that the array is compatible with BlockSampleData
assert
mapping_array
.
shape
[
1
]
==
4
self
.
mapping_array
=
mapping_array
def
__len__
(
self
):
return
self
.
mapping_array
.
shape
[
0
]
def
__getitem__
(
self
,
idx
):
"""Get the data associated with an indexed sample."""
sample_data
=
BlockSampleData
(
*
self
.
mapping_array
[
idx
])
return
sample_data
def
get_block_samples_mapping
(
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
seed
,
name
,
use_one_sent_docs
=
False
):
"""Get samples mapping for a dataset over fixed size blocks. This function also requires
a dataset of the titles for the source documents since their lengths must be taken into account.
:return: samples_mapping (BlockSamplesMapping)
"""
if
not
num_epochs
:
if
not
max_num_samples
:
raise
ValueError
(
"Need to specify either max_num_samples "
"or num_epochs"
)
num_epochs
=
np
.
iinfo
(
np
.
int32
).
max
-
1
if
not
max_num_samples
:
max_num_samples
=
np
.
iinfo
(
np
.
int64
).
max
-
1
# Filename of the index mapping
indexmap_filename
=
data_prefix
indexmap_filename
+=
'_{}_indexmap'
.
format
(
name
)
if
num_epochs
!=
(
np
.
iinfo
(
np
.
int32
).
max
-
1
):
indexmap_filename
+=
'_{}ep'
.
format
(
num_epochs
)
if
max_num_samples
!=
(
np
.
iinfo
(
np
.
int64
).
max
-
1
):
indexmap_filename
+=
'_{}mns'
.
format
(
max_num_samples
)
indexmap_filename
+=
'_{}msl'
.
format
(
max_seq_length
)
indexmap_filename
+=
'_{}s'
.
format
(
seed
)
if
use_one_sent_docs
:
indexmap_filename
+=
'_1sentok'
indexmap_filename
+=
'.npy'
# Build the indexed mapping if not exist.
if
mpu
.
get_data_parallel_rank
()
==
0
and
\
not
os
.
path
.
isfile
(
indexmap_filename
):
print
(
' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'
.
format
(
indexmap_filename
))
# Make sure the types match the helpers input types.
assert
block_dataset
.
doc_idx
.
dtype
==
np
.
int64
assert
block_dataset
.
sizes
.
dtype
==
np
.
int32
# Build samples mapping
verbose
=
torch
.
distributed
.
get_rank
()
==
0
start_time
=
time
.
time
()
print_rank_0
(
' > building samples index mapping for {} ...'
.
format
(
name
))
# compile/bind the C++ helper code
from
megatron.data.dataset_utils
import
compile_helper
compile_helper
()
from
megatron.data
import
helpers
mapping_array
=
helpers
.
build_blocks_mapping
(
block_dataset
.
doc_idx
,
block_dataset
.
sizes
,
title_dataset
.
sizes
,
num_epochs
,
max_num_samples
,
max_seq_length
-
3
,
# account for added tokens
seed
,
verbose
,
use_one_sent_docs
)
print_rank_0
(
' > done building samples index mapping'
)
np
.
save
(
indexmap_filename
,
mapping_array
,
allow_pickle
=
True
)
print_rank_0
(
' > saved the index mapping in {}'
.
format
(
indexmap_filename
))
# Make sure all the ranks have built the mapping
print_rank_0
(
' > elapsed time to build and save samples mapping '
'(seconds): {:4f}'
.
format
(
time
.
time
()
-
start_time
))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts
=
torch
.
cuda
.
LongTensor
([
1
])
torch
.
distributed
.
all_reduce
(
counts
,
group
=
mpu
.
get_data_parallel_group
())
assert
counts
[
0
].
item
()
==
torch
.
distributed
.
get_world_size
(
group
=
mpu
.
get_data_parallel_group
())
# Load indexed dataset.
print_rank_0
(
' > loading indexed mapping from {}'
.
format
(
indexmap_filename
))
start_time
=
time
.
time
()
mapping_array
=
np
.
load
(
indexmap_filename
,
allow_pickle
=
True
,
mmap_mode
=
'r'
)
samples_mapping
=
BlockSamplesMapping
(
mapping_array
)
print_rank_0
(
' loaded indexed file in {:3.3f} seconds'
.
format
(
time
.
time
()
-
start_time
))
print_rank_0
(
' total number of samples: {}'
.
format
(
mapping_array
.
shape
[
0
]))
return
samples_mapping
megatron/model/biencoder_model.py
0 → 100644
View file @
0295bb89
import
os
import
torch
import
sys
from
megatron
import
get_args
,
print_rank_0
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.module
import
MegatronModule
from
megatron
import
mpu
,
get_tokenizer
from
megatron.model.bert_model
import
bert_attention_mask_func
from
megatron.model.bert_model
import
bert_extended_attention_mask
from
megatron.model.bert_model
import
bert_position_ids
from
megatron.model.language_model
import
get_language_model
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
from
megatron.model.utils
import
scaled_init_method_normal
def
biencoder_model_provider
(
only_query_model
=
False
,
only_context_model
=
False
,
shared_query_context_model
=
False
):
"""Build the model."""
args
=
get_args
()
assert
mpu
.
get_tensor_model_parallel_world_size
()
==
1
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
==
1
,
\
"Model parallel size > 1 not supported for ICT"
print_rank_0
(
'building BiEncoderModel...'
)
# simpler to just keep using 2 tokentypes since
# the LM we initialize with has 2 tokentypes
model
=
BiEncoderModel
(
num_tokentypes
=
2
,
parallel_output
=
True
,
only_query_model
=
only_query_model
,
only_context_model
=
only_context_model
,
shared_query_context_model
=
shared_query_context_model
)
return
model
class
BiEncoderModel
(
MegatronModule
):
"""Bert-based module for Biencoder model."""
def
__init__
(
self
,
num_tokentypes
=
1
,
parallel_output
=
True
,
only_query_model
=
False
,
only_context_model
=
False
,
shared_query_context_model
=
False
):
super
(
BiEncoderModel
,
self
).
__init__
()
args
=
get_args
()
bert_kwargs
=
dict
(
num_tokentypes
=
num_tokentypes
,
parallel_output
=
parallel_output
)
self
.
shared_query_context_model
=
shared_query_context_model
assert
not
(
only_context_model
and
only_query_model
)
self
.
use_context_model
=
not
only_query_model
self
.
use_query_model
=
not
only_context_model
self
.
projection_dim
=
args
.
projection_dim
if
self
.
shared_query_context_model
:
self
.
model
=
PretrainedBertModel
(
**
bert_kwargs
)
self
.
_model_key
=
'shared_model'
self
.
query_model
,
self
.
context_model
=
self
.
model
,
self
.
model
else
:
if
self
.
use_query_model
:
# this model embeds (pseudo-)queries - Embed_input in the paper
self
.
query_model
=
PretrainedBertModel
(
**
bert_kwargs
)
self
.
_query_key
=
'query_model'
if
self
.
use_context_model
:
# this model embeds evidence blocks - Embed_doc in the paper
self
.
context_model
=
PretrainedBertModel
(
**
bert_kwargs
)
self
.
_context_key
=
'context_model'
def
forward
(
self
,
query_tokens
,
query_attention_mask
,
query_types
,
context_tokens
,
context_attention_mask
,
context_types
):
"""Run a forward pass for each of the models and
return the respective embeddings."""
if
self
.
use_query_model
:
query_logits
=
self
.
embed_text
(
self
.
query_model
,
query_tokens
,
query_attention_mask
,
query_types
)
else
:
raise
ValueError
(
"Cannot embed query without the query model."
)
if
self
.
use_context_model
:
context_logits
=
self
.
embed_text
(
self
.
context_model
,
context_tokens
,
context_attention_mask
,
context_types
)
else
:
raise
ValueError
(
"Cannot embed block without the block model."
)
return
query_logits
,
context_logits
@
staticmethod
def
embed_text
(
model
,
tokens
,
attention_mask
,
token_types
):
"""Embed a batch of tokens using the model"""
logits
=
model
(
tokens
,
attention_mask
,
token_types
)
return
logits
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
\
prefix
=
''
,
keep_vars
=
False
):
"""Save dict with state dicts of each of the models."""
state_dict_
=
{}
if
self
.
shared_query_context_model
:
state_dict_
[
self
.
_model_key
]
=
\
self
.
model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
else
:
if
self
.
use_query_model
:
state_dict_
[
self
.
_query_key
]
=
\
self
.
query_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
self
.
use_context_model
:
state_dict_
[
self
.
_context_key
]
=
\
self
.
context_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Load the state dicts of each of the models"""
if
self
.
shared_query_context_model
:
print_rank_0
(
"Loading shared query-context model"
)
self
.
model
.
load_state_dict
(
state_dict
[
self
.
_model_key
],
\
strict
=
strict
)
else
:
if
self
.
use_query_model
:
print_rank_0
(
"Loading query model"
)
self
.
query_model
.
load_state_dict
(
\
state_dict
[
self
.
_query_key
],
strict
=
strict
)
if
self
.
use_context_model
:
print_rank_0
(
"Loading context model"
)
self
.
context_model
.
load_state_dict
(
\
state_dict
[
self
.
_context_key
],
strict
=
strict
)
def
init_state_dict_from_bert
(
self
):
"""Initialize the state from a pretrained BERT model
on iteration zero of ICT pretraining"""
args
=
get_args
()
if
args
.
bert_load
is
None
:
print_rank_0
(
"bert-load argument is None"
)
return
tracker_filename
=
get_checkpoint_tracker_filename
(
args
.
bert_load
)
if
not
os
.
path
.
isfile
(
tracker_filename
):
raise
FileNotFoundError
(
"Could not find BERT checkpoint"
)
with
open
(
tracker_filename
,
'r'
)
as
f
:
iteration
=
int
(
f
.
read
().
strip
())
assert
iteration
>
0
#for param in self.query_model.language_model.parameters():
# print(param.data)
#break
#sys.exit()
checkpoint_name
=
get_checkpoint_name
(
args
.
bert_load
,
iteration
,
False
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading BERT checkpoint {}'
.
format
(
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
try
:
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
except
BaseException
:
raise
ValueError
(
"Could not load BERT checkpoint"
)
# load the LM state dict into each model
model_dict
=
state_dict
[
'model'
][
'language_model'
]
if
self
.
shared_query_context_model
:
self
.
model
.
language_model
.
load_state_dict
(
model_dict
)
else
:
if
self
.
use_query_model
:
self
.
query_model
.
language_model
.
load_state_dict
(
model_dict
)
# give each model the same ict_head to begin with as well
if
self
.
projection_dim
>
0
:
query_proj_state_dict
=
\
self
.
state_dict_for_save_checkpoint
()
\
[
self
.
_query_key
][
'projection_enc'
]
if
self
.
use_context_model
:
self
.
context_model
.
language_model
.
load_state_dict
(
model_dict
)
if
self
.
query_model
is
not
None
and
self
.
projection_dim
>
0
:
self
.
context_model
.
projection_enc
.
load_state_dict
\
(
query_proj_state_dict
)
#for param in self.query_model.language_model.parameters():
# print(param.data)
# #sys.exit()
class
PretrainedBertModel
(
MegatronModule
):
"""BERT-based encoder for queries or contexts used for
learned information retrieval."""
def
__init__
(
self
,
num_tokentypes
=
2
,
parallel_output
=
True
):
super
(
PretrainedBertModel
,
self
).
__init__
()
args
=
get_args
()
tokenizer
=
get_tokenizer
()
self
.
pad_id
=
tokenizer
.
pad
self
.
pool_type
=
args
.
pool_type
self
.
projection_dim
=
args
.
projection_dim
self
.
parallel_output
=
parallel_output
init_method
=
init_method_normal
(
args
.
init_method_std
)
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
attention_mask_func
=
bert_attention_mask_func
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
False
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
)
if
args
.
projection_dim
>
0
:
self
.
projection_enc
=
get_linear_layer
(
args
.
hidden_size
,
args
.
projection_dim
,
init_method
)
self
.
_projection_enc_key
=
'projection_enc'
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
extended_attention_mask
=
attention_mask
.
unsqueeze
(
1
)
#extended_attention_mask = bert_extended_attention_mask(attention_mask)
position_ids
=
bert_position_ids
(
input_ids
)
lm_output
=
self
.
language_model
(
input_ids
,
position_ids
,
extended_attention_mask
,
tokentype_ids
=
tokentype_ids
)
# This mask will be used in average-pooling and max-pooling
pool_mask
=
(
input_ids
==
self
.
pad_id
).
unsqueeze
(
2
)
# Taking the representation of the [CLS] token of BERT
if
self
.
pool_type
==
"cls-token"
:
pooled_output
=
lm_output
[:,
0
,
:]
elif
self
.
pool_type
==
"avg"
:
# Average Pooling
pooled_output
=
lm_output
.
masked_fill
(
pool_mask
,
0
)
pooled_output
=
pooled_output
.
sum
(
1
)
/
(
pool_mask
.
size
(
1
)
\
-
pool_mask
.
float
().
sum
(
1
))
elif
self
.
pool_type
==
"max"
:
# Max-Pooling
pooled_output
=
lm_output
.
masked_fill
(
pool_mask
,
-
1000
)
pooled_output
=
torch
.
max
(
pooled_output
,
1
)[
0
]
# Converting to float16 dtype
pooled_output
=
pooled_output
.
to
(
lm_output
.
dtype
)
# Output.
if
self
.
projection_dim
:
pooled_output
=
self
.
projection_enc
(
pooled_output
)
return
pooled_output
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_
=
{}
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
self
.
projection_dim
>
0
:
state_dict_
[
self
.
_projection_enc_key
]
=
\
self
.
projection_enc
.
state_dict
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
print_rank_0
(
"loading BERT weights"
)
self
.
language_model
.
load_state_dict
(
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
if
self
.
projection_dim
>
0
:
print_rank_0
(
"loading projection head weights"
)
self
.
projection_enc
.
load_state_dict
(
state_dict
[
self
.
_projection_enc_key
],
strict
=
strict
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment