Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a9e19f8e
Commit
a9e19f8e
authored
Mar 27, 2020
by
Mohammad
Browse files
added initialize megatron
parent
83aa9219
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
125 additions
and
0 deletions
+125
-0
megatron/arguments.py
megatron/arguments.py
+16
-0
megatron/initialize.py
megatron/initialize.py
+109
-0
No files found.
megatron/arguments.py
View file @
a9e19f8e
...
@@ -24,6 +24,22 @@ _GLOBAL_ARGS = None
...
@@ -24,6 +24,22 @@ _GLOBAL_ARGS = None
def
_print_args
():
"""Print arguments."""
args
=
get_args
()
writer
=
get_tensorboard_writer
()
print_rank_0
(
'arguments:'
)
str_list
=
[]
for
arg
in
vars
(
args
):
dots
=
'.'
*
(
29
-
len
(
arg
))
str_list
.
append
(
' {} {} {}'
.
format
(
arg
,
dots
,
getattr
(
args
,
arg
)))
if
writer
:
writer
.
add_text
(
arg
,
str
(
getattr
(
args
,
arg
)))
for
arg
in
sorted
(
str_list
,
key
=
lambda
x
:
x
.
lower
()):
print_rank_0
(
arg
)
def
parse_args
(
extra_args_provider
=
None
):
def
parse_args
(
extra_args_provider
=
None
):
global
_GLOBAL_ARGS
global
_GLOBAL_ARGS
...
...
megatron/initialize.py
0 → 100644
View file @
a9e19f8e
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron initialization."""
import
random
import
os
import
numpy
as
np
import
torch
from
megatron
import
mpu
from
.global_vars
import
get_adlr_autoresume
from
.global_vars
import
get_args
from
.global_vars
import
set_global_variables
def
initialize_megatron
(
extra_args_provider
=
None
):
"""Set global variables, initialize distributed, and
set autoresume and random seeds."""
# Parse args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
set_global_variables
(
extra_args_provider
=
extra_args_provider
)
# Pytorch distributed.
_initialize_distributed
()
# Autoresume.
_init_autoresume
()
# Random seeds for reproducability.
args
=
get_args
()
if
args
.
rank
==
0
:
print
(
'> setting random seeds to {} ...'
.
format
(
args
.
seed
))
_set_random_seed
(
args
.
seed
)
def
_initialize_distributed
():
"""Initialize torch.distributed and mpu."""
args
=
get_args
()
if
torch
.
distributed
.
is_initialized
():
if
args
.
rank
==
0
:
print
(
'torch distributed is already initialized, '
'skipping initialization ...'
,
flush
=
True
)
args
.
rank
=
torch
.
distributed
.
get_rank
()
args
.
world_size
=
torch
.
distributed
.
get_world_size
()
device
=
torch
.
cuda
.
current_device
()
local_rank
=
args
.
rank
%
torch
.
cuda
.
device_count
()
assert
local_rank
==
device
,
\
'expected local-rank to be the same as rank % device-count.'
else
:
if
args
.
rank
==
0
:
print
(
'> initializing torch distributed ...'
,
flush
=
True
)
# Manually set the device ids.
device
=
args
.
rank
%
torch
.
cuda
.
device_count
()
if
args
.
local_rank
is
not
None
:
assert
args
.
local_rank
==
device
,
\
'expected local-rank to be the same as rank % device-count.'
else
:
args
.
local_rank
=
device
torch
.
cuda
.
set_device
(
device
)
# Call the init process
init_method
=
'tcp://'
master_ip
=
os
.
getenv
(
'MASTER_ADDR'
,
'localhost'
)
master_port
=
os
.
getenv
(
'MASTER_PORT'
,
'6000'
)
init_method
+=
master_ip
+
':'
+
master_port
torch
.
distributed
.
init_process_group
(
backend
=
args
.
distributed_backend
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
init_method
=
init_method
)
# Set the model-parallel / data-parallel communicators.
mpu
.
initialize_model_parallel
(
args
.
model_parallel_size
)
def
_init_autoresume
():
"""Set autoresume start time."""
autoresume
=
get_adlr_autoresume
()
if
autoresume
:
torch
.
distributed
.
barrier
()
autoresume
.
init
()
torch
.
distributed
.
barrier
()
def
_set_random_seed
(
seed
):
"""Set random seed for reproducability."""
if
seed
is
not
None
and
seed
>
0
:
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
mpu
.
model_parallel_cuda_manual_seed
(
seed
)
else
:
raise
ValueError
(
'Seed ({}) should be a positive integer.'
.
format
(
seed
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment