"tests/test_data/vscode:/vscode.git/clone" did not exist on "eca5a9f29e689a54d622bea5e3f5bca3ce2e4592"
Commit a2fdcdf0 authored by Sangkug Lym's avatar Sangkug Lym
Browse files

persistent fused layer norm

fix the guard to fall back to the baseline fused layer norm kernel

Persisten ln: move the guard for supported hidden sizes to layer norm module
parent f5345dfa
...@@ -257,6 +257,16 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -257,6 +257,16 @@ def parse_args(extra_args_provider=None, defaults={},
'currently distrobuted checkpoint activations only supported for ' \ 'currently distrobuted checkpoint activations only supported for ' \
'nointerleaved pipeline parallelism' 'nointerleaved pipeline parallelism'
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
# Persistent fused layer norm.
if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11):
args.no_persist_layer_norm = True
if args.rank == 0:
print('Persistent fused layer norm kernel is supported from '
'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
'Defaulting to no_persist_layer_norm=True')
_print_args(args) _print_args(args)
return args return args
...@@ -486,6 +496,11 @@ def _add_training_args(parser): ...@@ -486,6 +496,11 @@ def _add_training_args(parser):
help='Disable asynchronous execution of ' help='Disable asynchronous execution of '
'tensor-model-parallel all-reduce with weight ' 'tensor-model-parallel all-reduce with weight '
'gradient compuation of a column-linear layer.') 'gradient compuation of a column-linear layer.')
group.add_argument('--no-persist-layer-norm', action='store_true',
help='Disable using persistent fused layer norm kernel. '
'This kernel supports only a set of hidden sizes. Please '
'check persist_ln_hidden_sizes if your hidden '
'size is supported.')
return parser return parser
......
...@@ -23,6 +23,8 @@ from torch.nn.parameter import Parameter ...@@ -23,6 +23,8 @@ from torch.nn.parameter import Parameter
from torch.nn import init from torch.nn import init
import importlib import importlib
from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
global fused_mix_prec_layer_norm_cuda global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = None fused_mix_prec_layer_norm_cuda = None
...@@ -61,13 +63,22 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): ...@@ -61,13 +63,22 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
class MixedFusedLayerNorm(torch.nn.Module): class MixedFusedLayerNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5): def __init__(self, normalized_shape, eps=1e-5, no_persist_layer_norm=True):
super(MixedFusedLayerNorm, self).__init__() super(MixedFusedLayerNorm, self).__init__()
global fused_mix_prec_layer_norm_cuda global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = importlib.import_module( fused_mix_prec_layer_norm_cuda = importlib.import_module(
"fused_mix_prec_layer_norm_cuda") "fused_mix_prec_layer_norm_cuda")
# List of hiddens sizes supported in the persistent layer norm kernel
# If the hidden size is not supported, fall back to the non-persistent
# kernel.
persist_ln_hidden_sizes = [1024, 1536, 2048, 2304, 3072, 3840, 4096,
5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
24576, 25600, 30720, 32768, 40960, 49152, 65536]
if normalized_shape not in persist_ln_hidden_sizes:
no_persist_layer_norm = True
if isinstance(normalized_shape, numbers.Integral): if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,) normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape) self.normalized_shape = torch.Size(normalized_shape)
...@@ -75,6 +86,7 @@ class MixedFusedLayerNorm(torch.nn.Module): ...@@ -75,6 +86,7 @@ class MixedFusedLayerNorm(torch.nn.Module):
self.weight = Parameter(torch.Tensor(*normalized_shape)) self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape)) self.bias = Parameter(torch.Tensor(*normalized_shape))
self.reset_parameters() self.reset_parameters()
self.no_persist_layer_norm = no_persist_layer_norm
def reset_parameters(self): def reset_parameters(self):
...@@ -85,6 +97,10 @@ class MixedFusedLayerNorm(torch.nn.Module): ...@@ -85,6 +97,10 @@ class MixedFusedLayerNorm(torch.nn.Module):
def forward(self, input): def forward(self, input):
return FusedLayerNormAffineFunction.apply( if self.no_persist_layer_norm:
input, self.weight, self.bias, self.normalized_shape,self.eps) return FusedLayerNormAffineFunction.apply(
input, self.weight, self.bias, self.normalized_shape, self.eps)
else:
return FastLayerNormFN.apply(
input, self.weight, self.bias, self.eps)
...@@ -423,7 +423,8 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -423,7 +423,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layernorm on the input data. # Layernorm on the input data.
self.input_layernorm = LayerNorm( self.input_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# Self attention. # Self attention.
self.self_attention = ParallelAttention( self.self_attention = ParallelAttention(
...@@ -438,7 +439,8 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -438,7 +439,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layernorm on the attention output # Layernorm on the attention output
self.post_attention_layernorm = LayerNorm( self.post_attention_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
if self.layer_type == LayerType.decoder: if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention( self.inter_attention = ParallelAttention(
...@@ -449,7 +451,8 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -449,7 +451,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layernorm on the attention output. # Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm( self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# MLP # MLP
self.mlp = ParallelMLP(init_method, self.mlp = ParallelMLP(init_method,
...@@ -602,7 +605,8 @@ class ParallelTransformer(MegatronModule): ...@@ -602,7 +605,8 @@ class ParallelTransformer(MegatronModule):
# Final layer norm before output. # Final layer norm before output.
self.final_layernorm = LayerNorm( self.final_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
def _get_layer(self, layer_number): def _get_layer(self, layer_number):
return self.layers[layer_number] return self.layers[layer_number]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment