Unverified Commit 0c1f03a2 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Sync cuda graph runners (#6976)

parent 3712abfa
......@@ -127,7 +127,7 @@ class EAGLEDraftCudaGraphRunner:
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
out_cache_loc=out_cache_loc,
seq_lens_sum=seq_lens.sum(),
seq_lens_sum=seq_lens.sum().item(),
return_logprob=False,
positions=positions,
spec_algorithm=self.model_runner.spec_algorithm,
......@@ -209,7 +209,7 @@ class EAGLEDraftCudaGraphRunner:
forward_batch.positions = self.positions[:num_tokens]
# Special handle for seq_len_cpu used when flashinfer mla is used
if (forward_batch.seq_lens_cpu is not None) and (bs != raw_bs):
if forward_batch.seq_lens_cpu is not None and bs != raw_bs:
self.seq_lens_cpu.fill_(1)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]
......
......@@ -138,7 +138,7 @@ class EAGLEDraftExtendCudaGraphRunner:
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
out_cache_loc=out_cache_loc,
seq_lens_sum=seq_lens.sum(),
seq_lens_sum=seq_lens.sum().item(),
return_logprob=False,
positions=positions,
spec_algorithm=self.model_runner.spec_algorithm,
......
from __future__ import annotations
import logging
import os
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional
from typing import List, Optional
import torch
import torch.nn.functional as F
......@@ -12,6 +14,7 @@ import triton.language as tl
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import apply_custom_logit_processor
from sglang.srt.managers.schedule_batch import (
Req,
ScheduleBatch,
......@@ -20,7 +23,6 @@ from sglang.srt.managers.schedule_batch import (
)
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2
......@@ -34,15 +36,15 @@ if is_cuda():
elif is_hip():
from sgl_kernel import verify_tree_greedy
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch
import logging
logger = logging.getLogger(__name__)
# Simulate acceptance length for benchmarking purposes
SIMULATE_ACC_LEN = os.environ.get("SIMULATE_ACC_LEN")
SIMULATE_ACC_METHOD = os.environ.get("SIMULATE_ACC_METHOD", "multinomial")
TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
@dataclass
......@@ -84,9 +86,9 @@ class EagleDraftInput:
self,
batch: ScheduleBatch,
speculative_num_steps: int,
context_length: int,
pad_input: bool = False,
):
assert len(self.verified_id) == len(batch.out_cache_loc)
accept_length_cpu = batch.spec_info.accept_length_cpu
batch.extend_lens = [x + 1 for x in accept_length_cpu]
batch.extend_num_tokens = sum(batch.extend_lens)
......@@ -112,49 +114,49 @@ class EagleDraftInput:
batch.input_ids = self.verified_id
self.verified_id = new_verified_id
if pad_input:
batch_size = sum(not req.finished() for req in batch.reqs)
# Total constant input length after padding
static_len = speculative_num_steps + 1
# Total size after padding
padded_input_size = batch_size * static_len
padded_len = padded_input_size - batch.input_ids.shape[0]
if padded_len > 0:
new_input_ids = torch.nn.functional.pad(
batch.input_ids, (0, padded_len), value=0
)
position_padding = torch.arange(
padded_len, device=self.positions.device
)
new_positions = torch.cat([self.positions, position_padding])
# need dummy hidden states for the padded positions
hidden_states_dim = self.hidden_states.shape[-1]
new_hidden_states = torch.cat(
[
self.hidden_states,
torch.zeros(
(padded_len, hidden_states_dim),
dtype=self.hidden_states.dtype,
device=self.hidden_states.device,
),
],
dim=0,
)
if not pad_input:
return
# allocate KV cache location for the padded tokens
padded_cache_loc = torch.zeros(
padded_len,
dtype=batch.out_cache_loc.dtype,
device=batch.out_cache_loc.device,
)
new_out_cache_loc = torch.cat([batch.out_cache_loc, padded_cache_loc])
batch_size = sum(not req.finished() for req in batch.reqs)
# Total constant input length after padding
static_len = speculative_num_steps + 1
# Total size after padding
padded_input_size = batch_size * static_len
padded_len = padded_input_size - batch.input_ids.shape[0]
if padded_len > 0:
new_input_ids = torch.nn.functional.pad(
batch.input_ids, (0, padded_len), value=0
)
position_padding = torch.arange(padded_len, device=self.positions.device)
new_positions = torch.cat([self.positions, position_padding])
# need dummy hidden states for the padded positions
hidden_states_dim = self.hidden_states.shape[-1]
new_hidden_states = torch.cat(
[
self.hidden_states,
torch.zeros(
(padded_len, hidden_states_dim),
dtype=self.hidden_states.dtype,
device=self.hidden_states.device,
),
],
dim=0,
)
# allocate KV cache location for the padded tokens
padded_cache_loc = torch.zeros(
padded_len,
dtype=batch.out_cache_loc.dtype,
device=batch.out_cache_loc.device,
)
new_out_cache_loc = torch.cat([batch.out_cache_loc, padded_cache_loc])
batch.input_ids = new_input_ids
self.hidden_states = new_hidden_states
self.positions = new_positions
batch.out_cache_loc = new_out_cache_loc
batch.input_ids = new_input_ids
self.hidden_states = new_hidden_states
self.positions = new_positions
batch.out_cache_loc = new_out_cache_loc
def generate_attn_arg_prefill(
self,
......
......@@ -687,6 +687,7 @@ class EAGLEWorker(TpModelWorker):
batch.spec_info.prepare_extend_after_decode(
batch,
self.speculative_num_steps,
self.server_args.context_length,
pad_input=self.cuda_graph_runner_for_draft_extend is not None,
)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
......
......@@ -23,6 +23,7 @@ from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
run_logprob_check,
)
......@@ -578,6 +579,7 @@ class TestEAGLEServerTriton(TestEAGLEServer):
)
@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
class TestEAGLEDraftExtend(CustomTestCase):
@classmethod
def setUpClass(cls):
......@@ -669,6 +671,7 @@ class TestEAGLEDraftExtendFlashinfer(TestEAGLEDraftExtend):
cls.accept_len_threshold = 1.50
@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
class TestEAGLEDraftExtendTriton(TestEAGLEDraftExtend):
@classmethod
def setUpClass(cls):
......@@ -697,6 +700,7 @@ class TestEAGLEDraftExtendTriton(TestEAGLEDraftExtend):
cls.accept_len_threshold = 1.50
@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
class TestEAGLEDraftExtendFlashinferMLA(TestEAGLEDraftExtend):
@classmethod
def setUpClass(cls):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment