Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
81c71789
Commit
81c71789
authored
Apr 15, 2020
by
Neel Kant
Browse files
Implement reformer hashing scheme
parent
8ba76558
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
46 additions
and
19 deletions
+46
-19
ict_qualitative_test.py
ict_qualitative_test.py
+30
-10
megatron/data/helpers.cpp
megatron/data/helpers.cpp
+11
-6
megatron/data/ict_dataset.py
megatron/data/ict_dataset.py
+3
-2
pretrain_bert_ict.py
pretrain_bert_ict.py
+2
-1
No files found.
ict_qualitative_test.py
View file @
81c71789
from
collections
import
defaultdict
import
pickle
import
numpy
as
np
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
...
...
@@ -22,20 +25,33 @@ def main():
dataset
=
get_dataset
()
data_iter
=
iter
(
get_dataloader
(
dataset
))
hash_data
=
defaultdict
(
list
)
hash_matrix
=
np
.
random
.
rand
(
128
,
1024
)
all_input_tokens
=
[]
all_input_logits
=
[]
all_block_tokens
=
[]
all_block_logits
=
[]
for
i
in
range
(
100
):
input_tokens
,
input_types
,
input_pad_mask
,
block_tokens
,
block_token_types
,
block_pad_mask
=
get_batch
(
data_iter
)
input_logits
,
doc_logits
,
_
=
model
.
module
.
module
.
forward
(
while
True
:
try
:
input_tokens
,
input_types
,
input_pad_mask
,
\
block_tokens
,
block_token_types
,
block_pad_mask
,
block_indices
=
get_batch
(
data_iter
)
except
StopIteration
:
break
input_logits
,
block_logits
,
_
=
model
.
module
.
module
.
forward
(
input_tokens
,
input_types
,
input_pad_mask
,
block_tokens
,
block_pad_mask
,
block_token_types
,
return_logits
=
True
)
block_hash_pos
=
torch
.
matmul
(
block_logits
,
hash_matrix
)
block_hash_full
=
torch
.
concat
((
block_hash_pos
,
-
block_hash_pos
),
axis
=
1
)
block_hashes
=
torch
.
argmax
(
block_hash_full
,
axis
=
1
)
for
hash
,
idx
in
zip
(
block_hashes
,
block_indices
):
hash_data
[
int
(
hash
)].
append
(
int
(
idx
))
all_input_tokens
.
append
(
input_tokens
.
detach
().
cpu
().
numpy
())
all_input_logits
.
append
(
input_logits
.
detach
().
cpu
().
numpy
())
all_block_tokens
.
append
(
block_tokens
.
detach
().
cpu
().
numpy
())
all_block_logits
.
append
(
d
oc_logits
.
detach
().
cpu
().
numpy
())
all_block_logits
.
append
(
bl
oc
k
_logits
.
detach
().
cpu
().
numpy
())
all_input_tokens
=
np
.
array
(
all_input_tokens
).
reshape
(
-
1
,
args
.
seq_length
)
all_input_logits
=
np
.
array
(
all_input_logits
).
reshape
(
-
1
,
128
)
...
...
@@ -44,7 +60,14 @@ def main():
np
.
save
(
'input_tokens.npy'
,
all_input_tokens
)
np
.
save
(
'input_logits.npy'
,
all_input_logits
)
np
.
save
(
'block_tokens.npy'
,
all_block_tokens
)
np
.
save
(
'doc_logits.npy'
,
all_block_logits
)
np
.
save
(
'block_logits.npy'
,
all_block_logits
)
for
hash
,
block_indices
in
hash_data
.
items
():
hash_data
[
hash
]
=
np
.
array
(
block_indices
)
hash_data
[
'matrix'
]
=
hash_matrix
with
open
(
'hash_data.pkl'
,
'wb'
)
as
hash_file
:
pickle
.
dump
(
hash_data
,
hash_file
)
def
load_checkpoint
():
...
...
@@ -78,16 +101,13 @@ def get_dataset():
block_dataset
=
get_indexed_dataset_
(
args
.
data_path
,
'mmap'
,
True
)
titles_dataset
=
get_indexed_dataset_
(
args
.
data_path
+
'-titles'
,
'mmap'
,
True
)
doc_idx_ptr
=
block_dataset
.
get_doc_idx
()
total_num_documents
=
block_dataset
.
doc_idx
.
shape
[
0
]
-
1
block_dataset
.
set_doc_idx
(
doc_idx_ptr
[
0
:
total_num_documents
])
kwargs
=
dict
(
name
=
'full'
,
context_dataset
=
block_dataset
,
titles_dataset
=
titles_dataset
,
data_prefix
=
args
.
data_path
,
num_epochs
=
None
,
max_num_samples
=
total_num_documents
*
3
,
num_epochs
=
1
,
max_num_samples
=
None
,
max_seq_length
=
288
,
# doesn't matter
short_seq_prob
=
0.0001
,
# doesn't matter
seed
=
1
...
...
megatron/data/helpers.cpp
View file @
81c71789
...
...
@@ -363,6 +363,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Current map index.
uint64_t
map_index
=
0
;
int32_t
block_id
=
0
;
// For each epoch:
for
(
int32_t
epoch
=
0
;
epoch
<
num_epochs
;
++
epoch
)
{
...
...
@@ -425,14 +426,16 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Populate the map.
if
(
second
)
{
const
auto
map_index_0
=
3
*
map_index
;
const
auto
map_index_0
=
4
*
map_index
;
maps
[
map_index_0
]
=
static_cast
<
DocIdx
>
(
prev_start_index
);
maps
[
map_index_0
+
1
]
=
static_cast
<
DocIdx
>
(
sent_index
+
1
);
maps
[
map_index_0
+
2
]
=
static_cast
<
DocIdx
>
(
doc
);
maps
[
map_index_0
+
3
]
=
static_cast
<
DocIdx
>
(
block_id
);
}
// Update indices / counters.
++
map_index
;
++
block_id
;
prev_start_index
=
sent_index
+
1
;
seq_len
=
0
;
num_sent
=
0
;
...
...
@@ -440,6 +443,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
}
// for (auto sent_index=sent_index_first; ...
}
// if (num_remain_sent > 1) {
}
// for (int doc=0; doc < num_docs; ++doc) {
block_id
=
0
;
}
// for (int epoch=0; epoch < num_epochs; ++epoch) {
if
(
!
second
)
{
...
...
@@ -449,7 +453,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
}
assert
(
maps
==
NULL
);
assert
(
num_samples
<
0
);
maps
=
new
DocIdx
[
3
*
map_index
];
maps
=
new
DocIdx
[
4
*
map_index
];
num_samples
=
static_cast
<
int64_t
>
(
map_index
);
}
...
...
@@ -461,12 +465,13 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
std
::
mt19937_64
rand64_gen
(
seed
+
1
);
for
(
auto
i
=
(
num_samples
-
1
);
i
>
0
;
--
i
)
{
const
auto
j
=
static_cast
<
int64_t
>
(
rand64_gen
()
%
(
i
+
1
));
const
auto
i0
=
3
*
i
;
const
auto
j0
=
3
*
j
;
const
auto
i0
=
4
*
i
;
const
auto
j0
=
4
*
j
;
// Swap values.
swap
(
maps
[
i0
],
maps
[
j0
]);
swap
(
maps
[
i0
+
1
],
maps
[
j0
+
1
]);
swap
(
maps
[
i0
+
2
],
maps
[
j0
+
2
]);
swap
(
maps
[
i0
+
3
],
maps
[
j0
+
3
]);
}
// Method to deallocate memory.
...
...
@@ -477,8 +482,8 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Return the numpy array.
const
auto
byte_size
=
sizeof
(
DocIdx
);
return
py
::
array
(
std
::
vector
<
int64_t
>
{
num_samples
,
3
},
// shape
{
3
*
byte_size
,
byte_size
},
// C-style contiguous strides
return
py
::
array
(
std
::
vector
<
int64_t
>
{
num_samples
,
4
},
// shape
{
4
*
byte_size
,
byte_size
},
// C-style contiguous strides
maps
,
// the data pointer
free_when_done
);
// numpy array references
...
...
megatron/data/ict_dataset.py
View file @
81c71789
...
...
@@ -47,7 +47,7 @@ class InverseClozeDataset(Dataset):
return
self
.
samples_mapping
.
shape
[
0
]
def
__getitem__
(
self
,
idx
):
start_idx
,
end_idx
,
doc_idx
=
self
.
samples_mapping
[
idx
]
start_idx
,
end_idx
,
doc_idx
,
block_idx
=
self
.
samples_mapping
[
idx
]
title
=
list
(
self
.
titles_dataset
[
int
(
doc_idx
)])
context
=
[
list
(
self
.
context_dataset
[
i
])
for
i
in
range
(
start_idx
,
end_idx
)]
assert
len
(
context
)
>
1
...
...
@@ -78,7 +78,8 @@ class InverseClozeDataset(Dataset):
'input_pad_mask'
:
np
.
array
(
input_pad_mask
),
'context_text'
:
np
.
array
(
context_tokens
),
'context_types'
:
np
.
array
(
context_token_types
),
'context_pad_mask'
:
np
.
array
(
context_pad_mask
)
'context_pad_mask'
:
np
.
array
(
context_pad_mask
),
'context_indices'
:
np
.
array
([
block_idx
])
}
return
sample
...
...
pretrain_bert_ict.py
View file @
81c71789
...
...
@@ -64,9 +64,10 @@ def get_batch(data_iterator):
context_tokens
=
data_b
[
'context_text'
].
long
()
context_types
=
data_b
[
'context_types'
].
long
()
context_pad_mask
=
data_b
[
'context_pad_mask'
].
long
()
context_indices
=
data_b
[
'context_indices'
].
long
()
return
input_tokens
,
input_types
,
input_pad_mask
,
\
context_tokens
,
context_types
,
context_pad_mask
context_tokens
,
context_types
,
context_pad_mask
,
context_indices
def
forward_step
(
data_iterator
,
model
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment