Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
4170fc69
Commit
4170fc69
authored
Apr 20, 2020
by
Neel Kant
Browse files
Merge hashing into realm-mlm
parents
ee2490d5
43d5d84b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
174 additions
and
6 deletions
+174
-6
hashed_index.py
hashed_index.py
+166
-0
megatron/data/ict_dataset.py
megatron/data/ict_dataset.py
+4
-3
pretrain_bert_ict.py
pretrain_bert_ict.py
+4
-3
No files found.
hashed_index.py
0 → 100644
View file @
4170fc69
from
collections
import
defaultdict
import
pickle
import
numpy
as
np
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.data.bert_dataset
import
get_indexed_dataset_
from
megatron.data.ict_dataset
import
InverseClozeDataset
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.initialize
import
initialize_megatron
from
megatron.training
import
get_model
from
pretrain_bert_ict
import
get_batch
,
model_provider
def
main
():
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args
=
get_args
()
model
=
load_checkpoint
()
model
.
eval
()
dataset
=
get_dataset
()
data_iter
=
iter
(
get_dataloader
(
dataset
))
hash_data
=
defaultdict
(
list
)
hash_matrix
=
torch
.
cuda
.
HalfTensor
(
np
.
random
.
rand
(
128
,
1024
))
#all_input_tokens = []
#all_input_logits = []
#all_block_tokens = []
block_data
=
defaultdict
(
list
)
all_block_logits
=
[]
all_block_indices
=
[]
my_rank
=
args
.
rank
block_file
=
open
(
f
'block_data
{
my_rank
}
.pkl'
,
'wb'
)
i
=
0
while
True
:
try
:
input_tokens
,
input_types
,
input_pad_mask
,
\
block_tokens
,
block_token_types
,
block_pad_mask
,
block_indices
=
get_batch
(
data_iter
)
except
:
break
# TODO: make sure input is still in block
input_logits
,
block_logits
,
_
=
model
.
module
.
module
.
forward
(
input_tokens
,
input_types
,
input_pad_mask
,
block_tokens
,
block_pad_mask
,
block_token_types
,
return_logits
=
True
)
block_hash_pos
=
torch
.
matmul
(
block_logits
,
hash_matrix
)
block_hash_full
=
torch
.
cat
((
block_hash_pos
,
-
block_hash_pos
),
axis
=
1
)
block_hashes
=
torch
.
argmax
(
block_hash_full
,
axis
=
1
).
detach
().
cpu
().
numpy
()
for
hash
,
indices_array
in
zip
(
block_hashes
,
block_indices
):
hash_data
[
int
(
hash
)].
append
(
indices_array
)
#all_input_tokens.append(input_tokens.detach().cpu().numpy())
#all_input_logits.append(input_logits.detach().cpu().numpy())
#all_block_tokens.append(block_tokens.detach().cpu().numpy())
#all_block_logits.append(block_logits.detach().cpu().numpy())
#all_block_indices.append(block_indices.detach().cpu().numpy()[:, 3])
block_logits
=
block_logits
.
detach
().
cpu
().
numpy
()
block_indices
=
block_indices
.
detach
().
cpu
().
numpy
()[:,
3
]
for
logits
,
idx
in
zip
(
block_logits
,
block_indices
):
pickle
.
dump
({
idx
:
logits
},
block_file
)
if
i
==
100
:
print
(
i
)
i
+=
1
block_file
.
close
()
#all_input_tokens = np.array(all_input_tokens).reshape(-1, args.seq_length)
#all_input_logits = np.array(all_input_logits).reshape(-1, 128)
#all_block_tokens = np.array(all_block_tokens).reshape(-1, args.seq_length)
#all_block_logits = np.array(all_block_logits).reshape(-1, 128)
#all_block_indices = np.array(all_block_indices).reshape(all_block_logits.shape[0])
#for logits, idx in zip(all_block_logits, all_block_indices):
# block_data[idx] = logits
#with as block_file:
# pickle.dump(block_data, block_file)
#np.save(f'input_tokens{my_rank}.npy', all_input_tokens)
#np.save(f'input_logits{my_rank}.npy', all_input_logits)
#np.save(f'block_tokens{my_rank}.npy', all_block_tokens)
#np.save(f'block_logits{my_rank}.npy', all_block_logits)
for
hash
,
block_indices
in
hash_data
.
items
():
hash_data
[
hash
]
=
np
.
array
(
block_indices
)
hash_data
[
'matrix'
]
=
hash_matrix
with
open
(
f
'hash_data
{
my_rank
}
.pkl'
,
'wb'
)
as
hash_file
:
pickle
.
dump
(
hash_data
,
hash_file
)
def
load_checkpoint
():
args
=
get_args
()
model
=
get_model
(
model_provider
)
if
isinstance
(
model
,
torchDDP
):
model
=
model
.
module
tracker_filename
=
get_checkpoint_tracker_filename
(
args
.
load
)
with
open
(
tracker_filename
,
'r'
)
as
f
:
iteration
=
int
(
f
.
read
().
strip
())
assert
iteration
>
0
checkpoint_name
=
get_checkpoint_name
(
args
.
load
,
iteration
,
False
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading checkpoint {}'
.
format
(
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
model
.
load_state_dict
(
state_dict
[
'model'
])
torch
.
distributed
.
barrier
()
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
' successfully loaded {}'
.
format
(
checkpoint_name
))
return
model
def
get_dataset
():
args
=
get_args
()
block_dataset
=
get_indexed_dataset_
(
args
.
data_path
,
'mmap'
,
True
)
titles_dataset
=
get_indexed_dataset_
(
args
.
data_path
+
'-titles'
,
'mmap'
,
True
)
kwargs
=
dict
(
name
=
'full'
,
context_dataset
=
block_dataset
,
titles_dataset
=
titles_dataset
,
data_prefix
=
args
.
data_path
,
num_epochs
=
1
,
max_num_samples
=
None
,
max_seq_length
=
288
,
# doesn't matter
short_seq_prob
=
0.0001
,
# doesn't matter
seed
=
1
)
dataset
=
InverseClozeDataset
(
**
kwargs
)
return
dataset
def
get_dataloader
(
dataset
):
args
=
get_args
()
world_size
=
mpu
.
get_data_parallel_world_size
()
rank
=
mpu
.
get_data_parallel_rank
()
global_batch_size
=
args
.
batch_size
*
world_size
num_workers
=
args
.
num_workers
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
batch_sampler
=
DistributedBatchSampler
(
sampler
,
batch_size
=
global_batch_size
,
drop_last
=
True
,
rank
=
rank
,
world_size
=
world_size
)
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
num_workers
=
num_workers
,
pin_memory
=
True
)
if
__name__
==
"__main__"
:
main
()
megatron/data/ict_dataset.py
View file @
4170fc69
...
...
@@ -41,6 +41,7 @@ class InverseClozeDataset(Dataset):
def
__getitem__
(
self
,
idx
):
start_idx
,
end_idx
,
doc_idx
,
block_idx
=
self
.
samples_mapping
[
idx
]
title
=
list
(
self
.
title_dataset
[
int
(
doc_idx
)])
block
=
[
list
(
self
.
block_dataset
[
i
])
for
i
in
range
(
start_idx
,
end_idx
)]
assert
len
(
block
)
>
1
...
...
@@ -50,8 +51,8 @@ class InverseClozeDataset(Dataset):
else
:
rand_sent_idx
=
self
.
rng
.
randint
(
1
,
len
(
block
)
-
2
)
# keep the query in the
block
10% of the time.
if
self
.
rng
.
random
()
<
0.
1
:
# keep the query in the
context
10% of the time.
if
self
.
rng
.
random
()
<
1
:
query
=
block
[
rand_sent_idx
].
copy
()
else
:
query
=
block
.
pop
(
rand_sent_idx
)
...
...
@@ -71,7 +72,7 @@ class InverseClozeDataset(Dataset):
'block_tokens'
:
np
.
array
(
block_tokens
),
'block_types'
:
np
.
array
(
block_token_types
),
'block_pad_mask'
:
np
.
array
(
block_pad_mask
),
'block_indices'
:
np
.
array
([
start_idx
,
end_idx
,
doc_idx
,
block_idx
])
'block_indices'
:
np
.
array
([
start_idx
,
end_idx
,
doc_idx
,
block_idx
])
.
astype
(
np
.
int64
)
}
return
sample
...
...
pretrain_bert_ict.py
View file @
4170fc69
...
...
@@ -46,7 +46,7 @@ def get_batch(data_iterator):
# Items and their type.
keys
=
[
'query_tokens'
,
'query_types'
,
'query_pad_mask'
,
'block_tokens'
,
'block_types'
,
'block_pad_mask'
]
'block_tokens'
,
'block_types'
,
'block_pad_mask'
,
'block_indices'
]
datatype
=
torch
.
int64
# Broadcast data.
...
...
@@ -63,9 +63,10 @@ def get_batch(data_iterator):
block_tokens
=
data_b
[
'block_tokens'
].
long
()
block_types
=
data_b
[
'block_types'
].
long
()
block_pad_mask
=
data_b
[
'block_pad_mask'
].
long
()
block_indices
=
data_b
[
'block_indices'
].
long
()
return
query_tokens
,
query_types
,
query_pad_mask
,
\
block_tokens
,
block_types
,
block_pad_mask
block_tokens
,
block_types
,
block_pad_mask
,
block_indices
def
forward_step
(
data_iterator
,
model
):
...
...
@@ -75,7 +76,7 @@ def forward_step(data_iterator, model):
# Get the batch.
timers
(
'batch generator'
).
start
()
query_tokens
,
query_types
,
query_pad_mask
,
\
block_tokens
,
block_types
,
block_pad_mask
=
get_batch
(
data_iterator
)
block_tokens
,
block_types
,
block_pad_mask
,
block_indices
=
get_batch
(
data_iterator
)
timers
(
'batch generator'
).
stop
()
# Forward model.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment