"tests/vscode:/vscode.git/clone" did not exist on "b26da25297c42c81ab98bb906979e1269ca05f37"
Commit 42d2be06 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'onnx-safe' into 'main'

--onnx-safe made overrideable, -lazy_mpu_init act as _USE_CPU_INITIALIZATION

See merge request ADLR/megatron-lm!117
parents 3d1cbecf b6b48a32
...@@ -169,7 +169,7 @@ def _add_network_size_args(parser): ...@@ -169,7 +169,7 @@ def _add_network_size_args(parser):
help='Use OpenAIs GeLU implementation. This option' help='Use OpenAIs GeLU implementation. This option'
'should not be used unless for backward compatibility' 'should not be used unless for backward compatibility'
'reasons.') 'reasons.')
group.add_argument('--onnx-safe', action='store_true', group.add_argument('--onnx-safe', type=bool, required=False,
help='Use workarounds for known problems with Torch ONNX exporter') help='Use workarounds for known problems with Torch ONNX exporter')
return parser return parser
...@@ -335,8 +335,11 @@ def _add_distributed_args(parser): ...@@ -335,8 +335,11 @@ def _add_distributed_args(parser):
help='local rank passed from distributed launcher.') help='local rank passed from distributed launcher.')
group.add_argument('--lazy-mpu-init', type=bool, required=False, group.add_argument('--lazy-mpu-init', type=bool, required=False,
help='If set to True, initialize_megatron() skips DDP initialization' help='If set to True, initialize_megatron() skips DDP initialization'
' and returns function to complete it instead' ' and returns function to complete it instead.'
'Also turns on --use-cpu-initialization flag.'
'This is for external DDP manager.' ) 'This is for external DDP manager.' )
group.add_argument('--use-cpu-initialization', action='store_true',
help='If set, affine parallel weights initialization uses CPU' )
return parser return parser
......
...@@ -62,6 +62,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -62,6 +62,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
args = get_args() args = get_args()
if args.lazy_mpu_init: if args.lazy_mpu_init:
args.use_cpu_initialization=True
# delayed initialization of DDP-related stuff # delayed initialization of DDP-related stuff
# We only set basic DDP globals # We only set basic DDP globals
set_model_parallel_world_size(args.model_parallel_size) set_model_parallel_world_size(args.model_parallel_size)
......
...@@ -47,10 +47,6 @@ from .utils import split_tensor_along_last_dim ...@@ -47,10 +47,6 @@ from .utils import split_tensor_along_last_dim
from .utils import VocabUtility from .utils import VocabUtility
from megatron import get_args from megatron import get_args
_USE_CPU_INITIALIZATION = False
def _initialize_affine_weight_gpu(weight, init_method, def _initialize_affine_weight_gpu(weight, init_method,
partition_dim, stride=1): partition_dim, stride=1):
"""Initialize affine weight for model parallel on GPU.""" """Initialize affine weight for model parallel on GPU."""
...@@ -141,7 +137,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -141,7 +137,7 @@ class VocabParallelEmbedding(torch.nn.Module):
# Allocate weights and initialize. # Allocate weights and initialize.
args = get_args() args = get_args()
if _USE_CPU_INITIALIZATION: if args.use_cpu_initialization:
self.weight = Parameter(torch.empty( self.weight = Parameter(torch.empty(
self.num_embeddings_per_partition, self.embedding_dim, self.num_embeddings_per_partition, self.embedding_dim,
dtype=args.params_dtype)) dtype=args.params_dtype))
...@@ -217,7 +213,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -217,7 +213,7 @@ class ColumnParallelLinear(torch.nn.Module):
# we allocate the transpose. # we allocate the transpose.
# Initialize weight. # Initialize weight.
args = get_args() args = get_args()
if _USE_CPU_INITIALIZATION: if args.use_cpu_initialization:
self.weight = Parameter(torch.empty(self.output_size_per_partition, self.weight = Parameter(torch.empty(self.output_size_per_partition,
self.input_size, self.input_size,
dtype=args.params_dtype)) dtype=args.params_dtype))
...@@ -233,7 +229,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -233,7 +229,7 @@ class ColumnParallelLinear(torch.nn.Module):
partition_dim=0, stride=stride) partition_dim=0, stride=stride)
if bias: if bias:
if _USE_CPU_INITIALIZATION: if args.use_cpu_initialization:
self.bias = Parameter(torch.empty( self.bias = Parameter(torch.empty(
self.output_size_per_partition, dtype=args.params_dtype)) self.output_size_per_partition, dtype=args.params_dtype))
else: else:
...@@ -311,7 +307,7 @@ class RowParallelLinear(torch.nn.Module): ...@@ -311,7 +307,7 @@ class RowParallelLinear(torch.nn.Module):
# we allocate the transpose. # we allocate the transpose.
# Initialize weight. # Initialize weight.
args = get_args() args = get_args()
if _USE_CPU_INITIALIZATION: if args.use_cpu_initialization:
self.weight = Parameter(torch.empty(self.output_size, self.weight = Parameter(torch.empty(self.output_size,
self.input_size_per_partition, self.input_size_per_partition,
dtype=args.params_dtype)) dtype=args.params_dtype))
...@@ -326,7 +322,7 @@ class RowParallelLinear(torch.nn.Module): ...@@ -326,7 +322,7 @@ class RowParallelLinear(torch.nn.Module):
_initialize_affine_weight_gpu(self.weight, init_method, _initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=1, stride=stride) partition_dim=1, stride=stride)
if bias: if bias:
if _USE_CPU_INITIALIZATION: if args.use_cpu_initialization:
self.bias = Parameter(torch.empty(self.output_size, self.bias = Parameter(torch.empty(self.output_size,
dtype=args.params_dtype)) dtype=args.params_dtype))
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment