Unverified Commit 458611de authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Unify forward output datastructure (#11124)

parent 3511b370
......@@ -22,6 +22,7 @@ from typing import List, Optional, Set, Union
import torch
from transformers import PretrainedConfig
from sglang.srt.environ import envs
from sglang.srt.hf_transformers_utils import (
get_config,
get_context_length,
......@@ -31,7 +32,7 @@ from sglang.srt.hf_transformers_utils import (
)
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_bool_env_var, is_hip, retry
from sglang.srt.utils import is_hip, retry
from sglang.utils import is_in_ci
logger = logging.getLogger(__name__)
......@@ -237,7 +238,7 @@ class ModelConfig:
f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config."
)
if (
get_bool_env_var("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN")
envs.SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN.get()
or is_in_ci() # FIXME: fix this special case
):
logger.warning(msg)
......
......@@ -689,7 +689,6 @@ class SchedulerDisaggregationPrefillMixin:
self.running_mbs = [
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
]
bids = [None] * self.pp_size
pp_outputs: Optional[PPProxyTensors] = None
# Either success or failed
......@@ -761,10 +760,7 @@ class SchedulerDisaggregationPrefillMixin:
# send the outputs to the next step
if self.pp_group.is_last_rank:
if self.cur_batch:
next_token_ids, bids[mb_id] = (
result.next_token_ids,
result.bid,
)
next_token_ids = result.next_token_ids
pp_outputs = PPProxyTensors(
{
"next_token_ids": next_token_ids,
......@@ -801,7 +797,6 @@ class SchedulerDisaggregationPrefillMixin:
next_token_ids=next_pp_outputs["next_token_ids"],
extend_input_len_per_req=None,
extend_logprob_start_len_per_req=None,
bid=bids[next_mb_id],
can_run_cuda_graph=result.can_run_cuda_graph,
)
self.process_batch_result_disagg_prefill(
......@@ -818,8 +813,6 @@ class SchedulerDisaggregationPrefillMixin:
# carry the outputs to the next stage
if not self.pp_group.is_last_rank:
if self.cur_batch:
bids[mb_id] = result.bid
if pp_outputs:
# send the outputs from the last round to let the next stage worker run post processing
self.pp_group.send_tensor_dict(
......@@ -838,8 +831,10 @@ class SchedulerDisaggregationPrefillMixin:
# send out proxy tensors to the next stage
if self.cur_batch:
# FIXME(lsyin): remove this assert
assert result.pp_hidden_states_proxy_tensors.tensors is not None
self.pp_group.send_tensor_dict(
result.pp_hidden_states_proxy_tensors,
result.pp_hidden_states_proxy_tensors.tensors,
all_gather_group=self.attn_tp_group,
)
......
......@@ -860,10 +860,6 @@ class Req:
)
# Batch id
bid = 0
@dataclasses.dataclass
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
"""Store all information of a batch on the scheduler."""
......@@ -1829,10 +1825,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
seq_lens_cpu_cache if seq_lens_cpu_cache is not None else self.seq_lens_cpu
)
global bid
bid += 1
return ModelWorkerBatch(
bid=bid,
forward_mode=self.forward_mode,
input_ids=self.input_ids,
req_pool_indices=self.req_pool_indices,
......@@ -1952,8 +1945,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
@dataclasses.dataclass
class ModelWorkerBatch:
# The batch id
bid: int
# The forward mode
forward_mode: ForwardMode
# The input ids
......
......@@ -150,7 +150,11 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
from sglang.srt.model_executor.forward_batch_info import (
ForwardBatchOutput,
ForwardMode,
PPProxyTensors,
)
from sglang.srt.parser.reasoning_parser import ReasoningParser
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
......@@ -175,7 +179,6 @@ from sglang.srt.utils import (
get_bool_env_var,
get_int_env_var,
get_zmq_socket,
is_cpu,
kill_itself_when_parent_died,
numa_bind_to_node,
point_to_point_pyobj,
......@@ -194,24 +197,59 @@ logger = logging.getLogger(__name__)
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
_is_cpu = is_cpu()
@dataclass
class GenerationBatchResult:
logits_output: Optional[LogitsProcessorOutput]
pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
pp_hidden_states_proxy_tensors: Optional[PPProxyTensors]
next_token_ids: Optional[List[int]]
can_run_cuda_graph: bool
# For output processing
extend_input_len_per_req: List[int]
extend_logprob_start_len_per_req: List[int]
bid: int
can_run_cuda_graph: bool
@classmethod
def from_forward_batch_output(
cls,
forward_batch_output: ForwardBatchOutput,
extend_input_len_per_req: List[int],
extend_logprob_start_len_per_req: List[int],
):
# TODO(lsyin): remove this workaround logic and try to unify output classes
return cls(
logits_output=forward_batch_output.logits_output,
pp_hidden_states_proxy_tensors=forward_batch_output.pp_proxy_tensors,
next_token_ids=forward_batch_output.next_token_ids,
extend_input_len_per_req=extend_input_len_per_req,
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
can_run_cuda_graph=forward_batch_output.can_run_cuda_graph,
)
@classmethod
def from_pp_proxy(
cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
):
# TODO(lsyin): also simplify this logic
# Current PP implementation in scheduler is not compatible with ForwardBatchOutput
# Maybe introduce a ProxyBatchOutput for PP and the original ForwardBatchOutput for TP
proxy_dict = next_pp_outputs.tensors
return cls(
logits_output=logits_output,
pp_hidden_states_proxy_tensors=None,
next_token_ids=next_pp_outputs["next_token_ids"],
extend_input_len_per_req=proxy_dict.get("extend_input_len_per_req", None),
extend_logprob_start_len_per_req=proxy_dict.get(
"extend_logprob_start_len_per_req", None
),
can_run_cuda_graph=can_run_cuda_graph,
)
@dataclass
class EmbeddingBatchResult:
embeddings: torch.Tensor
bid: int
class Scheduler(
......@@ -403,6 +441,12 @@ class Scheduler(
else:
self.draft_worker = None
# Dispatch the model worker
if self.spec_algorithm.is_none():
self.model_worker = self.tp_worker
else:
self.model_worker = self.draft_worker
# Get token and memory info from the model worker
(
self.max_total_num_tokens,
......@@ -959,7 +1003,6 @@ class Scheduler(
self.running_mbs = [
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
]
bids = [None] * self.pp_size
pp_outputs: Optional[PPProxyTensors] = None
while True:
server_is_idle = True
......@@ -980,10 +1023,7 @@ class Scheduler(
# (last rank) send the outputs to the next step
if self.pp_group.is_last_rank:
if self.cur_batch:
next_token_ids, bids[mb_id] = (
result.next_token_ids,
result.bid,
)
next_token_ids = result.next_token_ids
if self.cur_batch.return_logprob:
pp_outputs = PPProxyTensors(
{
......@@ -1031,17 +1071,10 @@ class Scheduler(
logits_output = LogitsProcessorOutput(**logits_output_args)
else:
logits_output = None
output_result = GenerationBatchResult(
output_result = GenerationBatchResult.from_pp_proxy(
logits_output=logits_output,
pp_hidden_states_proxy_tensors=None,
next_token_ids=next_pp_outputs["next_token_ids"],
extend_input_len_per_req=next_pp_outputs.tensors.get(
"extend_input_len_per_req", None
),
extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
"extend_logprob_start_len_per_req", None
),
bid=bids[next_mb_id],
next_pp_outputs=next_pp_outputs,
can_run_cuda_graph=result.can_run_cuda_graph,
)
self.process_batch_result(mbs[next_mb_id], output_result)
......@@ -1049,8 +1082,6 @@ class Scheduler(
# (not last rank)
if not self.pp_group.is_last_rank:
if self.cur_batch:
bids[mb_id] = result.bid
# carry the outputs to the next stage
# send the outputs from the last round to let the next stage worker run post processing
if pp_outputs:
......@@ -1072,8 +1103,10 @@ class Scheduler(
# send out proxy tensors to the next stage
if self.cur_batch:
# FIXME(lsyin): remove this assert
assert result.pp_hidden_states_proxy_tensors.tensors is not None
self.pp_group.send_tensor_dict(
result.pp_hidden_states_proxy_tensors,
result.pp_hidden_states_proxy_tensors.tensors,
all_gather_group=self.attn_tp_group,
)
......@@ -2016,33 +2049,25 @@ class Scheduler(
# Run forward
if self.is_generation:
batch_or_worker_batch = batch
if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch()
# FIXME(lsyin): remove this if and finally unify the abstraction
batch_or_worker_batch = batch.get_model_worker_batch()
if self.pp_group.is_last_rank:
logits_output, next_token_ids, can_run_cuda_graph = (
self.tp_worker.forward_batch_generation(model_worker_batch)
)
else:
pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
self.tp_worker.forward_batch_generation(model_worker_batch)
)
bid = model_worker_batch.bid
else:
(
logits_output,
next_token_ids,
bid,
num_accepted_tokens,
can_run_cuda_graph,
) = self.draft_worker.forward_batch_speculative_generation(batch)
bs = batch.batch_size()
self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
self.spec_num_total_forward_ct += bs
self.num_generated_tokens += num_accepted_tokens
if self.pp_group.is_last_rank:
batch.output_ids = next_token_ids
forward_batch_output = self.model_worker.forward_batch_generation(
batch_or_worker_batch
)
if not self.spec_algorithm.is_none():
# TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
self.udpate_spec_metrics(
batch.batch_size(), forward_batch_output.num_accepted_tokens
)
# update batch's output ids
batch.output_ids = forward_batch_output.next_token_ids
# These 2 values are needed for processing the output, but the values can be
# modified by overlap schedule. So we have to copy them here so that
......@@ -2051,6 +2076,7 @@ class Scheduler(
extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
else:
extend_input_len_per_req = None
if batch.return_logprob:
extend_logprob_start_len_per_req = [
req.extend_logprob_start_len for req in batch.reqs
......@@ -2058,25 +2084,15 @@ class Scheduler(
else:
extend_logprob_start_len_per_req = None
ret = GenerationBatchResult(
logits_output=logits_output if self.pp_group.is_last_rank else None,
pp_hidden_states_proxy_tensors=(
pp_hidden_states_proxy_tensors
if not self.pp_group.is_last_rank
else None
),
next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
return GenerationBatchResult.from_forward_batch_output(
forward_batch_output=forward_batch_output,
extend_input_len_per_req=extend_input_len_per_req,
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
bid=bid,
can_run_cuda_graph=can_run_cuda_graph,
)
else: # embedding or reward model
model_worker_batch = batch.get_model_worker_batch()
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
ret = EmbeddingBatchResult(
embeddings=embeddings, bid=model_worker_batch.bid
)
ret = EmbeddingBatchResult(embeddings=embeddings)
return ret
def process_batch_result(
......
......@@ -80,6 +80,11 @@ class SchedulerMetricsMixin:
kv_events_config, self.attn_dp_rank
)
def udpate_spec_metrics(self, bs: int, num_accepted_tokens: int):
self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
self.spec_num_total_forward_ct += bs
self.num_generated_tokens += num_accepted_tokens
def log_prefill_stats(
self: Scheduler,
adder: PrefillAdder,
......
......@@ -173,8 +173,7 @@ class SchedulerOutputProcessorMixin:
self.set_next_batch_sampling_info_done(batch)
else: # embedding or reward model
embeddings, bid = result.embeddings, result.bid
embeddings = embeddings.tolist()
embeddings = result.embeddings.tolist()
# Check finish conditions
for i, req in enumerate(batch.reqs):
......
......@@ -43,7 +43,11 @@ from sglang.srt.managers.io_struct import (
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardBatchOutput,
PPProxyTensors,
)
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.patch_torch import monkey_patch_torch_reductions
from sglang.srt.server_args import ServerArgs
......@@ -234,9 +238,7 @@ class TpModelWorker:
model_worker_batch: ModelWorkerBatch,
launch_done: Optional[threading.Event] = None,
skip_sample: bool = False,
) -> Tuple[
Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool
]:
) -> ForwardBatchOutput:
# update the consumer index of hicache to the running batch
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
......@@ -271,13 +273,20 @@ class TpModelWorker:
else:
next_token_ids = self.model_runner.sample(logits_output, forward_batch)
return logits_output, next_token_ids, can_run_cuda_graph
return ForwardBatchOutput(
logits_output=logits_output,
next_token_ids=next_token_ids,
can_run_cuda_graph=can_run_cuda_graph,
)
else:
pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
forward_batch,
pp_proxy_tensors=pp_proxy_tensors,
)
return pp_proxy_tensors.tensors, None, can_run_cuda_graph
return ForwardBatchOutput(
pp_proxy_tensors=pp_proxy_tensors,
can_run_cuda_graph=can_run_cuda_graph,
)
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
......
......@@ -39,6 +39,7 @@ from sglang.srt.managers.io_struct import (
from sglang.srt.managers.overlap_utils import FutureMap
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import ForwardBatchOutput
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import DynamicGradMode
from sglang.utils import get_exception_traceback
......@@ -160,13 +161,17 @@ class TpModelWorkerClient:
self.future_map.resolve_future(model_worker_batch)
# Run forward
forward_batch_output = self.worker.forward_batch_generation(
model_worker_batch,
model_worker_batch.launch_done,
# Skip sampling for prefill-only requests
skip_sample=model_worker_batch.is_prefill_only,
)
logits_output, next_token_ids, can_run_cuda_graph = (
self.worker.forward_batch_generation(
model_worker_batch,
model_worker_batch.launch_done,
# Skip sampling for prefill-only requests
skip_sample=model_worker_batch.is_prefill_only,
)
forward_batch_output.logits_output,
forward_batch_output.next_token_ids,
forward_batch_output.can_run_cuda_graph,
)
# Update the future token ids map
......@@ -227,7 +232,7 @@ class TpModelWorkerClient:
def forward_batch_generation(
self, model_worker_batch: ModelWorkerBatch
) -> Tuple[None, torch.Tensor, bool]:
) -> ForwardBatchOutput:
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
sampling_info = model_worker_batch.sampling_info
sampling_info.update_penalties()
......@@ -250,7 +255,10 @@ class TpModelWorkerClient:
future_next_token_ids = self.future_map.update_next_future(
cur_future_map_ct, bs
)
return None, future_next_token_ids, False
return ForwardBatchOutput(
next_token_ids=future_next_token_ids,
can_run_cuda_graph=False,
)
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
success, message = self.worker.update_weights_from_disk(recv_req)
......
......@@ -2,11 +2,10 @@ from __future__ import annotations
import logging
import multiprocessing as mp
from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, List, Optional
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
if TYPE_CHECKING:
......
......@@ -900,6 +900,17 @@ class ForwardBatch:
return self.tbo_split_seq_index is not None
@dataclass
class ForwardBatchOutput:
# FIXME(lsyin): unify the forward batch output between different spec and parallelism
# need to be more organized
logits_output: Optional[torch.Tensor] = None
next_token_ids: Optional[torch.Tensor] = None
num_accepted_tokens: Optional[int] = None
pp_proxy_tensors: Optional[PPProxyTensors] = None
can_run_cuda_graph: bool = False
def enable_num_token_non_padded(server_args):
return get_moe_expert_parallel_world_size() > 1
......
......@@ -14,7 +14,6 @@ from sglang.srt.distributed import (
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.mm_utils import embed_mm_inputs
from sglang.srt.managers.schedule_batch import (
ScheduleBatch,
get_last_loc,
......@@ -24,6 +23,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardBatchOutput,
ForwardMode,
)
from sglang.srt.server_args import ServerArgs
......@@ -422,9 +422,7 @@ class EAGLEWorker(TpModelWorker):
def draft_model_runner(self):
return self.model_runner
def forward_batch_speculative_generation(
self, batch: ScheduleBatch
) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, int, bool]:
def forward_batch_generation(self, batch: ScheduleBatch) -> ForwardBatchOutput:
"""Run speculative decoding forward.
NOTE: Many states of batch is modified as you go through. It is not guaranteed that
......@@ -437,14 +435,19 @@ class EAGLEWorker(TpModelWorker):
the batch id (used for overlap schedule), and number of accepted tokens.
"""
if batch.forward_mode.is_extend() or batch.is_extend_in_batch:
logits_output, next_token_ids, bid, seq_lens_cpu = (
self.forward_target_extend(batch)
logits_output, next_token_ids, seq_lens_cpu = self.forward_target_extend(
batch
)
with self.draft_tp_context(self.draft_model_runner.tp_group):
self.forward_draft_extend(
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
)
return logits_output, next_token_ids, bid, 0, False
return ForwardBatchOutput(
logits_output=logits_output,
next_token_ids=next_token_ids,
num_accepted_tokens=0,
can_run_cuda_graph=False,
)
else:
with self.draft_tp_context(self.draft_model_runner.tp_group):
spec_info = self.draft(batch)
......@@ -462,12 +465,11 @@ class EAGLEWorker(TpModelWorker):
# decode is not finished
self.forward_draft_extend_after_decode(batch)
return (
logits_output,
verify_output.verified_id,
model_worker_batch.bid,
sum(verify_output.accept_length_per_req_cpu),
can_run_cuda_graph,
return ForwardBatchOutput(
logits_output=logits_output,
next_token_ids=verify_output.verified_id,
num_accepted_tokens=sum(verify_output.accept_length_per_req_cpu),
can_run_cuda_graph=can_run_cuda_graph,
)
def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
......@@ -499,19 +501,21 @@ class EAGLEWorker(TpModelWorker):
Returns:
logits_output: The output of logits. It will contain the full hidden states.
next_token_ids: Next token ids generated.
bid: The model batch ID. Used for overlap schedule.
"""
# Forward with the target model and get hidden states.
# We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
forward_batch_output = self.target_worker.forward_batch_generation(
model_worker_batch
)
logits_output, next_token_ids = (
forward_batch_output.logits_output,
forward_batch_output.next_token_ids,
)
return (
logits_output,
next_token_ids,
model_worker_batch.bid,
model_worker_batch.seq_lens_cpu,
)
......@@ -811,10 +815,12 @@ class EAGLEWorker(TpModelWorker):
).cpu()
# Forward
logits_output, _, can_run_cuda_graph = (
self.target_worker.forward_batch_generation(
model_worker_batch, skip_sample=True
)
forward_batch_output = self.target_worker.forward_batch_generation(
model_worker_batch, skip_sample=True
)
logits_output, can_run_cuda_graph = (
forward_batch_output.logits_output,
forward_batch_output.can_run_cuda_graph,
)
vocab_mask = None
......
......@@ -7,7 +7,7 @@ from sgl_kernel.speculative import reconstruct_indices_from_tree_mask
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.forward_batch_info import ForwardBatchOutput, ForwardMode
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
......@@ -207,17 +207,18 @@ class NGRAMWorker:
batch_tokens.append(put_ids)
self.ngram_cache.batch_put(batch_tokens)
def forward_batch_speculative_generation(self, batch: ScheduleBatch):
def forward_batch_generation(self, batch: ScheduleBatch) -> ForwardBatchOutput:
self._prepare_for_speculative_decoding(batch)
model_worker_batch = batch.get_model_worker_batch()
bid = model_worker_batch.bid
num_accepted_tokens = 0
if model_worker_batch.forward_mode.is_target_verify():
logits_output, _, can_run_cuda_graph = (
self.target_worker.forward_batch_generation(
model_worker_batch, skip_sample=True
)
forward_batch_output = self.target_worker.forward_batch_generation(
model_worker_batch, skip_sample=True
)
logits_output, can_run_cuda_graph = (
forward_batch_output.logits_output,
forward_batch_output.can_run_cuda_graph,
)
verify_input = model_worker_batch.spec_info
logits_output, next_token_ids, num_accepted_tokens = verify_input.verify(
......@@ -227,14 +228,18 @@ class NGRAMWorker:
batch.forward_mode = ForwardMode.DECODE
else:
forward_batch_output = self.target_worker.forward_batch_generation(
model_worker_batch
)
logits_output, next_token_ids, can_run_cuda_graph = (
self.target_worker.forward_batch_generation(model_worker_batch)
forward_batch_output.logits_output,
forward_batch_output.next_token_ids,
forward_batch_output.can_run_cuda_graph,
)
return (
logits_output,
next_token_ids,
bid,
num_accepted_tokens,
can_run_cuda_graph,
return ForwardBatchOutput(
logits_output=logits_output,
next_token_ids=next_token_ids,
num_accepted_tokens=num_accepted_tokens,
can_run_cuda_graph=can_run_cuda_graph,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment