Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
8573ab35
Commit
8573ab35
authored
May 24, 2020
by
Neel Kant
Browse files
Use Faiss GPU index and report retrieval utility
parent
e59496bf
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
68 additions
and
27 deletions
+68
-27
megatron/arguments.py
megatron/arguments.py
+1
-0
megatron/data/realm_index.py
megatron/data/realm_index.py
+38
-17
megatron/utils.py
megatron/utils.py
+6
-1
pretrain_realm.py
pretrain_realm.py
+23
-9
No files found.
megatron/arguments.py
View file @
8573ab35
...
@@ -388,6 +388,7 @@ def _add_data_args(parser):
...
@@ -388,6 +388,7 @@ def _add_data_args(parser):
help
=
'Mask loss for the end of document tokens.'
)
help
=
'Mask loss for the end of document tokens.'
)
group
.
add_argument
(
'--query-in-block-prob'
,
type
=
float
,
default
=
0.1
,
group
.
add_argument
(
'--query-in-block-prob'
,
type
=
float
,
default
=
0.1
,
help
=
'Probability of keeping query in block for ICT dataset'
)
help
=
'Probability of keeping query in block for ICT dataset'
)
group
.
add_argument
(
'--faiss-use-gpu'
,
action
=
'store_true'
)
return
parser
return
parser
...
...
megatron/data/realm_index.py
View file @
8573ab35
...
@@ -33,9 +33,9 @@ class BlockData(object):
...
@@ -33,9 +33,9 @@ class BlockData(object):
@
classmethod
@
classmethod
def
load_from_file
(
cls
,
fname
):
def
load_from_file
(
cls
,
fname
):
print
(
"
> Unpickling block data"
)
print
(
"
\n
> Unpickling block data"
,
flush
=
True
)
state_dict
=
pickle
.
load
(
open
(
fname
,
'rb'
))
state_dict
=
pickle
.
load
(
open
(
fname
,
'rb'
))
print
(
"
> Finished unpickling
"
)
print
(
"
>
> Finished unpickling
block data
\n
"
,
flush
=
True
)
new_index
=
cls
()
new_index
=
cls
()
new_index
.
embed_data
=
state_dict
[
'embed_data'
]
new_index
.
embed_data
=
state_dict
[
'embed_data'
]
...
@@ -69,7 +69,7 @@ class BlockData(object):
...
@@ -69,7 +69,7 @@ class BlockData(object):
shard_size
=
len
(
data
[
'embed_data'
])
shard_size
=
len
(
data
[
'embed_data'
])
self
.
embed_data
.
update
(
data
[
'embed_data'
])
self
.
embed_data
.
update
(
data
[
'embed_data'
])
self
.
meta_data
.
update
(
data
[
'meta_data'
])
self
.
meta_data
.
update
(
data
[
'meta_data'
])
assert
(
len
(
self
.
embed_data
)
==
old_size
+
shard_size
)
or
(
str
(
ignore_shard
)
in
fname
)
#
assert (len(self.embed_data) == old_size + shard_size) or (str(ignore_shard) in fname)
args
=
get_args
()
args
=
get_args
()
with
open
(
args
.
block_data_path
,
'wb'
)
as
final_file
:
with
open
(
args
.
block_data_path
,
'wb'
)
as
final_file
:
...
@@ -82,6 +82,7 @@ class FaissMIPSIndex(object):
...
@@ -82,6 +82,7 @@ class FaissMIPSIndex(object):
self
.
index_type
=
index_type
self
.
index_type
=
index_type
self
.
embed_size
=
embed_size
self
.
embed_size
=
embed_size
self
.
use_gpu
=
use_gpu
self
.
use_gpu
=
use_gpu
self
.
id_map
=
dict
()
# alsh
# alsh
self
.
m
=
5
self
.
m
=
5
...
@@ -95,12 +96,20 @@ class FaissMIPSIndex(object):
...
@@ -95,12 +96,20 @@ class FaissMIPSIndex(object):
if
self
.
index_type
not
in
INDEX_TYPES
:
if
self
.
index_type
not
in
INDEX_TYPES
:
raise
ValueError
(
"Invalid index type specified"
)
raise
ValueError
(
"Invalid index type specified"
)
index
=
faiss
.
index_factory
(
self
.
embed_size
,
'Flat'
,
faiss
.
METRIC_INNER_PRODUCT
)
print
(
"
\n
> Building index"
,
flush
=
True
)
self
.
block_mips_index
=
faiss
.
IndexIDMap
(
index
)
self
.
block_mips_index
=
faiss
.
index_factory
(
self
.
embed_size
,
'Flat'
,
faiss
.
METRIC_INNER_PRODUCT
)
if
not
self
.
use_gpu
:
self
.
block_mips_index
=
faiss
.
IndexIDMap
(
self
.
block_mips_index
)
print
(
">> Finished building index"
,
flush
=
True
)
if
self
.
use_gpu
:
if
self
.
use_gpu
:
res
=
faiss
.
StandardGpuResources
()
res
=
faiss
.
StandardGpuResources
()
device
=
mpu
.
get_data_parallel_rank
()
# self.block_mips_index = faiss.index_cpu_to_gpu(res, device, self.block_mips_index)
self
.
block_mips_index
=
faiss
.
index_cpu_to_gpu
(
res
,
device
,
self
.
block_mips_index
)
config
=
faiss
.
GpuIndexFlatConfig
()
config
.
device
=
torch
.
cuda
.
current_device
()
config
.
useFloat16
=
True
self
.
block_mips_index
=
faiss
.
GpuIndexFlat
(
res
,
self
.
block_mips_index
,
config
)
print
(
">>> Loaded Faiss index on GPU {}
\n
"
.
format
(
self
.
block_mips_index
.
getDevice
()),
flush
=
True
)
def
reset_index
(
self
):
def
reset_index
(
self
):
self
.
_set_block_index
()
self
.
_set_block_index
()
...
@@ -108,12 +117,16 @@ class FaissMIPSIndex(object):
...
@@ -108,12 +117,16 @@ class FaissMIPSIndex(object):
def
add_block_embed_data
(
self
,
all_block_data
,
clear_block_data
=
False
):
def
add_block_embed_data
(
self
,
all_block_data
,
clear_block_data
=
False
):
"""Add the embedding of each block to the underlying FAISS index"""
"""Add the embedding of each block to the underlying FAISS index"""
block_indices
,
block_embeds
=
zip
(
*
all_block_data
.
embed_data
.
items
())
block_indices
,
block_embeds
=
zip
(
*
all_block_data
.
embed_data
.
items
())
if
self
.
use_gpu
:
for
i
,
idx
in
enumerate
(
block_indices
):
self
.
id_map
[
i
]
=
idx
if
clear_block_data
:
if
clear_block_data
:
all_block_data
.
clear
()
all_block_data
.
clear
()
if
self
.
index_type
==
'flat_l2'
:
if
self
.
use_gpu
:
block_embeds
=
self
.
alsh_block_preprocess_fn
(
block_embeds
)
self
.
block_mips_index
.
add
(
np
.
float32
(
np
.
array
(
block_embeds
)))
self
.
block_mips_index
.
add_with_ids
(
np
.
float32
(
np
.
array
(
block_embeds
)),
np
.
array
(
block_indices
))
else
:
self
.
block_mips_index
.
add_with_ids
(
np
.
float32
(
np
.
array
(
block_embeds
)),
np
.
array
(
block_indices
))
def
search_mips_index
(
self
,
query_embeds
,
top_k
,
reconstruct
=
True
):
def
search_mips_index
(
self
,
query_embeds
,
top_k
,
reconstruct
=
True
):
"""Get the top-k blocks by the index distance metric.
"""Get the top-k blocks by the index distance metric.
...
@@ -123,14 +136,22 @@ class FaissMIPSIndex(object):
...
@@ -123,14 +136,22 @@ class FaissMIPSIndex(object):
"""
"""
if
self
.
index_type
==
'flat_l2'
:
if
self
.
index_type
==
'flat_l2'
:
query_embeds
=
self
.
alsh_query_preprocess_fn
(
query_embeds
)
query_embeds
=
self
.
alsh_query_preprocess_fn
(
query_embeds
)
query_embeds
=
np
.
float32
(
query_embeds
)
query_embeds
=
np
.
float32
(
detach
(
query_embeds
))
# query_embeds = query_embeds.float()
if
reconstruct
:
with
torch
.
no_grad
():
top_k_block_embeds
=
self
.
block_mips_index
.
search_and_reconstruct
(
query_embeds
,
top_k
)
if
reconstruct
:
return
top_k_block_embeds
top_k_block_embeds
=
self
.
block_mips_index
.
search_and_reconstruct
(
query_embeds
,
top_k
)
else
:
return
top_k_block_embeds
distances
,
block_indices
=
self
.
block_mips_index
.
search
(
query_embeds
,
top_k
)
else
:
return
distances
,
block_indices
distances
,
block_indices
=
self
.
block_mips_index
.
search
(
query_embeds
,
top_k
)
if
self
.
use_gpu
:
fresh_indices
=
np
.
zeros
(
block_indices
.
shape
)
for
i
in
range
(
block_indices
.
shape
[
0
]):
for
j
in
range
(
block_indices
.
shape
[
1
]):
fresh_indices
[
i
,
j
]
=
self
.
id_map
[
block_indices
[
i
,
j
]]
block_indices
=
fresh_indices
return
distances
,
block_indices
# functions below are for ALSH, which currently isn't being used
# functions below are for ALSH, which currently isn't being used
...
...
megatron/utils.py
View file @
8573ab35
...
@@ -34,7 +34,12 @@ def reduce_losses(losses):
...
@@ -34,7 +34,12 @@ def reduce_losses(losses):
reduced_losses
=
torch
.
cat
(
reduced_losses
=
torch
.
cat
(
[
loss
.
clone
().
detach
().
view
(
1
)
for
loss
in
losses
])
[
loss
.
clone
().
detach
().
view
(
1
)
for
loss
in
losses
])
torch
.
distributed
.
all_reduce
(
reduced_losses
,
group
=
get_data_parallel_group
())
torch
.
distributed
.
all_reduce
(
reduced_losses
,
group
=
get_data_parallel_group
())
reduced_losses
=
reduced_losses
/
torch
.
distributed
.
get_world_size
()
args
=
get_args
()
if
args
.
max_training_rank
is
not
None
:
num_trainers
=
args
.
max_training_rank
else
:
num_trainers
=
torch
.
distributed
.
get_world_size
()
reduced_losses
=
reduced_losses
/
num_trainers
return
reduced_losses
return
reduced_losses
...
...
pretrain_realm.py
View file @
8573ab35
...
@@ -26,7 +26,8 @@ from megatron import print_rank_0
...
@@ -26,7 +26,8 @@ from megatron import print_rank_0
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.model
import
REALMBertModel
,
REALMRetriever
from
megatron.model
import
REALMBertModel
,
REALMRetriever
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
reduce_losses
,
report_memory
from
megatron
import
mpu
from
indexer
import
initialize_and_run_async_megatron
from
indexer
import
initialize_and_run_async_megatron
num_batches
=
0
num_batches
=
0
...
@@ -37,11 +38,14 @@ def model_provider():
...
@@ -37,11 +38,14 @@ def model_provider():
args
=
get_args
()
args
=
get_args
()
print_rank_0
(
'building REALM models ...'
)
print_rank_0
(
'building REALM models ...'
)
ict_model
=
load_ict_checkpoint
()
try
:
ict_model
=
load_ict_checkpoint
(
from_realm_chkpt
=
True
)
except
:
ict_model
=
load_ict_checkpoint
(
from_realm_chkpt
=
False
)
ict_dataset
=
get_ict_dataset
(
use_titles
=
False
)
ict_dataset
=
get_ict_dataset
(
use_titles
=
False
)
all_block_data
=
BlockData
.
load_from_file
(
args
.
block_data_path
)
all_block_data
=
BlockData
.
load_from_file
(
args
.
block_data_path
)
# hashed_index = RandProjectionLSHIndex.load_from_file(args.block_index_path)
# hashed_index = RandProjectionLSHIndex.load_from_file(args.block_index_path)
hashed_index
=
FaissMIPSIndex
(
index_type
=
'flat_ip'
,
embed_size
=
128
)
hashed_index
=
FaissMIPSIndex
(
index_type
=
'flat_ip'
,
embed_size
=
128
,
use_gpu
=
args
.
faiss_use_gpu
)
hashed_index
.
add_block_embed_data
(
all_block_data
)
hashed_index
.
add_block_embed_data
(
all_block_data
)
# top_k + 1 because we may need to exclude trivial candidate
# top_k + 1 because we may need to exclude trivial candidate
...
@@ -61,6 +65,9 @@ def get_batch(data_iterator):
...
@@ -61,6 +65,9 @@ def get_batch(data_iterator):
data
=
None
data
=
None
else
:
else
:
data
=
next
(
data_iterator
)
data
=
next
(
data_iterator
)
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
# Unpack.
...
@@ -90,9 +97,11 @@ def forward_step(data_iterator, model):
...
@@ -90,9 +97,11 @@ def forward_step(data_iterator, model):
# Forward model.
# Forward model.
lm_logits
,
block_probs
=
model
(
tokens
,
pad_mask
,
query_block_indices
)
lm_logits
,
block_probs
=
model
(
tokens
,
pad_mask
,
query_block_indices
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
retrieval_utility
=
get_retrieval_utility
(
lm_logits
,
block_probs
,
labels
,
loss_mask
)
max_retrieval_utility
,
top_retrieval_utility
,
avg_retrieval_utility
=
mpu
.
checkpoint
(
get_retrieval_utility
,
lm_logits
,
block_probs
,
labels
,
loss_mask
)
# P(y|x) = sum_z(P(y|z, x) * P(z|x))
# P(y|x) = sum_z(P(y|z, x) * P(z|x))
null_block_probs
=
torch
.
mean
(
block_probs
[:,
block_probs
.
shape
[
1
]
-
1
])
block_probs
=
block_probs
.
unsqueeze
(
2
).
unsqueeze
(
3
).
expand_as
(
lm_logits
)
block_probs
=
block_probs
.
unsqueeze
(
2
).
unsqueeze
(
3
).
expand_as
(
lm_logits
)
lm_logits
=
torch
.
sum
(
lm_logits
*
block_probs
,
dim
=
1
)[:,
:
labels
.
shape
[
1
]]
lm_logits
=
torch
.
sum
(
lm_logits
*
block_probs
,
dim
=
1
)[:,
:
labels
.
shape
[
1
]]
...
@@ -101,9 +110,13 @@ def forward_step(data_iterator, model):
...
@@ -101,9 +110,13 @@ def forward_step(data_iterator, model):
lm_loss
=
torch
.
sum
(
lm_loss
=
torch
.
sum
(
lm_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
lm_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
reduced_loss
=
reduce_losses
([
lm_loss
,
retrieval_utility
])
reduced_loss
=
reduce_losses
([
lm_loss
,
max_
retrieval_utility
,
top_retrieval_utility
,
avg_retrieval_utility
,
null_block_probs
])
# torch.cuda.synchronize()
# torch.cuda.synchronize()
return
lm_loss
,
{
'lm_loss'
:
reduced_loss
[
0
],
'retrieval_utility'
:
reduced_loss
[
1
]}
return
lm_loss
,
{
'lm_loss'
:
reduced_loss
[
0
],
'max_ru'
:
reduced_loss
[
1
],
'top_ru'
:
reduced_loss
[
2
],
'avg_ru'
:
reduced_loss
[
3
],
'null_prob'
:
reduced_loss
[
4
]}
def
get_retrieval_utility
(
lm_logits
,
block_probs
,
labels
,
loss_mask
):
def
get_retrieval_utility
(
lm_logits
,
block_probs
,
labels
,
loss_mask
):
...
@@ -129,9 +142,10 @@ def get_retrieval_utility(lm_logits, block_probs, labels, loss_mask):
...
@@ -129,9 +142,10 @@ def get_retrieval_utility(lm_logits, block_probs, labels, loss_mask):
retrieved_block_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
retrieved_block_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
retrieved_block_losses
.
append
(
retrieved_block_loss
)
retrieved_block_losses
.
append
(
retrieved_block_loss
)
avg_retrieved_block_loss
=
torch
.
sum
(
torch
.
cuda
.
FloatTensor
(
retrieved_block_losses
))
/
(
lm_logits
.
shape
[
1
]
-
1
)
avg_retrieved_block_loss
=
torch
.
sum
(
torch
.
cuda
.
FloatTensor
(
retrieved_block_losses
))
/
(
lm_logits
.
shape
[
1
]
-
1
)
max_retrieval_utility
=
null_block_loss
-
min
(
retrieved_block_losses
)
retrieval_utility
=
null_block_loss
-
avg_retrieved_block_loss
top_retrieval_utility
=
null_block_loss
-
retrieved_block_losses
[
0
]
return
retrieval_utility
avg_retrieval_utility
=
null_block_loss
-
avg_retrieved_block_loss
return
max_retrieval_utility
,
top_retrieval_utility
,
avg_retrieval_utility
def
qa_forward_step
(
data_iterator
,
model
):
def
qa_forward_step
(
data_iterator
,
model
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment