"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "7d89e48f5c1a8e719d5b8d154151c39685f50919"
Unverified Commit 134b4f7e authored by Ethan (Yusheng) Su's avatar Ethan (Yusheng) Su Committed by GitHub
Browse files

Support deterministic inference with triton backend (#10694)

parent f67d1f45
......@@ -201,6 +201,8 @@ class Envs:
SGLANG_ENABLE_DETERMINISTIC_INFERENCE = EnvBool(False)
SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE = EnvInt(4096)
SGLANG_FLASHINFER_DECODE_SPLIT_TILE_SIZE = EnvInt(2048)
SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE = EnvInt(4096)
SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE = EnvInt(256)
# fmt: on
......
......@@ -12,7 +12,12 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import get_bool_env_var, get_device_core_count, next_power_of_2
from sglang.srt.utils import (
get_bool_env_var,
get_device_core_count,
get_int_env_var,
next_power_of_2,
)
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
......@@ -94,7 +99,25 @@ class TritonAttnBackend(AttentionBackend):
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
)
self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
self.split_tile_size = model_runner.server_args.triton_attention_split_tile_size
# Decide whether enable deterministic inference with batch-invariant operations
self.enable_deterministic = (
model_runner.server_args.enable_deterministic_inference
)
# Configure deterministic inference settings
if self.enable_deterministic:
# Use fixed split tile size for batch invariance
self.split_tile_size = get_int_env_var(
"SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE", 256
)
# Set static_kv_splits to False to use deterministic logic instead
self.static_kv_splits = False
else:
self.split_tile_size = (
model_runner.server_args.triton_attention_split_tile_size
)
if self.split_tile_size is not None:
self.max_kv_splits = (
self.max_context_len + self.split_tile_size - 1
......@@ -154,13 +177,23 @@ class TritonAttnBackend(AttentionBackend):
num_group * num_seq == num_token
), f"num_seq({num_seq}), num_token({num_token}), something goes wrong!"
if self.static_kv_splits or self.device_core_count <= 0:
# Legacy dynamic splitting logic (non-deterministic)
if (
self.static_kv_splits or self.device_core_count <= 0
) and not self.enable_deterministic:
num_kv_splits.fill_(self.max_kv_splits)
return
if self.split_tile_size is not None:
# deterministic
if self.split_tile_size is not None and self.enable_deterministic:
# expand seq_lens to match num_token
if num_group > 1:
expanded_seq_lens = seq_lens.repeat_interleave(num_group)
else:
expanded_seq_lens = seq_lens
num_kv_splits[:] = (
seq_lens + self.split_tile_size - 1
expanded_seq_lens + self.split_tile_size - 1
) // self.split_tile_size
return
......
......@@ -565,16 +565,8 @@ class Scheduler(
if get_bool_env_var("SGLANG_GC_LOG"):
configure_gc_logger()
# Init prefill kv split size when deterministic inference is enabled with flashinfer attention backend
if (
self.server_args.enable_deterministic_inference
and self.server_args.attention_backend == "flashinfer"
):
self.truncation_align_size = get_int_env_var(
"SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096
)
else:
self.truncation_align_size = None
# Init prefill kv split size when deterministic inference is enabled with various attention backends
self.init_deterministic_inference_config()
# Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher(
......@@ -621,6 +613,23 @@ class Scheduler(
]
)
def init_deterministic_inference_config(self):
"""Initialize deterministic inference configuration for different attention backends."""
if not self.server_args.enable_deterministic_inference:
self.truncation_align_size = None
return
backend_sizes = {
"flashinfer": ("SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096),
"triton": ("SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE", 4096),
}
env_var, default_size = backend_sizes.get(
self.server_args.attention_backend, (None, None)
)
self.truncation_align_size = (
get_int_env_var(env_var, default_size) if env_var else None
)
def init_tokenizer(self):
server_args = self.server_args
self.is_generation = self.model_config.is_generation
......
......@@ -118,7 +118,7 @@ DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"]
GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3"]
DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3", "triton"]
# Allow external code to add more choices
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment