Commit e5f5eb4d authored by dongcl's avatar dongcl
Browse files

Merge branch 'a2a_overlap' of...

Merge branch 'a2a_overlap' of http://developer.sourcefind.cn/codes/OpenDAS/dcu_megatron into a2a_overlap
parents ff5427cf 3947aa6c
......@@ -773,7 +773,7 @@ def build_model_chunk_schedule_plan(
state.attention_mask = attention_mask
state.decoder_input = decoder_input
state.labels = labels
state.inference_context =inference_context
state.inference_context = inference_context
state.packed_seq_params = packed_seq_params
state.extra_block_kwargs = extra_block_kwargs
state.runtime_gather_output = runtime_gather_output
......
import contextlib
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, List, Tuple, Union
from typing import List, Union
import torch
from torch import Tensor
......@@ -16,10 +16,6 @@ from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler
from megatron.core.utils import get_attr_wrapped_model, make_viewless_tensor
# Types
Shape = Union[List[int], torch.Size]
def make_viewless(e):
"""make_viewless util func"""
e = make_viewless_tensor(inp=e, requires_grad=e.requires_grad, keep_graph=True)
......@@ -351,7 +347,7 @@ def forward_backward_step(
Tensor or list[Tensor]: The output object(s) from the forward step.
Tensor: The number of tokens.
"""
from .schedules import set_current_microbatch
from megatron.core.pipeline_parallel.schedules import set_current_microbatch
if config.timers is not None:
config.timers('forward-compute', log_level=2).start()
......
import contextlib
from typing import Callable, Iterator, List, Optional, Union
from typing import Iterator, List, Union
import torch
......@@ -7,10 +7,8 @@ from megatron.training import get_args
from megatron.core import parallel_state
from megatron.core.enums import ModelType
from megatron.core.pipeline_parallel import p2p_communication
from megatron.core.pipeline_parallel.schedules import set_current_microbatch
from megatron.core.transformer.cuda_graphs import create_cudagraphs
from megatron.core.utils import (
get_attr_wrapped_model,
get_model_config,
get_model_type,
get_model_xattn,
......
from functools import partial
from typing import Any, Optional
import torch
from torch import Tensor
from megatron.core import tensor_parallel
......@@ -12,8 +10,7 @@ from megatron.core.utils import (
)
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.transformer_layer import TransformerLayer as MegatronCoreTransformerLayer
from dcu_megatron.core.transformer.utils import SubmoduleCallables, TransformerLayerSubmoduleCallables
from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher
class TransformerLayer(MegatronCoreTransformerLayer):
......@@ -34,7 +31,10 @@ class TransformerLayer(MegatronCoreTransformerLayer):
inference_params: Optional[Any] = None,
):
if not isinstance(self.mlp, MoELayer):
if (
not isinstance(self.mlp, MoELayer)
or not isinstance(self.mlp.token_dispatcher, MoEAlltoAllTokenDispatcher)
):
return super().forward(
hidden_states=hidden_states,
context=context,
......@@ -55,7 +55,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
pre_mlp_layernorm_output,
tokens_per_expert,
permutated_local_input_tokens,
probs,
_,
) = self._submodule_attention_router_compound_forward(
hidden_states,
attention_mask,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment