Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
81c71789
"vscode:/vscode.git/clone" did not exist on "1e651ca2c9f12bdcc5d63da8830847706e186f22"
Commit
81c71789
authored
Apr 15, 2020
by
Neel Kant
Browse files
Implement reformer hashing scheme
parent
8ba76558
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
46 additions
and
19 deletions
+46
-19
ict_qualitative_test.py
ict_qualitative_test.py
+30
-10
megatron/data/helpers.cpp
megatron/data/helpers.cpp
+11
-6
megatron/data/ict_dataset.py
megatron/data/ict_dataset.py
+3
-2
pretrain_bert_ict.py
pretrain_bert_ict.py
+2
-1
No files found.
ict_qualitative_test.py
View file @
81c71789
from
collections
import
defaultdict
import
pickle
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
...
@@ -22,20 +25,33 @@ def main():
...
@@ -22,20 +25,33 @@ def main():
dataset
=
get_dataset
()
dataset
=
get_dataset
()
data_iter
=
iter
(
get_dataloader
(
dataset
))
data_iter
=
iter
(
get_dataloader
(
dataset
))
hash_data
=
defaultdict
(
list
)
hash_matrix
=
np
.
random
.
rand
(
128
,
1024
)
all_input_tokens
=
[]
all_input_tokens
=
[]
all_input_logits
=
[]
all_input_logits
=
[]
all_block_tokens
=
[]
all_block_tokens
=
[]
all_block_logits
=
[]
all_block_logits
=
[]
for
i
in
range
(
100
):
while
True
:
input_tokens
,
input_types
,
input_pad_mask
,
block_tokens
,
block_token_types
,
block_pad_mask
=
get_batch
(
data_iter
)
try
:
input_logits
,
doc_logits
,
_
=
model
.
module
.
module
.
forward
(
input_tokens
,
input_types
,
input_pad_mask
,
\
block_tokens
,
block_token_types
,
block_pad_mask
,
block_indices
=
get_batch
(
data_iter
)
except
StopIteration
:
break
input_logits
,
block_logits
,
_
=
model
.
module
.
module
.
forward
(
input_tokens
,
input_types
,
input_pad_mask
,
block_tokens
,
block_pad_mask
,
block_token_types
,
return_logits
=
True
)
input_tokens
,
input_types
,
input_pad_mask
,
block_tokens
,
block_pad_mask
,
block_token_types
,
return_logits
=
True
)
block_hash_pos
=
torch
.
matmul
(
block_logits
,
hash_matrix
)
block_hash_full
=
torch
.
concat
((
block_hash_pos
,
-
block_hash_pos
),
axis
=
1
)
block_hashes
=
torch
.
argmax
(
block_hash_full
,
axis
=
1
)
for
hash
,
idx
in
zip
(
block_hashes
,
block_indices
):
hash_data
[
int
(
hash
)].
append
(
int
(
idx
))
all_input_tokens
.
append
(
input_tokens
.
detach
().
cpu
().
numpy
())
all_input_tokens
.
append
(
input_tokens
.
detach
().
cpu
().
numpy
())
all_input_logits
.
append
(
input_logits
.
detach
().
cpu
().
numpy
())
all_input_logits
.
append
(
input_logits
.
detach
().
cpu
().
numpy
())
all_block_tokens
.
append
(
block_tokens
.
detach
().
cpu
().
numpy
())
all_block_tokens
.
append
(
block_tokens
.
detach
().
cpu
().
numpy
())
all_block_logits
.
append
(
d
oc_logits
.
detach
().
cpu
().
numpy
())
all_block_logits
.
append
(
bl
oc
k
_logits
.
detach
().
cpu
().
numpy
())
all_input_tokens
=
np
.
array
(
all_input_tokens
).
reshape
(
-
1
,
args
.
seq_length
)
all_input_tokens
=
np
.
array
(
all_input_tokens
).
reshape
(
-
1
,
args
.
seq_length
)
all_input_logits
=
np
.
array
(
all_input_logits
).
reshape
(
-
1
,
128
)
all_input_logits
=
np
.
array
(
all_input_logits
).
reshape
(
-
1
,
128
)
...
@@ -44,7 +60,14 @@ def main():
...
@@ -44,7 +60,14 @@ def main():
np
.
save
(
'input_tokens.npy'
,
all_input_tokens
)
np
.
save
(
'input_tokens.npy'
,
all_input_tokens
)
np
.
save
(
'input_logits.npy'
,
all_input_logits
)
np
.
save
(
'input_logits.npy'
,
all_input_logits
)
np
.
save
(
'block_tokens.npy'
,
all_block_tokens
)
np
.
save
(
'block_tokens.npy'
,
all_block_tokens
)
np
.
save
(
'doc_logits.npy'
,
all_block_logits
)
np
.
save
(
'block_logits.npy'
,
all_block_logits
)
for
hash
,
block_indices
in
hash_data
.
items
():
hash_data
[
hash
]
=
np
.
array
(
block_indices
)
hash_data
[
'matrix'
]
=
hash_matrix
with
open
(
'hash_data.pkl'
,
'wb'
)
as
hash_file
:
pickle
.
dump
(
hash_data
,
hash_file
)
def
load_checkpoint
():
def
load_checkpoint
():
...
@@ -78,16 +101,13 @@ def get_dataset():
...
@@ -78,16 +101,13 @@ def get_dataset():
block_dataset
=
get_indexed_dataset_
(
args
.
data_path
,
'mmap'
,
True
)
block_dataset
=
get_indexed_dataset_
(
args
.
data_path
,
'mmap'
,
True
)
titles_dataset
=
get_indexed_dataset_
(
args
.
data_path
+
'-titles'
,
'mmap'
,
True
)
titles_dataset
=
get_indexed_dataset_
(
args
.
data_path
+
'-titles'
,
'mmap'
,
True
)
doc_idx_ptr
=
block_dataset
.
get_doc_idx
()
total_num_documents
=
block_dataset
.
doc_idx
.
shape
[
0
]
-
1
block_dataset
.
set_doc_idx
(
doc_idx_ptr
[
0
:
total_num_documents
])
kwargs
=
dict
(
kwargs
=
dict
(
name
=
'full'
,
name
=
'full'
,
context_dataset
=
block_dataset
,
context_dataset
=
block_dataset
,
titles_dataset
=
titles_dataset
,
titles_dataset
=
titles_dataset
,
data_prefix
=
args
.
data_path
,
data_prefix
=
args
.
data_path
,
num_epochs
=
None
,
num_epochs
=
1
,
max_num_samples
=
total_num_documents
*
3
,
max_num_samples
=
None
,
max_seq_length
=
288
,
# doesn't matter
max_seq_length
=
288
,
# doesn't matter
short_seq_prob
=
0.0001
,
# doesn't matter
short_seq_prob
=
0.0001
,
# doesn't matter
seed
=
1
seed
=
1
...
...
megatron/data/helpers.cpp
View file @
81c71789
...
@@ -363,6 +363,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
...
@@ -363,6 +363,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Current map index.
// Current map index.
uint64_t
map_index
=
0
;
uint64_t
map_index
=
0
;
int32_t
block_id
=
0
;
// For each epoch:
// For each epoch:
for
(
int32_t
epoch
=
0
;
epoch
<
num_epochs
;
++
epoch
)
{
for
(
int32_t
epoch
=
0
;
epoch
<
num_epochs
;
++
epoch
)
{
...
@@ -425,14 +426,16 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
...
@@ -425,14 +426,16 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Populate the map.
// Populate the map.
if
(
second
)
{
if
(
second
)
{
const
auto
map_index_0
=
3
*
map_index
;
const
auto
map_index_0
=
4
*
map_index
;
maps
[
map_index_0
]
=
static_cast
<
DocIdx
>
(
prev_start_index
);
maps
[
map_index_0
]
=
static_cast
<
DocIdx
>
(
prev_start_index
);
maps
[
map_index_0
+
1
]
=
static_cast
<
DocIdx
>
(
sent_index
+
1
);
maps
[
map_index_0
+
1
]
=
static_cast
<
DocIdx
>
(
sent_index
+
1
);
maps
[
map_index_0
+
2
]
=
static_cast
<
DocIdx
>
(
doc
);
maps
[
map_index_0
+
2
]
=
static_cast
<
DocIdx
>
(
doc
);
maps
[
map_index_0
+
3
]
=
static_cast
<
DocIdx
>
(
block_id
);
}
}
// Update indices / counters.
// Update indices / counters.
++
map_index
;
++
map_index
;
++
block_id
;
prev_start_index
=
sent_index
+
1
;
prev_start_index
=
sent_index
+
1
;
seq_len
=
0
;
seq_len
=
0
;
num_sent
=
0
;
num_sent
=
0
;
...
@@ -440,6 +443,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
...
@@ -440,6 +443,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
}
// for (auto sent_index=sent_index_first; ...
}
// for (auto sent_index=sent_index_first; ...
}
// if (num_remain_sent > 1) {
}
// if (num_remain_sent > 1) {
}
// for (int doc=0; doc < num_docs; ++doc) {
}
// for (int doc=0; doc < num_docs; ++doc) {
block_id
=
0
;
}
// for (int epoch=0; epoch < num_epochs; ++epoch) {
}
// for (int epoch=0; epoch < num_epochs; ++epoch) {
if
(
!
second
)
{
if
(
!
second
)
{
...
@@ -449,7 +453,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
...
@@ -449,7 +453,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
}
}
assert
(
maps
==
NULL
);
assert
(
maps
==
NULL
);
assert
(
num_samples
<
0
);
assert
(
num_samples
<
0
);
maps
=
new
DocIdx
[
3
*
map_index
];
maps
=
new
DocIdx
[
4
*
map_index
];
num_samples
=
static_cast
<
int64_t
>
(
map_index
);
num_samples
=
static_cast
<
int64_t
>
(
map_index
);
}
}
...
@@ -461,12 +465,13 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
...
@@ -461,12 +465,13 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
std
::
mt19937_64
rand64_gen
(
seed
+
1
);
std
::
mt19937_64
rand64_gen
(
seed
+
1
);
for
(
auto
i
=
(
num_samples
-
1
);
i
>
0
;
--
i
)
{
for
(
auto
i
=
(
num_samples
-
1
);
i
>
0
;
--
i
)
{
const
auto
j
=
static_cast
<
int64_t
>
(
rand64_gen
()
%
(
i
+
1
));
const
auto
j
=
static_cast
<
int64_t
>
(
rand64_gen
()
%
(
i
+
1
));
const
auto
i0
=
3
*
i
;
const
auto
i0
=
4
*
i
;
const
auto
j0
=
3
*
j
;
const
auto
j0
=
4
*
j
;
// Swap values.
// Swap values.
swap
(
maps
[
i0
],
maps
[
j0
]);
swap
(
maps
[
i0
],
maps
[
j0
]);
swap
(
maps
[
i0
+
1
],
maps
[
j0
+
1
]);
swap
(
maps
[
i0
+
1
],
maps
[
j0
+
1
]);
swap
(
maps
[
i0
+
2
],
maps
[
j0
+
2
]);
swap
(
maps
[
i0
+
2
],
maps
[
j0
+
2
]);
swap
(
maps
[
i0
+
3
],
maps
[
j0
+
3
]);
}
}
// Method to deallocate memory.
// Method to deallocate memory.
...
@@ -477,8 +482,8 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
...
@@ -477,8 +482,8 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Return the numpy array.
// Return the numpy array.
const
auto
byte_size
=
sizeof
(
DocIdx
);
const
auto
byte_size
=
sizeof
(
DocIdx
);
return
py
::
array
(
std
::
vector
<
int64_t
>
{
num_samples
,
3
},
// shape
return
py
::
array
(
std
::
vector
<
int64_t
>
{
num_samples
,
4
},
// shape
{
3
*
byte_size
,
byte_size
},
// C-style contiguous strides
{
4
*
byte_size
,
byte_size
},
// C-style contiguous strides
maps
,
// the data pointer
maps
,
// the data pointer
free_when_done
);
// numpy array references
free_when_done
);
// numpy array references
...
...
megatron/data/ict_dataset.py
View file @
81c71789
...
@@ -47,7 +47,7 @@ class InverseClozeDataset(Dataset):
...
@@ -47,7 +47,7 @@ class InverseClozeDataset(Dataset):
return
self
.
samples_mapping
.
shape
[
0
]
return
self
.
samples_mapping
.
shape
[
0
]
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
start_idx
,
end_idx
,
doc_idx
=
self
.
samples_mapping
[
idx
]
start_idx
,
end_idx
,
doc_idx
,
block_idx
=
self
.
samples_mapping
[
idx
]
title
=
list
(
self
.
titles_dataset
[
int
(
doc_idx
)])
title
=
list
(
self
.
titles_dataset
[
int
(
doc_idx
)])
context
=
[
list
(
self
.
context_dataset
[
i
])
for
i
in
range
(
start_idx
,
end_idx
)]
context
=
[
list
(
self
.
context_dataset
[
i
])
for
i
in
range
(
start_idx
,
end_idx
)]
assert
len
(
context
)
>
1
assert
len
(
context
)
>
1
...
@@ -78,7 +78,8 @@ class InverseClozeDataset(Dataset):
...
@@ -78,7 +78,8 @@ class InverseClozeDataset(Dataset):
'input_pad_mask'
:
np
.
array
(
input_pad_mask
),
'input_pad_mask'
:
np
.
array
(
input_pad_mask
),
'context_text'
:
np
.
array
(
context_tokens
),
'context_text'
:
np
.
array
(
context_tokens
),
'context_types'
:
np
.
array
(
context_token_types
),
'context_types'
:
np
.
array
(
context_token_types
),
'context_pad_mask'
:
np
.
array
(
context_pad_mask
)
'context_pad_mask'
:
np
.
array
(
context_pad_mask
),
'context_indices'
:
np
.
array
([
block_idx
])
}
}
return
sample
return
sample
...
...
pretrain_bert_ict.py
View file @
81c71789
...
@@ -64,9 +64,10 @@ def get_batch(data_iterator):
...
@@ -64,9 +64,10 @@ def get_batch(data_iterator):
context_tokens
=
data_b
[
'context_text'
].
long
()
context_tokens
=
data_b
[
'context_text'
].
long
()
context_types
=
data_b
[
'context_types'
].
long
()
context_types
=
data_b
[
'context_types'
].
long
()
context_pad_mask
=
data_b
[
'context_pad_mask'
].
long
()
context_pad_mask
=
data_b
[
'context_pad_mask'
].
long
()
context_indices
=
data_b
[
'context_indices'
].
long
()
return
input_tokens
,
input_types
,
input_pad_mask
,
\
return
input_tokens
,
input_types
,
input_pad_mask
,
\
context_tokens
,
context_types
,
context_pad_mask
context_tokens
,
context_types
,
context_pad_mask
,
context_indices
def
forward_step
(
data_iterator
,
model
):
def
forward_step
(
data_iterator
,
model
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment