Commit 07ebf714 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'apex-fix2' into 'master'

Complete fix for APEX absence in NeMo

See merge request ADLR/megatron-lm!103
parents 7feb02c6 aa0ee72e
......@@ -16,19 +16,11 @@
"""Transformer."""
import math
import torch
try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
# Try to use FusedLayerNorm from Apex - this will trigger an error.
_ = LayerNorm(8, eps=1e-5)
except Exception as e:
print('WARNING: APEX is not available, using torch.nn.LayerNorm instead of apex.normalization.FusedLayerNorm!')
from torch.nn import LayerNorm
from megatron import get_args
from megatron import mpu
from megatron.mpu import LayerNorm
from megatron.module import MegatronModule
......
......@@ -32,6 +32,7 @@ from .initialize import get_model_parallel_world_size
from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized
from .layers import LayerNorm
from .layers import ColumnParallelLinear
from .layers import ParallelEmbedding
from .layers import RowParallelLinear
......
......@@ -21,8 +21,12 @@
import torch
from torch._six import inf
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
try:
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
except Exception as e:
print('WARNING: APEX is not installed, multi_tensor_applier will not be available.')
from .initialize import get_model_parallel_group
from .initialize import get_model_parallel_rank
......
......@@ -25,7 +25,14 @@ import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
# Try to use FusedLayerNorm from Apex - this will trigger an error.
_ = LayerNorm(8, eps=1e-5)
except Exception as e:
print('WARNING: APEX is not installed, using torch.nn.LayerNorm instead of apex.normalization.FusedLayerNorm!')
from torch.nn import LayerNorm
from .initialize import get_model_parallel_rank
from .initialize import get_model_parallel_world_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment