Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
11f76cd3
Commit
11f76cd3
authored
Jul 22, 2020
by
Neel Kant
Browse files
Address comments from last week
parent
7a348580
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
300 additions
and
77 deletions
+300
-77
megatron/arguments.py
megatron/arguments.py
+8
-2
megatron/checkpointing.py
megatron/checkpointing.py
+42
-1
megatron/data/dataset_utils.py
megatron/data/dataset_utils.py
+1
-1
megatron/data/ict_dataset.py
megatron/data/ict_dataset.py
+28
-1
megatron/data/realm_dataset_utils.py
megatron/data/realm_dataset_utils.py
+55
-3
megatron/data/realm_index.py
megatron/data/realm_index.py
+20
-17
megatron/indexer.py
megatron/indexer.py
+90
-0
megatron/model/realm_model.py
megatron/model/realm_model.py
+23
-1
pretrain_ict.py
pretrain_ict.py
+5
-51
tools/create_doc_index.py
tools/create_doc_index.py
+28
-0
No files found.
megatron/arguments.py
View file @
11f76cd3
...
...
@@ -411,7 +411,7 @@ def _add_realm_args(parser):
help
=
'Path to titles dataset used for ICT'
)
group
.
add_argument
(
'--query-in-block-prob'
,
type
=
float
,
default
=
0.1
,
help
=
'Probability of keeping query in block for ICT dataset'
)
group
.
add_argument
(
'--
ict
-one-sent'
,
action
=
'store_true'
,
group
.
add_argument
(
'--
use
-one-sent
-docs
'
,
action
=
'store_true'
,
help
=
'Whether to use one sentence documents in ICT'
)
# training
...
...
@@ -421,7 +421,13 @@ def _add_realm_args(parser):
# faiss index
group
.
add_argument
(
'--faiss-use-gpu'
,
action
=
'store_true'
,
help
=
'Whether create the FaissMIPSIndex on GPU'
)
group
.
add_argument
(
'--block-data-path'
,
type
=
str
,
group
.
add_argument
(
'--block-data-path'
,
type
=
str
,
default
=
None
,
help
=
'Where to save/load BlockData to/from'
)
# indexer
group
.
add_argument
(
'--indexer-batch-size'
,
type
=
int
,
default
=
128
,
help
=
'How large of batches to use when doing indexing jobs'
)
group
.
add_argument
(
'--indexer-log-interval'
,
type
=
int
,
default
=
1000
,
help
=
'After how many batches should the indexer report progress'
)
return
parser
megatron/checkpointing.py
View file @
11f76cd3
...
...
@@ -21,9 +21,10 @@ import sys
import
numpy
as
np
import
torch
from
torch.nn.parallel
import
DistributedDataParallel
as
torchDDP
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
mpu
from
megatron
import
mpu
,
get_args
from
megatron
import
get_args
from
megatron
import
print_rank_0
...
...
@@ -244,3 +245,43 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
print
(
' successfully loaded {}'
.
format
(
checkpoint_name
))
return
iteration
def
load_ict_checkpoint
(
model
,
only_query_model
=
False
,
only_block_model
=
False
,
from_realm_chkpt
=
False
):
"""selectively load ICT models for indexing/retrieving from ICT or REALM checkpoints"""
args
=
get_args
()
if
isinstance
(
model
,
torchDDP
):
model
=
model
.
module
load_path
=
args
.
load
if
from_realm_chkpt
else
args
.
ict_load
tracker_filename
=
get_checkpoint_tracker_filename
(
load_path
)
with
open
(
tracker_filename
,
'r'
)
as
f
:
iteration
=
int
(
f
.
read
().
strip
())
# assert iteration > 0
checkpoint_name
=
get_checkpoint_name
(
load_path
,
iteration
,
False
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading checkpoint {}'
.
format
(
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
ict_state_dict
=
state_dict
[
'model'
]
if
from_realm_chkpt
and
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
" loading ICT state dict from REALM"
,
flush
=
True
)
ict_state_dict
=
ict_state_dict
[
'retriever'
][
'ict_model'
]
if
only_query_model
:
ict_state_dict
.
pop
(
'context_model'
)
if
only_block_model
:
ict_state_dict
.
pop
(
'question_model'
)
model
.
load_state_dict
(
ict_state_dict
)
torch
.
distributed
.
barrier
()
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
' successfully loaded {}'
.
format
(
checkpoint_name
))
return
model
\ No newline at end of file
megatron/data/dataset_utils.py
View file @
11f76cd3
...
...
@@ -426,7 +426,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
block_dataset
=
indexed_dataset
,
title_dataset
=
title_dataset
,
query_in_block_prob
=
args
.
query_in_block_prob
,
use_one_sent_docs
=
args
.
ict
_one_sent
,
use_one_sent_docs
=
args
.
use
_one_sent
_docs
,
**
kwargs
)
else
:
...
...
megatron/data/ict_dataset.py
View file @
11f76cd3
...
...
@@ -5,9 +5,36 @@ import numpy as np
from
torch.utils.data
import
Dataset
from
megatron
import
get_tokenizer
from
megatron
import
get_args
from
megatron.data.dataset_utils
import
get_indexed_dataset_
from
megatron.data.realm_dataset_utils
import
get_block_samples_mapping
def
get_ict_dataset
(
use_titles
=
True
,
query_in_block_prob
=
1
):
"""Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
rather than for training, since it is only built with a single epoch sample mapping.
"""
args
=
get_args
()
block_dataset
=
get_indexed_dataset_
(
args
.
data_path
,
'mmap'
,
True
)
titles_dataset
=
get_indexed_dataset_
(
args
.
titles_data_path
,
'mmap'
,
True
)
kwargs
=
dict
(
name
=
'full'
,
block_dataset
=
block_dataset
,
title_dataset
=
titles_dataset
,
data_prefix
=
args
.
data_path
,
num_epochs
=
1
,
max_num_samples
=
None
,
max_seq_length
=
args
.
seq_length
,
seed
=
1
,
query_in_block_prob
=
query_in_block_prob
,
use_titles
=
use_titles
,
use_one_sent_docs
=
args
.
use_one_sent_docs
)
dataset
=
ICTDataset
(
**
kwargs
)
return
dataset
class
ICTDataset
(
Dataset
):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def
__init__
(
self
,
name
,
block_dataset
,
title_dataset
,
data_prefix
,
...
...
@@ -35,7 +62,7 @@ class ICTDataset(Dataset):
self
.
pad_id
=
self
.
tokenizer
.
pad
def
__len__
(
self
):
return
self
.
samples_mapping
.
shape
[
0
]
return
len
(
self
.
samples_mapping
)
def
__getitem__
(
self
,
idx
):
"""Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
...
...
megatron/data/realm_dataset_utils.py
View file @
11f76cd3
...
...
@@ -6,9 +6,59 @@ import torch
from
megatron
import
mpu
,
print_rank_0
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
pad_and_convert_to_numpy
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron
import
get_args
,
get_tokenizer
,
print_rank_0
,
mpu
def
get_one_epoch_dataloader
(
dataset
,
batch_size
=
None
):
"""Specifically one epoch to be used in an indexing job."""
args
=
get_args
()
world_size
=
mpu
.
get_data_parallel_world_size
()
rank
=
mpu
.
get_data_parallel_rank
()
if
batch_size
is
None
:
batch_size
=
args
.
batch_size
global_batch_size
=
batch_size
*
world_size
num_workers
=
args
.
num_workers
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
# importantly, drop_last must be False to get all the data.
batch_sampler
=
DistributedBatchSampler
(
sampler
,
batch_size
=
global_batch_size
,
drop_last
=
False
,
rank
=
rank
,
world_size
=
world_size
)
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
num_workers
=
num_workers
,
pin_memory
=
True
)
def
get_ict_batch
(
data_iterator
):
# Items and their type.
keys
=
[
'query_tokens'
,
'query_pad_mask'
,
'block_tokens'
,
'block_pad_mask'
,
'block_data'
]
datatype
=
torch
.
int64
# Broadcast data.
if
data_iterator
is
None
:
data
=
None
else
:
data
=
next
(
data_iterator
)
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
query_tokens
=
data_b
[
'query_tokens'
].
long
()
query_pad_mask
=
data_b
[
'query_pad_mask'
].
long
()
block_tokens
=
data_b
[
'block_tokens'
].
long
()
block_pad_mask
=
data_b
[
'block_pad_mask'
].
long
()
block_indices
=
data_b
[
'block_data'
].
long
()
return
query_tokens
,
query_pad_mask
,
\
block_tokens
,
block_pad_mask
,
block_indices
def
join_str_list
(
str_list
):
"""Join a list of strings, handling spaces appropriately"""
result
=
""
...
...
@@ -46,10 +96,12 @@ class BlockSamplesMapping(object):
# make sure that the array is compatible with BlockSampleData
assert
mapping_array
.
shape
[
1
]
==
4
self
.
mapping_array
=
mapping_array
self
.
shape
=
self
.
mapping_array
.
shape
def
__len__
(
self
):
return
self
.
mapping_array
.
shape
[
0
]
def
__getitem__
(
self
,
idx
):
"""Get the data associated with a
particular
sample."""
"""Get the data associated with a
n indexed
sample."""
sample_data
=
BlockSampleData
(
*
self
.
mapping_array
[
idx
])
return
sample_data
...
...
@@ -144,6 +196,6 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
print_rank_0
(
' loaded indexed file in {:3.3f} seconds'
.
format
(
time
.
time
()
-
start_time
))
print_rank_0
(
' total number of samples: {}'
.
format
(
samples_
mapping
.
shape
[
0
]))
mapping
_array
.
shape
[
0
]))
return
samples_mapping
megatron/data/realm_index.py
View file @
11f76cd3
...
...
@@ -3,7 +3,6 @@ import os
import
pickle
import
shutil
import
faiss
import
numpy
as
np
import
torch
...
...
@@ -123,6 +122,11 @@ class FaissMIPSIndex(object):
def
_set_block_index
(
self
):
"""Create a Faiss Flat index with inner product as the metric to search against"""
try
:
import
faiss
except
ImportError
:
raise
Exception
(
"Error: Please install faiss to use FaissMIPSIndex"
)
print
(
"
\n
> Building index"
,
flush
=
True
)
self
.
block_mips_index
=
faiss
.
index_factory
(
self
.
embed_size
,
'Flat'
,
faiss
.
METRIC_INNER_PRODUCT
)
...
...
@@ -188,19 +192,18 @@ class FaissMIPSIndex(object):
if False: return [num_queries x k] array of distances, and another for indices
"""
query_embeds
=
np
.
float32
(
detach
(
query_embeds
))
with
torch
.
no_grad
():
if
reconstruct
:
# get the vectors themselves
top_k_block_embeds
=
self
.
block_mips_index
.
search_and_reconstruct
(
query_embeds
,
top_k
)
return
top_k_block_embeds
else
:
# get distances and indices of closest vectors
distances
,
block_indices
=
self
.
block_mips_index
.
search
(
query_embeds
,
top_k
)
if
self
.
use_gpu
:
fresh_indices
=
np
.
zeros
(
block_indices
.
shape
)
for
i
,
j
in
itertools
.
product
(
block_indices
.
shape
):
fresh_indices
[
i
,
j
]
=
self
.
id_map
[
block_indices
[
i
,
j
]]
block_indices
=
fresh_indices
return
distances
,
block_indices
if
reconstruct
:
# get the vectors themselves
top_k_block_embeds
=
self
.
block_mips_index
.
search_and_reconstruct
(
query_embeds
,
top_k
)
return
top_k_block_embeds
else
:
# get distances and indices of closest vectors
distances
,
block_indices
=
self
.
block_mips_index
.
search
(
query_embeds
,
top_k
)
if
self
.
use_gpu
:
fresh_indices
=
np
.
zeros
(
block_indices
.
shape
)
for
i
,
j
in
itertools
.
product
(
block_indices
.
shape
):
fresh_indices
[
i
,
j
]
=
self
.
id_map
[
block_indices
[
i
,
j
]]
block_indices
=
fresh_indices
return
distances
,
block_indices
indexer.py
→
megatron/
indexer.py
View file @
11f76cd3
import
torch
import
torch.distributed
as
dist
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
ge
t_checkpoint
_name
from
megatron.data.dataset
_utils
import
get_i
ndexed
_dataset
_
from
megatron.data.
ict
_dataset
import
ICTDataset
from
megatron.checkpointing
import
load_ic
t_checkpoint
from
megatron.data.
ict_
dataset
import
get_i
ct
_dataset
from
megatron.data.
realm
_dataset
_utils
import
get_one_epoch_dataloader
from
megatron.data.realm_index
import
detach
,
BlockData
from
megatron.data.
samplers
import
DistributedBatchSampler
from
megatron.
initialize
import
initialize_megatron
from
megatron.data.
realm_dataset_utils
import
get_ict_batch
from
megatron.
model.realm_model
import
general_ict_model_provider
from
megatron.training
import
get_model
from
pretrain_ict
import
get_batch
,
general_ict_model_provider
def
pprint
(
*
args
):
print
(
*
args
,
flush
=
True
)
class
IndexBuilder
(
object
):
...
...
@@ -30,22 +24,27 @@ class IndexBuilder(object):
assert
not
(
args
.
load
and
args
.
ict_load
)
self
.
using_realm_chkpt
=
args
.
ict_load
is
None
self
.
log_interval
=
args
.
indexer_log_interval
self
.
batch_size
=
args
.
indexer_batch_size
self
.
load_attributes
()
self
.
is_main_builder
=
args
.
rank
==
0
self
.
is_main_builder
=
mpu
.
get_data_parallel_rank
()
==
0
self
.
num_total_builders
=
mpu
.
get_data_parallel_world_size
()
self
.
iteration
=
self
.
total_processed
=
0
def
load_attributes
(
self
):
"""Load the necessary attributes: model, dataloader and empty BlockData"""
self
.
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
from_realm_chkpt
=
self
.
using_realm_chkpt
)
model
=
get_model
(
lambda
:
general_ict_model_provider
(
only_block_model
=
True
))
self
.
model
=
load_ict_checkpoint
(
model
,
only_block_model
=
True
,
from_realm_chkpt
=
self
.
using_realm_chkpt
)
self
.
model
.
eval
()
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
get_ict_dataset
()))
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
get_ict_dataset
()
,
self
.
batch_size
))
self
.
block_data
=
BlockData
(
load_from_path
=
False
)
def
track_and_report_progress
(
self
,
batch_size
):
"""Utility function for tracking progress"""
self
.
iteration
+=
1
self
.
total_processed
+=
batch_size
if
self
.
i
teration
%
10
==
0
:
self
.
total_processed
+=
batch_size
*
self
.
num_total_builders
if
self
.
i
s_main_builder
and
self
.
iteration
%
self
.
log_interval
==
0
:
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
self
.
iteration
,
self
.
total_processed
),
flush
=
True
)
def
build_and_save_index
(
self
):
...
...
@@ -58,17 +57,20 @@ class IndexBuilder(object):
while
True
:
try
:
# batch also has query_tokens and query_pad_data
_
,
_
,
block_tokens
,
block_pad_mask
,
block_sample_data
=
get_batch
(
self
.
dataloader
)
except
:
_
,
_
,
block_tokens
,
block_pad_mask
,
block_sample_data
=
get_
ict_
batch
(
self
.
dataloader
)
except
StopIteration
:
break
# detach, setup and add to BlockData
unwrapped_model
=
self
.
model
while
not
hasattr
(
unwrapped_model
,
'embed_block'
):
unwrapped_model
=
unwrapped_model
.
module
block_logits
=
detach
(
unwrapped_model
.
embed_block
(
block_tokens
,
block_pad_mask
))
# detach, separate fields and add to BlockData
block_logits
=
detach
(
unwrapped_model
.
embed_block
(
block_tokens
,
block_pad_mask
))
detached_data
=
detach
(
block_sample_data
)
# block_sample_data is a 2D array [batch x 4]
# with columns [start_idx, end_idx, doc_idx, block_idx] same as class BlockSampleData
block_indices
=
detached_data
[:,
3
]
block_metas
=
detached_data
[:,
:
3
]
...
...
@@ -86,98 +88,3 @@ class IndexBuilder(object):
self
.
block_data
.
clear
()
def
load_ict_checkpoint
(
only_query_model
=
False
,
only_block_model
=
False
,
from_realm_chkpt
=
False
):
"""load ICT checkpoints for indexing/retrieving. Arguments specify which parts of the state dict to actually use."""
args
=
get_args
()
model
=
get_model
(
lambda
:
general_ict_model_provider
(
only_query_model
,
only_block_model
))
if
isinstance
(
model
,
torchDDP
):
model
=
model
.
module
load_path
=
args
.
load
if
from_realm_chkpt
else
args
.
ict_load
tracker_filename
=
get_checkpoint_tracker_filename
(
load_path
)
with
open
(
tracker_filename
,
'r'
)
as
f
:
iteration
=
int
(
f
.
read
().
strip
())
# assert iteration > 0
checkpoint_name
=
get_checkpoint_name
(
load_path
,
iteration
,
False
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading checkpoint {}'
.
format
(
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
ict_state_dict
=
state_dict
[
'model'
]
if
from_realm_chkpt
:
print
(
">>>> Attempting to get ict state dict from realm"
,
flush
=
True
)
ict_state_dict
=
ict_state_dict
[
'retriever'
][
'ict_model'
]
if
only_query_model
:
ict_state_dict
.
pop
(
'context_model'
)
if
only_block_model
:
ict_state_dict
.
pop
(
'question_model'
)
model
.
load_state_dict
(
ict_state_dict
)
torch
.
distributed
.
barrier
()
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
' successfully loaded {}'
.
format
(
checkpoint_name
))
return
model
def
get_ict_dataset
(
use_titles
=
True
,
query_in_block_prob
=
1
):
"""Get a dataset which uses block samples mappings to get ICT/block indexing data"""
args
=
get_args
()
block_dataset
=
get_indexed_dataset_
(
args
.
data_path
,
'mmap'
,
True
)
titles_dataset
=
get_indexed_dataset_
(
args
.
titles_data_path
,
'mmap'
,
True
)
kwargs
=
dict
(
name
=
'full'
,
block_dataset
=
block_dataset
,
title_dataset
=
titles_dataset
,
data_prefix
=
args
.
data_path
,
num_epochs
=
1
,
max_num_samples
=
None
,
max_seq_length
=
args
.
seq_length
,
seed
=
1
,
query_in_block_prob
=
query_in_block_prob
,
use_titles
=
use_titles
,
use_one_sent_docs
=
True
)
dataset
=
ICTDataset
(
**
kwargs
)
return
dataset
def
get_one_epoch_dataloader
(
dataset
,
batch_size
=
None
):
"""Specifically one epoch to be used in an indexing job."""
args
=
get_args
()
world_size
=
mpu
.
get_data_parallel_world_size
()
rank
=
mpu
.
get_data_parallel_rank
()
if
batch_size
is
None
:
batch_size
=
args
.
batch_size
global_batch_size
=
batch_size
*
world_size
num_workers
=
args
.
num_workers
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
# importantly, drop_last must be False to get all the data.
batch_sampler
=
DistributedBatchSampler
(
sampler
,
batch_size
=
global_batch_size
,
drop_last
=
False
,
rank
=
rank
,
world_size
=
world_size
)
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
num_workers
=
num_workers
,
pin_memory
=
True
)
if
__name__
==
"__main__"
:
# This usage is for basic (as opposed to realm async) indexing jobs.
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
index_builder
=
IndexBuilder
()
index_builder
.
build_and_save_index
()
megatron/model/realm_model.py
View file @
11f76cd3
import
os
import
torch
from
megatron
import
get_args
from
megatron
import
get_args
,
print_rank_0
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.model
import
BertModel
from
megatron.module
import
MegatronModule
...
...
@@ -13,6 +13,28 @@ from megatron.model.utils import scaled_init_method_normal
from
megatron.model.bert_model
import
bert_attention_mask_func
,
bert_extended_attention_mask
,
bert_position_ids
def
general_ict_model_provider
(
only_query_model
=
False
,
only_block_model
=
False
):
"""Build the model."""
args
=
get_args
()
assert
args
.
ict_head_size
is
not
None
,
\
"Need to specify --ict-head-size to provide an ICTBertModel"
assert
args
.
model_parallel_size
==
1
,
\
"Model parallel size > 1 not supported for ICT"
print_rank_0
(
'building ICTBertModel...'
)
# simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes
model
=
ICTBertModel
(
ict_head_size
=
args
.
ict_head_size
,
num_tokentypes
=
2
,
parallel_output
=
True
,
only_query_model
=
only_query_model
,
only_block_model
=
only_block_model
)
return
model
class
ICTBertModel
(
MegatronModule
):
"""Bert-based module for Inverse Cloze task."""
def
__init__
(
self
,
...
...
pretrain_ict.py
View file @
11f76cd3
...
...
@@ -27,33 +27,11 @@ from megatron.data.dataset_utils import build_train_valid_test_datasets
from
megatron.model
import
ICTBertModel
from
megatron.training
import
pretrain
from
megatron.utils
import
reduce_losses
from
megatron.model.realm_model
import
general_ict_model_provider
from
megatron.data.realm_dataset_utils
import
get_ict_batch
num_batches
=
0
def
general_ict_model_provider
(
only_query_model
=
False
,
only_block_model
=
False
):
"""Build the model."""
args
=
get_args
()
assert
args
.
ict_head_size
is
not
None
,
\
"Need to specify --ict-head-size to provide an ICTBertModel"
assert
args
.
model_parallel_size
==
1
,
\
"Model parallel size > 1 not supported for ICT"
print_rank_0
(
'building ICTBertModel...'
)
# simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes
model
=
ICTBertModel
(
ict_head_size
=
args
.
ict_head_size
,
num_tokentypes
=
2
,
parallel_output
=
True
,
only_query_model
=
only_query_model
,
only_block_model
=
only_block_model
)
return
model
def
model_provider
():
def
pretrain_ict_model_provider
():
return
general_ict_model_provider
(
False
,
False
)
...
...
@@ -95,30 +73,6 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function):
return
output
def
get_batch
(
data_iterator
):
# Items and their type.
keys
=
[
'query_tokens'
,
'query_pad_mask'
,
'block_tokens'
,
'block_pad_mask'
,
'block_data'
]
datatype
=
torch
.
int64
# Broadcast data.
if
data_iterator
is
None
:
data
=
None
else
:
data
=
next
(
data_iterator
)
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
query_tokens
=
data_b
[
'query_tokens'
].
long
()
query_pad_mask
=
data_b
[
'query_pad_mask'
].
long
()
block_tokens
=
data_b
[
'block_tokens'
].
long
()
block_pad_mask
=
data_b
[
'block_pad_mask'
].
long
()
block_indices
=
data_b
[
'block_data'
].
long
()
return
query_tokens
,
query_pad_mask
,
\
block_tokens
,
block_pad_mask
,
block_indices
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
args
=
get_args
()
...
...
@@ -127,7 +81,7 @@ def forward_step(data_iterator, model):
# Get the batch.
timers
(
'batch generator'
).
start
()
query_tokens
,
query_pad_mask
,
\
block_tokens
,
block_pad_mask
,
block_indices
=
get_batch
(
data_iterator
)
block_tokens
,
block_pad_mask
,
block_indices
=
get_
ict_
batch
(
data_iterator
)
timers
(
'batch generator'
).
stop
()
...
...
@@ -181,5 +135,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
forward_step
,
pretrain
(
train_valid_test_datasets_provider
,
pretrain_ict_
model_provider
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
tools/create_doc_index.py
0 → 100644
View file @
11f76cd3
import
sys
sys
.
path
.
append
(
'../'
)
from
megatron.indexer
import
IndexBuilder
from
megatron.initialize
import
initialize_megatron
def
main
():
"""Create a BlockData data structure by running an IndexBuilder over an ICT Dataset
- Include all args needed for initial model specification
Other key args:
--block-data-path: path to write to
--ict-load or --realm-load: path to checkpoint with which to embed
--data-path and --titles-data-path: paths for dataset
Check README.md for example script
"""
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
index_builder
=
IndexBuilder
()
index_builder
.
build_and_save_index
()
if
__name__
==
"__main__"
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment