Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
05ea0cca
Commit
05ea0cca
authored
May 21, 2020
by
Neel Kant
Browse files
Change sync variable to gloo backend
parent
a670b6c9
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
33 additions
and
12 deletions
+33
-12
indexer.py
indexer.py
+7
-5
megatron/initialize.py
megatron/initialize.py
+1
-0
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+13
-1
megatron/training.py
megatron/training.py
+12
-6
No files found.
indexer.py
View file @
05ea0cca
...
@@ -16,7 +16,7 @@ from megatron.data.samplers import DistributedBatchSampler
...
@@ -16,7 +16,7 @@ from megatron.data.samplers import DistributedBatchSampler
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
initialize_megatron
from
megatron.model
import
REALMRetriever
from
megatron.model
import
REALMRetriever
from
megatron.global_vars
import
set_global_variables
from
megatron.global_vars
import
set_global_variables
from
megatron.mpu.initialize
import
get_index_ready
,
get_index_group
,
get_train_group
,
get_data_parallel_group
from
megatron.mpu.initialize
import
get_index_ready
,
get_index_group
,
get_train_group
,
get_data_parallel_group
,
get_gloo_comm_group
from
megatron.mpu.initialize
import
set_data_parallel_group
,
set_model_parallel_group
,
init_realm_groups
from
megatron.mpu.initialize
import
set_data_parallel_group
,
set_model_parallel_group
,
init_realm_groups
from
megatron.initialize
import
init_distributed
,
_init_autoresume
,
_set_random_seed
,
_write_args_to_tensorboard
from
megatron.initialize
import
init_distributed
,
_init_autoresume
,
_set_random_seed
,
_write_args_to_tensorboard
from
megatron.training
import
get_model
from
megatron.training
import
get_model
...
@@ -176,10 +176,10 @@ class AsyncIndexBuilder(object):
...
@@ -176,10 +176,10 @@ class AsyncIndexBuilder(object):
INDEX_READY
=
1
-
INDEX_READY
INDEX_READY
=
1
-
INDEX_READY
print
(
"Switched INDEX_READY"
,
flush
=
True
)
print
(
"Switched INDEX_READY"
,
flush
=
True
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
send_handle
=
dist
.
broadcast
(
INDEX_READY
,
self
.
main_builder_idx
,
async_op
=
True
)
send_handle
=
dist
.
broadcast
(
INDEX_READY
,
self
.
main_builder_idx
,
group
=
get_gloo_comm_group
(),
async_op
=
True
)
torch
.
distributed
.
barrier
(
get_data_parallel_group
())
torch
.
distributed
.
barrier
(
get_data_parallel_group
())
recv_handle
=
dist
.
broadcast
(
INDEX_READY
,
0
)
dist
.
broadcast
(
INDEX_READY
,
0
,
group
=
get_gloo_comm_group
()
)
class
BasicIndexBuilder
(
object
):
class
BasicIndexBuilder
(
object
):
...
@@ -287,12 +287,14 @@ def get_ict_dataset(use_titles=True):
...
@@ -287,12 +287,14 @@ def get_ict_dataset(use_titles=True):
return
dataset
return
dataset
def
get_one_epoch_dataloader
(
dataset
):
def
get_one_epoch_dataloader
(
dataset
,
batch_size
=
None
):
args
=
get_args
()
args
=
get_args
()
world_size
=
mpu
.
get_data_parallel_world_size
()
world_size
=
mpu
.
get_data_parallel_world_size
()
rank
=
mpu
.
get_data_parallel_rank
()
rank
=
mpu
.
get_data_parallel_rank
()
global_batch_size
=
args
.
batch_size
*
world_size
if
batch_size
is
None
:
batch_size
=
args
.
batch_size
global_batch_size
=
batch_size
*
world_size
num_workers
=
args
.
num_workers
num_workers
=
args
.
num_workers
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
...
...
megatron/initialize.py
View file @
05ea0cca
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
"""Megatron initialization."""
"""Megatron initialization."""
import
datetime
import
random
import
random
import
os
import
os
...
...
megatron/mpu/initialize.py
View file @
05ea0cca
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
"""Model and data parallel groups."""
"""Model and data parallel groups."""
import
datetime
import
torch
import
torch
from
.utils
import
ensure_divisibility
from
.utils
import
ensure_divisibility
...
@@ -26,6 +27,7 @@ _MODEL_PARALLEL_GROUP = None
...
@@ -26,6 +27,7 @@ _MODEL_PARALLEL_GROUP = None
# Data parallel group that the current rank belongs to.
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP
=
None
_DATA_PARALLEL_GROUP
=
None
_GLOO_COMM_GROUP
=
None
_TRAIN_GROUP
=
None
_TRAIN_GROUP
=
None
_INDEX_GROUP
=
None
_INDEX_GROUP
=
None
_INDEX_READY
=
None
_INDEX_READY
=
None
...
@@ -177,12 +179,22 @@ def destroy_model_parallel():
...
@@ -177,12 +179,22 @@ def destroy_model_parallel():
def
init_realm_groups
(
max_training_rank
,
world_size
):
def
init_realm_groups
(
max_training_rank
,
world_size
):
global
_GLOO_COMM_GROUP
_GLOO_COMM_GROUP
=
torch
.
distributed
.
new_group
(
list
(
range
(
world_size
)),
backend
=
"gloo"
,
timeout
=
datetime
.
timedelta
(
0
,
7200
))
global
_TRAIN_GROUP
global
_TRAIN_GROUP
_TRAIN_GROUP
=
torch
.
distributed
.
new_group
(
list
(
range
(
max_training_rank
)))
_TRAIN_GROUP
=
torch
.
distributed
.
new_group
(
list
(
range
(
max_training_rank
)))
global
_INDEX_GROUP
global
_INDEX_GROUP
_INDEX_GROUP
=
torch
.
distributed
.
new_group
(
list
(
range
(
max_training_rank
,
world_size
)))
_INDEX_GROUP
=
torch
.
distributed
.
new_group
(
list
(
range
(
max_training_rank
,
world_size
)))
global
_INDEX_READY
global
_INDEX_READY
_INDEX_READY
=
torch
.
zeros
(
1
).
cuda
()
_INDEX_READY
=
torch
.
zeros
(
1
)
def
get_gloo_comm_group
():
global
_GLOO_COMM_GROUP
assert
_GLOO_COMM_GROUP
is
not
None
return
_GLOO_COMM_GROUP
def
get_train_group
():
def
get_train_group
():
...
...
megatron/training.py
View file @
05ea0cca
...
@@ -36,7 +36,7 @@ from megatron.initialize import initialize_megatron
...
@@ -36,7 +36,7 @@ from megatron.initialize import initialize_megatron
from
megatron.learning_rates
import
AnnealingLR
from
megatron.learning_rates
import
AnnealingLR
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
get_params_for_weight_decay_optimization
from
megatron.model
import
get_params_for_weight_decay_optimization
from
megatron.mpu.initialize
import
get_index_ready
,
get_train_group
,
get_data_parallel_group
from
megatron.mpu.initialize
import
get_index_ready
,
get_train_group
,
get_data_parallel_group
,
get_gloo_comm_group
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
make_data_loader
from
megatron.utils
import
make_data_loader
from
megatron.utils
import
report_memory
from
megatron.utils
import
report_memory
...
@@ -374,14 +374,20 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -374,14 +374,20 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
timers
(
'interval time'
).
start
()
timers
(
'interval time'
).
start
()
report_memory_flag
=
True
report_memory_flag
=
True
global
INDEX_READY
global
INDEX_READY
recv_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
async_op
=
True
)
# start off by posting a receive call which will be answered.
recv_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
group
=
get_gloo_comm_group
(),
async_op
=
True
)
while
iteration
<
args
.
train_iters
:
while
iteration
<
args
.
train_iters
:
if
args
.
max_training_rank
is
not
None
and
INDEX_READY
==
1
:
# this only applies for realm right here
if
args
.
max_training_rank
is
not
None
and
recv_handle
.
is_completed
():
# should add check that INDEX_READY == 1 but what else could be happening
true_model
=
model
true_model
=
model
if
hasattr
(
true_model
,
'module'
):
if
hasattr
(
true_model
,
'module'
):
true_model
=
true_model
.
module
true_model
=
true_model
.
module
if
hasattr
(
true_model
,
'module'
):
if
hasattr
(
true_model
,
'module'
):
true_model
=
true_model
.
module
true_model
=
true_model
.
module
print
(
">>>>>>> starting to reload index"
,
flush
=
True
)
print
(
">>>>>>> starting to reload index"
,
flush
=
True
)
true_model
.
retriever
.
reload_index
()
true_model
.
retriever
.
reload_index
()
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
...
@@ -390,10 +396,10 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -390,10 +396,10 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
INDEX_READY
=
1
-
INDEX_READY
INDEX_READY
=
1
-
INDEX_READY
print
(
">>> Switched index ready"
,
flush
=
True
)
print
(
">>> Switched index ready"
,
flush
=
True
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
send_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
0
)
send_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
0
,
group
=
get_gloo_comm_group
()
)
torch
.
distributed
.
barrier
(
get_data_parallel_group
())
torch
.
distributed
.
barrier
(
get_data_parallel_group
())
recv_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
async_op
=
True
)
recv_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
group
=
get_gloo_comm_group
(),
async_op
=
True
)
el
se
:
el
if
iteration
<
100
:
print
(
"moving right along"
,
flush
=
True
)
print
(
"moving right along"
,
flush
=
True
)
loss_dict
,
skipped_iter
=
train_step
(
forward_step_func
,
loss_dict
,
skipped_iter
=
train_step
(
forward_step_func
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment