Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
5942af97
"docs/vscode:/vscode.git/clone" did not exist on "3ee75292edb483abe2253596b98894616d7fe37c"
Commit
5942af97
authored
Sep 23, 2022
by
Jared Casper
Browse files
Alias core.parallel_state as mpu and use it throughout code. RIP mpu.
parent
c2ea914f
Changes
63
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
148 additions
and
152 deletions
+148
-152
megatron/checkpointing.py
megatron/checkpointing.py
+19
-19
megatron/core/__init__.py
megatron/core/__init__.py
+12
-13
megatron/core/parallel_state.py
megatron/core/parallel_state.py
+0
-4
megatron/core/tensor_parallel/__init__.py
megatron/core/tensor_parallel/__init__.py
+7
-2
megatron/core/tensor_parallel/random.py
megatron/core/tensor_parallel/random.py
+4
-32
megatron/core/tensor_parallel/utils.py
megatron/core/tensor_parallel/utils.py
+66
-9
megatron/core/utils.py
megatron/core/utils.py
+0
-29
megatron/data/biencoder_dataset_utils.py
megatron/data/biencoder_dataset_utils.py
+3
-2
megatron/data/blendable_dataset.py
megatron/data/blendable_dataset.py
+0
-2
megatron/data/data_samplers.py
megatron/data/data_samplers.py
+1
-1
megatron/data/dataset_utils.py
megatron/data/dataset_utils.py
+1
-1
megatron/data/gpt_dataset.py
megatron/data/gpt_dataset.py
+2
-1
megatron/data/orqa_wiki_dataset.py
megatron/data/orqa_wiki_dataset.py
+3
-2
megatron/data/realm_dataset_utils.py
megatron/data/realm_dataset_utils.py
+4
-3
megatron/data/realm_index.py
megatron/data/realm_index.py
+7
-7
megatron/indexer.py
megatron/indexer.py
+1
-1
megatron/initialize.py
megatron/initialize.py
+12
-17
megatron/model/bert_model.py
megatron/model/bert_model.py
+4
-4
megatron/model/biencoder_model.py
megatron/model/biencoder_model.py
+2
-2
megatron/model/classification.py
megatron/model/classification.py
+0
-1
No files found.
megatron/checkpointing.py
View file @
5942af97
...
...
@@ -9,8 +9,8 @@ import numpy as np
import
torch
from
megatron
import
(
core
,
update_num_microbatches
)
from
megatron
import
update_num_microbatches
from
megatron.core
import
mpu
,
tensor_parallel
from
.global_vars
import
get_args
from
.utils
import
(
unwrap_model
,
print_rank_0
)
...
...
@@ -79,11 +79,11 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
# Use both the tensor and pipeline MP rank.
if
pipeline_parallel
is
None
:
pipeline_parallel
=
(
core
.
get_pipeline_model_parallel_world_size
()
>
1
)
pipeline_parallel
=
(
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
)
if
tensor_rank
is
None
:
tensor_rank
=
core
.
get_tensor_model_parallel_rank
()
tensor_rank
=
mpu
.
get_tensor_model_parallel_rank
()
if
pipeline_rank
is
None
:
pipeline_rank
=
core
.
get_pipeline_model_parallel_rank
()
pipeline_rank
=
mpu
.
get_pipeline_model_parallel_rank
()
# Use both the tensor and pipeline MP rank. If using the distributed
# optimizer, then the optimizer's path must additionally include the
...
...
@@ -98,7 +98,7 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
if
use_distributed_optimizer
:
model_name
=
os
.
path
.
join
(
common_path
,
"model_rng.pt"
)
optim_name
=
os
.
path
.
join
(
common_path
+
"_%03d"
%
core
.
get_data_parallel_rank
(),
common_path
+
"_%03d"
%
mpu
.
get_data_parallel_rank
(),
"optim.pt"
)
else
:
model_name
=
optim_name
=
os
.
path
.
join
(
common_path
,
"model_optim_rng.pt"
)
...
...
@@ -185,18 +185,18 @@ def get_rng_state():
'np_rng_state'
:
np
.
random
.
get_state
(),
'torch_rng_state'
:
torch
.
get_rng_state
(),
'cuda_rng_state'
:
torch
.
cuda
.
get_rng_state
(),
'rng_tracker_states'
:
core
.
tensor_parallel
.
get_cuda_rng_tracker
().
get_states
()}
'rng_tracker_states'
:
tensor_parallel
.
get_cuda_rng_tracker
().
get_states
()}
rng_state_list
=
None
if
torch
.
distributed
.
is_initialized
()
and
\
core
.
get_data_parallel_world_size
()
>
1
and
\
mpu
.
get_data_parallel_world_size
()
>
1
and
\
args
.
data_parallel_random_init
:
rng_state_list
=
\
[
None
for
i
in
range
(
core
.
get_data_parallel_world_size
())]
[
None
for
i
in
range
(
mpu
.
get_data_parallel_world_size
())]
torch
.
distributed
.
all_gather_object
(
rng_state_list
,
rng_state
,
group
=
core
.
get_data_parallel_group
())
group
=
mpu
.
get_data_parallel_group
())
else
:
rng_state_list
=
[
rng_state
]
...
...
@@ -223,7 +223,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
# Collect args, model, RNG.
model_state_dict
=
{}
if
not
torch
.
distributed
.
is_initialized
()
\
or
core
.
get_data_parallel_rank
()
==
0
:
or
mpu
.
get_data_parallel_rank
()
==
0
:
# Arguments, iteration, and model.
model_state_dict
[
'args'
]
=
args
...
...
@@ -233,7 +233,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
model_state_dict
[
'model'
]
=
model
[
0
].
state_dict_for_save_checkpoint
()
else
:
for
i
in
range
(
len
(
model
)):
core
.
set_virtual_pipeline_model_parallel_rank
(
i
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
model_state_dict
[
'model%d'
%
i
]
=
\
model
[
i
].
state_dict_for_save_checkpoint
()
...
...
@@ -246,7 +246,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
optim_state_dict
=
{}
if
not
args
.
no_save_optim
\
and
(
not
torch
.
distributed
.
is_initialized
()
or
core
.
get_data_parallel_rank
()
==
0
or
mpu
.
get_data_parallel_rank
()
==
0
or
args
.
use_distributed_optimizer
):
# Optimizer stuff.
...
...
@@ -548,7 +548,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
model
[
0
].
load_state_dict
(
model_state_dict
[
'model'
],
strict
=
strict
)
else
:
for
i
in
range
(
len
(
model
)):
core
.
set_virtual_pipeline_model_parallel_rank
(
i
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
model
[
i
].
load_state_dict
(
model_state_dict
[
'model%d'
%
i
],
strict
=
strict
)
# Fix up query/key/value matrix ordering if needed
...
...
@@ -580,7 +580,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# access rng_state for data parallel rank
if
args
.
data_parallel_random_init
:
rng_state
=
model_state_dict
[
'rng_state'
][
core
.
get_data_parallel_rank
()]
rng_state
=
model_state_dict
[
'rng_state'
][
mpu
.
get_data_parallel_rank
()]
else
:
rng_state
=
model_state_dict
[
'rng_state'
][
0
]
random
.
setstate
(
rng_state
[
'random_rng_state'
])
...
...
@@ -590,7 +590,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Check for empty states array
if
not
rng_state
[
'rng_tracker_states'
]:
raise
KeyError
core
.
tensor_parallel
.
get_cuda_rng_tracker
().
set_states
(
tensor_parallel
.
get_cuda_rng_tracker
().
set_states
(
rng_state
[
'rng_tracker_states'
])
else
:
# backward compatability
random
.
setstate
(
model_state_dict
[
'random_rng_state'
])
...
...
@@ -600,7 +600,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Check for empty states array
if
not
model_state_dict
[
'rng_tracker_states'
]:
raise
KeyError
core
.
tensor_parallel
.
get_cuda_rng_tracker
().
set_states
(
tensor_parallel
.
get_cuda_rng_tracker
().
set_states
(
model_state_dict
[
'rng_tracker_states'
])
except
KeyError
:
print_rank_0
(
'Unable to load rng state from checkpoint {}. '
...
...
@@ -640,7 +640,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
args
.
use_distributed_optimizer
,
release
=
False
)
if
core
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading checkpoint {}'
.
format
(
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
...
...
@@ -656,7 +656,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
model
[
0
].
load_state_dict
(
ret_state_dict
)
torch
.
distributed
.
barrier
()
if
core
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
' successfully loaded {}'
.
format
(
checkpoint_name
))
return
model
megatron/core/__init__.py
View file @
5942af97
from
.parallel_state
import
(
initialize_model_parallel
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_rank
,
get_pipeline_model_parallel_world_size
,
get_pipeline_model_parallel_rank
,
get_virtual_pipeline_model_parallel_rank
,
set_virtual_pipeline_model_parallel_rank
,
get_data_parallel_world_size
,
get_data_parallel_rank
,
get_global_memory_buffer
,
get_num_layers
,
)
from
megatron.core
import
tensor_parallel
import
megatron.core.parallel_state
import
megatron.core.tensor_parallel
import
megatron.core.utils
# Alias parallel_state as mpu, its legacy name
mpu
=
parallel_state
__all__
=
[
"parallel_state"
,
"tensor_parallel"
,
"utils"
,
]
megatron/core/parallel_state.py
View file @
5942af97
...
...
@@ -47,10 +47,6 @@ _DATA_PARALLEL_GLOBAL_RANKS = None
# Memory buffers to avoid dynamic memory allocation
_GLOBAL_MEMORY_BUFFER
=
None
def
is_unitialized
():
"""Useful for code segments that may be accessed with or without mpu initialization"""
return
_DATA_PARALLEL_GROUP
is
None
def
initialize_model_parallel
(
tensor_model_parallel_size
:
int
=
1
,
...
...
megatron/core/tensor_parallel/__init__.py
View file @
5942af97
...
...
@@ -5,6 +5,7 @@ from .layers import (
ColumnParallelLinear
,
RowParallelLinear
,
VocabParallelEmbedding
,
set_tensor_model_parallel_attributes
,
set_defaults_if_not_set_tensor_model_parallel_attributes
,
copy_tensor_model_parallel_attributes
,
param_is_not_tensor_parallel_duplicate
,
...
...
@@ -23,10 +24,14 @@ from .mappings import (
from
.random
import
(
checkpoint
,
get_cuda_rng_tracker
,
model_parallel_cuda_manual_seed
model_parallel_cuda_manual_seed
,
)
from
.utils
import
split_tensor_along_last_dim
from
.utils
import
(
split_tensor_along_last_dim
,
split_tensor_into_1d_equal_chunks
,
gather_split_1d_tensor
,
)
__all__
=
[
# cross_entropy.py
...
...
megatron/core/tensor_parallel/random.py
View file @
5942af97
...
...
@@ -17,6 +17,10 @@ from megatron.core.parallel_state import (
get_tensor_model_parallel_world_size
,
)
from
.utils
import
(
split_tensor_into_1d_equal_chunks
,
gather_split_1d_tensor
,
)
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME
=
'model-parallel-rng'
...
...
@@ -55,38 +59,6 @@ def _set_cuda_rng_state(new_state, device=-1):
_lazy_call
(
cb
)
def
split_tensor_into_1d_equal_chunks
(
tensor
,
new_buffer
=
False
):
"""Break a tensor into equal 1D chunks."""
partition_size
=
torch
.
numel
(
tensor
)
//
\
get_tensor_model_parallel_world_size
()
start_index
=
partition_size
*
get_tensor_model_parallel_rank
()
end_index
=
start_index
+
partition_size
if
new_buffer
:
data
=
torch
.
empty
(
partition_size
,
dtype
=
tensor
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
data
.
copy_
(
tensor
.
view
(
-
1
)[
start_index
:
end_index
])
else
:
data
=
tensor
.
view
(
-
1
)[
start_index
:
end_index
]
return
data
def
gather_split_1d_tensor
(
tensor
):
"""Opposite of above function, gather values from model parallel ranks."""
numel_gathered
=
torch
.
numel
(
tensor
)
*
\
get_tensor_model_parallel_world_size
()
gathered
=
torch
.
empty
(
numel_gathered
,
dtype
=
tensor
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
# TODO: This API is experimental in pytorch (as of Feb 2022) and
# this might break in future pytorch releases. We chose this API
# as opposed to torch.distributed.all_gather for efficiency reasons.
# This API calls directly NCCL all-gather versus the former does
# internal copies and can potentially cause slow down.
torch
.
distributed
.
_all_gather_base
(
gathered
,
tensor
,
group
=
get_tensor_model_parallel_group
())
return
gathered
class
CudaRNGStatesTracker
:
"""Tracker for the cuda RNG states.
...
...
megatron/core/tensor_parallel/utils.py
View file @
5942af97
...
...
@@ -10,12 +10,16 @@ def split_tensor_along_last_dim(
num_partitions
:
int
,
contiguous_split_chunks
:
bool
=
False
,
)
->
List
[
torch
.
Tensor
]:
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
""" Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
Returns:
A list of Tensors
"""
# Get the size and dimension.
last_dim
=
tensor
.
dim
()
-
1
...
...
@@ -28,11 +32,64 @@ def split_tensor_along_last_dim(
return
tensor_list
def
split_tensor_into_1d_equal_chunks
(
tensor
,
new_buffer
=
False
):
""" Break a tensor into equal 1D chunks across tensor parallel ranks.
Returns a Tensor or View with this rank's portion of the data.
Arguments:
tensor: The tensor to split
Keyword Arguments:
new_buffer (bool): If True, returns a new Tensor.
If False, returns a view into the existing Tensor.
Default is False
"""
partition_size
=
torch
.
numel
(
tensor
)
//
\
get_tensor_model_parallel_world_size
()
start_index
=
partition_size
*
get_tensor_model_parallel_rank
()
end_index
=
start_index
+
partition_size
if
new_buffer
:
data
=
torch
.
empty
(
partition_size
,
dtype
=
tensor
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
data
.
copy_
(
tensor
.
view
(
-
1
)[
start_index
:
end_index
])
else
:
data
=
tensor
.
view
(
-
1
)[
start_index
:
end_index
]
return
data
def
gather_split_1d_tensor
(
tensor
):
""" Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor
model parallel ranks.
Returns a new Tensor with the gathered data.
Arguments:
tensor: A Tensor or view of this rank's portion of the data.
"""
numel_gathered
=
torch
.
numel
(
tensor
)
*
\
get_tensor_model_parallel_world_size
()
gathered
=
torch
.
empty
(
numel_gathered
,
dtype
=
tensor
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
# TODO: This API is experimental in pytorch (as of Feb 2022) and
# this might break in future pytorch releases. We chose this API
# as opposed to torch.distributed.all_gather for efficiency reasons.
# This API calls directly NCCL all-gather versus the former does
# internal copies and can potentially cause slow down.
torch
.
distributed
.
_all_gather_base
(
gathered
,
tensor
,
group
=
get_tensor_model_parallel_group
())
return
gathered
class
VocabUtility
:
"""Split the vocabulary into `world_size` chunks and return the
first and last index of the vocabulary belonging to the `rank`
partition: Note that indices in [fist, last)"""
""" Split the vocabulary into `world_size` chunks and return the first
and last index of the vocabulary belonging to the `rank`
partition: Note that indices in [fist, last)
"""
@
staticmethod
def
vocab_range_from_per_partition_vocab_size
(
...
...
megatron/core/utils.py
View file @
5942af97
...
...
@@ -21,35 +21,6 @@ def divide(numerator, denominator):
return
numerator
//
denominator
def
split_tensor_into_1d_equal_chunks
(
tensor
):
"""Break a tensor into equal 1D chunks."""
data
=
tensor
.
view
(
-
1
)
partition_size
=
(
torch
.
numel
(
data
)
//
parallel_state
.
get_tensor_model_parallel_world_size
()
)
start_index
=
partition_size
*
parallel_state
.
get_tensor_model_parallel_rank
()
end_index
=
start_index
+
partition_size
return
data
[
start_index
:
end_index
]
def
gather_split_1d_tensor
(
tensor
):
"""Opposite of above function, gather values from model parallel ranks."""
world_size
=
parallel_state
.
get_tensor_model_parallel_world_size
()
numel
=
torch
.
numel
(
tensor
)
numel_gathered
=
world_size
*
numel
gathered
=
torch
.
empty
(
numel_gathered
,
dtype
=
tensor
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
torch
.
distributed
.
_all_gather_base
(
gathered
,
tensor
,
group
=
parallel_state
.
get_tensor_model_parallel_group
()
)
return
gathered
class
GlobalMemoryBuffer
:
"""Global buffer to avoid dynamic memory allocations.
Caller should ensure that buffers of the same name
...
...
megatron/data/biencoder_dataset_utils.py
View file @
5942af97
...
...
@@ -4,7 +4,8 @@ import time
import
numpy
as
np
import
torch
from
megatron
import
get_args
,
get_tokenizer
,
mpu
,
print_rank_0
from
megatron
import
get_args
,
get_tokenizer
,
print_rank_0
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
\
pad_and_convert_to_numpy
from
megatron.data.data_samplers
import
MegatronPretrainingSampler
...
...
@@ -57,7 +58,7 @@ def get_ict_batch(data_iterator):
data
=
None
else
:
data
=
next
(
data_iterator
)
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
data_b
=
tensor_parallel
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
query_tokens
=
data_b
[
'query_tokens'
].
long
()
...
...
megatron/data/blendable_dataset.py
View file @
5942af97
...
...
@@ -8,8 +8,6 @@ import numpy as np
import
torch
from
megatron
import
print_rank_0
from
megatron
import
mpu
class
BlendableDataset
(
torch
.
utils
.
data
.
Dataset
):
...
...
megatron/data/data_samplers.py
View file @
5942af97
...
...
@@ -8,7 +8,7 @@ import torch
import
numpy
as
np
from
torch.utils.data
import
Dataset
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
.core
import
mpu
def
build_pretraining_data_loader
(
dataset
,
consumed_samples
):
...
...
megatron/data/dataset_utils.py
View file @
5942af97
...
...
@@ -28,9 +28,9 @@ import torch
from
megatron
import
(
get_args
,
mpu
,
print_rank_0
)
from
megatron.core
import
mpu
from
megatron.data.blendable_dataset
import
BlendableDataset
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
...
...
megatron/data/gpt_dataset.py
View file @
5942af97
...
...
@@ -8,7 +8,8 @@ import time
import
numpy
as
np
import
torch
from
megatron
import
mpu
,
print_rank_0
from
megatron
import
print_rank_0
from
megatron.core
import
mpu
from
megatron.data.blendable_dataset
import
BlendableDataset
from
megatron.data.dataset_utils
import
get_datasets_weights_and_num_samples
from
megatron.data.dataset_utils
import
get_train_valid_test_split_
...
...
megatron/data/orqa_wiki_dataset.py
View file @
5942af97
...
...
@@ -9,7 +9,8 @@ import random
import
torch
from
torch.utils.data
import
Dataset
from
megatron
import
print_rank_0
,
get_args
,
get_tokenizer
,
mpu
from
megatron
import
print_rank_0
,
get_args
,
get_tokenizer
from
megatron.core
import
tensor_parallel
from
megatron.data.biencoder_dataset_utils
import
make_attention_mask
def
get_open_retrieval_wiki_dataset
():
...
...
@@ -32,7 +33,7 @@ def get_open_retrieval_batch(data_iterator):
# Broadcast data.
data
=
None
if
data_iterator
is
None
else
next
(
data_iterator
)
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
data_b
=
tensor_parallel
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
row_id
=
data_b
[
'row_id'
].
long
()
...
...
megatron/data/realm_dataset_utils.py
View file @
5942af97
...
...
@@ -4,9 +4,10 @@ import time
import
numpy
as
np
import
torch
from
megatron
import
mpu
,
print_rank_0
from
megatron
import
print_rank_0
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
pad_and_convert_to_numpy
from
megatron
import
get_args
,
get_tokenizer
,
print_rank_0
,
mpu
from
megatron
import
get_args
,
get_tokenizer
,
print_rank_0
def
get_one_epoch_dataloader
(
dataset
,
micro_batch_size
=
None
):
...
...
@@ -47,7 +48,7 @@ def get_ict_batch(data_iterator):
data
=
None
else
:
data
=
next
(
data_iterator
)
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
data_b
=
tensor_parallel
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
query_tokens
=
data_b
[
'query_tokens'
].
long
()
...
...
megatron/data/realm_index.py
View file @
5942af97
...
...
@@ -7,7 +7,7 @@ import numpy as np
import
torch
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
.core
import
mpu
def
detach
(
tensor
):
...
...
@@ -50,10 +50,10 @@ class OpenRetreivalDataStore(object):
def
load_from_file
(
self
):
"""Populate members from instance saved to file"""
if
mpu
.
is_
u
nitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
if
not
mpu
.
model_parallel_
is_
i
nitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
"
\n
> Unpickling BlockData"
,
flush
=
True
)
state_dict
=
pickle
.
load
(
open
(
self
.
embedding_path
,
'rb'
))
if
mpu
.
is_
u
nitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
if
not
mpu
.
model_parallel_
is_
i
nitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
">> Finished unpickling BlockData
\n
"
,
flush
=
True
)
self
.
embed_data
=
state_dict
[
'embed_data'
]
...
...
@@ -137,7 +137,7 @@ class FaissMIPSIndex(object):
except
ImportError
:
raise
Exception
(
"Error: Please install faiss to use FaissMIPSIndex"
)
if
mpu
.
is_
u
nitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
if
not
mpu
.
model_parallel_
is_
i
nitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
"
\n
> Building index"
,
flush
=
True
)
cpu_index
=
faiss
.
IndexFlatIP
(
self
.
embed_size
)
...
...
@@ -149,12 +149,12 @@ class FaissMIPSIndex(object):
config
.
useFloat16
=
True
gpu_index
=
faiss
.
index_cpu_to_all_gpus
(
cpu_index
,
co
=
config
)
self
.
mips_index
=
faiss
.
IndexIDMap
(
gpu_index
)
if
mpu
.
is_
u
nitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
if
not
mpu
.
model_parallel_
is_
i
nitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
">> Initialized index on GPU"
,
flush
=
True
)
else
:
# CPU index supports IDs so wrap with IDMap
self
.
mips_index
=
faiss
.
IndexIDMap
(
cpu_index
)
if
mpu
.
is_
u
nitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
if
not
mpu
.
model_parallel_
is_
i
nitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
">> Initialized index on CPU"
,
flush
=
True
)
# if we were constructed with a BlockData, then automatically load it
...
...
@@ -199,7 +199,7 @@ class FaissMIPSIndex(object):
self
.
mips_index
.
add_with_ids
(
embeds_arr
,
indices_arr
)
if
mpu
.
is_
u
nitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
if
not
mpu
.
model_parallel_
is_
i
nitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
">>> Finished adding block data to index"
,
flush
=
True
)
def
search_mips_index
(
self
,
query_embeds
,
top_k
,
reconstruct
=
True
):
...
...
megatron/indexer.py
View file @
5942af97
...
...
@@ -4,7 +4,7 @@ import torch
import
torch.distributed
as
dist
from
megatron
import
get_args
,
print_rank_0
from
megatron
import
mpu
from
megatron
.core
import
mpu
from
megatron.checkpointing
import
load_biencoder_checkpoint
from
megatron.data.orqa_wiki_dataset
import
get_open_retrieval_wiki_dataset
from
megatron.data.orqa_wiki_dataset
import
get_open_retrieval_batch
...
...
megatron/initialize.py
View file @
5942af97
...
...
@@ -14,13 +14,10 @@ from megatron import fused_kernels
from
megatron
import
get_adlr_autoresume
from
megatron
import
get_args
from
megatron
import
get_tensorboard_writer
from
megatron
import
mpu
from
megatron
import
core
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.arguments
import
(
parse_args
,
validate_args
)
from
megatron.checkpointing
import
load_args_from_checkpoint
from
megatron.global_vars
import
set_global_variables
from
megatron.mpu
import
(
set_tensor_model_parallel_rank
,
set_tensor_model_parallel_world_size
)
from
megatron.model.transformer
import
bias_dropout_add_fused_train
from
megatron.model.fused_bias_gelu
import
bias_gelu
...
...
@@ -65,13 +62,14 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
args
=
get_args
()
if
args
.
lazy_mpu_init
:
# TODO is this still a necessary option?
args
.
use_cpu_initialization
=
True
# delayed initialization of DDP-related stuff
# We only set basic DDP globals
set_tensor_model_parallel_world_size
(
args
.
tensor_model_parallel_size
)
# We only set basic DDP globals
mpu
.
set_tensor_model_parallel_world_size
(
args
.
tensor_model_parallel_size
)
# and return function for external DDP manager
# to call when it has DDP initialized
set_tensor_model_parallel_rank
(
args
.
rank
)
mpu
.
set_tensor_model_parallel_rank
(
args
.
rank
)
return
finish_mpu_init
else
:
# Megatron's MPU is the master. Complete initialization right away.
...
...
@@ -147,7 +145,7 @@ def _compile_dependencies():
def
_initialize_distributed
():
"""Initialize torch.distributed and
mpu
."""
"""Initialize torch.distributed and
core model parallel
."""
args
=
get_args
()
device_count
=
torch
.
cuda
.
device_count
()
...
...
@@ -185,17 +183,14 @@ def _initialize_distributed():
print
(
'model parallel is already initialized'
)
else
:
mpu
.
initialize_model_parallel
(
args
.
tensor_model_parallel_size
,
args
.
pipeline_model_parallel_size
,
args
.
virtual_pipeline_model_parallel_size
,
args
.
pipeline_model_parallel_split_rank
)
core
.
initialize_model_parallel
(
args
.
tensor_model_parallel_size
,
args
.
pipeline_model_parallel_size
,
args
.
virtual_pipeline_model_parallel_size
,
args
.
pipeline_model_parallel_split_rank
)
print
(
f
'> initialized tensor model parallel with size '
f
'
{
core
.
get_tensor_model_parallel_world_size
()
}
'
)
print
(
f
'> initialized pipeline model parallel with size '
f
'
{
core
.
get_pipeline_model_parallel_world_size
()
}
'
)
if
args
.
rank
==
0
:
print
(
f
'> initialized tensor model parallel with size '
f
'
{
mpu
.
get_tensor_model_parallel_world_size
()
}
'
)
print
(
f
'> initialized pipeline model parallel with size '
f
'
{
mpu
.
get_pipeline_model_parallel_world_size
()
}
'
)
def
_init_autoresume
():
...
...
@@ -219,7 +214,7 @@ def _set_random_seed(seed_, data_parallel_random_init=False):
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
if
torch
.
cuda
.
device_count
()
>
0
:
core
.
tensor_parallel
.
model_parallel_cuda_manual_seed
(
seed
)
tensor_parallel
.
model_parallel_cuda_manual_seed
(
seed
)
else
:
raise
ValueError
(
'Seed ({}) should be a positive integer.'
.
format
(
seed
))
...
...
megatron/model/bert_model.py
View file @
5942af97
...
...
@@ -5,7 +5,7 @@
import
torch
from
megatron
import
get_args
from
megatron
import
core
from
megatron
.core
import
tensor_parallel
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.language_model
import
parallel_lm_logits
from
megatron.model.language_model
import
get_language_model
...
...
@@ -61,7 +61,7 @@ class BertLMHead(MegatronModule):
args
=
get_args
()
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
mpu_vocab_size
))
mpu
.
set_tensor_model_parallel_attributes
(
self
.
bias
,
True
,
0
,
1
)
tensor_parallel
.
set_tensor_model_parallel_attributes
(
self
.
bias
,
True
,
0
,
1
)
self
.
parallel_output
=
parallel_output
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
...
...
@@ -110,9 +110,9 @@ def post_language_model_processing(lm_output, pooled_output,
# lm_logits : [s, b, h] and lm_labels: [s, b]
if
fp16_lm_cross_entropy
:
assert
lm_logits
.
dtype
==
torch
.
half
lm_loss
=
core
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
lm_loss
=
tensor_parallel
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
else
:
lm_loss
=
core
.
vocab_parallel_cross_entropy
(
lm_logits
.
float
(),
lm_loss
=
tensor_parallel
.
vocab_parallel_cross_entropy
(
lm_logits
.
float
(),
lm_labels
)
# [s, b] => [b s]
lm_loss
=
lm_loss
.
transpose
(
0
,
1
).
contiguous
()
...
...
megatron/model/biencoder_model.py
View file @
5942af97
...
...
@@ -2,11 +2,11 @@ import os
import
torch
import
sys
from
megatron
import
get_args
,
print_rank_0
from
megatron
import
get_args
,
print_rank_0
,
get_tokenizer
from
megatron.core
import
mpu
from
megatron.checkpointing
import
fix_query_key_value_ordering
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
from
megatron.checkpointing
import
get_checkpoint_name
from
megatron
import
mpu
,
get_tokenizer
from
megatron.model.bert_model
import
bert_position_ids
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.language_model
import
get_language_model
...
...
megatron/model/classification.py
View file @
5942af97
...
...
@@ -5,7 +5,6 @@
import
torch
from
megatron
import
get_args
,
print_rank_last
from
megatron
import
mpu
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.bert_model
import
bert_extended_attention_mask
,
bert_position_ids
from
megatron.model.language_model
import
get_language_model
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment