Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
32bb4edc
Commit
32bb4edc
authored
Jun 05, 2020
by
Neel Kant
Browse files
Prune changes to only be related to ICT
parent
674814a5
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
45 additions
and
1675 deletions
+45
-1675
faiss_test.py
faiss_test.py
+0
-192
ict_eval_bm25.py
ict_eval_bm25.py
+0
-124
indexer.py
indexer.py
+0
-293
megatron/arguments.py
megatron/arguments.py
+0
-10
megatron/checkpointing.py
megatron/checkpointing.py
+4
-5
megatron/data/bert_dataset.py
megatron/data/bert_dataset.py
+8
-9
megatron/data/dataset_utils.py
megatron/data/dataset_utils.py
+2
-10
megatron/data/ict_dataset.py
megatron/data/ict_dataset.py
+0
-179
megatron/data/preprocess_data.py
megatron/data/preprocess_data.py
+0
-125
megatron/data/realm_dataset.py
megatron/data/realm_dataset.py
+6
-80
megatron/data/realm_dataset_utils.py
megatron/data/realm_dataset_utils.py
+1
-109
megatron/data/realm_index.py
megatron/data/realm_index.py
+0
-319
megatron/deprecated_data_utils/__init__.py
megatron/deprecated_data_utils/__init__.py
+3
-9
megatron/deprecated_data_utils/configure_data.py
megatron/deprecated_data_utils/configure_data.py
+4
-2
megatron/deprecated_data_utils/datasets.py
megatron/deprecated_data_utils/datasets.py
+7
-184
megatron/global_vars.py
megatron/global_vars.py
+2
-2
megatron/initialize.py
megatron/initialize.py
+2
-9
megatron/model/bert_model.py
megatron/model/bert_model.py
+1
-6
megatron/model/distributed.py
megatron/model/distributed.py
+1
-1
megatron/model/language_model.py
megatron/model/language_model.py
+4
-7
No files found.
faiss_test.py
deleted
100644 → 0
View file @
674814a5
from
collections
import
defaultdict
import
time
import
pickle
import
faiss
from
faiss
import
index_factory
,
index_cpu_to_gpu
import
numpy
as
np
from
megatron
import
get_args
PCAS
=
[
'PCA'
,
'PCAR'
,
'PCAW'
,
'PCAWR'
]
# PCA to 64 dim gets "first missing" ~ 95% and "mixed" ~ 5% for all
# however, this is pretty hard since the embeds and queries are totally random, would be better to test according to a distribution
# update: Using realisitc mean and covariance helps, but then adjusting for inner product makes it unusable again
# CONCLUSION: PCA should not be used for MIPS
QUANTIZERS
=
[
'IVF4096_SQ16'
,
# 'IMI2x9',
'HNSW32_SQ16'
,
# 'IVF4096_HNSW32'
]
# IMI2x9 or any other MultiIndex doesn't support inner product so it's unusable
# IVF4096_HNSW32 doesn't support inner product either
ENCODINGS
=
[
'Flat'
,
'PQ16np'
,
# PQ16, PQ16x12(np)
'SQ4'
,
'SQ8'
,
'SQ6'
,
'SQfp16'
,
# 'LSH', 'LSHrt', 'LSHr', 'LSHt'
]
# PQ16 is pretty slow for creating and adding - ~96s for 1e5, 105s for 1e6
# PQ16np is a bit faster but is pretty inaccurate - misses top-1 result 2/3 of time (1e6 embeds)
# PQ16x12(np) gets real slow. Uses 4096 centroids.
# SQfp16 is solid.
# LSH is inaccurate - pretty much always missing the top-1 result (1e6 embeds)
def
latest
(
times
):
return
times
[
-
1
]
-
times
[
-
2
]
def
get_embed_mean_and_cov
():
embed_data
=
pickle
.
load
(
open
(
'/home/dcg-adlr-nkant-data.cosmos1202/hash_data/normed4096_whitened.pkl'
,
'rb'
))
embed_mean
=
embed_data
[
'embed_mean'
]
whitener
=
embed_data
[
'embed_whitener'
]
embed_cov
=
whitener
.
dot
(
whitener
.
transpose
())
return
embed_mean
,
embed_cov
def
get_embeds_and_queries
(
mean
,
cov
,
num_embeds
,
num_queries
):
embeds
=
np
.
random
.
multivariate_normal
(
mean
,
cov
,
num_embeds
).
astype
(
'float32'
)
queries
=
np
.
random
.
multivariate_normal
(
mean
,
cov
,
num_queries
).
astype
(
'float32'
)
return
embeds
,
queries
def
get_random_embeds_and_queries
(
d
,
num_embeds
,
num_queries
):
embeds
=
np
.
random
.
rand
(
num_embeds
,
d
).
astype
(
'float32'
)
queries
=
np
.
random
.
rand
(
num_queries
,
d
).
astype
(
'float32'
)
return
embeds
,
queries
def
print_timing_stats
(
name
,
create_and_add
,
search
):
print
(
'{:20s} Create and add embeds: {:10.4f}s | Search embeds: {:10.4f}s'
.
format
(
name
,
create_and_add
,
search
))
def
print_accuracy_stats
(
name
,
gold_indices
,
estimated_indices
):
gold_indices
,
estimated_indices
=
list
(
gold_indices
),
list
(
estimated_indices
)
results
=
defaultdict
(
int
)
for
gold
,
estimated
in
zip
(
gold_indices
,
estimated_indices
):
if
gold
[
0
]
not
in
estimated
:
results
[
'first_missing'
]
+=
1
elif
np
.
array_equal
(
gold
,
estimated
):
results
[
'all_equal'
]
+=
1
else
:
results
[
'mixed'
]
+=
1
result_strs
=
[
'first_missing'
,
'all_equal'
,
'mixed'
]
print
(
'{:20s} First missing: {:4d} | All equal: {:4d} | Mixed: {:4d}'
.
format
(
name
,
*
[
results
[
s
]
for
s
in
result_strs
]))
def
create_and_test_gold
(
d
,
k
,
embeds
,
queries
):
times
=
[
time
.
time
()]
res
=
faiss
.
StandardGpuResources
()
gold_idx
=
index_cpu_to_gpu
(
res
,
0
,
index_factory
(
d
,
'Flat'
))
gold_idx
.
add
(
embeds
)
times
.
append
(
time
.
time
())
create_and_add
=
latest
(
times
)
distances
,
indices
=
gold_idx
.
search
(
queries
,
k
)
times
.
append
(
time
.
time
())
print_timing_stats
(
'Flat'
,
create_and_add
,
latest
(
times
))
print
(
'-'
*
100
)
return
distances
,
indices
def
test_pca
(
d
,
k
,
embeds
,
queries
,
pca_dim
):
distances
,
indices
=
create_and_test_gold
(
d
,
k
,
embeds
,
queries
)
times
=
[
time
.
time
()]
all_pca_indices
=
[]
for
s
in
PCAS
:
pca_idx
=
index_factory
(
d
,
s
+
"{},Flat"
.
format
(
pca_dim
),
faiss
.
METRIC_INNER_PRODUCT
)
pca_idx
.
train
(
embeds
)
pca_idx
.
add
(
embeds
)
times
.
append
(
time
.
time
())
create_and_add
=
latest
(
times
)
pca_distances
,
pca_indices
=
pca_idx
.
search
(
queries
,
k
)
all_pca_indices
.
append
(
pca_indices
)
times
.
append
(
time
.
time
())
print_timing_stats
(
s
,
create_and_add
,
latest
(
times
))
print
(
'
\n
'
)
for
s
,
pca_indices
in
zip
(
PCAS
,
all_pca_indices
):
print_accuracy_stats
(
s
,
indices
,
pca_indices
)
def
test_quantizers
(
d
,
k
,
embeds
,
queries
):
distances
,
indices
=
create_and_test_gold
(
d
,
k
,
embeds
,
queries
)
times
=
[
time
.
time
()]
for
s
in
QUANTIZERS
:
if
'HNSW'
in
s
:
quant_idx
=
index_factory
(
d
,
s
,
faiss
.
METRIC_INNER_PRODUCT
)
else
:
quant_idx
=
index_factory
(
d
,
"Flat,"
+
s
,
faiss
.
METRIC_INNER_PRODUCT
)
quant_idx
.
train
(
embeds
)
quant_idx
.
add
(
embeds
)
times
.
append
(
time
.
time
())
create_and_add
=
latest
(
times
)
quant_distances
,
quant_indices
=
quant_idx
.
search
(
queries
,
k
)
times
.
append
(
time
.
time
())
print_timing_stats
(
s
,
create_and_add
,
latest
(
times
))
def
test_encodings
(
d
,
k
,
embeds
,
queries
):
distances
,
indices
=
create_and_test_gold
(
d
,
k
,
embeds
,
queries
)
times
=
[
time
.
time
()]
all_encode_indices
=
[]
for
s
in
ENCODINGS
:
encode_idx
=
index_factory
(
d
,
s
,
faiss
.
METRIC_INNER_PRODUCT
)
encode_idx
.
train
(
embeds
)
encode_idx
.
add
(
embeds
)
times
.
append
(
time
.
time
())
create_and_add
=
latest
(
times
)
_
,
encode_indices
=
encode_idx
.
search
(
queries
,
k
)
all_encode_indices
.
append
(
encode_indices
)
times
.
append
(
time
.
time
())
print_timing_stats
(
s
,
create_and_add
,
latest
(
times
))
print
(
'
\n
'
)
for
s
,
encode_indices
in
zip
(
ENCODINGS
,
all_encode_indices
):
print_accuracy_stats
(
s
,
indices
,
encode_indices
)
def
run_all_tests
():
mean
,
cov
=
get_embed_mean_and_cov
()
embeds
,
queries
=
get_embeds_and_queries
(
mean
,
cov
,
int
(
1e6
),
256
)
d
=
128
k
=
10
test_pca
(
d
,
k
,
embeds
,
queries
,
96
)
test_quantizers
(
d
,
k
,
embeds
,
queries
)
test_encodings
(
d
,
k
,
embeds
,
queries
)
if
__name__
==
"__main__"
:
run_all_tests
()
ict_eval_bm25.py
deleted
100644 → 0
View file @
674814a5
import
lucene
import
sys
from
java.nio.file
import
Paths
from
org.apache.lucene.analysis.standard
import
StandardAnalyzer
from
org.apache.lucene.document
import
Document
,
Field
,
FieldType
from
org.apache.lucene.index
import
IndexWriter
,
IndexWriterConfig
,
IndexOptions
,
DirectoryReader
from
org.apache.lucene.store
import
SimpleFSDirectory
from
org.apache.lucene.search
import
IndexSearcher
from
org.apache.lucene.queryparser.classic
import
QueryParser
from
org.apache.lucene.search.similarities
import
BM25Similarity
from
org.apache.lucene.util
import
Version
import
torch
import
torch.distributed
as
dist
from
indexer
import
get_ict_dataset
,
get_one_epoch_dataloader
from
megatron.initialize
import
initialize_megatron
from
pretrain_bert_ict
import
get_batch
def
setup
():
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
lucene
.
initVM
(
vmargs
=
[
'-Djava.awt.headless=true'
])
def
run
(
embed_all
=
False
):
dset
=
get_ict_dataset
(
use_titles
=
False
,
query_in_block_prob
=
0.1
)
dataloader
=
iter
(
get_one_epoch_dataloader
(
dset
))
index_dir
=
SimpleFSDirectory
(
Paths
.
get
(
"full_wiki_index/"
))
analyzer
=
StandardAnalyzer
()
analyzer
.
setMaxTokenLength
(
1024
)
config
=
IndexWriterConfig
(
analyzer
)
config
.
setOpenMode
(
IndexWriterConfig
.
OpenMode
.
CREATE
)
writer
=
IndexWriter
(
index_dir
,
config
)
# field for document ID
t1
=
FieldType
()
t1
.
setStored
(
True
)
t1
.
setTokenized
(
False
)
# field for document text
t2
=
FieldType
()
t2
.
setStored
(
True
)
t2
.
setTokenized
(
True
)
t2
.
setIndexOptions
(
IndexOptions
.
DOCS_AND_FREQS_AND_POSITIONS
)
correct
=
total
=
0
round_correct
=
torch
.
zeros
(
1
).
cuda
()
round_total
=
torch
.
zeros
(
1
).
cuda
()
for
round
in
range
(
100000
):
with
torch
.
no_grad
():
try
:
query_tokens
,
query_pad_mask
,
\
block_tokens
,
block_pad_mask
,
block_index_data
=
get_batch
(
dataloader
)
except
:
break
# query_tokens = query_tokens.detach().cpu().numpy()
block_tokens
=
block_tokens
.
detach
().
cpu
().
numpy
()
# query_strs = [dset.decode_tokens(query_tokens[i].tolist(), hardcore=True) for i in range(query_tokens.shape[0])]
block_strs
=
[
dset
.
decode_tokens
(
block_tokens
[
i
].
tolist
(),
hardcore
=
True
)
for
i
in
range
(
block_tokens
.
shape
[
0
])]
def
add_document
(
text
,
writer
,
doc_id
):
doc
=
Document
()
doc
.
add
(
Field
(
"text"
,
text
,
t2
))
doc
.
add
(
Field
(
"doc_id"
,
doc_id
,
t1
))
writer
.
addDocument
(
doc
)
# add documents to index writer
for
i
in
range
(
len
(
block_strs
)):
add_document
(
block_strs
[
i
],
writer
,
i
)
# write and finalize the index
writer
.
commit
()
# define BM25 searcher
# searcher = IndexSearcher(DirectoryReader.open(index_dir))
# searcher.setSimilarity(BM25Similarity())
# # feed queries and get scores for everything in the index
# hits_list = []
# for s in query_strs:
# query = QueryParser("text", analyzer).parse(s)
# hits = searcher.search(query, 1).scoreDocs
# hits_list.append(hits)
# for (i, hits) in enumerate(hits_list):
# doc_ids = [int(searcher.doc(hit.doc)['doc_id']) for hit in hits]
# correct += int(i in doc_ids)
# total += 1
# dist.all_reduce(round_correct)
# dist.all_reduce(round_total)
# correct += int(round_correct.item())
# total += int(round_total.item())
# round_correct -= round_correct
# round_total -= round_total
# print("Correct: {:8d} | Total: {:8d} | Fraction: {:6.5f}".format(correct, total, correct / total))
if
round
%
10
==
0
:
print
(
round
)
writer
.
close
()
# Plan
# overall accuracy test:
# have index with all blocks. For BERT these are token ids, for BM25 these are tokens
#
# 1. run batch size 4096 BM25 self similarity test. For this I can just detokenize out of the dataset.
# I get the retrieval scores in the forward_step and log the results.
# 2. Create a BM25 index over all of wikipedia, have it ready for use in megatron QA.
#
# Create an index with the block embeddings with block ids
if
__name__
==
"__main__"
:
setup
()
run
()
indexer.py
deleted
100644 → 0
View file @
674814a5
import
os
import
sys
import
time
import
torch
import
torch.distributed
as
dist
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
get_args
,
get_adlr_autoresume
,
print_rank_0
from
megatron
import
mpu
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.data.bert_dataset
import
get_indexed_dataset_
from
megatron.data.realm_dataset
import
ICTDataset
from
megatron.data.realm_index
import
detach
,
BlockData
,
FaissMIPSIndex
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.initialize
import
initialize_megatron
from
megatron.model
import
REALMRetriever
from
megatron.global_vars
import
set_global_variables
from
megatron.mpu.initialize
import
get_index_ready
,
get_index_group
,
get_train_group
,
get_data_parallel_group
,
get_gloo_comm_group
from
megatron.mpu.initialize
import
set_data_parallel_group
,
set_model_parallel_group
,
init_realm_groups
from
megatron.initialize
import
init_distributed
,
_init_autoresume
,
_set_random_seed
,
_write_args_to_tensorboard
from
megatron.training
import
get_model
from
megatron.utils
import
check_adlr_autoresume_termination
from
pretrain_bert_ict
import
get_batch
,
model_provider
INDEX_READY
=
None
def
pprint
(
*
args
):
print
(
*
args
,
flush
=
True
)
def
initialize_and_run_async_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{},
ignore_unknown_args
=
False
,
allow_no_cuda
=
False
):
if
not
allow_no_cuda
:
# Make sure cuda is available.
assert
torch
.
cuda
.
is_available
(),
'Megatron requires CUDA.'
# Parse args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
set_global_variables
(
extra_args_provider
=
extra_args_provider
,
args_defaults
=
args_defaults
,
ignore_unknown_args
=
ignore_unknown_args
)
# instead of _initialize_distributed()
init_distributed
()
setup_realm_groups_and_vars
()
global
INDEX_READY
INDEX_READY
=
get_index_ready
()
pprint
(
'finished setting up groups'
)
# Autoresume
_init_autoresume
()
pprint
(
'finished setting up autoresume'
)
# Random seeds for reproducibility.
args
=
get_args
()
if
args
.
rank
==
0
:
pprint
(
'> setting random seeds to {} ...'
.
format
(
args
.
seed
))
_set_random_seed
(
args
.
seed
)
# Write arguments to tensorboard.
_write_args_to_tensorboard
()
pprint
(
'finished writing args to tensorboard'
)
torch
.
distributed
.
barrier
()
if
args
.
rank
<
args
.
max_training_rank
:
torch
.
distributed
.
barrier
(
get_data_parallel_group
())
pprint
(
"All trainers ready."
)
return
else
:
runner
=
AsyncIndexBuilder
(
args
.
rank
)
torch
.
distributed
.
barrier
(
get_data_parallel_group
())
pprint
(
"All indexers ready."
)
runner
.
run_async
()
def
setup_realm_groups_and_vars
():
args
=
get_args
()
world_size
=
dist
.
get_world_size
()
max_training_rank
=
args
.
max_training_rank
# assuming no model parallelism right now
set_model_parallel_group
(
dist
.
new_group
([
args
.
rank
]))
init_realm_groups
(
max_training_rank
,
world_size
)
if
args
.
rank
<
max_training_rank
:
set_data_parallel_group
(
get_train_group
())
else
:
set_data_parallel_group
(
get_index_group
())
class
IndexBuilder
(
object
):
def
__init__
(
self
):
args
=
get_args
()
self
.
debug
=
args
.
debug
self
.
rank
=
args
.
rank
self
.
model
=
None
self
.
dataloader
=
None
self
.
block_data
=
None
self
.
load_attributes
()
self
.
is_main_builder
=
args
.
rank
==
0
def
load_attributes
(
self
):
self
.
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
no_grad
=
True
,
from_realm_chkpt
=
False
)
self
.
model
.
eval
()
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
get_ict_dataset
()))
self
.
block_data
=
BlockData
()
def
build_and_save_index
(
self
):
i
=
1
total
=
0
while
True
:
with
torch
.
no_grad
():
try
:
query_tokens
,
query_pad_mask
,
\
block_tokens
,
block_pad_mask
,
block_index_data
=
get_batch
(
self
.
dataloader
)
except
:
break
block_index_data
=
detach
(
block_index_data
)
block_indices
=
block_index_data
[:,
3
]
block_meta
=
block_index_data
[:,
:
3
]
block_logits
=
detach
(
self
.
model
(
None
,
None
,
block_tokens
,
block_pad_mask
,
only_block
=
True
))
self
.
block_data
.
add_block_data
(
block_indices
,
block_logits
,
block_meta
)
total
+=
block_indices
.
size
i
+=
1
if
i
%
1000
==
0
:
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
i
,
total
),
flush
=
True
)
if
self
.
debug
:
break
self
.
block_data
.
save_shard
(
self
.
rank
)
torch
.
distributed
.
barrier
(
get_data_parallel_group
())
del
self
.
model
if
self
.
is_main_builder
:
self
.
block_data
.
consolidate_shards_and_save
(
ignore_shard
=
self
.
rank
)
self
.
block_data
.
clear
()
class
AsyncIndexBuilder
(
IndexBuilder
):
def
__init__
(
self
,
rank
):
self
.
rank
=
rank
args
=
get_args
()
self
.
is_main_builder
=
self
.
rank
==
args
.
max_training_rank
self
.
main_builder_idx
=
args
.
max_training_rank
self
.
debug
=
args
.
debug
self
.
model
=
None
self
.
dataloader
=
None
self
.
block_data
=
None
self
.
load_attributes
()
global
INDEX_READY
INDEX_READY
=
get_index_ready
()
def
run_async
(
self
):
global
INDEX_READY
# synchronize for start
dist
.
broadcast
(
INDEX_READY
,
0
,
group
=
get_gloo_comm_group
())
while
True
:
print
(
"Starting (again!)"
,
flush
=
True
)
self
.
build_and_save_index
()
self
.
send_index_ready_signal
()
while
INDEX_READY
==
1
:
print
(
"Waiting for new model checkpoint."
,
flush
=
True
)
time
.
sleep
(
5
)
self
.
load_attributes
()
def
load_attributes
(
self
):
try
:
self
.
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
no_grad
=
True
,
from_realm_chkpt
=
True
)
except
:
print
(
">>>>> No realm chkpt available"
,
flush
=
True
)
self
.
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
no_grad
=
True
,
from_realm_chkpt
=
False
)
self
.
model
.
eval
()
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
get_ict_dataset
()))
self
.
block_data
=
BlockData
()
def
send_index_ready_signal
(
self
):
global
INDEX_READY
if
self
.
is_main_builder
:
INDEX_READY
=
1
-
INDEX_READY
print
(
"Switched INDEX_READY"
,
flush
=
True
)
torch
.
cuda
.
synchronize
()
# send handle
dist
.
broadcast
(
INDEX_READY
,
self
.
main_builder_idx
,
group
=
get_gloo_comm_group
(),
async_op
=
True
)
# recv handle
dist
.
broadcast
(
INDEX_READY
,
0
,
group
=
get_gloo_comm_group
())
torch
.
distributed
.
barrier
(
get_data_parallel_group
())
def
load_ict_checkpoint
(
only_query_model
=
False
,
only_block_model
=
False
,
no_grad
=
False
,
from_realm_chkpt
=
False
):
args
=
get_args
()
model
=
get_model
(
lambda
:
model_provider
(
only_query_model
,
only_block_model
))
if
isinstance
(
model
,
torchDDP
):
model
=
model
.
module
load_path
=
args
.
load
if
from_realm_chkpt
else
args
.
ict_load
tracker_filename
=
get_checkpoint_tracker_filename
(
load_path
)
with
open
(
tracker_filename
,
'r'
)
as
f
:
iteration
=
int
(
f
.
read
().
strip
())
# assert iteration > 0
checkpoint_name
=
get_checkpoint_name
(
load_path
,
iteration
,
False
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading checkpoint {}'
.
format
(
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
ict_state_dict
=
state_dict
[
'model'
]
if
from_realm_chkpt
:
print
(
">>>> Attempting to get ict state dict from realm"
,
flush
=
True
)
ict_state_dict
=
ict_state_dict
[
'retriever'
][
'ict_model'
]
if
only_query_model
:
ict_state_dict
.
pop
(
'context_model'
)
if
only_block_model
:
ict_state_dict
.
pop
(
'question_model'
)
if
no_grad
:
with
torch
.
no_grad
():
model
.
load_state_dict
(
ict_state_dict
)
else
:
model
.
load_state_dict
(
ict_state_dict
)
torch
.
distributed
.
barrier
(
get_data_parallel_group
())
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
' successfully loaded {}'
.
format
(
checkpoint_name
))
return
model
def
get_ict_dataset
(
use_titles
=
True
,
query_in_block_prob
=
1
):
args
=
get_args
()
block_dataset
=
get_indexed_dataset_
(
args
.
data_path
,
'mmap'
,
True
)
titles_dataset
=
get_indexed_dataset_
(
args
.
titles_data_path
,
'mmap'
,
True
)
kwargs
=
dict
(
name
=
'full'
,
block_dataset
=
block_dataset
,
title_dataset
=
titles_dataset
,
data_prefix
=
args
.
data_path
,
num_epochs
=
1
,
max_num_samples
=
None
,
max_seq_length
=
args
.
seq_length
,
short_seq_prob
=
0.0001
,
# doesn't matter
seed
=
1
,
query_in_block_prob
=
query_in_block_prob
,
use_titles
=
use_titles
)
dataset
=
ICTDataset
(
**
kwargs
)
return
dataset
def
get_one_epoch_dataloader
(
dataset
,
batch_size
=
None
):
args
=
get_args
()
world_size
=
mpu
.
get_data_parallel_world_size
()
rank
=
mpu
.
get_data_parallel_rank
()
if
batch_size
is
None
:
batch_size
=
args
.
batch_size
global_batch_size
=
batch_size
*
world_size
num_workers
=
args
.
num_workers
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
batch_sampler
=
DistributedBatchSampler
(
sampler
,
batch_size
=
global_batch_size
,
drop_last
=
True
,
rank
=
rank
,
world_size
=
world_size
)
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
num_workers
=
num_workers
,
pin_memory
=
True
)
if
__name__
==
"__main__"
:
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
index_builder
=
IndexBuilder
()
index_builder
.
build_and_save_index
()
megatron/arguments.py
View file @
32bb4edc
...
@@ -195,7 +195,6 @@ def _add_training_args(parser):
...
@@ -195,7 +195,6 @@ def _add_training_args(parser):
'by this value.'
)
'by this value.'
)
group
.
add_argument
(
'--tensorboard-dir'
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
'--tensorboard-dir'
,
type
=
str
,
default
=
None
,
help
=
'Write TensorBoard logs to this directory.'
)
help
=
'Write TensorBoard logs to this directory.'
)
group
.
add_argument
(
'--max-training-rank'
,
type
=
int
,
default
=
None
)
return
parser
return
parser
...
@@ -343,14 +342,6 @@ def _add_data_args(parser):
...
@@ -343,14 +342,6 @@ def _add_data_args(parser):
help
=
'Path to combined dataset to split.'
)
help
=
'Path to combined dataset to split.'
)
group
.
add_argument
(
'--titles-data-path'
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
'--titles-data-path'
,
type
=
str
,
default
=
None
,
help
=
'Path to titles dataset used for ICT'
)
help
=
'Path to titles dataset used for ICT'
)
group
.
add_argument
(
'--block-data-path'
,
type
=
str
,
default
=
None
,
help
=
'Path to pickled BlockData data structure'
)
group
.
add_argument
(
'--block-index-path'
,
type
=
str
,
default
=
None
,
help
=
'Path to pickled data structure for efficient block indexing'
)
group
.
add_argument
(
'--block-top-k'
,
type
=
int
,
default
=
5
,
help
=
'Number of blocks to use as top-k during retrieval'
)
group
.
add_argument
(
'--async-indexer'
,
action
=
'store_true'
,
help
=
'Whether the indexer job is running asynchronously with a trainer job'
)
group
.
add_argument
(
'--split'
,
type
=
str
,
default
=
'969, 30, 1'
,
group
.
add_argument
(
'--split'
,
type
=
str
,
default
=
'969, 30, 1'
,
help
=
'Comma-separated list of proportions for training,'
help
=
'Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
' validation, and test split. For example the split '
...
@@ -388,7 +379,6 @@ def _add_data_args(parser):
...
@@ -388,7 +379,6 @@ def _add_data_args(parser):
help
=
'Mask loss for the end of document tokens.'
)
help
=
'Mask loss for the end of document tokens.'
)
group
.
add_argument
(
'--query-in-block-prob'
,
type
=
float
,
default
=
0.1
,
group
.
add_argument
(
'--query-in-block-prob'
,
type
=
float
,
default
=
0.1
,
help
=
'Probability of keeping query in block for ICT dataset'
)
help
=
'Probability of keeping query in block for ICT dataset'
)
group
.
add_argument
(
'--faiss-use-gpu'
,
action
=
'store_true'
)
return
parser
return
parser
...
...
megatron/checkpointing.py
View file @
32bb4edc
...
@@ -24,7 +24,6 @@ import torch
...
@@ -24,7 +24,6 @@ import torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.mpu.initialize
import
get_train_group
,
get_data_parallel_group
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
...
@@ -45,7 +44,7 @@ def check_checkpoint_args(checkpoint_args):
...
@@ -45,7 +44,7 @@ def check_checkpoint_args(checkpoint_args):
_compare
(
'num_layers'
)
_compare
(
'num_layers'
)
_compare
(
'hidden_size'
)
_compare
(
'hidden_size'
)
_compare
(
'num_attention_heads'
)
_compare
(
'num_attention_heads'
)
#
_compare('max_position_embeddings')
_compare
(
'max_position_embeddings'
)
_compare
(
'make_vocab_size_divisible_by'
)
_compare
(
'make_vocab_size_divisible_by'
)
_compare
(
'padded_vocab_size'
)
_compare
(
'padded_vocab_size'
)
_compare
(
'tokenizer_type'
)
_compare
(
'tokenizer_type'
)
...
@@ -119,14 +118,14 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
...
@@ -119,14 +118,14 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
print
(
' successfully saved {}'
.
format
(
checkpoint_name
))
print
(
' successfully saved {}'
.
format
(
checkpoint_name
))
# Wait so everyone is done (necessary)
# Wait so everyone is done (necessary)
torch
.
distributed
.
barrier
(
get_data_parallel_group
()
)
torch
.
distributed
.
barrier
()
# And update the latest iteration
# And update the latest iteration
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
tracker_filename
=
get_checkpoint_tracker_filename
(
args
.
save
)
tracker_filename
=
get_checkpoint_tracker_filename
(
args
.
save
)
with
open
(
tracker_filename
,
'w'
)
as
f
:
with
open
(
tracker_filename
,
'w'
)
as
f
:
f
.
write
(
str
(
iteration
))
f
.
write
(
str
(
iteration
))
# Wait so everyone is done (not necessary)
# Wait so everyone is done (not necessary)
torch
.
distributed
.
barrier
(
get_data_parallel_group
()
)
torch
.
distributed
.
barrier
()
def
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
):
def
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
):
...
@@ -243,7 +242,7 @@ def load_checkpoint(model, optimizer, lr_scheduler):
...
@@ -243,7 +242,7 @@ def load_checkpoint(model, optimizer, lr_scheduler):
'exiting ...'
.
format
(
checkpoint_name
))
'exiting ...'
.
format
(
checkpoint_name
))
sys
.
exit
()
sys
.
exit
()
#
torch.distributed.barrier()
torch
.
distributed
.
barrier
()
if
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
' successfully loaded {}'
.
format
(
checkpoint_name
))
print
(
' successfully loaded {}'
.
format
(
checkpoint_name
))
...
...
megatron/data/bert_dataset.py
View file @
32bb4edc
...
@@ -25,6 +25,7 @@ from torch.utils.data import Dataset
...
@@ -25,6 +25,7 @@ from torch.utils.data import Dataset
from
megatron
import
get_tokenizer
,
get_args
from
megatron
import
get_tokenizer
,
get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
from
megatron.data.dataset_utils
import
build_training_sample
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
...
@@ -61,8 +62,6 @@ class BertDataset(Dataset):
...
@@ -61,8 +62,6 @@ class BertDataset(Dataset):
self
.
sep_id
=
tokenizer
.
sep
self
.
sep_id
=
tokenizer
.
sep
self
.
mask_id
=
tokenizer
.
mask
self
.
mask_id
=
tokenizer
.
mask
self
.
pad_id
=
tokenizer
.
pad
self
.
pad_id
=
tokenizer
.
pad
from
megatron.data.dataset_utils
import
build_training_sample
self
.
build_sample_fn
=
build_training_sample
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
samples_mapping
.
shape
[
0
]
return
self
.
samples_mapping
.
shape
[
0
]
...
@@ -73,13 +72,13 @@ class BertDataset(Dataset):
...
@@ -73,13 +72,13 @@ class BertDataset(Dataset):
# Note that this rng state should be numpy and not python since
# Note that this rng state should be numpy and not python since
# python randint is inclusive whereas the numpy one is exclusive.
# python randint is inclusive whereas the numpy one is exclusive.
np_rng
=
np
.
random
.
RandomState
(
seed
=
(
self
.
seed
+
idx
))
np_rng
=
np
.
random
.
RandomState
(
seed
=
(
self
.
seed
+
idx
))
return
self
.
build_sample
_fn
(
sample
,
seq_length
,
return
build_
training_
sample
(
sample
,
seq_length
,
self
.
max_seq_length
,
# needed for padding
self
.
max_seq_length
,
# needed for padding
self
.
vocab_id_list
,
self
.
vocab_id_list
,
self
.
vocab_id_to_token_dict
,
self
.
vocab_id_to_token_dict
,
self
.
cls_id
,
self
.
sep_id
,
self
.
cls_id
,
self
.
sep_id
,
self
.
mask_id
,
self
.
pad_id
,
self
.
mask_id
,
self
.
pad_id
,
self
.
masked_lm_prob
,
np_rng
)
self
.
masked_lm_prob
,
np_rng
)
def
get_indexed_dataset_
(
data_prefix
,
data_impl
,
skip_warmup
):
def
get_indexed_dataset_
(
data_prefix
,
data_impl
,
skip_warmup
):
...
...
megatron/data/dataset_utils.py
View file @
32bb4edc
...
@@ -25,7 +25,7 @@ import numpy as np
...
@@ -25,7 +25,7 @@ import numpy as np
from
megatron
import
print_rank_0
,
get_args
from
megatron
import
print_rank_0
,
get_args
from
megatron.data.bert_dataset
import
get_indexed_dataset_
,
get_train_valid_test_split_
,
BertDataset
from
megatron.data.bert_dataset
import
get_indexed_dataset_
,
get_train_valid_test_split_
,
BertDataset
DATASET_TYPES
=
[
'standard_bert'
,
'ict'
,
'realm'
]
DATASET_TYPES
=
[
'standard_bert'
,
'ict'
]
def
compile_helper
():
def
compile_helper
():
"""Compile helper function ar runtime. Make sure this
"""Compile helper function ar runtime. Make sure this
...
@@ -388,7 +388,7 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
...
@@ -388,7 +388,7 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
padding_length
=
max_seq_length
-
num_tokens
padding_length
=
max_seq_length
-
num_tokens
assert
padding_length
>=
0
assert
padding_length
>=
0
assert
len
(
tokentypes
)
==
num_tokens
assert
len
(
tokentypes
)
==
num_tokens
assert
len
(
masked_positions
)
==
len
(
masked_labels
)
,
(
len
(
masked_positions
),
len
(
masked_labels
))
assert
len
(
masked_positions
)
==
len
(
masked_labels
)
# Tokens and token types.
# Tokens and token types.
filler
=
[
pad_id
]
*
padding_length
filler
=
[
pad_id
]
*
padding_length
...
@@ -456,7 +456,6 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -456,7 +456,6 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
def
build_dataset
(
index
,
name
):
def
build_dataset
(
index
,
name
):
from
megatron.data.realm_dataset
import
ICTDataset
from
megatron.data.realm_dataset
import
ICTDataset
from
megatron.data.realm_dataset
import
REALMDataset
dataset
=
None
dataset
=
None
if
splits
[
index
+
1
]
>
splits
[
index
]:
if
splits
[
index
+
1
]
>
splits
[
index
]:
# Get the pointer to the original doc-idx so we can set it later.
# Get the pointer to the original doc-idx so we can set it later.
...
@@ -486,13 +485,6 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -486,13 +485,6 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
query_in_block_prob
=
args
.
query_in_block_prob
,
query_in_block_prob
=
args
.
query_in_block_prob
,
**
kwargs
**
kwargs
)
)
elif
dataset_type
==
'realm'
:
dataset
=
REALMDataset
(
block_dataset
=
indexed_dataset
,
title_dataset
=
title_dataset
,
masked_lm_prob
=
masked_lm_prob
,
**
kwargs
)
else
:
else
:
dataset
=
BertDataset
(
dataset
=
BertDataset
(
indexed_dataset
=
indexed_dataset
,
indexed_dataset
=
indexed_dataset
,
...
...
megatron/data/ict_dataset.py
deleted
100644 → 0
View file @
674814a5
import
itertools
import
random
import
os
import
time
import
numpy
as
np
import
torch
from
torch.utils.data
import
Dataset
from
megatron
import
get_tokenizer
from
megatron
import
print_rank_0
from
megatron
import
mpu
from
megatron.data
import
helpers
class
InverseClozeDataset
(
Dataset
):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def
__init__
(
self
,
name
,
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
query_in_block_prob
,
short_seq_prob
,
seed
):
self
.
name
=
name
self
.
seed
=
seed
self
.
max_seq_length
=
max_seq_length
self
.
query_in_block_prob
=
query_in_block_prob
self
.
block_dataset
=
block_dataset
self
.
title_dataset
=
title_dataset
self
.
short_seq_prob
=
short_seq_prob
self
.
rng
=
random
.
Random
(
self
.
seed
)
self
.
samples_mapping
=
self
.
get_samples_mapping
(
data_prefix
,
num_epochs
,
max_num_samples
)
self
.
tokenizer
=
get_tokenizer
()
self
.
vocab_id_list
=
list
(
self
.
tokenizer
.
inv_vocab
.
keys
())
self
.
vocab_id_to_token_list
=
self
.
tokenizer
.
inv_vocab
self
.
cls_id
=
self
.
tokenizer
.
cls
self
.
sep_id
=
self
.
tokenizer
.
sep
self
.
mask_id
=
self
.
tokenizer
.
mask
self
.
pad_id
=
self
.
tokenizer
.
pad
def
__len__
(
self
):
return
self
.
samples_mapping
.
shape
[
0
]
def
__getitem__
(
self
,
idx
):
start_idx
,
end_idx
,
doc_idx
,
block_idx
=
self
.
samples_mapping
[
idx
]
title
=
list
(
self
.
title_dataset
[
int
(
doc_idx
)])
block
=
[
list
(
self
.
block_dataset
[
i
])
for
i
in
range
(
start_idx
,
end_idx
)]
assert
len
(
block
)
>
1
# avoid selecting the first or last sentence to be the query.
if
len
(
block
)
==
2
:
rand_sent_idx
=
int
(
self
.
rng
.
random
()
>
0.5
)
else
:
rand_sent_idx
=
self
.
rng
.
randint
(
1
,
len
(
block
)
-
2
)
# keep the query in the context 10% of the time.
if
self
.
rng
.
random
()
<
self
.
query_in_block_prob
:
query
=
block
[
rand_sent_idx
].
copy
()
else
:
query
=
block
.
pop
(
rand_sent_idx
)
# still need to truncate because blocks are concluded when
# the sentence lengths have exceeded max_seq_length.
query
=
query
[:
self
.
max_seq_length
-
2
]
block
=
list
(
itertools
.
chain
(
*
block
))[:
self
.
max_seq_length
-
(
3
+
len
(
title
))]
query_tokens
,
query_pad_mask
=
self
.
concat_and_pad_tokens
(
query
)
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
sample
=
{
'query_tokens'
:
np
.
array
(
query_tokens
),
'query_pad_mask'
:
np
.
array
(
query_pad_mask
),
'block_tokens'
:
np
.
array
(
block_tokens
),
'block_pad_mask'
:
np
.
array
(
block_pad_mask
),
'block_data'
:
np
.
array
([
start_idx
,
end_idx
,
doc_idx
,
block_idx
]).
astype
(
np
.
int64
)
}
return
sample
def
encode_text
(
self
,
text
):
return
self
.
tokenizer
.
tokenize
(
text
)
def
decode_tokens
(
self
,
token_ids
):
tokens
=
self
.
tokenizer
.
tokenizer
.
convert_ids_to_tokens
(
token_ids
)
return
' '
.
join
(
token
for
token
in
tokens
if
token
!=
'[PAD]'
)
def
get_block
(
self
,
start_idx
,
end_idx
,
doc_idx
):
"""Get the IDs for an evidence block plus the title of the corresponding document"""
block
=
[
list
(
self
.
block_dataset
[
i
])
for
i
in
range
(
start_idx
,
end_idx
)]
title
=
list
(
self
.
title_dataset
[
int
(
doc_idx
)])
block
=
list
(
itertools
.
chain
(
*
block
))[:
self
.
max_seq_length
-
(
3
+
len
(
title
))]
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
return
(
block_tokens
,
block_pad_mask
)
def
concat_and_pad_tokens
(
self
,
tokens
,
title
=
None
):
"""concat with special tokens and pad sequence to self.max_seq_length"""
tokens
=
[
self
.
cls_id
]
+
tokens
+
[
self
.
sep_id
]
if
title
is
not
None
:
# tokens += title + [self.sep_id]
tokens
=
t
assert
len
(
tokens
)
<=
self
.
max_seq_length
,
len
(
tokens
)
num_pad
=
self
.
max_seq_length
-
len
(
tokens
)
pad_mask
=
[
1
]
*
len
(
tokens
)
+
[
0
]
*
num_pad
tokens
+=
[
self
.
pad_id
]
*
num_pad
return
tokens
,
pad_mask
def
get_samples_mapping
(
self
,
data_prefix
,
num_epochs
,
max_num_samples
):
if
not
num_epochs
:
if
not
max_num_samples
:
raise
ValueError
(
"Need to specify either max_num_samples "
"or num_epochs"
)
num_epochs
=
np
.
iinfo
(
np
.
int32
).
max
-
1
if
not
max_num_samples
:
max_num_samples
=
np
.
iinfo
(
np
.
int64
).
max
-
1
# Filename of the index mapping
indexmap_filename
=
data_prefix
indexmap_filename
+=
'_{}_indexmap'
.
format
(
self
.
name
)
if
num_epochs
!=
(
np
.
iinfo
(
np
.
int32
).
max
-
1
):
indexmap_filename
+=
'_{}ep'
.
format
(
num_epochs
)
if
max_num_samples
!=
(
np
.
iinfo
(
np
.
int64
).
max
-
1
):
indexmap_filename
+=
'_{}mns'
.
format
(
max_num_samples
)
indexmap_filename
+=
'_{}msl'
.
format
(
self
.
max_seq_length
)
indexmap_filename
+=
'_{}s'
.
format
(
self
.
seed
)
indexmap_filename
+=
'.npy'
# Build the indexed mapping if not exist.
if
torch
.
distributed
.
get_rank
()
==
0
and
\
not
os
.
path
.
isfile
(
indexmap_filename
):
print
(
' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'
.
format
(
indexmap_filename
))
# Make sure the types match the helpers input types.
assert
self
.
block_dataset
.
doc_idx
.
dtype
==
np
.
int64
assert
self
.
block_dataset
.
sizes
.
dtype
==
np
.
int32
# Build samples mapping
verbose
=
torch
.
distributed
.
get_rank
()
==
0
start_time
=
time
.
time
()
print_rank_0
(
' > building samples index mapping for {} ...'
.
format
(
self
.
name
))
samples_mapping
=
helpers
.
build_blocks_mapping
(
self
.
block_dataset
.
doc_idx
,
self
.
block_dataset
.
sizes
,
self
.
title_dataset
.
sizes
,
num_epochs
,
max_num_samples
,
self
.
max_seq_length
-
3
,
# account for added tokens
self
.
seed
,
verbose
)
print_rank_0
(
' > done building samples index mapping'
)
np
.
save
(
indexmap_filename
,
samples_mapping
,
allow_pickle
=
True
)
print_rank_0
(
' > saved the index mapping in {}'
.
format
(
indexmap_filename
))
# Make sure all the ranks have built the mapping
print_rank_0
(
' > elapsed time to build and save samples mapping '
'(seconds): {:4f}'
.
format
(
time
.
time
()
-
start_time
))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts
=
torch
.
cuda
.
LongTensor
([
1
])
torch
.
distributed
.
all_reduce
(
counts
,
group
=
mpu
.
get_data_parallel_group
())
assert
counts
[
0
].
item
()
==
torch
.
distributed
.
get_world_size
(
group
=
mpu
.
get_data_parallel_group
())
# Load indexed dataset.
print_rank_0
(
' > loading indexed mapping from {}'
.
format
(
indexmap_filename
))
start_time
=
time
.
time
()
samples_mapping
=
np
.
load
(
indexmap_filename
,
allow_pickle
=
True
)
print_rank_0
(
' loaded indexed file in {:3.3f} seconds'
.
format
(
time
.
time
()
-
start_time
))
print_rank_0
(
' total number of samples: {}'
.
format
(
samples_mapping
.
shape
[
0
]))
return
samples_mapping
megatron/data/preprocess_data.py
deleted
100644 → 0
View file @
674814a5
import
argparse
import
itertools
import
json
import
multiprocessing
import
nltk
import
sys
import
time
import
torch
sys
.
path
.
insert
(
0
,
'../'
)
sys
.
path
.
insert
(
0
,
'../../'
)
from
tokenizer.bert_tokenization
import
FullTokenizer
from
data.indexed_dataset
import
make_builder
class
CustomLanguageVars
(
nltk
.
tokenize
.
punkt
.
PunktLanguageVars
):
_period_context_fmt
=
r
"""
\S* # some word material
%(SentEndChars)s # a potential sentence ending
\s* # <-- THIS is what I changed
(?=(?P<after_tok>
%(NonWord)s # either other punctuation
|
(?P<next_tok>\S+) # <-- Normally you would have \s+ here
))"""
class
Encoder
(
object
):
splitter
=
None
tokenizer
=
None
def
__init__
(
self
,
args
):
self
.
args
=
args
def
initializer
(
self
):
# Use Encoder class as a container for global data
Encoder
.
tokenizer
=
FullTokenizer
(
self
.
args
.
vocab
,
do_lower_case
=
True
)
spliter
=
nltk
.
load
(
"tokenizers/punkt/english.pickle"
)
if
self
.
args
.
keep_newlines
:
# this prevents punkt from eating newlines after sentences
Encoder
.
splitter
=
nltk
.
tokenize
.
punkt
.
PunktSentenceTokenizer
(
train_text
=
spliter
.
_params
,
lang_vars
=
CustomLanguageVars
())
else
:
Encoder
.
splitter
=
spliter
def
encode
(
self
,
json_line
):
text
=
json
.
loads
(
json_line
)[
self
.
args
.
json_key
]
if
not
text
:
text
=
"no text"
doc_ids
=
[]
for
sentence
in
Encoder
.
splitter
.
tokenize
(
text
):
tokens
=
Encoder
.
tokenizer
.
tokenize
(
sentence
)
ids
=
Encoder
.
tokenizer
.
convert_tokens_to_ids
(
tokens
)
if
len
(
ids
)
>
0
:
doc_ids
.
append
(
ids
)
else
:
print
(
"no ids!"
,
flush
=
True
)
tokens
=
Encoder
.
tokenizer
.
tokenize
(
"no text"
)
ids
=
Encoder
.
tokenizer
.
convert_tokens_to_ids
(
tokens
)
doc_ids
.
append
(
ids
)
if
self
.
args
.
flatten
and
len
(
doc_ids
)
>
1
:
doc_ids
=
[
list
(
itertools
.
chain
(
*
doc_ids
))]
return
doc_ids
,
len
(
json_line
)
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--input'
,
type
=
str
,
help
=
'Path to input JSON'
)
parser
.
add_argument
(
'--vocab'
,
type
=
str
,
help
=
'Path to vocab.txt'
)
parser
.
add_argument
(
'--flatten'
,
action
=
'store_true'
,
help
=
'Path to input JSON'
)
parser
.
add_argument
(
'--json-key'
,
type
=
str
,
default
=
'text'
,
help
=
'Key to extract from json'
)
parser
.
add_argument
(
'--output-prefix'
,
type
=
str
,
help
=
'Path to binary output file without suffix'
)
parser
.
add_argument
(
'--workers'
,
type
=
int
,
default
=
20
,
help
=
'Number of worker processes to launch'
)
parser
.
add_argument
(
'--log-interval'
,
type
=
int
,
default
=
100
,
help
=
'Interval between progress updates'
)
parser
.
add_argument
(
'--keep-newlines'
,
action
=
'store_true'
,
help
=
'Keep newlines between sentences.'
)
parser
.
add_argument
(
'--dataset-impl'
,
type
=
str
,
default
=
'mmap'
,
choices
=
[
'lazy'
,
'cached'
,
'mmap'
])
args
=
parser
.
parse_args
()
args
.
keep_empty
=
False
startup_start
=
time
.
time
()
print
(
"Opening"
,
args
.
input
)
fin
=
open
(
args
.
input
,
'r'
,
encoding
=
'utf-8'
)
nltk
.
download
(
"punkt"
,
quiet
=
True
)
encoder
=
Encoder
(
args
)
tokenizer
=
FullTokenizer
(
args
.
vocab
,
do_lower_case
=
True
)
pool
=
multiprocessing
.
Pool
(
args
.
workers
,
initializer
=
encoder
.
initializer
)
encoded_docs
=
pool
.
imap
(
encoder
.
encode
,
fin
,
25
)
print
(
f
"Vocab size:
{
tokenizer
.
vocab_size
()
}
"
)
output_bin_file
=
"{}.bin"
.
format
(
args
.
output_prefix
)
output_idx_file
=
"{}.idx"
.
format
(
args
.
output_prefix
)
builder
=
make_builder
(
output_bin_file
,
impl
=
args
.
dataset_impl
,
vocab_size
=
tokenizer
.
vocab_size
())
startup_end
=
time
.
time
()
proc_start
=
time
.
time
()
total_bytes_processed
=
0
print
(
"Time to startup:"
,
startup_end
-
startup_start
)
for
i
,
(
doc
,
bytes_processed
)
in
enumerate
(
encoded_docs
,
start
=
1
):
total_bytes_processed
+=
bytes_processed
for
sentence
in
doc
:
#print(sentence)
#print(tokenizer.convert_ids_to_tokens(sentence))
builder
.
add_item
(
torch
.
IntTensor
(
sentence
))
builder
.
end_document
()
if
i
%
args
.
log_interval
==
0
:
current
=
time
.
time
()
elapsed
=
current
-
proc_start
mbs
=
total_bytes_processed
/
elapsed
/
1024
/
1024
print
(
f
"Processed
{
i
}
documents"
,
f
"(
{
i
/
elapsed
}
docs/s,
{
mbs
}
MB/s)."
,
file
=
sys
.
stderr
)
builder
.
finalize
(
output_idx_file
)
if
__name__
==
'__main__'
:
main
()
megatron/data/realm_dataset.py
View file @
32bb4edc
...
@@ -5,64 +5,6 @@ import numpy as np
...
@@ -5,64 +5,6 @@ import numpy as np
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron.data.realm_dataset_utils
import
build_realm_training_sample
,
get_block_samples_mapping
,
join_str_list
class
REALMDataset
(
Dataset
):
"""Dataset containing simple masked sentences for masked language modeling.
The dataset should yield sentences just like the regular BertDataset
However, this dataset also needs to be able to return a set of blocks
given their start and end indices.
Presumably
"""
def
__init__
(
self
,
name
,
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
masked_lm_prob
,
max_seq_length
,
short_seq_prob
,
seed
):
self
.
name
=
name
self
.
seed
=
seed
self
.
max_seq_length
=
max_seq_length
self
.
masked_lm_prob
=
masked_lm_prob
self
.
block_dataset
=
block_dataset
self
.
title_dataset
=
title_dataset
self
.
short_seq_prob
=
short_seq_prob
self
.
rng
=
random
.
Random
(
self
.
seed
)
self
.
samples_mapping
=
get_block_samples_mapping
(
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
seed
,
name
)
self
.
tokenizer
=
get_tokenizer
()
self
.
vocab_id_list
=
list
(
self
.
tokenizer
.
inv_vocab
.
keys
())
self
.
vocab_id_to_token_list
=
self
.
tokenizer
.
inv_vocab
self
.
cls_id
=
self
.
tokenizer
.
cls
self
.
sep_id
=
self
.
tokenizer
.
sep
self
.
mask_id
=
self
.
tokenizer
.
mask
self
.
pad_id
=
self
.
tokenizer
.
pad
def
__len__
(
self
):
return
self
.
samples_mapping
.
shape
[
0
]
def
__getitem__
(
self
,
idx
):
start_idx
,
end_idx
,
doc_idx
,
block_idx
=
self
.
samples_mapping
[
idx
]
block
=
[
list
(
self
.
block_dataset
[
i
])
for
i
in
range
(
start_idx
,
end_idx
)]
assert
len
(
block
)
>
1
np_rng
=
np
.
random
.
RandomState
(
seed
=
(
self
.
seed
+
idx
))
sample
=
build_realm_training_sample
(
block
,
self
.
max_seq_length
,
self
.
vocab_id_list
,
self
.
vocab_id_to_token_list
,
self
.
cls_id
,
self
.
sep_id
,
self
.
mask_id
,
self
.
pad_id
,
self
.
masked_lm_prob
,
np_rng
)
sample
.
update
({
'query_block_indices'
:
np
.
array
([
block_idx
]).
astype
(
np
.
int64
)})
return
sample
class
ICTDataset
(
Dataset
):
class
ICTDataset
(
Dataset
):
...
@@ -95,6 +37,7 @@ class ICTDataset(Dataset):
...
@@ -95,6 +37,7 @@ class ICTDataset(Dataset):
return
self
.
samples_mapping
.
shape
[
0
]
return
self
.
samples_mapping
.
shape
[
0
]
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
"""Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
start_idx
,
end_idx
,
doc_idx
,
block_idx
=
self
.
samples_mapping
[
idx
]
start_idx
,
end_idx
,
doc_idx
,
block_idx
=
self
.
samples_mapping
[
idx
]
if
self
.
use_titles
:
if
self
.
use_titles
:
title
=
list
(
self
.
title_dataset
[
int
(
doc_idx
)])
title
=
list
(
self
.
title_dataset
[
int
(
doc_idx
)])
...
@@ -107,7 +50,7 @@ class ICTDataset(Dataset):
...
@@ -107,7 +50,7 @@ class ICTDataset(Dataset):
rand_sent_idx
=
self
.
rng
.
randint
(
0
,
len
(
block
)
-
1
)
rand_sent_idx
=
self
.
rng
.
randint
(
0
,
len
(
block
)
-
1
)
# keep the query in the context
10%
of the time.
# keep the query in the context
query_in_block_prob fraction
of the time.
if
self
.
rng
.
random
()
<
self
.
query_in_block_prob
:
if
self
.
rng
.
random
()
<
self
.
query_in_block_prob
:
query
=
block
[
rand_sent_idx
].
copy
()
query
=
block
[
rand_sent_idx
].
copy
()
else
:
else
:
...
@@ -134,30 +77,12 @@ class ICTDataset(Dataset):
...
@@ -134,30 +77,12 @@ class ICTDataset(Dataset):
def
encode_text
(
self
,
text
):
def
encode_text
(
self
,
text
):
return
self
.
tokenizer
.
tokenize
(
text
)
return
self
.
tokenizer
.
tokenize
(
text
)
def
decode_tokens
(
self
,
token_ids
,
hardcore
=
False
):
def
decode_tokens
(
self
,
token_ids
):
"""Utility function to help with debugging mostly"""
tokens
=
self
.
tokenizer
.
tokenizer
.
convert_ids_to_tokens
(
token_ids
)
tokens
=
self
.
tokenizer
.
tokenizer
.
convert_ids_to_tokens
(
token_ids
)
exclude_list
=
[
'[PAD]'
,
'[CLS]'
]
exclude_list
=
[
'[PAD]'
,
'[CLS]'
]
if
hardcore
:
extra_exclude
=
[
'[SEP]'
]
exclude_list
.
extend
(
extra_exclude
)
non_pads
=
[
t
for
t
in
tokens
if
t
not
in
exclude_list
]
non_pads
=
[
t
for
t
in
tokens
if
t
not
in
exclude_list
]
joined_strs
=
join_str_list
(
non_pads
)
joined_strs
=
join_str_list
(
non_pads
)
if
hardcore
:
escape_chars
=
[
'+'
,
'-'
,
'&'
,
'!'
,
'('
,
')'
,
'{'
,
'}'
,
'['
,
']'
,
'^'
,
'"'
,
'~'
,
'*'
,
'?'
,
':'
,
'/'
]
skip_me
=
False
joined_strs
=
list
(
joined_strs
)
joined_strs
=
[
s
for
s
in
joined_strs
if
s
!=
'
\\
'
]
for
i
,
c
in
enumerate
(
joined_strs
):
if
skip_me
:
skip_me
=
False
continue
if
c
in
escape_chars
:
joined_strs
.
insert
(
i
,
'
\\
'
)
skip_me
=
True
joined_strs
=
''
.
join
(
joined_strs
)
if
len
(
joined_strs
)
<
3
:
joined_strs
+=
'text here'
return
joined_strs
def
get_block
(
self
,
start_idx
,
end_idx
,
doc_idx
):
def
get_block
(
self
,
start_idx
,
end_idx
,
doc_idx
):
"""Get the IDs for an evidence block plus the title of the corresponding document"""
"""Get the IDs for an evidence block plus the title of the corresponding document"""
...
@@ -170,13 +95,14 @@ class ICTDataset(Dataset):
...
@@ -170,13 +95,14 @@ class ICTDataset(Dataset):
return
(
block_tokens
,
block_pad_mask
)
return
(
block_tokens
,
block_pad_mask
)
def
get_null_block
(
self
):
def
get_null_block
(
self
):
"""Get empty block and title - used in REALM pretraining"""
block
,
title
=
[],
[]
block
,
title
=
[],
[]
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
return
(
block_tokens
,
block_pad_mask
)
return
(
block_tokens
,
block_pad_mask
)
def
concat_and_pad_tokens
(
self
,
tokens
,
title
=
None
):
def
concat_and_pad_tokens
(
self
,
tokens
,
title
=
None
):
"""
c
oncat with special tokens and pad sequence to self.max_seq_length"""
"""
C
oncat with special tokens and pad sequence to self.max_seq_length"""
if
title
is
None
:
if
title
is
None
:
tokens
=
[
self
.
cls_id
]
+
tokens
+
[
self
.
sep_id
]
tokens
=
[
self
.
cls_id
]
+
tokens
+
[
self
.
sep_id
]
else
:
else
:
...
...
megatron/data/realm_dataset_utils.py
View file @
32bb4edc
import
itertools
import
os
import
os
import
random
import
time
import
time
import
numpy
as
np
import
numpy
as
np
import
spacy
import
torch
import
torch
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
pad_and_convert_to_numpy
from
megatron
import
print_rank_0
,
mpu
from
megatron
import
get_tokenizer
,
print_rank_0
,
mpu
SPACY_NER
=
spacy
.
load
(
'en_core_web_lg'
)
def
build_realm_training_sample
(
sample
,
max_seq_length
,
vocab_id_list
,
vocab_id_to_token_dict
,
cls_id
,
sep_id
,
mask_id
,
pad_id
,
masked_lm_prob
,
np_rng
):
tokens
=
list
(
itertools
.
chain
(
*
sample
))[:
max_seq_length
-
2
]
tokens
,
tokentypes
=
create_single_tokens_and_tokentypes
(
tokens
,
cls_id
,
sep_id
)
try
:
masked_tokens
,
masked_positions
,
masked_labels
=
salient_span_mask
(
tokens
,
mask_id
)
except
TypeError
:
# this means the above returned None, and None isn't iterable.
# TODO: consider coding style.
max_predictions_per_seq
=
masked_lm_prob
*
max_seq_length
masked_tokens
,
masked_positions
,
masked_labels
,
_
=
create_masked_lm_predictions
(
tokens
,
vocab_id_list
,
vocab_id_to_token_dict
,
masked_lm_prob
,
cls_id
,
sep_id
,
mask_id
,
max_predictions_per_seq
,
np_rng
)
tokens_np
,
tokentypes_np
,
labels_np
,
padding_mask_np
,
loss_mask_np
\
=
pad_and_convert_to_numpy
(
masked_tokens
,
tokentypes
,
masked_positions
,
masked_labels
,
pad_id
,
max_seq_length
)
train_sample
=
{
'tokens'
:
tokens_np
,
'labels'
:
labels_np
,
'loss_mask'
:
loss_mask_np
,
'pad_mask'
:
padding_mask_np
}
return
train_sample
def
create_single_tokens_and_tokentypes
(
_tokens
,
cls_id
,
sep_id
):
tokens
=
[]
tokens
.
append
(
cls_id
)
tokens
.
extend
(
list
(
_tokens
))
tokens
.
append
(
sep_id
)
tokentypes
=
[
0
]
*
len
(
tokens
)
return
tokens
,
tokentypes
def
join_str_list
(
str_list
):
def
join_str_list
(
str_list
):
...
@@ -63,69 +18,6 @@ def join_str_list(str_list):
...
@@ -63,69 +18,6 @@ def join_str_list(str_list):
return
result
return
result
def
id_to_str_pos_map
(
token_ids
,
tokenizer
):
"""Given a list of ids, return a list of integers which correspond to the starting index
of the corresponding token in the original string (with spaces, without artifacts e.g. ##)"""
token_strs
=
tokenizer
.
tokenizer
.
convert_ids_to_tokens
(
token_ids
)
pos_map
=
[
0
]
for
i
in
range
(
len
(
token_strs
)
-
1
):
len_prev
=
len
(
token_strs
[
i
])
# do not add the length of the "##"
if
token_strs
[
i
].
startswith
(
"##"
):
len_prev
-=
2
# add the length of the space if needed
if
token_strs
[
i
+
1
].
startswith
(
"##"
):
pos_map
.
append
(
pos_map
[
-
1
]
+
len_prev
)
else
:
pos_map
.
append
(
pos_map
[
-
1
]
+
len_prev
+
1
)
# make sure total size is correct
offset
=
-
2
if
token_strs
[
-
1
].
startswith
(
"##"
)
else
0
total_len
=
pos_map
[
-
1
]
+
len
(
token_strs
[
-
1
])
+
offset
assert
total_len
==
len
(
join_str_list
(
token_strs
))
-
1
,
(
total_len
,
len
(
join_str_list
(
token_strs
)))
return
pos_map
def
salient_span_mask
(
tokens
,
mask_id
):
"""Creates the predictions for the masked LM objective.
Note: Tokens here are vocab ids and not text tokens."""
tokenizer
=
get_tokenizer
()
tokens_str
=
join_str_list
(
tokenizer
.
tokenizer
.
convert_ids_to_tokens
(
tokens
))
# need to get all named entities
entities
=
SPACY_NER
(
tokens_str
).
ents
entities
=
[
e
for
e
in
entities
if
e
.
text
!=
"CLS"
]
if
len
(
entities
)
==
0
:
return
None
entity_idx
=
np
.
random
.
randint
(
0
,
len
(
entities
))
selected_entity
=
entities
[
entity_idx
]
token_pos_map
=
id_to_str_pos_map
(
tokens
,
tokenizer
)
mask_start
=
mask_end
=
0
set_mask_start
=
False
while
mask_end
<
len
(
token_pos_map
)
and
token_pos_map
[
mask_end
]
<
selected_entity
.
end_char
:
if
token_pos_map
[
mask_start
]
>
selected_entity
.
start_char
:
set_mask_start
=
True
if
not
set_mask_start
:
mask_start
+=
1
mask_end
+=
1
masked_positions
=
list
(
range
(
mask_start
-
1
,
mask_end
))
labels
=
[]
output_tokens
=
tokens
.
copy
()
for
id_idx
in
masked_positions
:
labels
.
append
(
tokens
[
id_idx
])
output_tokens
[
id_idx
]
=
mask_id
#print("-" * 100 + '\n',
# "TOKEN STR\n", tokens_str + '\n',
# "SELECTED ENTITY\n", selected_entity.text + '\n',
# "OUTPUT\n", join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(output_tokens)), flush=True)
return
output_tokens
,
masked_positions
,
labels
def
get_block_samples_mapping
(
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
def
get_block_samples_mapping
(
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
seed
,
name
):
max_num_samples
,
max_seq_length
,
seed
,
name
):
if
not
num_epochs
:
if
not
num_epochs
:
...
...
megatron/data/realm_index.py
deleted
100644 → 0
View file @
674814a5
from
collections
import
defaultdict
import
os
import
pickle
import
shutil
import
faiss
import
numpy
as
np
import
torch
from
megatron
import
get_args
,
mpu
def
detach
(
tensor
):
return
tensor
.
detach
().
cpu
().
numpy
()
class
BlockData
(
object
):
def
__init__
(
self
):
self
.
embed_data
=
dict
()
self
.
meta_data
=
dict
()
self
.
temp_dir_name
=
'temp_block_data'
def
state
(
self
):
return
{
'embed_data'
:
self
.
embed_data
,
'meta_data'
:
self
.
meta_data
}
def
clear
(
self
):
"""Clear the data structures to save memory"""
self
.
embed_data
=
dict
()
self
.
meta_data
=
dict
()
@
classmethod
def
load_from_file
(
cls
,
fname
):
print
(
"
\n
> Unpickling block data"
,
flush
=
True
)
state_dict
=
pickle
.
load
(
open
(
fname
,
'rb'
))
print
(
">> Finished unpickling block data
\n
"
,
flush
=
True
)
new_index
=
cls
()
new_index
.
embed_data
=
state_dict
[
'embed_data'
]
new_index
.
meta_data
=
state_dict
[
'meta_data'
]
return
new_index
def
add_block_data
(
self
,
block_indices
,
block_embeds
,
block_metas
,
allow_overwrite
=
False
):
for
idx
,
embed
,
meta
in
zip
(
block_indices
,
block_embeds
,
block_metas
):
if
not
allow_overwrite
and
idx
in
self
.
embed_data
:
raise
ValueError
(
"Unexpectedly tried to overwrite block data"
)
self
.
embed_data
[
idx
]
=
np
.
float16
(
embed
)
self
.
meta_data
[
idx
]
=
meta
def
save_shard
(
self
,
rank
):
if
not
os
.
path
.
isdir
(
self
.
temp_dir_name
):
os
.
mkdir
(
self
.
temp_dir_name
)
# save the data for each shard
with
open
(
'{}/{}.pkl'
.
format
(
self
.
temp_dir_name
,
rank
),
'wb'
)
as
data_file
:
pickle
.
dump
(
self
.
state
(),
data_file
)
def
consolidate_shards_and_save
(
self
,
ignore_shard
=
0
):
"""Combine all the shards made using self.save_shard()"""
fnames
=
os
.
listdir
(
self
.
temp_dir_name
)
for
fname
in
fnames
:
with
open
(
'{}/{}'
.
format
(
self
.
temp_dir_name
,
fname
),
'rb'
)
as
f
:
data
=
pickle
.
load
(
f
)
old_size
=
len
(
self
.
embed_data
)
shard_size
=
len
(
data
[
'embed_data'
])
self
.
embed_data
.
update
(
data
[
'embed_data'
])
self
.
meta_data
.
update
(
data
[
'meta_data'
])
# assert (len(self.embed_data) == old_size + shard_size) or (str(ignore_shard) in fname)
args
=
get_args
()
with
open
(
args
.
block_data_path
,
'wb'
)
as
final_file
:
pickle
.
dump
(
self
.
state
(),
final_file
)
shutil
.
rmtree
(
self
.
temp_dir_name
,
ignore_errors
=
True
)
class
FaissMIPSIndex
(
object
):
def
__init__
(
self
,
index_type
,
embed_size
,
use_gpu
=
False
):
self
.
index_type
=
index_type
self
.
embed_size
=
embed_size
self
.
use_gpu
=
use_gpu
self
.
id_map
=
dict
()
# alsh
self
.
m
=
5
self
.
u
=
0.99
self
.
max_norm
=
None
self
.
block_mips_index
=
None
self
.
_set_block_index
()
def
_set_block_index
(
self
):
INDEX_TYPES
=
[
'flat_ip'
]
if
self
.
index_type
not
in
INDEX_TYPES
:
raise
ValueError
(
"Invalid index type specified"
)
print
(
"
\n
> Building index"
,
flush
=
True
)
self
.
block_mips_index
=
faiss
.
index_factory
(
self
.
embed_size
,
'Flat'
,
faiss
.
METRIC_INNER_PRODUCT
)
if
not
self
.
use_gpu
:
self
.
block_mips_index
=
faiss
.
IndexIDMap
(
self
.
block_mips_index
)
print
(
">> Finished building index"
,
flush
=
True
)
if
self
.
use_gpu
:
res
=
faiss
.
StandardGpuResources
()
# self.block_mips_index = faiss.index_cpu_to_gpu(res, device, self.block_mips_index)
config
=
faiss
.
GpuIndexFlatConfig
()
config
.
device
=
torch
.
cuda
.
current_device
()
config
.
useFloat16
=
True
self
.
block_mips_index
=
faiss
.
GpuIndexFlat
(
res
,
self
.
block_mips_index
,
config
)
print
(
">>> Loaded Faiss index on GPU {}
\n
"
.
format
(
self
.
block_mips_index
.
getDevice
()),
flush
=
True
)
def
reset_index
(
self
):
self
.
_set_block_index
()
def
add_block_embed_data
(
self
,
all_block_data
,
clear_block_data
=
False
):
"""Add the embedding of each block to the underlying FAISS index"""
block_indices
,
block_embeds
=
zip
(
*
all_block_data
.
embed_data
.
items
())
if
self
.
use_gpu
:
for
i
,
idx
in
enumerate
(
block_indices
):
self
.
id_map
[
i
]
=
idx
if
clear_block_data
:
all_block_data
.
clear
()
if
self
.
use_gpu
:
self
.
block_mips_index
.
add
(
np
.
float32
(
np
.
array
(
block_embeds
)))
else
:
self
.
block_mips_index
.
add_with_ids
(
np
.
float32
(
np
.
array
(
block_embeds
)),
np
.
array
(
block_indices
))
def
search_mips_index
(
self
,
query_embeds
,
top_k
,
reconstruct
=
True
):
"""Get the top-k blocks by the index distance metric.
:param reconstruct: if True: return a [num_queries x k x embed_dim] array of blocks
if False: return [num_queries x k] array of distances, and another for indices
"""
if
self
.
index_type
==
'flat_l2'
:
query_embeds
=
self
.
alsh_query_preprocess_fn
(
query_embeds
)
query_embeds
=
np
.
float32
(
detach
(
query_embeds
))
# query_embeds = query_embeds.float()
with
torch
.
no_grad
():
if
reconstruct
:
top_k_block_embeds
=
self
.
block_mips_index
.
search_and_reconstruct
(
query_embeds
,
top_k
)
return
top_k_block_embeds
else
:
distances
,
block_indices
=
self
.
block_mips_index
.
search
(
query_embeds
,
top_k
)
if
self
.
use_gpu
:
fresh_indices
=
np
.
zeros
(
block_indices
.
shape
)
for
i
in
range
(
block_indices
.
shape
[
0
]):
for
j
in
range
(
block_indices
.
shape
[
1
]):
fresh_indices
[
i
,
j
]
=
self
.
id_map
[
block_indices
[
i
,
j
]]
block_indices
=
fresh_indices
return
distances
,
block_indices
# functions below are for ALSH, which currently isn't being used
def
get_norm_powers_and_halves_array
(
self
,
embeds
):
norm
=
np
.
linalg
.
norm
(
embeds
,
axis
=
1
)
norm_powers
=
[
np
.
multiply
(
norm
,
norm
)]
# squared L2 norms of all
for
i
in
range
(
self
.
m
-
1
):
norm_powers
.
append
(
np
.
multiply
(
norm_powers
[
-
1
],
norm_powers
[
-
1
]))
# [num_blocks x self.m]
norm_powers
=
np
.
transpose
(
np
.
array
(
norm_powers
))
halves_array
=
0.5
*
np
.
ones
(
norm_powers
.
shape
)
return
norm_powers
,
halves_array
def
alsh_block_preprocess_fn
(
self
,
block_embeds
):
block_embeds
=
np
.
array
(
block_embeds
)
if
self
.
max_norm
is
None
:
self
.
max_norm
=
max
(
np
.
linalg
.
norm
(
block_embeds
,
axis
=
1
))
if
self
.
max_norm
>
1
:
block_embeds
=
self
.
u
/
self
.
max_norm
*
block_embeds
norm_powers
,
halves_array
=
self
.
get_norm_powers_and_halves_array
(
block_embeds
)
# P'(S(x)) for all x in block_embeds
return
np
.
float32
(
np
.
concatenate
((
block_embeds
,
norm_powers
,
halves_array
),
axis
=
1
))
def
alsh_query_preprocess_fn
(
self
,
query_embeds
):
max_norm
=
max
(
np
.
linalg
.
norm
(
query_embeds
,
axis
=
1
))
if
max_norm
>
1
:
query_embeds
=
self
.
u
/
max_norm
*
query_embeds
norm_powers
,
halves_array
=
self
.
get_norm_powers_and_halves_array
(
query_embeds
)
# Q'(S(x)) for all x in query_embeds
return
np
.
float32
(
np
.
concatenate
((
query_embeds
,
halves_array
,
norm_powers
),
axis
=
1
))
# This was the original hashing scheme, not used anymore
class
RandProjectionLSHIndex
(
object
):
"""Class for holding hashed data"""
def
__init__
(
self
,
embed_size
,
num_buckets
,
whiten
=
True
,
seed
=
0
):
np
.
random
.
seed
(
seed
)
self
.
hash_data
=
defaultdict
(
list
)
hash_matrix
=
2
*
np
.
random
.
rand
(
embed_size
,
int
(
num_buckets
/
2
))
-
1
self
.
hash_matrix
=
hash_matrix
/
np
.
linalg
.
norm
(
hash_matrix
,
axis
=
0
).
reshape
(
1
,
-
1
)
self
.
embed_mean
=
None
self
.
embed_whitener
=
None
self
.
whiten
=
whiten
def
state
(
self
):
state
=
{
'hash_data'
:
self
.
hash_data
,
'hash_matrix'
:
self
.
hash_matrix
,
'embed_mean'
:
self
.
embed_mean
,
'embed_whitener'
:
self
.
embed_whitener
,
}
return
state
def
save_to_file
(
self
):
args
=
get_args
()
with
open
(
args
.
block_index_path
,
'wb'
)
as
index_file
:
pickle
.
dump
(
self
.
state
(),
index_file
)
@
classmethod
def
load_from_file
(
cls
,
fname
):
print
(
" > Unpickling block hash data"
)
state_dict
=
pickle
.
load
(
open
(
fname
,
'rb'
))
print
(
" > Finished unpickling"
)
hash_matrix
=
state_dict
[
'hash_matrix'
]
new_index
=
cls
(
hash_matrix
.
shape
[
0
],
hash_matrix
.
shape
[
1
]
*
2
)
new_index
.
hash_data
=
state_dict
[
'hash_data'
]
new_index
.
embed_mean
=
state_dict
.
get
(
'embed_mean'
)
new_index
.
embed_whitener
=
state_dict
.
get
(
'embed_whitener'
)
new_index
.
hash_matrix
=
hash_matrix
return
new_index
def
get_block_bucket
(
self
,
hash
):
return
self
.
hash_data
[
hash
]
def
hash_embeds
(
self
,
embeds
,
write_block_data
=
None
):
"""Hash a tensor of embeddings using a random projection matrix"""
embed_scores_pos
=
torch
.
matmul
(
embeds
,
torch
.
cuda
.
FloatTensor
(
self
.
hash_matrix
).
type
(
embeds
.
dtype
))
embed_scores
=
torch
.
cat
((
embed_scores_pos
,
-
embed_scores_pos
),
axis
=
1
)
embed_hashes
=
detach
(
torch
.
argmax
(
embed_scores
,
axis
=
1
))
if
write_block_data
is
not
None
:
for
hash
,
indices
in
zip
(
embed_hashes
,
write_block_data
):
self
.
hash_data
[
hash
].
append
(
indices
)
return
embed_hashes
def
hash_whitened_block_embeds
(
self
,
block_data
):
"""Transform all block embeds to have zero mean and unit covariance
when treated as samples from a distribution"""
block_idx
,
all_embeds
=
zip
(
*
block_data
.
embed_data
.
items
())
arr_embeds
=
np
.
transpose
(
np
.
array
(
all_embeds
))
mean
=
np
.
mean
(
arr_embeds
,
axis
=
1
).
reshape
(
-
1
,
1
)
centered
=
arr_embeds
-
mean
inv_cov
=
np
.
linalg
.
inv
(
np
.
cov
(
arr_embeds
))
whitener
=
np
.
transpose
(
np
.
linalg
.
cholesky
(
inv_cov
))
whitened
=
np
.
float16
(
np
.
transpose
(
whitener
.
dot
(
centered
)))
self
.
embed_mean
=
mean
.
reshape
(
-
1
)
self
.
embed_whitener
=
whitener
self
.
hash_data
=
defaultdict
(
list
)
batch_size
=
16384
i
=
0
args
=
get_args
()
with
torch
.
no_grad
():
while
True
:
if
args
.
debug
:
print
(
i
,
flush
=
True
)
batch_slice
=
slice
(
i
*
batch_size
,
(
i
+
1
)
*
batch_size
)
batch_embed
=
torch
.
cuda
.
HalfTensor
(
whitened
[
batch_slice
])
batch_meta
=
[
block_data
.
meta_data
[
idx
]
for
idx
in
block_idx
[
batch_slice
]]
if
len
(
batch_meta
)
==
0
:
break
self
.
hash_embeds
(
batch_embed
,
batch_meta
)
i
+=
1
def
exact_mips_equals
(
self
,
query_embeds
,
all_block_data
,
norm_blocks
):
"""For each query, determine whether the mips block is in the correct hash bucket"""
shuffled_block_idx
,
block_embeds
=
zip
(
*
all_block_data
.
items
())
if
norm_blocks
:
block_embeds
=
block_embeds
/
np
.
linalg
.
norm
(
block_embeds
,
axis
=
1
).
reshape
(
-
1
,
1
)
with
torch
.
no_grad
():
query_hashes
=
self
.
hash_embeds
(
query_embeds
)
# [num_query x num_blocks]
inner_products
=
torch
.
matmul
(
torch
.
cuda
.
HalfTensor
(
query_embeds
),
torch
.
cuda
.
HalfTensor
(
np
.
transpose
(
np
.
array
(
block_embeds
))))
max_inner_product_idxes
=
detach
(
torch
.
argmax
(
inner_products
,
axis
=
1
))
best_blocks
=
np
.
array
([
all_block_data
[
shuffled_block_idx
[
idx
]]
for
idx
in
max_inner_product_idxes
])
best_block_hashes
=
self
.
hash_embeds
(
best_blocks
)
print
(
'Query hashes: '
,
query_hashes
)
print
(
'Block hashes: '
,
best_block_hashes
)
equal_arr
=
np
.
equal
(
query_hashes
,
best_block_hashes
).
astype
(
int
)
# array of zeros and ones which can be used for counting success
return
equal_arr
def
exact_mips_test
(
self
,
num_queries
,
all_block_data
,
norm_blocks
):
if
self
.
whiten
:
if
self
.
embed_mean
is
None
:
self
.
hash_whitened_block_embeds
(
all_block_data
)
embed_size
=
self
.
hash_matrix
.
shape
[
0
]
query_embeds
=
np
.
random
.
multivariate_normal
(
np
.
zeros
(
embed_size
),
np
.
eye
(
embed_size
),
num_queries
)
query_embeds
=
query_embeds
/
np
.
linalg
.
norm
(
query_embeds
,
axis
=
1
).
reshape
(
-
1
,
1
)
else
:
block_idx
,
all_embeds
=
zip
(
*
all_block_data
.
items
())
arr_embeds
=
np
.
transpose
(
np
.
array
(
all_embeds
))
mean
=
np
.
mean
(
arr_embeds
,
axis
=
1
).
reshape
(
-
1
,
1
)
cov
=
np
.
cov
(
arr_embeds
)
query_embeds
=
np
.
random
.
multivariate_normal
(
mean
,
cov
,
num_queries
)
equal_arr
=
self
.
exact_mips_equals
(
query_embeds
,
all_block_data
,
norm_blocks
)
print
(
"Num correct: "
,
sum
(
equal_arr
),
" Fraction correct: "
,
sum
(
equal_arr
)
/
equal_arr
.
size
)
print
(
equal_arr
)
megatron/deprecated_data_utils/__init__.py
View file @
32bb4edc
...
@@ -19,7 +19,7 @@ import math
...
@@ -19,7 +19,7 @@ import math
import
torch
import
torch
from
.samplers
import
DistributedBatchSampler
from
.samplers
import
DistributedBatchSampler
from
.datasets
import
json_dataset
,
csv_dataset
,
split_ds
,
ConcatDataset
,
SplitDataset
,
bert_sentencepair_dataset
,
GPT2Dataset
,
InverseClozeDataset
from
.datasets
import
json_dataset
,
csv_dataset
,
split_ds
,
ConcatDataset
,
SplitDataset
,
bert_sentencepair_dataset
,
GPT2Dataset
from
.lazy_loader
import
exists_lazy
,
make_lazy
,
lazy_array_loader
from
.lazy_loader
import
exists_lazy
,
make_lazy
,
lazy_array_loader
from
.tokenization
import
Tokenization
,
CommandToken
,
Tokenizer
,
CharacterLevelTokenizer
,
BertWordPieceTokenizer
,
GPT2BPETokenizer
,
make_tokenizer
from
.tokenization
import
Tokenization
,
CommandToken
,
Tokenizer
,
CharacterLevelTokenizer
,
BertWordPieceTokenizer
,
GPT2BPETokenizer
,
make_tokenizer
from
.
import
corpora
from
.
import
corpora
...
@@ -126,10 +126,7 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
...
@@ -126,10 +126,7 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
ds
=
split_ds
(
ds
,
split
)
ds
=
split_ds
(
ds
,
split
)
if
'bert'
in
ds_type
.
lower
():
if
'bert'
in
ds_type
.
lower
():
presplit_sentences
=
kwargs
[
'presplit_sentences'
]
if
'presplit_sentences'
in
kwargs
else
False
presplit_sentences
=
kwargs
[
'presplit_sentences'
]
if
'presplit_sentences'
in
kwargs
else
False
if
'ict'
in
ds_type
.
lower
():
dstype
=
bert_sentencepair_dataset
dstype
=
InverseClozeDataset
else
:
dstype
=
bert_sentencepair_dataset
ds
=
[
dstype
(
d
,
max_seq_len
=
seq_length
,
presplit_sentences
=
presplit_sentences
)
ds
=
[
dstype
(
d
,
max_seq_len
=
seq_length
,
presplit_sentences
=
presplit_sentences
)
if
d
is
not
None
else
None
for
d
in
ds
]
if
d
is
not
None
else
None
for
d
in
ds
]
elif
ds_type
.
lower
()
==
'gpt2'
:
elif
ds_type
.
lower
()
==
'gpt2'
:
...
@@ -137,10 +134,7 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
...
@@ -137,10 +134,7 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
else
:
else
:
if
'bert'
in
ds_type
.
lower
():
if
'bert'
in
ds_type
.
lower
():
presplit_sentences
=
kwargs
[
'presplit_sentences'
]
if
'presplit_sentences'
in
kwargs
else
False
presplit_sentences
=
kwargs
[
'presplit_sentences'
]
if
'presplit_sentences'
in
kwargs
else
False
if
'ict'
in
ds_type
.
lower
():
dstype
=
bert_sentencepair_dataset
dstype
=
InverseClozeDataset
else
:
dstype
=
bert_sentencepair_dataset
ds
=
dstype
(
ds
,
max_seq_len
=
seq_length
,
presplit_sentences
=
presplit_sentences
)
ds
=
dstype
(
ds
,
max_seq_len
=
seq_length
,
presplit_sentences
=
presplit_sentences
)
elif
ds_type
.
lower
()
==
'gpt2'
:
elif
ds_type
.
lower
()
==
'gpt2'
:
ds
=
GPT2Dataset
(
ds
,
max_seq_len
=
seq_length
)
ds
=
GPT2Dataset
(
ds
,
max_seq_len
=
seq_length
)
...
...
megatron/deprecated_data_utils/configure_data.py
View file @
32bb4edc
...
@@ -46,9 +46,11 @@ class DataConfig:
...
@@ -46,9 +46,11 @@ class DataConfig:
def
make_data_loader
(
dataset
,
batch_size
,
args
):
def
make_data_loader
(
dataset
,
batch_size
,
args
):
if
args
.
shuffle
:
shuffle
=
args
.
shuffle
if
shuffle
:
sampler
=
data_utils
.
samplers
.
RandomSampler
(
sampler
=
data_utils
.
samplers
.
RandomSampler
(
dataset
,
replacement
=
True
,
num_samples
=
batch_size
*
args
.
train_iters
)
dataset
,
replacement
=
True
,
num_samples
=
batch_size
*
args
.
train_iters
)
else
:
else
:
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
world_size
=
torch
.
distributed
.
get_world_size
(
world_size
=
torch
.
distributed
.
get_world_size
(
...
...
megatron/deprecated_data_utils/datasets.py
View file @
32bb4edc
...
@@ -18,7 +18,6 @@ import os
...
@@ -18,7 +18,6 @@ import os
import
time
import
time
from
operator
import
itemgetter
from
operator
import
itemgetter
from
bisect
import
bisect_right
from
bisect
import
bisect_right
import
itertools
import
json
import
json
import
csv
import
csv
import
math
import
math
...
@@ -337,6 +336,7 @@ class json_dataset(data.Dataset):
...
@@ -337,6 +336,7 @@ class json_dataset(data.Dataset):
all_strs (list): list of all strings from the dataset
all_strs (list): list of all strings from the dataset
all_labels (list): list of all labels from the dataset (if they have it)
all_labels (list): list of all labels from the dataset (if they have it)
"""
"""
def
__init__
(
self
,
path
,
tokenizer
=
None
,
preprocess_fn
=
None
,
binarize_sent
=
False
,
def
__init__
(
self
,
path
,
tokenizer
=
None
,
preprocess_fn
=
None
,
binarize_sent
=
False
,
text_key
=
'sentence'
,
label_key
=
'label'
,
loose_json
=
False
,
**
kwargs
):
text_key
=
'sentence'
,
label_key
=
'label'
,
loose_json
=
False
,
**
kwargs
):
self
.
is_lazy
=
False
self
.
is_lazy
=
False
...
@@ -354,6 +354,9 @@ class json_dataset(data.Dataset):
...
@@ -354,6 +354,9 @@ class json_dataset(data.Dataset):
self
.
X
.
append
(
s
)
self
.
X
.
append
(
s
)
self
.
Y
.
append
(
j
[
label_key
])
self
.
Y
.
append
(
j
[
label_key
])
if
binarize_sent
:
self
.
Y
=
binarize_labels
(
self
.
Y
,
hard
=
binarize_sent
)
def
SetTokenizer
(
self
,
tokenizer
):
def
SetTokenizer
(
self
,
tokenizer
):
if
tokenizer
is
None
:
if
tokenizer
is
None
:
self
.
using_tokenizer
=
False
self
.
using_tokenizer
=
False
...
@@ -642,8 +645,10 @@ class bert_sentencepair_dataset(data.Dataset):
...
@@ -642,8 +645,10 @@ class bert_sentencepair_dataset(data.Dataset):
np_rng
=
np
.
random
.
RandomState
(
seed
=
[
rng
.
randint
(
0
,
2
**
32
-
1
)
for
_
in
range
(
16
)])
np_rng
=
np
.
random
.
RandomState
(
seed
=
[
rng
.
randint
(
0
,
2
**
32
-
1
)
for
_
in
range
(
16
)])
# get seq length
# get seq length
target_seq_length
=
self
.
max_seq_len
target_seq_length
=
self
.
max_seq_len
short_seq
=
False
if
rng
.
random
()
<
self
.
short_seq_prob
:
if
rng
.
random
()
<
self
.
short_seq_prob
:
target_seq_length
=
rng
.
randint
(
2
,
target_seq_length
)
target_seq_length
=
rng
.
randint
(
2
,
target_seq_length
)
short_seq
=
True
# get sentence pair and label
# get sentence pair and label
is_random_next
=
None
is_random_next
=
None
...
@@ -817,7 +822,7 @@ class bert_sentencepair_dataset(data.Dataset):
...
@@ -817,7 +822,7 @@ class bert_sentencepair_dataset(data.Dataset):
def
mask_token
(
self
,
idx
,
tokens
,
types
,
vocab_words
,
rng
):
def
mask_token
(
self
,
idx
,
tokens
,
types
,
vocab_words
,
rng
):
"""
"""
helper function to mask `idx` token from `tokens` according to
helper function to mask `idx` token from `tokens` according to
section 3.
1
.1 of https://arxiv.org/pdf/1810.04805.pdf
section 3.
3
.1 of https://arxiv.org/pdf/1810.04805.pdf
"""
"""
label
=
tokens
[
idx
]
label
=
tokens
[
idx
]
if
rng
.
random
()
<
0.8
:
if
rng
.
random
()
<
0.8
:
...
@@ -876,185 +881,3 @@ class bert_sentencepair_dataset(data.Dataset):
...
@@ -876,185 +881,3 @@ class bert_sentencepair_dataset(data.Dataset):
mask_labels
[
idx
]
=
label
mask_labels
[
idx
]
=
label
return
(
output_tokens
,
output_types
),
mask
,
mask_labels
,
pad_mask
return
(
output_tokens
,
output_types
),
mask
,
mask_labels
,
pad_mask
class
InverseClozeDataset
(
data
.
Dataset
):
"""
Dataset containing sentences and various 'blocks' for an inverse cloze task.
Arguments:
ds (Dataset or array-like): data corpus to use for training
max_seq_len (int): maximum sequence length to use for an input sentence
short_seq_prob (float): Proportion of input sentences purposefully shorter than max_seq_len
dataset_size (int): number of input sentences in the dataset.
"""
def
__init__
(
self
,
ds
,
max_seq_len
=
512
,
short_seq_prob
=
.
01
,
dataset_size
=
None
,
presplit_sentences
=
False
,
weighted
=
True
,
**
kwargs
):
self
.
ds
=
ds
self
.
ds_len
=
len
(
self
.
ds
)
self
.
tokenizer
=
self
.
ds
.
GetTokenizer
()
self
.
vocab_words
=
list
(
self
.
tokenizer
.
text_token_vocab
.
values
())
self
.
ds
.
SetTokenizer
(
None
)
self
.
max_seq_len
=
max_seq_len
self
.
short_seq_prob
=
short_seq_prob
self
.
dataset_size
=
dataset_size
if
self
.
dataset_size
is
None
:
# this is wrong
self
.
dataset_size
=
self
.
ds_len
*
(
self
.
ds_len
-
1
)
self
.
presplit_sentences
=
presplit_sentences
if
not
self
.
presplit_sentences
:
nltk
.
download
(
'punkt'
,
download_dir
=
"./nltk"
)
self
.
weighted
=
weighted
if
self
.
weighted
:
if
hasattr
(
self
.
ds
,
'is_lazy'
)
and
self
.
ds
.
is_lazy
:
lens
=
np
.
array
(
self
.
ds
.
lens
)
else
:
lens
=
np
.
array
([
len
(
d
[
'text'
])
if
isinstance
(
d
,
dict
)
else
len
(
d
)
for
d
in
self
.
ds
])
self
.
total_len
=
np
.
sum
(
lens
)
self
.
weighting
=
list
(
accumulate
(
lens
))
else
:
self
.
weighting
=
None
def
get_weighted_samples
(
self
,
np_rng
):
if
self
.
weighting
is
not
None
:
idx
=
np_rng
.
randint
(
self
.
total_len
)
return
bisect_right
(
self
.
weighting
,
idx
)
else
:
return
np_rng
.
randint
(
self
.
ds_len
-
1
)
def
__len__
(
self
):
return
self
.
dataset_size
def
__getitem__
(
self
,
idx
):
# get rng state corresponding to index (allows deterministic random pair)
rng
=
random
.
Random
(
idx
+
1000
)
np_rng
=
np
.
random
.
RandomState
(
seed
=
[
rng
.
randint
(
0
,
2
**
32
-
1
)
for
_
in
range
(
16
)])
# get seq length. Save 2 tokens for beginning and end
target_seq_length
=
self
.
max_seq_len
-
2
if
rng
.
random
()
<
self
.
short_seq_prob
:
target_seq_length
=
rng
.
randint
(
5
,
target_seq_length
)
input_data
,
context_data
=
self
.
get_input_and_context
(
target_seq_length
,
rng
,
np_rng
)
input_tokens
,
input_token_types
,
input_pad_mask
=
input_data
context_tokens
,
context_token_types
,
context_pad_mask
=
context_data
sample
=
{
'input_text'
:
np
.
array
(
input_tokens
),
'query_types'
:
np
.
array
(
input_token_types
),
'input_pad_mask'
:
np
.
array
(
input_pad_mask
),
'context_text'
:
np
.
array
(
context_tokens
),
'block_types'
:
np
.
array
(
context_token_types
),
'context_pad_mask'
:
np
.
array
(
context_pad_mask
)
}
return
sample
def
get_sentence_split_doc
(
self
,
idx
):
"""fetch document at index idx and split into sentences"""
document
=
self
.
ds
[
idx
]
if
isinstance
(
document
,
dict
):
document
=
document
[
'text'
]
lines
=
document
.
split
(
'
\n
'
)
if
self
.
presplit_sentences
:
return
[
line
for
line
in
lines
if
line
]
rtn
=
[]
for
line
in
lines
:
if
line
!=
''
:
rtn
.
extend
(
tokenize
.
sent_tokenize
(
line
))
return
rtn
def
sentence_tokenize
(
self
,
sent
,
sentence_num
=
0
,
beginning
=
False
,
ending
=
False
):
"""tokenize sentence and get token types"""
tokens
=
self
.
tokenizer
.
EncodeAsIds
(
sent
).
tokenization
str_type
=
'str'
+
str
(
sentence_num
)
token_types
=
[
self
.
tokenizer
.
get_type
(
str_type
).
Id
]
*
len
(
tokens
)
return
tokens
,
token_types
def
get_input_and_context
(
self
,
target_seq_length
,
rng
,
np_rng
):
"""fetches a sentence and its surrounding context"""
num_tries
=
0
while
num_tries
<
20
:
num_tries
+=
1
doc
=
None
while
doc
is
None
:
doc_idx
=
self
.
get_weighted_samples
(
np_rng
)
# doc is a list of sentences
doc
=
self
.
get_sentence_split_doc
(
doc_idx
)
if
not
doc
:
doc
=
None
# set up and tokenize the entire selected document
num_sentences
=
len
(
doc
)
padless_max_len
=
self
.
max_seq_len
-
2
# select a random sentence from the document as input
# TODO: consider adding multiple input sentences.
input_sentence_idx
=
rng
.
randint
(
0
,
num_sentences
-
1
)
tokens
,
token_types
=
self
.
sentence_tokenize
(
doc
[
input_sentence_idx
],
0
)
input_tokens
,
input_token_types
=
tokens
[:
target_seq_length
],
token_types
[:
target_seq_length
]
if
not
len
(
input_tokens
)
>
0
:
continue
context_tokens
,
context_token_types
=
[],
[]
# 10% of the time, the input sentence is left in the context.
# The other 90% of the time, remove it.
if
rng
.
random
()
<
0.1
:
context_tokens
=
input_tokens
.
copy
()
context_token_types
=
input_token_types
.
copy
()
# parameters for examining sentences to add to the context
view_preceding
=
True
view_radius
=
1
while
len
(
context_tokens
)
<
padless_max_len
:
# keep adding sentences while the context can accommodate more.
if
view_preceding
:
examine_idx
=
input_sentence_idx
-
view_radius
if
examine_idx
>=
0
:
new_tokens
,
new_token_types
=
self
.
sentence_tokenize
(
doc
[
examine_idx
],
0
)
context_tokens
=
new_tokens
+
context_tokens
context_token_types
=
new_token_types
+
context_token_types
else
:
examine_idx
=
input_sentence_idx
+
view_radius
if
examine_idx
<
num_sentences
:
new_tokens
,
new_token_types
=
self
.
sentence_tokenize
(
doc
[
examine_idx
],
0
)
context_tokens
+=
new_tokens
context_token_types
+=
new_token_types
view_radius
+=
1
view_preceding
=
not
view_preceding
if
view_radius
>
num_sentences
:
break
# assemble the tokens and token types of the context
context_tokens
=
context_tokens
[:
padless_max_len
]
context_token_types
=
context_token_types
[:
padless_max_len
]
if
not
len
(
context_tokens
)
>
0
:
continue
# concatenate 'CLS' and 'SEP' tokens and add extra token types
input_tokens
,
input_token_types
,
input_pad_mask
=
self
.
concat_and_pad_tokens
(
input_tokens
,
input_token_types
)
context_tokens
,
context_token_types
,
context_pad_mask
=
self
.
concat_and_pad_tokens
(
context_tokens
,
context_token_types
)
return
(
input_tokens
,
input_token_types
,
input_pad_mask
),
\
(
context_tokens
,
context_token_types
,
context_pad_mask
)
else
:
raise
RuntimeError
(
"Could not get a valid data point from InverseClozeDataset"
)
def
concat_and_pad_tokens
(
self
,
tokens
,
token_types
):
"""concat with special tokens and pad sequence to self.max_seq_len"""
tokens
=
[
self
.
tokenizer
.
get_command
(
'ENC'
).
Id
]
+
tokens
+
[
self
.
tokenizer
.
get_command
(
'sep'
).
Id
]
token_types
=
[
token_types
[
0
]]
+
token_types
+
[
token_types
[
0
]]
assert
len
(
tokens
)
<=
self
.
max_seq_len
num_pad
=
max
(
0
,
self
.
max_seq_len
-
len
(
tokens
))
pad_mask
=
[
0
]
*
len
(
tokens
)
+
[
1
]
*
num_pad
tokens
+=
[
self
.
tokenizer
.
get_command
(
'pad'
).
Id
]
*
num_pad
token_types
+=
[
token_types
[
0
]]
*
num_pad
return
tokens
,
token_types
,
pad_mask
megatron/global_vars.py
View file @
32bb4edc
...
@@ -164,14 +164,14 @@ class _Timer:
...
@@ -164,14 +164,14 @@ class _Timer:
def
start
(
self
):
def
start
(
self
):
"""Start the timer."""
"""Start the timer."""
assert
not
self
.
started_
,
'timer has already been started'
assert
not
self
.
started_
,
'timer has already been started'
#
torch.cuda.synchronize()
torch
.
cuda
.
synchronize
()
self
.
start_time
=
time
.
time
()
self
.
start_time
=
time
.
time
()
self
.
started_
=
True
self
.
started_
=
True
def
stop
(
self
):
def
stop
(
self
):
"""Stop the timer."""
"""Stop the timer."""
assert
self
.
started_
,
'timer is not started'
assert
self
.
started_
,
'timer is not started'
#
torch.cuda.synchronize()
torch
.
cuda
.
synchronize
()
self
.
elapsed_
+=
(
time
.
time
()
-
self
.
start_time
)
self
.
elapsed_
+=
(
time
.
time
()
-
self
.
start_time
)
self
.
started_
=
False
self
.
started_
=
False
...
...
megatron/initialize.py
View file @
32bb4edc
...
@@ -15,7 +15,6 @@
...
@@ -15,7 +15,6 @@
"""Megatron initialization."""
"""Megatron initialization."""
import
datetime
import
random
import
random
import
os
import
os
...
@@ -62,7 +61,8 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
...
@@ -62,7 +61,8 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
_write_args_to_tensorboard
()
_write_args_to_tensorboard
()
def
init_distributed
():
def
_initialize_distributed
():
"""Initialize torch.distributed and mpu."""
args
=
get_args
()
args
=
get_args
()
device_count
=
torch
.
cuda
.
device_count
()
device_count
=
torch
.
cuda
.
device_count
()
...
@@ -102,13 +102,6 @@ def init_distributed():
...
@@ -102,13 +102,6 @@ def init_distributed():
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
init_method
=
init_method
)
init_method
=
init_method
)
def
_initialize_distributed
():
"""Initialize torch.distributed and mpu."""
init_distributed
()
args
=
get_args
()
device_count
=
torch
.
cuda
.
device_count
()
# Set the model-parallel / data-parallel communicators.
# Set the model-parallel / data-parallel communicators.
if
device_count
>
0
:
if
device_count
>
0
:
mpu
.
initialize_model_parallel
(
args
.
model_parallel_size
)
mpu
.
initialize_model_parallel
(
args
.
model_parallel_size
)
...
...
megatron/model/bert_model.py
View file @
32bb4edc
...
@@ -125,17 +125,12 @@ class BertModel(MegatronModule):
...
@@ -125,17 +125,12 @@ class BertModel(MegatronModule):
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
args
.
num_layers
)
max_pos_embeds
=
None
if
not
add_binary_head
and
ict_head_size
is
None
:
max_pos_embeds
=
2
*
args
.
seq_length
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
attention_mask_func
=
bert_attention_mask_func
,
attention_mask_func
=
bert_attention_mask_func
,
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
add_pooler
,
add_pooler
=
add_pooler
,
init_method
=
init_method
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
,
scaled_init_method
=
scaled_init_method
)
max_pos_embeds
=
max_pos_embeds
)
if
not
self
.
add_ict_head
:
if
not
self
.
add_ict_head
:
self
.
lm_head
=
BertLMHead
(
self
.
lm_head
=
BertLMHead
(
...
...
megatron/model/distributed.py
View file @
32bb4edc
...
@@ -56,7 +56,7 @@ class DistributedDataParallel(MegatronModule):
...
@@ -56,7 +56,7 @@ class DistributedDataParallel(MegatronModule):
if
not
no_scale
and
not
reduce_after
:
if
not
no_scale
and
not
reduce_after
:
coalesced
/=
dist
.
get_world_size
(
group
=
self
.
data_parallel_group
)
coalesced
/=
dist
.
get_world_size
(
group
=
self
.
data_parallel_group
)
dist
.
all_reduce
(
coalesced
,
group
=
self
.
data_parallel_group
)
dist
.
all_reduce
(
coalesced
,
group
=
self
.
data_parallel_group
)
#
torch.cuda.synchronize()
torch
.
cuda
.
synchronize
()
if
not
no_scale
and
reduce_after
:
if
not
no_scale
and
reduce_after
:
coalesced
/=
dist
.
get_world_size
(
group
=
self
.
data_parallel_group
)
coalesced
/=
dist
.
get_world_size
(
group
=
self
.
data_parallel_group
)
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
...
...
megatron/model/language_model.py
View file @
32bb4edc
...
@@ -44,7 +44,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
...
@@ -44,7 +44,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def
get_language_model
(
attention_mask_func
,
num_tokentypes
,
add_pooler
,
def
get_language_model
(
attention_mask_func
,
num_tokentypes
,
add_pooler
,
init_method
,
scaled_init_method
,
max_pos_embeds
=
None
):
init_method
,
scaled_init_method
):
"""Build language model and return along with the key to save."""
"""Build language model and return along with the key to save."""
args
=
get_args
()
args
=
get_args
()
...
@@ -60,8 +60,7 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
...
@@ -60,8 +60,7 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
init_method
=
init_method
,
init_method
=
init_method
,
output_layer_init_method
=
scaled_init_method
,
output_layer_init_method
=
scaled_init_method
,
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
add_pooler
,
add_pooler
=
add_pooler
)
max_pos_embeds
=
max_pos_embeds
)
# key used for checkpoints.
# key used for checkpoints.
language_model_key
=
'language_model'
language_model_key
=
'language_model'
...
@@ -268,8 +267,7 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -268,8 +267,7 @@ class TransformerLanguageModel(MegatronModule):
init_method
,
init_method
,
output_layer_init_method
,
output_layer_init_method
,
num_tokentypes
=
0
,
num_tokentypes
=
0
,
add_pooler
=
False
,
add_pooler
=
False
):
max_pos_embeds
=
None
):
super
(
TransformerLanguageModel
,
self
).
__init__
()
super
(
TransformerLanguageModel
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
...
@@ -278,11 +276,10 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -278,11 +276,10 @@ class TransformerLanguageModel(MegatronModule):
self
.
init_method
=
init_method
self
.
init_method
=
init_method
self
.
add_pooler
=
add_pooler
self
.
add_pooler
=
add_pooler
max_pos_embeds
=
args
.
max_position_embeddings
if
max_pos_embeds
is
None
else
max_pos_embeds
# Embeddings
# Embeddings
self
.
embedding
=
Embedding
(
self
.
hidden_size
,
self
.
embedding
=
Embedding
(
self
.
hidden_size
,
args
.
padded_vocab_size
,
args
.
padded_vocab_size
,
max_pos_embeds
,
args
.
max_pos
ition
_embed
ding
s
,
args
.
hidden_dropout
,
args
.
hidden_dropout
,
self
.
init_method
,
self
.
init_method
,
self
.
num_tokentypes
)
self
.
num_tokentypes
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment