Unverified Commit 32cd7070 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support TP in attention for two batch overlap (#6634)

parent ebd1ed49
......@@ -448,6 +448,13 @@ class CommunicateSummableTensorPairFn:
):
return CommunicateSummableTensorPairFn._gather
if (
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
and (residual_input_mode == ScatterMode.TP_ATTN_FULL)
and (output_mode == ScatterMode.SCATTERED)
):
return CommunicateSummableTensorPairFn._scatter
raise NotImplementedError(
f"{hidden_states_input_mode=} {residual_input_mode=} {output_mode=}"
)
......@@ -496,3 +503,15 @@ class CommunicateSummableTensorPairFn:
local_hidden_states,
)
return hidden_states, residual
@staticmethod
def _scatter(
hidden_states: torch.Tensor,
residual: torch.Tensor,
forward_batch: ForwardBatch,
context: CommunicateContext,
):
assert residual is None, "not yet handled residual!=None"
tensor_list = list(hidden_states.tensor_split(context.attn_tp_size))
hidden_states = tensor_list[context.attn_tp_rank]
return hidden_states, residual
......@@ -1613,6 +1613,9 @@ class DeepseekV2Model(nn.Module):
forward_batch=forward_batch,
hidden_states=hidden_states,
residual=residual,
input_data_scatter_mode=self.layers[
normal_num_layers - 1
].layer_scatter_modes.layer_output_mode,
zero_allocator=zero_allocator,
)
......
......@@ -32,7 +32,11 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.communicator import (
LayerCommunicator,
LayerScatterModes,
ScatterMode,
)
from sglang.srt.layers.dp_attention import (
attn_tp_all_gather,
attn_tp_reduce_scatter,
......@@ -447,6 +451,7 @@ class Qwen2MoeModel(nn.Module):
hidden_states, residual = model_forward_maybe_tbo(
layers=self.layers,
enable_tbo=True,
input_data_scatter_mode=ScatterMode.model_input_output(),
positions=positions,
forward_batch=forward_batch,
hidden_states=hidden_states,
......
......@@ -5,6 +5,12 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
import torch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.communicator import (
CommunicateContext,
CommunicateSimpleFn,
CommunicateSummableTensorPairFn,
ScatterMode,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.quantization.deep_gemm import configure_deep_gemm_num_sms
......@@ -355,6 +361,7 @@ def model_forward_maybe_tbo(
positions: torch.Tensor,
forward_batch: ForwardBatch,
hidden_states: torch.Tensor,
input_data_scatter_mode: ScatterMode,
residual: Optional[torch.Tensor],
zero_allocator: Optional[BumpAllocator] = None,
):
......@@ -365,20 +372,32 @@ def model_forward_maybe_tbo(
residual=residual,
**(dict(zero_allocator=zero_allocator) if zero_allocator is not None else {}),
)
layer_input_scatter_mode = layers[0].layer_scatter_modes.layer_input_mode
operations_strategy = OperationsStrategy.init_new_tbo(
layers, forward_batch.global_forward_mode
)
if enable_tbo:
return _model_forward_tbo(inputs, operations_strategy)
return _model_forward_tbo(
inputs=inputs,
operations_strategy=operations_strategy,
input_data_scatter_mode=input_data_scatter_mode,
layer_input_scatter_mode=layer_input_scatter_mode,
)
else:
return _model_forward_non_tbo(inputs, operations_strategy)
def _model_forward_tbo(inputs, operations_strategy: OperationsStrategy):
# The attn_tp_size!=1 case is not yet extracted to master
assert get_attention_tp_size() == 1
inputs_arr = _model_forward_tbo_split_inputs(**inputs)
def _model_forward_tbo(
inputs,
operations_strategy: OperationsStrategy,
input_data_scatter_mode: ScatterMode,
layer_input_scatter_mode: ScatterMode,
):
inputs_arr = _model_forward_tbo_split_inputs(
**inputs,
input_data_scatter_mode=input_data_scatter_mode,
layer_input_scatter_mode=layer_input_scatter_mode,
)
del inputs
with configure_deep_gemm_num_sms(operations_strategy.deep_gemm_num_sms):
......@@ -401,7 +420,57 @@ def _model_forward_tbo_split_inputs(
residual: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: Optional[BumpAllocator] = None,
zero_allocator: Optional[BumpAllocator],
input_data_scatter_mode: ScatterMode,
layer_input_scatter_mode: ScatterMode,
) -> List[Dict]:
tbo_splitter_scatter_mode = ScatterMode.TP_ATTN_FULL
context = CommunicateContext.init_new()
hidden_states, residual = CommunicateSummableTensorPairFn.execute(
hidden_states_input_mode=input_data_scatter_mode,
residual_input_mode=input_data_scatter_mode,
output_mode=tbo_splitter_scatter_mode,
hidden_states=hidden_states,
residual=residual,
forward_batch=forward_batch,
context=context,
)
inputs_arr = _model_forward_tbo_split_inputs_raw(
hidden_states=hidden_states,
residual=residual,
positions=positions,
forward_batch=forward_batch,
zero_allocator=zero_allocator,
)
def _post_transform(hidden_states, residual, forward_batch, **kwargs):
hidden_states, residual = CommunicateSummableTensorPairFn.execute(
hidden_states_input_mode=tbo_splitter_scatter_mode,
residual_input_mode=tbo_splitter_scatter_mode,
output_mode=layer_input_scatter_mode,
hidden_states=hidden_states,
residual=residual,
forward_batch=forward_batch,
context=context,
)
return dict(
hidden_states=hidden_states,
residual=residual,
forward_batch=forward_batch,
**kwargs,
)
return [_post_transform(**inputs) for inputs in inputs_arr]
def _model_forward_tbo_split_inputs_raw(
hidden_states: torch.Tensor,
residual: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: Optional[BumpAllocator],
) -> List[Dict]:
return [
dict(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment