import torch from sgl_kernel.utils import get_cuda_stream def tree_speculative_sampling_target_only( predicts: torch.Tensor, # mutable accept_index: torch.Tensor, # mutable accept_token_num: torch.Tensor, # mutable candidates: torch.Tensor, retrive_index: torch.Tensor, retrive_next_token: torch.Tensor, retrive_next_sibling: torch.Tensor, uniform_samples: torch.Tensor, target_probs: torch.Tensor, draft_probs: torch.Tensor, deterministic: bool = True, ) -> None: torch.ops.sgl_kernel.tree_speculative_sampling_target_only( predicts, accept_index, accept_token_num, candidates, retrive_index, retrive_next_token, retrive_next_sibling, uniform_samples, target_probs, draft_probs, deterministic, get_cuda_stream(), ) def build_tree_kernel_efficient( parent_list: torch.Tensor, selected_index: torch.Tensor, verified_seq_len: torch.Tensor, tree_mask: torch.Tensor, positions: torch.Tensor, retrive_index: torch.Tensor, retrive_next_token: torch.Tensor, retrive_next_sibling: torch.Tensor, topk: int, depth: int, draft_token_num: int, ) -> None: torch.ops.sgl_kernel.build_tree_kernel_efficient( parent_list, selected_index, verified_seq_len, tree_mask, positions, retrive_index, retrive_next_token, retrive_next_sibling, topk, depth, draft_token_num, ) def build_tree_kernel( parent_list: torch.Tensor, selected_index: torch.Tensor, verified_seq_len: torch.Tensor, tree_mask: torch.Tensor, positions: torch.Tensor, retrive_index: torch.Tensor, topk: int, depth: int, draft_token_num: int, ) -> None: torch.ops.sgl_kernel.build_tree_kernel( parent_list, selected_index, verified_seq_len, tree_mask, positions, retrive_index, topk, depth, draft_token_num, )