Unverified Commit 632b7d8c authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Use simulate acc len from `sglang.environ` (#10771)

parent 16adf3dc
......@@ -124,6 +124,8 @@ class Envs:
SGLANG_TEST_REQUEST_TIME_STATS = EnvBool(False)
SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK = EnvBool(False)
SGLANG_DISABLE_REQUEST_LOGGING = EnvBool(False)
SGLANG_SIMULATE_ACC_LEN = EnvFloat(-1)
SGLANG_SIMULATE_ACC_METHOD = EnvStr("multinomial")
# Model Parallel
SGLANG_USE_MESSAGE_QUEUE_BROADCASTER = EnvBool(True)
......
......@@ -7,9 +7,23 @@ from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.server_args import prepare_server_args
from sglang.srt.utils import kill_process_tree
MOVE_ENVS_WARN = """
########################################################################
# For contributors and developers: #
# Please move environment variable definitions to 'sglang/environ.py' #
# using the following pattern: #
# SGLANG_XXX = EnvBool(False) #
# #
########################################################################
"""
if __name__ == "__main__":
server_args = prepare_server_args(sys.argv[1:])
from sglang.srt.server_args import print_deprecated_warning
print_deprecated_warning(MOVE_ENVS_WARN)
try:
launch_server(server_args)
finally:
......
......@@ -12,6 +12,7 @@ import torch.nn.functional as F
import triton
import triton.language as tl
from sglang.environ import envs
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
......@@ -23,7 +24,7 @@ from sglang.srt.managers.schedule_batch import (
global_server_args_dict,
)
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
if is_cuda():
......@@ -42,8 +43,8 @@ logger = logging.getLogger(__name__)
# Simulate acceptance length for benchmarking purposes
SIMULATE_ACC_LEN = os.environ.get("SIMULATE_ACC_LEN")
SIMULATE_ACC_METHOD = os.environ.get("SIMULATE_ACC_METHOD", "multinomial")
SIMULATE_ACC_LEN = envs.SGLANG_SIMULATE_ACC_LEN.get() # turn off if < 0
SIMULATE_ACC_METHOD = envs.SGLANG_SIMULATE_ACC_METHOD.get()
TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
......@@ -500,13 +501,12 @@ class EagleVerifyInput:
deterministic=True,
)
if SIMULATE_ACC_LEN:
if SIMULATE_ACC_LEN > 0.0:
# Do simulation
accept_index = _generate_simulated_accept_index(
accept_index=accept_index,
predict=predict, # mutable
accept_length=accept_length, # mutable
simulate_acc_len=SIMULATE_ACC_LEN,
bs=bs,
spec_steps=self.spec_steps,
)
......@@ -1131,14 +1131,16 @@ def _generate_simulated_accept_index(
accept_index,
predict,
accept_length,
simulate_acc_len,
bs,
spec_steps,
simulate_acc_len: float = SIMULATE_ACC_LEN,
simulate_acc_method: str = SIMULATE_ACC_METHOD,
):
simulate_acc_len_float = float(simulate_acc_len)
if SIMULATE_ACC_METHOD == "multinomial":
assert simulate_acc_len > 0.0
if simulate_acc_method == "multinomial":
simulated_values = torch.normal(
mean=simulate_acc_len_float,
mean=simulate_acc_len,
std=1.0,
size=(1,),
device="cpu",
......@@ -1146,19 +1148,19 @@ def _generate_simulated_accept_index(
# clamp simulated values to be between 1 and self.spec_steps
simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1)
simulate_acc_len = int(simulated_values.round().item())
elif SIMULATE_ACC_METHOD == "match-expected":
elif simulate_acc_method == "match-expected":
# multinomial sampling does not match the expected length
# we keep it for the sake of compatibility of existing tests
# but it's better to use "match-expected" for the cases that need to
# match the expected length, One caveat is that this will only sample
# either round down or round up of the expected length
simulate_acc_len_float = max(1.0, min(spec_steps + 1, simulate_acc_len_float))
lower = int(simulate_acc_len_float // 1)
simulate_acc_len = max(1.0, min(spec_steps + 1, simulate_acc_len))
lower = int(simulate_acc_len // 1)
upper = lower + 1 if lower < spec_steps + 1 else lower
if lower == upper:
simulate_acc_len = lower
else:
weight_upper = simulate_acc_len_float - lower
weight_upper = simulate_acc_len - lower
weight_lower = 1.0 - weight_upper
probs = torch.tensor([weight_lower, weight_upper], device="cpu")
sampled_index = torch.multinomial(probs, num_samples=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment