Unverified Commit 5589b750 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Add treemask mode to build_eagle_tree & release sgl-kernel 0.2.3 (#7756)


Co-authored-by: default avatarPranjal Shankhdhar <pranjal.ssh@gmail.com>
parent c04a8a82
# NOTE: Please run this file to make sure the test cases are correct. # NOTE: Please run this file to make sure the test cases are correct.
from typing import List import math
from enum import IntEnum
from typing import List, Optional
import torch import torch
from sglang.srt.utils import is_cuda, is_hip, rank0_log from sglang.srt.utils import is_cuda, is_hip
if is_cuda() or is_hip(): if is_cuda() or is_hip():
from sgl_kernel import ( from sgl_kernel import (
...@@ -40,6 +42,12 @@ def build_tree_kernel_efficient_preprocess( ...@@ -40,6 +42,12 @@ def build_tree_kernel_efficient_preprocess(
return parent_list, top_scores_index, draft_tokens return parent_list, top_scores_index, draft_tokens
class TreeMaskMode(IntEnum):
FULL_MASK = 0
QLEN_ONLY = 1
QLEN_ONLY_BITPACKING = 2
def build_tree_kernel_efficient( def build_tree_kernel_efficient(
verified_id: torch.Tensor, verified_id: torch.Tensor,
score_list: List[torch.Tensor], score_list: List[torch.Tensor],
...@@ -50,6 +58,9 @@ def build_tree_kernel_efficient( ...@@ -50,6 +58,9 @@ def build_tree_kernel_efficient(
topk: int, topk: int,
spec_steps: int, spec_steps: int,
num_verify_tokens: int, num_verify_tokens: int,
tree_mask_mode: TreeMaskMode = TreeMaskMode.FULL_MASK,
tree_mask_buf: Optional[torch.Tensor] = None,
position_buf: Optional[torch.Tensor] = None,
): ):
parent_list, top_scores_index, draft_tokens = ( parent_list, top_scores_index, draft_tokens = (
build_tree_kernel_efficient_preprocess( build_tree_kernel_efficient_preprocess(
...@@ -66,15 +77,37 @@ def build_tree_kernel_efficient( ...@@ -66,15 +77,37 @@ def build_tree_kernel_efficient(
device = seq_lens.device device = seq_lens.device
# e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened) # e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
# where each row indicates the attending pattern of each draft token # where each row indicates the attending pattern of each draft token
# if use_partial_packed_tree_mask is True, tree_mask: num_draft_token (flattened, packed)
if tree_mask_buf is not None:
tree_mask = tree_mask_buf
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY:
tree_mask = torch.full(
(num_verify_tokens * bs * num_verify_tokens,),
True,
dtype=torch.bool,
device=device,
)
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
packed_dtypes = [torch.uint8, torch.uint16, torch.uint32]
packed_dtype_idx = int(math.ceil(math.log2((num_verify_tokens + 7) // 8)))
tree_mask = torch.zeros(
(num_verify_tokens * bs,),
dtype=packed_dtypes[packed_dtype_idx],
device=device,
)
elif tree_mask_mode == TreeMaskMode.FULL_MASK:
tree_mask = torch.full(
(
seq_lens_sum * num_verify_tokens
+ num_verify_tokens * num_verify_tokens * bs,
),
True,
device=device,
)
else:
raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}")
# TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel` # TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
tree_mask = torch.full(
(
seq_lens_sum * num_verify_tokens
+ num_verify_tokens * num_verify_tokens * bs,
),
True,
device=device,
)
retrive_index = torch.full( retrive_index = torch.full(
(bs, num_verify_tokens), -1, device=device, dtype=torch.long (bs, num_verify_tokens), -1, device=device, dtype=torch.long
) )
...@@ -87,7 +120,12 @@ def build_tree_kernel_efficient( ...@@ -87,7 +120,12 @@ def build_tree_kernel_efficient(
# position: where each token belongs to # position: where each token belongs to
# e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7 # e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
# then, positions = [7, 8, 8, 9] # then, positions = [7, 8, 8, 9]
positions = torch.empty((bs * num_verify_tokens,), device=device, dtype=torch.long) if position_buf is not None:
positions = position_buf
else:
positions = torch.empty(
(bs * num_verify_tokens,), device=device, dtype=torch.long
)
sgl_build_tree_kernel_efficient( sgl_build_tree_kernel_efficient(
parent_list, parent_list,
...@@ -101,6 +139,7 @@ def build_tree_kernel_efficient( ...@@ -101,6 +139,7 @@ def build_tree_kernel_efficient(
topk, topk,
spec_steps, spec_steps,
num_verify_tokens, num_verify_tokens,
tree_mask_mode,
) )
return ( return (
tree_mask, tree_mask,
...@@ -344,13 +383,13 @@ def test_build_tree_kernel_efficient(): ...@@ -344,13 +383,13 @@ def test_build_tree_kernel_efficient():
num_verify_tokens=num_draft_token, num_verify_tokens=num_draft_token,
) )
rank0_log("=========== build tree kernel efficient ==========") print("=========== build tree kernel efficient ==========")
# rank0_log(f"{tree_mask=}") print(f"{tree_mask=}")
rank0_log(f"{position=}") print(f"{position=}")
rank0_log(f"{retrive_index=}") print(f"{retrive_index=}")
rank0_log(f"{retrive_next_token=}") print(f"{retrive_next_token=}")
rank0_log(f"{retrive_next_sibling=}") print(f"{retrive_next_sibling=}")
rank0_log(f"{draft_tokens=}") print(f"{draft_tokens=}")
assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14] assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
assert retrive_index.tolist() == [ assert retrive_index.tolist() == [
[0, 1, 2, 3, 4, 5, 6, 7], [0, 1, 2, 3, 4, 5, 6, 7],
......
...@@ -232,7 +232,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -232,7 +232,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def( m.def(
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, " "build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, " "Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()"); "Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num, int tree_mask_mode) -> "
"()");
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient); m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
m.def( m.def(
......
...@@ -23,6 +23,8 @@ ...@@ -23,6 +23,8 @@
#include "pytorch_extension_utils_rocm.h" #include "pytorch_extension_utils_rocm.h"
#endif #endif
typedef enum { FULL_MASK = 0, QLEN_ONLY = 1, QLEN_ONLY_BITPACKING = 2 } TreeMaskMode;
// parent_list [bs, topk * (depth - 1) + 1)] // parent_list [bs, topk * (depth - 1) + 1)]
// selected_index [bs, draft_token_num - 1] // selected_index [bs, draft_token_num - 1]
// verified_seq_len [bs] // verified_seq_len [bs]
...@@ -40,7 +42,8 @@ __global__ void build_tree_efficient( ...@@ -40,7 +42,8 @@ __global__ void build_tree_efficient(
int64_t* retrive_next_sibling, int64_t* retrive_next_sibling,
int topk, int topk,
int depth, int depth,
int draft_token_num) { int draft_token_num,
int tree_mask_mode) {
int bid = blockIdx.x; int bid = blockIdx.x;
int tid = threadIdx.x; int tid = threadIdx.x;
...@@ -52,7 +55,13 @@ __global__ void build_tree_efficient( ...@@ -52,7 +55,13 @@ __global__ void build_tree_efficient(
seq_tree_idx += verified_seq_len[i] * draft_token_num; seq_tree_idx += verified_seq_len[i] * draft_token_num;
} }
int seq_len = verified_seq_len[bid]; int seq_len = verified_seq_len[bid];
int token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1; int token_tree_idx;
if (tree_mask_mode == FULL_MASK) {
token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1;
} else {
token_tree_idx = draft_token_num * draft_token_num * bid + draft_token_num * tid + 1;
}
tree_mask[token_tree_idx - 1] = true;
for (int i = 0; i < draft_token_num - 1; i++) { for (int i = 0; i < draft_token_num - 1; i++) {
tree_mask[token_tree_idx + i] = false; tree_mask[token_tree_idx + i] = false;
} }
...@@ -124,7 +133,8 @@ void build_tree_kernel_efficient( ...@@ -124,7 +133,8 @@ void build_tree_kernel_efficient(
at::Tensor retrive_next_sibling, at::Tensor retrive_next_sibling,
int64_t topk, int64_t topk,
int64_t depth, int64_t depth,
int64_t draft_token_num) { int64_t draft_token_num,
int64_t tree_mask_mode) {
// TODO (ying) check shape // TODO (ying) check shape
// TODO (ying) check type // TODO (ying) check type
int bs = parent_list.size(0); int bs = parent_list.size(0);
...@@ -132,18 +142,29 @@ void build_tree_kernel_efficient( ...@@ -132,18 +142,29 @@ void build_tree_kernel_efficient(
dim3 block(draft_token_num); dim3 block(draft_token_num);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
build_tree_efficient<<<grid, block, 0, stream>>>( if (tree_mask_mode == QLEN_ONLY_BITPACKING) {
static_cast<int64_t*>(parent_list.data_ptr()), size_t num_bytes_per_item = 1;
static_cast<int64_t*>(selected_index.data_ptr()), if (draft_token_num > 16) {
static_cast<int64_t*>(verified_seq_len.data_ptr()), num_bytes_per_item = 4;
static_cast<bool*>(tree_mask.data_ptr()), } else if (draft_token_num > 8) {
static_cast<int64_t*>(positions.data_ptr()), num_bytes_per_item = 2;
static_cast<int64_t*>(retrive_index.data_ptr()), }
static_cast<int64_t*>(retrive_next_token.data_ptr()), throw std::runtime_error("Not implemented");
static_cast<int64_t*>(retrive_next_sibling.data_ptr()), } else {
int32_t(topk), build_tree_efficient<<<grid, block, 0, stream>>>(
int32_t(depth), static_cast<int64_t*>(parent_list.data_ptr()),
int32_t(draft_token_num)); static_cast<int64_t*>(selected_index.data_ptr()),
static_cast<int64_t*>(verified_seq_len.data_ptr()),
static_cast<bool*>(tree_mask.data_ptr()),
static_cast<int64_t*>(positions.data_ptr()),
static_cast<int64_t*>(retrive_index.data_ptr()),
static_cast<int64_t*>(retrive_next_token.data_ptr()),
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
int32_t(topk),
int32_t(depth),
int32_t(draft_token_num),
int32_t(tree_mask_mode));
}
} }
template <typename IdType, typename IdType2> template <typename IdType, typename IdType2>
......
...@@ -78,7 +78,8 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -78,7 +78,8 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
m.def( m.def(
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, " "build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, " "Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()"); "Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num, int tree_mask_mode) -> "
"()");
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient); m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
} }
......
...@@ -374,7 +374,8 @@ void build_tree_kernel_efficient( ...@@ -374,7 +374,8 @@ void build_tree_kernel_efficient(
at::Tensor retrive_next_sibling, at::Tensor retrive_next_sibling,
int64_t topk, int64_t topk,
int64_t depth, int64_t depth,
int64_t draft_token_num); int64_t draft_token_num,
int64_t tree_mask_mode);
void segment_packbits( void segment_packbits(
at::Tensor x, at::Tensor x,
......
...@@ -72,6 +72,7 @@ def build_tree_kernel_efficient( ...@@ -72,6 +72,7 @@ def build_tree_kernel_efficient(
topk: int, topk: int,
depth: int, depth: int,
draft_token_num: int, draft_token_num: int,
tree_mask_mode: int,
) -> None: ) -> None:
torch.ops.sgl_kernel.build_tree_kernel_efficient.default( torch.ops.sgl_kernel.build_tree_kernel_efficient.default(
parent_list, parent_list,
...@@ -85,6 +86,7 @@ def build_tree_kernel_efficient( ...@@ -85,6 +86,7 @@ def build_tree_kernel_efficient(
topk, topk,
depth, depth,
draft_token_num, draft_token_num,
tree_mask_mode,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment