Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
942c402d
Commit
942c402d
authored
Feb 18, 2022
by
Jared Casper
Browse files
Making loading arguments from checkpoint cleaner and available more broadly.
parent
06fc51ce
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
97 additions
and
80 deletions
+97
-80
megatron/__init__.py
megatron/__init__.py
+3
-19
megatron/arguments.py
megatron/arguments.py
+8
-7
megatron/checkpointing.py
megatron/checkpointing.py
+16
-11
megatron/global_vars.py
megatron/global_vars.py
+10
-14
megatron/initialize.py
megatron/initialize.py
+13
-4
megatron/model/module.py
megatron/model/module.py
+10
-10
megatron/utils.py
megatron/utils.py
+19
-1
tools/checkpoint_loader_megatron.py
tools/checkpoint_loader_megatron.py
+12
-12
tools/checkpoint_saver_megatron.py
tools/checkpoint_saver_megatron.py
+6
-2
No files found.
megatron/__init__.py
View file @
942c402d
...
@@ -25,22 +25,6 @@ from .global_vars import get_adlr_autoresume
...
@@ -25,22 +25,6 @@ from .global_vars import get_adlr_autoresume
from
.global_vars
import
get_timers
from
.global_vars
import
get_timers
from
.initialize
import
initialize_megatron
from
.initialize
import
initialize_megatron
def
print_rank_0
(
message
):
from
.utils
import
(
print_rank_0
,
"""If distributed is initialized, print only on rank 0."""
is_last_rank
,
if
torch
.
distributed
.
is_initialized
():
print_rank_last
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
message
,
flush
=
True
)
else
:
print
(
message
,
flush
=
True
)
def
is_last_rank
():
return
torch
.
distributed
.
get_rank
()
==
(
torch
.
distributed
.
get_world_size
()
-
1
)
def
print_rank_last
(
message
):
"""If distributed is initialized, print only on last rank."""
if
torch
.
distributed
.
is_initialized
():
if
is_last_rank
():
print
(
message
,
flush
=
True
)
else
:
print
(
message
,
flush
=
True
)
megatron/arguments.py
View file @
942c402d
...
@@ -20,8 +20,7 @@ import os
...
@@ -20,8 +20,7 @@ import os
import
torch
import
torch
def
parse_args
(
extra_args_provider
=
None
,
defaults
=
{},
def
parse_args
(
extra_args_provider
=
None
,
ignore_unknown_args
=
False
):
ignore_unknown_args
=
False
,
validate
=
True
):
"""Parse all arguments."""
"""Parse all arguments."""
parser
=
argparse
.
ArgumentParser
(
description
=
'Megatron-LM Arguments'
,
parser
=
argparse
.
ArgumentParser
(
description
=
'Megatron-LM Arguments'
,
allow_abbrev
=
False
)
allow_abbrev
=
False
)
...
@@ -53,14 +52,13 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -53,14 +52,13 @@ def parse_args(extra_args_provider=None, defaults={},
else
:
else
:
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
validate
:
# Args from environment
return
validate_args
(
args
,
defaults
)
args
.
rank
=
int
(
os
.
getenv
(
'RANK'
,
'0'
))
args
.
world_size
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
'1'
))
return
args
return
args
def
validate_args
(
args
,
defaults
=
{}):
def
validate_args
(
args
,
defaults
=
{}):
# Distributed args.
args
.
rank
=
int
(
os
.
getenv
(
'RANK'
,
'0'
))
args
.
world_size
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
'1'
))
# Tensor model parallel size.
# Tensor model parallel size.
args
.
tensor_model_parallel_size
=
min
(
args
.
tensor_model_parallel_size
=
min
(
args
.
tensor_model_parallel_size
,
args
.
world_size
)
args
.
tensor_model_parallel_size
,
args
.
world_size
)
...
@@ -628,6 +626,9 @@ def _add_checkpointing_args(parser):
...
@@ -628,6 +626,9 @@ def _add_checkpointing_args(parser):
'can reduce startup time when definitely loading from a '
'can reduce startup time when definitely loading from a '
'checkpoint'
,
'checkpoint'
,
dest
=
'perform_initialization'
)
dest
=
'perform_initialization'
)
group
.
add_argument
(
'--use-checkpoint-args'
,
action
=
'store_true'
,
help
=
'Override any command line arguments with arguments '
'from the checkpoint'
)
return
parser
return
parser
...
...
megatron/checkpointing.py
View file @
942c402d
...
@@ -22,11 +22,12 @@ import numpy as np
...
@@ -22,11 +22,12 @@ import numpy as np
import
torch
import
torch
from
megatron
import
(
get_args
,
from
megatron
import
(
mpu
,
mpu
,
update_num_microbatches
)
print_rank_0
,
from
.global_vars
import
get_args
update_num_microbatches
,
from
.utils
import
(
unwrap_model
,
utils
)
print_rank_0
)
_CHECKPOINT_VERSION
=
None
_CHECKPOINT_VERSION
=
None
...
@@ -207,7 +208,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
...
@@ -207,7 +208,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
args
=
get_args
()
args
=
get_args
()
# Only rank zero of the data parallel writes to the disk.
# Only rank zero of the data parallel writes to the disk.
model
=
utils
.
unwrap_model
(
model
)
model
=
unwrap_model
(
model
)
print_rank_0
(
'saving checkpoint at iteration {:7d} to {}'
.
format
(
print_rank_0
(
'saving checkpoint at iteration {:7d} to {}'
.
format
(
iteration
,
args
.
save
))
iteration
,
args
.
save
))
...
@@ -386,8 +387,11 @@ def _load_base_checkpoint(load_dir, rank0=False):
...
@@ -386,8 +387,11 @@ def _load_base_checkpoint(load_dir, rank0=False):
return
state_dict
,
release
return
state_dict
,
release
def
load_args_from_checkpoint
(
args
,
load_arg
=
'load'
):
def
load_args_from_checkpoint
(
args
,
load_arg
=
'load'
):
"""Set any arguments that are not currently set from the checkpoint
"""Set required arguments from the checkpoint specified in the
specified in the arguments.
arguments.
Will overwrite arguments that have a non-None default value, but
will leave any arguments that default to None as set.
Returns the same args NameSpace with the new values added/updated.
Returns the same args NameSpace with the new values added/updated.
...
@@ -406,6 +410,7 @@ def load_args_from_checkpoint(args, load_arg='load'):
...
@@ -406,6 +410,7 @@ def load_args_from_checkpoint(args, load_arg='load'):
return
args
return
args
if
'args'
not
in
state_dict
:
if
'args'
not
in
state_dict
:
print
(
'Checkpoint provided does not have arguments saved.'
)
return
args
return
args
checkpoint_args
=
state_dict
[
'args'
]
checkpoint_args
=
state_dict
[
'args'
]
...
@@ -422,7 +427,7 @@ def load_args_from_checkpoint(args, load_arg='load'):
...
@@ -422,7 +427,7 @@ def load_args_from_checkpoint(args, load_arg='load'):
checkpoint_value
=
getattr
(
checkpoint_args
,
arg_name
,
None
)
checkpoint_value
=
getattr
(
checkpoint_args
,
arg_name
,
None
)
if
checkpoint_value
is
not
None
:
if
checkpoint_value
is
not
None
:
print
(
f
"Setting
{
arg_name
}
to
{
checkpoint_value
}
from checkpoint"
)
print
_rank_0
(
f
"Setting
{
arg_name
}
to
{
checkpoint_value
}
from checkpoint"
)
setattr
(
args
,
arg_name
,
checkpoint_value
)
setattr
(
args
,
arg_name
,
checkpoint_value
)
_set_arg
(
'num_layers'
)
_set_arg
(
'num_layers'
)
...
@@ -453,7 +458,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
...
@@ -453,7 +458,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
args
=
get_args
()
args
=
get_args
()
load_dir
=
getattr
(
args
,
load_arg
)
load_dir
=
getattr
(
args
,
load_arg
)
model
=
utils
.
unwrap_model
(
model
)
model
=
unwrap_model
(
model
)
state_dict
,
release
=
_load_base_checkpoint
(
load_dir
,
False
)
state_dict
,
release
=
_load_base_checkpoint
(
load_dir
,
False
)
...
@@ -574,7 +579,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
...
@@ -574,7 +579,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
args
=
get_args
()
args
=
get_args
()
model
=
utils
.
unwrap_model
(
model
)
model
=
unwrap_model
(
model
)
load_path
=
custom_load_path
if
custom_load_path
is
not
None
else
args
.
load
load_path
=
custom_load_path
if
custom_load_path
is
not
None
else
args
.
load
...
...
megatron/global_vars.py
View file @
942c402d
...
@@ -23,7 +23,6 @@ import torch
...
@@ -23,7 +23,6 @@ import torch
from
megatron
import
dist_signal_handler
from
megatron
import
dist_signal_handler
from
megatron.tokenizer
import
build_tokenizer
from
megatron.tokenizer
import
build_tokenizer
from
.arguments
import
parse_args
from
.microbatches
import
build_num_microbatches_calculator
from
.microbatches
import
build_num_microbatches_calculator
_GLOBAL_ARGS
=
None
_GLOBAL_ARGS
=
None
...
@@ -86,16 +85,14 @@ def _set_signal_handler():
...
@@ -86,16 +85,14 @@ def _set_signal_handler():
_ensure_var_is_not_initialized
(
_GLOBAL_SIGNAL_HANDLER
,
'signal handler'
)
_ensure_var_is_not_initialized
(
_GLOBAL_SIGNAL_HANDLER
,
'signal handler'
)
_GLOBAL_SIGNAL_HANDLER
=
dist_signal_handler
.
DistributedSignalHandler
().
__enter__
()
_GLOBAL_SIGNAL_HANDLER
=
dist_signal_handler
.
DistributedSignalHandler
().
__enter__
()
def
set_global_variables
(
extra_args_provider
=
None
,
args_defaults
=
{},
def
set_global_variables
(
args
):
ignore_unknown_args
=
False
,
parse_args
=
True
):
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
if
parse_args
:
args
=
_parse_args
(
extra_args_provider
=
extra_args_provider
,
assert
args
is
not
None
defaults
=
args_defaults
,
ignore_unknown_args
=
ignore_unknown_args
)
_ensure_var_is_not_initialized
(
_GLOBAL_ARGS
,
'args'
)
else
:
set_args
(
args
)
_ensure_var_is_initialized
(
_GLOBAL_ARGS
,
'args'
)
args
=
get_args
()
_build_num_microbatches_calculator
(
args
)
_build_num_microbatches_calculator
(
args
)
if
args
.
vocab_file
:
if
args
.
vocab_file
:
_
=
_build_tokenizer
(
args
)
_
=
_build_tokenizer
(
args
)
...
@@ -117,10 +114,9 @@ def _parse_args(extra_args_provider=None, defaults={},
...
@@ -117,10 +114,9 @@ def _parse_args(extra_args_provider=None, defaults={},
"""Parse entire arguments."""
"""Parse entire arguments."""
global
_GLOBAL_ARGS
global
_GLOBAL_ARGS
_ensure_var_is_not_initialized
(
_GLOBAL_ARGS
,
'args'
)
_ensure_var_is_not_initialized
(
_GLOBAL_ARGS
,
'args'
)
_GLOBAL_ARGS
=
parse_args
(
extra_args_provider
=
extra_args_provider
,
defaults
=
defaults
,
ignore_unknown_args
=
ignore_unknown_args
,
_GLOBAL_ARGS
=
args
validate
=
True
)
return
_GLOBAL_ARGS
return
_GLOBAL_ARGS
...
...
megatron/initialize.py
View file @
942c402d
...
@@ -28,6 +28,8 @@ from megatron import get_adlr_autoresume
...
@@ -28,6 +28,8 @@ from megatron import get_adlr_autoresume
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
get_tensorboard_writer
from
megatron
import
get_tensorboard_writer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.arguments
import
(
parse_args
,
validate_args
)
from
megatron.checkpointing
import
load_args_from_checkpoint
from
megatron.global_vars
import
set_global_variables
from
megatron.global_vars
import
set_global_variables
from
megatron.mpu
import
(
set_tensor_model_parallel_rank
,
from
megatron.mpu
import
(
set_tensor_model_parallel_rank
,
set_tensor_model_parallel_world_size
)
set_tensor_model_parallel_world_size
)
...
@@ -47,11 +49,18 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
...
@@ -47,11 +49,18 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Make sure cuda is available.
# Make sure cuda is available.
assert
torch
.
cuda
.
is_available
(),
'Megatron requires CUDA.'
assert
torch
.
cuda
.
is_available
(),
'Megatron requires CUDA.'
# Parse args, build tokenizer, and set adlr-autoresume,
# Parse arguments
args
=
parse_args
(
extra_args_provider
,
ignore_unknown_args
)
if
args
.
use_checkpoint_args
or
args_defaults
.
get
(
'use_checkpoint_args'
,
False
):
assert
args
.
load
is
not
None
,
'--use-checkpoints-args requires --load argument'
load_args_from_checkpoint
(
args
)
validate_args
(
args
,
args_defaults
)
# set global args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
# tensorboard-writer, and timers.
set_global_variables
(
extra_args_provider
=
extra_args_provider
,
set_global_variables
(
args
)
args_defaults
=
args_defaults
,
ignore_unknown_args
=
ignore_unknown_args
)
# torch.distributed initialization
# torch.distributed initialization
def
finish_mpu_init
():
def
finish_mpu_init
():
...
...
megatron/model/module.py
View file @
942c402d
...
@@ -72,16 +72,6 @@ class MegatronModule(torch.nn.Module):
...
@@ -72,16 +72,6 @@ class MegatronModule(torch.nn.Module):
if
args
.
pipeline_model_parallel_size
==
1
:
if
args
.
pipeline_model_parallel_size
==
1
:
return
return
if
not
torch
.
distributed
.
is_initialized
():
if
not
getattr
(
MegatronModule
,
"embedding_warning_printed"
,
False
):
print
(
"WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
MegatronModule
.
embedding_warning_printed
=
True
return
# Parameters are shared between the word embeddings layers, and the
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# one stage, the initial embedding layer and the head are on different
...
@@ -112,6 +102,16 @@ class MegatronModule(torch.nn.Module):
...
@@ -112,6 +102,16 @@ class MegatronModule(torch.nn.Module):
self
.
pre_process
:
self
.
pre_process
:
self
.
language_model
.
embedding
.
zero_parameters
()
self
.
language_model
.
embedding
.
zero_parameters
()
if
not
torch
.
distributed
.
is_initialized
():
if
not
getattr
(
MegatronModule
,
"embedding_warning_printed"
,
False
):
print
(
"WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
MegatronModule
.
embedding_warning_printed
=
True
return
# Ensure that first and last stages have the same initial parameter
# Ensure that first and last stages have the same initial parameter
# values.
# values.
if
mpu
.
is_rank_in_embedding_group
():
if
mpu
.
is_rank_in_embedding_group
():
...
...
megatron/utils.py
View file @
942c402d
...
@@ -24,7 +24,6 @@ from apex.multi_tensor_apply import multi_tensor_applier
...
@@ -24,7 +24,6 @@ from apex.multi_tensor_apply import multi_tensor_applier
import
amp_C
import
amp_C
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
get_adlr_autoresume
from
megatron
import
get_adlr_autoresume
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.model.module
import
param_is_not_shared
from
megatron.model.module
import
param_is_not_shared
...
@@ -204,3 +203,22 @@ def get_ltor_masks_and_position_ids(data,
...
@@ -204,3 +203,22 @@ def get_ltor_masks_and_position_ids(data,
return
attention_mask
,
loss_mask
,
position_ids
return
attention_mask
,
loss_mask
,
position_ids
def
print_rank_0
(
message
):
"""If distributed is initialized, print only on rank 0."""
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
message
,
flush
=
True
)
else
:
print
(
message
,
flush
=
True
)
def
is_last_rank
():
return
torch
.
distributed
.
get_rank
()
==
(
torch
.
distributed
.
get_world_size
()
-
1
)
def
print_rank_last
(
message
):
"""If distributed is initialized, print only on last rank."""
if
torch
.
distributed
.
is_initialized
():
if
is_last_rank
():
print
(
message
,
flush
=
True
)
else
:
print
(
message
,
flush
=
True
)
tools/checkpoint_loader_megatron.py
View file @
942c402d
...
@@ -29,7 +29,7 @@ def _load_checkpoint(queue, args):
...
@@ -29,7 +29,7 @@ def _load_checkpoint(queue, args):
from
megatron.arguments
import
parse_args
,
validate_args
from
megatron.arguments
import
parse_args
,
validate_args
from
megatron.global_vars
import
set_args
,
set_global_variables
from
megatron.global_vars
import
set_args
,
set_global_variables
from
megatron.checkpointing
import
load_args_from_checkpoint
,
load_checkpoint
from
megatron.checkpointing
import
load_args_from_checkpoint
,
load_checkpoint
from
megatron.model
import
ModelType
from
megatron.model
import
ModelType
,
module
from
megatron
import
mpu
,
fused_kernels
from
megatron
import
mpu
,
fused_kernels
except
ModuleNotFoundError
:
except
ModuleNotFoundError
:
print
(
"Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting."
)
print
(
"Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting."
)
...
@@ -51,9 +51,15 @@ def _load_checkpoint(queue, args):
...
@@ -51,9 +51,15 @@ def _load_checkpoint(queue, args):
'--load'
,
args
.
load_dir
'--load'
,
args
.
load_dir
]
]
margs
=
parse_args
(
validate
=
False
)
margs
=
parse_args
()
margs
=
load_args_from_checkpoint
(
margs
)
margs
=
load_args_from_checkpoint
(
margs
)
# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes
margs
.
world_size
=
margs
.
tensor_model_parallel_size
*
margs
.
pipeline_model_parallel_size
margs
=
validate_args
(
margs
)
def
check_for_arg
(
arg_name
):
def
check_for_arg
(
arg_name
):
if
getattr
(
margs
,
arg_name
,
None
)
is
None
:
if
getattr
(
margs
,
arg_name
,
None
)
is
None
:
print
(
f
"Checkpoint does not specify the argument
{
arg_name
}
. Exiting."
)
print
(
f
"Checkpoint does not specify the argument
{
arg_name
}
. Exiting."
)
...
@@ -71,13 +77,6 @@ def _load_checkpoint(queue, args):
...
@@ -71,13 +77,6 @@ def _load_checkpoint(queue, args):
check_for_arg
(
'tokenizer_type'
)
check_for_arg
(
'tokenizer_type'
)
check_for_arg
(
'iteration'
)
check_for_arg
(
'iteration'
)
check_for_arg
(
'bert_binary_head'
)
check_for_arg
(
'bert_binary_head'
)
# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes
os
.
environ
[
"WORLD_SIZE"
]
=
f
'
{
margs
.
tensor_model_parallel_size
*
margs
.
pipeline_model_parallel_size
}
'
margs
=
validate_args
(
margs
)
check_for_arg
(
'params_dtype'
)
check_for_arg
(
'params_dtype'
)
# Determine how to make our models
# Determine how to make our models
...
@@ -90,6 +89,9 @@ def _load_checkpoint(queue, args):
...
@@ -90,6 +89,9 @@ def _load_checkpoint(queue, args):
else
:
else
:
raise
Exception
(
f
'unrecognized model type:
{
args
.
model_type
}
'
)
raise
Exception
(
f
'unrecognized model type:
{
args
.
model_type
}
'
)
# supress warning about torch.distributed not being initialized
module
.
MegatronModule
.
embedding_warning_printed
=
True
def
get_models
(
count
,
dtype
,
pre_process
,
post_process
):
def
get_models
(
count
,
dtype
,
pre_process
,
post_process
):
# with concurrent.futures.ThreadPoolExecutor(max_workers=count) as executor:
# with concurrent.futures.ThreadPoolExecutor(max_workers=count) as executor:
# futures = [executor.submit(model_provider, pre_process, post_process) for _ in range(count)]
# futures = [executor.submit(model_provider, pre_process, post_process) for _ in range(count)]
...
@@ -105,14 +107,12 @@ def _load_checkpoint(queue, args):
...
@@ -105,14 +107,12 @@ def _load_checkpoint(queue, args):
models
.
append
(
model_
[
0
])
models
.
append
(
model_
[
0
])
return
models
return
models
set_args
(
margs
)
if
margs
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
if
margs
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
print
(
"Model with an interleaved pipeline schedule are not yet supported."
)
print
(
"Model with an interleaved pipeline schedule are not yet supported."
)
queue
.
put
(
"exit"
)
queue
.
put
(
"exit"
)
exit
(
1
)
exit
(
1
)
set_global_variables
(
parse_args
=
False
)
set_global_variables
(
margs
)
mpu
.
initialize
.
set_tensor_model_parallel_world_size
(
margs
.
tensor_model_parallel_size
)
mpu
.
initialize
.
set_tensor_model_parallel_world_size
(
margs
.
tensor_model_parallel_size
)
mpu
.
initialize
.
set_pipeline_model_parallel_world_size
(
margs
.
pipeline_model_parallel_size
)
mpu
.
initialize
.
set_pipeline_model_parallel_world_size
(
margs
.
pipeline_model_parallel_size
)
fused_kernels
.
load
(
margs
)
fused_kernels
.
load
(
margs
)
...
...
tools/checkpoint_saver_megatron.py
View file @
942c402d
...
@@ -28,6 +28,7 @@ def save_checkpoint(queue, args):
...
@@ -28,6 +28,7 @@ def save_checkpoint(queue, args):
sys
.
path
.
insert
(
0
,
args
.
megatron_path
)
sys
.
path
.
insert
(
0
,
args
.
megatron_path
)
try
:
try
:
from
megatron.arguments
import
(
parse_args
,
validate_args
)
from
megatron.checkpointing
import
save_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.global_vars
import
set_global_variables
,
get_args
from
megatron.global_vars
import
set_global_variables
,
get_args
from
megatron.model
import
ModelType
from
megatron.model
import
ModelType
...
@@ -46,7 +47,6 @@ def save_checkpoint(queue, args):
...
@@ -46,7 +47,6 @@ def save_checkpoint(queue, args):
md
=
queue_get
()
md
=
queue_get
()
if
args
.
target_tensor_parallel_size
is
None
:
if
args
.
target_tensor_parallel_size
is
None
:
if
hasattr
(
md
,
'previous_tensor_parallel_size'
):
if
hasattr
(
md
,
'previous_tensor_parallel_size'
):
args
.
target_tensor_parallel_size
=
md
.
previous_tensor_parallel_size
args
.
target_tensor_parallel_size
=
md
.
previous_tensor_parallel_size
...
@@ -102,7 +102,10 @@ def save_checkpoint(queue, args):
...
@@ -102,7 +102,10 @@ def save_checkpoint(queue, args):
if
md
.
model_type
==
'BERT'
and
not
md
.
bert_binary_head
:
if
md
.
model_type
==
'BERT'
and
not
md
.
bert_binary_head
:
sys
.
argv
.
append
(
'--bert-no-binary-head'
)
sys
.
argv
.
append
(
'--bert-no-binary-head'
)
set_global_variables
()
margs
=
parse_args
()
validate_args
(
margs
)
set_global_variables
(
margs
)
# margs = megatron args
# margs = megatron args
margs
=
get_args
()
margs
=
get_args
()
...
@@ -157,6 +160,7 @@ def save_checkpoint(queue, args):
...
@@ -157,6 +160,7 @@ def save_checkpoint(queue, args):
else
:
else
:
print
(
"Original vocab size not specified, leaving embedding table as-is. "
print
(
"Original vocab size not specified, leaving embedding table as-is. "
"If you've changed the tensor parallel size this could cause problems."
)
"If you've changed the tensor parallel size this could cause problems."
)
margs
.
padded_vocab_size
=
orig_word_embed
.
shape
[
0
]
full_word_embed
=
orig_word_embed
full_word_embed
=
orig_word_embed
# Split into new tensor model parallel sizes
# Split into new tensor model parallel sizes
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment