Unverified Commit c645e9a2 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Remove propose_draft method (#35070)


Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
parent 944ffb59
...@@ -858,29 +858,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -858,29 +858,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
computed_prefill, self.req_states.prefill_len.np, out=computed_prefill computed_prefill, self.req_states.prefill_len.np, out=computed_prefill
) )
@torch.inference_mode()
def propose_draft(
self,
input_batch: InputBatch,
last_hidden_states: torch.Tensor,
aux_hidden_states: list[torch.Tensor] | None,
num_sampled: torch.Tensor,
num_rejected: torch.Tensor,
) -> torch.Tensor:
assert self.speculator is not None
draft_tokens = self.speculator.propose(
input_batch,
last_hidden_states,
aux_hidden_states,
num_sampled,
num_rejected,
self.req_states.last_sampled_tokens,
self.req_states.next_prefill_tokens,
self.sampler.sampling_states.temperature.gpu,
self.sampler.sampling_states.seeds.gpu,
)
return draft_tokens
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
...@@ -1113,12 +1090,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1113,12 +1090,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected
) )
if self.speculator is not None: if self.speculator is not None:
draft_tokens = self.propose_draft( draft_tokens = self.speculator.propose(
input_batch, input_batch,
hidden_states, hidden_states,
aux_hidden_states, aux_hidden_states,
num_sampled, num_sampled,
num_rejected, num_rejected,
self.req_states.last_sampled_tokens,
self.req_states.next_prefill_tokens,
self.sampler.sampling_states.temperature.gpu,
self.sampler.sampling_states.seeds.gpu,
) )
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens) self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment