"vscode:/vscode.git/clone" did not exist on "2bce175d156b5c5c1be925cb57fe33215675fafd"
Commit d6380fda authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'slym/persist-ln-fix' into 'main'

add a kernel import guard for persistent layer norm

See merge request ADLR/megatron-lm!363
parents 4a3213f1 03a2903f
...@@ -23,7 +23,11 @@ from torch.nn.parameter import Parameter ...@@ -23,7 +23,11 @@ from torch.nn.parameter import Parameter
from torch.nn import init from torch.nn import init
import importlib import importlib
from apex.contrib.layer_norm.layer_norm import FastLayerNormFN try:
from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
HAVE_PERSIST_LAYER_NORM = True
except:
HAVE_PERSIST_LAYER_NORM = False
global fused_mix_prec_layer_norm_cuda global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = None fused_mix_prec_layer_norm_cuda = None
...@@ -76,7 +80,8 @@ class MixedFusedLayerNorm(torch.nn.Module): ...@@ -76,7 +80,8 @@ class MixedFusedLayerNorm(torch.nn.Module):
persist_ln_hidden_sizes = [1024, 1536, 2048, 2304, 3072, 3840, 4096, persist_ln_hidden_sizes = [1024, 1536, 2048, 2304, 3072, 3840, 4096,
5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
24576, 25600, 30720, 32768, 40960, 49152, 65536] 24576, 25600, 30720, 32768, 40960, 49152, 65536]
if normalized_shape not in persist_ln_hidden_sizes: if normalized_shape not in persist_ln_hidden_sizes or \
not HAVE_PERSIST_LAYER_NORM:
no_persist_layer_norm = True no_persist_layer_norm = True
if isinstance(normalized_shape, numbers.Integral): if isinstance(normalized_shape, numbers.Integral):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment