Commit 62632d39 authored by mohammad's avatar mohammad Committed by Deepak Narayanan
Browse files

Add the option for fp32 residual connection (fp32 residual connection...

Add the option for fp32 residual connection (fp32 residual connection machinery still needs to be added)
parent 9b174da8
......@@ -183,6 +183,9 @@ def parse_args(extra_args_provider=None, defaults={},
# Mixed precision checks.
if args.fp16_lm_cross_entropy:
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
if args.fp32_residual_connection:
assert args.fp16, \
'residual connection in fp32 only supports in fp16 mode.'
# Activation checkpointing.
if args.distribute_checkpointed_activations:
assert args.checkpoint_activations, \
......@@ -435,6 +438,8 @@ def _add_mixed_precision_args(parser):
group.add_argument('--fp16', action='store_true',
help='Run model in fp16 mode.')
group.add_argument('--fp32-residual-connection', action='store_true',
help='Move residual connections to fp32.')
group.add_argument('--apply-query-key-layer-scaling', action='store_true',
help='Scale Q * K^T by 1 / layer-number. If this flag '
'is set, then it will automatically set '
......
......@@ -13,9 +13,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
_LAYER_NORM = None
def import_layernorm(fp32_residual_connection):
global _LAYER_NORM
if not _LAYER_NORM:
if fp32_residual_connection:
from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
else:
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
_LAYER_NORM = LayerNorm
return _LAYER_NORM
from .distributed import *
from .bert_model import BertModel, BertModelFirstStage, BertModelIntermediateStage, BertModelLastStage
from .realm_model import ICTBertModel
from .gpt2_model import GPT2Model, GPT2ModelFirstStage, GPT2ModelIntermediateStage, GPT2ModelLastStage
from .utils import get_params_for_weight_decay_optimization
from .language_model import get_language_model
......@@ -21,7 +21,7 @@ from megatron import get_args
from megatron import mpu
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
from megatron.model.transformer import LayerNorm
from megatron.model import import_layernorm
from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
......@@ -83,6 +83,7 @@ class BertLMHead(MegatronModule):
self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
LayerNorm = import_layernorm(args.fp32_residual_connection)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
self.gelu = torch.nn.functional.gelu
if args.openai_gelu:
......
......@@ -21,9 +21,9 @@ import torch.nn.functional as F
from megatron import get_args
from megatron import mpu
from megatron.mpu import LayerNorm
from megatron.module import MegatronModule
from megatron.checkpointing import get_checkpoint_version
from megatron.model import import_layernorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import openai_gelu, erf_gelu
......@@ -404,6 +404,7 @@ class ParallelTransformerLayer(MegatronModule):
= args.apply_residual_connection_post_layernorm
# Layernorm on the input data.
LayerNorm = import_layernorm(args.fp32_residual_connection)
self.input_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
......@@ -500,6 +501,8 @@ class ParallelTransformer(MegatronModule):
super(ParallelTransformer, self).__init__()
args = get_args()
self.fp32_residual_connection = args.fp32_residual_connection
# Store activation checkpoiting flag.
self.checkpoint_activations = args.checkpoint_activations
self.checkpoint_num_layers = args.checkpoint_num_layers
......@@ -520,6 +523,7 @@ class ParallelTransformer(MegatronModule):
if mpu.is_pipeline_last_stage():
# Final layer norm before output.
LayerNorm = import_layernorm(args.fp32_residual_connection)
self.final_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
......@@ -564,6 +568,9 @@ class ParallelTransformer(MegatronModule):
if mpu.is_pipeline_first_stage():
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
if self.fp32_residual_connection:
hidden_states = hidden_states.transpose(0, 1).contiguous().float()
else:
hidden_states = hidden_states.transpose(0, 1).contiguous()
if self.checkpoint_activations:
......
......@@ -19,8 +19,8 @@ import math
import torch
from .transformer import LayerNorm
from megatron import get_args
from megatron.model import import_layernorm
def init_method_normal(sigma):
"""Init method based on N(0, sigma)."""
......@@ -65,6 +65,10 @@ def get_params_for_weight_decay_optimization(module):
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
"""
args = get_args()
LayerNorm = import_layernorm(args.fp32_residual_connection)
weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module_ in module.modules():
......
......@@ -41,7 +41,6 @@ from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_mod
from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized
from .layers import LayerNorm
from .layers import ColumnParallelLinear
from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding
......
......@@ -25,16 +25,6 @@ import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter
try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
# Try to use FusedLayerNorm from Apex - this will trigger an error.
_ = LayerNorm(8, eps=1e-5)
except Exception as e:
print('WARNING: APEX is not installed, using torch.nn.LayerNorm '
'instead of apex.normalization.FusedLayerNorm!')
from torch.nn import LayerNorm
from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_world_size
from .mappings import copy_to_tensor_model_parallel_region
......
......@@ -333,16 +333,19 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward)
tensor_recv_prev = None
tensor_recv_next = None
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
dtype = args.params_dtype
if args.fp32_residual_connection:
dtype = torch.float
if recv_forward:
tensor_recv_prev = torch.empty(tensor_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=args.params_dtype)
dtype=dtype)
if recv_backward:
tensor_recv_next = torch.empty(tensor_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=args.params_dtype)
dtype=dtype)
# Send tensors in both the forward and backward directions as appropriate.
torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment