Commit 0403b808 authored by mohammad's avatar mohammad
Browse files

added gpu initialization and option to avoid master values

parent 61f50c45
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
import argparse import argparse
import os import os
import torch
def parse_args(extra_args_provider=None, defaults={}, def parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False): ignore_unknown_args=False):
...@@ -62,6 +64,15 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -62,6 +64,15 @@ def parse_args(extra_args_provider=None, defaults={},
if args.loss_scale is None: if args.loss_scale is None:
args.dynamic_loss_scale = True args.dynamic_loss_scale = True
# Parameters dtype.
args.params_dtype = torch.float
if args.fp16:
args.params_dtype = torch.half
if args.rank == 0:
print('using {} for parameters ...'.format(args.params_dtype),
flush=True)
# Set input defaults. # Set input defaults.
for key in defaults: for key in defaults:
# For default to be valid, it should not be provided in the # For default to be valid, it should not be provided in the
......
...@@ -35,7 +35,6 @@ from .initialize import model_parallel_is_initialized ...@@ -35,7 +35,6 @@ from .initialize import model_parallel_is_initialized
from .layers import LayerNorm from .layers import LayerNorm
from .layers import ColumnParallelLinear from .layers import ColumnParallelLinear
from .layers import ParallelEmbedding
from .layers import RowParallelLinear from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding from .layers import VocabParallelEmbedding
......
...@@ -31,7 +31,8 @@ try: ...@@ -31,7 +31,8 @@ try:
_ = LayerNorm(8, eps=1e-5) _ = LayerNorm(8, eps=1e-5)
except Exception as e: except Exception as e:
print('WARNING: APEX is not installed, using torch.nn.LayerNorm instead of apex.normalization.FusedLayerNorm!') print('WARNING: APEX is not installed, using torch.nn.LayerNorm '
'instead of apex.normalization.FusedLayerNorm!')
from torch.nn import LayerNorm from torch.nn import LayerNorm
from .initialize import get_model_parallel_rank from .initialize import get_model_parallel_rank
...@@ -44,11 +45,28 @@ from .random import get_cuda_rng_tracker ...@@ -44,11 +45,28 @@ from .random import get_cuda_rng_tracker
from .utils import divide from .utils import divide
from .utils import split_tensor_along_last_dim from .utils import split_tensor_along_last_dim
from .utils import VocabUtility from .utils import VocabUtility
from megatron import get_args
def _initialize_affine_weight(weight, output_size, input_size, _USE_CPU_INITIALIZATION = False
per_partition_size, partition_dim, init_method,
stride=1, return_master_weight=False):
def _initialize_affine_weight_gpu(weight, init_method,
partition_dim, stride=1):
"""Initialize affine weight for model parallel on GPU."""
weight.model_parallel = True
weight.partition_dim = partition_dim
weight.partition_stride = stride
with get_cuda_rng_tracker().fork():
init_method(weight)
def _initialize_affine_weight_cpu(weight, output_size, input_size,
per_partition_size, partition_dim,
init_method, stride=1,
return_master_weight=False):
"""Initialize affine weight for model parallel. """Initialize affine weight for model parallel.
Build the master weight on all processes and scatter Build the master weight on all processes and scatter
...@@ -56,7 +74,7 @@ def _initialize_affine_weight(weight, output_size, input_size, ...@@ -56,7 +74,7 @@ def _initialize_affine_weight(weight, output_size, input_size,
weight.model_parallel = True weight.model_parallel = True
weight.partition_dim = partition_dim weight.partition_dim = partition_dim
weight.stride = stride weight.partition_stride = stride
# If we only use 1 process for model parallelism, bypass scatter. # If we only use 1 process for model parallelism, bypass scatter.
world_size = get_model_parallel_world_size() world_size = get_model_parallel_world_size()
...@@ -68,9 +86,11 @@ def _initialize_affine_weight(weight, output_size, input_size, ...@@ -68,9 +86,11 @@ def _initialize_affine_weight(weight, output_size, input_size,
# Initialize master weight # Initialize master weight
master_weight = torch.empty(output_size, input_size, master_weight = torch.empty(output_size, input_size,
dtype=weight.dtype, dtype=torch.float,
requires_grad=False) requires_grad=False)
init_method(master_weight) init_method(master_weight)
args = get_args()
master_weight = master_weight.to(dtype=args.params_dtype)
# Split and copy # Split and copy
per_partition_per_stride_size = divide(per_partition_size, stride) per_partition_per_stride_size = divide(per_partition_size, stride)
...@@ -119,13 +139,21 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -119,13 +139,21 @@ class VocabParallelEmbedding(torch.nn.Module):
self.num_embeddings_per_partition = self.vocab_end_index - \ self.num_embeddings_per_partition = self.vocab_end_index - \
self.vocab_start_index self.vocab_start_index
# Allocate weights. # Allocate weights and initialize.
self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition, args = get_args()
self.embedding_dim)) if _USE_CPU_INITIALIZATION:
# And initialize. self.weight = Parameter(torch.empty(
_initialize_affine_weight( self.num_embeddings_per_partition, self.embedding_dim,
dtype=args.params_dtype))
_initialize_affine_weight_cpu(
self.weight, self.num_embeddings, self.embedding_dim, self.weight, self.num_embeddings, self.embedding_dim,
self.num_embeddings_per_partition, 0, init_method) self.num_embeddings_per_partition, 0, init_method)
else:
self.weight = Parameter(torch.empty(
self.num_embeddings_per_partition, self.embedding_dim,
device=torch.cuda.current_device(), dtype=args.params_dtype))
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=1)
def forward(self, input_): def forward(self, input_):
if self.model_parallel_size > 1: if self.model_parallel_size > 1:
...@@ -150,55 +178,6 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -150,55 +178,6 @@ class VocabParallelEmbedding(torch.nn.Module):
return output return output
class ParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the embedding dimension.
This is mainly adapted from torch.nn.Embedding and all the default
values are kept.
Arguments:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
init_method: method to initialize weights.
"""
def __init__(self, num_embeddings, embedding_dim,
init_method=init.xavier_normal_,
keep_master_weight_for_test=False):
super(ParallelEmbedding, self).__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
# Set some detauls for compatibility.
self.padding_idx = None
self.max_norm = None
self.norm_type = 2.
self.scale_grad_by_freq = False
self.sparse = False
self._weight = None
# Divide the weight matrix along the embedding dimension.
world_size = get_model_parallel_world_size()
self.embedding_dim_per_partition = divide(self.embedding_dim,
world_size)
# Allocate weights.
self.weight = Parameter(torch.Tensor(self.num_embeddings,
self.embedding_dim_per_partition))
# And initialize.
_initialize_affine_weight(
self.weight, self.num_embeddings, self.embedding_dim,
self.embedding_dim_per_partition, 1, init_method,
stride=1, return_master_weight=False)
def forward(self, input_):
input_parallel = copy_to_model_parallel_region(input_)
output_parallel = F.embedding(input_parallel, self.weight,
self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq,
self.sparse)
output = gather_from_model_parallel_region(output_parallel)
return output
class ColumnParallelLinear(torch.nn.Module): class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism. """Linear layer with column parallelism.
...@@ -236,10 +215,32 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -236,10 +215,32 @@ class ColumnParallelLinear(torch.nn.Module):
# Parameters. # Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result # Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose. # we allocate the transpose.
self.weight = Parameter(torch.Tensor(self.output_size_per_partition, # Initialize weight.
self.input_size)) args = get_args()
if _USE_CPU_INITIALIZATION:
self.weight = Parameter(torch.empty(self.output_size_per_partition,
self.input_size,
dtype=args.params_dtype))
self.master_weight = _initialize_affine_weight_cpu(
self.weight, self.output_size, self.input_size,
self.output_size_per_partition, 0, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test)
else:
self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.input_size,
device=torch.cuda.current_device(), dtype=args.params_dtype))
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=stride)
if bias: if bias:
self.bias = Parameter(torch.Tensor(self.output_size_per_partition)) if _USE_CPU_INITIALIZATION:
self.bias = Parameter(torch.empty(
self.output_size_per_partition, dtype=args.params_dtype))
else:
self.bias = Parameter(torch.empty(
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=args.params_dtype))
self.bias.model_parallel = True self.bias.model_parallel = True
self.bias.partition_dim = 0 self.bias.partition_dim = 0
self.bias.stride = stride self.bias.stride = stride
...@@ -249,11 +250,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -249,11 +250,7 @@ class ColumnParallelLinear(torch.nn.Module):
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
# Initialize weight.
self.master_weight = _initialize_affine_weight(
self.weight, self.output_size, self.input_size,
self.output_size_per_partition, 0, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test)
def forward(self, input_): def forward(self, input_):
# Set up backprop all-reduce. # Set up backprop all-reduce.
...@@ -312,21 +309,37 @@ class RowParallelLinear(torch.nn.Module): ...@@ -312,21 +309,37 @@ class RowParallelLinear(torch.nn.Module):
# Parameters. # Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result # Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose. # we allocate the transpose.
self.weight = Parameter(torch.Tensor(self.output_size, # Initialize weight.
self.input_size_per_partition)) args = get_args()
if _USE_CPU_INITIALIZATION:
self.weight = Parameter(torch.empty(self.output_size,
self.input_size_per_partition,
dtype=args.params_dtype))
self.master_weight = _initialize_affine_weight_cpu(
self.weight, self.output_size, self.input_size,
self.input_size_per_partition, 1, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test)
else:
self.weight = Parameter(torch.empty(
self.output_size, self.input_size_per_partition,
device=torch.cuda.current_device(), dtype=args.params_dtype))
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=1, stride=stride)
if bias: if bias:
self.bias = Parameter(torch.Tensor(self.output_size)) if _USE_CPU_INITIALIZATION:
self.bias = Parameter(torch.empty(self.output_size,
dtype=args.params_dtype))
else:
self.bias = Parameter(torch.empty(
self.output_size, device=torch.cuda.current_device(),
dtype=args.params_dtype))
# Always initialize bias to zero. # Always initialize bias to zero.
with torch.no_grad(): with torch.no_grad():
self.bias.zero_() self.bias.zero_()
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
# Initialize weight.
self.master_weight = _initialize_affine_weight(
self.weight, self.output_size, self.input_size,
self.input_size_per_partition, 1, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test)
def forward(self, input_): def forward(self, input_):
# Set up backprop all-reduce. # Set up backprop all-reduce.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment