Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
eaa5d877
Commit
eaa5d877
authored
Jul 25, 2020
by
Neel Kant
Browse files
Add additional assertion on Indexer to test correctness, and limit verbosity in other classes
parent
c2a32e12
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
25 additions
and
11 deletions
+25
-11
megatron/data/realm_index.py
megatron/data/realm_index.py
+14
-7
megatron/indexer.py
megatron/indexer.py
+5
-4
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+1
-0
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+5
-0
No files found.
megatron/data/realm_index.py
View file @
eaa5d877
...
@@ -7,6 +7,7 @@ import numpy as np
...
@@ -7,6 +7,7 @@ import numpy as np
import
torch
import
torch
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
mpu
def
detach
(
tensor
):
def
detach
(
tensor
):
...
@@ -47,9 +48,11 @@ class BlockData(object):
...
@@ -47,9 +48,11 @@ class BlockData(object):
def
load_from_file
(
self
):
def
load_from_file
(
self
):
"""Populate members from instance saved to file"""
"""Populate members from instance saved to file"""
print
(
"
\n
> Unpickling BlockData"
,
flush
=
True
)
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
"
\n
> Unpickling BlockData"
,
flush
=
True
)
state_dict
=
pickle
.
load
(
open
(
self
.
block_data_path
,
'rb'
))
state_dict
=
pickle
.
load
(
open
(
self
.
block_data_path
,
'rb'
))
print
(
">> Finished unpickling BlockData
\n
"
,
flush
=
True
)
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
">> Finished unpickling BlockData
\n
"
,
flush
=
True
)
self
.
embed_data
=
state_dict
[
'embed_data'
]
self
.
embed_data
=
state_dict
[
'embed_data'
]
self
.
meta_data
=
state_dict
[
'meta_data'
]
self
.
meta_data
=
state_dict
[
'meta_data'
]
...
@@ -127,7 +130,8 @@ class FaissMIPSIndex(object):
...
@@ -127,7 +130,8 @@ class FaissMIPSIndex(object):
except
ImportError
:
except
ImportError
:
raise
Exception
(
"Error: Please install faiss to use FaissMIPSIndex"
)
raise
Exception
(
"Error: Please install faiss to use FaissMIPSIndex"
)
print
(
"
\n
> Building index"
,
flush
=
True
)
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
"
\n
> Building index"
,
flush
=
True
)
self
.
block_mips_index
=
faiss
.
index_factory
(
self
.
embed_size
,
'Flat'
,
faiss
.
METRIC_INNER_PRODUCT
)
self
.
block_mips_index
=
faiss
.
index_factory
(
self
.
embed_size
,
'Flat'
,
faiss
.
METRIC_INNER_PRODUCT
)
if
self
.
use_gpu
:
if
self
.
use_gpu
:
...
@@ -138,11 +142,13 @@ class FaissMIPSIndex(object):
...
@@ -138,11 +142,13 @@ class FaissMIPSIndex(object):
config
.
useFloat16
=
True
config
.
useFloat16
=
True
self
.
block_mips_index
=
faiss
.
GpuIndexFlat
(
res
,
self
.
block_mips_index
,
config
)
self
.
block_mips_index
=
faiss
.
GpuIndexFlat
(
res
,
self
.
block_mips_index
,
config
)
print
(
">> Initialized index on GPU {}"
.
format
(
self
.
block_mips_index
.
getDevice
()),
flush
=
True
)
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
">> Initialized index on GPU {}"
.
format
(
self
.
block_mips_index
.
getDevice
()),
flush
=
True
)
else
:
else
:
# CPU index supports IDs so wrap with IDMap
# CPU index supports IDs so wrap with IDMap
self
.
block_mips_index
=
faiss
.
IndexIDMap
(
self
.
block_mips_index
)
self
.
block_mips_index
=
faiss
.
IndexIDMap
(
self
.
block_mips_index
)
print
(
">> Initialized index on CPU"
,
flush
=
True
)
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
">> Initialized index on CPU"
,
flush
=
True
)
# if we were constructed with a BlockData, then automatically load it when the FAISS structure is built
# if we were constructed with a BlockData, then automatically load it when the FAISS structure is built
if
self
.
block_data
is
not
None
:
if
self
.
block_data
is
not
None
:
...
@@ -156,7 +162,7 @@ class FaissMIPSIndex(object):
...
@@ -156,7 +162,7 @@ class FaissMIPSIndex(object):
if
self
.
block_data
is
not
None
:
if
self
.
block_data
is
not
None
:
block_data_path
=
self
.
block_data
.
block_data_path
block_data_path
=
self
.
block_data
.
block_data_path
del
self
.
block_data
del
self
.
block_data
self
.
block_data
=
BlockData
.
load_from_file
(
block_data_path
)
self
.
block_data
=
BlockData
(
block_data_path
)
self
.
_set_block_index
()
self
.
_set_block_index
()
...
@@ -183,7 +189,8 @@ class FaissMIPSIndex(object):
...
@@ -183,7 +189,8 @@ class FaissMIPSIndex(object):
else
:
else
:
self
.
block_mips_index
.
add_with_ids
(
block_embeds_arr
,
block_indices_arr
)
self
.
block_mips_index
.
add_with_ids
(
block_embeds_arr
,
block_indices_arr
)
print
(
">>> Finished adding block data to index"
,
flush
=
True
)
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
">>> Finished adding block data to index"
,
flush
=
True
)
def
search_mips_index
(
self
,
query_embeds
,
top_k
,
reconstruct
=
True
):
def
search_mips_index
(
self
,
query_embeds
,
top_k
,
reconstruct
=
True
):
"""Get the top-k blocks by the index distance metric.
"""Get the top-k blocks by the index distance metric.
...
...
megatron/indexer.py
View file @
eaa5d877
...
@@ -37,7 +37,8 @@ class IndexBuilder(object):
...
@@ -37,7 +37,8 @@ class IndexBuilder(object):
model
=
get_model
(
lambda
:
general_ict_model_provider
(
only_block_model
=
True
))
model
=
get_model
(
lambda
:
general_ict_model_provider
(
only_block_model
=
True
))
self
.
model
=
load_ict_checkpoint
(
model
,
only_block_model
=
True
,
from_realm_chkpt
=
self
.
using_realm_chkpt
)
self
.
model
=
load_ict_checkpoint
(
model
,
only_block_model
=
True
,
from_realm_chkpt
=
self
.
using_realm_chkpt
)
self
.
model
.
eval
()
self
.
model
.
eval
()
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
get_ict_dataset
(),
self
.
batch_size
))
self
.
dataset
=
get_ict_dataset
()
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
self
.
dataset
,
self
.
batch_size
))
self
.
block_data
=
BlockData
(
load_from_path
=
False
)
self
.
block_data
=
BlockData
(
load_from_path
=
False
)
def
track_and_report_progress
(
self
,
batch_size
):
def
track_and_report_progress
(
self
,
batch_size
):
...
@@ -58,7 +59,7 @@ class IndexBuilder(object):
...
@@ -58,7 +59,7 @@ class IndexBuilder(object):
try
:
try
:
# batch also has query_tokens and query_pad_data
# batch also has query_tokens and query_pad_data
_
,
_
,
block_tokens
,
block_pad_mask
,
block_sample_data
=
get_ict_batch
(
self
.
dataloader
)
_
,
_
,
block_tokens
,
block_pad_mask
,
block_sample_data
=
get_ict_batch
(
self
.
dataloader
)
except
StopIteration
:
except
(
StopIteration
,
IndexError
)
:
break
break
unwrapped_model
=
self
.
model
unwrapped_model
=
self
.
model
...
@@ -85,6 +86,6 @@ class IndexBuilder(object):
...
@@ -85,6 +86,6 @@ class IndexBuilder(object):
# rank 0 process builds the final copy
# rank 0 process builds the final copy
if
self
.
is_main_builder
:
if
self
.
is_main_builder
:
self
.
block_data
.
merge_shards_and_save
()
self
.
block_data
.
merge_shards_and_save
()
# make sure that every single piece of data was embedded
assert
len
(
self
.
block_data
.
embed_data
)
==
len
(
self
.
dataset
)
self
.
block_data
.
clear
()
self
.
block_data
.
clear
()
megatron/mpu/__init__.py
View file @
eaa5d877
...
@@ -21,6 +21,7 @@ from .data import broadcast_data
...
@@ -21,6 +21,7 @@ from .data import broadcast_data
from
.grads
import
clip_grad_norm
from
.grads
import
clip_grad_norm
from
.initialize
import
is_unitialized
from
.initialize
import
destroy_model_parallel
from
.initialize
import
destroy_model_parallel
from
.initialize
import
get_data_parallel_group
from
.initialize
import
get_data_parallel_group
from
.initialize
import
get_data_parallel_rank
from
.initialize
import
get_data_parallel_rank
...
...
megatron/mpu/initialize.py
View file @
eaa5d877
...
@@ -31,6 +31,11 @@ _MPU_WORLD_SIZE = None
...
@@ -31,6 +31,11 @@ _MPU_WORLD_SIZE = None
_MPU_RANK
=
None
_MPU_RANK
=
None
def
is_unitialized
():
"""Useful for code segments that may be accessed with or without mpu initialization"""
return
_DATA_PARALLEL_GROUP
is
None
def
initialize_model_parallel
(
model_parallel_size_
):
def
initialize_model_parallel
(
model_parallel_size_
):
"""
"""
Initialize model data parallel groups.
Initialize model data parallel groups.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment