"tests/test_data/vscode:/vscode.git/clone" did not exist on "9cb75e7ddadf447196d761e97a4bfc26502a5728"
Commit a2fdcdf0 authored by Sangkug Lym's avatar Sangkug Lym
Browse files

persistent fused layer norm

fix the guard to fall back to the baseline fused layer norm kernel

Persisten ln: move the guard for supported hidden sizes to layer norm module
parent f5345dfa
......@@ -257,6 +257,16 @@ def parse_args(extra_args_provider=None, defaults={},
'currently distrobuted checkpoint activations only supported for ' \
'nointerleaved pipeline parallelism'
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
# Persistent fused layer norm.
if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11):
args.no_persist_layer_norm = True
if args.rank == 0:
print('Persistent fused layer norm kernel is supported from '
'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
'Defaulting to no_persist_layer_norm=True')
_print_args(args)
return args
......@@ -486,6 +496,11 @@ def _add_training_args(parser):
help='Disable asynchronous execution of '
'tensor-model-parallel all-reduce with weight '
'gradient compuation of a column-linear layer.')
group.add_argument('--no-persist-layer-norm', action='store_true',
help='Disable using persistent fused layer norm kernel. '
'This kernel supports only a set of hidden sizes. Please '
'check persist_ln_hidden_sizes if your hidden '
'size is supported.')
return parser
......
......@@ -23,6 +23,8 @@ from torch.nn.parameter import Parameter
from torch.nn import init
import importlib
from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = None
......@@ -61,13 +63,22 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
class MixedFusedLayerNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5):
def __init__(self, normalized_shape, eps=1e-5, no_persist_layer_norm=True):
super(MixedFusedLayerNorm, self).__init__()
global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = importlib.import_module(
"fused_mix_prec_layer_norm_cuda")
# List of hiddens sizes supported in the persistent layer norm kernel
# If the hidden size is not supported, fall back to the non-persistent
# kernel.
persist_ln_hidden_sizes = [1024, 1536, 2048, 2304, 3072, 3840, 4096,
5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
24576, 25600, 30720, 32768, 40960, 49152, 65536]
if normalized_shape not in persist_ln_hidden_sizes:
no_persist_layer_norm = True
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
......@@ -75,6 +86,7 @@ class MixedFusedLayerNorm(torch.nn.Module):
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
self.reset_parameters()
self.no_persist_layer_norm = no_persist_layer_norm
def reset_parameters(self):
......@@ -85,6 +97,10 @@ class MixedFusedLayerNorm(torch.nn.Module):
def forward(self, input):
return FusedLayerNormAffineFunction.apply(
input, self.weight, self.bias, self.normalized_shape,self.eps)
if self.no_persist_layer_norm:
return FusedLayerNormAffineFunction.apply(
input, self.weight, self.bias, self.normalized_shape, self.eps)
else:
return FastLayerNormFN.apply(
input, self.weight, self.bias, self.eps)
......@@ -423,7 +423,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layernorm on the input data.
self.input_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# Self attention.
self.self_attention = ParallelAttention(
......@@ -438,7 +439,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention(
......@@ -449,7 +451,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# MLP
self.mlp = ParallelMLP(init_method,
......@@ -602,7 +605,8 @@ class ParallelTransformer(MegatronModule):
# Final layer norm before output.
self.final_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
def _get_layer(self, layer_number):
return self.layers[layer_number]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment