Unverified Commit 632b7d8c authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Use simulate acc len from `sglang.environ` (#10771)

parent 16adf3dc
...@@ -124,6 +124,8 @@ class Envs: ...@@ -124,6 +124,8 @@ class Envs:
SGLANG_TEST_REQUEST_TIME_STATS = EnvBool(False) SGLANG_TEST_REQUEST_TIME_STATS = EnvBool(False)
SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK = EnvBool(False) SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK = EnvBool(False)
SGLANG_DISABLE_REQUEST_LOGGING = EnvBool(False) SGLANG_DISABLE_REQUEST_LOGGING = EnvBool(False)
SGLANG_SIMULATE_ACC_LEN = EnvFloat(-1)
SGLANG_SIMULATE_ACC_METHOD = EnvStr("multinomial")
# Model Parallel # Model Parallel
SGLANG_USE_MESSAGE_QUEUE_BROADCASTER = EnvBool(True) SGLANG_USE_MESSAGE_QUEUE_BROADCASTER = EnvBool(True)
......
...@@ -7,9 +7,23 @@ from sglang.srt.entrypoints.http_server import launch_server ...@@ -7,9 +7,23 @@ from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.server_args import prepare_server_args from sglang.srt.server_args import prepare_server_args
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
MOVE_ENVS_WARN = """
########################################################################
# For contributors and developers: #
# Please move environment variable definitions to 'sglang/environ.py' #
# using the following pattern: #
# SGLANG_XXX = EnvBool(False) #
# #
########################################################################
"""
if __name__ == "__main__": if __name__ == "__main__":
server_args = prepare_server_args(sys.argv[1:]) server_args = prepare_server_args(sys.argv[1:])
from sglang.srt.server_args import print_deprecated_warning
print_deprecated_warning(MOVE_ENVS_WARN)
try: try:
launch_server(server_args) launch_server(server_args)
finally: finally:
......
...@@ -12,6 +12,7 @@ import torch.nn.functional as F ...@@ -12,6 +12,7 @@ import torch.nn.functional as F
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.environ import envs
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
...@@ -23,7 +24,7 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -23,7 +24,7 @@ from sglang.srt.managers.schedule_batch import (
global_server_args_dict, global_server_args_dict,
) )
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2 from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
if is_cuda(): if is_cuda():
...@@ -42,8 +43,8 @@ logger = logging.getLogger(__name__) ...@@ -42,8 +43,8 @@ logger = logging.getLogger(__name__)
# Simulate acceptance length for benchmarking purposes # Simulate acceptance length for benchmarking purposes
SIMULATE_ACC_LEN = os.environ.get("SIMULATE_ACC_LEN") SIMULATE_ACC_LEN = envs.SGLANG_SIMULATE_ACC_LEN.get() # turn off if < 0
SIMULATE_ACC_METHOD = os.environ.get("SIMULATE_ACC_METHOD", "multinomial") SIMULATE_ACC_METHOD = envs.SGLANG_SIMULATE_ACC_METHOD.get()
TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
...@@ -500,13 +501,12 @@ class EagleVerifyInput: ...@@ -500,13 +501,12 @@ class EagleVerifyInput:
deterministic=True, deterministic=True,
) )
if SIMULATE_ACC_LEN: if SIMULATE_ACC_LEN > 0.0:
# Do simulation # Do simulation
accept_index = _generate_simulated_accept_index( accept_index = _generate_simulated_accept_index(
accept_index=accept_index, accept_index=accept_index,
predict=predict, # mutable predict=predict, # mutable
accept_length=accept_length, # mutable accept_length=accept_length, # mutable
simulate_acc_len=SIMULATE_ACC_LEN,
bs=bs, bs=bs,
spec_steps=self.spec_steps, spec_steps=self.spec_steps,
) )
...@@ -1131,14 +1131,16 @@ def _generate_simulated_accept_index( ...@@ -1131,14 +1131,16 @@ def _generate_simulated_accept_index(
accept_index, accept_index,
predict, predict,
accept_length, accept_length,
simulate_acc_len,
bs, bs,
spec_steps, spec_steps,
simulate_acc_len: float = SIMULATE_ACC_LEN,
simulate_acc_method: str = SIMULATE_ACC_METHOD,
): ):
simulate_acc_len_float = float(simulate_acc_len) assert simulate_acc_len > 0.0
if SIMULATE_ACC_METHOD == "multinomial":
if simulate_acc_method == "multinomial":
simulated_values = torch.normal( simulated_values = torch.normal(
mean=simulate_acc_len_float, mean=simulate_acc_len,
std=1.0, std=1.0,
size=(1,), size=(1,),
device="cpu", device="cpu",
...@@ -1146,19 +1148,19 @@ def _generate_simulated_accept_index( ...@@ -1146,19 +1148,19 @@ def _generate_simulated_accept_index(
# clamp simulated values to be between 1 and self.spec_steps # clamp simulated values to be between 1 and self.spec_steps
simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1) simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1)
simulate_acc_len = int(simulated_values.round().item()) simulate_acc_len = int(simulated_values.round().item())
elif SIMULATE_ACC_METHOD == "match-expected": elif simulate_acc_method == "match-expected":
# multinomial sampling does not match the expected length # multinomial sampling does not match the expected length
# we keep it for the sake of compatibility of existing tests # we keep it for the sake of compatibility of existing tests
# but it's better to use "match-expected" for the cases that need to # but it's better to use "match-expected" for the cases that need to
# match the expected length, One caveat is that this will only sample # match the expected length, One caveat is that this will only sample
# either round down or round up of the expected length # either round down or round up of the expected length
simulate_acc_len_float = max(1.0, min(spec_steps + 1, simulate_acc_len_float)) simulate_acc_len = max(1.0, min(spec_steps + 1, simulate_acc_len))
lower = int(simulate_acc_len_float // 1) lower = int(simulate_acc_len // 1)
upper = lower + 1 if lower < spec_steps + 1 else lower upper = lower + 1 if lower < spec_steps + 1 else lower
if lower == upper: if lower == upper:
simulate_acc_len = lower simulate_acc_len = lower
else: else:
weight_upper = simulate_acc_len_float - lower weight_upper = simulate_acc_len - lower
weight_lower = 1.0 - weight_upper weight_lower = 1.0 - weight_upper
probs = torch.tensor([weight_lower, weight_upper], device="cpu") probs = torch.tensor([weight_lower, weight_upper], device="cpu")
sampled_index = torch.multinomial(probs, num_samples=1) sampled_index = torch.multinomial(probs, num_samples=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment