Commit 84a5997a authored by Boris Fomitchev's avatar Boris Fomitchev
Browse files

Merge remote-tracking branch 'upstream/master' into onnx-erf

parents bb7c638f 07ebf714
...@@ -39,6 +39,12 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -39,6 +39,12 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Make sure cuda is available. # Make sure cuda is available.
assert torch.cuda.is_available(), 'Megatron requires CUDA.' assert torch.cuda.is_available(), 'Megatron requires CUDA.'
# This is temporary WAR to make simple case like pytest calling with same args twice
# Need to implement clean factory init.
if mpu.model_parallel_is_initialized():
return
# Parse args, build tokenizer, and set adlr-autoresume, # Parse args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers. # tensorboard-writer, and timers.
set_global_variables(extra_args_provider=extra_args_provider, set_global_variables(extra_args_provider=extra_args_provider,
......
...@@ -16,12 +16,11 @@ ...@@ -16,12 +16,11 @@
"""Transformer.""" """Transformer."""
import math import math
import torch import torch
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.mpu import LayerNorm
from megatron.module import MegatronModule from megatron.module import MegatronModule
......
...@@ -32,6 +32,7 @@ from .initialize import get_model_parallel_world_size ...@@ -32,6 +32,7 @@ from .initialize import get_model_parallel_world_size
from .initialize import initialize_model_parallel from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized from .initialize import model_parallel_is_initialized
from .layers import LayerNorm
from .layers import ColumnParallelLinear from .layers import ColumnParallelLinear
from .layers import ParallelEmbedding from .layers import ParallelEmbedding
from .layers import RowParallelLinear from .layers import RowParallelLinear
......
...@@ -21,8 +21,12 @@ ...@@ -21,8 +21,12 @@
import torch import torch
from torch._six import inf from torch._six import inf
from apex.multi_tensor_apply import multi_tensor_applier try:
import amp_C from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
except Exception as e:
print('WARNING: APEX is not installed, multi_tensor_applier will not be available.')
from .initialize import get_model_parallel_group from .initialize import get_model_parallel_group
from .initialize import get_model_parallel_rank from .initialize import get_model_parallel_rank
......
...@@ -25,7 +25,14 @@ import torch.nn.functional as F ...@@ -25,7 +25,14 @@ import torch.nn.functional as F
import torch.nn.init as init import torch.nn.init as init
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
# Try to use FusedLayerNorm from Apex - this will trigger an error.
_ = LayerNorm(8, eps=1e-5)
except Exception as e:
print('WARNING: APEX is not installed, using torch.nn.LayerNorm instead of apex.normalization.FusedLayerNorm!')
from torch.nn import LayerNorm
from .initialize import get_model_parallel_rank from .initialize import get_model_parallel_rank
from .initialize import get_model_parallel_world_size from .initialize import get_model_parallel_world_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment