"vscode:/vscode.git/clone" did not exist on "1bd24577200a6eb9ba8ecf22457e99278a34abc6"
Commit 56819e16 authored by dongcl's avatar dongcl
Browse files

Merge branch 'a2a_overlap', support 1f1b overlap

parent 1e8185f4
......@@ -39,10 +39,9 @@ def a2a_overlap_adaptation(patches_manager):
create_dummy=True)
# backward_dw
if is_te_min_version("2.4.0.dev0"):
patches_manager.register_patch('megatron.core.extensions.transformer_engine._get_extra_te_kwargs',
_get_extra_te_kwargs_wrapper,
apply_wrapper=True)
patches_manager.register_patch('megatron.core.extensions.transformer_engine._get_extra_te_kwargs',
_get_extra_te_kwargs_wrapper,
apply_wrapper=True)
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELinear',
TELinear)
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear',
......
import os
import copy
import torch
import dataclasses
import transformer_engine as te
......@@ -7,6 +8,7 @@ from functools import wraps
from typing import Any, Optional, Callable
from packaging.version import Version as PkgVersion
from megatron.training import get_args
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.tensor_parallel import get_cuda_rng_tracker
from megatron.core.utils import get_te_version, is_te_min_version
......@@ -25,15 +27,17 @@ from megatron.core.parallel_state import (
)
def _get_extra_te_kwargs_wrapper(fn):
@wraps(fn)
def _get_extra_te_kwargs_wrapper(_get_extra_te_kwargs_func):
@wraps(_get_extra_te_kwargs_func)
def wrapper(config: TransformerConfig):
extra_transformer_engine_kwargs = fn(config)
extra_transformer_engine_kwargs = _get_extra_te_kwargs_func(config)
if hasattr(config, "split_bw"):
extra_transformer_engine_kwargs["delay_wgrad_compute"] = config.split_bw
return extra_transformer_engine_kwargs
return wrapper
if is_te_min_version("2.3.0.dev0"):
return wrapper
return _get_extra_te_kwargs_func
class TELinear(MegatronCoreTELinear):
......@@ -66,8 +70,14 @@ class TELinear(MegatronCoreTELinear):
tp_comm_buffer_name: Optional[str] = None,
is_expert: bool = False,
):
self.split_bw = config.split_bw if hasattr(config, "split_bw") else False
assert not self.split_bw, "split_bw is currently not supported"
args = get_args()
self.split_bw = args.split_bw if hasattr(args, "split_bw") else False
if not is_te_min_version("2.3.0.dev0"):
assert not self.split_bw, "split_bw is currently not supported"
if self.split_bw:
config = copy.copy(config)
config.split_bw = True
super().__init__(
input_size,
......@@ -86,6 +96,8 @@ class TELinear(MegatronCoreTELinear):
if not self.split_bw:
return
return super(MegatronCoreTELinear, self).backward_dw()
class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinear):
"""
......@@ -107,8 +119,14 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: Optional[str] = None,
):
self.split_bw = config.split_bw if hasattr(config, "split_bw") else False
assert not self.split_bw, "split_bw is currently not supported"
args = get_args()
self.split_bw = args.split_bw if hasattr(args, "split_bw") else False
if not is_te_min_version("2.3.0.dev0"):
assert not self.split_bw, "split_bw is currently not supported"
if self.split_bw:
config = copy.copy(config)
config.split_bw = True
super().__init__(
input_size,
......@@ -127,6 +145,8 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
if not self.split_bw:
return
return super(MegatronCoreTELayerNormColumnParallelLinear, self).backward_dw()
class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
def __init__(
......@@ -289,8 +309,14 @@ if is_te_min_version("1.9.0.dev0"):
is_expert: bool = False,
tp_comm_buffer_name: Optional[str] = None,
):
self.split_bw = config.split_bw if hasattr(config, "split_bw") else False
assert not self.split_bw, "split_bw is currently not supported"
args = get_args()
self.split_bw = args.split_bw if hasattr(args, "split_bw") else False
if not is_te_min_version("2.3.0.dev0"):
assert not self.split_bw, "split_bw is currently not supported"
if self.split_bw:
config = copy.copy(config)
config.split_bw = True
super().__init__(
num_gemms,
......@@ -308,3 +334,5 @@ if is_te_min_version("1.9.0.dev0"):
def backward_dw(self):
if not self.split_bw:
return
return super(MegatronCoreTEGroupedLinear, self).backward_dw()
......@@ -540,6 +540,8 @@ class ModelChunkSchedulePlan(AbstractSchedulePlan):
def state(self):
return self._model_chunk_state
# F_DISPATCH_B_MLP_SYNC_EVENT = torch.cuda.Event()
F_DISPATCH_B_MLP_SYNC_EVENT = None
def schedule_layer_1f1b(
f_layer,
......@@ -579,13 +581,17 @@ def schedule_layer_1f1b(
with f_context:
f_input = f_layer.attn.forward(f_input)
f_dispatch_b_mlp_sync_event = None
if f_layer is not None and b_layer is not None:
f_dispatch_b_mlp_sync_event = F_DISPATCH_B_MLP_SYNC_EVENT
if f_layer is not None:
with f_context:
f_input = f_layer.dispatch.forward(f_input)
f_input = f_layer.dispatch.forward(f_input, stream_record_event=f_dispatch_b_mlp_sync_event)
if b_layer is not None:
with b_context:
b_grad = b_layer.mlp.backward(b_grad)
b_grad = b_layer.mlp.backward(b_grad, stream_wait_event=f_dispatch_b_mlp_sync_event)
b_grad = b_layer.dispatch.backward(b_grad)
b_layer.mlp.dw()
......@@ -688,32 +694,28 @@ def schedule_chunk_1f1b(
)
torch.cuda.nvtx.range_pop()
# tail forward
f_input = layer_pre_forward()
del layer_pre_forward
# tail backward
grad = layer_pre_backward()
del layer_pre_backward
with b_context:
for i in range(overlaped_layers, b_num_layers):
b_layer = b_schedule_plan.get_layer(b_num_layers - 1 - i)
torch.cuda.nvtx.range_push(f"layer_{b_num_layers - 1 - i}b")
tmp, grad, _ = schedule_layer_1f1b(None, b_layer, b_grad=grad)
_, grad, _ = schedule_layer_1f1b(None, b_layer, b_grad=grad)
torch.cuda.nvtx.range_pop()
if b_schedule_plan is not None:
b_schedule_plan.pre_process.backward(grad)
# tail forward
f_input = layer_pre_forward()
del layer_pre_forward
with f_context:
for i in range(overlaped_layers, f_num_layers):
f_layer = f_schedule_plan.get_layer(i)
torch.cuda.nvtx.range_push(f"layer_{i}f")
f_input, tmp, _ = schedule_layer_1f1b(f_layer, None, f_input=f_input)
f_input, _, _ = schedule_layer_1f1b(f_layer, None, f_input=f_input)
torch.cuda.nvtx.range_pop()
if f_schedule_plan is not None and f_schedule_plan.post_process is not None:
f_input = f_schedule_plan.post_process.forward(f_input)
# output pp send receive, overlapped with attn backward
if f_schedule_plan is not None and post_forward is not None:
with f_context:
......@@ -730,6 +732,13 @@ def schedule_chunk_1f1b(
layer_pre_backward_dw()
del layer_pre_backward_dw
with f_context:
if f_schedule_plan is not None and f_schedule_plan.post_process is not None:
f_input = f_schedule_plan.post_process.forward(f_input)
with b_context:
if b_schedule_plan is not None:
b_schedule_plan.pre_process.backward(grad)
if f_schedule_plan:
f_schedule_plan.wait_current_stream()
if b_schedule_plan:
......@@ -764,7 +773,7 @@ def build_model_chunk_schedule_plan(
state.attention_mask = attention_mask
state.decoder_input = decoder_input
state.labels = labels
state.inference_context =inference_context
state.inference_context = inference_context
state.packed_seq_params = packed_seq_params
state.extra_block_kwargs = extra_block_kwargs
state.runtime_gather_output = runtime_gather_output
......
import contextlib
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, List, Tuple, Union
from typing import List, Union
import torch
from torch import Tensor
......@@ -12,13 +12,10 @@ from megatron.core.distributed import DistributedDataParallel
from megatron.core.transformer.module import Float16Module
from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler
from megatron.core.utils import get_attr_wrapped_model, make_viewless_tensor
# Types
Shape = Union[List[int], torch.Size]
def make_viewless(e):
"""make_viewless util func"""
e = make_viewless_tensor(inp=e, requires_grad=e.requires_grad, keep_graph=True)
......@@ -56,6 +53,11 @@ class ScheduleNode:
self.outputs = None
def default_backward_func(self, outputs, output_grad):
# Handle scalar output
if output_grad is None:
assert outputs.numel() == 1, "implicit grad requires scalar output."
output_grad = torch.ones_like(outputs, memory_format=torch.preserve_format)
Variable._execution_engine.run_backward(
tensors=outputs,
grad_tensors=output_grad,
......@@ -67,17 +69,20 @@ class ScheduleNode:
)
return output_grad
def forward(self, inputs=()):
def forward(self, inputs=(), stream_wait_event=None, stream_record_event=None):
"""schedule node forward"""
if not isinstance(inputs, tuple):
inputs = (inputs,)
return self._forward(*inputs)
return self._forward(*inputs, stream_wait_event=stream_wait_event, stream_record_event=stream_record_event)
def _forward(self, *inputs):
def _forward(self, *inputs, stream_wait_event=None, stream_record_event=None):
with stream_acquire_context(self.stream, self.event):
torch.cuda.nvtx.range_push(f"{self.name} forward")
with torch.cuda.stream(self.stream):
if stream_wait_event is not None:
stream_wait_event.wait(self.stream)
self.inputs = [make_viewless(e).detach() if e is not None else None for e in inputs]
for i, input in enumerate(self.inputs):
if input is not None:
......@@ -92,6 +97,10 @@ class ScheduleNode:
data = tuple([make_viewless(e) if isinstance(e, Tensor) else e for e in data])
self.output = data
if stream_record_event is not None:
stream_record_event.record(self.stream)
torch.cuda.nvtx.range_pop()
if self.free_inputs:
......@@ -105,16 +114,19 @@ class ScheduleNode:
"""get the forward output"""
return self.output
def backward(self, output_grad):
def backward(self, output_grad, stream_wait_event=None, stream_record_event=None):
"""schedule node backward"""
if not isinstance(output_grad, tuple):
output_grad = (output_grad,)
return self._backward(*output_grad)
return self._backward(*output_grad, stream_wait_event=stream_wait_event, stream_record_event=stream_record_event)
def _backward(self, *output_grad):
def _backward(self, *output_grad, stream_wait_event=None, stream_record_event=None):
with stream_acquire_context(self.stream, self.event):
torch.cuda.nvtx.range_push(f"{self.name} backward")
with torch.cuda.stream(self.stream):
if stream_wait_event is not None:
stream_wait_event.wait(self.stream)
outputs = self.output
if not isinstance(outputs, tuple):
outputs = (outputs,)
......@@ -125,6 +137,10 @@ class ScheduleNode:
output_grad = self.backward_func(outputs, output_grad)
else:
output_grad = self.default_backward_func(outputs, output_grad)
if stream_record_event is not None:
stream_record_event.record(self.stream)
torch.cuda.nvtx.range_pop()
# output_grad maybe from another stream
......@@ -192,17 +208,6 @@ def schedule_chunk_1f1b(
)
def schedule_chunk_forward(schedule_plan):
"""model level fine-grained forward schedule"""
f_input = schedule_chunk_1f1b(schedule_plan, None, None)
return f_input
def schedule_chunk_backward(schedule_plan, grad):
"""model level fine-grained backward schedule"""
tmp = schedule_chunk_1f1b(None, schedule_plan, grad)
_COMP_STREAM = None
_COM_STREAM = None
......@@ -215,7 +220,7 @@ def set_streams(comp_stream=None, com_stream=None):
return
if comp_stream is None:
comp_stream = torch.cuda.Stream(device="cuda")
comp_stream = torch.cuda.current_stream()
if com_stream is None:
com_stream = torch.cuda.Stream(device="cuda")
......@@ -342,7 +347,7 @@ def forward_backward_step(
Tensor or list[Tensor]: The output object(s) from the forward step.
Tensor: The number of tokens.
"""
from .schedules import set_current_microbatch
from megatron.core.pipeline_parallel.schedules import set_current_microbatch
if config.timers is not None:
config.timers('forward-compute', log_level=2).start()
......@@ -441,12 +446,13 @@ def forward_backward_step(
output_tensor, num_tokens, loss_reduced = outputs
if not config.calculate_per_token_loss:
output_tensor /= num_tokens
output_tensor *= parallel_state.get_context_parallel_world_size()
output_tensor /= num_microbatches
else:
# preserve legacy loss averaging behavior
# (ie, over the number of microbatches)
# preserve legacy loss averaging behavior (ie, over the number of microbatches)
assert len(outputs) == 2
output_tensor, loss_reduced = outputs
output_tensor *= parallel_state.get_context_parallel_world_size()
output_tensor = output_tensor / num_microbatches
forward_data_store.append(loss_reduced)
......@@ -464,12 +470,11 @@ def forward_backward_step(
# Since we use a trick to do backward on the auxiliary loss, we need to set the scale
# explicitly.
if hasattr(config, 'num_moe_experts') and config.num_moe_experts is not None:
# Calculate the loss scale based on the grad_scale_func if available,
# else default to 1.
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale = (
config.grad_scale_func(torch.ones(1, device=output_tensor.device))
if config.grad_scale_func is not None
else torch.tensor(1.0)
else torch.ones(1, device=output_tensor.device)
)
# Set the loss scale
if config.calculate_per_token_loss:
......@@ -477,8 +482,23 @@ def forward_backward_step(
else:
MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
# Set the loss scale for Multi-Token Prediction (MTP) loss.
if hasattr(config, 'mtp_num_layers') and config.mtp_num_layers is not None:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale = (
config.grad_scale_func(torch.ones(1, device=output_tensor.device))
if config.grad_scale_func is not None
else torch.ones(1, device=output_tensor.device)
)
# Set the loss scale
if config.calculate_per_token_loss:
MTPLossAutoScaler.set_loss_scale(loss_scale)
else:
MTPLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
if not unwrap_output_tensor:
output_tensor, num_tokens = [output_tensor], num_tokens
# backward post process
input_tensor_grad = None
if b_model is not None:
......
import contextlib
from typing import Callable, Iterator, List, Optional, Union
from typing import Iterator, List, Union
import torch
......@@ -7,10 +7,8 @@ from megatron.training import get_args
from megatron.core import parallel_state
from megatron.core.enums import ModelType
from megatron.core.pipeline_parallel import p2p_communication
from megatron.core.pipeline_parallel.schedules import set_current_microbatch
from megatron.core.transformer.cuda_graphs import create_cudagraphs
from megatron.core.utils import (
get_attr_wrapped_model,
get_model_config,
get_model_type,
get_model_xattn,
......@@ -448,6 +446,8 @@ def forward_backward_pipelining_with_interleaving(
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
nonlocal output_tensor_grads
model_chunk_id = get_model_chunk_id(virtual_microbatch_id, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
......
......@@ -40,8 +40,7 @@ class ExtraTransformerConfig:
combined_1f1b_recipe: str = 'ep_a2a'
"""Recipe to use for combined 1F1B. Currently only 'ep_a2a' and 'golden' are supported."""
split_bw: bool = False
"""If true, split dgrad and wgrad for better overlapping in combined 1F1B."""
# split_bw: bool = False
@dataclass
......
from functools import partial
from typing import Any, Optional
import torch
from torch import Tensor
from megatron.core import tensor_parallel
......@@ -12,8 +10,7 @@ from megatron.core.utils import (
)
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.transformer_layer import TransformerLayer as MegatronCoreTransformerLayer
from dcu_megatron.core.transformer.utils import SubmoduleCallables, TransformerLayerSubmoduleCallables
from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher
class TransformerLayer(MegatronCoreTransformerLayer):
......@@ -34,7 +31,10 @@ class TransformerLayer(MegatronCoreTransformerLayer):
inference_params: Optional[Any] = None,
):
if not isinstance(self.mlp, MoELayer):
if (
not isinstance(self.mlp, MoELayer)
or not isinstance(self.mlp.token_dispatcher, MoEAlltoAllTokenDispatcher)
):
return super().forward(
hidden_states=hidden_states,
context=context,
......@@ -55,7 +55,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
pre_mlp_layernorm_output,
tokens_per_expert,
permutated_local_input_tokens,
probs,
_,
) = self._submodule_attention_router_compound_forward(
hidden_states,
attention_mask,
......
......@@ -97,7 +97,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
print("> initializing torch distributed ...", flush=True)
# Manually set the device ids.
if device_count > 0:
torch.cuda.set_device(args.local_rank % device_count)
torch.cuda.set_device(args.local_rank)
device_id = torch.device(f'cuda:{args.local_rank}')
else:
device_id = None
......
......@@ -195,7 +195,8 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
active=args.profile_step_end-args.profile_step_start,
repeat=1),
on_trace_ready=trace_handler,
record_shapes=True,
record_shapes=True,
with_stack=True,
)
prof.start()
elif args.profile and torch.distributed.get_rank() in args.profile_ranks and args.use_hip_profiler:
......
......@@ -16,9 +16,6 @@ from megatron.core.enums import ModelType
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig
from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset
from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import (
get_gpt_heterogeneous_layer_spec,
)
from megatron.core.rerun_state_machine import get_rerun_state_machine
import megatron.legacy.model
from megatron.core.models.gpt import GPTModel
......@@ -38,7 +35,6 @@ from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_with_transformer_engine_spec,
get_gpt_mtp_block_spec,
)
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
from dcu_megatron import megatron_adaptor
......@@ -102,8 +98,6 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
if args.num_experts:
# Define the decoder block spec
transformer_layer_spec = get_gpt_decoder_block_spec(config, use_transformer_engine=use_te, normalization=args.normalization)
elif args.heterogeneous_layers_config_path is not None:
transformer_layer_spec = get_gpt_heterogeneous_layer_spec(config, use_te)
else:
# Define the decoder layer spec
if use_te:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment