"references/vscode:/vscode.git/clone" did not exist on "0dceac025615a1c2df6ec1675d8f9d7757432a49"
Unverified Commit 0c1f03a2 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Sync cuda graph runners (#6976)

parent 3712abfa
...@@ -127,7 +127,7 @@ class EAGLEDraftCudaGraphRunner: ...@@ -127,7 +127,7 @@ class EAGLEDraftCudaGraphRunner:
req_to_token_pool=self.model_runner.req_to_token_pool, req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool, token_to_kv_pool=self.model_runner.token_to_kv_pool,
out_cache_loc=out_cache_loc, out_cache_loc=out_cache_loc,
seq_lens_sum=seq_lens.sum(), seq_lens_sum=seq_lens.sum().item(),
return_logprob=False, return_logprob=False,
positions=positions, positions=positions,
spec_algorithm=self.model_runner.spec_algorithm, spec_algorithm=self.model_runner.spec_algorithm,
...@@ -209,7 +209,7 @@ class EAGLEDraftCudaGraphRunner: ...@@ -209,7 +209,7 @@ class EAGLEDraftCudaGraphRunner:
forward_batch.positions = self.positions[:num_tokens] forward_batch.positions = self.positions[:num_tokens]
# Special handle for seq_len_cpu used when flashinfer mla is used # Special handle for seq_len_cpu used when flashinfer mla is used
if (forward_batch.seq_lens_cpu is not None) and (bs != raw_bs): if forward_batch.seq_lens_cpu is not None and bs != raw_bs:
self.seq_lens_cpu.fill_(1) self.seq_lens_cpu.fill_(1)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs] forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]
......
...@@ -138,7 +138,7 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -138,7 +138,7 @@ class EAGLEDraftExtendCudaGraphRunner:
req_to_token_pool=self.model_runner.req_to_token_pool, req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool, token_to_kv_pool=self.model_runner.token_to_kv_pool,
out_cache_loc=out_cache_loc, out_cache_loc=out_cache_loc,
seq_lens_sum=seq_lens.sum(), seq_lens_sum=seq_lens.sum().item(),
return_logprob=False, return_logprob=False,
positions=positions, positions=positions,
spec_algorithm=self.model_runner.spec_algorithm, spec_algorithm=self.model_runner.spec_algorithm,
......
from __future__ import annotations from __future__ import annotations
import logging
import os import os
import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional from typing import List, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -12,6 +14,7 @@ import triton.language as tl ...@@ -12,6 +14,7 @@ import triton.language as tl
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import apply_custom_logit_processor
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import (
Req, Req,
ScheduleBatch, ScheduleBatch,
...@@ -20,7 +23,6 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -20,7 +23,6 @@ from sglang.srt.managers.schedule_batch import (
) )
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2 from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2
...@@ -34,15 +36,15 @@ if is_cuda(): ...@@ -34,15 +36,15 @@ if is_cuda():
elif is_hip(): elif is_hip():
from sgl_kernel import verify_tree_greedy from sgl_kernel import verify_tree_greedy
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Simulate acceptance length for benchmarking purposes
SIMULATE_ACC_LEN = os.environ.get("SIMULATE_ACC_LEN") SIMULATE_ACC_LEN = os.environ.get("SIMULATE_ACC_LEN")
SIMULATE_ACC_METHOD = os.environ.get("SIMULATE_ACC_METHOD", "multinomial")
TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
@dataclass @dataclass
...@@ -84,9 +86,9 @@ class EagleDraftInput: ...@@ -84,9 +86,9 @@ class EagleDraftInput:
self, self,
batch: ScheduleBatch, batch: ScheduleBatch,
speculative_num_steps: int, speculative_num_steps: int,
context_length: int,
pad_input: bool = False, pad_input: bool = False,
): ):
assert len(self.verified_id) == len(batch.out_cache_loc)
accept_length_cpu = batch.spec_info.accept_length_cpu accept_length_cpu = batch.spec_info.accept_length_cpu
batch.extend_lens = [x + 1 for x in accept_length_cpu] batch.extend_lens = [x + 1 for x in accept_length_cpu]
batch.extend_num_tokens = sum(batch.extend_lens) batch.extend_num_tokens = sum(batch.extend_lens)
...@@ -112,49 +114,49 @@ class EagleDraftInput: ...@@ -112,49 +114,49 @@ class EagleDraftInput:
batch.input_ids = self.verified_id batch.input_ids = self.verified_id
self.verified_id = new_verified_id self.verified_id = new_verified_id
if pad_input: if not pad_input:
batch_size = sum(not req.finished() for req in batch.reqs) return
# Total constant input length after padding
static_len = speculative_num_steps + 1
# Total size after padding
padded_input_size = batch_size * static_len
padded_len = padded_input_size - batch.input_ids.shape[0]
if padded_len > 0:
new_input_ids = torch.nn.functional.pad(
batch.input_ids, (0, padded_len), value=0
)
position_padding = torch.arange(
padded_len, device=self.positions.device
)
new_positions = torch.cat([self.positions, position_padding])
# need dummy hidden states for the padded positions
hidden_states_dim = self.hidden_states.shape[-1]
new_hidden_states = torch.cat(
[
self.hidden_states,
torch.zeros(
(padded_len, hidden_states_dim),
dtype=self.hidden_states.dtype,
device=self.hidden_states.device,
),
],
dim=0,
)
# allocate KV cache location for the padded tokens batch_size = sum(not req.finished() for req in batch.reqs)
padded_cache_loc = torch.zeros( # Total constant input length after padding
padded_len, static_len = speculative_num_steps + 1
dtype=batch.out_cache_loc.dtype, # Total size after padding
device=batch.out_cache_loc.device, padded_input_size = batch_size * static_len
)
new_out_cache_loc = torch.cat([batch.out_cache_loc, padded_cache_loc]) padded_len = padded_input_size - batch.input_ids.shape[0]
if padded_len > 0:
new_input_ids = torch.nn.functional.pad(
batch.input_ids, (0, padded_len), value=0
)
position_padding = torch.arange(padded_len, device=self.positions.device)
new_positions = torch.cat([self.positions, position_padding])
# need dummy hidden states for the padded positions
hidden_states_dim = self.hidden_states.shape[-1]
new_hidden_states = torch.cat(
[
self.hidden_states,
torch.zeros(
(padded_len, hidden_states_dim),
dtype=self.hidden_states.dtype,
device=self.hidden_states.device,
),
],
dim=0,
)
# allocate KV cache location for the padded tokens
padded_cache_loc = torch.zeros(
padded_len,
dtype=batch.out_cache_loc.dtype,
device=batch.out_cache_loc.device,
)
new_out_cache_loc = torch.cat([batch.out_cache_loc, padded_cache_loc])
batch.input_ids = new_input_ids batch.input_ids = new_input_ids
self.hidden_states = new_hidden_states self.hidden_states = new_hidden_states
self.positions = new_positions self.positions = new_positions
batch.out_cache_loc = new_out_cache_loc batch.out_cache_loc = new_out_cache_loc
def generate_attn_arg_prefill( def generate_attn_arg_prefill(
self, self,
......
...@@ -687,6 +687,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -687,6 +687,7 @@ class EAGLEWorker(TpModelWorker):
batch.spec_info.prepare_extend_after_decode( batch.spec_info.prepare_extend_after_decode(
batch, batch,
self.speculative_num_steps, self.speculative_num_steps,
self.server_args.context_length,
pad_input=self.cuda_graph_runner_for_draft_extend is not None, pad_input=self.cuda_graph_runner_for_draft_extend is not None,
) )
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
......
...@@ -23,6 +23,7 @@ from sglang.test.test_utils import ( ...@@ -23,6 +23,7 @@ from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
CustomTestCase, CustomTestCase,
is_in_ci,
popen_launch_server, popen_launch_server,
run_logprob_check, run_logprob_check,
) )
...@@ -578,6 +579,7 @@ class TestEAGLEServerTriton(TestEAGLEServer): ...@@ -578,6 +579,7 @@ class TestEAGLEServerTriton(TestEAGLEServer):
) )
@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
class TestEAGLEDraftExtend(CustomTestCase): class TestEAGLEDraftExtend(CustomTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
...@@ -669,6 +671,7 @@ class TestEAGLEDraftExtendFlashinfer(TestEAGLEDraftExtend): ...@@ -669,6 +671,7 @@ class TestEAGLEDraftExtendFlashinfer(TestEAGLEDraftExtend):
cls.accept_len_threshold = 1.50 cls.accept_len_threshold = 1.50
@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
class TestEAGLEDraftExtendTriton(TestEAGLEDraftExtend): class TestEAGLEDraftExtendTriton(TestEAGLEDraftExtend):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
...@@ -697,6 +700,7 @@ class TestEAGLEDraftExtendTriton(TestEAGLEDraftExtend): ...@@ -697,6 +700,7 @@ class TestEAGLEDraftExtendTriton(TestEAGLEDraftExtend):
cls.accept_len_threshold = 1.50 cls.accept_len_threshold = 1.50
@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
class TestEAGLEDraftExtendFlashinferMLA(TestEAGLEDraftExtend): class TestEAGLEDraftExtendFlashinferMLA(TestEAGLEDraftExtend):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment