Commit 19bc93d9 authored by 王敏's avatar 王敏
Browse files

增加medusa并行解码功能,后续增加使用说明和测试文档

parent aba40fda
...@@ -13,6 +13,7 @@ import numpy as np ...@@ -13,6 +13,7 @@ import numpy as np
import torch import torch
import torch.distributed import torch.distributed
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
...@@ -197,6 +198,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -197,6 +198,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.lora_requests.clear() # type: ignore self.lora_requests.clear() # type: ignore
self.prompt_adapter_index_mapping.clear() # type: ignore self.prompt_adapter_index_mapping.clear() # type: ignore
self.prompt_adapter_prompt_mapping.clear() # type: ignore self.prompt_adapter_prompt_mapping.clear() # type: ignore
self.tree_attn_masks[0] = None # type: ignore
def __init__( def __init__(
self, self,
...@@ -244,6 +246,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -244,6 +246,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
reinit: bool = False, reinit: bool = False,
reinit_use_defaults: bool = False, reinit_use_defaults: bool = False,
encoder_seq_len: int = 0, encoder_seq_len: int = 0,
# attention mask used in tree-style generation
tree_attn_masks: Optional[List[torch.Tensor]] = None,
): ):
if reinit: if reinit:
assert len(self.seq_ids) == len(seq_ids) # type: ignore assert len(self.seq_ids) == len(seq_ids) # type: ignore
...@@ -335,6 +340,11 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -335,6 +340,11 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
else: else:
self.prompt_adapter_prompt_mapping.clear() self.prompt_adapter_prompt_mapping.clear()
if tree_attn_masks:
self.tree_attn_masks = tree_attn_masks
else:
self.tree_attn_masks.clear()
else: else:
self.input_tokens = input_tokens or [] self.input_tokens = input_tokens or []
self.input_positions = input_positions or [] self.input_positions = input_positions or []
...@@ -354,6 +364,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -354,6 +364,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
prompt_adapter_index_mapping or []) prompt_adapter_index_mapping or [])
self.prompt_adapter_prompt_mapping = ( self.prompt_adapter_prompt_mapping = (
prompt_adapter_prompt_mapping or []) prompt_adapter_prompt_mapping or [])
self.tree_attn_masks = tree_attn_masks or []
self.prompt_adapter_request = prompt_adapter_request self.prompt_adapter_request = prompt_adapter_request
self.multi_modal_inputs = multi_modal_inputs self.multi_modal_inputs = multi_modal_inputs
...@@ -369,6 +380,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -369,6 +380,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.input_tokens = [[] for _ in range(self.n_seqs)] self.input_tokens = [[] for _ in range(self.n_seqs)]
self.input_positions = [[] for _ in range(self.n_seqs)] self.input_positions = [[] for _ in range(self.n_seqs)]
self.tree_attn_masks = [None for _ in range(self.n_seqs)]
self.mrope_input_positions = None self.mrope_input_positions = None
self.seq_lens = [0] * self.n_seqs self.seq_lens = [0] * self.n_seqs
self.orig_seq_lens = [0] * self.n_seqs self.orig_seq_lens = [0] * self.n_seqs
...@@ -502,6 +514,16 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -502,6 +514,16 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
else: else:
inter_data.input_positions[seq_idx].extend( inter_data.input_positions[seq_idx].extend(
range(context_len, seq_len)) range(context_len, seq_len))
if seq_group_metadata.tree_position_ids is None:
if (seq_len - context_len) == 1:
inter_data.input_positions[seq_idx].append(seq_len - 1)
else:
inter_data.input_positions[seq_idx].extend(
range(context_len, seq_len))
else:
inter_data.input_positions[seq_idx] = seq_group_metadata.tree_position_ids.cpu().tolist()
inter_data.tree_attn_masks[seq_idx] = seq_group_metadata.tree_attn_masks
inter_data.query_lens[ inter_data.query_lens[
seq_idx] = seq_len - context_len if inter_data.is_prompt else 1 seq_idx] = seq_len - context_len if inter_data.is_prompt else 1
...@@ -835,10 +857,25 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -835,10 +857,25 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Sequence and query lengths. # Sequence and query lengths.
if cuda_graph_pad_size: if cuda_graph_pad_size:
seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size)) seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size))
# prepare tree attention masks
max_context_len = 0
for inter_data in self.inter_data_list:
max_context_len = max(max_context_len, max(inter_data.context_lens))
tree_attention_masks_list = []
for inter_data in self.inter_data_list:
for i in range(len(inter_data.seq_lens)):
tree_attn_masks = inter_data.tree_attn_masks[i]
if tree_attn_masks is not None:
tree_attention_masks_list.append(tree_attn_masks)
tree_attention_masks_tensor = None
if len(tree_attention_masks_list) > 0:
tree_attention_masks_tensor = torch.stack(tree_attention_masks_list, dim=0)
# Attention metadata. # Attention metadata.
attn_metadata = self.attn_metadata_builder.build( attn_metadata = self.attn_metadata_builder.build(
seq_lens, query_lens, cuda_graph_pad_size, batch_size) seq_lens, query_lens, cuda_graph_pad_size, batch_size,
tree_attention_masks_tensor=tree_attention_masks_tensor)
# LoRA data. # LoRA data.
lora_requests = set() lora_requests = set()
......
...@@ -10,6 +10,7 @@ from vllm.distributed import (ensure_model_parallel_initialized, ...@@ -10,6 +10,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.neuron_model_runner import NeuronModelRunner from vllm.worker.neuron_model_runner import NeuronModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerInput) LoraNotSupportedWorkerBase, WorkerInput)
...@@ -92,6 +93,10 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -92,6 +93,10 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
@property @property
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return None return None
@property
def cache_engines(self) -> Optional[List[CacheEngine]]:
return None
@torch.inference_mode() @torch.inference_mode()
def prepare_worker_input( def prepare_worker_input(
......
...@@ -15,6 +15,7 @@ from vllm.model_executor import set_random_seed ...@@ -15,6 +15,7 @@ from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
from vllm.worker.tpu_model_runner import TPUModelRunner from vllm.worker.tpu_model_runner import TPUModelRunner
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerInput) LoraNotSupportedWorkerBase, WorkerInput)
...@@ -212,6 +213,10 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -212,6 +213,10 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
# NOTE(woosuk): This assumes virtual_engine == 0, i.e., no pipeline # NOTE(woosuk): This assumes virtual_engine == 0, i.e., no pipeline
# parallelism. # parallelism.
return [self.tpu_cache] return [self.tpu_cache]
@property
def cache_engines(self) -> Optional[List[CacheEngine]]:
return None
def prepare_worker_input( def prepare_worker_input(
self, self,
......
...@@ -291,6 +291,10 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -291,6 +291,10 @@ class Worker(LocalOrDistributedWorkerBase):
@property @property
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return self.gpu_cache return self.gpu_cache
@property
def cache_engines(self) -> Optional[List[CacheEngine]]:
return self.cache_engine
@torch.inference_mode() @torch.inference_mode()
def prepare_worker_input( def prepare_worker_input(
......
...@@ -16,6 +16,7 @@ from vllm.platforms import current_platform ...@@ -16,6 +16,7 @@ from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import (enable_trace_function_call_for_thread, from vllm.utils import (enable_trace_function_call_for_thread,
update_environment_variables) update_environment_variables)
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner_base import (BroadcastableModelInput, from vllm.worker.model_runner_base import (BroadcastableModelInput,
ModelRunnerBase, ModelRunnerBase,
ModelRunnerInputBase) ModelRunnerInputBase)
...@@ -29,6 +30,8 @@ class WorkerBase(ABC): ...@@ -29,6 +30,8 @@ class WorkerBase(ABC):
communicate request metadata to other workers. communicate request metadata to other workers.
""" """
model_input: Optional[ModelRunnerInputBase] = None
@abstractmethod @abstractmethod
def init_device(self) -> None: def init_device(self) -> None:
"""Initialize device state, such as loading the model or other on-device """Initialize device state, such as loading the model or other on-device
...@@ -99,6 +102,23 @@ class WorkerBase(ABC): ...@@ -99,6 +102,23 @@ class WorkerBase(ABC):
@abstractmethod @abstractmethod
def list_loras(self) -> Set[int]: def list_loras(self) -> Set[int]:
raise NotImplementedError raise NotImplementedError
@property
@abstractmethod
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
"""
Gets the list of kv caches to pass to the worker's model runner. Each
element in the list is a kv cache corresponding to a particular virtual
engine (PP stream). Used by the default `execute_model`. If the worker's
model runner does not follow the ModelRunnerBase interface, then inherit
from WorkerBase instead.
"""
raise NotImplementedError
@property
@abstractmethod
def cache_engines(self) -> Optional[List[CacheEngine]]:
raise NotImplementedError
class LoraNotSupportedWorkerBase(WorkerBase): class LoraNotSupportedWorkerBase(WorkerBase):
...@@ -118,6 +138,14 @@ class LoraNotSupportedWorkerBase(WorkerBase): ...@@ -118,6 +138,14 @@ class LoraNotSupportedWorkerBase(WorkerBase):
def list_loras(self) -> Set[int]: def list_loras(self) -> Set[int]:
raise ValueError(f"{type(self)} does not support LoRA") raise ValueError(f"{type(self)} does not support LoRA")
@property
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return None
@property
def cache_engines(self) -> Optional[List[CacheEngine]]:
return None
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
...@@ -249,6 +277,15 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -249,6 +277,15 @@ class LocalOrDistributedWorkerBase(WorkerBase):
worker_input: WorkerInput = self.prepare_worker_input( worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req) execute_model_req=execute_model_req)
# set tree_attn_masks and position ids to seq_group_metadata_list
if execute_model_req.tree_attn_masks is not None:
for i, seq_group_metadata in enumerate(execute_model_req.seq_group_metadata_list):
# seq_group_metadata.tree_attn_masks = execute_model_req.tree_attn_masks[i]
# seq_group_metadata.tree_position_ids = execute_model_req.tree_position_ids[i]
seq_group_metadata.set_tree_style_args(tree_attn_masks=execute_model_req.tree_attn_masks[i],
tree_position_ids=execute_model_req.tree_position_ids[i])
model_input: ModelRunnerInputBase = ( model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input( self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list, execute_model_req.seq_group_metadata_list,
...@@ -307,6 +344,8 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -307,6 +344,8 @@ class LocalOrDistributedWorkerBase(WorkerBase):
model_input, worker_input, kwargs = inputs model_input, worker_input, kwargs = inputs
num_steps = worker_input.num_steps num_steps = worker_input.num_steps
self.model_input = model_input
self.execute_worker(worker_input) self.execute_worker(worker_input)
# If there is no input, we don't need to execute the model. # If there is no input, we don't need to execute the model.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment