Commit 124accba authored by dongcl's avatar dongcl
Browse files

support split_bw when te_version > 2.3.0.dev0

parent bfe0b4a9
......@@ -39,10 +39,9 @@ def a2a_overlap_adaptation(patches_manager):
create_dummy=True)
# backward_dw
if is_te_min_version("2.4.0.dev0"):
patches_manager.register_patch('megatron.core.extensions.transformer_engine._get_extra_te_kwargs',
_get_extra_te_kwargs_wrapper,
apply_wrapper=True)
patches_manager.register_patch('megatron.core.extensions.transformer_engine._get_extra_te_kwargs',
_get_extra_te_kwargs_wrapper,
apply_wrapper=True)
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELinear',
TELinear)
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear',
......
import os
import copy
import torch
import dataclasses
import transformer_engine as te
......@@ -7,6 +8,7 @@ from functools import wraps
from typing import Any, Optional, Callable
from packaging.version import Version as PkgVersion
from megatron.training import get_args
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.tensor_parallel import get_cuda_rng_tracker
from megatron.core.utils import get_te_version, is_te_min_version
......@@ -25,15 +27,17 @@ from megatron.core.parallel_state import (
)
def _get_extra_te_kwargs_wrapper(fn):
@wraps(fn)
def _get_extra_te_kwargs_wrapper(_get_extra_te_kwargs_func):
@wraps(_get_extra_te_kwargs_func)
def wrapper(config: TransformerConfig):
extra_transformer_engine_kwargs = fn(config)
extra_transformer_engine_kwargs = _get_extra_te_kwargs_func(config)
if hasattr(config, "split_bw"):
extra_transformer_engine_kwargs["delay_wgrad_compute"] = config.split_bw
return extra_transformer_engine_kwargs
return wrapper
if is_te_min_version("2.3.0.dev0"):
return wrapper
return _get_extra_te_kwargs_func
class TELinear(MegatronCoreTELinear):
......@@ -66,8 +70,14 @@ class TELinear(MegatronCoreTELinear):
tp_comm_buffer_name: Optional[str] = None,
is_expert: bool = False,
):
self.split_bw = config.split_bw if hasattr(config, "split_bw") else False
assert not self.split_bw, "split_bw is currently not supported"
args = get_args()
self.split_bw = args.split_bw if hasattr(args, "split_bw") else False
if not is_te_min_version("2.3.0.dev0"):
assert not self.split_bw, "split_bw is currently not supported"
if self.split_bw:
config = copy.copy(config)
config.split_bw = True
super().__init__(
input_size,
......@@ -86,6 +96,8 @@ class TELinear(MegatronCoreTELinear):
if not self.split_bw:
return
return super(MegatronCoreTELinear, self).backward_dw()
class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinear):
"""
......@@ -107,8 +119,14 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: Optional[str] = None,
):
self.split_bw = config.split_bw if hasattr(config, "split_bw") else False
assert not self.split_bw, "split_bw is currently not supported"
args = get_args()
self.split_bw = args.split_bw if hasattr(args, "split_bw") else False
if not is_te_min_version("2.3.0.dev0"):
assert not self.split_bw, "split_bw is currently not supported"
if self.split_bw:
config = copy.copy(config)
config.split_bw = True
super().__init__(
input_size,
......@@ -127,6 +145,8 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
if not self.split_bw:
return
return super(MegatronCoreTELayerNormColumnParallelLinear, self).backward_dw()
class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
def __init__(
......@@ -289,8 +309,14 @@ if is_te_min_version("1.9.0.dev0"):
is_expert: bool = False,
tp_comm_buffer_name: Optional[str] = None,
):
self.split_bw = config.split_bw if hasattr(config, "split_bw") else False
assert not self.split_bw, "split_bw is currently not supported"
args = get_args()
self.split_bw = args.split_bw if hasattr(args, "split_bw") else False
if not is_te_min_version("2.3.0.dev0"):
assert not self.split_bw, "split_bw is currently not supported"
if self.split_bw:
config = copy.copy(config)
config.split_bw = True
super().__init__(
num_gemms,
......@@ -308,3 +334,5 @@ if is_te_min_version("1.9.0.dev0"):
def backward_dw(self):
if not self.split_bw:
return
return super(MegatronCoreTEGroupedLinear, self).backward_dw()
......@@ -342,6 +342,8 @@ class MoeMlPNode(TransformerLayerNode):
assert mlp_bias is None
# pre_mlp_layernorm_output used
# cur_stream = torch.cuda.current_stream()
# self.common_state.pre_mlp_layernorm_output.record_stream(cur_stream)
self.common_state.pre_mlp_layernorm_output = None
return expert_output, shared_expert_output
......
......@@ -12,6 +12,7 @@ from megatron.core.distributed import DistributedDataParallel
from megatron.core.transformer.module import Float16Module
from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler
from megatron.core.utils import get_attr_wrapped_model, make_viewless_tensor
......@@ -56,6 +57,11 @@ class ScheduleNode:
self.outputs = None
def default_backward_func(self, outputs, output_grad):
# Handle scalar output
if output_grad is None:
assert outputs.numel() == 1, "implicit grad requires scalar output."
output_grad = torch.ones_like(outputs, memory_format=torch.preserve_format)
Variable._execution_engine.run_backward(
tensors=outputs,
grad_tensors=output_grad,
......@@ -441,12 +447,13 @@ def forward_backward_step(
output_tensor, num_tokens, loss_reduced = outputs
if not config.calculate_per_token_loss:
output_tensor /= num_tokens
output_tensor *= parallel_state.get_context_parallel_world_size()
output_tensor /= num_microbatches
else:
# preserve legacy loss averaging behavior
# (ie, over the number of microbatches)
# preserve legacy loss averaging behavior (ie, over the number of microbatches)
assert len(outputs) == 2
output_tensor, loss_reduced = outputs
output_tensor *= parallel_state.get_context_parallel_world_size()
output_tensor = output_tensor / num_microbatches
forward_data_store.append(loss_reduced)
......@@ -464,12 +471,11 @@ def forward_backward_step(
# Since we use a trick to do backward on the auxiliary loss, we need to set the scale
# explicitly.
if hasattr(config, 'num_moe_experts') and config.num_moe_experts is not None:
# Calculate the loss scale based on the grad_scale_func if available,
# else default to 1.
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale = (
config.grad_scale_func(torch.ones(1, device=output_tensor.device))
if config.grad_scale_func is not None
else torch.tensor(1.0)
else torch.ones(1, device=output_tensor.device)
)
# Set the loss scale
if config.calculate_per_token_loss:
......@@ -477,8 +483,23 @@ def forward_backward_step(
else:
MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
# Set the loss scale for Multi-Token Prediction (MTP) loss.
if hasattr(config, 'mtp_num_layers') and config.mtp_num_layers is not None:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale = (
config.grad_scale_func(torch.ones(1, device=output_tensor.device))
if config.grad_scale_func is not None
else torch.ones(1, device=output_tensor.device)
)
# Set the loss scale
if config.calculate_per_token_loss:
MTPLossAutoScaler.set_loss_scale(loss_scale)
else:
MTPLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
if not unwrap_output_tensor:
output_tensor, num_tokens = [output_tensor], num_tokens
# backward post process
input_tensor_grad = None
if b_model is not None:
......
......@@ -40,9 +40,6 @@ class ExtraTransformerConfig:
combined_1f1b_recipe: str = 'ep_a2a'
"""Recipe to use for combined 1F1B. Currently only 'ep_a2a' and 'golden' are supported."""
split_bw: bool = False
"""If true, split dgrad and wgrad for better overlapping in combined 1F1B."""
@dataclass
class TransformerConfigPatch(TransformerConfig, ExtraTransformerConfig):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment