Unverified Commit 7ee569b0 authored by Edenzzzz's avatar Edenzzzz Committed by GitHub
Browse files

[hotfix] Fixed fused layernorm bug without apex (#5609)

* fixed fused layernorm bug without apex

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* same for flash attn

* remove flash attn check

---------
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 0d0a5820
...@@ -225,7 +225,13 @@ class FusedLayerNorm(BaseLayerNorm): ...@@ -225,7 +225,13 @@ class FusedLayerNorm(BaseLayerNorm):
# fall back to the normal fused layernorm is not built # fall back to the normal fused layernorm is not built
ApexFusedLayerNorm = FusedLayerNormWithHook ApexFusedLayerNorm = FusedLayerNormWithHook
else: else:
try:
ApexFusedLayerNorm = FusedLayerNormWithHook ApexFusedLayerNorm = FusedLayerNormWithHook
except NameError:
warnings.warn(
"Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using vanilla layernorm instead."
)
return module
layernorm = ( layernorm = (
ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device) ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device)
......
...@@ -120,7 +120,15 @@ class ShardConfig: ...@@ -120,7 +120,15 @@ class ShardConfig:
Turn on all optimization. Turn on all optimization.
""" """
# you can add all the optimization flag here # you can add all the optimization flag here
self.enable_fused_normalization = True try:
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm # noqa
apex_avail = True
except ImportError:
apex_avail = False
warnings.warn("You set enable_all_optimization=True, but apex is not installed.")
self.enable_fused_normalization = apex_avail
self.enable_flash_attention = True self.enable_flash_attention = True
self.enable_jit_fused = True self.enable_jit_fused = True
# This can cause non-in-place param sharding when used without ZeRO. # This can cause non-in-place param sharding when used without ZeRO.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment