Commit 62632d39 authored by mohammad's avatar mohammad Committed by Deepak Narayanan
Browse files

Add the option for fp32 residual connection (fp32 residual connection...

Add the option for fp32 residual connection (fp32 residual connection machinery still needs to be added)
parent 9b174da8
...@@ -183,6 +183,9 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -183,6 +183,9 @@ def parse_args(extra_args_provider=None, defaults={},
# Mixed precision checks. # Mixed precision checks.
if args.fp16_lm_cross_entropy: if args.fp16_lm_cross_entropy:
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
if args.fp32_residual_connection:
assert args.fp16, \
'residual connection in fp32 only supports in fp16 mode.'
# Activation checkpointing. # Activation checkpointing.
if args.distribute_checkpointed_activations: if args.distribute_checkpointed_activations:
assert args.checkpoint_activations, \ assert args.checkpoint_activations, \
...@@ -435,6 +438,8 @@ def _add_mixed_precision_args(parser): ...@@ -435,6 +438,8 @@ def _add_mixed_precision_args(parser):
group.add_argument('--fp16', action='store_true', group.add_argument('--fp16', action='store_true',
help='Run model in fp16 mode.') help='Run model in fp16 mode.')
group.add_argument('--fp32-residual-connection', action='store_true',
help='Move residual connections to fp32.')
group.add_argument('--apply-query-key-layer-scaling', action='store_true', group.add_argument('--apply-query-key-layer-scaling', action='store_true',
help='Scale Q * K^T by 1 / layer-number. If this flag ' help='Scale Q * K^T by 1 / layer-number. If this flag '
'is set, then it will automatically set ' 'is set, then it will automatically set '
......
...@@ -13,9 +13,27 @@ ...@@ -13,9 +13,27 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
_LAYER_NORM = None
def import_layernorm(fp32_residual_connection):
global _LAYER_NORM
if not _LAYER_NORM:
if fp32_residual_connection:
from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
else:
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
_LAYER_NORM = LayerNorm
return _LAYER_NORM
from .distributed import * from .distributed import *
from .bert_model import BertModel, BertModelFirstStage, BertModelIntermediateStage, BertModelLastStage from .bert_model import BertModel, BertModelFirstStage, BertModelIntermediateStage, BertModelLastStage
from .realm_model import ICTBertModel from .realm_model import ICTBertModel
from .gpt2_model import GPT2Model, GPT2ModelFirstStage, GPT2ModelIntermediateStage, GPT2ModelLastStage from .gpt2_model import GPT2Model, GPT2ModelFirstStage, GPT2ModelIntermediateStage, GPT2ModelLastStage
from .utils import get_params_for_weight_decay_optimization from .utils import get_params_for_weight_decay_optimization
from .language_model import get_language_model from .language_model import get_language_model
...@@ -21,7 +21,7 @@ from megatron import get_args ...@@ -21,7 +21,7 @@ from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.transformer import LayerNorm from megatron.model import import_layernorm
from megatron.model.utils import openai_gelu, erf_gelu from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
...@@ -83,6 +83,7 @@ class BertLMHead(MegatronModule): ...@@ -83,6 +83,7 @@ class BertLMHead(MegatronModule):
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method) self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
LayerNorm = import_layernorm(args.fp32_residual_connection)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
self.gelu = torch.nn.functional.gelu self.gelu = torch.nn.functional.gelu
if args.openai_gelu: if args.openai_gelu:
......
...@@ -21,9 +21,9 @@ import torch.nn.functional as F ...@@ -21,9 +21,9 @@ import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.mpu import LayerNorm
from megatron.module import MegatronModule from megatron.module import MegatronModule
from megatron.checkpointing import get_checkpoint_version from megatron.checkpointing import get_checkpoint_version
from megatron.model import import_layernorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import openai_gelu, erf_gelu from megatron.model.utils import openai_gelu, erf_gelu
...@@ -404,6 +404,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -404,6 +404,7 @@ class ParallelTransformerLayer(MegatronModule):
= args.apply_residual_connection_post_layernorm = args.apply_residual_connection_post_layernorm
# Layernorm on the input data. # Layernorm on the input data.
LayerNorm = import_layernorm(args.fp32_residual_connection)
self.input_layernorm = LayerNorm( self.input_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon)
...@@ -500,6 +501,8 @@ class ParallelTransformer(MegatronModule): ...@@ -500,6 +501,8 @@ class ParallelTransformer(MegatronModule):
super(ParallelTransformer, self).__init__() super(ParallelTransformer, self).__init__()
args = get_args() args = get_args()
self.fp32_residual_connection = args.fp32_residual_connection
# Store activation checkpoiting flag. # Store activation checkpoiting flag.
self.checkpoint_activations = args.checkpoint_activations self.checkpoint_activations = args.checkpoint_activations
self.checkpoint_num_layers = args.checkpoint_num_layers self.checkpoint_num_layers = args.checkpoint_num_layers
...@@ -520,6 +523,7 @@ class ParallelTransformer(MegatronModule): ...@@ -520,6 +523,7 @@ class ParallelTransformer(MegatronModule):
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
# Final layer norm before output. # Final layer norm before output.
LayerNorm = import_layernorm(args.fp32_residual_connection)
self.final_layernorm = LayerNorm( self.final_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon)
...@@ -564,7 +568,10 @@ class ParallelTransformer(MegatronModule): ...@@ -564,7 +568,10 @@ class ParallelTransformer(MegatronModule):
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
# Data format change to avoid explicit tranposes : [b s h] --> [s b h]. # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
hidden_states = hidden_states.transpose(0, 1).contiguous() if self.fp32_residual_connection:
hidden_states = hidden_states.transpose(0, 1).contiguous().float()
else:
hidden_states = hidden_states.transpose(0, 1).contiguous()
if self.checkpoint_activations: if self.checkpoint_activations:
hidden_states = self._checkpointed_forward(hidden_states, hidden_states = self._checkpointed_forward(hidden_states,
......
...@@ -19,8 +19,8 @@ import math ...@@ -19,8 +19,8 @@ import math
import torch import torch
from .transformer import LayerNorm from megatron import get_args
from megatron.model import import_layernorm
def init_method_normal(sigma): def init_method_normal(sigma):
"""Init method based on N(0, sigma).""" """Init method based on N(0, sigma)."""
...@@ -65,6 +65,10 @@ def get_params_for_weight_decay_optimization(module): ...@@ -65,6 +65,10 @@ def get_params_for_weight_decay_optimization(module):
"""Divide params into with-weight-decay and without-weight-decay groups. """Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will. Layernorms and baises will have no weight decay but the rest will.
""" """
args = get_args()
LayerNorm = import_layernorm(args.fp32_residual_connection)
weight_decay_params = {'params': []} weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0} no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module_ in module.modules(): for module_ in module.modules():
......
...@@ -41,7 +41,6 @@ from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_mod ...@@ -41,7 +41,6 @@ from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_mod
from .initialize import initialize_model_parallel from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized from .initialize import model_parallel_is_initialized
from .layers import LayerNorm
from .layers import ColumnParallelLinear from .layers import ColumnParallelLinear
from .layers import RowParallelLinear from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding from .layers import VocabParallelEmbedding
......
...@@ -25,16 +25,6 @@ import torch.nn.functional as F ...@@ -25,16 +25,6 @@ import torch.nn.functional as F
import torch.nn.init as init import torch.nn.init as init
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
# Try to use FusedLayerNorm from Apex - this will trigger an error.
_ = LayerNorm(8, eps=1e-5)
except Exception as e:
print('WARNING: APEX is not installed, using torch.nn.LayerNorm '
'instead of apex.normalization.FusedLayerNorm!')
from torch.nn import LayerNorm
from .initialize import get_tensor_model_parallel_rank from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_world_size from .initialize import get_tensor_model_parallel_world_size
from .mappings import copy_to_tensor_model_parallel_region from .mappings import copy_to_tensor_model_parallel_region
......
...@@ -333,16 +333,19 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward) ...@@ -333,16 +333,19 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward)
tensor_recv_prev = None tensor_recv_prev = None
tensor_recv_next = None tensor_recv_next = None
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
dtype = args.params_dtype
if args.fp32_residual_connection:
dtype = torch.float
if recv_forward: if recv_forward:
tensor_recv_prev = torch.empty(tensor_shape, tensor_recv_prev = torch.empty(tensor_shape,
requires_grad=True, requires_grad=True,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=args.params_dtype) dtype=dtype)
if recv_backward: if recv_backward:
tensor_recv_next = torch.empty(tensor_shape, tensor_recv_next = torch.empty(tensor_shape,
requires_grad=True, requires_grad=True,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=args.params_dtype) dtype=dtype)
# Send tensors in both the forward and backward directions as appropriate. # Send tensors in both the forward and backward directions as appropriate.
torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev, torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment