Unverified Commit b87aacb5 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[DP Attention] Refactor: adding some utility functions (#9136)

parent b3363cc1
......@@ -32,6 +32,8 @@ from sglang.srt.layers.dp_attention import (
get_attention_dp_size,
get_attention_tp_rank,
get_attention_tp_size,
get_global_dp_buffer,
get_local_dp_buffer,
)
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict
......@@ -319,7 +321,7 @@ class CommunicateSimpleFn:
context: CommunicateContext,
) -> torch.Tensor:
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
get_local_dp_buffer(),
hidden_states,
)
attn_tp_all_gather_into_tensor(
......@@ -408,9 +410,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
):
if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
residual, local_residual = (
torch.empty_like(
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]]
),
get_local_dp_buffer(),
residual,
)
attn_tp_all_gather_into_tensor(residual, local_residual)
......@@ -424,7 +424,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
residual = hidden_states
hidden_states = layernorm(hidden_states)
hidden_states, local_hidden_states = (
torch.empty_like(forward_batch.gathered_buffer),
get_global_dp_buffer(),
hidden_states,
)
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
......@@ -548,7 +548,7 @@ class CommunicateSummableTensorPairFn:
allow_reduce_scatter: bool = False,
):
hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
get_local_dp_buffer(),
hidden_states,
)
if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
......@@ -569,7 +569,7 @@ class CommunicateSummableTensorPairFn:
hidden_states += residual
residual = None
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
get_local_dp_buffer(),
hidden_states,
)
attn_tp_all_gather_into_tensor(
......
......@@ -4,7 +4,7 @@ import functools
import logging
from contextlib import contextmanager
from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple
import torch
import triton
......@@ -18,21 +18,26 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_reduce,
)
if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
_ATTN_TP_GROUP = None
_ATTN_TP_RANK = None
_ATTN_TP_SIZE = None
_ATTN_DP_RANK = None
_ATTN_DP_SIZE = None
_LOCAL_ATTN_DP_SIZE = None
_LOCAL_ATTN_DP_RANK = None
_ATTN_TP_GROUP: Optional[GroupCoordinator] = None
_ATTN_TP_RANK: Optional[int] = None
_ATTN_TP_SIZE: Optional[int] = None
_ATTN_DP_RANK: Optional[int] = None
_ATTN_DP_SIZE: Optional[int] = None
_LOCAL_ATTN_DP_SIZE: Optional[int] = None
_LOCAL_ATTN_DP_RANK: Optional[int] = None
_ENABLE_DP_ATTENTION_FLAG: bool = False
class DPPaddingMode(IntEnum):
class DpPaddingMode(IntEnum):
# Padding tokens to max length and then gather tokens using `all_gather_into_tensor`
MAX_LEN = auto()
......@@ -40,13 +45,13 @@ class DPPaddingMode(IntEnum):
SUM_LEN = auto()
def is_max_len(self):
return self == DPPaddingMode.MAX_LEN
return self == DpPaddingMode.MAX_LEN
def is_sum_len(self):
return self == DPPaddingMode.SUM_LEN
return self == DpPaddingMode.SUM_LEN
@classmethod
def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DPPaddingMode:
def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DpPaddingMode:
# we choose the mode that minimizes the communication cost
max_len = max(global_num_tokens)
sum_len = sum(global_num_tokens)
......@@ -56,10 +61,76 @@ class DPPaddingMode(IntEnum):
return cls.SUM_LEN
@classmethod
def get_default_mode_in_cuda_graph(cls) -> DPPaddingMode:
def get_default_mode_in_cuda_graph(cls) -> DpPaddingMode:
return cls.MAX_LEN
class _DpGatheredBufferWrapper:
_hidden_size: int
_dtype: torch.dtype
_device: torch.device
_global_dp_buffer_len: int
_local_dp_buffer_len: int
@classmethod
def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
cls._hidden_size = hidden_size
cls._dtype = dtype
cls._device = device
@classmethod
def set_dp_buffer_len(cls, global_dp_buffer_len: int, local_dp_buffer_len: int):
cls._global_dp_buffer_len = global_dp_buffer_len
cls._local_dp_buffer_len = local_dp_buffer_len
@classmethod
def get_global_dp_buffer(cls) -> torch.Tensor:
return torch.empty(
(cls._global_dp_buffer_len, cls._hidden_size),
dtype=cls._dtype,
device=cls._device,
)
@classmethod
def get_local_dp_buffer(cls) -> torch.Tensor:
return torch.empty(
(cls._local_dp_buffer_len, cls._hidden_size),
dtype=cls._dtype,
device=cls._device,
)
@classmethod
def get_global_dp_buffer_len(cls) -> int:
return cls._global_dp_buffer_len
@classmethod
def get_local_dp_buffer_len(cls) -> int:
return cls._local_dp_buffer_len
def set_dp_buffer_len(global_dp_buffer_len: int, local_dp_buffer_len: int):
_DpGatheredBufferWrapper.set_dp_buffer_len(
global_dp_buffer_len, local_dp_buffer_len
)
def get_global_dp_buffer() -> torch.Tensor:
return _DpGatheredBufferWrapper.get_global_dp_buffer()
def get_local_dp_buffer() -> torch.Tensor:
return _DpGatheredBufferWrapper.get_local_dp_buffer()
def get_global_dp_buffer_len() -> int:
return _DpGatheredBufferWrapper.get_global_dp_buffer_len()
def get_local_dp_buffer_len() -> int:
return _DpGatheredBufferWrapper.get_local_dp_buffer_len()
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
if not enable_dp_attention:
return tp_rank, tp_size, 0
......@@ -89,18 +160,24 @@ def compute_dp_attention_local_info(
def initialize_dp_attention(
enable_dp_attention: bool,
tp_rank: int,
tp_size: int,
dp_size: int,
moe_dense_tp_size: int,
pp_size: int,
server_args: ServerArgs,
model_config: ModelConfig,
):
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE
global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK
global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK, _ENABLE_DP_ATTENTION_FLAG
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
enable_dp_attention = server_args.enable_dp_attention
tp_size = server_args.tp_size
dp_size = server_args.dp_size
moe_dense_tp_size = server_args.moe_dense_tp_size
pp_size = server_args.pp_size
tp_rank = get_tensor_model_parallel_rank()
_ENABLE_DP_ATTENTION_FLAG = enable_dp_attention
_ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info(
enable_dp_attention, tp_rank, tp_size, dp_size
)
......@@ -135,38 +212,48 @@ def initialize_dp_attention(
group_name="attention_tp",
)
_DpGatheredBufferWrapper.set_metadata(
hidden_size=model_config.hidden_size,
dtype=model_config.dtype,
device=torch.device("cuda"),
)
def is_dp_attention_enabled() -> bool:
return _ENABLE_DP_ATTENTION_FLAG
def get_attention_tp_group():
def get_attention_tp_group() -> GroupCoordinator:
assert _ATTN_TP_GROUP is not None, "dp attention not initialized!"
return _ATTN_TP_GROUP
def get_attention_tp_rank():
def get_attention_tp_rank() -> int:
assert _ATTN_TP_RANK is not None, "dp attention not initialized!"
return _ATTN_TP_RANK
def get_attention_tp_size():
def get_attention_tp_size() -> int:
assert _ATTN_TP_SIZE is not None, "dp attention not initialized!"
return _ATTN_TP_SIZE
def get_attention_dp_rank():
def get_attention_dp_rank() -> int:
assert _ATTN_DP_RANK is not None, "dp attention not initialized!"
return _ATTN_DP_RANK
def get_attention_dp_size():
def get_attention_dp_size() -> int:
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
return _ATTN_DP_SIZE
def get_local_attention_dp_rank():
def get_local_attention_dp_rank() -> int:
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
return _LOCAL_ATTN_DP_RANK
def get_local_attention_dp_size():
def get_local_attention_dp_size() -> int:
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
return _LOCAL_ATTN_DP_SIZE
......
......@@ -27,7 +27,7 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_gather,
)
from sglang.srt.layers.dp_attention import (
DPPaddingMode,
DpPaddingMode,
attn_tp_all_gather,
attn_tp_all_gather_into_tensor,
dp_gather_replicate,
......@@ -35,7 +35,9 @@ from sglang.srt.layers.dp_attention import (
get_attention_dp_rank,
get_attention_dp_size,
get_attention_tp_size,
get_global_dp_buffer,
get_local_attention_dp_size,
set_dp_buffer_len,
)
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.schedule_batch import global_server_args_dict
......@@ -108,14 +110,12 @@ class LogitsMetadata:
# The start position of local hidden states.
dp_local_start_pos: Optional[torch.Tensor] = None
dp_local_num_tokens: Optional[torch.Tensor] = None
gathered_buffer: Optional[torch.Tensor] = None
# Buffer to gather logits from all ranks.
forward_batch_gathered_buffer: Optional[torch.Tensor] = None
global_dp_buffer_len: Optional[int] = None
# Number of tokens to sample per DP rank
global_num_tokens_for_logprob_cpu: Optional[torch.Tensor] = None
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
# The gather mode for DP attention
dp_padding_mode: Optional[DPPaddingMode] = None
dp_padding_mode: Optional[DpPaddingMode] = None
# for padding
padded_static_len: int = -1
......@@ -164,11 +164,10 @@ class LogitsMetadata:
global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
dp_local_start_pos=forward_batch.dp_local_start_pos,
dp_local_num_tokens=forward_batch.dp_local_num_tokens,
gathered_buffer=forward_batch.gathered_buffer,
forward_batch_gathered_buffer=forward_batch.gathered_buffer,
global_dp_buffer_len=forward_batch.global_dp_buffer_len,
global_num_tokens_for_logprob_cpu=forward_batch.global_num_tokens_for_logprob_cpu,
global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu,
dp_padding_mode=DPPaddingMode.SUM_LEN,
dp_padding_mode=DpPaddingMode.SUM_LEN,
)
def compute_dp_attention_metadata(self):
......@@ -188,16 +187,11 @@ class LogitsMetadata:
if self.global_num_tokens_for_logprob_cpu is not None:
# create a smaller buffer to reduce peak memory usage
self.gathered_buffer = torch.empty(
(
sum(self.global_num_tokens_for_logprob_cpu),
self.gathered_buffer.shape[1],
),
dtype=self.gathered_buffer.dtype,
device=self.gathered_buffer.device,
)
self.global_dp_buffer_len = sum(self.global_num_tokens_for_logprob_cpu)
else:
self.gathered_buffer = torch.empty_like(self.gathered_buffer)
self.global_dp_buffer_len = self.global_dp_buffer_len
set_dp_buffer_len(self.global_dp_buffer_len, self.dp_local_num_tokens)
class LogitsProcessor(nn.Module):
......@@ -443,7 +437,7 @@ class LogitsProcessor(nn.Module):
if self.do_tensor_parallel_all_gather_dp_attn:
logits_metadata.compute_dp_attention_metadata()
hidden_states, local_hidden_states = (
logits_metadata.gathered_buffer,
get_global_dp_buffer(),
hidden_states,
)
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
......
......@@ -6,7 +6,10 @@ import torch.distributed as dist
from torch import nn
from sglang.srt.distributed import get_tp_group
from sglang.srt.layers.dp_attention import get_attention_tp_group
from sglang.srt.layers.dp_attention import (
get_attention_tp_group,
is_dp_attention_enabled,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
......@@ -32,7 +35,7 @@ class Sampler(nn.Module):
self.use_nan_detection = global_server_args_dict["enable_nan_detection"]
self.tp_sync_group = get_tp_group().device_group
if global_server_args_dict["enable_dp_attention"]:
if is_dp_attention_enabled():
self.tp_sync_group = get_attention_tp_group().device_group
def forward(
......
......@@ -84,7 +84,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
"device",
"disable_chunked_prefix_cache",
"disable_radix_cache",
"enable_dp_attention",
"enable_two_batch_overlap",
"tbo_token_distribution_threshold",
"enable_dp_lm_head",
......
......@@ -34,9 +34,10 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
)
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
from sglang.srt.layers.dp_attention import (
DPPaddingMode,
DpPaddingMode,
get_attention_tp_rank,
get_attention_tp_size,
set_dp_buffer_len,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.torchao_utils import save_gemlite_cache
......@@ -349,30 +350,15 @@ class CudaGraphRunner:
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
self.gathered_buffer = torch.zeros(
(
self.max_num_token * self.dp_size,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
else:
assert self.require_attn_tp_gather
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(1,), dtype=torch.int32
)
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
else:
self.global_num_tokens_gpu = None
self.global_num_tokens_for_logprob_gpu = None
self.gathered_buffer = None
self.custom_mask = torch.ones(
(
......@@ -556,7 +542,7 @@ class CudaGraphRunner:
device=input_ids.device,
)
)
gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
global_dp_buffer_len = num_tokens * self.dp_size
elif self.require_attn_tp_gather:
self.global_num_tokens_gpu.copy_(
torch.tensor(
......@@ -572,9 +558,9 @@ class CudaGraphRunner:
device=input_ids.device,
)
)
gathered_buffer = self.gathered_buffer[:num_tokens]
global_dp_buffer_len = num_tokens
else:
gathered_buffer = None
global_dp_buffer_len = None
spec_info = self.get_spec_info(num_tokens)
if self.capture_hidden_mode != CaptureHiddenMode.FULL:
......@@ -607,8 +593,8 @@ class CudaGraphRunner:
positions=positions,
global_num_tokens_gpu=self.global_num_tokens_gpu,
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
gathered_buffer=gathered_buffer,
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
global_dp_buffer_len=global_dp_buffer_len,
mrope_positions=mrope_positions,
spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info,
......@@ -637,6 +623,7 @@ class CudaGraphRunner:
def run_once():
# Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
kwargs = {}
if (
......
......@@ -40,9 +40,10 @@ import triton.language as tl
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.dp_attention import (
DPPaddingMode,
DpPaddingMode,
get_attention_dp_rank,
get_attention_tp_size,
set_dp_buffer_len,
)
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import (
......@@ -274,13 +275,13 @@ class ForwardBatch:
global_num_tokens_for_logprob_cpu: Optional[List[int]] = None
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
# The padding mode for DP attention
dp_padding_mode: Optional[DPPaddingMode] = None
dp_padding_mode: Optional[DpPaddingMode] = None
# for extend, local start pos and num tokens is different in logits processor
# this will be computed in get_dp_local_info
# this will be recomputed in LogitsMetadata.from_forward_batch
dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
gathered_buffer: Optional[torch.Tensor] = None
global_dp_buffer_len: Optional[int] = None
is_extend_in_batch: bool = False
can_run_dp_cuda_graph: bool = False
global_forward_mode: Optional[ForwardMode] = None
......@@ -628,7 +629,7 @@ class ForwardBatch:
(global_num_tokens[i] - 1) // attn_tp_size + 1
) * attn_tp_size
dp_padding_mode = DPPaddingMode.get_dp_padding_mode(global_num_tokens)
dp_padding_mode = DpPaddingMode.get_dp_padding_mode(global_num_tokens)
self.dp_padding_mode = dp_padding_mode
if dp_padding_mode.is_max_len():
......@@ -642,17 +643,14 @@ class ForwardBatch:
else:
buffer_len = sum(global_num_tokens)
self.gathered_buffer = torch.zeros(
(buffer_len, model_runner.model_config.hidden_size),
dtype=model_runner.dtype,
device=model_runner.device,
)
if len(global_num_tokens) > 1:
num_tokens = global_num_tokens[get_attention_dp_rank()]
else:
num_tokens = global_num_tokens[0]
self.global_dp_buffer_len = buffer_len
set_dp_buffer_len(buffer_len, num_tokens)
bs = self.batch_size
if self.forward_mode.is_decode():
......
......@@ -603,12 +603,8 @@ class ModelRunner:
duplicate_tp_group=self.server_args.enable_pdmux,
)
initialize_dp_attention(
enable_dp_attention=self.server_args.enable_dp_attention,
tp_rank=self.tp_rank,
tp_size=self.tp_size,
dp_size=self.server_args.dp_size,
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
pp_size=self.server_args.pp_size,
server_args=self.server_args,
model_config=self.model_config,
)
min_per_gpu_memory = get_available_gpu_memory(
......
......@@ -22,6 +22,7 @@ from transformers import PretrainedConfig
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
......@@ -56,7 +57,7 @@ class DeepseekModelNextN(nn.Module):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
enable_tp=not global_server_args_dict["enable_dp_attention"],
enable_tp=not is_dp_attention_enabled(),
prefix=add_prefix("embed_tokens", prefix),
)
......
......@@ -51,6 +51,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
is_dp_attention_enabled,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
......@@ -1797,7 +1798,6 @@ class DeepseekV2DecoderLayer(nn.Module):
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
self.layer_id = layer_id
self.is_nextn = is_nextn
......@@ -1917,7 +1917,9 @@ class DeepseekV2DecoderLayer(nn.Module):
should_allreduce_fusion = (
self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
and not (self.enable_dp_attention and self.speculative_algorithm.is_eagle())
and not (
is_dp_attention_enabled() and self.speculative_algorithm.is_eagle()
)
and not self.is_nextn
)
......@@ -2047,7 +2049,7 @@ class DeepseekV2Model(nn.Module):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
enable_tp=not global_server_args_dict["enable_dp_attention"],
enable_tp=not is_dp_attention_enabled(),
)
self.alt_stream = torch.cuda.Stream() if _is_cuda else None
self.layers = nn.ModuleList(
......
......@@ -40,6 +40,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
is_dp_attention_enabled,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
......@@ -634,7 +635,6 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
)
rms_norm_eps = config.rms_norm_eps
attention_bias = config.attention_bias
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
self.layer_id = layer_id
self.self_attn = Glm4MoeAttention(
hidden_size=self.hidden_size,
......@@ -744,7 +744,7 @@ class Glm4MoeModel(DeepseekV2Model):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
enable_tp=not global_server_args_dict["enable_dp_attention"],
enable_tp=not is_dp_attention_enabled(),
)
self.alt_stream = torch.cuda.Stream() if _is_cuda else None
self.layers = nn.ModuleList(
......
......@@ -22,6 +22,7 @@ from transformers import PretrainedConfig
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
......@@ -56,7 +57,7 @@ class Glm4MoeModelNextN(nn.Module):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
enable_tp=not global_server_args_dict["enable_dp_attention"],
enable_tp=not is_dp_attention_enabled(),
prefix=add_prefix("embed_tokens", prefix),
)
......
......@@ -41,6 +41,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
is_dp_attention_enabled,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
......@@ -565,7 +566,7 @@ class GptOssModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
enable_tp=not global_server_args_dict["enable_dp_attention"],
enable_tp=not is_dp_attention_enabled(),
prefix=add_prefix("embed_tokens", prefix),
)
else:
......
......@@ -32,6 +32,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
is_dp_attention_enabled,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
......@@ -45,7 +46,6 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
......@@ -466,7 +466,7 @@ class Llama4Model(nn.Module):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("embed_tokens", prefix),
enable_tp=not global_server_args_dict["enable_dp_attention"],
enable_tp=not is_dp_attention_enabled(),
)
self.layers = make_layers(
config.num_hidden_layers,
......
......@@ -27,6 +27,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
......@@ -43,7 +44,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
......@@ -273,7 +273,7 @@ class Qwen2Model(nn.Module):
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
enable_tp=not global_server_args_dict["enable_dp_attention"],
enable_tp=not is_dp_attention_enabled(),
prefix=add_prefix("embed_tokens", prefix),
)
else:
......
......@@ -46,6 +46,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
is_dp_attention_enabled,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
......@@ -420,7 +421,7 @@ class Qwen2MoeModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
enable_tp=not global_server_args_dict["enable_dp_attention"],
enable_tp=not is_dp_attention_enabled(),
prefix=add_prefix("embed_tokens", prefix),
)
else:
......
......@@ -25,7 +25,11 @@ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
is_dp_attention_enabled,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
......@@ -437,7 +441,7 @@ class Step3TextModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
enable_tp=not global_server_args_dict["enable_dp_attention"],
enable_tp=not is_dp_attention_enabled(),
prefix=add_prefix("embed_tokens", prefix),
)
......
from __future__ import annotations
import os
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Callable, Dict, Generator, List, Sequence, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Sequence, Union
import torch
from sglang.srt.layers.dp_attention import set_dp_buffer_len
if TYPE_CHECKING:
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
_ENABLE_PROFILE = bool(int(os.environ.get("SGLANG_OPERATIONS_ENABLE_PROFILE", "0")))
if _ENABLE_PROFILE:
......@@ -66,18 +73,26 @@ Stage = List[ExecutionOperation]
class _StageExecutor:
def __init__(self, debug_name: str, stages: List[Stage], inputs):
def __init__(self, debug_name: str, stages: List[Stage], inputs: dict):
self._debug_name = debug_name
self._stages = stages
self._index = 0
self._stage_state = _StateDict()
self._stage_output = inputs
# handling DP attention
forward_batch: ForwardBatch = inputs["forward_batch"]
self._global_dp_buffer_len = forward_batch.global_dp_buffer_len
self._local_dp_buffer_len = forward_batch.input_ids.shape[0]
def next(self):
assert not self.done
stage = self._stages[self._index]
if self._global_dp_buffer_len is not None:
set_dp_buffer_len(self._global_dp_buffer_len, self._local_dp_buffer_len)
with _annotate_region(debug_name=f"{self._debug_name}{self._index}"):
for op in stage:
with _annotate_region(debug_name=op.debug_name):
......
......@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Callable
import torch
from sglang.srt.layers.dp_attention import DPPaddingMode
from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
from sglang.srt.model_executor.cuda_graph_runner import (
CUDA_GRAPH_CAPTURE_FAILED_MSG,
CudaGraphRunner,
......@@ -105,30 +105,15 @@ class EAGLEDraftCudaGraphRunner:
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
self.gathered_buffer = torch.zeros(
(
self.max_num_token * self.dp_size,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
else:
assert self.require_attn_tp_gather
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(1,), dtype=torch.int32
)
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
else:
self.global_num_tokens_gpu = None
self.global_num_tokens_for_logprob_gpu = None
self.gathered_buffer = None
# Capture
try:
......@@ -193,7 +178,7 @@ class EAGLEDraftCudaGraphRunner:
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
global_dp_buffer_len = num_tokens * self.dp_size
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
elif self.require_attn_tp_gather:
self.global_num_tokens_gpu.copy_(
......@@ -211,11 +196,11 @@ class EAGLEDraftCudaGraphRunner:
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
global_dp_buffer_len = num_tokens
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
else:
global_num_tokens = None
gathered_buffer = None
global_dp_buffer_len = None
global_num_tokens_for_logprob = None
spec_info = EagleDraftInput(
......@@ -239,8 +224,8 @@ class EAGLEDraftCudaGraphRunner:
return_logprob=False,
positions=positions,
global_num_tokens_gpu=global_num_tokens,
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
gathered_buffer=gathered_buffer,
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
global_dp_buffer_len=global_dp_buffer_len,
spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info,
capture_hidden_mode=(
......@@ -258,6 +243,7 @@ class EAGLEDraftCudaGraphRunner:
def run_once():
# Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
# Backup two fields, which will be modified in-place in `draft_forward`.
output_cache_loc_backup = forward_batch.out_cache_loc
......
......@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Callable
import torch
from sglang.srt.layers.dp_attention import DPPaddingMode
from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
from sglang.srt.model_executor.cuda_graph_runner import (
CUDA_GRAPH_CAPTURE_FAILED_MSG,
CudaGraphRunner,
......@@ -117,30 +117,15 @@ class EAGLEDraftExtendCudaGraphRunner:
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
self.gathered_buffer = torch.zeros(
(
self.max_num_token * self.dp_size,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
else:
assert self.require_attn_tp_gather
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(1,), dtype=torch.int32
)
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
else:
self.global_num_tokens_gpu = None
self.global_num_tokens_for_logprob_gpu = None
self.gathered_buffer = None
if hasattr(
self.model_runner.model_config.hf_config, "draft_vocab_size"
......@@ -222,7 +207,7 @@ class EAGLEDraftExtendCudaGraphRunner:
device=self.input_ids.device,
)
)
gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
global_dp_buffer_len = num_tokens * self.dp_size
elif self.require_attn_tp_gather:
self.global_num_tokens_gpu.copy_(
torch.tensor(
......@@ -238,9 +223,9 @@ class EAGLEDraftExtendCudaGraphRunner:
device=self.input_ids.device,
)
)
gathered_buffer = self.gathered_buffer[:num_tokens]
global_dp_buffer_len = num_tokens
else:
gathered_buffer = None
global_dp_buffer_len = None
spec_info = EagleDraftInput(
hidden_states=hidden_states,
......@@ -264,8 +249,8 @@ class EAGLEDraftExtendCudaGraphRunner:
positions=positions,
global_num_tokens_gpu=self.global_num_tokens_gpu,
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
gathered_buffer=gathered_buffer,
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
global_dp_buffer_len=global_dp_buffer_len,
spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info,
capture_hidden_mode=CaptureHiddenMode.LAST,
......@@ -288,6 +273,7 @@ class EAGLEDraftExtendCudaGraphRunner:
def run_once():
# Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
# Backup two fields, which will be modified in-place in `draft_forward`.
output_cache_loc_backup = forward_batch.out_cache_loc
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment