Unverified Commit e95cd879 authored by Cade Daniel's avatar Cade Daniel Committed by GitHub
Browse files

[Speculative decoding 6/9] Integrate speculative decoding with LLMEngine (#3894)

parent 69e1d2fb
...@@ -48,10 +48,13 @@ class NeuronExecutor(ExecutorBase): ...@@ -48,10 +48,13 @@ class NeuronExecutor(ExecutorBase):
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int], blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int], blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int) -> List[SamplerOutput]:
assert (blocks_to_swap_in == {} and blocks_to_swap_out == {} assert (blocks_to_swap_in == {} and blocks_to_swap_out == {}
and blocks_to_copy == {}), ( and blocks_to_copy == {}), (
"Cache operations are not supported for Neuron backend.") "Cache operations are not supported for Neuron backend.")
assert num_lookahead_slots == 0, (
"lookahead not supported for Neuron backend.")
output = self.driver_worker.execute_model( output = self.driver_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list) seq_group_metadata_list=seq_group_metadata_list)
......
...@@ -242,7 +242,8 @@ class RayGPUExecutor(ExecutorBase): ...@@ -242,7 +242,8 @@ class RayGPUExecutor(ExecutorBase):
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int], blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int], blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int = 0) -> SamplerOutput:
all_outputs = self._run_workers( all_outputs = self._run_workers(
"execute_model", "execute_model",
driver_kwargs={ driver_kwargs={
......
...@@ -693,3 +693,16 @@ class SamplerOutput: ...@@ -693,3 +693,16 @@ class SamplerOutput:
def __eq__(self, other: object): def __eq__(self, other: object):
return isinstance(other, return isinstance(other,
self.__class__) and self.outputs == other.outputs self.__class__) and self.outputs == other.outputs
def __repr__(self) -> str:
"""Show the shape of a tensor instead of its values to reduce noise.
"""
sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
else self.sampled_token_probs.shape)
sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
self.sampled_token_ids.shape)
return (
f"SamplerOutput(outputs={self.outputs}, "
f"sampled_token_probs={sampled_token_probs_repr}, "
f"sampled_token_ids={sampled_token_ids_repr}, "
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
...@@ -6,10 +6,10 @@ import torch ...@@ -6,10 +6,10 @@ import torch
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores) SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, from vllm.spec_decode.util import (get_all_seq_ids, maybe_mock_device_tensors,
sampler_output_to_torch, nvtx_range, sampler_output_to_torch,
split_batch_by_proposal_len) split_batch_by_proposal_len)
from vllm.worker.worker import Worker from vllm.worker.worker_base import WorkerBase
SeqId = int SeqId = int
TargetSeqId = int TargetSeqId = int
...@@ -31,7 +31,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -31,7 +31,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
of topk/tree. of topk/tree.
""" """
def __init__(self, scorer_worker: Worker, device: str, vocab_size: int): def __init__(self, scorer_worker: WorkerBase, device: str,
vocab_size: int):
self._scorer_worker = scorer_worker self._scorer_worker = scorer_worker
self._device = device self._device = device
self._vocab_size = vocab_size self._vocab_size = vocab_size
...@@ -83,7 +84,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -83,7 +84,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
return_python_output=False) )
assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0]
all_tokens, all_probs = self._contract_batch( all_tokens, all_probs = self._contract_batch(
original_bs=len(seq_group_metadata_list), original_bs=len(seq_group_metadata_list),
...@@ -142,6 +145,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -142,6 +145,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
This maps the scores of speculative tokens back to their original This maps the scores of speculative tokens back to their original
sequences. sequences.
""" """
# We mock the device tensors until PR 7/9 is merged (e2e correctness).
# https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
maybe_mock_device_tensors(
sampler_output=target_sampler_output,
batch_size=len(non_spec_indices) + num_scoring_tokens,
vocab_size=self._vocab_size,
device=self._device,
)
(target_token_ids, target_probs, non_spec_target_token_ids, (target_token_ids, target_probs, non_spec_target_token_ids,
non_spec_target_probs) = self._split_scoring_output( non_spec_target_probs) = self._split_scoring_output(
target_sampler_output, num_scoring_tokens) target_sampler_output, num_scoring_tokens)
......
...@@ -6,7 +6,8 @@ import torch ...@@ -6,7 +6,8 @@ import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer) SpeculativeProposer)
from vllm.spec_decode.util import sampler_output_to_torch from vllm.spec_decode.util import (maybe_mock_device_tensors,
sampler_output_to_torch)
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
...@@ -69,6 +70,9 @@ class MultiStepWorker(Worker): ...@@ -69,6 +70,9 @@ class MultiStepWorker(Worker):
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
) )
assert (len(model_output) == 1
), "composing multistep workers not supported"
model_output = model_output[0]
self._append_new_tokens(model_output, self._append_new_tokens(model_output,
copied_seq_group_metadata_list) copied_seq_group_metadata_list)
...@@ -341,6 +345,16 @@ class DraftModelTop1Proposer(SpeculativeProposer): ...@@ -341,6 +345,16 @@ class DraftModelTop1Proposer(SpeculativeProposer):
sampler_output = maybe_sampler_output sampler_output = maybe_sampler_output
# We mock the device tensors until PR 7/9 is merged (e2e correctness).
# https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
for step_output in sampler_output:
maybe_mock_device_tensors(
sampler_output=step_output,
batch_size=len(proposal_lens),
vocab_size=self._vocab_size,
device=self._device,
)
proposal_tokens, proposal_probs = sampler_output_to_torch( proposal_tokens, proposal_probs = sampler_output_to_torch(
sampler_output) sampler_output)
......
...@@ -3,8 +3,9 @@ from typing import Dict, List, Optional, Tuple ...@@ -3,8 +3,9 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import (SamplerOutput, SequenceGroupMetadata, from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata,
SequenceGroupOutput, SequenceOutput) SequenceGroupOutput, SequenceOutput)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
...@@ -13,8 +14,9 @@ from vllm.spec_decode.metrics import AsyncMetricsCollector ...@@ -13,8 +14,9 @@ from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
split_batch_by_proposal_len) split_batch_by_proposal_len)
from vllm.worker.worker import Worker from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
logger = init_logger(__name__)
class SpecDecodeWorker(LoraNotSupportedWorkerBase): class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...@@ -45,10 +47,20 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -45,10 +47,20 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit. More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
""" """
@classmethod
def from_workers(cls, proposer_worker: MultiStepWorker,
scorer_worker: WorkerBase) -> "SpecDecodeWorker":
return SpecDecodeWorker(
proposer_worker,
scorer_worker,
# TODO(cade) disable strict mode for speedup.
rejection_sampler=RejectionSampler(strict_mode=True),
)
def __init__( def __init__(
self, self,
proposer_worker: MultiStepWorker, proposer_worker: MultiStepWorker,
scorer_worker: Worker, scorer_worker: WorkerBase,
rejection_sampler: RejectionSampler, rejection_sampler: RejectionSampler,
metrics_collector: Optional[AsyncMetricsCollector] = None, metrics_collector: Optional[AsyncMetricsCollector] = None,
): ):
...@@ -87,6 +99,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -87,6 +99,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.scorer_worker.init_device() self.scorer_worker.init_device()
self.proposer_worker.init_device() self.proposer_worker.init_device()
# NOTE(cade): load_model is not part of the WorkerBase interface.
self.scorer_worker.load_model()
self.proposer_worker.load_model()
self._metrics.init_gpu_tensors(self.rank) self._metrics.init_gpu_tensors(self.rank)
self.rejection_sampler.init_gpu_tensors(self.rank) self.rejection_sampler.init_gpu_tensors(self.rank)
self.scorer = BatchExpansionTop1Scorer( self.scorer = BatchExpansionTop1Scorer(
...@@ -131,7 +147,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -131,7 +147,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
blocks_to_swap_in: Optional[Dict[int, int]], blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]], blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]], blocks_to_copy: Optional[Dict[int, List[int]]],
num_spec_tokens: int, num_lookahead_slots: int,
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
"""Perform speculative decoding on the input batch. """Perform speculative decoding on the input batch.
""" """
...@@ -140,9 +156,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -140,9 +156,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"speculative decoding " "speculative decoding "
"requires non-None seq_group_metadata_list") "requires non-None seq_group_metadata_list")
logger.info(f"spec_decode_worker.execute_model {num_lookahead_slots=}")
# If no spec tokens, call the proposer and scorer workers normally. # If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill. # Used for prefill.
if num_spec_tokens == 0 or len(seq_group_metadata_list) == 0: if num_lookahead_slots == 0 or len(seq_group_metadata_list) == 0:
return self._run_no_spec( return self._run_no_spec(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
...@@ -155,7 +173,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -155,7 +173,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
k=num_spec_tokens, k=num_lookahead_slots,
) )
@nvtx_range("spec_decode_worker._run_no_spec") @nvtx_range("spec_decode_worker._run_no_spec")
...@@ -170,20 +188,24 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -170,20 +188,24 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposer and scorer model so that the KV cache is consistent between the proposer and scorer model so that the KV cache is consistent between the
two. two.
""" """
logger.info("run proposer worker no spec")
self.proposer_worker.execute_model( self.proposer_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
return_python_output=False) )
logger.info("run target worker no spec")
sampler_output = self.scorer_worker.execute_model( sampler_output = self.scorer_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
) )
assert len(sampler_output) == 1
sampler_output = sampler_output[0]
# Clear device tensors from sampler output. This reduces communication # Clear device tensors from sampler output. This reduces communication
# overhead when the engine runs in a different process than the workers. # overhead when the engine runs in a different process than the workers.
...@@ -209,11 +231,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -209,11 +231,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
sequence. sequence.
""" """
logger.info("get spec proposals")
# Generate proposals using draft worker. # Generate proposals using draft worker.
proposals = self.proposer_worker.get_spec_proposals( proposals = self.proposer_worker.get_spec_proposals(
seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
blocks_to_copy, k) blocks_to_copy, k)
logger.info("score proposals")
proposal_scores = self.scorer.score_proposals( proposal_scores = self.scorer.score_proposals(
seq_group_metadata_list, seq_group_metadata_list,
blocks_to_swap_in, blocks_to_swap_in,
...@@ -223,9 +247,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -223,9 +247,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposals, proposals,
) )
logger.info("verify proposals")
accepted_token_ids = self._verify_tokens(seq_group_metadata_list, accepted_token_ids = self._verify_tokens(seq_group_metadata_list,
proposal_scores, proposals, k) proposal_scores, proposals, k)
logger.info("create output list")
return self._create_output_sampler_list(seq_group_metadata_list, return self._create_output_sampler_list(seq_group_metadata_list,
accepted_token_ids, k) accepted_token_ids, k)
...@@ -311,7 +337,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -311,7 +337,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
parent_seq_id=seq_id, parent_seq_id=seq_id,
output_token=token_id, output_token=token_id,
# TODO Add verifier logprobs. # TODO Add verifier logprobs.
logprobs={token_id: 0.0}, logprobs={token_id: Logprob(0.0)},
) )
], ],
prompt_logprobs=None, prompt_logprobs=None,
......
...@@ -82,6 +82,32 @@ def sampler_output_to_torch( ...@@ -82,6 +82,32 @@ def sampler_output_to_torch(
return sampled_token_ids, sampled_token_probs return sampled_token_ids, sampled_token_probs
def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int,
vocab_size: int, device: str) -> None:
"""Helper method which mocks out the GPU tensors in SamplerOutput with dummy
values. This will be removed in PR 7/9.
https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
"""
values = [
sampler_output.sampled_token_probs, sampler_output.sampled_token_ids
]
assert all(v is None for v in values) or not any(v is None for v in values)
if not any(v is None for v in values):
# Do nothing if the tensors are already created (usually in unit tests).
return
# Softmax to ensure valid probs.
sampler_output.sampled_token_probs = torch.nn.functional.softmax(
torch.rand(batch_size, vocab_size, dtype=torch.float32, device=device),
dim=-1)
sampler_output.sampled_token_ids = torch.randint(low=10,
high=100,
size=(batch_size, ),
dtype=torch.long,
device=device)
@contextmanager @contextmanager
def nvtx_range(msg, *args, **kwargs): def nvtx_range(msg, *args, **kwargs):
""" """
......
...@@ -251,7 +251,7 @@ class CPUWorker(LoraNotSupportedWorkerBase): ...@@ -251,7 +251,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
blocks_to_swap_in: Optional[Dict[int, int]] = None, blocks_to_swap_in: Optional[Dict[int, int]] = None,
blocks_to_swap_out: Optional[Dict[int, int]] = None, blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, List[int]]] = None, blocks_to_copy: Optional[Dict[int, List[int]]] = None,
) -> Optional[SamplerOutput]: ) -> List[SamplerOutput]:
if self.is_driver_worker: if self.is_driver_worker:
assert seq_group_metadata_list is not None assert seq_group_metadata_list is not None
num_seq_groups = len(seq_group_metadata_list) num_seq_groups = len(seq_group_metadata_list)
...@@ -274,11 +274,13 @@ class CPUWorker(LoraNotSupportedWorkerBase): ...@@ -274,11 +274,13 @@ class CPUWorker(LoraNotSupportedWorkerBase):
# If there is no input, we don't need to execute the model. # If there is no input, we don't need to execute the model.
if num_seq_groups == 0: if num_seq_groups == 0:
return {} return []
output = self.model_runner.execute_model(seq_group_metadata_list, output = self.model_runner.execute_model(seq_group_metadata_list,
self.cpu_cache) self.cpu_cache)
return output
# CPU worker only supports single-step execution.
return [output]
def init_distributed_environment(self) -> None: def init_distributed_environment(self) -> None:
"""Initialize the distributed environment.""" """Initialize the distributed environment."""
......
"""A Neuron worker class.""" """A Neuron worker class."""
from typing import List, Optional, Tuple from typing import List, Tuple
import torch import torch
import torch.distributed import torch.distributed
...@@ -73,15 +73,18 @@ class NeuronWorker(LoraNotSupportedWorkerBase): ...@@ -73,15 +73,18 @@ class NeuronWorker(LoraNotSupportedWorkerBase):
def execute_model( def execute_model(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Optional[SamplerOutput]: ) -> List[SamplerOutput]:
num_seq_groups = len(seq_group_metadata_list) num_seq_groups = len(seq_group_metadata_list)
# If there is no input, we don't need to execute the model. # If there is no input, we don't need to execute the model.
if num_seq_groups == 0: if num_seq_groups == 0:
return {} return []
output = self.model_runner.execute_model(seq_group_metadata_list) output = self.model_runner.execute_model(seq_group_metadata_list)
return output
# Neuron worker only supports single-step output. Wrap the output in a
# list to conform to interface.
return [output]
def get_cache_block_size_bytes(self) -> int: def get_cache_block_size_bytes(self) -> int:
"""Determine the size in bytes of a cache block. """Determine the size in bytes of a cache block.
......
...@@ -210,7 +210,9 @@ class Worker(WorkerBase): ...@@ -210,7 +210,9 @@ class Worker(WorkerBase):
blocks_to_swap_in: Optional[Dict[int, int]] = None, blocks_to_swap_in: Optional[Dict[int, int]] = None,
blocks_to_swap_out: Optional[Dict[int, int]] = None, blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, List[int]]] = None, blocks_to_copy: Optional[Dict[int, List[int]]] = None,
) -> Optional[SamplerOutput]: num_lookahead_slots: int = 0,
) -> List[SamplerOutput]:
if self.is_driver_worker: if self.is_driver_worker:
assert seq_group_metadata_list is not None assert seq_group_metadata_list is not None
num_seq_groups = len(seq_group_metadata_list) num_seq_groups = len(seq_group_metadata_list)
...@@ -235,11 +237,14 @@ class Worker(WorkerBase): ...@@ -235,11 +237,14 @@ class Worker(WorkerBase):
# If there is no input, we don't need to execute the model. # If there is no input, we don't need to execute the model.
if num_seq_groups == 0: if num_seq_groups == 0:
return {} return []
output = self.model_runner.execute_model(seq_group_metadata_list, output = self.model_runner.execute_model(seq_group_metadata_list,
self.gpu_cache) self.gpu_cache)
return output
# Worker only supports single-step execution. Wrap the output in a list
# to conform to interface.
return [output]
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request) return self.model_runner.add_lora(lora_request)
......
...@@ -40,12 +40,13 @@ class WorkerBase(ABC): ...@@ -40,12 +40,13 @@ class WorkerBase(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def execute_model(self, def execute_model(
seq_group_metadata_list: List[SequenceGroupMetadata], self, seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int], blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int,
blocks_to_swap_out: Dict[int, int], int],
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: blocks_to_copy: Dict[int, List[int]]) -> List[SamplerOutput]:
"""Executes one model step on the given sequences.""" """Executes at least one model step on the given sequences, unless no
sequences are provided."""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment