Commit 89639c96 authored by jujl1's avatar jujl1
Browse files

feat: triton kernel 实现 update_input

parent 0936ee97
......@@ -205,7 +205,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
num_new = len(generated_token_ids)
if (model_runner_output.fix_req_ids and req_id in model_runner_output.fix_req_ids
and request.num_computed_tokens <= request.num_prompt_tokens + num_new):
and request.num_computed_tokens >= request.num_prompt_tokens + num_new):
req_idx = model_runner_output.fix_req_ids.index(req_id)
num_new = len(model_runner_output.fix_sampled_token_ids[req_idx])
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 - num_new)
......
......@@ -23,6 +23,46 @@ from vllm.profiler.prof import profile
from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model
from vllm.v1.spec_decode.utils import DraftProbs
import triton
import triton.language as tl
@triton.jit
def fused_last_valid_scatter_kernel(
last_ids_ptr, # [B, T]
input_ids_ptr, # [N]
update_req_ptr, # [U]
input_pos_ptr, # [U]
stride0,
stride1,
T,
BLOCK_T: tl.constexpr,
):
pid = tl.program_id(0)
# indices
req_idx = tl.load(update_req_ptr + pid)
input_pos = tl.load(input_pos_ptr + pid)
# load row
offs = tl.arange(0, BLOCK_T)
mask = offs < T
row_ptr = last_ids_ptr + req_idx * stride0 + offs * stride1
vals = tl.load(row_ptr, mask=mask, other=-1)
# ✅ 正确做法:index reduction
idx = tl.where(vals != -1, offs, -1)
last_idx = tl.max(idx, axis=0)
# load last token
last_val = tl.load(
last_ids_ptr + req_idx * stride0 + last_idx * stride1,
mask=last_idx >= 0,
other=0,
)
# scatter
tl.store(input_ids_ptr + input_pos, last_val)
class V1ZeroModelRunner(GPUModelRunner):
def __init__(self, vllm_config, device):
......@@ -302,18 +342,30 @@ class V1ZeroModelRunner(GPUModelRunner):
True)
last_draft_token_ids = self.last_draft_token_ids.flatten().to(torch.int)
input_ids[input_ids_indices_tensor] = last_draft_token_ids[update_req_indices_tensor]
def find_last_valid_vectorized(tensor):
"""
向量化方法找到每行最后一个非-1元素
"""
mask = tensor != -1
reversed_mask = mask.flip(dims=[1]) # 沿着列方向反转
_, col_indices = torch.max(reversed_mask.int(), dim=1)
original_col_indices = tensor.size(1) - 1 - col_indices
result = tensor[torch.arange(tensor.size(0)), original_col_indices]
all_invalid = ~mask.any(dim=1)
result[all_invalid] = -1 # 或者设置为其他默认值
return result
def fused_update_input_ids(
last_sampled_token_ids,
input_ids,
update_req_indices,
input_ids_indices,
):
B, T = last_sampled_token_ids.shape
U = update_req_indices.numel()
BLOCK_T = 1024
assert T <= BLOCK_T
grid = (U,)
fused_last_valid_scatter_kernel[grid](
last_sampled_token_ids,
input_ids,
update_req_indices,
input_ids_indices,
last_sampled_token_ids.stride(0),
last_sampled_token_ids.stride(1),
T,
BLOCK_T=BLOCK_T,
)
update_req_indices = []
input_ids_indices = []
......@@ -332,8 +384,12 @@ class V1ZeroModelRunner(GPUModelRunner):
input_ids_indices_tensor = async_tensor_h2d(input_ids_indices, torch.int32,
self.device,
True)
last_sampled_token_ids = find_last_valid_vectorized(self.last_sampled_token_ids).flatten()
input_ids[input_ids_indices_tensor] = last_sampled_token_ids[update_req_indices_tensor]
fused_update_input_ids(
self.last_sampled_token_ids,
input_ids,
update_req_indices_tensor,
input_ids_indices_tensor)
def propose_draft_token_ids(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment