Unverified Commit e7bc6003 authored by Zhihao Zhang's avatar Zhihao Zhang Committed by GitHub
Browse files

[Feature] Speculative decoding support lookahead (#9873)


Co-authored-by: default avatara4zhangfei <a4zhangfei@qq.com>
Co-authored-by: default avatarQiaolin-Yu <liin1211@outlook.com>
parent 2a2ff9a8
......@@ -80,6 +80,7 @@ DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST = (
"meta-llama/Llama-3.1-8B-Instruct"
)
DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
DEFAULT_LOOKAHEAD_SPECULATIVE_TARGET_MODEL_FOR_TEST = "Qwen/Qwen2.5-Coder-7B-Instruct"
# Other use cases
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION = (
......
......@@ -318,6 +318,7 @@ set(SOURCES
"csrc/kvcacheio/transfer.cu"
"csrc/speculative/eagle_utils.cu"
"csrc/speculative/lookahead_utils.cu"
"csrc/speculative/packbit.cu"
"csrc/speculative/speculative_sampling.cu"
......
......@@ -291,6 +291,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"Tensor target_predict, int cuda_stream) -> ()");
m.impl("verify_tree_greedy", torch::kCUDA, &verify_tree_greedy);
m.def(
"reconstruct_indices_from_tree_mask(Tensor tree_mask, Tensor verified_seq_len, Tensor positions, "
"Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"int batch_size, int draft_token_num) -> ()");
m.impl("reconstruct_indices_from_tree_mask", torch::kCUDA, &reconstruct_indices_from_tree_mask);
m.def(
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
......
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#ifndef USE_ROCM
#include "pytorch_extension_utils.h"
#else
#include "pytorch_extension_utils_rocm.h"
#endif
// tree_mask: [bs * draft_token_num * draft_token_num]
// verified_seq_len: [bs]
// positions: [bs * draft_token_num]
// retrive_index: [bs, draft_token_num]
// retrive_next_token: [bs, draft_token_num]
// retrive_next_sibling: [bs, draft_token_num]
__global__ void reconstructIndicesFromTreeMask(
bool* tree_mask,
int64_t* verified_seq_len,
int64_t* positions,
int64_t* retrive_index,
int64_t* retrive_next_token,
int64_t* retrive_next_sibling,
int batch_size,
int draft_token_num) {
int bid = blockIdx.x;
int tid = threadIdx.x;
if (bid >= batch_size || tid >= draft_token_num) {
return;
}
int base_offset = draft_token_num * draft_token_num;
// token_idx: [bid * draft_token_num, (bid + 1) * draft_token_num)
int token_idx = bid * draft_token_num;
// tree_mask_idx: [bid * base_offset, (bid + 1) * base_offset)
int tree_mask_offset = bid * base_offset;
int depth = 0;
int parent_idx = -1;
for (int i = tid - 1, start_idx = tree_mask_offset + tid * draft_token_num; i >= 0; i--) {
if (tree_mask[start_idx + i]) {
depth++;
if (parent_idx == -1) {
parent_idx = i;
}
}
}
retrive_index[token_idx + tid] = token_idx + tid;
positions[token_idx + tid] = depth + verified_seq_len[bid];
int next_token_idx = -1;
for (int i = tid + 1; i < draft_token_num; i++) {
if (tree_mask[tree_mask_offset + i * draft_token_num + tid]) {
next_token_idx = i;
break;
}
}
retrive_next_token[token_idx + tid] = next_token_idx;
int next_sibling_idx = -1;
if (parent_idx != -1) {
for (int i = tid + 1; i < draft_token_num; i++) {
int start_idx = tree_mask_offset + i * draft_token_num + parent_idx;
if (tree_mask[start_idx]) {
bool is_sibling = true;
int end_idx = tree_mask_offset + i * draft_token_num + i;
for (int j = start_idx + 1; j < end_idx; ++j) {
if (tree_mask[j]) {
is_sibling = false;
break;
}
}
if (is_sibling) {
next_sibling_idx = i;
break;
}
}
}
}
retrive_next_sibling[token_idx + tid] = next_sibling_idx;
}
void reconstruct_indices_from_tree_mask(
at::Tensor tree_mask,
at::Tensor verified_seq_len,
at::Tensor positions,
at::Tensor retrive_index,
at::Tensor retrive_next_token,
at::Tensor retrive_next_sibling,
int64_t batch_size,
int64_t draft_token_num) {
dim3 grid(batch_size);
dim3 block(draft_token_num);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
reconstructIndicesFromTreeMask<<<grid, block, 0, stream>>>(
static_cast<bool*>(tree_mask.data_ptr()),
static_cast<int64_t*>(verified_seq_len.data_ptr()),
static_cast<int64_t*>(positions.data_ptr()),
static_cast<int64_t*>(retrive_index.data_ptr()),
static_cast<int64_t*>(retrive_next_token.data_ptr()),
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
int(batch_size),
int(draft_token_num));
}
......@@ -457,6 +457,16 @@ void verify_tree_greedy(
at::Tensor target_predict,
int64_t cuda_stream = 0);
void reconstruct_indices_from_tree_mask(
at::Tensor tree_mask,
at::Tensor verified_seq_len,
at::Tensor positions, // mutable
at::Tensor retrive_index, // mutable
at::Tensor retrive_next_token, // mutable
at::Tensor retrive_next_sibling, // mutable
int64_t batch_size,
int64_t draft_token_num);
void build_tree_kernel_efficient(
at::Tensor parent_list,
at::Tensor selected_index,
......
......@@ -126,6 +126,7 @@ from sgl_kernel.sampling import (
)
from sgl_kernel.speculative import (
build_tree_kernel_efficient,
reconstruct_indices_from_tree_mask,
segment_packbits,
tree_speculative_sampling_target_only,
verify_tree_greedy,
......
......@@ -90,6 +90,28 @@ def build_tree_kernel_efficient(
)
def reconstruct_indices_from_tree_mask(
tree_mask: torch.Tensor,
verified_seq_len: torch.Tensor,
positions: torch.Tensor,
retrive_index: torch.Tensor,
retrive_next_token: torch.Tensor,
retrive_next_sibling: torch.Tensor,
batch_size: int,
draft_token_num: int,
) -> None:
torch.ops.sgl_kernel.reconstruct_indices_from_tree_mask.default(
tree_mask,
verified_seq_len,
positions,
retrive_index,
retrive_next_token,
retrive_next_sibling,
batch_size,
draft_token_num,
)
def segment_packbits(
x: torch.Tensor,
input_indptr: torch.Tensor,
......
import pytest
import torch
import torch.nn.functional as F
from sgl_kernel import reconstruct_indices_from_tree_mask
def test_reconstruct_indices_from_tree_mask():
bs = 1
num_branch_token = 4
seq_lens = torch.tensor([12], device="cuda", dtype=torch.int64)
retrive_index = torch.full(
(bs, num_branch_token), -1, device="cuda", dtype=torch.int64
)
retrive_next_token = torch.full(
(bs, num_branch_token), -1, device="cuda", dtype=torch.int64
)
retrive_next_sibling = torch.full(
(bs, num_branch_token), -1, device="cuda", dtype=torch.int64
)
positions = torch.empty((bs * num_branch_token), device="cuda", dtype=torch.int64)
tree_mask = torch.tensor(
[
1,
0,
0,
0,
1,
1,
0,
0,
1,
0,
1,
0,
1,
0,
1,
1,
],
device="cuda",
dtype=torch.int32,
).to(torch.bool)
reconstruct_indices_from_tree_mask(
tree_mask,
seq_lens,
positions, # mutable
retrive_index, # mutable
retrive_next_token, # mutable
retrive_next_sibling, # mutable
bs,
num_branch_token,
)
# print(f"debug: \n\n{tree_mask=}, {retrive_index=}, {retrive_next_token=}, {retrive_next_sibling=}, {positions=}\n\n")
assert retrive_index.tolist() == [
[0, 1, 2, 3],
], f"{retrive_index=}"
assert retrive_next_token.tolist() == [
[1, -1, 3, -1],
], f"{retrive_next_token=}"
assert retrive_next_sibling.tolist() == [
[-1, 2, -1, -1],
], f"{retrive_next_sibling=}"
assert positions.tolist() == [
12,
13,
13,
14,
], f"{positions=}"
if __name__ == "__main__":
test_reconstruct_indices_from_tree_mask()
pytest.main([__file__])
......@@ -78,6 +78,7 @@ suites = {
TestFile("test_hidden_states.py", 55),
TestFile("test_hybrid_attn_backend.py", 100),
TestFile("test_standalone_speculative_decoding.py", 250),
TestFile("test_lookahead_speculative_decoding.py", 250),
TestFile("test_input_embeddings.py", 38),
TestFile("test_io_struct.py", 8),
TestFile("test_jinja_template_utils.py", 1),
......
import os
import unittest
from types import SimpleNamespace
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_LOOKAHEAD_SPECULATIVE_TARGET_MODEL_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
GSM_DATASET_PATH = None
# Default server arguments shared across all tests
DEFAULT_SERVER_ARGS = [
"--trust-remote-code",
"--cuda-graph-max-bs",
"8",
"--speculative-algorithm",
"LOOKAHEAD",
"--speculative-num-draft-tokens",
"16",
"--mem-fraction-static",
0.8,
]
class TestStandaloneSpeculativeDecodingBase(CustomTestCase):
model = DEFAULT_LOOKAHEAD_SPECULATIVE_TARGET_MODEL_FOR_TEST
base_url = DEFAULT_URL_FOR_TEST
accuracy_threshold = 0.8 # derived tests need to override this
spec_decode_threshold = 1.8 # derived spec decoding tests need to override this
@classmethod
def get_server_args(cls):
"""Return the arguments for the server launch. Override in subclasses."""
return DEFAULT_SERVER_ARGS + ["--attention-backend", "fa3"]
@classmethod
def setUpClass(cls):
# disable deep gemm precompile to make launch server faster
# please don't do this if you want to make your inference workload faster
os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false"
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
model = cls.model
cls.process = popen_launch_server(
model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=cls.get_server_args(),
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace(
num_shots=4,
num_questions=100,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
data_path=GSM_DATASET_PATH,
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"{metrics=}")
# Use the appropriate metric key based on the test class
metric_key = "accuracy"
self.assertGreater(metrics[metric_key], self.accuracy_threshold)
server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["internal_states"][0][
"avg_spec_accept_length"
]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, self.spec_decode_threshold)
class TestStandaloneSpeculativeDecodingTriton(TestStandaloneSpeculativeDecodingBase):
@classmethod
def get_server_args(cls):
return DEFAULT_SERVER_ARGS + ["--attention-backend", "triton"]
class TestStandaloneSpeculativeDecodingFlashinfer(
TestStandaloneSpeculativeDecodingBase
):
@classmethod
def get_server_args(cls):
return DEFAULT_SERVER_ARGS + ["--attention-backend", "flashinfer"]
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment