Unverified Commit 32cd7070 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support TP in attention for two batch overlap (#6634)

parent ebd1ed49
...@@ -448,6 +448,13 @@ class CommunicateSummableTensorPairFn: ...@@ -448,6 +448,13 @@ class CommunicateSummableTensorPairFn:
): ):
return CommunicateSummableTensorPairFn._gather return CommunicateSummableTensorPairFn._gather
if (
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
and (residual_input_mode == ScatterMode.TP_ATTN_FULL)
and (output_mode == ScatterMode.SCATTERED)
):
return CommunicateSummableTensorPairFn._scatter
raise NotImplementedError( raise NotImplementedError(
f"{hidden_states_input_mode=} {residual_input_mode=} {output_mode=}" f"{hidden_states_input_mode=} {residual_input_mode=} {output_mode=}"
) )
...@@ -496,3 +503,15 @@ class CommunicateSummableTensorPairFn: ...@@ -496,3 +503,15 @@ class CommunicateSummableTensorPairFn:
local_hidden_states, local_hidden_states,
) )
return hidden_states, residual return hidden_states, residual
@staticmethod
def _scatter(
hidden_states: torch.Tensor,
residual: torch.Tensor,
forward_batch: ForwardBatch,
context: CommunicateContext,
):
assert residual is None, "not yet handled residual!=None"
tensor_list = list(hidden_states.tensor_split(context.attn_tp_size))
hidden_states = tensor_list[context.attn_tp_rank]
return hidden_states, residual
...@@ -1613,6 +1613,9 @@ class DeepseekV2Model(nn.Module): ...@@ -1613,6 +1613,9 @@ class DeepseekV2Model(nn.Module):
forward_batch=forward_batch, forward_batch=forward_batch,
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual, residual=residual,
input_data_scatter_mode=self.layers[
normal_num_layers - 1
].layer_scatter_modes.layer_output_mode,
zero_allocator=zero_allocator, zero_allocator=zero_allocator,
) )
......
...@@ -32,7 +32,11 @@ from sglang.srt.distributed import ( ...@@ -32,7 +32,11 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes from sglang.srt.layers.communicator import (
LayerCommunicator,
LayerScatterModes,
ScatterMode,
)
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
attn_tp_all_gather, attn_tp_all_gather,
attn_tp_reduce_scatter, attn_tp_reduce_scatter,
...@@ -447,6 +451,7 @@ class Qwen2MoeModel(nn.Module): ...@@ -447,6 +451,7 @@ class Qwen2MoeModel(nn.Module):
hidden_states, residual = model_forward_maybe_tbo( hidden_states, residual = model_forward_maybe_tbo(
layers=self.layers, layers=self.layers,
enable_tbo=True, enable_tbo=True,
input_data_scatter_mode=ScatterMode.model_input_output(),
positions=positions, positions=positions,
forward_batch=forward_batch, forward_batch=forward_batch,
hidden_states=hidden_states, hidden_states=hidden_states,
......
...@@ -5,6 +5,12 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence ...@@ -5,6 +5,12 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
import torch import torch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.communicator import (
CommunicateContext,
CommunicateSimpleFn,
CommunicateSummableTensorPairFn,
ScatterMode,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.quantization.deep_gemm import configure_deep_gemm_num_sms from sglang.srt.layers.quantization.deep_gemm import configure_deep_gemm_num_sms
...@@ -355,6 +361,7 @@ def model_forward_maybe_tbo( ...@@ -355,6 +361,7 @@ def model_forward_maybe_tbo(
positions: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_data_scatter_mode: ScatterMode,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
zero_allocator: Optional[BumpAllocator] = None, zero_allocator: Optional[BumpAllocator] = None,
): ):
...@@ -365,20 +372,32 @@ def model_forward_maybe_tbo( ...@@ -365,20 +372,32 @@ def model_forward_maybe_tbo(
residual=residual, residual=residual,
**(dict(zero_allocator=zero_allocator) if zero_allocator is not None else {}), **(dict(zero_allocator=zero_allocator) if zero_allocator is not None else {}),
) )
layer_input_scatter_mode = layers[0].layer_scatter_modes.layer_input_mode
operations_strategy = OperationsStrategy.init_new_tbo( operations_strategy = OperationsStrategy.init_new_tbo(
layers, forward_batch.global_forward_mode layers, forward_batch.global_forward_mode
) )
if enable_tbo: if enable_tbo:
return _model_forward_tbo(inputs, operations_strategy) return _model_forward_tbo(
inputs=inputs,
operations_strategy=operations_strategy,
input_data_scatter_mode=input_data_scatter_mode,
layer_input_scatter_mode=layer_input_scatter_mode,
)
else: else:
return _model_forward_non_tbo(inputs, operations_strategy) return _model_forward_non_tbo(inputs, operations_strategy)
def _model_forward_tbo(inputs, operations_strategy: OperationsStrategy): def _model_forward_tbo(
# The attn_tp_size!=1 case is not yet extracted to master inputs,
assert get_attention_tp_size() == 1 operations_strategy: OperationsStrategy,
input_data_scatter_mode: ScatterMode,
inputs_arr = _model_forward_tbo_split_inputs(**inputs) layer_input_scatter_mode: ScatterMode,
):
inputs_arr = _model_forward_tbo_split_inputs(
**inputs,
input_data_scatter_mode=input_data_scatter_mode,
layer_input_scatter_mode=layer_input_scatter_mode,
)
del inputs del inputs
with configure_deep_gemm_num_sms(operations_strategy.deep_gemm_num_sms): with configure_deep_gemm_num_sms(operations_strategy.deep_gemm_num_sms):
...@@ -401,7 +420,57 @@ def _model_forward_tbo_split_inputs( ...@@ -401,7 +420,57 @@ def _model_forward_tbo_split_inputs(
residual: torch.Tensor, residual: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
zero_allocator: Optional[BumpAllocator] = None, zero_allocator: Optional[BumpAllocator],
input_data_scatter_mode: ScatterMode,
layer_input_scatter_mode: ScatterMode,
) -> List[Dict]:
tbo_splitter_scatter_mode = ScatterMode.TP_ATTN_FULL
context = CommunicateContext.init_new()
hidden_states, residual = CommunicateSummableTensorPairFn.execute(
hidden_states_input_mode=input_data_scatter_mode,
residual_input_mode=input_data_scatter_mode,
output_mode=tbo_splitter_scatter_mode,
hidden_states=hidden_states,
residual=residual,
forward_batch=forward_batch,
context=context,
)
inputs_arr = _model_forward_tbo_split_inputs_raw(
hidden_states=hidden_states,
residual=residual,
positions=positions,
forward_batch=forward_batch,
zero_allocator=zero_allocator,
)
def _post_transform(hidden_states, residual, forward_batch, **kwargs):
hidden_states, residual = CommunicateSummableTensorPairFn.execute(
hidden_states_input_mode=tbo_splitter_scatter_mode,
residual_input_mode=tbo_splitter_scatter_mode,
output_mode=layer_input_scatter_mode,
hidden_states=hidden_states,
residual=residual,
forward_batch=forward_batch,
context=context,
)
return dict(
hidden_states=hidden_states,
residual=residual,
forward_batch=forward_batch,
**kwargs,
)
return [_post_transform(**inputs) for inputs in inputs_arr]
def _model_forward_tbo_split_inputs_raw(
hidden_states: torch.Tensor,
residual: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: Optional[BumpAllocator],
) -> List[Dict]: ) -> List[Dict]:
return [ return [
dict( dict(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment