Commit e5f5eb4d authored by dongcl's avatar dongcl
Browse files

Merge branch 'a2a_overlap' of...

Merge branch 'a2a_overlap' of http://developer.sourcefind.cn/codes/OpenDAS/dcu_megatron into a2a_overlap
parents ff5427cf 3947aa6c
...@@ -773,7 +773,7 @@ def build_model_chunk_schedule_plan( ...@@ -773,7 +773,7 @@ def build_model_chunk_schedule_plan(
state.attention_mask = attention_mask state.attention_mask = attention_mask
state.decoder_input = decoder_input state.decoder_input = decoder_input
state.labels = labels state.labels = labels
state.inference_context =inference_context state.inference_context = inference_context
state.packed_seq_params = packed_seq_params state.packed_seq_params = packed_seq_params
state.extra_block_kwargs = extra_block_kwargs state.extra_block_kwargs = extra_block_kwargs
state.runtime_gather_output = runtime_gather_output state.runtime_gather_output = runtime_gather_output
......
import contextlib import contextlib
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, List, Tuple, Union from typing import List, Union
import torch import torch
from torch import Tensor from torch import Tensor
...@@ -16,10 +16,6 @@ from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler ...@@ -16,10 +16,6 @@ from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler
from megatron.core.utils import get_attr_wrapped_model, make_viewless_tensor from megatron.core.utils import get_attr_wrapped_model, make_viewless_tensor
# Types
Shape = Union[List[int], torch.Size]
def make_viewless(e): def make_viewless(e):
"""make_viewless util func""" """make_viewless util func"""
e = make_viewless_tensor(inp=e, requires_grad=e.requires_grad, keep_graph=True) e = make_viewless_tensor(inp=e, requires_grad=e.requires_grad, keep_graph=True)
...@@ -351,7 +347,7 @@ def forward_backward_step( ...@@ -351,7 +347,7 @@ def forward_backward_step(
Tensor or list[Tensor]: The output object(s) from the forward step. Tensor or list[Tensor]: The output object(s) from the forward step.
Tensor: The number of tokens. Tensor: The number of tokens.
""" """
from .schedules import set_current_microbatch from megatron.core.pipeline_parallel.schedules import set_current_microbatch
if config.timers is not None: if config.timers is not None:
config.timers('forward-compute', log_level=2).start() config.timers('forward-compute', log_level=2).start()
......
import contextlib import contextlib
from typing import Callable, Iterator, List, Optional, Union from typing import Iterator, List, Union
import torch import torch
...@@ -7,10 +7,8 @@ from megatron.training import get_args ...@@ -7,10 +7,8 @@ from megatron.training import get_args
from megatron.core import parallel_state from megatron.core import parallel_state
from megatron.core.enums import ModelType from megatron.core.enums import ModelType
from megatron.core.pipeline_parallel import p2p_communication from megatron.core.pipeline_parallel import p2p_communication
from megatron.core.pipeline_parallel.schedules import set_current_microbatch
from megatron.core.transformer.cuda_graphs import create_cudagraphs from megatron.core.transformer.cuda_graphs import create_cudagraphs
from megatron.core.utils import ( from megatron.core.utils import (
get_attr_wrapped_model,
get_model_config, get_model_config,
get_model_type, get_model_type,
get_model_xattn, get_model_xattn,
......
from functools import partial
from typing import Any, Optional from typing import Any, Optional
import torch
from torch import Tensor from torch import Tensor
from megatron.core import tensor_parallel from megatron.core import tensor_parallel
...@@ -12,8 +10,7 @@ from megatron.core.utils import ( ...@@ -12,8 +10,7 @@ from megatron.core.utils import (
) )
from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.transformer_layer import TransformerLayer as MegatronCoreTransformerLayer from megatron.core.transformer.transformer_layer import TransformerLayer as MegatronCoreTransformerLayer
from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher
from dcu_megatron.core.transformer.utils import SubmoduleCallables, TransformerLayerSubmoduleCallables
class TransformerLayer(MegatronCoreTransformerLayer): class TransformerLayer(MegatronCoreTransformerLayer):
...@@ -34,7 +31,10 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -34,7 +31,10 @@ class TransformerLayer(MegatronCoreTransformerLayer):
inference_params: Optional[Any] = None, inference_params: Optional[Any] = None,
): ):
if not isinstance(self.mlp, MoELayer): if (
not isinstance(self.mlp, MoELayer)
or not isinstance(self.mlp.token_dispatcher, MoEAlltoAllTokenDispatcher)
):
return super().forward( return super().forward(
hidden_states=hidden_states, hidden_states=hidden_states,
context=context, context=context,
...@@ -55,7 +55,7 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -55,7 +55,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
pre_mlp_layernorm_output, pre_mlp_layernorm_output,
tokens_per_expert, tokens_per_expert,
permutated_local_input_tokens, permutated_local_input_tokens,
probs, _,
) = self._submodule_attention_router_compound_forward( ) = self._submodule_attention_router_compound_forward(
hidden_states, hidden_states,
attention_mask, attention_mask,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment