Unverified Commit b87aacb5 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[DP Attention] Refactor: adding some utility functions (#9136)

parent b3363cc1
...@@ -32,6 +32,8 @@ from sglang.srt.layers.dp_attention import ( ...@@ -32,6 +32,8 @@ from sglang.srt.layers.dp_attention import (
get_attention_dp_size, get_attention_dp_size,
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
get_global_dp_buffer,
get_local_dp_buffer,
) )
from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
...@@ -319,7 +321,7 @@ class CommunicateSimpleFn: ...@@ -319,7 +321,7 @@ class CommunicateSimpleFn:
context: CommunicateContext, context: CommunicateContext,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states, local_hidden_states = ( hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], get_local_dp_buffer(),
hidden_states, hidden_states,
) )
attn_tp_all_gather_into_tensor( attn_tp_all_gather_into_tensor(
...@@ -408,9 +410,7 @@ class CommunicateWithAllReduceAndLayerNormFn: ...@@ -408,9 +410,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
): ):
if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1: if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
residual, local_residual = ( residual, local_residual = (
torch.empty_like( get_local_dp_buffer(),
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]]
),
residual, residual,
) )
attn_tp_all_gather_into_tensor(residual, local_residual) attn_tp_all_gather_into_tensor(residual, local_residual)
...@@ -424,7 +424,7 @@ class CommunicateWithAllReduceAndLayerNormFn: ...@@ -424,7 +424,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
residual = hidden_states residual = hidden_states
hidden_states = layernorm(hidden_states) hidden_states = layernorm(hidden_states)
hidden_states, local_hidden_states = ( hidden_states, local_hidden_states = (
torch.empty_like(forward_batch.gathered_buffer), get_global_dp_buffer(),
hidden_states, hidden_states,
) )
dp_gather_partial(hidden_states, local_hidden_states, forward_batch) dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
...@@ -548,7 +548,7 @@ class CommunicateSummableTensorPairFn: ...@@ -548,7 +548,7 @@ class CommunicateSummableTensorPairFn:
allow_reduce_scatter: bool = False, allow_reduce_scatter: bool = False,
): ):
hidden_states, global_hidden_states = ( hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], get_local_dp_buffer(),
hidden_states, hidden_states,
) )
if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len(): if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
...@@ -569,7 +569,7 @@ class CommunicateSummableTensorPairFn: ...@@ -569,7 +569,7 @@ class CommunicateSummableTensorPairFn:
hidden_states += residual hidden_states += residual
residual = None residual = None
hidden_states, local_hidden_states = ( hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], get_local_dp_buffer(),
hidden_states, hidden_states,
) )
attn_tp_all_gather_into_tensor( attn_tp_all_gather_into_tensor(
......
...@@ -4,7 +4,7 @@ import functools ...@@ -4,7 +4,7 @@ import functools
import logging import logging
from contextlib import contextmanager from contextlib import contextmanager
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple
import torch import torch
import triton import triton
...@@ -18,21 +18,26 @@ from sglang.srt.distributed import ( ...@@ -18,21 +18,26 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
_ATTN_TP_GROUP = None _ATTN_TP_GROUP: Optional[GroupCoordinator] = None
_ATTN_TP_RANK = None _ATTN_TP_RANK: Optional[int] = None
_ATTN_TP_SIZE = None _ATTN_TP_SIZE: Optional[int] = None
_ATTN_DP_RANK = None _ATTN_DP_RANK: Optional[int] = None
_ATTN_DP_SIZE = None _ATTN_DP_SIZE: Optional[int] = None
_LOCAL_ATTN_DP_SIZE = None _LOCAL_ATTN_DP_SIZE: Optional[int] = None
_LOCAL_ATTN_DP_RANK = None _LOCAL_ATTN_DP_RANK: Optional[int] = None
_ENABLE_DP_ATTENTION_FLAG: bool = False
class DPPaddingMode(IntEnum): class DpPaddingMode(IntEnum):
# Padding tokens to max length and then gather tokens using `all_gather_into_tensor` # Padding tokens to max length and then gather tokens using `all_gather_into_tensor`
MAX_LEN = auto() MAX_LEN = auto()
...@@ -40,13 +45,13 @@ class DPPaddingMode(IntEnum): ...@@ -40,13 +45,13 @@ class DPPaddingMode(IntEnum):
SUM_LEN = auto() SUM_LEN = auto()
def is_max_len(self): def is_max_len(self):
return self == DPPaddingMode.MAX_LEN return self == DpPaddingMode.MAX_LEN
def is_sum_len(self): def is_sum_len(self):
return self == DPPaddingMode.SUM_LEN return self == DpPaddingMode.SUM_LEN
@classmethod @classmethod
def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DPPaddingMode: def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DpPaddingMode:
# we choose the mode that minimizes the communication cost # we choose the mode that minimizes the communication cost
max_len = max(global_num_tokens) max_len = max(global_num_tokens)
sum_len = sum(global_num_tokens) sum_len = sum(global_num_tokens)
...@@ -56,10 +61,76 @@ class DPPaddingMode(IntEnum): ...@@ -56,10 +61,76 @@ class DPPaddingMode(IntEnum):
return cls.SUM_LEN return cls.SUM_LEN
@classmethod @classmethod
def get_default_mode_in_cuda_graph(cls) -> DPPaddingMode: def get_default_mode_in_cuda_graph(cls) -> DpPaddingMode:
return cls.MAX_LEN return cls.MAX_LEN
class _DpGatheredBufferWrapper:
_hidden_size: int
_dtype: torch.dtype
_device: torch.device
_global_dp_buffer_len: int
_local_dp_buffer_len: int
@classmethod
def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
cls._hidden_size = hidden_size
cls._dtype = dtype
cls._device = device
@classmethod
def set_dp_buffer_len(cls, global_dp_buffer_len: int, local_dp_buffer_len: int):
cls._global_dp_buffer_len = global_dp_buffer_len
cls._local_dp_buffer_len = local_dp_buffer_len
@classmethod
def get_global_dp_buffer(cls) -> torch.Tensor:
return torch.empty(
(cls._global_dp_buffer_len, cls._hidden_size),
dtype=cls._dtype,
device=cls._device,
)
@classmethod
def get_local_dp_buffer(cls) -> torch.Tensor:
return torch.empty(
(cls._local_dp_buffer_len, cls._hidden_size),
dtype=cls._dtype,
device=cls._device,
)
@classmethod
def get_global_dp_buffer_len(cls) -> int:
return cls._global_dp_buffer_len
@classmethod
def get_local_dp_buffer_len(cls) -> int:
return cls._local_dp_buffer_len
def set_dp_buffer_len(global_dp_buffer_len: int, local_dp_buffer_len: int):
_DpGatheredBufferWrapper.set_dp_buffer_len(
global_dp_buffer_len, local_dp_buffer_len
)
def get_global_dp_buffer() -> torch.Tensor:
return _DpGatheredBufferWrapper.get_global_dp_buffer()
def get_local_dp_buffer() -> torch.Tensor:
return _DpGatheredBufferWrapper.get_local_dp_buffer()
def get_global_dp_buffer_len() -> int:
return _DpGatheredBufferWrapper.get_global_dp_buffer_len()
def get_local_dp_buffer_len() -> int:
return _DpGatheredBufferWrapper.get_local_dp_buffer_len()
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size): def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
if not enable_dp_attention: if not enable_dp_attention:
return tp_rank, tp_size, 0 return tp_rank, tp_size, 0
...@@ -89,18 +160,24 @@ def compute_dp_attention_local_info( ...@@ -89,18 +160,24 @@ def compute_dp_attention_local_info(
def initialize_dp_attention( def initialize_dp_attention(
enable_dp_attention: bool, server_args: ServerArgs,
tp_rank: int, model_config: ModelConfig,
tp_size: int,
dp_size: int,
moe_dense_tp_size: int,
pp_size: int,
): ):
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE
global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK, _ENABLE_DP_ATTENTION_FLAG
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
enable_dp_attention = server_args.enable_dp_attention
tp_size = server_args.tp_size
dp_size = server_args.dp_size
moe_dense_tp_size = server_args.moe_dense_tp_size
pp_size = server_args.pp_size
tp_rank = get_tensor_model_parallel_rank()
_ENABLE_DP_ATTENTION_FLAG = enable_dp_attention
_ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info( _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info(
enable_dp_attention, tp_rank, tp_size, dp_size enable_dp_attention, tp_rank, tp_size, dp_size
) )
...@@ -135,38 +212,48 @@ def initialize_dp_attention( ...@@ -135,38 +212,48 @@ def initialize_dp_attention(
group_name="attention_tp", group_name="attention_tp",
) )
_DpGatheredBufferWrapper.set_metadata(
hidden_size=model_config.hidden_size,
dtype=model_config.dtype,
device=torch.device("cuda"),
)
def is_dp_attention_enabled() -> bool:
return _ENABLE_DP_ATTENTION_FLAG
def get_attention_tp_group(): def get_attention_tp_group() -> GroupCoordinator:
assert _ATTN_TP_GROUP is not None, "dp attention not initialized!" assert _ATTN_TP_GROUP is not None, "dp attention not initialized!"
return _ATTN_TP_GROUP return _ATTN_TP_GROUP
def get_attention_tp_rank(): def get_attention_tp_rank() -> int:
assert _ATTN_TP_RANK is not None, "dp attention not initialized!" assert _ATTN_TP_RANK is not None, "dp attention not initialized!"
return _ATTN_TP_RANK return _ATTN_TP_RANK
def get_attention_tp_size(): def get_attention_tp_size() -> int:
assert _ATTN_TP_SIZE is not None, "dp attention not initialized!" assert _ATTN_TP_SIZE is not None, "dp attention not initialized!"
return _ATTN_TP_SIZE return _ATTN_TP_SIZE
def get_attention_dp_rank(): def get_attention_dp_rank() -> int:
assert _ATTN_DP_RANK is not None, "dp attention not initialized!" assert _ATTN_DP_RANK is not None, "dp attention not initialized!"
return _ATTN_DP_RANK return _ATTN_DP_RANK
def get_attention_dp_size(): def get_attention_dp_size() -> int:
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!" assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
return _ATTN_DP_SIZE return _ATTN_DP_SIZE
def get_local_attention_dp_rank(): def get_local_attention_dp_rank() -> int:
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!" assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
return _LOCAL_ATTN_DP_RANK return _LOCAL_ATTN_DP_RANK
def get_local_attention_dp_size(): def get_local_attention_dp_size() -> int:
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!" assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
return _LOCAL_ATTN_DP_SIZE return _LOCAL_ATTN_DP_SIZE
......
...@@ -27,7 +27,7 @@ from sglang.srt.distributed import ( ...@@ -27,7 +27,7 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
DPPaddingMode, DpPaddingMode,
attn_tp_all_gather, attn_tp_all_gather,
attn_tp_all_gather_into_tensor, attn_tp_all_gather_into_tensor,
dp_gather_replicate, dp_gather_replicate,
...@@ -35,7 +35,9 @@ from sglang.srt.layers.dp_attention import ( ...@@ -35,7 +35,9 @@ from sglang.srt.layers.dp_attention import (
get_attention_dp_rank, get_attention_dp_rank,
get_attention_dp_size, get_attention_dp_size,
get_attention_tp_size, get_attention_tp_size,
get_global_dp_buffer,
get_local_attention_dp_size, get_local_attention_dp_size,
set_dp_buffer_len,
) )
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
...@@ -108,14 +110,12 @@ class LogitsMetadata: ...@@ -108,14 +110,12 @@ class LogitsMetadata:
# The start position of local hidden states. # The start position of local hidden states.
dp_local_start_pos: Optional[torch.Tensor] = None dp_local_start_pos: Optional[torch.Tensor] = None
dp_local_num_tokens: Optional[torch.Tensor] = None dp_local_num_tokens: Optional[torch.Tensor] = None
gathered_buffer: Optional[torch.Tensor] = None global_dp_buffer_len: Optional[int] = None
# Buffer to gather logits from all ranks.
forward_batch_gathered_buffer: Optional[torch.Tensor] = None
# Number of tokens to sample per DP rank # Number of tokens to sample per DP rank
global_num_tokens_for_logprob_cpu: Optional[torch.Tensor] = None global_num_tokens_for_logprob_cpu: Optional[torch.Tensor] = None
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
# The gather mode for DP attention # The gather mode for DP attention
dp_padding_mode: Optional[DPPaddingMode] = None dp_padding_mode: Optional[DpPaddingMode] = None
# for padding # for padding
padded_static_len: int = -1 padded_static_len: int = -1
...@@ -164,11 +164,10 @@ class LogitsMetadata: ...@@ -164,11 +164,10 @@ class LogitsMetadata:
global_num_tokens_gpu=forward_batch.global_num_tokens_gpu, global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
dp_local_start_pos=forward_batch.dp_local_start_pos, dp_local_start_pos=forward_batch.dp_local_start_pos,
dp_local_num_tokens=forward_batch.dp_local_num_tokens, dp_local_num_tokens=forward_batch.dp_local_num_tokens,
gathered_buffer=forward_batch.gathered_buffer, global_dp_buffer_len=forward_batch.global_dp_buffer_len,
forward_batch_gathered_buffer=forward_batch.gathered_buffer,
global_num_tokens_for_logprob_cpu=forward_batch.global_num_tokens_for_logprob_cpu, global_num_tokens_for_logprob_cpu=forward_batch.global_num_tokens_for_logprob_cpu,
global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu, global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu,
dp_padding_mode=DPPaddingMode.SUM_LEN, dp_padding_mode=DpPaddingMode.SUM_LEN,
) )
def compute_dp_attention_metadata(self): def compute_dp_attention_metadata(self):
...@@ -188,16 +187,11 @@ class LogitsMetadata: ...@@ -188,16 +187,11 @@ class LogitsMetadata:
if self.global_num_tokens_for_logprob_cpu is not None: if self.global_num_tokens_for_logprob_cpu is not None:
# create a smaller buffer to reduce peak memory usage # create a smaller buffer to reduce peak memory usage
self.gathered_buffer = torch.empty( self.global_dp_buffer_len = sum(self.global_num_tokens_for_logprob_cpu)
(
sum(self.global_num_tokens_for_logprob_cpu),
self.gathered_buffer.shape[1],
),
dtype=self.gathered_buffer.dtype,
device=self.gathered_buffer.device,
)
else: else:
self.gathered_buffer = torch.empty_like(self.gathered_buffer) self.global_dp_buffer_len = self.global_dp_buffer_len
set_dp_buffer_len(self.global_dp_buffer_len, self.dp_local_num_tokens)
class LogitsProcessor(nn.Module): class LogitsProcessor(nn.Module):
...@@ -443,7 +437,7 @@ class LogitsProcessor(nn.Module): ...@@ -443,7 +437,7 @@ class LogitsProcessor(nn.Module):
if self.do_tensor_parallel_all_gather_dp_attn: if self.do_tensor_parallel_all_gather_dp_attn:
logits_metadata.compute_dp_attention_metadata() logits_metadata.compute_dp_attention_metadata()
hidden_states, local_hidden_states = ( hidden_states, local_hidden_states = (
logits_metadata.gathered_buffer, get_global_dp_buffer(),
hidden_states, hidden_states,
) )
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata) dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
......
...@@ -6,7 +6,10 @@ import torch.distributed as dist ...@@ -6,7 +6,10 @@ import torch.distributed as dist
from torch import nn from torch import nn
from sglang.srt.distributed import get_tp_group from sglang.srt.distributed import get_tp_group
from sglang.srt.layers.dp_attention import get_attention_tp_group from sglang.srt.layers.dp_attention import (
get_attention_tp_group,
is_dp_attention_enabled,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
...@@ -32,7 +35,7 @@ class Sampler(nn.Module): ...@@ -32,7 +35,7 @@ class Sampler(nn.Module):
self.use_nan_detection = global_server_args_dict["enable_nan_detection"] self.use_nan_detection = global_server_args_dict["enable_nan_detection"]
self.tp_sync_group = get_tp_group().device_group self.tp_sync_group = get_tp_group().device_group
if global_server_args_dict["enable_dp_attention"]: if is_dp_attention_enabled():
self.tp_sync_group = get_attention_tp_group().device_group self.tp_sync_group = get_attention_tp_group().device_group
def forward( def forward(
......
...@@ -84,7 +84,6 @@ GLOBAL_SERVER_ARGS_KEYS = [ ...@@ -84,7 +84,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
"device", "device",
"disable_chunked_prefix_cache", "disable_chunked_prefix_cache",
"disable_radix_cache", "disable_radix_cache",
"enable_dp_attention",
"enable_two_batch_overlap", "enable_two_batch_overlap",
"tbo_token_distribution_threshold", "tbo_token_distribution_threshold",
"enable_dp_lm_head", "enable_dp_lm_head",
......
...@@ -34,9 +34,10 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import ( ...@@ -34,9 +34,10 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
) )
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
DPPaddingMode, DpPaddingMode,
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
set_dp_buffer_len,
) )
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.torchao_utils import save_gemlite_cache from sglang.srt.layers.torchao_utils import save_gemlite_cache
...@@ -349,30 +350,15 @@ class CudaGraphRunner: ...@@ -349,30 +350,15 @@ class CudaGraphRunner:
self.global_num_tokens_for_logprob_gpu = torch.zeros( self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32 (self.dp_size,), dtype=torch.int32
) )
self.gathered_buffer = torch.zeros(
(
self.max_num_token * self.dp_size,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
else: else:
assert self.require_attn_tp_gather assert self.require_attn_tp_gather
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
self.global_num_tokens_for_logprob_gpu = torch.zeros( self.global_num_tokens_for_logprob_gpu = torch.zeros(
(1,), dtype=torch.int32 (1,), dtype=torch.int32
) )
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
else: else:
self.global_num_tokens_gpu = None self.global_num_tokens_gpu = None
self.global_num_tokens_for_logprob_gpu = None self.global_num_tokens_for_logprob_gpu = None
self.gathered_buffer = None
self.custom_mask = torch.ones( self.custom_mask = torch.ones(
( (
...@@ -556,7 +542,7 @@ class CudaGraphRunner: ...@@ -556,7 +542,7 @@ class CudaGraphRunner:
device=input_ids.device, device=input_ids.device,
) )
) )
gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size] global_dp_buffer_len = num_tokens * self.dp_size
elif self.require_attn_tp_gather: elif self.require_attn_tp_gather:
self.global_num_tokens_gpu.copy_( self.global_num_tokens_gpu.copy_(
torch.tensor( torch.tensor(
...@@ -572,9 +558,9 @@ class CudaGraphRunner: ...@@ -572,9 +558,9 @@ class CudaGraphRunner:
device=input_ids.device, device=input_ids.device,
) )
) )
gathered_buffer = self.gathered_buffer[:num_tokens] global_dp_buffer_len = num_tokens
else: else:
gathered_buffer = None global_dp_buffer_len = None
spec_info = self.get_spec_info(num_tokens) spec_info = self.get_spec_info(num_tokens)
if self.capture_hidden_mode != CaptureHiddenMode.FULL: if self.capture_hidden_mode != CaptureHiddenMode.FULL:
...@@ -607,8 +593,8 @@ class CudaGraphRunner: ...@@ -607,8 +593,8 @@ class CudaGraphRunner:
positions=positions, positions=positions,
global_num_tokens_gpu=self.global_num_tokens_gpu, global_num_tokens_gpu=self.global_num_tokens_gpu,
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu, global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(), dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
gathered_buffer=gathered_buffer, global_dp_buffer_len=global_dp_buffer_len,
mrope_positions=mrope_positions, mrope_positions=mrope_positions,
spec_algorithm=self.model_runner.spec_algorithm, spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info, spec_info=spec_info,
...@@ -637,6 +623,7 @@ class CudaGraphRunner: ...@@ -637,6 +623,7 @@ class CudaGraphRunner:
def run_once(): def run_once():
# Clean intermediate result cache for DP attention # Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
kwargs = {} kwargs = {}
if ( if (
......
...@@ -40,9 +40,10 @@ import triton.language as tl ...@@ -40,9 +40,10 @@ import triton.language as tl
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
DPPaddingMode, DpPaddingMode,
get_attention_dp_rank, get_attention_dp_rank,
get_attention_tp_size, get_attention_tp_size,
set_dp_buffer_len,
) )
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import ( from sglang.srt.utils import (
...@@ -274,13 +275,13 @@ class ForwardBatch: ...@@ -274,13 +275,13 @@ class ForwardBatch:
global_num_tokens_for_logprob_cpu: Optional[List[int]] = None global_num_tokens_for_logprob_cpu: Optional[List[int]] = None
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
# The padding mode for DP attention # The padding mode for DP attention
dp_padding_mode: Optional[DPPaddingMode] = None dp_padding_mode: Optional[DpPaddingMode] = None
# for extend, local start pos and num tokens is different in logits processor # for extend, local start pos and num tokens is different in logits processor
# this will be computed in get_dp_local_info # this will be computed in get_dp_local_info
# this will be recomputed in LogitsMetadata.from_forward_batch # this will be recomputed in LogitsMetadata.from_forward_batch
dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
gathered_buffer: Optional[torch.Tensor] = None global_dp_buffer_len: Optional[int] = None
is_extend_in_batch: bool = False is_extend_in_batch: bool = False
can_run_dp_cuda_graph: bool = False can_run_dp_cuda_graph: bool = False
global_forward_mode: Optional[ForwardMode] = None global_forward_mode: Optional[ForwardMode] = None
...@@ -628,7 +629,7 @@ class ForwardBatch: ...@@ -628,7 +629,7 @@ class ForwardBatch:
(global_num_tokens[i] - 1) // attn_tp_size + 1 (global_num_tokens[i] - 1) // attn_tp_size + 1
) * attn_tp_size ) * attn_tp_size
dp_padding_mode = DPPaddingMode.get_dp_padding_mode(global_num_tokens) dp_padding_mode = DpPaddingMode.get_dp_padding_mode(global_num_tokens)
self.dp_padding_mode = dp_padding_mode self.dp_padding_mode = dp_padding_mode
if dp_padding_mode.is_max_len(): if dp_padding_mode.is_max_len():
...@@ -642,17 +643,14 @@ class ForwardBatch: ...@@ -642,17 +643,14 @@ class ForwardBatch:
else: else:
buffer_len = sum(global_num_tokens) buffer_len = sum(global_num_tokens)
self.gathered_buffer = torch.zeros(
(buffer_len, model_runner.model_config.hidden_size),
dtype=model_runner.dtype,
device=model_runner.device,
)
if len(global_num_tokens) > 1: if len(global_num_tokens) > 1:
num_tokens = global_num_tokens[get_attention_dp_rank()] num_tokens = global_num_tokens[get_attention_dp_rank()]
else: else:
num_tokens = global_num_tokens[0] num_tokens = global_num_tokens[0]
self.global_dp_buffer_len = buffer_len
set_dp_buffer_len(buffer_len, num_tokens)
bs = self.batch_size bs = self.batch_size
if self.forward_mode.is_decode(): if self.forward_mode.is_decode():
......
...@@ -603,12 +603,8 @@ class ModelRunner: ...@@ -603,12 +603,8 @@ class ModelRunner:
duplicate_tp_group=self.server_args.enable_pdmux, duplicate_tp_group=self.server_args.enable_pdmux,
) )
initialize_dp_attention( initialize_dp_attention(
enable_dp_attention=self.server_args.enable_dp_attention, server_args=self.server_args,
tp_rank=self.tp_rank, model_config=self.model_config,
tp_size=self.tp_size,
dp_size=self.server_args.dp_size,
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
pp_size=self.server_args.pp_size,
) )
min_per_gpu_memory = get_available_gpu_memory( min_per_gpu_memory = get_available_gpu_memory(
......
...@@ -22,6 +22,7 @@ from transformers import PretrainedConfig ...@@ -22,6 +22,7 @@ from transformers import PretrainedConfig
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
...@@ -56,7 +57,7 @@ class DeepseekModelNextN(nn.Module): ...@@ -56,7 +57,7 @@ class DeepseekModelNextN(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
enable_tp=not global_server_args_dict["enable_dp_attention"], enable_tp=not is_dp_attention_enabled(),
prefix=add_prefix("embed_tokens", prefix), prefix=add_prefix("embed_tokens", prefix),
) )
......
...@@ -51,6 +51,7 @@ from sglang.srt.layers.dp_attention import ( ...@@ -51,6 +51,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
get_local_attention_dp_size, get_local_attention_dp_size,
is_dp_attention_enabled,
) )
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
...@@ -1797,7 +1798,6 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1797,7 +1798,6 @@ class DeepseekV2DecoderLayer(nn.Module):
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
self.speculative_algorithm = global_server_args_dict["speculative_algorithm"] self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
self.layer_id = layer_id self.layer_id = layer_id
self.is_nextn = is_nextn self.is_nextn = is_nextn
...@@ -1917,7 +1917,9 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1917,7 +1917,9 @@ class DeepseekV2DecoderLayer(nn.Module):
should_allreduce_fusion = ( should_allreduce_fusion = (
self._should_fuse_mlp_allreduce_with_next_layer(forward_batch) self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
and not (self.enable_dp_attention and self.speculative_algorithm.is_eagle()) and not (
is_dp_attention_enabled() and self.speculative_algorithm.is_eagle()
)
and not self.is_nextn and not self.is_nextn
) )
...@@ -2047,7 +2049,7 @@ class DeepseekV2Model(nn.Module): ...@@ -2047,7 +2049,7 @@ class DeepseekV2Model(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
enable_tp=not global_server_args_dict["enable_dp_attention"], enable_tp=not is_dp_attention_enabled(),
) )
self.alt_stream = torch.cuda.Stream() if _is_cuda else None self.alt_stream = torch.cuda.Stream() if _is_cuda else None
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
......
...@@ -40,6 +40,7 @@ from sglang.srt.layers.dp_attention import ( ...@@ -40,6 +40,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
get_local_attention_dp_size, get_local_attention_dp_size,
is_dp_attention_enabled,
) )
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
...@@ -634,7 +635,6 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer): ...@@ -634,7 +635,6 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
) )
rms_norm_eps = config.rms_norm_eps rms_norm_eps = config.rms_norm_eps
attention_bias = config.attention_bias attention_bias = config.attention_bias
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
self.layer_id = layer_id self.layer_id = layer_id
self.self_attn = Glm4MoeAttention( self.self_attn = Glm4MoeAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
...@@ -744,7 +744,7 @@ class Glm4MoeModel(DeepseekV2Model): ...@@ -744,7 +744,7 @@ class Glm4MoeModel(DeepseekV2Model):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
enable_tp=not global_server_args_dict["enable_dp_attention"], enable_tp=not is_dp_attention_enabled(),
) )
self.alt_stream = torch.cuda.Stream() if _is_cuda else None self.alt_stream = torch.cuda.Stream() if _is_cuda else None
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
......
...@@ -22,6 +22,7 @@ from transformers import PretrainedConfig ...@@ -22,6 +22,7 @@ from transformers import PretrainedConfig
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
...@@ -56,7 +57,7 @@ class Glm4MoeModelNextN(nn.Module): ...@@ -56,7 +57,7 @@ class Glm4MoeModelNextN(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
enable_tp=not global_server_args_dict["enable_dp_attention"], enable_tp=not is_dp_attention_enabled(),
prefix=add_prefix("embed_tokens", prefix), prefix=add_prefix("embed_tokens", prefix),
) )
......
...@@ -41,6 +41,7 @@ from sglang.srt.layers.dp_attention import ( ...@@ -41,6 +41,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
get_local_attention_dp_size, get_local_attention_dp_size,
is_dp_attention_enabled,
) )
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
...@@ -565,7 +566,7 @@ class GptOssModel(nn.Module): ...@@ -565,7 +566,7 @@ class GptOssModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
enable_tp=not global_server_args_dict["enable_dp_attention"], enable_tp=not is_dp_attention_enabled(),
prefix=add_prefix("embed_tokens", prefix), prefix=add_prefix("embed_tokens", prefix),
) )
else: else:
......
...@@ -32,6 +32,7 @@ from sglang.srt.layers.dp_attention import ( ...@@ -32,6 +32,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
get_local_attention_dp_size, get_local_attention_dp_size,
is_dp_attention_enabled,
) )
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
...@@ -45,7 +46,6 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig ...@@ -45,7 +46,6 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch, ForwardBatch,
ForwardMode, ForwardMode,
...@@ -466,7 +466,7 @@ class Llama4Model(nn.Module): ...@@ -466,7 +466,7 @@ class Llama4Model(nn.Module):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("embed_tokens", prefix), prefix=add_prefix("embed_tokens", prefix),
enable_tp=not global_server_args_dict["enable_dp_attention"], enable_tp=not is_dp_attention_enabled(),
) )
self.layers = make_layers( self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
......
...@@ -27,6 +27,7 @@ from sglang.srt.distributed import ( ...@@ -27,6 +27,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -43,7 +44,6 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -43,7 +44,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import ( from sglang.srt.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
...@@ -273,7 +273,7 @@ class Qwen2Model(nn.Module): ...@@ -273,7 +273,7 @@ class Qwen2Model(nn.Module):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
enable_tp=not global_server_args_dict["enable_dp_attention"], enable_tp=not is_dp_attention_enabled(),
prefix=add_prefix("embed_tokens", prefix), prefix=add_prefix("embed_tokens", prefix),
) )
else: else:
......
...@@ -46,6 +46,7 @@ from sglang.srt.layers.dp_attention import ( ...@@ -46,6 +46,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
get_local_attention_dp_size, get_local_attention_dp_size,
is_dp_attention_enabled,
) )
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
...@@ -420,7 +421,7 @@ class Qwen2MoeModel(nn.Module): ...@@ -420,7 +421,7 @@ class Qwen2MoeModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
enable_tp=not global_server_args_dict["enable_dp_attention"], enable_tp=not is_dp_attention_enabled(),
prefix=add_prefix("embed_tokens", prefix), prefix=add_prefix("embed_tokens", prefix),
) )
else: else:
......
...@@ -25,7 +25,11 @@ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation ...@@ -25,7 +25,11 @@ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
is_dp_attention_enabled,
)
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
...@@ -437,7 +441,7 @@ class Step3TextModel(nn.Module): ...@@ -437,7 +441,7 @@ class Step3TextModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
enable_tp=not global_server_args_dict["enable_dp_attention"], enable_tp=not is_dp_attention_enabled(),
prefix=add_prefix("embed_tokens", prefix), prefix=add_prefix("embed_tokens", prefix),
) )
......
from __future__ import annotations
import os import os
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, Generator, List, Sequence, Union from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Sequence, Union
import torch import torch
from sglang.srt.layers.dp_attention import set_dp_buffer_len
if TYPE_CHECKING:
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
_ENABLE_PROFILE = bool(int(os.environ.get("SGLANG_OPERATIONS_ENABLE_PROFILE", "0"))) _ENABLE_PROFILE = bool(int(os.environ.get("SGLANG_OPERATIONS_ENABLE_PROFILE", "0")))
if _ENABLE_PROFILE: if _ENABLE_PROFILE:
...@@ -66,18 +73,26 @@ Stage = List[ExecutionOperation] ...@@ -66,18 +73,26 @@ Stage = List[ExecutionOperation]
class _StageExecutor: class _StageExecutor:
def __init__(self, debug_name: str, stages: List[Stage], inputs): def __init__(self, debug_name: str, stages: List[Stage], inputs: dict):
self._debug_name = debug_name self._debug_name = debug_name
self._stages = stages self._stages = stages
self._index = 0 self._index = 0
self._stage_state = _StateDict() self._stage_state = _StateDict()
self._stage_output = inputs self._stage_output = inputs
# handling DP attention
forward_batch: ForwardBatch = inputs["forward_batch"]
self._global_dp_buffer_len = forward_batch.global_dp_buffer_len
self._local_dp_buffer_len = forward_batch.input_ids.shape[0]
def next(self): def next(self):
assert not self.done assert not self.done
stage = self._stages[self._index] stage = self._stages[self._index]
if self._global_dp_buffer_len is not None:
set_dp_buffer_len(self._global_dp_buffer_len, self._local_dp_buffer_len)
with _annotate_region(debug_name=f"{self._debug_name}{self._index}"): with _annotate_region(debug_name=f"{self._debug_name}{self._index}"):
for op in stage: for op in stage:
with _annotate_region(debug_name=op.debug_name): with _annotate_region(debug_name=op.debug_name):
......
...@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Callable ...@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Callable
import torch import torch
from sglang.srt.layers.dp_attention import DPPaddingMode from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
from sglang.srt.model_executor.cuda_graph_runner import ( from sglang.srt.model_executor.cuda_graph_runner import (
CUDA_GRAPH_CAPTURE_FAILED_MSG, CUDA_GRAPH_CAPTURE_FAILED_MSG,
CudaGraphRunner, CudaGraphRunner,
...@@ -105,30 +105,15 @@ class EAGLEDraftCudaGraphRunner: ...@@ -105,30 +105,15 @@ class EAGLEDraftCudaGraphRunner:
self.global_num_tokens_for_logprob_gpu = torch.zeros( self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32 (self.dp_size,), dtype=torch.int32
) )
self.gathered_buffer = torch.zeros(
(
self.max_num_token * self.dp_size,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
else: else:
assert self.require_attn_tp_gather assert self.require_attn_tp_gather
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
self.global_num_tokens_for_logprob_gpu = torch.zeros( self.global_num_tokens_for_logprob_gpu = torch.zeros(
(1,), dtype=torch.int32 (1,), dtype=torch.int32
) )
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
else: else:
self.global_num_tokens_gpu = None self.global_num_tokens_gpu = None
self.global_num_tokens_for_logprob_gpu = None self.global_num_tokens_for_logprob_gpu = None
self.gathered_buffer = None
# Capture # Capture
try: try:
...@@ -193,7 +178,7 @@ class EAGLEDraftCudaGraphRunner: ...@@ -193,7 +178,7 @@ class EAGLEDraftCudaGraphRunner:
) )
) )
global_num_tokens = self.global_num_tokens_gpu global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size] global_dp_buffer_len = num_tokens * self.dp_size
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
elif self.require_attn_tp_gather: elif self.require_attn_tp_gather:
self.global_num_tokens_gpu.copy_( self.global_num_tokens_gpu.copy_(
...@@ -211,11 +196,11 @@ class EAGLEDraftCudaGraphRunner: ...@@ -211,11 +196,11 @@ class EAGLEDraftCudaGraphRunner:
) )
) )
global_num_tokens = self.global_num_tokens_gpu global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens] global_dp_buffer_len = num_tokens
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
else: else:
global_num_tokens = None global_num_tokens = None
gathered_buffer = None global_dp_buffer_len = None
global_num_tokens_for_logprob = None global_num_tokens_for_logprob = None
spec_info = EagleDraftInput( spec_info = EagleDraftInput(
...@@ -239,8 +224,8 @@ class EAGLEDraftCudaGraphRunner: ...@@ -239,8 +224,8 @@ class EAGLEDraftCudaGraphRunner:
return_logprob=False, return_logprob=False,
positions=positions, positions=positions,
global_num_tokens_gpu=global_num_tokens, global_num_tokens_gpu=global_num_tokens,
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(), dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
gathered_buffer=gathered_buffer, global_dp_buffer_len=global_dp_buffer_len,
spec_algorithm=self.model_runner.spec_algorithm, spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info, spec_info=spec_info,
capture_hidden_mode=( capture_hidden_mode=(
...@@ -258,6 +243,7 @@ class EAGLEDraftCudaGraphRunner: ...@@ -258,6 +243,7 @@ class EAGLEDraftCudaGraphRunner:
def run_once(): def run_once():
# Clean intermediate result cache for DP attention # Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
# Backup two fields, which will be modified in-place in `draft_forward`. # Backup two fields, which will be modified in-place in `draft_forward`.
output_cache_loc_backup = forward_batch.out_cache_loc output_cache_loc_backup = forward_batch.out_cache_loc
......
...@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Callable ...@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Callable
import torch import torch
from sglang.srt.layers.dp_attention import DPPaddingMode from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
from sglang.srt.model_executor.cuda_graph_runner import ( from sglang.srt.model_executor.cuda_graph_runner import (
CUDA_GRAPH_CAPTURE_FAILED_MSG, CUDA_GRAPH_CAPTURE_FAILED_MSG,
CudaGraphRunner, CudaGraphRunner,
...@@ -117,30 +117,15 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -117,30 +117,15 @@ class EAGLEDraftExtendCudaGraphRunner:
self.global_num_tokens_for_logprob_gpu = torch.zeros( self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32 (self.dp_size,), dtype=torch.int32
) )
self.gathered_buffer = torch.zeros(
(
self.max_num_token * self.dp_size,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
else: else:
assert self.require_attn_tp_gather assert self.require_attn_tp_gather
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
self.global_num_tokens_for_logprob_gpu = torch.zeros( self.global_num_tokens_for_logprob_gpu = torch.zeros(
(1,), dtype=torch.int32 (1,), dtype=torch.int32
) )
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
else: else:
self.global_num_tokens_gpu = None self.global_num_tokens_gpu = None
self.global_num_tokens_for_logprob_gpu = None self.global_num_tokens_for_logprob_gpu = None
self.gathered_buffer = None
if hasattr( if hasattr(
self.model_runner.model_config.hf_config, "draft_vocab_size" self.model_runner.model_config.hf_config, "draft_vocab_size"
...@@ -222,7 +207,7 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -222,7 +207,7 @@ class EAGLEDraftExtendCudaGraphRunner:
device=self.input_ids.device, device=self.input_ids.device,
) )
) )
gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size] global_dp_buffer_len = num_tokens * self.dp_size
elif self.require_attn_tp_gather: elif self.require_attn_tp_gather:
self.global_num_tokens_gpu.copy_( self.global_num_tokens_gpu.copy_(
torch.tensor( torch.tensor(
...@@ -238,9 +223,9 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -238,9 +223,9 @@ class EAGLEDraftExtendCudaGraphRunner:
device=self.input_ids.device, device=self.input_ids.device,
) )
) )
gathered_buffer = self.gathered_buffer[:num_tokens] global_dp_buffer_len = num_tokens
else: else:
gathered_buffer = None global_dp_buffer_len = None
spec_info = EagleDraftInput( spec_info = EagleDraftInput(
hidden_states=hidden_states, hidden_states=hidden_states,
...@@ -264,8 +249,8 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -264,8 +249,8 @@ class EAGLEDraftExtendCudaGraphRunner:
positions=positions, positions=positions,
global_num_tokens_gpu=self.global_num_tokens_gpu, global_num_tokens_gpu=self.global_num_tokens_gpu,
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu, global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(), dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
gathered_buffer=gathered_buffer, global_dp_buffer_len=global_dp_buffer_len,
spec_algorithm=self.model_runner.spec_algorithm, spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info, spec_info=spec_info,
capture_hidden_mode=CaptureHiddenMode.LAST, capture_hidden_mode=CaptureHiddenMode.LAST,
...@@ -288,6 +273,7 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -288,6 +273,7 @@ class EAGLEDraftExtendCudaGraphRunner:
def run_once(): def run_once():
# Clean intermediate result cache for DP attention # Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
# Backup two fields, which will be modified in-place in `draft_forward`. # Backup two fields, which will be modified in-place in `draft_forward`.
output_cache_loc_backup = forward_batch.out_cache_loc output_cache_loc_backup = forward_batch.out_cache_loc
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment