Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
662dc982
Commit
662dc982
authored
Apr 15, 2020
by
Neel Kant
Browse files
Debug hash dump
parent
81c71789
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
5 deletions
+15
-5
ict_qualitative_test.py
ict_qualitative_test.py
+13
-3
megatron/data/ict_dataset.py
megatron/data/ict_dataset.py
+1
-1
pretrain_bert_ict.py
pretrain_bert_ict.py
+1
-1
No files found.
ict_qualitative_test.py
View file @
662dc982
...
@@ -26,13 +26,14 @@ def main():
...
@@ -26,13 +26,14 @@ def main():
data_iter
=
iter
(
get_dataloader
(
dataset
))
data_iter
=
iter
(
get_dataloader
(
dataset
))
hash_data
=
defaultdict
(
list
)
hash_data
=
defaultdict
(
list
)
hash_matrix
=
np
.
random
.
rand
(
128
,
1024
)
hash_matrix
=
torch
.
cuda
.
HalfTensor
(
np
.
random
.
rand
(
128
,
1024
)
)
all_input_tokens
=
[]
all_input_tokens
=
[]
all_input_logits
=
[]
all_input_logits
=
[]
all_block_tokens
=
[]
all_block_tokens
=
[]
all_block_logits
=
[]
all_block_logits
=
[]
i
=
0
while
True
:
while
True
:
try
:
try
:
input_tokens
,
input_types
,
input_pad_mask
,
\
input_tokens
,
input_types
,
input_pad_mask
,
\
...
@@ -43,8 +44,8 @@ def main():
...
@@ -43,8 +44,8 @@ def main():
input_tokens
,
input_types
,
input_pad_mask
,
block_tokens
,
block_pad_mask
,
block_token_types
,
return_logits
=
True
)
input_tokens
,
input_types
,
input_pad_mask
,
block_tokens
,
block_pad_mask
,
block_token_types
,
return_logits
=
True
)
block_hash_pos
=
torch
.
matmul
(
block_logits
,
hash_matrix
)
block_hash_pos
=
torch
.
matmul
(
block_logits
,
hash_matrix
)
block_hash_full
=
torch
.
con
cat
((
block_hash_pos
,
-
block_hash_pos
),
axis
=
1
)
block_hash_full
=
torch
.
cat
((
block_hash_pos
,
-
block_hash_pos
),
axis
=
1
)
block_hashes
=
torch
.
argmax
(
block_hash_full
,
axis
=
1
)
block_hashes
=
torch
.
argmax
(
block_hash_full
,
axis
=
1
)
.
detach
().
cpu
().
numpy
()
for
hash
,
idx
in
zip
(
block_hashes
,
block_indices
):
for
hash
,
idx
in
zip
(
block_hashes
,
block_indices
):
hash_data
[
int
(
hash
)].
append
(
int
(
idx
))
hash_data
[
int
(
hash
)].
append
(
int
(
idx
))
...
@@ -53,6 +54,15 @@ def main():
...
@@ -53,6 +54,15 @@ def main():
all_block_tokens
.
append
(
block_tokens
.
detach
().
cpu
().
numpy
())
all_block_tokens
.
append
(
block_tokens
.
detach
().
cpu
().
numpy
())
all_block_logits
.
append
(
block_logits
.
detach
().
cpu
().
numpy
())
all_block_logits
.
append
(
block_logits
.
detach
().
cpu
().
numpy
())
if
i
%
100
==
0
:
print
(
i
,
flush
=
True
)
print
(
len
(
all_block_tokens
),
flush
=
True
)
print
(
block_tokens
.
shape
,
flush
=
True
)
i
+=
1
if
i
==
10
:
break
all_input_tokens
=
np
.
array
(
all_input_tokens
).
reshape
(
-
1
,
args
.
seq_length
)
all_input_tokens
=
np
.
array
(
all_input_tokens
).
reshape
(
-
1
,
args
.
seq_length
)
all_input_logits
=
np
.
array
(
all_input_logits
).
reshape
(
-
1
,
128
)
all_input_logits
=
np
.
array
(
all_input_logits
).
reshape
(
-
1
,
128
)
all_block_tokens
=
np
.
array
(
all_block_tokens
).
reshape
(
-
1
,
args
.
seq_length
)
all_block_tokens
=
np
.
array
(
all_block_tokens
).
reshape
(
-
1
,
args
.
seq_length
)
...
...
megatron/data/ict_dataset.py
View file @
662dc982
...
@@ -79,7 +79,7 @@ class InverseClozeDataset(Dataset):
...
@@ -79,7 +79,7 @@ class InverseClozeDataset(Dataset):
'context_text'
:
np
.
array
(
context_tokens
),
'context_text'
:
np
.
array
(
context_tokens
),
'context_types'
:
np
.
array
(
context_token_types
),
'context_types'
:
np
.
array
(
context_token_types
),
'context_pad_mask'
:
np
.
array
(
context_pad_mask
),
'context_pad_mask'
:
np
.
array
(
context_pad_mask
),
'context_indices'
:
np
.
array
([
block_idx
])
'context_indices'
:
np
.
array
([
block_idx
])
.
astype
(
np
.
int64
)
}
}
return
sample
return
sample
...
...
pretrain_bert_ict.py
View file @
662dc982
...
@@ -47,7 +47,7 @@ def get_batch(data_iterator):
...
@@ -47,7 +47,7 @@ def get_batch(data_iterator):
# Items and their type.
# Items and their type.
keys
=
[
'input_text'
,
'input_types'
,
'input_pad_mask'
,
keys
=
[
'input_text'
,
'input_types'
,
'input_pad_mask'
,
'context_text'
,
'context_types'
,
'context_pad_mask'
]
'context_text'
,
'context_types'
,
'context_pad_mask'
,
'context_indices'
]
datatype
=
torch
.
int64
datatype
=
torch
.
int64
# Broadcast data.
# Broadcast data.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment