Commit 5890bb4c authored by dongcl's avatar dongcl
Browse files

fix import error

parent 4bb958ec
......@@ -5,13 +5,17 @@ def a2a_overlap_adaptation(patches_manager):
"""
patches_manager: MegatronPatchesManager
"""
from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TERowParallelLinear
from ..core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher
from ..core.transformer.transformer_block import TransformerBlock
from ..core.transformer.transformer_layer import TransformerLayer
from ..core.models.gpt.gpt_model import GPTModel
from ..core.pipeline_parallel.schedules import get_pp_rank_microbatches, forward_backward_pipelining_with_interleaving
from ..core.extensions.transformer_engine import _get_extra_te_kwargs_wrapper, TELinear, TELayerNormColumnParallelLinear
from ..core.extensions.transformer_engine import (
_get_extra_te_kwargs_wrapper,
TELinear,
TELayerNormColumnParallelLinear,
)
from ..core.transformer.multi_latent_attention import MLASelfAttention
from ..core.transformer.mlp import MLP
from ..core.transformer.moe.experts import TEGroupedMLP
......@@ -38,23 +42,34 @@ def a2a_overlap_adaptation(patches_manager):
GPTModel)
# backward_dw
patches_manager.register_patch('megatron.core.extensions.transformer_engine._get_extra_te_kwargs',
_get_extra_te_kwargs_wrapper,
apply_wrapper=True)
# patches_manager.register_patch('megatron.core.extensions.transformer_engine._get_extra_te_kwargs',
# _get_extra_te_kwargs_wrapper,
# apply_wrapper=True)
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELinear',
TELinear)
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear',
TELayerNormColumnParallelLinear)
TEColumnParallelLinear.__bases__ = (TELinear,)
TERowParallelLinear.__bases__ = (TELinear,)
if is_te_min_version("1.9.0.dev0"):
from megatron.core.extensions.transformer_engine import TEColumnParallelGroupedLinear, TERowParallelGroupedLinear
from ..core.extensions.transformer_engine import TEGroupedLinear
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TEGroupedLinear',
TEGroupedLinear)
TEColumnParallelGroupedLinear.__bases__ = (TEGroupedLinear,)
TERowParallelGroupedLinear.__bases__ = (TEGroupedLinear,)
patches_manager.register_patch('megatron.core.transformer.multi_latent_attention.MLASelfAttention',
MLASelfAttention)
patches_manager.register_patch('megatron.core.transformer.mlp.MLP',
MLP)
patches_manager.register_patch('megatron.core.transformer.moe.experts.TEGroupedMLP',
TEGroupedMLP)
patches_manager.register_patch('megatron.core.transformer.moe.moe_layer.MoELayer',
MoELayer)
patches_manager.register_patch('megatron.core.transformer.multi_latent_attention.MLASelfAttention.backward_dw',
MLASelfAttention.backward_dw,
create_dummy=True)
patches_manager.register_patch('megatron.core.transformer.mlp.MLP.backward_dw',
MLP.backward_dw,
create_dummy=True)
patches_manager.register_patch('megatron.core.transformer.moe.experts.TEGroupedMLP.backward_dw',
TEGroupedMLP.backward_dw,
create_dummy=True)
patches_manager.register_patch('megatron.core.transformer.moe.moe_layer.MoELayer.backward_dw',
MoELayer.backward_dw,
create_dummy=True)
......@@ -104,7 +104,7 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__',
gpt_model_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward',
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward',
gpt_model_forward)
def patch_core_transformers(self):
......@@ -122,12 +122,12 @@ class CoreAdaptation(MegatronAdaptationABC):
MLATransformerConfigPatch)
# Moe
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity',
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False}),
apply_wrapper=True)
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
apply_wrapper=True)
# MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity',
# torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False}),
# apply_wrapper=True)
# MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
# torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
# apply_wrapper=True)
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.permute',
torch.compile(mode='max-autotune-no-cudagraphs'),
apply_wrapper=True)
......
......@@ -3,6 +3,7 @@ import torch
import dataclasses
import transformer_engine as te
from functools import wraps
from typing import Any, Optional, Callable
from packaging.version import Version as PkgVersion
......@@ -18,7 +19,6 @@ from megatron.core.extensions.transformer_engine import TELinear as MegatronCore
from megatron.core.extensions.transformer_engine import TELayerNormColumnParallelLinear as MegatronCoreTELayerNormColumnParallelLinear
from megatron.core.parallel_state import (
get_context_parallel_global_ranks,
get_context_parallel_group,
get_hierarchical_context_parallel_groups,
get_tensor_model_parallel_group,
......@@ -29,7 +29,7 @@ def _get_extra_te_kwargs_wrapper(fn):
@wraps(fn)
def wrapper(config: TransformerConfig):
extra_transformer_engine_kwargs = fn(config)
extra_transformer_engine_kwargs["delay_wgrad_compute"] = config.get("split_bw", False)
extra_transformer_engine_kwargs["delay_wgrad_compute"] = config.split_bw if hasattr(config, "split_bw") else False
return extra_transformer_engine_kwargs
return wrapper
......@@ -66,7 +66,7 @@ class TELinear(MegatronCoreTELinear):
is_expert: bool = False,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
):
self.split_bw = config.get("split_bw", False)
self.split_bw = config.split_bw if hasattr(config, "split_bw") else False
assert not self.split_bw, "split_bw is currently not supported"
super().__init__(
......@@ -109,7 +109,7 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
tp_comm_buffer_name: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
):
self.split_bw = config.get("split_bw", False)
self.split_bw = config.split_bw if hasattr(config, "split_bw") else False
assert not self.split_bw, "split_bw is currently not supported"
super().__init__(
......@@ -314,7 +314,7 @@ if is_te_min_version("1.9.0.dev0"):
tp_comm_buffer_name: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
):
self.split_bw = config.get("split_bw", False)
self.split_bw = config.split_bw if hasattr(config, "split_bw") else False
assert not self.split_bw, "split_bw is currently not supported"
super().__init__(
......
import contextlib
import weakref
from typing import Any, Callable, Optional, Tuple, Union
from collections import OrderedDict
from typing import Optional
import torch
from torch import Tensor
from megatron.core.pipeline_parallel.combined_1f1b import (
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer import transformer_layer
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.utils import deprecate_inference_params
from dcu_megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllPerBatchState
from dcu_megatron.core.pipeline_parallel.combined_1f1b import (
AbstractSchedulePlan,
ScheduleNode,
get_com_stream,
get_comp_stream,
make_viewless,
)
from megatron.core.transformer import transformer_layer
from megatron.core.transformer.module import float16_to_fp32
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllPerBatchState
def weak_method(method):
......@@ -43,6 +48,7 @@ class PreProcessNode(ScheduleNode):
input_ids = self.model_chunk_state.input_ids
position_ids = self.model_chunk_state.position_ids
inference_context = self.model_chunk_state.inference_context
inference_params = self.model_chunk_state.inference_params
packed_seq_params = self.model_chunk_state.packed_seq_params
inference_context = deprecate_inference_params(inference_context, inference_params)
......@@ -121,22 +127,6 @@ class PostProcessNode(ScheduleNode):
self.gpt_model = gpt_model
self.model_chunk_state = model_chunk_state
state.input_ids = input_ids
state.position_ids = position_ids
state.attention_mask = attention_mask
state.decoder_input = decoder_input
state.labels = labels
state.inference_context =inference_context
state.packed_seq_params = packed_seq_params
state.extra_block_kwargs = extra_block_kwargs
state.runtime_gather_output = runtime_gather_output
state.inference_params = inference_params
state.loss_mask = loss_mask
state.context = None
state.context_mask = None
state.attention_bias = None
def forward_impl(self, hidden_states):
gpt_model = self.gpt_model
......@@ -145,11 +135,13 @@ class PostProcessNode(ScheduleNode):
labels = self.model_chunk_state.labels
loss_mask = self.model_chunk_state.loss_mask
attention_mask = self.model_chunk_state.attention_mask
decoder_input = self.model_chunk_state.decoder_input
inference_params= self.model_chunk_state.inference_params
rotary_pos_emb = self.model_chunk_state.rotary_pos_emb
rotary_pos_cos = self.model_chunk_state.rotary_pos_cos
rotary_pos_sin = self.model_chunk_state.rotary_pos_sin
packed_seq_params = self.model_chunk_state.packed_seq_params
extra_block_kwargs = self.model_chunk_state.extra_block_kwargs
sequence_len_offset = self.model_chunk_state.sequence_len_offset
runtime_gather_output = self.model_chunk_state.runtime_gather_output
inference_context = self.model_chunk_state.inference_context
......@@ -267,6 +259,9 @@ class TransformerLayerNode(ScheduleNode):
def backward_impl(self, outputs, output_grad):
detached_grad = tuple([e.grad for e in self.detached])
grads = output_grad + detached_grad
# if len(detached_grad):
# print(f"output_grad: {grads}")
self.default_backward_func(outputs + self.before_detached, grads)
self.before_detached = None
self.detached = None
......@@ -296,7 +291,6 @@ class MoeAttnNode(TransformerLayerNode):
tokens_per_expert,
permutated_local_input_tokens,
permuted_probs,
probs,
) = self.layer._submodule_attention_router_compound_forward(
hidden_states,
attention_mask=attention_mask,
......@@ -312,7 +306,6 @@ class MoeAttnNode(TransformerLayerNode):
self.common_state.tokens_per_expert = tokens_per_expert
# detached here
self.common_state.probs = self.detach(probs)
self.common_state.residual = self.detach(hidden_states)
self.common_state.pre_mlp_layernorm_output = self.detach(pre_mlp_layernorm_output)
......@@ -334,7 +327,7 @@ class MoeDispatchNode(TransformerLayerNode):
)
# release tensor not used by backward
# inputs.untyped_storage().resize_(0)
self.common_state.tokens_per_expert = = tokens_per_expert
self.common_state.tokens_per_expert = tokens_per_expert
return global_input_tokens, global_probs
......@@ -345,7 +338,7 @@ class MoeMlPNode(TransformerLayerNode):
token_dispatcher = self.layer.mlp.token_dispatcher
with token_dispatcher.per_batch_state_context(self.common_state):
expert_output, shared_expert_output, mlp_bias = self.layer._submodule_moe_forward(
self.common_state.tokens_per_expert, global_input_tokens, global_prob, pre_mlp_layernorm_output
self.common_state.tokens_per_expert, global_input_tokens, global_probs, pre_mlp_layernorm_output
)
assert mlp_bias is None
......@@ -372,9 +365,7 @@ class MoeCombineNode(TransformerLayerNode):
)
cur_stream = torch.cuda.current_stream()
self.common_state.residual.record_stream(cur_stream)
self.common_state.probs.record_stream(cur_stream)
self.common_state.residual = None
self.common_state.probs = None
return output
......@@ -554,21 +545,18 @@ def schedule_layer_1f1b(
f_context = f_context if f_context is not None else contextlib.nullcontext()
b_context = b_context if b_context is not None else contextlib.nullcontext()
if pre_forward is not None:
assert f_input is None
# combine from last iter
f_input = pre_forward()
del pre_forward
if pre_backward is not None:
# attn backward from last iter
assert b_grad is None
b_grad = pre_backward()
del pre_backward
if b_layer is not None:
with b_context:
b_grad = b_layer.combine.backward(b_grad)
......@@ -577,7 +565,6 @@ def schedule_layer_1f1b(
pre_backward_dw()
del pre_backward_dw
if f_layer is not None:
with f_context:
f_input = f_layer.attn.forward(f_input)
......@@ -592,7 +579,6 @@ def schedule_layer_1f1b(
b_grad = b_layer.dispatch.backward(b_grad)
b_layer.mlp.dw()
if f_layer is not None:
with f_context:
f_input = f_layer.mlp.forward(f_input)
......@@ -614,7 +600,6 @@ def schedule_layer_1f1b(
with b_context:
b_layer.attn.dw()
if f_layer and b_layer:
return next_iter_pre_forward, next_iter_pre_backward, next_iter_pre_backward_dw
else:
......
......@@ -7,10 +7,10 @@ from functools import wraps
import torch
from torch import Tensor
from megatron.core import InferenceParams, tensor_parallel
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.utils import WrappedTensor, deprecate_inference_params
from megatron.core.models.gpt import GPTModel as MegatronCoreGPTModel
......
......@@ -427,7 +427,7 @@ def forward_backward_step(
if f_model:
with f_context:
num_tokens = torch.tensor(0, dtype=torch.int)
if parallel_state.is_pipeline_last_stage():
if parallel_state.is_pipeline_last_stage(ignore_virtual=False):
if not collect_non_loss_data:
loss_node = ScheduleNode(
loss_func,
......
......@@ -2,17 +2,13 @@ import contextlib
from typing import Callable, Iterator, List, Optional, Union
import torch
from torch.autograd.variable import Variable
from megatron.training import get_args
from megatron.core import parallel_state
from megatron.core.enums import ModelType
from megatron.core.pipeline_parallel import p2p_communication
from megatron.core.transformer.cuda_graphs import create_cudagraphs
from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler
from megatron.core.utils import (
drain_embedding_wgrad_compute,
get_attr_wrapped_model,
get_model_config,
get_model_type,
......@@ -32,6 +28,18 @@ from megatron.core.pipeline_parallel.schedules import (
from .combined_1f1b import VppContextManager, forward_backward_step, set_streams, wrap_forward_func
def set_current_microbatch(model, microbatch_id):
"""Set the current microbatch."""
decoder_exists = True
decoder = None
try:
decoder = get_attr_wrapped_model(model, "decoder")
except RuntimeError:
decoder_exists = False
if decoder_exists and decoder is not None:
decoder.current_microbatch = microbatch_id
def get_pp_rank_microbatches(
num_microbatches, num_model_chunks, microbatch_group_size_per_vp_stage, forward_only=False
):
......@@ -541,7 +549,7 @@ def forward_backward_pipelining_with_interleaving(
)
# forward step
if parallel_state.is_pipeline_first_stage():
if parallel_state.is_pipeline_first_stage(ignore_virtual=False):
if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None)
......@@ -573,7 +581,7 @@ def forward_backward_pipelining_with_interleaving(
enable_grad_sync()
synchronized_model_chunks.add(model_chunk_id)
if parallel_state.is_pipeline_last_stage():
if parallel_state.is_pipeline_last_stage(ignore_virtual=False):
if len(output_tensor_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None)
b_input_tensor = input_tensors[model_chunk_id].pop(0)
......@@ -679,7 +687,6 @@ def forward_backward_pipelining_with_interleaving(
post_backward=post_backward,
)
else:
output_tensor = None
input_tensor_grad = None
if f_virtual_microbatch_id is not None:
# forward pass
......@@ -704,7 +711,7 @@ def forward_backward_pipelining_with_interleaving(
input_tensor_grad = backward_step_helper(b_virtual_microbatch_id)
if post_backward is not None:
input_tensor_grad = post_backward(input_tensor_grad)
return output_tensor, input_tensor_grad
return output_tensor if f_virtual_microbatch_id is not None else None, input_tensor_grad
# Run warmup forward passes.
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
......@@ -890,6 +897,7 @@ def forward_backward_pipelining_with_interleaving(
output_tensor_grads[num_model_chunks - 1].append(bwd_recv_buffer[-1])
# Run 1F1B in steady state.
output_tensor = None
for k in range(num_microbatches_remaining):
# Forward pass.
forward_k = k + num_warmup_microbatches
......
from megatron.core.transformer.mlp import MLP as MegatronCoreMLP
class MLP(MegatronCoreMLP):
class MLP():
def backward_dw(self):
self.linear_fc2.backward_dw()
self.linear_fc1.backward_dw()
from megatron.core.transformer.experts import TEGroupedMLP as MegatronCoreTEGroupedMLP
class TEGroupedMLP(MegatronCoreTEGroupedMLP):
class TEGroupedMLP():
def backward_dw(self):
self.linear_fc2.backward_dw()
self.linear_fc1.backward_dw()
from megatron.core.transformer.moe.moe_layer import MoELayer as MegatronCoreMoELayer
class MoELayer(MegatronCoreMoELayer):
class MoELayer():
def backward_dw(self):
self.experts.backward_dw()
self.shared_experts.backward_dw()
from contextlib import contextmanager
from typing import Optional, Tuple
import torch
from megatron.core.tensor_parallel import (
all_to_all,
gather_from_sequence_parallel_region,
reduce_scatter_to_sequence_parallel_region,
)
from megatron.core.transformer.moe.moe_utils import (
permute,
sort_chunks_by_idxs,
unpermute,
)
from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher as MegatronCoreMoEAlltoAllTokenDispatcher
......@@ -303,7 +318,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
"""
assert bias is None, "Bias is not supported in MoEAlltoAllTokenDispatcher"
hidden_states = self.combine_preprocess(hidden_states)
hidden_states = self.combine_preprocess(hidden_states)
permutated_local_input_tokens = self.combine_all_to_all(hidden_states)
output = self.combine_postprocess(permutated_local_input_tokens)
......
from megatron.core.transformer.multi_latent_attention import MLASelfAttention as MegatronCoreMLASelfAttention
class MLASelfAttention(MegatronCoreMLASelfAttention):
class MLASelfAttention():
"""MLA Self-attention layer class
Self-attention layer takes input with size [s, b, h]
......
......@@ -9,7 +9,7 @@ def transformer_block_init_wrapper(fn):
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
config = args[0] if len(args) > 1 else kwargs['config']
if getattr(config, "mtp_num_layers", 0) > 0:
if hasattr(config, "mtp_num_layers") and config.mtp_num_layers is not None:
self.main_final_layernorm = self.final_layernorm
self.final_layernorm = None
......
......@@ -40,6 +40,9 @@ class ExtraTransformerConfig:
combined_1f1b_recipe: str = 'ep_a2a'
"""Recipe to use for combined 1F1B. Currently only 'ep_a2a' and 'golden' are supported."""
split_bw: bool = False
"""If true, split dgrad and wgrad for better overlapping in combined 1F1B."""
@dataclass
class TransformerConfigPatch(TransformerConfig, ExtraTransformerConfig):
......
from megatron.core import parallel_state, tensor_parallel
from functools import partial
from typing import Any, Optional
import torch
from torch import Tensor
from megatron.core import tensor_parallel
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.utils import (
deprecate_inference_params,
make_viewless_tensor,
)
from megatron.core.transformer.transformer_layer import TransformerLayer as MegatronCoreTransformerLayer
from dcu_megatron.core.transformer.utils import SubmoduleCallables, TransformerLayerSubmoduleCallables
class TransformerLayer(MegatronCoreTransformerLayer):
def _callable_wrapper(
......@@ -147,7 +156,6 @@ class TransformerLayer(MegatronCoreTransformerLayer):
tokens_per_expert,
permutated_local_input_tokens,
permuted_probs,
probs,
]
return tuple(outputs)
......
import warnings
import torch
from typing import Optional
import lightop
try:
import lightop
except ImportError:
warnings.warn('lightop is not installed.')
from functools import partial
from megatron.core.utils import is_torch_min_version
......
......@@ -139,6 +139,8 @@ def _add_extra_moe_args(parser):
choices=['ep_a2a', 'golden'],
default='golden',
help='Options are "ep_a2a" and "golden".')
group.add_argument('--split-bw', action='store_true',
help='Split dgrad and wgrad for batch-level overlapping')
return parser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment