Commit 5890bb4c authored by dongcl's avatar dongcl
Browse files

fix import error

parent 4bb958ec
...@@ -5,13 +5,17 @@ def a2a_overlap_adaptation(patches_manager): ...@@ -5,13 +5,17 @@ def a2a_overlap_adaptation(patches_manager):
""" """
patches_manager: MegatronPatchesManager patches_manager: MegatronPatchesManager
""" """
from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TERowParallelLinear
from ..core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher from ..core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher
from ..core.transformer.transformer_block import TransformerBlock from ..core.transformer.transformer_block import TransformerBlock
from ..core.transformer.transformer_layer import TransformerLayer from ..core.transformer.transformer_layer import TransformerLayer
from ..core.models.gpt.gpt_model import GPTModel from ..core.models.gpt.gpt_model import GPTModel
from ..core.pipeline_parallel.schedules import get_pp_rank_microbatches, forward_backward_pipelining_with_interleaving from ..core.pipeline_parallel.schedules import get_pp_rank_microbatches, forward_backward_pipelining_with_interleaving
from ..core.extensions.transformer_engine import _get_extra_te_kwargs_wrapper, TELinear, TELayerNormColumnParallelLinear from ..core.extensions.transformer_engine import (
_get_extra_te_kwargs_wrapper,
TELinear,
TELayerNormColumnParallelLinear,
)
from ..core.transformer.multi_latent_attention import MLASelfAttention from ..core.transformer.multi_latent_attention import MLASelfAttention
from ..core.transformer.mlp import MLP from ..core.transformer.mlp import MLP
from ..core.transformer.moe.experts import TEGroupedMLP from ..core.transformer.moe.experts import TEGroupedMLP
...@@ -38,23 +42,34 @@ def a2a_overlap_adaptation(patches_manager): ...@@ -38,23 +42,34 @@ def a2a_overlap_adaptation(patches_manager):
GPTModel) GPTModel)
# backward_dw # backward_dw
patches_manager.register_patch('megatron.core.extensions.transformer_engine._get_extra_te_kwargs', # patches_manager.register_patch('megatron.core.extensions.transformer_engine._get_extra_te_kwargs',
_get_extra_te_kwargs_wrapper, # _get_extra_te_kwargs_wrapper,
apply_wrapper=True) # apply_wrapper=True)
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELinear', patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELinear',
TELinear) TELinear)
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear', patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear',
TELayerNormColumnParallelLinear) TELayerNormColumnParallelLinear)
TEColumnParallelLinear.__bases__ = (TELinear,)
TERowParallelLinear.__bases__ = (TELinear,)
if is_te_min_version("1.9.0.dev0"): if is_te_min_version("1.9.0.dev0"):
from megatron.core.extensions.transformer_engine import TEColumnParallelGroupedLinear, TERowParallelGroupedLinear
from ..core.extensions.transformer_engine import TEGroupedLinear from ..core.extensions.transformer_engine import TEGroupedLinear
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TEGroupedLinear', patches_manager.register_patch('megatron.core.extensions.transformer_engine.TEGroupedLinear',
TEGroupedLinear) TEGroupedLinear)
TEColumnParallelGroupedLinear.__bases__ = (TEGroupedLinear,)
TERowParallelGroupedLinear.__bases__ = (TEGroupedLinear,)
patches_manager.register_patch('megatron.core.transformer.multi_latent_attention.MLASelfAttention', patches_manager.register_patch('megatron.core.transformer.multi_latent_attention.MLASelfAttention.backward_dw',
MLASelfAttention) MLASelfAttention.backward_dw,
patches_manager.register_patch('megatron.core.transformer.mlp.MLP', create_dummy=True)
MLP) patches_manager.register_patch('megatron.core.transformer.mlp.MLP.backward_dw',
patches_manager.register_patch('megatron.core.transformer.moe.experts.TEGroupedMLP', MLP.backward_dw,
TEGroupedMLP) create_dummy=True)
patches_manager.register_patch('megatron.core.transformer.moe.moe_layer.MoELayer', patches_manager.register_patch('megatron.core.transformer.moe.experts.TEGroupedMLP.backward_dw',
MoELayer) TEGroupedMLP.backward_dw,
create_dummy=True)
patches_manager.register_patch('megatron.core.transformer.moe.moe_layer.MoELayer.backward_dw',
MoELayer.backward_dw,
create_dummy=True)
...@@ -104,7 +104,7 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -104,7 +104,7 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__', MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__',
gpt_model_init_wrapper, gpt_model_init_wrapper,
apply_wrapper=True) apply_wrapper=True)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward',
gpt_model_forward) gpt_model_forward)
def patch_core_transformers(self): def patch_core_transformers(self):
...@@ -122,12 +122,12 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -122,12 +122,12 @@ class CoreAdaptation(MegatronAdaptationABC):
MLATransformerConfigPatch) MLATransformerConfigPatch)
# Moe # Moe
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity', # MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity',
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False}), # torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False}),
apply_wrapper=True) # apply_wrapper=True)
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func', # MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}), # torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
apply_wrapper=True) # apply_wrapper=True)
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.permute', MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.permute',
torch.compile(mode='max-autotune-no-cudagraphs'), torch.compile(mode='max-autotune-no-cudagraphs'),
apply_wrapper=True) apply_wrapper=True)
......
...@@ -3,6 +3,7 @@ import torch ...@@ -3,6 +3,7 @@ import torch
import dataclasses import dataclasses
import transformer_engine as te import transformer_engine as te
from functools import wraps
from typing import Any, Optional, Callable from typing import Any, Optional, Callable
from packaging.version import Version as PkgVersion from packaging.version import Version as PkgVersion
...@@ -18,7 +19,6 @@ from megatron.core.extensions.transformer_engine import TELinear as MegatronCore ...@@ -18,7 +19,6 @@ from megatron.core.extensions.transformer_engine import TELinear as MegatronCore
from megatron.core.extensions.transformer_engine import TELayerNormColumnParallelLinear as MegatronCoreTELayerNormColumnParallelLinear from megatron.core.extensions.transformer_engine import TELayerNormColumnParallelLinear as MegatronCoreTELayerNormColumnParallelLinear
from megatron.core.parallel_state import ( from megatron.core.parallel_state import (
get_context_parallel_global_ranks,
get_context_parallel_group, get_context_parallel_group,
get_hierarchical_context_parallel_groups, get_hierarchical_context_parallel_groups,
get_tensor_model_parallel_group, get_tensor_model_parallel_group,
...@@ -29,7 +29,7 @@ def _get_extra_te_kwargs_wrapper(fn): ...@@ -29,7 +29,7 @@ def _get_extra_te_kwargs_wrapper(fn):
@wraps(fn) @wraps(fn)
def wrapper(config: TransformerConfig): def wrapper(config: TransformerConfig):
extra_transformer_engine_kwargs = fn(config) extra_transformer_engine_kwargs = fn(config)
extra_transformer_engine_kwargs["delay_wgrad_compute"] = config.get("split_bw", False) extra_transformer_engine_kwargs["delay_wgrad_compute"] = config.split_bw if hasattr(config, "split_bw") else False
return extra_transformer_engine_kwargs return extra_transformer_engine_kwargs
return wrapper return wrapper
...@@ -66,7 +66,7 @@ class TELinear(MegatronCoreTELinear): ...@@ -66,7 +66,7 @@ class TELinear(MegatronCoreTELinear):
is_expert: bool = False, is_expert: bool = False,
tp_group: Optional[torch.distributed.ProcessGroup] = None, tp_group: Optional[torch.distributed.ProcessGroup] = None,
): ):
self.split_bw = config.get("split_bw", False) self.split_bw = config.split_bw if hasattr(config, "split_bw") else False
assert not self.split_bw, "split_bw is currently not supported" assert not self.split_bw, "split_bw is currently not supported"
super().__init__( super().__init__(
...@@ -109,7 +109,7 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea ...@@ -109,7 +109,7 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
tp_comm_buffer_name: Optional[str] = None, tp_comm_buffer_name: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None, tp_group: Optional[torch.distributed.ProcessGroup] = None,
): ):
self.split_bw = config.get("split_bw", False) self.split_bw = config.split_bw if hasattr(config, "split_bw") else False
assert not self.split_bw, "split_bw is currently not supported" assert not self.split_bw, "split_bw is currently not supported"
super().__init__( super().__init__(
...@@ -314,7 +314,7 @@ if is_te_min_version("1.9.0.dev0"): ...@@ -314,7 +314,7 @@ if is_te_min_version("1.9.0.dev0"):
tp_comm_buffer_name: Optional[str] = None, tp_comm_buffer_name: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None, tp_group: Optional[torch.distributed.ProcessGroup] = None,
): ):
self.split_bw = config.get("split_bw", False) self.split_bw = config.split_bw if hasattr(config, "split_bw") else False
assert not self.split_bw, "split_bw is currently not supported" assert not self.split_bw, "split_bw is currently not supported"
super().__init__( super().__init__(
......
import contextlib import contextlib
import weakref import weakref
from typing import Any, Callable, Optional, Tuple, Union from collections import OrderedDict
from typing import Optional
import torch import torch
from torch import Tensor from torch import Tensor
from megatron.core.pipeline_parallel.combined_1f1b import ( from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer import transformer_layer
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.utils import deprecate_inference_params
from dcu_megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllPerBatchState
from dcu_megatron.core.pipeline_parallel.combined_1f1b import (
AbstractSchedulePlan, AbstractSchedulePlan,
ScheduleNode, ScheduleNode,
get_com_stream, get_com_stream,
get_comp_stream,
make_viewless, make_viewless,
) )
from megatron.core.transformer import transformer_layer
from megatron.core.transformer.module import float16_to_fp32
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllPerBatchState
def weak_method(method): def weak_method(method):
...@@ -43,6 +48,7 @@ class PreProcessNode(ScheduleNode): ...@@ -43,6 +48,7 @@ class PreProcessNode(ScheduleNode):
input_ids = self.model_chunk_state.input_ids input_ids = self.model_chunk_state.input_ids
position_ids = self.model_chunk_state.position_ids position_ids = self.model_chunk_state.position_ids
inference_context = self.model_chunk_state.inference_context inference_context = self.model_chunk_state.inference_context
inference_params = self.model_chunk_state.inference_params
packed_seq_params = self.model_chunk_state.packed_seq_params packed_seq_params = self.model_chunk_state.packed_seq_params
inference_context = deprecate_inference_params(inference_context, inference_params) inference_context = deprecate_inference_params(inference_context, inference_params)
...@@ -121,22 +127,6 @@ class PostProcessNode(ScheduleNode): ...@@ -121,22 +127,6 @@ class PostProcessNode(ScheduleNode):
self.gpt_model = gpt_model self.gpt_model = gpt_model
self.model_chunk_state = model_chunk_state self.model_chunk_state = model_chunk_state
state.input_ids = input_ids
state.position_ids = position_ids
state.attention_mask = attention_mask
state.decoder_input = decoder_input
state.labels = labels
state.inference_context =inference_context
state.packed_seq_params = packed_seq_params
state.extra_block_kwargs = extra_block_kwargs
state.runtime_gather_output = runtime_gather_output
state.inference_params = inference_params
state.loss_mask = loss_mask
state.context = None
state.context_mask = None
state.attention_bias = None
def forward_impl(self, hidden_states): def forward_impl(self, hidden_states):
gpt_model = self.gpt_model gpt_model = self.gpt_model
...@@ -145,11 +135,13 @@ class PostProcessNode(ScheduleNode): ...@@ -145,11 +135,13 @@ class PostProcessNode(ScheduleNode):
labels = self.model_chunk_state.labels labels = self.model_chunk_state.labels
loss_mask = self.model_chunk_state.loss_mask loss_mask = self.model_chunk_state.loss_mask
attention_mask = self.model_chunk_state.attention_mask attention_mask = self.model_chunk_state.attention_mask
decoder_input = self.model_chunk_state.decoder_input
inference_params= self.model_chunk_state.inference_params inference_params= self.model_chunk_state.inference_params
rotary_pos_emb = self.model_chunk_state.rotary_pos_emb rotary_pos_emb = self.model_chunk_state.rotary_pos_emb
rotary_pos_cos = self.model_chunk_state.rotary_pos_cos rotary_pos_cos = self.model_chunk_state.rotary_pos_cos
rotary_pos_sin = self.model_chunk_state.rotary_pos_sin rotary_pos_sin = self.model_chunk_state.rotary_pos_sin
packed_seq_params = self.model_chunk_state.packed_seq_params packed_seq_params = self.model_chunk_state.packed_seq_params
extra_block_kwargs = self.model_chunk_state.extra_block_kwargs
sequence_len_offset = self.model_chunk_state.sequence_len_offset sequence_len_offset = self.model_chunk_state.sequence_len_offset
runtime_gather_output = self.model_chunk_state.runtime_gather_output runtime_gather_output = self.model_chunk_state.runtime_gather_output
inference_context = self.model_chunk_state.inference_context inference_context = self.model_chunk_state.inference_context
...@@ -267,6 +259,9 @@ class TransformerLayerNode(ScheduleNode): ...@@ -267,6 +259,9 @@ class TransformerLayerNode(ScheduleNode):
def backward_impl(self, outputs, output_grad): def backward_impl(self, outputs, output_grad):
detached_grad = tuple([e.grad for e in self.detached]) detached_grad = tuple([e.grad for e in self.detached])
grads = output_grad + detached_grad grads = output_grad + detached_grad
# if len(detached_grad):
# print(f"output_grad: {grads}")
self.default_backward_func(outputs + self.before_detached, grads) self.default_backward_func(outputs + self.before_detached, grads)
self.before_detached = None self.before_detached = None
self.detached = None self.detached = None
...@@ -296,7 +291,6 @@ class MoeAttnNode(TransformerLayerNode): ...@@ -296,7 +291,6 @@ class MoeAttnNode(TransformerLayerNode):
tokens_per_expert, tokens_per_expert,
permutated_local_input_tokens, permutated_local_input_tokens,
permuted_probs, permuted_probs,
probs,
) = self.layer._submodule_attention_router_compound_forward( ) = self.layer._submodule_attention_router_compound_forward(
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -312,7 +306,6 @@ class MoeAttnNode(TransformerLayerNode): ...@@ -312,7 +306,6 @@ class MoeAttnNode(TransformerLayerNode):
self.common_state.tokens_per_expert = tokens_per_expert self.common_state.tokens_per_expert = tokens_per_expert
# detached here # detached here
self.common_state.probs = self.detach(probs)
self.common_state.residual = self.detach(hidden_states) self.common_state.residual = self.detach(hidden_states)
self.common_state.pre_mlp_layernorm_output = self.detach(pre_mlp_layernorm_output) self.common_state.pre_mlp_layernorm_output = self.detach(pre_mlp_layernorm_output)
...@@ -334,7 +327,7 @@ class MoeDispatchNode(TransformerLayerNode): ...@@ -334,7 +327,7 @@ class MoeDispatchNode(TransformerLayerNode):
) )
# release tensor not used by backward # release tensor not used by backward
# inputs.untyped_storage().resize_(0) # inputs.untyped_storage().resize_(0)
self.common_state.tokens_per_expert = = tokens_per_expert self.common_state.tokens_per_expert = tokens_per_expert
return global_input_tokens, global_probs return global_input_tokens, global_probs
...@@ -345,7 +338,7 @@ class MoeMlPNode(TransformerLayerNode): ...@@ -345,7 +338,7 @@ class MoeMlPNode(TransformerLayerNode):
token_dispatcher = self.layer.mlp.token_dispatcher token_dispatcher = self.layer.mlp.token_dispatcher
with token_dispatcher.per_batch_state_context(self.common_state): with token_dispatcher.per_batch_state_context(self.common_state):
expert_output, shared_expert_output, mlp_bias = self.layer._submodule_moe_forward( expert_output, shared_expert_output, mlp_bias = self.layer._submodule_moe_forward(
self.common_state.tokens_per_expert, global_input_tokens, global_prob, pre_mlp_layernorm_output self.common_state.tokens_per_expert, global_input_tokens, global_probs, pre_mlp_layernorm_output
) )
assert mlp_bias is None assert mlp_bias is None
...@@ -372,9 +365,7 @@ class MoeCombineNode(TransformerLayerNode): ...@@ -372,9 +365,7 @@ class MoeCombineNode(TransformerLayerNode):
) )
cur_stream = torch.cuda.current_stream() cur_stream = torch.cuda.current_stream()
self.common_state.residual.record_stream(cur_stream) self.common_state.residual.record_stream(cur_stream)
self.common_state.probs.record_stream(cur_stream)
self.common_state.residual = None self.common_state.residual = None
self.common_state.probs = None
return output return output
...@@ -554,21 +545,18 @@ def schedule_layer_1f1b( ...@@ -554,21 +545,18 @@ def schedule_layer_1f1b(
f_context = f_context if f_context is not None else contextlib.nullcontext() f_context = f_context if f_context is not None else contextlib.nullcontext()
b_context = b_context if b_context is not None else contextlib.nullcontext() b_context = b_context if b_context is not None else contextlib.nullcontext()
if pre_forward is not None: if pre_forward is not None:
assert f_input is None assert f_input is None
# combine from last iter # combine from last iter
f_input = pre_forward() f_input = pre_forward()
del pre_forward del pre_forward
if pre_backward is not None: if pre_backward is not None:
# attn backward from last iter # attn backward from last iter
assert b_grad is None assert b_grad is None
b_grad = pre_backward() b_grad = pre_backward()
del pre_backward del pre_backward
if b_layer is not None: if b_layer is not None:
with b_context: with b_context:
b_grad = b_layer.combine.backward(b_grad) b_grad = b_layer.combine.backward(b_grad)
...@@ -577,7 +565,6 @@ def schedule_layer_1f1b( ...@@ -577,7 +565,6 @@ def schedule_layer_1f1b(
pre_backward_dw() pre_backward_dw()
del pre_backward_dw del pre_backward_dw
if f_layer is not None: if f_layer is not None:
with f_context: with f_context:
f_input = f_layer.attn.forward(f_input) f_input = f_layer.attn.forward(f_input)
...@@ -592,7 +579,6 @@ def schedule_layer_1f1b( ...@@ -592,7 +579,6 @@ def schedule_layer_1f1b(
b_grad = b_layer.dispatch.backward(b_grad) b_grad = b_layer.dispatch.backward(b_grad)
b_layer.mlp.dw() b_layer.mlp.dw()
if f_layer is not None: if f_layer is not None:
with f_context: with f_context:
f_input = f_layer.mlp.forward(f_input) f_input = f_layer.mlp.forward(f_input)
...@@ -614,7 +600,6 @@ def schedule_layer_1f1b( ...@@ -614,7 +600,6 @@ def schedule_layer_1f1b(
with b_context: with b_context:
b_layer.attn.dw() b_layer.attn.dw()
if f_layer and b_layer: if f_layer and b_layer:
return next_iter_pre_forward, next_iter_pre_backward, next_iter_pre_backward_dw return next_iter_pre_forward, next_iter_pre_backward, next_iter_pre_backward_dw
else: else:
......
...@@ -7,10 +7,10 @@ from functools import wraps ...@@ -7,10 +7,10 @@ from functools import wraps
import torch import torch
from torch import Tensor from torch import Tensor
from megatron.core import InferenceParams, tensor_parallel
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.utils import WrappedTensor, deprecate_inference_params
from megatron.core.models.gpt import GPTModel as MegatronCoreGPTModel from megatron.core.models.gpt import GPTModel as MegatronCoreGPTModel
......
...@@ -427,7 +427,7 @@ def forward_backward_step( ...@@ -427,7 +427,7 @@ def forward_backward_step(
if f_model: if f_model:
with f_context: with f_context:
num_tokens = torch.tensor(0, dtype=torch.int) num_tokens = torch.tensor(0, dtype=torch.int)
if parallel_state.is_pipeline_last_stage(): if parallel_state.is_pipeline_last_stage(ignore_virtual=False):
if not collect_non_loss_data: if not collect_non_loss_data:
loss_node = ScheduleNode( loss_node = ScheduleNode(
loss_func, loss_func,
......
...@@ -2,17 +2,13 @@ import contextlib ...@@ -2,17 +2,13 @@ import contextlib
from typing import Callable, Iterator, List, Optional, Union from typing import Callable, Iterator, List, Optional, Union
import torch import torch
from torch.autograd.variable import Variable
from megatron.training import get_args from megatron.training import get_args
from megatron.core import parallel_state from megatron.core import parallel_state
from megatron.core.enums import ModelType from megatron.core.enums import ModelType
from megatron.core.pipeline_parallel import p2p_communication from megatron.core.pipeline_parallel import p2p_communication
from megatron.core.transformer.cuda_graphs import create_cudagraphs from megatron.core.transformer.cuda_graphs import create_cudagraphs
from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler
from megatron.core.utils import ( from megatron.core.utils import (
drain_embedding_wgrad_compute,
get_attr_wrapped_model, get_attr_wrapped_model,
get_model_config, get_model_config,
get_model_type, get_model_type,
...@@ -32,6 +28,18 @@ from megatron.core.pipeline_parallel.schedules import ( ...@@ -32,6 +28,18 @@ from megatron.core.pipeline_parallel.schedules import (
from .combined_1f1b import VppContextManager, forward_backward_step, set_streams, wrap_forward_func from .combined_1f1b import VppContextManager, forward_backward_step, set_streams, wrap_forward_func
def set_current_microbatch(model, microbatch_id):
"""Set the current microbatch."""
decoder_exists = True
decoder = None
try:
decoder = get_attr_wrapped_model(model, "decoder")
except RuntimeError:
decoder_exists = False
if decoder_exists and decoder is not None:
decoder.current_microbatch = microbatch_id
def get_pp_rank_microbatches( def get_pp_rank_microbatches(
num_microbatches, num_model_chunks, microbatch_group_size_per_vp_stage, forward_only=False num_microbatches, num_model_chunks, microbatch_group_size_per_vp_stage, forward_only=False
): ):
...@@ -541,7 +549,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -541,7 +549,7 @@ def forward_backward_pipelining_with_interleaving(
) )
# forward step # forward step
if parallel_state.is_pipeline_first_stage(): if parallel_state.is_pipeline_first_stage(ignore_virtual=False):
if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]): if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None) input_tensors[model_chunk_id].append(None)
...@@ -573,7 +581,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -573,7 +581,7 @@ def forward_backward_pipelining_with_interleaving(
enable_grad_sync() enable_grad_sync()
synchronized_model_chunks.add(model_chunk_id) synchronized_model_chunks.add(model_chunk_id)
if parallel_state.is_pipeline_last_stage(): if parallel_state.is_pipeline_last_stage(ignore_virtual=False):
if len(output_tensor_grads[model_chunk_id]) == 0: if len(output_tensor_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None) output_tensor_grads[model_chunk_id].append(None)
b_input_tensor = input_tensors[model_chunk_id].pop(0) b_input_tensor = input_tensors[model_chunk_id].pop(0)
...@@ -679,7 +687,6 @@ def forward_backward_pipelining_with_interleaving( ...@@ -679,7 +687,6 @@ def forward_backward_pipelining_with_interleaving(
post_backward=post_backward, post_backward=post_backward,
) )
else: else:
output_tensor = None
input_tensor_grad = None input_tensor_grad = None
if f_virtual_microbatch_id is not None: if f_virtual_microbatch_id is not None:
# forward pass # forward pass
...@@ -704,7 +711,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -704,7 +711,7 @@ def forward_backward_pipelining_with_interleaving(
input_tensor_grad = backward_step_helper(b_virtual_microbatch_id) input_tensor_grad = backward_step_helper(b_virtual_microbatch_id)
if post_backward is not None: if post_backward is not None:
input_tensor_grad = post_backward(input_tensor_grad) input_tensor_grad = post_backward(input_tensor_grad)
return output_tensor, input_tensor_grad return output_tensor if f_virtual_microbatch_id is not None else None, input_tensor_grad
# Run warmup forward passes. # Run warmup forward passes.
parallel_state.set_virtual_pipeline_model_parallel_rank(0) parallel_state.set_virtual_pipeline_model_parallel_rank(0)
...@@ -890,6 +897,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -890,6 +897,7 @@ def forward_backward_pipelining_with_interleaving(
output_tensor_grads[num_model_chunks - 1].append(bwd_recv_buffer[-1]) output_tensor_grads[num_model_chunks - 1].append(bwd_recv_buffer[-1])
# Run 1F1B in steady state. # Run 1F1B in steady state.
output_tensor = None
for k in range(num_microbatches_remaining): for k in range(num_microbatches_remaining):
# Forward pass. # Forward pass.
forward_k = k + num_warmup_microbatches forward_k = k + num_warmup_microbatches
......
from megatron.core.transformer.mlp import MLP as MegatronCoreMLP class MLP():
class MLP(MegatronCoreMLP):
def backward_dw(self): def backward_dw(self):
self.linear_fc2.backward_dw() self.linear_fc2.backward_dw()
self.linear_fc1.backward_dw() self.linear_fc1.backward_dw()
from megatron.core.transformer.experts import TEGroupedMLP as MegatronCoreTEGroupedMLP class TEGroupedMLP():
class TEGroupedMLP(MegatronCoreTEGroupedMLP):
def backward_dw(self): def backward_dw(self):
self.linear_fc2.backward_dw() self.linear_fc2.backward_dw()
self.linear_fc1.backward_dw() self.linear_fc1.backward_dw()
from megatron.core.transformer.moe.moe_layer import MoELayer as MegatronCoreMoELayer class MoELayer():
class MoELayer(MegatronCoreMoELayer):
def backward_dw(self): def backward_dw(self):
self.experts.backward_dw() self.experts.backward_dw()
self.shared_experts.backward_dw() self.shared_experts.backward_dw()
from contextlib import contextmanager
from typing import Optional, Tuple
import torch
from megatron.core.tensor_parallel import (
all_to_all,
gather_from_sequence_parallel_region,
reduce_scatter_to_sequence_parallel_region,
)
from megatron.core.transformer.moe.moe_utils import (
permute,
sort_chunks_by_idxs,
unpermute,
)
from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher as MegatronCoreMoEAlltoAllTokenDispatcher from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher as MegatronCoreMoEAlltoAllTokenDispatcher
...@@ -303,7 +318,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -303,7 +318,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
""" """
assert bias is None, "Bias is not supported in MoEAlltoAllTokenDispatcher" assert bias is None, "Bias is not supported in MoEAlltoAllTokenDispatcher"
hidden_states = self.combine_preprocess(hidden_states) hidden_states = self.combine_preprocess(hidden_states)
permutated_local_input_tokens = self.combine_all_to_all(hidden_states) permutated_local_input_tokens = self.combine_all_to_all(hidden_states)
output = self.combine_postprocess(permutated_local_input_tokens) output = self.combine_postprocess(permutated_local_input_tokens)
......
from megatron.core.transformer.multi_latent_attention import MLASelfAttention as MegatronCoreMLASelfAttention class MLASelfAttention():
class MLASelfAttention(MegatronCoreMLASelfAttention):
"""MLA Self-attention layer class """MLA Self-attention layer class
Self-attention layer takes input with size [s, b, h] Self-attention layer takes input with size [s, b, h]
......
...@@ -9,7 +9,7 @@ def transformer_block_init_wrapper(fn): ...@@ -9,7 +9,7 @@ def transformer_block_init_wrapper(fn):
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block # mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
config = args[0] if len(args) > 1 else kwargs['config'] config = args[0] if len(args) > 1 else kwargs['config']
if getattr(config, "mtp_num_layers", 0) > 0: if hasattr(config, "mtp_num_layers") and config.mtp_num_layers is not None:
self.main_final_layernorm = self.final_layernorm self.main_final_layernorm = self.final_layernorm
self.final_layernorm = None self.final_layernorm = None
......
...@@ -40,6 +40,9 @@ class ExtraTransformerConfig: ...@@ -40,6 +40,9 @@ class ExtraTransformerConfig:
combined_1f1b_recipe: str = 'ep_a2a' combined_1f1b_recipe: str = 'ep_a2a'
"""Recipe to use for combined 1F1B. Currently only 'ep_a2a' and 'golden' are supported.""" """Recipe to use for combined 1F1B. Currently only 'ep_a2a' and 'golden' are supported."""
split_bw: bool = False
"""If true, split dgrad and wgrad for better overlapping in combined 1F1B."""
@dataclass @dataclass
class TransformerConfigPatch(TransformerConfig, ExtraTransformerConfig): class TransformerConfigPatch(TransformerConfig, ExtraTransformerConfig):
......
from megatron.core import parallel_state, tensor_parallel from functools import partial
from typing import Any, Optional
import torch
from torch import Tensor
from megatron.core import tensor_parallel
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.utils import ( from megatron.core.utils import (
deprecate_inference_params, deprecate_inference_params,
make_viewless_tensor, make_viewless_tensor,
) )
from megatron.core.transformer.transformer_layer import TransformerLayer as MegatronCoreTransformerLayer from megatron.core.transformer.transformer_layer import TransformerLayer as MegatronCoreTransformerLayer
from dcu_megatron.core.transformer.utils import SubmoduleCallables, TransformerLayerSubmoduleCallables
class TransformerLayer(MegatronCoreTransformerLayer): class TransformerLayer(MegatronCoreTransformerLayer):
def _callable_wrapper( def _callable_wrapper(
...@@ -147,7 +156,6 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -147,7 +156,6 @@ class TransformerLayer(MegatronCoreTransformerLayer):
tokens_per_expert, tokens_per_expert,
permutated_local_input_tokens, permutated_local_input_tokens,
permuted_probs, permuted_probs,
probs,
] ]
return tuple(outputs) return tuple(outputs)
......
import warnings
import torch import torch
from typing import Optional from typing import Optional
import lightop try:
import lightop
except ImportError:
warnings.warn('lightop is not installed.')
from functools import partial from functools import partial
from megatron.core.utils import is_torch_min_version from megatron.core.utils import is_torch_min_version
......
...@@ -139,6 +139,8 @@ def _add_extra_moe_args(parser): ...@@ -139,6 +139,8 @@ def _add_extra_moe_args(parser):
choices=['ep_a2a', 'golden'], choices=['ep_a2a', 'golden'],
default='golden', default='golden',
help='Options are "ep_a2a" and "golden".') help='Options are "ep_a2a" and "golden".')
group.add_argument('--split-bw', action='store_true',
help='Split dgrad and wgrad for batch-level overlapping')
return parser return parser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment