Commit 07ebf714 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'apex-fix2' into 'master'

Complete fix for APEX absence in NeMo

See merge request ADLR/megatron-lm!103
parents 7feb02c6 aa0ee72e
...@@ -16,19 +16,11 @@ ...@@ -16,19 +16,11 @@
"""Transformer.""" """Transformer."""
import math import math
import torch import torch
try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
# Try to use FusedLayerNorm from Apex - this will trigger an error.
_ = LayerNorm(8, eps=1e-5)
except Exception as e:
print('WARNING: APEX is not available, using torch.nn.LayerNorm instead of apex.normalization.FusedLayerNorm!')
from torch.nn import LayerNorm
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.mpu import LayerNorm
from megatron.module import MegatronModule from megatron.module import MegatronModule
......
...@@ -32,6 +32,7 @@ from .initialize import get_model_parallel_world_size ...@@ -32,6 +32,7 @@ from .initialize import get_model_parallel_world_size
from .initialize import initialize_model_parallel from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized from .initialize import model_parallel_is_initialized
from .layers import LayerNorm
from .layers import ColumnParallelLinear from .layers import ColumnParallelLinear
from .layers import ParallelEmbedding from .layers import ParallelEmbedding
from .layers import RowParallelLinear from .layers import RowParallelLinear
......
...@@ -21,8 +21,12 @@ ...@@ -21,8 +21,12 @@
import torch import torch
from torch._six import inf from torch._six import inf
from apex.multi_tensor_apply import multi_tensor_applier try:
import amp_C from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
except Exception as e:
print('WARNING: APEX is not installed, multi_tensor_applier will not be available.')
from .initialize import get_model_parallel_group from .initialize import get_model_parallel_group
from .initialize import get_model_parallel_rank from .initialize import get_model_parallel_rank
......
...@@ -25,7 +25,14 @@ import torch.nn.functional as F ...@@ -25,7 +25,14 @@ import torch.nn.functional as F
import torch.nn.init as init import torch.nn.init as init
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
# Try to use FusedLayerNorm from Apex - this will trigger an error.
_ = LayerNorm(8, eps=1e-5)
except Exception as e:
print('WARNING: APEX is not installed, using torch.nn.LayerNorm instead of apex.normalization.FusedLayerNorm!')
from torch.nn import LayerNorm
from .initialize import get_model_parallel_rank from .initialize import get_model_parallel_rank
from .initialize import get_model_parallel_world_size from .initialize import get_model_parallel_world_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment