Commit 124accba authored by dongcl's avatar dongcl
Browse files

support split_bw when te_version > 2.3.0.dev0

parent bfe0b4a9
...@@ -39,10 +39,9 @@ def a2a_overlap_adaptation(patches_manager): ...@@ -39,10 +39,9 @@ def a2a_overlap_adaptation(patches_manager):
create_dummy=True) create_dummy=True)
# backward_dw # backward_dw
if is_te_min_version("2.4.0.dev0"): patches_manager.register_patch('megatron.core.extensions.transformer_engine._get_extra_te_kwargs',
patches_manager.register_patch('megatron.core.extensions.transformer_engine._get_extra_te_kwargs', _get_extra_te_kwargs_wrapper,
_get_extra_te_kwargs_wrapper, apply_wrapper=True)
apply_wrapper=True)
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELinear', patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELinear',
TELinear) TELinear)
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear', patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear',
......
import os import os
import copy
import torch import torch
import dataclasses import dataclasses
import transformer_engine as te import transformer_engine as te
...@@ -7,6 +8,7 @@ from functools import wraps ...@@ -7,6 +8,7 @@ from functools import wraps
from typing import Any, Optional, Callable from typing import Any, Optional, Callable
from packaging.version import Version as PkgVersion from packaging.version import Version as PkgVersion
from megatron.training import get_args
from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.tensor_parallel import get_cuda_rng_tracker from megatron.core.tensor_parallel import get_cuda_rng_tracker
from megatron.core.utils import get_te_version, is_te_min_version from megatron.core.utils import get_te_version, is_te_min_version
...@@ -25,15 +27,17 @@ from megatron.core.parallel_state import ( ...@@ -25,15 +27,17 @@ from megatron.core.parallel_state import (
) )
def _get_extra_te_kwargs_wrapper(fn): def _get_extra_te_kwargs_wrapper(_get_extra_te_kwargs_func):
@wraps(fn) @wraps(_get_extra_te_kwargs_func)
def wrapper(config: TransformerConfig): def wrapper(config: TransformerConfig):
extra_transformer_engine_kwargs = fn(config) extra_transformer_engine_kwargs = _get_extra_te_kwargs_func(config)
if hasattr(config, "split_bw"): if hasattr(config, "split_bw"):
extra_transformer_engine_kwargs["delay_wgrad_compute"] = config.split_bw extra_transformer_engine_kwargs["delay_wgrad_compute"] = config.split_bw
return extra_transformer_engine_kwargs return extra_transformer_engine_kwargs
return wrapper if is_te_min_version("2.3.0.dev0"):
return wrapper
return _get_extra_te_kwargs_func
class TELinear(MegatronCoreTELinear): class TELinear(MegatronCoreTELinear):
...@@ -66,8 +70,14 @@ class TELinear(MegatronCoreTELinear): ...@@ -66,8 +70,14 @@ class TELinear(MegatronCoreTELinear):
tp_comm_buffer_name: Optional[str] = None, tp_comm_buffer_name: Optional[str] = None,
is_expert: bool = False, is_expert: bool = False,
): ):
self.split_bw = config.split_bw if hasattr(config, "split_bw") else False args = get_args()
assert not self.split_bw, "split_bw is currently not supported" self.split_bw = args.split_bw if hasattr(args, "split_bw") else False
if not is_te_min_version("2.3.0.dev0"):
assert not self.split_bw, "split_bw is currently not supported"
if self.split_bw:
config = copy.copy(config)
config.split_bw = True
super().__init__( super().__init__(
input_size, input_size,
...@@ -86,6 +96,8 @@ class TELinear(MegatronCoreTELinear): ...@@ -86,6 +96,8 @@ class TELinear(MegatronCoreTELinear):
if not self.split_bw: if not self.split_bw:
return return
return super(MegatronCoreTELinear, self).backward_dw()
class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinear): class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinear):
""" """
...@@ -107,8 +119,14 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea ...@@ -107,8 +119,14 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
skip_weight_param_allocation: bool = False, skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: Optional[str] = None, tp_comm_buffer_name: Optional[str] = None,
): ):
self.split_bw = config.split_bw if hasattr(config, "split_bw") else False args = get_args()
assert not self.split_bw, "split_bw is currently not supported" self.split_bw = args.split_bw if hasattr(args, "split_bw") else False
if not is_te_min_version("2.3.0.dev0"):
assert not self.split_bw, "split_bw is currently not supported"
if self.split_bw:
config = copy.copy(config)
config.split_bw = True
super().__init__( super().__init__(
input_size, input_size,
...@@ -127,6 +145,8 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea ...@@ -127,6 +145,8 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
if not self.split_bw: if not self.split_bw:
return return
return super(MegatronCoreTELayerNormColumnParallelLinear, self).backward_dw()
class TEDotProductAttentionPatch(te.pytorch.DotProductAttention): class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
def __init__( def __init__(
...@@ -289,8 +309,14 @@ if is_te_min_version("1.9.0.dev0"): ...@@ -289,8 +309,14 @@ if is_te_min_version("1.9.0.dev0"):
is_expert: bool = False, is_expert: bool = False,
tp_comm_buffer_name: Optional[str] = None, tp_comm_buffer_name: Optional[str] = None,
): ):
self.split_bw = config.split_bw if hasattr(config, "split_bw") else False args = get_args()
assert not self.split_bw, "split_bw is currently not supported" self.split_bw = args.split_bw if hasattr(args, "split_bw") else False
if not is_te_min_version("2.3.0.dev0"):
assert not self.split_bw, "split_bw is currently not supported"
if self.split_bw:
config = copy.copy(config)
config.split_bw = True
super().__init__( super().__init__(
num_gemms, num_gemms,
...@@ -308,3 +334,5 @@ if is_te_min_version("1.9.0.dev0"): ...@@ -308,3 +334,5 @@ if is_te_min_version("1.9.0.dev0"):
def backward_dw(self): def backward_dw(self):
if not self.split_bw: if not self.split_bw:
return return
return super(MegatronCoreTEGroupedLinear, self).backward_dw()
...@@ -342,6 +342,8 @@ class MoeMlPNode(TransformerLayerNode): ...@@ -342,6 +342,8 @@ class MoeMlPNode(TransformerLayerNode):
assert mlp_bias is None assert mlp_bias is None
# pre_mlp_layernorm_output used # pre_mlp_layernorm_output used
# cur_stream = torch.cuda.current_stream()
# self.common_state.pre_mlp_layernorm_output.record_stream(cur_stream)
self.common_state.pre_mlp_layernorm_output = None self.common_state.pre_mlp_layernorm_output = None
return expert_output, shared_expert_output return expert_output, shared_expert_output
......
...@@ -12,6 +12,7 @@ from megatron.core.distributed import DistributedDataParallel ...@@ -12,6 +12,7 @@ from megatron.core.distributed import DistributedDataParallel
from megatron.core.transformer.module import Float16Module from megatron.core.transformer.module import Float16Module
from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler
from megatron.core.utils import get_attr_wrapped_model, make_viewless_tensor from megatron.core.utils import get_attr_wrapped_model, make_viewless_tensor
...@@ -56,6 +57,11 @@ class ScheduleNode: ...@@ -56,6 +57,11 @@ class ScheduleNode:
self.outputs = None self.outputs = None
def default_backward_func(self, outputs, output_grad): def default_backward_func(self, outputs, output_grad):
# Handle scalar output
if output_grad is None:
assert outputs.numel() == 1, "implicit grad requires scalar output."
output_grad = torch.ones_like(outputs, memory_format=torch.preserve_format)
Variable._execution_engine.run_backward( Variable._execution_engine.run_backward(
tensors=outputs, tensors=outputs,
grad_tensors=output_grad, grad_tensors=output_grad,
...@@ -441,12 +447,13 @@ def forward_backward_step( ...@@ -441,12 +447,13 @@ def forward_backward_step(
output_tensor, num_tokens, loss_reduced = outputs output_tensor, num_tokens, loss_reduced = outputs
if not config.calculate_per_token_loss: if not config.calculate_per_token_loss:
output_tensor /= num_tokens output_tensor /= num_tokens
output_tensor *= parallel_state.get_context_parallel_world_size()
output_tensor /= num_microbatches output_tensor /= num_microbatches
else: else:
# preserve legacy loss averaging behavior # preserve legacy loss averaging behavior (ie, over the number of microbatches)
# (ie, over the number of microbatches)
assert len(outputs) == 2 assert len(outputs) == 2
output_tensor, loss_reduced = outputs output_tensor, loss_reduced = outputs
output_tensor *= parallel_state.get_context_parallel_world_size()
output_tensor = output_tensor / num_microbatches output_tensor = output_tensor / num_microbatches
forward_data_store.append(loss_reduced) forward_data_store.append(loss_reduced)
...@@ -464,12 +471,11 @@ def forward_backward_step( ...@@ -464,12 +471,11 @@ def forward_backward_step(
# Since we use a trick to do backward on the auxiliary loss, we need to set the scale # Since we use a trick to do backward on the auxiliary loss, we need to set the scale
# explicitly. # explicitly.
if hasattr(config, 'num_moe_experts') and config.num_moe_experts is not None: if hasattr(config, 'num_moe_experts') and config.num_moe_experts is not None:
# Calculate the loss scale based on the grad_scale_func if available, # Calculate the loss scale based on the grad_scale_func if available, else default to 1.
# else default to 1.
loss_scale = ( loss_scale = (
config.grad_scale_func(torch.ones(1, device=output_tensor.device)) config.grad_scale_func(torch.ones(1, device=output_tensor.device))
if config.grad_scale_func is not None if config.grad_scale_func is not None
else torch.tensor(1.0) else torch.ones(1, device=output_tensor.device)
) )
# Set the loss scale # Set the loss scale
if config.calculate_per_token_loss: if config.calculate_per_token_loss:
...@@ -477,8 +483,23 @@ def forward_backward_step( ...@@ -477,8 +483,23 @@ def forward_backward_step(
else: else:
MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches) MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
# Set the loss scale for Multi-Token Prediction (MTP) loss.
if hasattr(config, 'mtp_num_layers') and config.mtp_num_layers is not None:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale = (
config.grad_scale_func(torch.ones(1, device=output_tensor.device))
if config.grad_scale_func is not None
else torch.ones(1, device=output_tensor.device)
)
# Set the loss scale
if config.calculate_per_token_loss:
MTPLossAutoScaler.set_loss_scale(loss_scale)
else:
MTPLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
if not unwrap_output_tensor: if not unwrap_output_tensor:
output_tensor, num_tokens = [output_tensor], num_tokens output_tensor, num_tokens = [output_tensor], num_tokens
# backward post process # backward post process
input_tensor_grad = None input_tensor_grad = None
if b_model is not None: if b_model is not None:
......
...@@ -40,9 +40,6 @@ class ExtraTransformerConfig: ...@@ -40,9 +40,6 @@ class ExtraTransformerConfig:
combined_1f1b_recipe: str = 'ep_a2a' combined_1f1b_recipe: str = 'ep_a2a'
"""Recipe to use for combined 1F1B. Currently only 'ep_a2a' and 'golden' are supported.""" """Recipe to use for combined 1F1B. Currently only 'ep_a2a' and 'golden' are supported."""
split_bw: bool = False
"""If true, split dgrad and wgrad for better overlapping in combined 1F1B."""
@dataclass @dataclass
class TransformerConfigPatch(TransformerConfig, ExtraTransformerConfig): class TransformerConfigPatch(TransformerConfig, ExtraTransformerConfig):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment