Unverified Commit 8ecef73f authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

[1/2] Support deterministic inference with flashinfer attention backend (#10645)


Co-authored-by: default avatarhebiao064 <hebiaobuaa@gmail.com>
Co-authored-by: default avatarQiaolin-Yu <liin1211@outlook.com>
parent 1d1ce624
...@@ -197,6 +197,11 @@ class Envs: ...@@ -197,6 +197,11 @@ class Envs:
SGLANG_SYNC_TOKEN_IDS_ACROSS_TP = EnvBool(False) SGLANG_SYNC_TOKEN_IDS_ACROSS_TP = EnvBool(False)
SGLANG_ENABLE_COLOCATED_BATCH_GEN = EnvBool(False) SGLANG_ENABLE_COLOCATED_BATCH_GEN = EnvBool(False)
# Deterministic inference
SGLANG_ENABLE_DETERMINISTIC_INFERENCE = EnvBool(False)
SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE = EnvInt(4096)
SGLANG_FLASHINFER_DECODE_SPLIT_TILE_SIZE = EnvInt(2048)
# fmt: on # fmt: on
......
...@@ -31,6 +31,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo ...@@ -31,6 +31,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
from sglang.srt.utils import ( from sglang.srt.utils import (
get_int_env_var,
is_flashinfer_available, is_flashinfer_available,
is_sm100_supported, is_sm100_supported,
next_power_of_2, next_power_of_2,
...@@ -40,6 +41,7 @@ if TYPE_CHECKING: ...@@ -40,6 +41,7 @@ if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
if is_flashinfer_available(): if is_flashinfer_available():
from flashinfer import ( from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper, BatchDecodeWithPagedKVCacheWrapper,
...@@ -123,12 +125,33 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -123,12 +125,33 @@ class FlashInferAttnBackend(AttentionBackend):
): ):
global_config.flashinfer_workspace_size = 512 * 1024 * 1024 global_config.flashinfer_workspace_size = 512 * 1024 * 1024
# When deterministic inference is enabled, tensor cores should be used for decode
# Also set split tile sizes for prefill and decode from environment variables, and disable kv split for cuda graph
# More information can be found here: https://github.com/flashinfer-ai/flashinfer/pull/1675
self.enable_deterministic = (
model_runner.server_args.enable_deterministic_inference
)
self.prefill_split_tile_size = None
self.decode_split_tile_size = None
self.disable_cuda_graph_kv_split = False
if self.enable_deterministic:
self.decode_use_tensor_cores = True
self.prefill_split_tile_size = get_int_env_var(
"SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096
)
self.decode_split_tile_size = get_int_env_var(
"SGLANG_FLASHINFER_DECODE_SPLIT_TILE_SIZE", 2048
)
self.disable_cuda_graph_kv_split = True
global_config.flashinfer_workspace_size = 2048 * 1024 * 1024
# Allocate buffers # Allocate buffers
global global_workspace_buffer global global_workspace_buffer
if global_workspace_buffer is None: if global_workspace_buffer is None:
# different from flashinfer zero_init_global_workspace_buffer # different from flashinfer zero_init_global_workspace_buffer
global_workspace_size = global_config.flashinfer_workspace_size
global_workspace_buffer = torch.empty( global_workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size, global_workspace_size,
dtype=torch.uint8, dtype=torch.uint8,
device=model_runner.device, device=model_runner.device,
) )
...@@ -219,6 +242,8 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -219,6 +242,8 @@ class FlashInferAttnBackend(AttentionBackend):
decode_wrappers=self.decode_wrappers, decode_wrappers=self.decode_wrappers,
encoder_lens=forward_batch.encoder_lens, encoder_lens=forward_batch.encoder_lens,
spec_info=forward_batch.spec_info, spec_info=forward_batch.spec_info,
fixed_split_size=self.decode_split_tile_size,
disable_split_kv=False,
) )
self.forward_metadata = DecodeMetadata(self.decode_wrappers) self.forward_metadata = DecodeMetadata(self.decode_wrappers)
elif forward_batch.forward_mode.is_draft_extend(): elif forward_batch.forward_mode.is_draft_extend():
...@@ -258,7 +283,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -258,7 +283,7 @@ class FlashInferAttnBackend(AttentionBackend):
use_ragged = False use_ragged = False
extend_no_prefix = False extend_no_prefix = False
else: else:
use_ragged = True use_ragged = not self.enable_deterministic
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
self.indices_updater_prefill.update( self.indices_updater_prefill.update(
...@@ -271,6 +296,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -271,6 +296,7 @@ class FlashInferAttnBackend(AttentionBackend):
use_ragged=use_ragged, use_ragged=use_ragged,
encoder_lens=forward_batch.encoder_lens, encoder_lens=forward_batch.encoder_lens,
spec_info=None, spec_info=None,
fixed_split_size=self.prefill_split_tile_size,
) )
self.forward_metadata = PrefillMetadata( self.forward_metadata = PrefillMetadata(
self.prefill_wrappers_paged, use_ragged, extend_no_prefix self.prefill_wrappers_paged, use_ragged, extend_no_prefix
...@@ -347,6 +373,8 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -347,6 +373,8 @@ class FlashInferAttnBackend(AttentionBackend):
decode_wrappers=decode_wrappers, decode_wrappers=decode_wrappers,
encoder_lens=encoder_lens, encoder_lens=encoder_lens,
spec_info=spec_info, spec_info=spec_info,
fixed_split_size=None,
disable_split_kv=self.disable_cuda_graph_kv_split,
) )
self.decode_cuda_graph_metadata[bs] = decode_wrappers self.decode_cuda_graph_metadata[bs] = decode_wrappers
self.forward_metadata = DecodeMetadata(decode_wrappers) self.forward_metadata = DecodeMetadata(decode_wrappers)
...@@ -439,6 +467,8 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -439,6 +467,8 @@ class FlashInferAttnBackend(AttentionBackend):
decode_wrappers=self.decode_cuda_graph_metadata[bs], decode_wrappers=self.decode_cuda_graph_metadata[bs],
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None, encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
spec_info=spec_info, spec_info=spec_info,
fixed_split_size=None,
disable_split_kv=self.disable_cuda_graph_kv_split,
) )
elif forward_mode.is_target_verify(): elif forward_mode.is_target_verify():
self.indices_updater_prefill.update( self.indices_updater_prefill.update(
...@@ -646,6 +676,8 @@ class FlashInferIndicesUpdaterDecode: ...@@ -646,6 +676,8 @@ class FlashInferIndicesUpdaterDecode:
spec_info: Optional[ spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
], ],
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
): ):
# Keep the signature for type checking. It will be assigned during runtime. # Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError() raise NotImplementedError()
...@@ -661,6 +693,8 @@ class FlashInferIndicesUpdaterDecode: ...@@ -661,6 +693,8 @@ class FlashInferIndicesUpdaterDecode:
spec_info: Optional[ spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
], ],
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
): ):
decode_wrappers = decode_wrappers or self.decode_wrappers decode_wrappers = decode_wrappers or self.decode_wrappers
self.call_begin_forward( self.call_begin_forward(
...@@ -672,6 +706,8 @@ class FlashInferIndicesUpdaterDecode: ...@@ -672,6 +706,8 @@ class FlashInferIndicesUpdaterDecode:
None, None,
spec_info, spec_info,
seq_lens_cpu, seq_lens_cpu,
fixed_split_size=fixed_split_size,
disable_split_kv=disable_split_kv,
) )
def update_sliding_window( def update_sliding_window(
...@@ -685,6 +721,8 @@ class FlashInferIndicesUpdaterDecode: ...@@ -685,6 +721,8 @@ class FlashInferIndicesUpdaterDecode:
spec_info: Optional[ spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
], ],
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
): ):
assert self.sliding_window_size is not None assert self.sliding_window_size is not None
for wrapper_id in range(2): for wrapper_id in range(2):
...@@ -735,6 +773,8 @@ class FlashInferIndicesUpdaterDecode: ...@@ -735,6 +773,8 @@ class FlashInferIndicesUpdaterDecode:
spec_info: Optional[ spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
], ],
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
): ):
for wrapper_id in range(2): for wrapper_id in range(2):
if wrapper_id == 0: if wrapper_id == 0:
...@@ -771,6 +811,8 @@ class FlashInferIndicesUpdaterDecode: ...@@ -771,6 +811,8 @@ class FlashInferIndicesUpdaterDecode:
], ],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
use_sliding_window_kv_pool: bool = False, use_sliding_window_kv_pool: bool = False,
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
): ):
if spec_info is None: if spec_info is None:
bs = len(req_pool_indices) bs = len(req_pool_indices)
...@@ -825,6 +867,10 @@ class FlashInferIndicesUpdaterDecode: ...@@ -825,6 +867,10 @@ class FlashInferIndicesUpdaterDecode:
data_type=self.data_type, data_type=self.data_type,
q_data_type=self.q_data_type, q_data_type=self.q_data_type,
non_blocking=True, non_blocking=True,
fixed_split_size=fixed_split_size,
disable_split_kv=(
disable_split_kv if disable_split_kv is not None else False
),
) )
if locally_override: if locally_override:
...@@ -876,6 +922,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -876,6 +922,7 @@ class FlashInferIndicesUpdaterPrefill:
spec_info: Optional[ spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
], ],
fixed_split_size: Optional[int] = None,
): ):
# Keep the signature for type checking. It will be assigned during runtime. # Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError() raise NotImplementedError()
...@@ -893,6 +940,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -893,6 +940,7 @@ class FlashInferIndicesUpdaterPrefill:
spec_info: Optional[ spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
], ],
fixed_split_size: Optional[int] = None,
): ):
if use_ragged: if use_ragged:
# TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu # TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
...@@ -916,6 +964,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -916,6 +964,7 @@ class FlashInferIndicesUpdaterPrefill:
self.qo_indptr[0], self.qo_indptr[0],
use_ragged, use_ragged,
spec_info, spec_info,
fixed_split_size=fixed_split_size,
) )
def update_sliding_window( def update_sliding_window(
...@@ -931,6 +980,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -931,6 +980,7 @@ class FlashInferIndicesUpdaterPrefill:
spec_info: Optional[ spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
], ],
fixed_split_size: Optional[int] = None,
): ):
for wrapper_id in range(2): for wrapper_id in range(2):
if wrapper_id == 0: if wrapper_id == 0:
...@@ -979,6 +1029,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -979,6 +1029,7 @@ class FlashInferIndicesUpdaterPrefill:
spec_info: Optional[ spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
], ],
fixed_split_size: Optional[int] = None,
): ):
for wrapper_id in range(2): for wrapper_id in range(2):
if wrapper_id == 0: if wrapper_id == 0:
...@@ -1024,6 +1075,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -1024,6 +1075,7 @@ class FlashInferIndicesUpdaterPrefill:
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
], ],
use_sliding_window_kv_pool: bool = False, use_sliding_window_kv_pool: bool = False,
fixed_split_size: Optional[int] = None,
): ):
bs = len(seq_lens) bs = len(seq_lens)
if spec_info is None: if spec_info is None:
...@@ -1094,6 +1146,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -1094,6 +1146,7 @@ class FlashInferIndicesUpdaterPrefill:
kv_data_type=self.data_type, kv_data_type=self.data_type,
custom_mask=custom_mask, custom_mask=custom_mask,
non_blocking=True, non_blocking=True,
fixed_split_size=fixed_split_size,
) )
...@@ -1327,6 +1380,8 @@ def fast_decode_plan( ...@@ -1327,6 +1380,8 @@ def fast_decode_plan(
rope_scale: Optional[float] = None, rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None, rope_theta: Optional[float] = None,
non_blocking: bool = True, non_blocking: bool = True,
fixed_split_size: Optional[int] = None,
disable_split_kv: bool = False,
) -> None: ) -> None:
""" """
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend. A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
...@@ -1352,6 +1407,9 @@ def fast_decode_plan( ...@@ -1352,6 +1407,9 @@ def fast_decode_plan(
if self.use_tensor_cores: if self.use_tensor_cores:
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
# Here we set fixed_split_size to -1 to avoid the assertion error in flashinfer's plan function
if fixed_split_size is None:
fixed_split_size = -1
if self.is_cuda_graph_enabled: if self.is_cuda_graph_enabled:
if batch_size != self._fixed_batch_size: if batch_size != self._fixed_batch_size:
...@@ -1433,8 +1491,8 @@ def fast_decode_plan( ...@@ -1433,8 +1491,8 @@ def fast_decode_plan(
head_dim, head_dim,
False, # causal False, # causal
window_left, window_left,
-1, fixed_split_size,
False, disable_split_kv,
) )
except Exception as e: except Exception as e:
raise RuntimeError(f"Error in standard plan: {e}") raise RuntimeError(f"Error in standard plan: {e}")
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
"""Fused operators for normalization layers.""" """Fused operators for normalization layers."""
import logging import logging
import os
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
...@@ -80,6 +81,8 @@ class RMSNorm(CustomOp): ...@@ -80,6 +81,8 @@ class RMSNorm(CustomOp):
) )
if _use_aiter: if _use_aiter:
self._forward_method = self.forward_aiter self._forward_method = self.forward_aiter
if os.environ["SGLANG_ENABLE_DETERMINISTIC_INFERENCE"] == "1":
self._forward_method = self.forward_native
def forward_cuda( def forward_cuda(
self, self,
......
...@@ -111,6 +111,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ ...@@ -111,6 +111,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"enable_symm_mem", "enable_symm_mem",
"enable_custom_logit_processor", "enable_custom_logit_processor",
"disaggregation_mode", "disaggregation_mode",
"enable_deterministic_inference",
] ]
# Put some global args for easy access # Put some global args for easy access
......
...@@ -541,7 +541,9 @@ class PrefillAdder: ...@@ -541,7 +541,9 @@ class PrefillAdder:
return self.budget_state() return self.budget_state()
def add_one_req(self, req: Req, has_chunked_req: bool): def add_one_req(
self, req: Req, has_chunked_req: bool, truncation_align_size: Optional[int]
):
if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True): if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True):
return self.add_one_req_ignore_eos(req, has_chunked_req) return self.add_one_req_ignore_eos(req, has_chunked_req)
...@@ -600,6 +602,17 @@ class PrefillAdder: ...@@ -600,6 +602,17 @@ class PrefillAdder:
if trunc_len <= 0: if trunc_len <= 0:
return AddReqResult.OTHER return AddReqResult.OTHER
# When truncation align size is set, we want to assert that the prefill prefix length is multiple of truncation align size
# A typical use case is when deterministic inference is enabled with flashinfer attention backend,
# we need the prefill prefix length to be multiple of attention split size
if truncation_align_size is not None:
if trunc_len < truncation_align_size:
return AddReqResult.OTHER
else:
trunc_len = truncation_align_size * (
trunc_len // truncation_align_size
)
# Chunked prefill # Chunked prefill
req.extend_input_len = trunc_len req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len] req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
......
...@@ -172,6 +172,7 @@ from sglang.srt.utils import ( ...@@ -172,6 +172,7 @@ from sglang.srt.utils import (
freeze_gc, freeze_gc,
get_available_gpu_memory, get_available_gpu_memory,
get_bool_env_var, get_bool_env_var,
get_int_env_var,
get_zmq_socket, get_zmq_socket,
is_cpu, is_cpu,
kill_itself_when_parent_died, kill_itself_when_parent_died,
...@@ -565,6 +566,17 @@ class Scheduler( ...@@ -565,6 +566,17 @@ class Scheduler(
if get_bool_env_var("SGLANG_GC_LOG"): if get_bool_env_var("SGLANG_GC_LOG"):
configure_gc_logger() configure_gc_logger()
# Init prefill kv split size when deterministic inference is enabled with flashinfer attention backend
if (
self.server_args.enable_deterministic_inference
and self.server_args.attention_backend == "flashinfer"
):
self.truncation_align_size = get_int_env_var(
"SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096
)
else:
self.truncation_align_size = None
# Init request dispatcher # Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher( self._request_dispatcher = TypeBasedDispatcher(
[ [
...@@ -1846,7 +1858,11 @@ class Scheduler( ...@@ -1846,7 +1858,11 @@ class Scheduler(
continue continue
req.init_next_round_input(self.tree_cache) req.init_next_round_input(self.tree_cache)
res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None)) res = adder.add_one_req(
req,
has_chunked_req=(self.chunked_req is not None),
truncation_align_size=self.truncation_align_size,
)
if res != AddReqResult.CONTINUE: if res != AddReqResult.CONTINUE:
if res == AddReqResult.NO_TOKEN: if res == AddReqResult.NO_TOKEN:
......
...@@ -406,6 +406,12 @@ class ModelRunner: ...@@ -406,6 +406,12 @@ class ModelRunner:
) )
self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type) self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
# Enable batch invariant mode
if server_args.enable_deterministic_inference:
from batch_invariant_ops import enable_batch_invariant_mode
enable_batch_invariant_mode()
# Init memory pool and attention backends # Init memory pool and attention backends
self.init_memory_pool( self.init_memory_pool(
min_per_gpu_memory, min_per_gpu_memory,
......
...@@ -75,6 +75,7 @@ class SamplingBatchInfo: ...@@ -75,6 +75,7 @@ class SamplingBatchInfo:
@classmethod @classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
global_server_args_dict = cls._get_global_server_args_dict() global_server_args_dict = cls._get_global_server_args_dict()
enable_deterministic = global_server_args_dict["enable_deterministic_inference"]
reqs = batch.reqs reqs = batch.reqs
device = batch.device device = batch.device
......
...@@ -118,6 +118,8 @@ DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"] ...@@ -118,6 +118,8 @@ DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"]
GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"] GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer"]
# Allow external code to add more choices # Allow external code to add more choices
def add_load_format_choices(choices): def add_load_format_choices(choices):
...@@ -437,6 +439,9 @@ class ServerArgs: ...@@ -437,6 +439,9 @@ class ServerArgs:
max_mamba_cache_size: Optional[int] = None max_mamba_cache_size: Optional[int] = None
mamba_ssm_dtype: str = "float32" mamba_ssm_dtype: str = "float32"
# For deterministic inference
enable_deterministic_inference: bool = False
# Deprecated arguments # Deprecated arguments
enable_ep_moe: bool = False enable_ep_moe: bool = False
enable_deepep_moe: bool = False enable_deepep_moe: bool = False
...@@ -980,6 +985,29 @@ class ServerArgs: ...@@ -980,6 +985,29 @@ class ServerArgs:
"Please set --tokenizer-metrics-custom-labels-header when setting --tokenizer-metrics-allowed-customer-labels." "Please set --tokenizer-metrics-custom-labels-header when setting --tokenizer-metrics-allowed-customer-labels."
) )
# Deterministic inference
os.environ["SGLANG_ENABLE_DETERMINISTIC_INFERENCE"] = (
"1" if self.enable_deterministic_inference else "0"
)
if self.enable_deterministic_inference:
# Check batch_invariant_ops dependency
import importlib
if not importlib.util.find_spec("batch_invariant_ops"):
raise ValueError(
"batch_invariant_ops is not installed. Please install it from https://github.com/thinking-machines-lab/batch_invariant_ops/."
)
# Check some settings
self.disable_radix_cache = True
logger.warning(
"Currently radix cache is disabled for deterministic inference. It will be supported in the future."
)
if self.attention_backend not in DETERMINISTIC_ATTENTION_BACKEND_CHOICES:
raise ValueError(
f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference."
)
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
# Model and tokenizer # Model and tokenizer
...@@ -2470,6 +2498,13 @@ class ServerArgs: ...@@ -2470,6 +2498,13 @@ class ServerArgs:
help="Number of sm partition groups.", help="Number of sm partition groups.",
) )
# For deterministic inference
parser.add_argument(
"--enable-deterministic-inference",
action="store_true",
help="Enable deterministic inference mode with batch invariant ops.",
)
# Deprecated arguments # Deprecated arguments
parser.add_argument( parser.add_argument(
"--enable-ep-moe", "--enable-ep-moe",
......
"""
Batch the same prompt in random batch sizes, and test if the results are consistent across different trials.
Usage:
python3 -m sglang.test.test_deterministic --n-trials <numer_of_trials> --test-mode <single|mixed|prefix> --profile
"""
import argparse
import dataclasses
import json
import os
import random
from typing import List
import requests
from sglang.profiler import run_profile
PROMPT_1 = "Tell me about Richard Feynman: "
PROMPT_2 = "Generate 1000 random numbers. Go directly into it, don't say Sure and don't say here are numbers. Just start with a number."
dirpath = os.path.dirname(__file__)
with open("python/sglang/test/long_prompt.txt", "r") as f:
LONG_PROMPT = f.read()
@dataclasses.dataclass
class BenchArgs:
host: str = "localhost"
port: int = 30000
batch_size: int = 1
temperature: float = 0.0
max_new_tokens: int = 100
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
return_logprob: bool = False
stream: bool = False
profile: bool = False
profile_steps: int = 3
profile_by_stage: bool = False
test_mode: str = "single"
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--host", type=str, default=BenchArgs.host)
parser.add_argument("--port", type=int, default=BenchArgs.port)
parser.add_argument("--n-trials", type=int, default=50)
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
parser.add_argument(
"--max-new-tokens", type=int, default=BenchArgs.max_new_tokens
)
parser.add_argument(
"--frequency-penalty", type=float, default=BenchArgs.frequency_penalty
)
parser.add_argument(
"--presence-penalty", type=float, default=BenchArgs.presence_penalty
)
parser.add_argument("--return-logprob", action="store_true")
parser.add_argument("--stream", action="store_true")
parser.add_argument(
"--test-mode",
type=str,
default=BenchArgs.test_mode,
choices=["single", "mixed", "prefix"],
)
parser.add_argument("--profile", action="store_true")
parser.add_argument(
"--profile-steps", type=int, default=BenchArgs.profile_steps
)
parser.add_argument("--profile-by-stage", action="store_true")
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
attrs = [attr.name for attr in dataclasses.fields(cls)]
return cls(**{attr: getattr(args, attr) for attr in attrs})
def send_single(
args,
batch_size: int,
profile: bool = False,
profile_steps: int = 3,
profile_by_stage: bool = False,
):
base_url = f"http://{args.host}:{args.port}"
prompt = [PROMPT_1] * batch_size
json_data = {
"text": prompt,
"sampling_params": {
"temperature": args.temperature,
"max_new_tokens": args.max_new_tokens,
"frequency_penalty": args.frequency_penalty,
"presence_penalty": args.presence_penalty,
},
"return_logprob": args.return_logprob,
"stream": args.stream,
}
if profile:
run_profile(
base_url, profile_steps, ["CPU", "GPU"], None, None, profile_by_stage
)
response = requests.post(
f"{base_url}/generate",
json=json_data,
stream=args.stream,
)
if args.stream:
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
ret = json.loads(chunk[5:].strip("\n"))
else:
ret = response.json()
ret = ret[0]
if response.status_code != 200:
print(ret)
return -1
return ret["text"]
def send_mixed(args, batch_size: int):
num_long_prompt = 0 if batch_size <= 10 else random.randint(1, 10)
num_prompt_1 = random.randint(1, batch_size - num_long_prompt)
num_prompt_2 = batch_size - num_prompt_1 - num_long_prompt
json_data = {
"text": [PROMPT_1] * num_prompt_1
+ [PROMPT_2] * num_prompt_2
+ [LONG_PROMPT] * num_long_prompt,
"sampling_params": {
"temperature": args.temperature,
"max_new_tokens": args.max_new_tokens,
"frequency_penalty": args.frequency_penalty,
"presence_penalty": args.presence_penalty,
},
"return_logprob": args.return_logprob,
"stream": args.stream,
}
response = requests.post(
f"http://{args.host}:{args.port}/generate",
json=json_data,
stream=args.stream,
)
ret = response.json()
if response.status_code != 200:
print(ret)
return -1, -1, -1
prompt_1_ret = [ret[i]["text"] for i in range(num_prompt_1)]
prompt_2_ret = [
ret[i]["text"] for i in range(num_prompt_1, num_prompt_1 + num_prompt_2)
]
long_prompt_ret = [
ret[i]["text"]
for i in range(
num_prompt_1 + num_prompt_2, num_prompt_1 + num_prompt_2 + num_long_prompt
)
]
return prompt_1_ret, prompt_2_ret, long_prompt_ret
def send_prefix(args, batch_size: int, prompts: List[str]):
requests.post(f"http://{args.host}:{args.port}/flush_cache")
batch_data = []
sampled_indices = []
for _ in range(batch_size):
sampled_index = random.randint(0, len(prompts) - 1)
sampled_indices.append(sampled_index)
batch_data.append(prompts[sampled_index])
json_data = {
"text": batch_data,
"sampling_params": {
"temperature": args.temperature,
"max_new_tokens": args.max_new_tokens,
"frequency_penalty": args.frequency_penalty,
"presence_penalty": args.presence_penalty,
},
"return_logprob": args.return_logprob,
"stream": args.stream,
}
response = requests.post(
f"http://{args.host}:{args.port}/generate",
json=json_data,
stream=args.stream,
)
ret = response.json()
if response.status_code != 200:
print(ret)
return -1, -1, -1
ret_dict = {i: [] for i in range(len(prompts))}
for i in range(batch_size):
ret_dict[sampled_indices[i]].append(ret[i]["text"])
return ret_dict
def test_deterministic(args):
# First do some warmups
for i in range(3):
send_single(args, 16, args.profile)
if args.test_mode == "single":
# In single mode, we test the deterministic behavior by sending the same prompt in batch sizes ranging from 1 to n_trials.
texts = []
for i in range(1, args.n_trials + 1):
batch_size = i
text = send_single(args, batch_size, args.profile)
text = text.replace("\n", " ")
print(f"Trial {i} with batch size {batch_size}: {text}")
texts.append(text)
print(f"Total samples: {len(texts)}, Unique samples: {len(set(texts))}")
elif args.test_mode == "mixed":
# In mixed mode, we send a mixture of two short prompts and one long prompt in the same batch with batch size ranging from 1 to n_trials.
output_prompt_1 = []
output_prompt_2 = []
output_long_prompt = []
for i in range(1, args.n_trials + 1):
batch_size = i
ret_prompt_1, ret_prompt_2, ret_long_prompt = send_mixed(args, batch_size)
output_prompt_1.extend(ret_prompt_1)
output_prompt_2.extend(ret_prompt_2)
output_long_prompt.extend(ret_long_prompt)
print(
f"Testing Trial {i} with batch size {batch_size}, number of prompt 1: {len(ret_prompt_1)}, number of prompt 2: {len(ret_prompt_2)}, number of long prompt: {len(ret_long_prompt)}"
)
print(
f"Prompt 1: total samples: {len(output_prompt_1)}, Unique samples: {len(set(output_prompt_1))}"
)
print(
f"Prompt 2: total samples: {len(output_prompt_2)}, Unique samples: {len(set(output_prompt_2))}"
)
print(
f"Long prompt: total samples: {len(output_long_prompt)}, Unique samples: {len(set(output_long_prompt))}"
)
elif args.test_mode == "prefix":
# In prefix mode, we create prompts from the same long prompt, with different lengths of common prefix.
len_prefix = [1, 511, 2048, 4097]
num_prompts = len(len_prefix)
outputs = {i: [] for i in range(4)}
prompts = [LONG_PROMPT[: len_prefix[i]] for i in range(4)]
for i in range(1, args.n_trials + 1):
batch_size = i
ret_dict = send_prefix(args, batch_size, prompts)
msg = f"Testing Trial {i} with batch size {batch_size},"
for i in range(num_prompts):
msg += f" # prefix length {len_prefix[i]}: {len(ret_dict[i])},"
print(msg)
for i in range(num_prompts):
outputs[i].extend(ret_dict[i])
for i in range(num_prompts):
print(
f"Prompt {i} with prefix length {len_prefix[i]}: total samples: {len(outputs[i])}, Unique samples: {len(set(outputs[i]))}"
)
else:
raise ValueError(f"Invalid test mode: {args.test_mode}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
BenchArgs.add_cli_args(parser)
args = parser.parse_args()
test_deterministic(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment