import torch import numpy as np from vllm.distributed.kv_transfer.kv_transfer_state import get_kv_transfer_group, has_kv_transfer_group from vllm.distributed.parallel_state import get_tp_group from vllm.utils import async_tensor_h2d from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput from vllm.profiler.prof import profile class V1ZeroModelRunner(): def __init__(self): self.last_sampled_token_ids = None self.last_sampled_req_ids = [] self.last_sampled_token_lens = [] self.last_sampler_event = torch.cuda.Event(enable_timing=False) self.last_sampler_host_tokens = None self.token_ids_cpu_fix_recode = [] def set_last_sampled_token_ids(self, sampled_token_ids): self.last_sampled_token_ids = sampled_token_ids self.last_sampled_req_ids = [] self.last_sampled_token_lens = [] v1_zero_overhead = V1ZeroModelRunner() def zero_prepare_inputs(runner, scheduler_output, input_ids): req_ids = runner.input_batch.req_ids update_req_indices = [] input_ids_indices = [] token_idx = 0 if v1_zero_overhead.last_sampled_token_ids is None: return sampled_tokens_num = v1_zero_overhead.last_sampled_token_ids.shape[1] for req_id in req_ids: if req_id in v1_zero_overhead.last_sampled_req_ids: req_idx = v1_zero_overhead.last_sampled_req_ids.index(req_id) * sampled_tokens_num update_req_indices.append(req_idx) input_ids_indices.append(token_idx) token_idx += scheduler_output.num_scheduled_tokens[req_id] if len(update_req_indices) > 0: update_req_indices_tensor = async_tensor_h2d(update_req_indices, torch.int32, runner.device, True) input_ids_indices_tensor = async_tensor_h2d(input_ids_indices, torch.int32, runner.device, True) last_sampled_token_ids = v1_zero_overhead.last_sampled_token_ids.flatten() for i in range(sampled_tokens_num): input_ids[input_ids_indices_tensor + i] = last_sampled_token_ids[update_req_indices_tensor + i] def execute_model_sampled(runner, max_gen_len, sampled_token_ids, discard_sampled_tokens_req_indices, scheduler_output, sampling_metadata, hidden_states, sample_hidden_states, aux_hidden_states, spec_decode_metadata, attn_metadata, logprobs_lists, prompt_logprobs_dict, finished_sending, finished_recving, num_nans_in_logits ): fix_req_ids = None fix_sampled_token_ids = None if max_gen_len == 1: # No spec decode tokens. if v1_zero_overhead.last_sampler_host_tokens != None: v1_zero_overhead.last_sampler_event.synchronize() fix_sampled_token_ids = v1_zero_overhead.last_sampler_host_tokens.tolist() for req_idx, start_idx, end_idx in v1_zero_overhead.token_ids_cpu_fix_recode: runner.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = fix_sampled_token_ids[req_idx] fix_req_ids = v1_zero_overhead.last_sampled_req_ids for req_idx, req_id in enumerate(fix_req_ids): if req_id in runner.requests: req_state = runner.requests[req_id] token_idx = v1_zero_overhead.last_sampled_token_lens[req_idx] req_state.output_token_ids[token_idx] = fix_sampled_token_ids[req_idx][0] v1_zero_overhead.last_sampler_host_tokens = sampled_token_ids.to('cpu', non_blocking=True) v1_zero_overhead.last_sampler_event.record() v1_zero_overhead.set_last_sampled_token_ids(sampled_token_ids) valid_sampled_token_ids = np.ones(sampled_token_ids.shape, dtype=int).tolist() else: # Includes spec decode tokens. valid_sampled_token_ids = runner.rejection_sampler.parse_output( sampled_token_ids, runner.input_batch.vocab_size, ) # Mask out the sampled tokens that should not be sampled. for i in discard_sampled_tokens_req_indices: valid_sampled_token_ids[i].clear() # Cache the sampled tokens in the model runner, so that the scheduler # doesn't need to send them back. # NOTE(woosuk): As an exception, when using PP, the scheduler sends # the sampled tokens back, because there's no direct communication # between the first-stage worker and the last-stage worker. v1_zero_overhead.token_ids_cpu_fix_recode.clear() for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): if not sampled_ids: continue start_idx = runner.input_batch.num_tokens_no_spec[req_idx] end_idx = start_idx + len(sampled_ids) assert end_idx <= runner.max_model_len, ( "Sampled token IDs exceed the max model length. " f"Total number of tokens: {end_idx} > max_model_len: " f"{runner.max_model_len}") runner.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids v1_zero_overhead.token_ids_cpu_fix_recode.append([req_idx, start_idx, end_idx]) runner.input_batch.num_tokens_no_spec[req_idx] = end_idx runner.input_batch.num_tokens[req_idx] = end_idx req_id = runner.input_batch.req_ids[req_idx] if req_id in runner.requests: req_state = runner.requests[req_id] v1_zero_overhead.last_sampled_req_ids.append(req_id) v1_zero_overhead.last_sampled_token_lens.append(len(req_state.output_token_ids)) req_state.output_token_ids.extend(sampled_ids) if not runner.speculative_config: # Speculative decoding is not enabled. spec_token_ids = None else: spec_token_ids = runner.propose_draft_token_ids( scheduler_output, valid_sampled_token_ids, sampling_metadata, hidden_states, sample_hidden_states, aux_hidden_states, spec_decode_metadata, attn_metadata, ) # Clear KVConnector state after all KVs are generated. if has_kv_transfer_group(): get_kv_transfer_group().clear_connector_metadata() runner.eplb_step() model_output = ZeroV1ModelRunnerOutput( req_ids=runner.input_batch.req_ids, req_id_to_index=runner.input_batch.req_id_to_index, sampled_token_ids=valid_sampled_token_ids, spec_token_ids=spec_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, pooler_output=[], finished_sending=finished_sending, finished_recving=finished_recving, num_nans_in_logits=num_nans_in_logits, fix_req_ids = fix_req_ids, fix_sampled_token_ids = fix_sampled_token_ids ) return model_output