Unverified Commit 4cb5a523 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Tiny `skip_sample` adjust (#11225)

parent 85c1f793
...@@ -663,7 +663,11 @@ class Req: ...@@ -663,7 +663,11 @@ class Req:
@property @property
def is_prefill_only(self) -> bool: def is_prefill_only(self) -> bool:
"""Check if this request is prefill-only (no token generation needed).""" """Check if this request is prefill-only (no token generation needed)."""
return self.sampling_params.max_new_tokens == 0 # NOTE: when spec is enabled, prefill_only optimizations are disabled
return (
self.sampling_params.max_new_tokens == 0
and global_server_args_dict["speculative_algorithm"] is None
)
def add_latency(self, stage: RequestStage): def add_latency(self, stage: RequestStage):
if self.metrics_collector is None: if self.metrics_collector is None:
......
...@@ -237,7 +237,7 @@ class TpModelWorker: ...@@ -237,7 +237,7 @@ class TpModelWorker:
self, self,
model_worker_batch: ModelWorkerBatch, model_worker_batch: ModelWorkerBatch,
launch_done: Optional[threading.Event] = None, launch_done: Optional[threading.Event] = None,
skip_sample: bool = False, is_verify: bool = False,
) -> ForwardBatchOutput: ) -> ForwardBatchOutput:
# update the consumer index of hicache to the running batch # update the consumer index of hicache to the running batch
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index) self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
...@@ -259,19 +259,16 @@ class TpModelWorker: ...@@ -259,19 +259,16 @@ class TpModelWorker:
if launch_done is not None: if launch_done is not None:
launch_done.set() launch_done.set()
if skip_sample: skip_sample = is_verify or model_worker_batch.is_prefill_only
next_token_ids = None next_token_ids = None
# For prefill-only requests, we still need to compute logprobs even when sampling is skipped
if ( if not skip_sample:
model_worker_batch.is_prefill_only next_token_ids = self.model_runner.sample(logits_output, forward_batch)
and model_worker_batch.return_logprob elif model_worker_batch.return_logprob and not is_verify:
): # NOTE: Compute logprobs without full sampling
# Compute logprobs without full sampling
self.model_runner.compute_logprobs_only( self.model_runner.compute_logprobs_only(
logits_output, model_worker_batch logits_output, model_worker_batch
) )
else:
next_token_ids = self.model_runner.sample(logits_output, forward_batch)
return ForwardBatchOutput( return ForwardBatchOutput(
logits_output=logits_output, logits_output=logits_output,
......
...@@ -164,8 +164,6 @@ class TpModelWorkerClient: ...@@ -164,8 +164,6 @@ class TpModelWorkerClient:
forward_batch_output = self.worker.forward_batch_generation( forward_batch_output = self.worker.forward_batch_generation(
model_worker_batch, model_worker_batch,
model_worker_batch.launch_done, model_worker_batch.launch_done,
# Skip sampling for prefill-only requests
skip_sample=model_worker_batch.is_prefill_only,
) )
logits_output, next_token_ids, can_run_cuda_graph = ( logits_output, next_token_ids, can_run_cuda_graph = (
......
...@@ -823,7 +823,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -823,7 +823,7 @@ class EAGLEWorker(TpModelWorker):
# Forward # Forward
forward_batch_output = self.target_worker.forward_batch_generation( forward_batch_output = self.target_worker.forward_batch_generation(
model_worker_batch, skip_sample=True model_worker_batch, is_verify=True
) )
logits_output, can_run_cuda_graph = ( logits_output, can_run_cuda_graph = (
forward_batch_output.logits_output, forward_batch_output.logits_output,
......
...@@ -214,7 +214,7 @@ class NGRAMWorker: ...@@ -214,7 +214,7 @@ class NGRAMWorker:
if model_worker_batch.forward_mode.is_target_verify(): if model_worker_batch.forward_mode.is_target_verify():
forward_batch_output = self.target_worker.forward_batch_generation( forward_batch_output = self.target_worker.forward_batch_generation(
model_worker_batch, skip_sample=True model_worker_batch, is_verify=True
) )
logits_output, can_run_cuda_graph = ( logits_output, can_run_cuda_graph = (
forward_batch_output.logits_output, forward_batch_output.logits_output,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment