Commit 19bc93d9 authored by 王敏's avatar 王敏
Browse files

增加medusa并行解码功能,后续增加使用说明和测试文档

parent aba40fda
......@@ -13,6 +13,7 @@ import numpy as np
import torch
import torch.distributed
import torch.nn as nn
import torch.nn.functional as F
import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend
......@@ -197,6 +198,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.lora_requests.clear() # type: ignore
self.prompt_adapter_index_mapping.clear() # type: ignore
self.prompt_adapter_prompt_mapping.clear() # type: ignore
self.tree_attn_masks[0] = None # type: ignore
def __init__(
self,
......@@ -244,6 +246,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
reinit: bool = False,
reinit_use_defaults: bool = False,
encoder_seq_len: int = 0,
# attention mask used in tree-style generation
tree_attn_masks: Optional[List[torch.Tensor]] = None,
):
if reinit:
assert len(self.seq_ids) == len(seq_ids) # type: ignore
......@@ -335,6 +340,11 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
else:
self.prompt_adapter_prompt_mapping.clear()
if tree_attn_masks:
self.tree_attn_masks = tree_attn_masks
else:
self.tree_attn_masks.clear()
else:
self.input_tokens = input_tokens or []
self.input_positions = input_positions or []
......@@ -354,6 +364,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
prompt_adapter_index_mapping or [])
self.prompt_adapter_prompt_mapping = (
prompt_adapter_prompt_mapping or [])
self.tree_attn_masks = tree_attn_masks or []
self.prompt_adapter_request = prompt_adapter_request
self.multi_modal_inputs = multi_modal_inputs
......@@ -369,6 +380,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.input_tokens = [[] for _ in range(self.n_seqs)]
self.input_positions = [[] for _ in range(self.n_seqs)]
self.tree_attn_masks = [None for _ in range(self.n_seqs)]
self.mrope_input_positions = None
self.seq_lens = [0] * self.n_seqs
self.orig_seq_lens = [0] * self.n_seqs
......@@ -502,6 +514,16 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
else:
inter_data.input_positions[seq_idx].extend(
range(context_len, seq_len))
if seq_group_metadata.tree_position_ids is None:
if (seq_len - context_len) == 1:
inter_data.input_positions[seq_idx].append(seq_len - 1)
else:
inter_data.input_positions[seq_idx].extend(
range(context_len, seq_len))
else:
inter_data.input_positions[seq_idx] = seq_group_metadata.tree_position_ids.cpu().tolist()
inter_data.tree_attn_masks[seq_idx] = seq_group_metadata.tree_attn_masks
inter_data.query_lens[
seq_idx] = seq_len - context_len if inter_data.is_prompt else 1
......@@ -835,10 +857,25 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Sequence and query lengths.
if cuda_graph_pad_size:
seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size))
# prepare tree attention masks
max_context_len = 0
for inter_data in self.inter_data_list:
max_context_len = max(max_context_len, max(inter_data.context_lens))
tree_attention_masks_list = []
for inter_data in self.inter_data_list:
for i in range(len(inter_data.seq_lens)):
tree_attn_masks = inter_data.tree_attn_masks[i]
if tree_attn_masks is not None:
tree_attention_masks_list.append(tree_attn_masks)
tree_attention_masks_tensor = None
if len(tree_attention_masks_list) > 0:
tree_attention_masks_tensor = torch.stack(tree_attention_masks_list, dim=0)
# Attention metadata.
attn_metadata = self.attn_metadata_builder.build(
seq_lens, query_lens, cuda_graph_pad_size, batch_size)
seq_lens, query_lens, cuda_graph_pad_size, batch_size,
tree_attention_masks_tensor=tree_attention_masks_tensor)
# LoRA data.
lora_requests = set()
......
......@@ -10,6 +10,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.neuron_model_runner import NeuronModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerInput)
......@@ -92,6 +93,10 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
@property
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return None
@property
def cache_engines(self) -> Optional[List[CacheEngine]]:
return None
@torch.inference_mode()
def prepare_worker_input(
......
......@@ -15,6 +15,7 @@ from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
from vllm.worker.tpu_model_runner import TPUModelRunner
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerInput)
......@@ -212,6 +213,10 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
# NOTE(woosuk): This assumes virtual_engine == 0, i.e., no pipeline
# parallelism.
return [self.tpu_cache]
@property
def cache_engines(self) -> Optional[List[CacheEngine]]:
return None
def prepare_worker_input(
self,
......
......@@ -291,6 +291,10 @@ class Worker(LocalOrDistributedWorkerBase):
@property
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return self.gpu_cache
@property
def cache_engines(self) -> Optional[List[CacheEngine]]:
return self.cache_engine
@torch.inference_mode()
def prepare_worker_input(
......
......@@ -16,6 +16,7 @@ from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import (enable_trace_function_call_for_thread,
update_environment_variables)
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner_base import (BroadcastableModelInput,
ModelRunnerBase,
ModelRunnerInputBase)
......@@ -29,6 +30,8 @@ class WorkerBase(ABC):
communicate request metadata to other workers.
"""
model_input: Optional[ModelRunnerInputBase] = None
@abstractmethod
def init_device(self) -> None:
"""Initialize device state, such as loading the model or other on-device
......@@ -99,6 +102,23 @@ class WorkerBase(ABC):
@abstractmethod
def list_loras(self) -> Set[int]:
raise NotImplementedError
@property
@abstractmethod
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
"""
Gets the list of kv caches to pass to the worker's model runner. Each
element in the list is a kv cache corresponding to a particular virtual
engine (PP stream). Used by the default `execute_model`. If the worker's
model runner does not follow the ModelRunnerBase interface, then inherit
from WorkerBase instead.
"""
raise NotImplementedError
@property
@abstractmethod
def cache_engines(self) -> Optional[List[CacheEngine]]:
raise NotImplementedError
class LoraNotSupportedWorkerBase(WorkerBase):
......@@ -118,6 +138,14 @@ class LoraNotSupportedWorkerBase(WorkerBase):
def list_loras(self) -> Set[int]:
raise ValueError(f"{type(self)} does not support LoRA")
@property
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return None
@property
def cache_engines(self) -> Optional[List[CacheEngine]]:
return None
@dataclasses.dataclass(frozen=True)
......@@ -249,6 +277,15 @@ class LocalOrDistributedWorkerBase(WorkerBase):
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
# set tree_attn_masks and position ids to seq_group_metadata_list
if execute_model_req.tree_attn_masks is not None:
for i, seq_group_metadata in enumerate(execute_model_req.seq_group_metadata_list):
# seq_group_metadata.tree_attn_masks = execute_model_req.tree_attn_masks[i]
# seq_group_metadata.tree_position_ids = execute_model_req.tree_position_ids[i]
seq_group_metadata.set_tree_style_args(tree_attn_masks=execute_model_req.tree_attn_masks[i],
tree_position_ids=execute_model_req.tree_position_ids[i])
model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list,
......@@ -307,6 +344,8 @@ class LocalOrDistributedWorkerBase(WorkerBase):
model_input, worker_input, kwargs = inputs
num_steps = worker_input.num_steps
self.model_input = model_input
self.execute_worker(worker_input)
# If there is no input, we don't need to execute the model.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment