Commit 5ba627ec authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'optimize_initialization' into 'main'

added gpu initialization and option to avoid master values for initialization

See merge request ADLR/megatron-lm!105
parents 0a74547c 0403b808
......@@ -18,6 +18,8 @@
import argparse
import os
import torch
def parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
......@@ -62,6 +64,15 @@ def parse_args(extra_args_provider=None, defaults={},
if args.loss_scale is None:
args.dynamic_loss_scale = True
# Parameters dtype.
args.params_dtype = torch.float
if args.fp16:
args.params_dtype = torch.half
if args.rank == 0:
print('using {} for parameters ...'.format(args.params_dtype),
flush=True)
# Set input defaults.
for key in defaults:
# For default to be valid, it should not be provided in the
......
......@@ -35,7 +35,6 @@ from .initialize import model_parallel_is_initialized
from .layers import LayerNorm
from .layers import ColumnParallelLinear
from .layers import ParallelEmbedding
from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding
......
......@@ -31,7 +31,8 @@ try:
_ = LayerNorm(8, eps=1e-5)
except Exception as e:
print('WARNING: APEX is not installed, using torch.nn.LayerNorm instead of apex.normalization.FusedLayerNorm!')
print('WARNING: APEX is not installed, using torch.nn.LayerNorm '
'instead of apex.normalization.FusedLayerNorm!')
from torch.nn import LayerNorm
from .initialize import get_model_parallel_rank
......@@ -44,11 +45,28 @@ from .random import get_cuda_rng_tracker
from .utils import divide
from .utils import split_tensor_along_last_dim
from .utils import VocabUtility
from megatron import get_args
def _initialize_affine_weight(weight, output_size, input_size,
per_partition_size, partition_dim, init_method,
stride=1, return_master_weight=False):
_USE_CPU_INITIALIZATION = False
def _initialize_affine_weight_gpu(weight, init_method,
partition_dim, stride=1):
"""Initialize affine weight for model parallel on GPU."""
weight.model_parallel = True
weight.partition_dim = partition_dim
weight.partition_stride = stride
with get_cuda_rng_tracker().fork():
init_method(weight)
def _initialize_affine_weight_cpu(weight, output_size, input_size,
per_partition_size, partition_dim,
init_method, stride=1,
return_master_weight=False):
"""Initialize affine weight for model parallel.
Build the master weight on all processes and scatter
......@@ -56,7 +74,7 @@ def _initialize_affine_weight(weight, output_size, input_size,
weight.model_parallel = True
weight.partition_dim = partition_dim
weight.stride = stride
weight.partition_stride = stride
# If we only use 1 process for model parallelism, bypass scatter.
world_size = get_model_parallel_world_size()
......@@ -68,9 +86,11 @@ def _initialize_affine_weight(weight, output_size, input_size,
# Initialize master weight
master_weight = torch.empty(output_size, input_size,
dtype=weight.dtype,
dtype=torch.float,
requires_grad=False)
init_method(master_weight)
args = get_args()
master_weight = master_weight.to(dtype=args.params_dtype)
# Split and copy
per_partition_per_stride_size = divide(per_partition_size, stride)
......@@ -119,13 +139,21 @@ class VocabParallelEmbedding(torch.nn.Module):
self.num_embeddings_per_partition = self.vocab_end_index - \
self.vocab_start_index
# Allocate weights.
self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition,
self.embedding_dim))
# And initialize.
_initialize_affine_weight(
self.weight, self.num_embeddings, self.embedding_dim,
self.num_embeddings_per_partition, 0, init_method)
# Allocate weights and initialize.
args = get_args()
if _USE_CPU_INITIALIZATION:
self.weight = Parameter(torch.empty(
self.num_embeddings_per_partition, self.embedding_dim,
dtype=args.params_dtype))
_initialize_affine_weight_cpu(
self.weight, self.num_embeddings, self.embedding_dim,
self.num_embeddings_per_partition, 0, init_method)
else:
self.weight = Parameter(torch.empty(
self.num_embeddings_per_partition, self.embedding_dim,
device=torch.cuda.current_device(), dtype=args.params_dtype))
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=1)
def forward(self, input_):
if self.model_parallel_size > 1:
......@@ -150,55 +178,6 @@ class VocabParallelEmbedding(torch.nn.Module):
return output
class ParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the embedding dimension.
This is mainly adapted from torch.nn.Embedding and all the default
values are kept.
Arguments:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
init_method: method to initialize weights.
"""
def __init__(self, num_embeddings, embedding_dim,
init_method=init.xavier_normal_,
keep_master_weight_for_test=False):
super(ParallelEmbedding, self).__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
# Set some detauls for compatibility.
self.padding_idx = None
self.max_norm = None
self.norm_type = 2.
self.scale_grad_by_freq = False
self.sparse = False
self._weight = None
# Divide the weight matrix along the embedding dimension.
world_size = get_model_parallel_world_size()
self.embedding_dim_per_partition = divide(self.embedding_dim,
world_size)
# Allocate weights.
self.weight = Parameter(torch.Tensor(self.num_embeddings,
self.embedding_dim_per_partition))
# And initialize.
_initialize_affine_weight(
self.weight, self.num_embeddings, self.embedding_dim,
self.embedding_dim_per_partition, 1, init_method,
stride=1, return_master_weight=False)
def forward(self, input_):
input_parallel = copy_to_model_parallel_region(input_)
output_parallel = F.embedding(input_parallel, self.weight,
self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq,
self.sparse)
output = gather_from_model_parallel_region(output_parallel)
return output
class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism.
......@@ -236,10 +215,32 @@ class ColumnParallelLinear(torch.nn.Module):
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
self.weight = Parameter(torch.Tensor(self.output_size_per_partition,
self.input_size))
# Initialize weight.
args = get_args()
if _USE_CPU_INITIALIZATION:
self.weight = Parameter(torch.empty(self.output_size_per_partition,
self.input_size,
dtype=args.params_dtype))
self.master_weight = _initialize_affine_weight_cpu(
self.weight, self.output_size, self.input_size,
self.output_size_per_partition, 0, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test)
else:
self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.input_size,
device=torch.cuda.current_device(), dtype=args.params_dtype))
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=stride)
if bias:
self.bias = Parameter(torch.Tensor(self.output_size_per_partition))
if _USE_CPU_INITIALIZATION:
self.bias = Parameter(torch.empty(
self.output_size_per_partition, dtype=args.params_dtype))
else:
self.bias = Parameter(torch.empty(
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=args.params_dtype))
self.bias.model_parallel = True
self.bias.partition_dim = 0
self.bias.stride = stride
......@@ -249,11 +250,7 @@ class ColumnParallelLinear(torch.nn.Module):
else:
self.register_parameter('bias', None)
# Initialize weight.
self.master_weight = _initialize_affine_weight(
self.weight, self.output_size, self.input_size,
self.output_size_per_partition, 0, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test)
def forward(self, input_):
# Set up backprop all-reduce.
......@@ -312,21 +309,37 @@ class RowParallelLinear(torch.nn.Module):
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
self.weight = Parameter(torch.Tensor(self.output_size,
self.input_size_per_partition))
# Initialize weight.
args = get_args()
if _USE_CPU_INITIALIZATION:
self.weight = Parameter(torch.empty(self.output_size,
self.input_size_per_partition,
dtype=args.params_dtype))
self.master_weight = _initialize_affine_weight_cpu(
self.weight, self.output_size, self.input_size,
self.input_size_per_partition, 1, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test)
else:
self.weight = Parameter(torch.empty(
self.output_size, self.input_size_per_partition,
device=torch.cuda.current_device(), dtype=args.params_dtype))
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=1, stride=stride)
if bias:
self.bias = Parameter(torch.Tensor(self.output_size))
if _USE_CPU_INITIALIZATION:
self.bias = Parameter(torch.empty(self.output_size,
dtype=args.params_dtype))
else:
self.bias = Parameter(torch.empty(
self.output_size, device=torch.cuda.current_device(),
dtype=args.params_dtype))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
# Initialize weight.
self.master_weight = _initialize_affine_weight(
self.weight, self.output_size, self.input_size,
self.input_size_per_partition, 1, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test)
def forward(self, input_):
# Set up backprop all-reduce.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment