from enum import Enum import os import torch import vllm.envs as envs import triton import triton.language as tl zero_no_thread = os.environ.get('VLLM_ZERO_NO_THREAD') == '1' def is_zero_no_thread(): return zero_no_thread and envs.VLLM_ZERO_OVERHEAD class SpecStepKind(Enum): KIND_DEFAULT = 0 PREFILL = 1 FIRST_PROPOSAL = 2 OTHER_PROPOSAL = 3 SCORE_DECODE = 4 class ZeroOverheadSpecContext(): def __init__(self): self.step_kind = SpecStepKind.KIND_DEFAULT self.last_step = SpecStepKind.KIND_DEFAULT self.proposal_lens_list = None self.proposal_token_ids = None self.accepted_token_ids = None self.accepted_seq_ids = None spec_context = ZeroOverheadSpecContext() def set_spec_step(_step): global spec_context spec_context.last_step = spec_context.step_kind spec_context.step_kind = _step def get_spec_step(): return spec_context.step_kind def get_spec_last_step(): return spec_context.last_step def record_proposal_lens_list(list): global spec_context spec_context.proposal_lens_list = list def get_proposal_lens_list(): return spec_context.proposal_lens_list def record_proposal_token_ids(tensor): global spec_context spec_context.proposal_token_ids = tensor def get_proposal_token_ids(): return spec_context.proposal_token_ids def record_accepted_token_ids(tensor, seq_ids): global spec_context spec_context.accepted_token_ids = tensor spec_context.accepted_seq_ids = seq_ids def get_accepted_token_ids(): return spec_context.accepted_token_ids, spec_context.accepted_seq_ids # 零消耗调度不在默认流上推理,用以规避runtime引入的内存申请流同步问题。 alloc_stream = {} def zero_overhead_stream(target_device): """Asynchronously create a tensor and copy it from host to device.""" if target_device not in alloc_stream.keys(): alloc_stream[target_device] = torch.cuda.Stream(device=target_device) return alloc_stream[target_device] @triton.jit def fused_last_valid_scatter_kernel( last_ids_ptr, # [B, T] input_ids_ptr, # [N] update_req_ptr, # [U] input_pos_ptr, # [U] stride0, stride1, T, BLOCK_T: tl.constexpr, ): pid = tl.program_id(0) # indices req_idx = tl.load(update_req_ptr + pid) input_pos = tl.load(input_pos_ptr + pid) # load row offs = tl.arange(0, BLOCK_T) mask = offs < T row_ptr = last_ids_ptr + req_idx * stride0 + offs * stride1 vals = tl.load(row_ptr, mask=mask, other=-1) idx = tl.where(vals != -1, offs, -1) last_idx = tl.max(idx, axis=0) # load last token last_val = tl.load( last_ids_ptr + req_idx * stride0 + last_idx * stride1, mask=last_idx >= 0, other=0, ) # scatter tl.store(input_ids_ptr + input_pos, last_val) def fused_update_input_ids_impl( last_sampled_token_ids, input_ids, update_req_indices, input_ids_indices, ): B, T = last_sampled_token_ids.shape U = update_req_indices.numel() BLOCK_T = 1024 assert T <= BLOCK_T grid = (U,) fused_last_valid_scatter_kernel[grid]( last_sampled_token_ids, input_ids, update_req_indices, input_ids_indices, last_sampled_token_ids.stride(0), last_sampled_token_ids.stride(1), T, BLOCK_T=BLOCK_T, )