Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
11f76cd3
Commit
11f76cd3
authored
Jul 22, 2020
by
Neel Kant
Browse files
Address comments from last week
parent
7a348580
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
300 additions
and
77 deletions
+300
-77
megatron/arguments.py
megatron/arguments.py
+8
-2
megatron/checkpointing.py
megatron/checkpointing.py
+42
-1
megatron/data/dataset_utils.py
megatron/data/dataset_utils.py
+1
-1
megatron/data/ict_dataset.py
megatron/data/ict_dataset.py
+28
-1
megatron/data/realm_dataset_utils.py
megatron/data/realm_dataset_utils.py
+55
-3
megatron/data/realm_index.py
megatron/data/realm_index.py
+20
-17
megatron/indexer.py
megatron/indexer.py
+90
-0
megatron/model/realm_model.py
megatron/model/realm_model.py
+23
-1
pretrain_ict.py
pretrain_ict.py
+5
-51
tools/create_doc_index.py
tools/create_doc_index.py
+28
-0
No files found.
megatron/arguments.py
View file @
11f76cd3
...
@@ -411,7 +411,7 @@ def _add_realm_args(parser):
...
@@ -411,7 +411,7 @@ def _add_realm_args(parser):
help
=
'Path to titles dataset used for ICT'
)
help
=
'Path to titles dataset used for ICT'
)
group
.
add_argument
(
'--query-in-block-prob'
,
type
=
float
,
default
=
0.1
,
group
.
add_argument
(
'--query-in-block-prob'
,
type
=
float
,
default
=
0.1
,
help
=
'Probability of keeping query in block for ICT dataset'
)
help
=
'Probability of keeping query in block for ICT dataset'
)
group
.
add_argument
(
'--
ict
-one-sent'
,
action
=
'store_true'
,
group
.
add_argument
(
'--
use
-one-sent
-docs
'
,
action
=
'store_true'
,
help
=
'Whether to use one sentence documents in ICT'
)
help
=
'Whether to use one sentence documents in ICT'
)
# training
# training
...
@@ -421,7 +421,13 @@ def _add_realm_args(parser):
...
@@ -421,7 +421,13 @@ def _add_realm_args(parser):
# faiss index
# faiss index
group
.
add_argument
(
'--faiss-use-gpu'
,
action
=
'store_true'
,
group
.
add_argument
(
'--faiss-use-gpu'
,
action
=
'store_true'
,
help
=
'Whether create the FaissMIPSIndex on GPU'
)
help
=
'Whether create the FaissMIPSIndex on GPU'
)
group
.
add_argument
(
'--block-data-path'
,
type
=
str
,
group
.
add_argument
(
'--block-data-path'
,
type
=
str
,
default
=
None
,
help
=
'Where to save/load BlockData to/from'
)
help
=
'Where to save/load BlockData to/from'
)
# indexer
group
.
add_argument
(
'--indexer-batch-size'
,
type
=
int
,
default
=
128
,
help
=
'How large of batches to use when doing indexing jobs'
)
group
.
add_argument
(
'--indexer-log-interval'
,
type
=
int
,
default
=
1000
,
help
=
'After how many batches should the indexer report progress'
)
return
parser
return
parser
megatron/checkpointing.py
View file @
11f76cd3
...
@@ -21,9 +21,10 @@ import sys
...
@@ -21,9 +21,10 @@ import sys
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
torch.nn.parallel
import
DistributedDataParallel
as
torchDDP
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
mpu
from
megatron
import
mpu
,
get_args
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
...
@@ -244,3 +245,43 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
...
@@ -244,3 +245,43 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
print
(
' successfully loaded {}'
.
format
(
checkpoint_name
))
print
(
' successfully loaded {}'
.
format
(
checkpoint_name
))
return
iteration
return
iteration
def
load_ict_checkpoint
(
model
,
only_query_model
=
False
,
only_block_model
=
False
,
from_realm_chkpt
=
False
):
"""selectively load ICT models for indexing/retrieving from ICT or REALM checkpoints"""
args
=
get_args
()
if
isinstance
(
model
,
torchDDP
):
model
=
model
.
module
load_path
=
args
.
load
if
from_realm_chkpt
else
args
.
ict_load
tracker_filename
=
get_checkpoint_tracker_filename
(
load_path
)
with
open
(
tracker_filename
,
'r'
)
as
f
:
iteration
=
int
(
f
.
read
().
strip
())
# assert iteration > 0
checkpoint_name
=
get_checkpoint_name
(
load_path
,
iteration
,
False
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading checkpoint {}'
.
format
(
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
ict_state_dict
=
state_dict
[
'model'
]
if
from_realm_chkpt
and
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
" loading ICT state dict from REALM"
,
flush
=
True
)
ict_state_dict
=
ict_state_dict
[
'retriever'
][
'ict_model'
]
if
only_query_model
:
ict_state_dict
.
pop
(
'context_model'
)
if
only_block_model
:
ict_state_dict
.
pop
(
'question_model'
)
model
.
load_state_dict
(
ict_state_dict
)
torch
.
distributed
.
barrier
()
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
' successfully loaded {}'
.
format
(
checkpoint_name
))
return
model
\ No newline at end of file
megatron/data/dataset_utils.py
View file @
11f76cd3
...
@@ -426,7 +426,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -426,7 +426,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
block_dataset
=
indexed_dataset
,
block_dataset
=
indexed_dataset
,
title_dataset
=
title_dataset
,
title_dataset
=
title_dataset
,
query_in_block_prob
=
args
.
query_in_block_prob
,
query_in_block_prob
=
args
.
query_in_block_prob
,
use_one_sent_docs
=
args
.
ict
_one_sent
,
use_one_sent_docs
=
args
.
use
_one_sent
_docs
,
**
kwargs
**
kwargs
)
)
else
:
else
:
...
...
megatron/data/ict_dataset.py
View file @
11f76cd3
...
@@ -5,9 +5,36 @@ import numpy as np
...
@@ -5,9 +5,36 @@ import numpy as np
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
get_args
from
megatron.data.dataset_utils
import
get_indexed_dataset_
from
megatron.data.realm_dataset_utils
import
get_block_samples_mapping
from
megatron.data.realm_dataset_utils
import
get_block_samples_mapping
def
get_ict_dataset
(
use_titles
=
True
,
query_in_block_prob
=
1
):
"""Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
rather than for training, since it is only built with a single epoch sample mapping.
"""
args
=
get_args
()
block_dataset
=
get_indexed_dataset_
(
args
.
data_path
,
'mmap'
,
True
)
titles_dataset
=
get_indexed_dataset_
(
args
.
titles_data_path
,
'mmap'
,
True
)
kwargs
=
dict
(
name
=
'full'
,
block_dataset
=
block_dataset
,
title_dataset
=
titles_dataset
,
data_prefix
=
args
.
data_path
,
num_epochs
=
1
,
max_num_samples
=
None
,
max_seq_length
=
args
.
seq_length
,
seed
=
1
,
query_in_block_prob
=
query_in_block_prob
,
use_titles
=
use_titles
,
use_one_sent_docs
=
args
.
use_one_sent_docs
)
dataset
=
ICTDataset
(
**
kwargs
)
return
dataset
class
ICTDataset
(
Dataset
):
class
ICTDataset
(
Dataset
):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def
__init__
(
self
,
name
,
block_dataset
,
title_dataset
,
data_prefix
,
def
__init__
(
self
,
name
,
block_dataset
,
title_dataset
,
data_prefix
,
...
@@ -35,7 +62,7 @@ class ICTDataset(Dataset):
...
@@ -35,7 +62,7 @@ class ICTDataset(Dataset):
self
.
pad_id
=
self
.
tokenizer
.
pad
self
.
pad_id
=
self
.
tokenizer
.
pad
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
samples_mapping
.
shape
[
0
]
return
len
(
self
.
samples_mapping
)
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
"""Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
"""Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
...
...
megatron/data/realm_dataset_utils.py
View file @
11f76cd3
...
@@ -6,9 +6,59 @@ import torch
...
@@ -6,9 +6,59 @@ import torch
from
megatron
import
mpu
,
print_rank_0
from
megatron
import
mpu
,
print_rank_0
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
pad_and_convert_to_numpy
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
pad_and_convert_to_numpy
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron
import
get_args
,
get_tokenizer
,
print_rank_0
,
mpu
from
megatron
import
get_args
,
get_tokenizer
,
print_rank_0
,
mpu
def
get_one_epoch_dataloader
(
dataset
,
batch_size
=
None
):
"""Specifically one epoch to be used in an indexing job."""
args
=
get_args
()
world_size
=
mpu
.
get_data_parallel_world_size
()
rank
=
mpu
.
get_data_parallel_rank
()
if
batch_size
is
None
:
batch_size
=
args
.
batch_size
global_batch_size
=
batch_size
*
world_size
num_workers
=
args
.
num_workers
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
# importantly, drop_last must be False to get all the data.
batch_sampler
=
DistributedBatchSampler
(
sampler
,
batch_size
=
global_batch_size
,
drop_last
=
False
,
rank
=
rank
,
world_size
=
world_size
)
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
num_workers
=
num_workers
,
pin_memory
=
True
)
def
get_ict_batch
(
data_iterator
):
# Items and their type.
keys
=
[
'query_tokens'
,
'query_pad_mask'
,
'block_tokens'
,
'block_pad_mask'
,
'block_data'
]
datatype
=
torch
.
int64
# Broadcast data.
if
data_iterator
is
None
:
data
=
None
else
:
data
=
next
(
data_iterator
)
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
query_tokens
=
data_b
[
'query_tokens'
].
long
()
query_pad_mask
=
data_b
[
'query_pad_mask'
].
long
()
block_tokens
=
data_b
[
'block_tokens'
].
long
()
block_pad_mask
=
data_b
[
'block_pad_mask'
].
long
()
block_indices
=
data_b
[
'block_data'
].
long
()
return
query_tokens
,
query_pad_mask
,
\
block_tokens
,
block_pad_mask
,
block_indices
def
join_str_list
(
str_list
):
def
join_str_list
(
str_list
):
"""Join a list of strings, handling spaces appropriately"""
"""Join a list of strings, handling spaces appropriately"""
result
=
""
result
=
""
...
@@ -46,10 +96,12 @@ class BlockSamplesMapping(object):
...
@@ -46,10 +96,12 @@ class BlockSamplesMapping(object):
# make sure that the array is compatible with BlockSampleData
# make sure that the array is compatible with BlockSampleData
assert
mapping_array
.
shape
[
1
]
==
4
assert
mapping_array
.
shape
[
1
]
==
4
self
.
mapping_array
=
mapping_array
self
.
mapping_array
=
mapping_array
self
.
shape
=
self
.
mapping_array
.
shape
def
__len__
(
self
):
return
self
.
mapping_array
.
shape
[
0
]
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
"""Get the data associated with a
particular
sample."""
"""Get the data associated with a
n indexed
sample."""
sample_data
=
BlockSampleData
(
*
self
.
mapping_array
[
idx
])
sample_data
=
BlockSampleData
(
*
self
.
mapping_array
[
idx
])
return
sample_data
return
sample_data
...
@@ -144,6 +196,6 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
...
@@ -144,6 +196,6 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
print_rank_0
(
' loaded indexed file in {:3.3f} seconds'
.
format
(
print_rank_0
(
' loaded indexed file in {:3.3f} seconds'
.
format
(
time
.
time
()
-
start_time
))
time
.
time
()
-
start_time
))
print_rank_0
(
' total number of samples: {}'
.
format
(
print_rank_0
(
' total number of samples: {}'
.
format
(
samples_
mapping
.
shape
[
0
]))
mapping
_array
.
shape
[
0
]))
return
samples_mapping
return
samples_mapping
megatron/data/realm_index.py
View file @
11f76cd3
...
@@ -3,7 +3,6 @@ import os
...
@@ -3,7 +3,6 @@ import os
import
pickle
import
pickle
import
shutil
import
shutil
import
faiss
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -123,6 +122,11 @@ class FaissMIPSIndex(object):
...
@@ -123,6 +122,11 @@ class FaissMIPSIndex(object):
def
_set_block_index
(
self
):
def
_set_block_index
(
self
):
"""Create a Faiss Flat index with inner product as the metric to search against"""
"""Create a Faiss Flat index with inner product as the metric to search against"""
try
:
import
faiss
except
ImportError
:
raise
Exception
(
"Error: Please install faiss to use FaissMIPSIndex"
)
print
(
"
\n
> Building index"
,
flush
=
True
)
print
(
"
\n
> Building index"
,
flush
=
True
)
self
.
block_mips_index
=
faiss
.
index_factory
(
self
.
embed_size
,
'Flat'
,
faiss
.
METRIC_INNER_PRODUCT
)
self
.
block_mips_index
=
faiss
.
index_factory
(
self
.
embed_size
,
'Flat'
,
faiss
.
METRIC_INNER_PRODUCT
)
...
@@ -188,19 +192,18 @@ class FaissMIPSIndex(object):
...
@@ -188,19 +192,18 @@ class FaissMIPSIndex(object):
if False: return [num_queries x k] array of distances, and another for indices
if False: return [num_queries x k] array of distances, and another for indices
"""
"""
query_embeds
=
np
.
float32
(
detach
(
query_embeds
))
query_embeds
=
np
.
float32
(
detach
(
query_embeds
))
with
torch
.
no_grad
():
if
reconstruct
:
if
reconstruct
:
# get the vectors themselves
# get the vectors themselves
top_k_block_embeds
=
self
.
block_mips_index
.
search_and_reconstruct
(
query_embeds
,
top_k
)
top_k_block_embeds
=
self
.
block_mips_index
.
search_and_reconstruct
(
query_embeds
,
top_k
)
return
top_k_block_embeds
return
top_k_block_embeds
else
:
else
:
# get distances and indices of closest vectors
# get distances and indices of closest vectors
distances
,
block_indices
=
self
.
block_mips_index
.
search
(
query_embeds
,
top_k
)
distances
,
block_indices
=
self
.
block_mips_index
.
search
(
query_embeds
,
top_k
)
if
self
.
use_gpu
:
if
self
.
use_gpu
:
fresh_indices
=
np
.
zeros
(
block_indices
.
shape
)
fresh_indices
=
np
.
zeros
(
block_indices
.
shape
)
for
i
,
j
in
itertools
.
product
(
block_indices
.
shape
):
for
i
,
j
in
itertools
.
product
(
block_indices
.
shape
):
fresh_indices
[
i
,
j
]
=
self
.
id_map
[
block_indices
[
i
,
j
]]
fresh_indices
[
i
,
j
]
=
self
.
id_map
[
block_indices
[
i
,
j
]]
block_indices
=
fresh_indices
block_indices
=
fresh_indices
return
distances
,
block_indices
return
distances
,
block_indices
indexer.py
→
megatron/
indexer.py
View file @
11f76cd3
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
ge
t_checkpoint
_name
from
megatron.checkpointing
import
load_ic
t_checkpoint
from
megatron.data.dataset
_utils
import
get_i
ndexed
_dataset
_
from
megatron.data.
ict_
dataset
import
get_i
ct
_dataset
from
megatron.data.
ict
_dataset
import
ICTDataset
from
megatron.data.
realm
_dataset
_utils
import
get_one_epoch_dataloader
from
megatron.data.realm_index
import
detach
,
BlockData
from
megatron.data.realm_index
import
detach
,
BlockData
from
megatron.data.
samplers
import
DistributedBatchSampler
from
megatron.data.
realm_dataset_utils
import
get_ict_batch
from
megatron.
initialize
import
initialize_megatron
from
megatron.
model.realm_model
import
general_ict_model_provider
from
megatron.training
import
get_model
from
megatron.training
import
get_model
from
pretrain_ict
import
get_batch
,
general_ict_model_provider
def
pprint
(
*
args
):
print
(
*
args
,
flush
=
True
)
class
IndexBuilder
(
object
):
class
IndexBuilder
(
object
):
...
@@ -30,22 +24,27 @@ class IndexBuilder(object):
...
@@ -30,22 +24,27 @@ class IndexBuilder(object):
assert
not
(
args
.
load
and
args
.
ict_load
)
assert
not
(
args
.
load
and
args
.
ict_load
)
self
.
using_realm_chkpt
=
args
.
ict_load
is
None
self
.
using_realm_chkpt
=
args
.
ict_load
is
None
self
.
log_interval
=
args
.
indexer_log_interval
self
.
batch_size
=
args
.
indexer_batch_size
self
.
load_attributes
()
self
.
load_attributes
()
self
.
is_main_builder
=
args
.
rank
==
0
self
.
is_main_builder
=
mpu
.
get_data_parallel_rank
()
==
0
self
.
num_total_builders
=
mpu
.
get_data_parallel_world_size
()
self
.
iteration
=
self
.
total_processed
=
0
self
.
iteration
=
self
.
total_processed
=
0
def
load_attributes
(
self
):
def
load_attributes
(
self
):
"""Load the necessary attributes: model, dataloader and empty BlockData"""
"""Load the necessary attributes: model, dataloader and empty BlockData"""
self
.
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
from_realm_chkpt
=
self
.
using_realm_chkpt
)
model
=
get_model
(
lambda
:
general_ict_model_provider
(
only_block_model
=
True
))
self
.
model
=
load_ict_checkpoint
(
model
,
only_block_model
=
True
,
from_realm_chkpt
=
self
.
using_realm_chkpt
)
self
.
model
.
eval
()
self
.
model
.
eval
()
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
get_ict_dataset
()))
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
get_ict_dataset
()
,
self
.
batch_size
))
self
.
block_data
=
BlockData
(
load_from_path
=
False
)
self
.
block_data
=
BlockData
(
load_from_path
=
False
)
def
track_and_report_progress
(
self
,
batch_size
):
def
track_and_report_progress
(
self
,
batch_size
):
"""Utility function for tracking progress"""
"""Utility function for tracking progress"""
self
.
iteration
+=
1
self
.
iteration
+=
1
self
.
total_processed
+=
batch_size
self
.
total_processed
+=
batch_size
*
self
.
num_total_builders
if
self
.
i
teration
%
10
==
0
:
if
self
.
i
s_main_builder
and
self
.
iteration
%
self
.
log_interval
==
0
:
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
self
.
iteration
,
self
.
total_processed
),
flush
=
True
)
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
self
.
iteration
,
self
.
total_processed
),
flush
=
True
)
def
build_and_save_index
(
self
):
def
build_and_save_index
(
self
):
...
@@ -58,17 +57,20 @@ class IndexBuilder(object):
...
@@ -58,17 +57,20 @@ class IndexBuilder(object):
while
True
:
while
True
:
try
:
try
:
# batch also has query_tokens and query_pad_data
# batch also has query_tokens and query_pad_data
_
,
_
,
block_tokens
,
block_pad_mask
,
block_sample_data
=
get_batch
(
self
.
dataloader
)
_
,
_
,
block_tokens
,
block_pad_mask
,
block_sample_data
=
get_
ict_
batch
(
self
.
dataloader
)
except
:
except
StopIteration
:
break
break
# detach, setup and add to BlockData
unwrapped_model
=
self
.
model
unwrapped_model
=
self
.
model
while
not
hasattr
(
unwrapped_model
,
'embed_block'
):
while
not
hasattr
(
unwrapped_model
,
'embed_block'
):
unwrapped_model
=
unwrapped_model
.
module
unwrapped_model
=
unwrapped_model
.
module
block_logits
=
detach
(
unwrapped_model
.
embed_block
(
block_tokens
,
block_pad_mask
))
# detach, separate fields and add to BlockData
block_logits
=
detach
(
unwrapped_model
.
embed_block
(
block_tokens
,
block_pad_mask
))
detached_data
=
detach
(
block_sample_data
)
detached_data
=
detach
(
block_sample_data
)
# block_sample_data is a 2D array [batch x 4]
# with columns [start_idx, end_idx, doc_idx, block_idx] same as class BlockSampleData
block_indices
=
detached_data
[:,
3
]
block_indices
=
detached_data
[:,
3
]
block_metas
=
detached_data
[:,
:
3
]
block_metas
=
detached_data
[:,
:
3
]
...
@@ -86,98 +88,3 @@ class IndexBuilder(object):
...
@@ -86,98 +88,3 @@ class IndexBuilder(object):
self
.
block_data
.
clear
()
self
.
block_data
.
clear
()
def
load_ict_checkpoint
(
only_query_model
=
False
,
only_block_model
=
False
,
from_realm_chkpt
=
False
):
"""load ICT checkpoints for indexing/retrieving. Arguments specify which parts of the state dict to actually use."""
args
=
get_args
()
model
=
get_model
(
lambda
:
general_ict_model_provider
(
only_query_model
,
only_block_model
))
if
isinstance
(
model
,
torchDDP
):
model
=
model
.
module
load_path
=
args
.
load
if
from_realm_chkpt
else
args
.
ict_load
tracker_filename
=
get_checkpoint_tracker_filename
(
load_path
)
with
open
(
tracker_filename
,
'r'
)
as
f
:
iteration
=
int
(
f
.
read
().
strip
())
# assert iteration > 0
checkpoint_name
=
get_checkpoint_name
(
load_path
,
iteration
,
False
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading checkpoint {}'
.
format
(
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
ict_state_dict
=
state_dict
[
'model'
]
if
from_realm_chkpt
:
print
(
">>>> Attempting to get ict state dict from realm"
,
flush
=
True
)
ict_state_dict
=
ict_state_dict
[
'retriever'
][
'ict_model'
]
if
only_query_model
:
ict_state_dict
.
pop
(
'context_model'
)
if
only_block_model
:
ict_state_dict
.
pop
(
'question_model'
)
model
.
load_state_dict
(
ict_state_dict
)
torch
.
distributed
.
barrier
()
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
' successfully loaded {}'
.
format
(
checkpoint_name
))
return
model
def
get_ict_dataset
(
use_titles
=
True
,
query_in_block_prob
=
1
):
"""Get a dataset which uses block samples mappings to get ICT/block indexing data"""
args
=
get_args
()
block_dataset
=
get_indexed_dataset_
(
args
.
data_path
,
'mmap'
,
True
)
titles_dataset
=
get_indexed_dataset_
(
args
.
titles_data_path
,
'mmap'
,
True
)
kwargs
=
dict
(
name
=
'full'
,
block_dataset
=
block_dataset
,
title_dataset
=
titles_dataset
,
data_prefix
=
args
.
data_path
,
num_epochs
=
1
,
max_num_samples
=
None
,
max_seq_length
=
args
.
seq_length
,
seed
=
1
,
query_in_block_prob
=
query_in_block_prob
,
use_titles
=
use_titles
,
use_one_sent_docs
=
True
)
dataset
=
ICTDataset
(
**
kwargs
)
return
dataset
def
get_one_epoch_dataloader
(
dataset
,
batch_size
=
None
):
"""Specifically one epoch to be used in an indexing job."""
args
=
get_args
()
world_size
=
mpu
.
get_data_parallel_world_size
()
rank
=
mpu
.
get_data_parallel_rank
()
if
batch_size
is
None
:
batch_size
=
args
.
batch_size
global_batch_size
=
batch_size
*
world_size
num_workers
=
args
.
num_workers
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
# importantly, drop_last must be False to get all the data.
batch_sampler
=
DistributedBatchSampler
(
sampler
,
batch_size
=
global_batch_size
,
drop_last
=
False
,
rank
=
rank
,
world_size
=
world_size
)
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
num_workers
=
num_workers
,
pin_memory
=
True
)
if
__name__
==
"__main__"
:
# This usage is for basic (as opposed to realm async) indexing jobs.
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
index_builder
=
IndexBuilder
()
index_builder
.
build_and_save_index
()
megatron/model/realm_model.py
View file @
11f76cd3
import
os
import
os
import
torch
import
torch
from
megatron
import
get_args
from
megatron
import
get_args
,
print_rank_0
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.model
import
BertModel
from
megatron.model
import
BertModel
from
megatron.module
import
MegatronModule
from
megatron.module
import
MegatronModule
...
@@ -13,6 +13,28 @@ from megatron.model.utils import scaled_init_method_normal
...
@@ -13,6 +13,28 @@ from megatron.model.utils import scaled_init_method_normal
from
megatron.model.bert_model
import
bert_attention_mask_func
,
bert_extended_attention_mask
,
bert_position_ids
from
megatron.model.bert_model
import
bert_attention_mask_func
,
bert_extended_attention_mask
,
bert_position_ids
def
general_ict_model_provider
(
only_query_model
=
False
,
only_block_model
=
False
):
"""Build the model."""
args
=
get_args
()
assert
args
.
ict_head_size
is
not
None
,
\
"Need to specify --ict-head-size to provide an ICTBertModel"
assert
args
.
model_parallel_size
==
1
,
\
"Model parallel size > 1 not supported for ICT"
print_rank_0
(
'building ICTBertModel...'
)
# simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes
model
=
ICTBertModel
(
ict_head_size
=
args
.
ict_head_size
,
num_tokentypes
=
2
,
parallel_output
=
True
,
only_query_model
=
only_query_model
,
only_block_model
=
only_block_model
)
return
model
class
ICTBertModel
(
MegatronModule
):
class
ICTBertModel
(
MegatronModule
):
"""Bert-based module for Inverse Cloze task."""
"""Bert-based module for Inverse Cloze task."""
def
__init__
(
self
,
def
__init__
(
self
,
...
...
pretrain_ict.py
View file @
11f76cd3
...
@@ -27,33 +27,11 @@ from megatron.data.dataset_utils import build_train_valid_test_datasets
...
@@ -27,33 +27,11 @@ from megatron.data.dataset_utils import build_train_valid_test_datasets
from
megatron.model
import
ICTBertModel
from
megatron.model
import
ICTBertModel
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
reduce_losses
from
megatron.model.realm_model
import
general_ict_model_provider
from
megatron.data.realm_dataset_utils
import
get_ict_batch
num_batches
=
0
def
pretrain_ict_model_provider
():
def
general_ict_model_provider
(
only_query_model
=
False
,
only_block_model
=
False
):
"""Build the model."""
args
=
get_args
()
assert
args
.
ict_head_size
is
not
None
,
\
"Need to specify --ict-head-size to provide an ICTBertModel"
assert
args
.
model_parallel_size
==
1
,
\
"Model parallel size > 1 not supported for ICT"
print_rank_0
(
'building ICTBertModel...'
)
# simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes
model
=
ICTBertModel
(
ict_head_size
=
args
.
ict_head_size
,
num_tokentypes
=
2
,
parallel_output
=
True
,
only_query_model
=
only_query_model
,
only_block_model
=
only_block_model
)
return
model
def
model_provider
():
return
general_ict_model_provider
(
False
,
False
)
return
general_ict_model_provider
(
False
,
False
)
...
@@ -95,30 +73,6 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function):
...
@@ -95,30 +73,6 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function):
return
output
return
output
def
get_batch
(
data_iterator
):
# Items and their type.
keys
=
[
'query_tokens'
,
'query_pad_mask'
,
'block_tokens'
,
'block_pad_mask'
,
'block_data'
]
datatype
=
torch
.
int64
# Broadcast data.
if
data_iterator
is
None
:
data
=
None
else
:
data
=
next
(
data_iterator
)
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
query_tokens
=
data_b
[
'query_tokens'
].
long
()
query_pad_mask
=
data_b
[
'query_pad_mask'
].
long
()
block_tokens
=
data_b
[
'block_tokens'
].
long
()
block_pad_mask
=
data_b
[
'block_pad_mask'
].
long
()
block_indices
=
data_b
[
'block_data'
].
long
()
return
query_tokens
,
query_pad_mask
,
\
block_tokens
,
block_pad_mask
,
block_indices
def
forward_step
(
data_iterator
,
model
):
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
"""Forward step."""
args
=
get_args
()
args
=
get_args
()
...
@@ -127,7 +81,7 @@ def forward_step(data_iterator, model):
...
@@ -127,7 +81,7 @@ def forward_step(data_iterator, model):
# Get the batch.
# Get the batch.
timers
(
'batch generator'
).
start
()
timers
(
'batch generator'
).
start
()
query_tokens
,
query_pad_mask
,
\
query_tokens
,
query_pad_mask
,
\
block_tokens
,
block_pad_mask
,
block_indices
=
get_batch
(
data_iterator
)
block_tokens
,
block_pad_mask
,
block_indices
=
get_
ict_
batch
(
data_iterator
)
timers
(
'batch generator'
).
stop
()
timers
(
'batch generator'
).
stop
()
...
@@ -181,5 +135,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
...
@@ -181,5 +135,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
forward_step
,
pretrain
(
train_valid_test_datasets_provider
,
pretrain_ict_
model_provider
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
tools/create_doc_index.py
0 → 100644
View file @
11f76cd3
import
sys
sys
.
path
.
append
(
'../'
)
from
megatron.indexer
import
IndexBuilder
from
megatron.initialize
import
initialize_megatron
def
main
():
"""Create a BlockData data structure by running an IndexBuilder over an ICT Dataset
- Include all args needed for initial model specification
Other key args:
--block-data-path: path to write to
--ict-load or --realm-load: path to checkpoint with which to embed
--data-path and --titles-data-path: paths for dataset
Check README.md for example script
"""
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
index_builder
=
IndexBuilder
()
index_builder
.
build_and_save_index
()
if
__name__
==
"__main__"
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment