Unverified Commit 973617ae authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

[Speculative decoding][Re-take] Enable TP>1 speculative decoding (#4840)


Co-authored-by: default avatarCade Daniel <edacih@gmail.com>
Co-authored-by: default avatarCade Daniel <cade@anyscale.com>
parent 30e75439
...@@ -42,6 +42,7 @@ steps: ...@@ -42,6 +42,7 @@ steps:
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- pytest -v -s spec_decode/e2e/test_integration_dist.py
- label: Distributed Tests (Multiple Groups) - label: Distributed Tests (Multiple Groups)
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
......
...@@ -18,6 +18,8 @@ def main(args: argparse.Namespace): ...@@ -18,6 +18,8 @@ def main(args: argparse.Namespace):
# NOTE(woosuk): If the request cannot be processed in a single batch, # NOTE(woosuk): If the request cannot be processed in a single batch,
# the engine will automatically process the request in multiple batches. # the engine will automatically process the request in multiple batches.
llm = LLM(model=args.model, llm = LLM(model=args.model,
speculative_model=args.speculative_model,
num_speculative_tokens=args.num_speculative_tokens,
tokenizer=args.tokenizer, tokenizer=args.tokenizer,
quantization=args.quantization, quantization=args.quantization,
tensor_parallel_size=args.tensor_parallel_size, tensor_parallel_size=args.tensor_parallel_size,
...@@ -28,6 +30,7 @@ def main(args: argparse.Namespace): ...@@ -28,6 +30,7 @@ def main(args: argparse.Namespace):
quantization_param_path=args.quantization_param_path, quantization_param_path=args.quantization_param_path,
device=args.device, device=args.device,
ray_workers_use_nsight=args.ray_workers_use_nsight, ray_workers_use_nsight=args.ray_workers_use_nsight,
use_v2_block_manager=args.use_v2_block_manager,
enable_chunked_prefill=args.enable_chunked_prefill, enable_chunked_prefill=args.enable_chunked_prefill,
download_dir=args.download_dir, download_dir=args.download_dir,
block_size=args.block_size) block_size=args.block_size)
...@@ -99,6 +102,8 @@ if __name__ == '__main__': ...@@ -99,6 +102,8 @@ if __name__ == '__main__':
description='Benchmark the latency of processing a single batch of ' description='Benchmark the latency of processing a single batch of '
'requests till completion.') 'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m') parser.add_argument('--model', type=str, default='facebook/opt-125m')
parser.add_argument('--speculative-model', type=str, default=None)
parser.add_argument('--num-speculative-tokens', type=int, default=None)
parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
...@@ -181,6 +186,7 @@ if __name__ == '__main__': ...@@ -181,6 +186,7 @@ if __name__ == '__main__':
action='store_true', action='store_true',
help='If True, the prefill requests can be chunked based on the ' help='If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens') 'max_num_batched_tokens')
parser.add_argument('--use-v2-block-manager', action='store_true')
parser.add_argument( parser.add_argument(
"--ray-workers-use-nsight", "--ray-workers-use-nsight",
action='store_true', action='store_true',
......
...@@ -5,56 +5,6 @@ from vllm import SamplingParams ...@@ -5,56 +5,6 @@ from vllm import SamplingParams
from .conftest import get_output_from_llm_generator from .conftest import get_output_from_llm_generator
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Expect failure as spec decode not supported by
# Ray backend.
"worker_use_ray": True,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_ray(test_llm_generator):
"""Verify that speculative decoding with Ray fails.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
try:
with pytest.raises(
AssertionError,
match="Speculative decoding not yet supported for "):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
finally:
# we need to free up ray resource,
# so that latter test could use the gpu we allocated here
import ray
ray.shutdown()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [{
......
"""Tests which cover integration of the speculative decoding framework with
other features, e.g. cuda graphs.
"""
import pytest
from .conftest import run_greedy_equality_correctness_test
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Required for spec decode.
"use_v2_block_manager": True,
# Verify equality when cuda graphs allowed.
"enforce_eager": False,
"model": "JackFram/llama-68m",
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Identical models.
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("output_len", [32])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator,
batch_size, output_len):
"""Verify spec decode equality when cuda graphs are enabled.
"""
run_greedy_equality_correctness_test(
baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True,
)
"""Tests which cover integration of the speculative decoding framework with
tensor parallelism.
"""
import pytest
import torch
from vllm.utils import is_hip
from .conftest import run_greedy_equality_correctness_test
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
"tensor_parallel_size": 2,
# Use AsyncLLM engine, so that the engine runs in its own process.
# Otherwise, since vLLM does not follow true SPMD, the test runner
# process will have both the engine and the rank0 worker. NCCL is not
# cleaned up properly, and its server host thread leaks, causing the
# second run of the test to fail with internal NCCL error.
"use_async": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
},
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify greedy equality when tensor parallelism is used.
"""
if is_hip():
pytest.skip("hip is not well-supported yet")
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
...@@ -611,40 +611,3 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int, ...@@ -611,40 +611,3 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
batch_size, batch_size,
max_output_len=output_len, max_output_len=output_len,
force_output_len=True) force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Required for spec decode.
"use_v2_block_manager": True,
# Verify equality when cuda graphs allowed.
"enforce_eager": False,
"model": "JackFram/llama-68m",
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Identical models.
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("output_len", [32])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator,
batch_size, output_len):
"""Verify spec decode equality when cuda graphs are enabled.
"""
run_greedy_equality_correctness_test(
baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True,
)
...@@ -219,16 +219,16 @@ def broadcast_tensor_dict( ...@@ -219,16 +219,16 @@ def broadcast_tensor_dict(
to broadcast the metadata of the dict (e.g. dict structure, tensor sizes, to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
dtypes). dtypes).
""" """
# Bypass the function if we are using only 1 GPU.
if (not torch.distributed.is_initialized()
or torch.distributed.get_world_size(group=group) == 1):
return tensor_dict
group = group or torch.distributed.group.WORLD group = group or torch.distributed.group.WORLD
metadata_group = metadata_group or get_cpu_world_group() metadata_group = metadata_group or get_cpu_world_group()
ranks = torch.distributed.get_process_group_ranks(group) ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})" assert src in ranks, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1:
return tensor_dict
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
if rank == src: if rank == src:
metadata_list: List[Tuple[Any, Any]] = [] metadata_list: List[Tuple[Any, Any]] = []
......
...@@ -15,14 +15,13 @@ class GPUExecutor(ExecutorBase): ...@@ -15,14 +15,13 @@ class GPUExecutor(ExecutorBase):
def _init_executor(self) -> None: def _init_executor(self) -> None:
"""Initialize the worker and load the model. """Initialize the worker and load the model.
If speculative decoding is enabled, we instead create the speculative
worker.
""" """
if self.speculative_config is None: assert self.parallel_config.world_size == 1, (
self._init_non_spec_worker() "GPUExecutor only supports single GPU.")
else:
self._init_spec_worker() self.driver_worker = self._create_worker()
self.driver_worker.init_device()
self.driver_worker.load_model()
def _get_worker_kwargs( def _get_worker_kwargs(
self, self,
...@@ -45,6 +44,7 @@ class GPUExecutor(ExecutorBase): ...@@ -45,6 +44,7 @@ class GPUExecutor(ExecutorBase):
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
lora_config=self.lora_config, lora_config=self.lora_config,
vision_language_config=self.vision_language_config, vision_language_config=self.vision_language_config,
speculative_config=self.speculative_config,
is_driver_worker=rank == 0, is_driver_worker=rank == 0,
) )
...@@ -52,59 +52,22 @@ class GPUExecutor(ExecutorBase): ...@@ -52,59 +52,22 @@ class GPUExecutor(ExecutorBase):
local_rank: int = 0, local_rank: int = 0,
rank: int = 0, rank: int = 0,
distributed_init_method: Optional[str] = None): distributed_init_method: Optional[str] = None):
if self.speculative_config is None:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"
else:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
wrapper = WorkerWrapperBase( wrapper = WorkerWrapperBase(
worker_module_name="vllm.worker.worker", worker_module_name=worker_module_name,
worker_class_name="Worker", worker_class_name=worker_class_name,
) )
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank, wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method)) distributed_init_method))
return wrapper.worker return wrapper.worker
def _init_non_spec_worker(self):
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")
self.driver_worker = self._create_worker()
self.driver_worker.init_device()
self.driver_worker.load_model()
def _init_spec_worker(self):
"""Initialize a SpecDecodeWorker, using a draft model for proposals.
"""
assert self.speculative_config is not None
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
target_worker = self._create_worker()
draft_worker_kwargs = self._get_worker_kwargs()
# Override draft-model specific worker args.
draft_worker_kwargs.update(
model_config=self.speculative_config.draft_model_config,
parallel_config=self.speculative_config.draft_parallel_config,
ngram_prompt_lookup_max=self.speculative_config.
ngram_prompt_lookup_max,
ngram_prompt_lookup_min=self.speculative_config.
ngram_prompt_lookup_min,
# TODO allow draft-model specific load config.
#load_config=self.load_config,
)
spec_decode_worker = SpecDecodeWorker.create_worker(
scorer_worker=target_worker,
draft_worker_kwargs=draft_worker_kwargs,
disable_by_batch_size=self.speculative_config.
speculative_disable_by_batch_size,
)
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")
self.driver_worker = spec_decode_worker
# Load model handled in spec decode worker.
self.driver_worker.init_device()
def determine_num_available_blocks(self) -> Tuple[int, int]: def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the """Determine the number of available KV blocks by invoking the
underlying worker. underlying worker.
......
...@@ -28,9 +28,6 @@ USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG ...@@ -28,9 +28,6 @@ USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
class RayGPUExecutor(DistributedGPUExecutor): class RayGPUExecutor(DistributedGPUExecutor):
def _init_executor(self) -> None: def _init_executor(self) -> None:
assert (not self.speculative_config
), "Speculative decoding not yet supported for RayGPU backend."
assert self.parallel_config.distributed_executor_backend == "ray" assert self.parallel_config.distributed_executor_backend == "ray"
placement_group = self.parallel_config.placement_group placement_group = self.parallel_config.placement_group
...@@ -90,14 +87,22 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -90,14 +87,22 @@ class RayGPUExecutor(DistributedGPUExecutor):
placement_group_capture_child_tasks=True, placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id, placement_group_bundle_index=bundle_id,
) )
if self.speculative_config is not None:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
else:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"
worker = ray.remote( worker = ray.remote(
num_cpus=0, num_cpus=0,
num_gpus=num_gpus, num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy, scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs, **ray_remote_kwargs,
)(RayWorkerWrapper).remote( )(RayWorkerWrapper).remote(
worker_module_name="vllm.worker.worker", worker_module_name=worker_module_name,
worker_class_name="Worker", worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code, trust_remote_code=self.model_config.trust_remote_code,
) )
...@@ -107,8 +112,8 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -107,8 +112,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
# as the resource holder for the driver process. # as the resource holder for the driver process.
self.driver_dummy_worker = worker self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper( self.driver_worker = RayWorkerWrapper(
worker_module_name="vllm.worker.worker", worker_module_name=worker_module_name,
worker_class_name="Worker", worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code, trust_remote_code=self.model_config.trust_remote_code,
) )
else: else:
......
...@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple ...@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple
import torch import torch
from vllm.distributed.communication_op import broadcast_tensor_dict
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
...@@ -17,11 +18,43 @@ from vllm.spec_decode.util import (create_sequence_group_output, ...@@ -17,11 +18,43 @@ from vllm.spec_decode.util import (create_sequence_group_output,
get_all_num_logprobs, get_all_seq_ids, get_all_num_logprobs, get_all_seq_ids,
get_sampled_token_logprobs, nvtx_range, get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len) split_batch_by_proposal_len)
from vllm.worker.worker import Worker
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
logger = init_logger(__name__) logger = init_logger(__name__)
def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
"""Helper method that is the entrypoint for Executors which use
WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
"""
assert "speculative_config" in kwargs
speculative_config = kwargs.get("speculative_config")
assert speculative_config is not None
target_worker = Worker(*args, **kwargs)
draft_worker_kwargs = kwargs.copy()
# Override draft-model specific worker args.
draft_worker_kwargs.update(
model_config=speculative_config.draft_model_config,
parallel_config=speculative_config.draft_parallel_config,
ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
# TODO allow draft-model specific load config.
#load_config=load_config,
)
spec_decode_worker = SpecDecodeWorker.create_worker(
scorer_worker=target_worker,
draft_worker_kwargs=draft_worker_kwargs,
disable_by_batch_size=speculative_config.
speculative_disable_by_batch_size,
)
return spec_decode_worker
class SpecDecodeWorker(LoraNotSupportedWorkerBase): class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"""Worker which implements speculative decoding. """Worker which implements speculative decoding.
...@@ -142,6 +175,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -142,6 +175,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self._configure_model_sampler_for_spec_decode() self._configure_model_sampler_for_spec_decode()
def load_model(self, *args, **kwargs):
pass
def _configure_model_sampler_for_spec_decode(self): def _configure_model_sampler_for_spec_decode(self):
"""Configure model sampler to emit GPU tensors. This allows spec decode """Configure model sampler to emit GPU tensors. This allows spec decode
to keep data on device without transferring to CPU and serializing, to keep data on device without transferring to CPU and serializing,
...@@ -195,23 +231,91 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -195,23 +231,91 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks) num_cpu_blocks=num_cpu_blocks)
def _broadcast_control_flow_decision(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
disable_all_speculation: bool = False) -> Tuple[int, bool]:
"""Broadcast how many lookahead slots are scheduled for this step, and
whether all speculation is disabled, to all non-driver workers.
This is required as if the number of draft model runs changes
dynamically, the non-driver workers won't know unless we perform a
communication to inform then.
Returns the broadcasted num_lookahead_slots and disable_all_speculation.
"""
if self.rank == self._driver_rank:
assert execute_model_req is not None
broadcast_dict = dict(
num_lookahead_slots=execute_model_req.num_lookahead_slots,
disable_all_speculation=disable_all_speculation,
)
broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
else:
assert execute_model_req is None
broadcast_dict = broadcast_tensor_dict(src=self._driver_rank)
return (broadcast_dict["num_lookahead_slots"],
broadcast_dict["disable_all_speculation"])
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""Perform speculative decoding on the input batch. """Perform speculative decoding on the input batch.
""" """
disable_all_speculation = False
if self.rank == self._driver_rank:
disable_all_speculation = self._should_disable_all_speculation(
execute_model_req)
(num_lookahead_slots,
disable_all_speculation) = self._broadcast_control_flow_decision(
execute_model_req, disable_all_speculation)
if self.rank == self._driver_rank:
assert execute_model_req is not None
assert execute_model_req.seq_group_metadata_list is not None, ( assert execute_model_req.seq_group_metadata_list is not None, (
"speculative decoding " "speculative decoding requires non-None seq_group_metadata_list"
"requires non-None seq_group_metadata_list") )
self._maybe_disable_speculative_tokens(
disable_all_speculation,
execute_model_req.seq_group_metadata_list)
# If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill.
if num_lookahead_slots == 0 or len(
execute_model_req.seq_group_metadata_list) == 0:
return self._run_no_spec(execute_model_req,
skip_proposer=disable_all_speculation)
return self._run_speculative_decoding_step(execute_model_req,
num_lookahead_slots)
else:
self._run_non_driver_rank(num_lookahead_slots)
return []
def _should_disable_all_speculation(
self, execute_model_req: ExecuteModelRequest) -> bool:
# When the batch size is too large, disable speculative decoding # When the batch size is too large, disable speculative decoding
# to stop trading off throughput for latency. # to stop trading off throughput for latency.
disable_all = (execute_model_req.running_queue_size >= disable_all_speculation = (execute_model_req.running_queue_size >=
self.disable_by_batch_size) self.disable_by_batch_size)
if disable_all:
for seq_group_metadata in execute_model_req.seq_group_metadata_list: return disable_all_speculation
def _maybe_disable_speculative_tokens(
self, disable_all_speculation: bool,
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
if not disable_all_speculation:
return
for seq_group_metadata in seq_group_metadata_list:
# Once num_speculative_tokens is set to 0, the spec decode # Once num_speculative_tokens is set to 0, the spec decode
# of this request will be disabled forever. # of this request will be disabled forever.
# TODO(comaniac): We currently store spec decoding specific # TODO(comaniac): We currently store spec decoding specific
...@@ -219,16 +323,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -219,16 +323,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# this state within spec decode worker. # this state within spec decode worker.
seq_group_metadata.num_speculative_tokens = 0 seq_group_metadata.num_speculative_tokens = 0
# If no spec tokens, call the proposer and scorer workers normally.
# This happens for prefill, or when the spec decode is disabled
# for this batch.
if execute_model_req.num_lookahead_slots == 0 or len(
execute_model_req.seq_group_metadata_list) == 0:
return self._run_no_spec(execute_model_req,
skip_proposer=disable_all)
return self._run_speculative_decoding_step(execute_model_req)
@nvtx_range("spec_decode_worker._run_no_spec") @nvtx_range("spec_decode_worker._run_no_spec")
def _run_no_spec(self, execute_model_req: ExecuteModelRequest, def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
skip_proposer: bool) -> List[SamplerOutput]: skip_proposer: bool) -> List[SamplerOutput]:
...@@ -252,10 +346,28 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -252,10 +346,28 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
sampler_output.logprobs = None sampler_output.logprobs = None
return [sampler_output] return [sampler_output]
def _run_non_driver_rank(self, num_lookahead_slots: int) -> None:
"""Run proposer and verifier model in non-driver workers. This is used
for both speculation cases (num_lookahead_slots>0) and non-speculation
cases (e.g. prefill).
"""
# In non-driver workers the input is None
execute_model_req = None
# Even if num_lookahead_slots is zero, we want to run the proposer model
# as it may have KV.
#
# We run the proposer once per lookahead slot. In the future we should
# delegate how many times it runs to the proposer.
for _ in range(max(num_lookahead_slots, 1)):
self.proposer_worker.execute_model(execute_model_req)
self.scorer_worker.execute_model(execute_model_req)
@nvtx_range("spec_decode_worker._run_speculative_decoding_step") @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
def _run_speculative_decoding_step( def _run_speculative_decoding_step(
self, self, execute_model_req: ExecuteModelRequest,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: num_lookahead_slots: int) -> List[SamplerOutput]:
"""Execute a single step of speculative decoding. """Execute a single step of speculative decoding.
This invokes the proposer worker to get k speculative tokens for each This invokes the proposer worker to get k speculative tokens for each
...@@ -264,6 +376,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -264,6 +376,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
Returns a list of SamplerOutput, each containing a single token per Returns a list of SamplerOutput, each containing a single token per
sequence. sequence.
""" """
assert num_lookahead_slots == execute_model_req.num_lookahead_slots
# Generate proposals using draft worker. # Generate proposals using draft worker.
proposals = self.proposer_worker.get_spec_proposals(execute_model_req) proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
...@@ -455,6 +568,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -455,6 +568,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
def device(self): def device(self):
return self.scorer_worker.device return self.scorer_worker.device
@property
def _driver_rank(self) -> int:
return 0
def get_cache_block_size_bytes(self): def get_cache_block_size_bytes(self):
"""Return the size of a cache block in bytes. """Return the size of a cache block in bytes.
......
...@@ -8,7 +8,7 @@ import torch.distributed ...@@ -8,7 +8,7 @@ import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig) SpeculativeConfig, VisionLanguageConfig)
from vllm.distributed import (broadcast_tensor_dict, from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized, ensure_model_parallel_initialized,
init_distributed_environment, init_distributed_environment,
...@@ -43,6 +43,7 @@ class Worker(WorkerBase): ...@@ -43,6 +43,7 @@ class Worker(WorkerBase):
distributed_init_method: str, distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
vision_language_config: Optional[VisionLanguageConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
is_driver_worker: bool = False, is_driver_worker: bool = False,
) -> None: ) -> None:
self.model_config = model_config self.model_config = model_config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment