"git@developer.sourcefind.cn:OpenDAS/llama-factory.git" did not exist on "c3a3c6780c28b703565d02169b2e1e5a52214eb6"
Commit 03a2903f authored by Sangkug Lym's avatar Sangkug Lym
Browse files

add a kernel import guard for persistent layer norm

parent 4a3213f1
...@@ -23,7 +23,11 @@ from torch.nn.parameter import Parameter ...@@ -23,7 +23,11 @@ from torch.nn.parameter import Parameter
from torch.nn import init from torch.nn import init
import importlib import importlib
from apex.contrib.layer_norm.layer_norm import FastLayerNormFN try:
from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
HAVE_PERSIST_LAYER_NORM = True
except:
HAVE_PERSIST_LAYER_NORM = False
global fused_mix_prec_layer_norm_cuda global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = None fused_mix_prec_layer_norm_cuda = None
...@@ -76,7 +80,8 @@ class MixedFusedLayerNorm(torch.nn.Module): ...@@ -76,7 +80,8 @@ class MixedFusedLayerNorm(torch.nn.Module):
persist_ln_hidden_sizes = [1024, 1536, 2048, 2304, 3072, 3840, 4096, persist_ln_hidden_sizes = [1024, 1536, 2048, 2304, 3072, 3840, 4096,
5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
24576, 25600, 30720, 32768, 40960, 49152, 65536] 24576, 25600, 30720, 32768, 40960, 49152, 65536]
if normalized_shape not in persist_ln_hidden_sizes: if normalized_shape not in persist_ln_hidden_sizes or \
not HAVE_PERSIST_LAYER_NORM:
no_persist_layer_norm = True no_persist_layer_norm = True
if isinstance(normalized_shape, numbers.Integral): if isinstance(normalized_shape, numbers.Integral):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment