Commit 4a4e3601 authored by 王敏's avatar 王敏
Browse files

[feat]并行解码支持多卡推理

parent a1592b87
import os
import math
from typing import Iterable, List, Tuple
from typing import Iterable, List, Tuple, Optional
import torch
import torch.nn as nn
......@@ -10,9 +10,11 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs import MLPSpeculatorConfig
from vllm import _custom_ops as ops
from vllm.distributed import tensor_model_parallel_all_gather, tensor_model_parallel_gather
SQRT2 = 2**0.5
......@@ -95,8 +97,16 @@ class MLPSpeculator(nn.Module):
# the initial projection from the base model may
# have a different size, so that stays separate.
proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False)
proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False)
# proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False)
# proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False)
proj_first = ColumnParallelLinear(input_size=self.emb_dim,
output_size=self.inner_dim,
bias=False,
gather_output=True)
proj_tied = ColumnParallelLinear(input_size=self.inner_dim,
output_size=self.inner_dim,
bias=False,
gather_output=True)
self.proj = nn.ModuleList([proj_first] + [proj_tied] *
(self.max_speculative_tokens - 1))
......@@ -116,9 +126,10 @@ class MLPSpeculator(nn.Module):
])
self.proj = nn.ModuleList([
nn.Linear((self.emb_dim if i == 0 else self.inner_dim),
self.inner_dim,
bias=False)
ColumnParallelLinear(input_size=(self.emb_dim if i == 0 else self.inner_dim),
output_size=self.inner_dim,
bias=False,
gather_output=True)
for i in range(self.max_speculative_tokens)
])
......@@ -150,28 +161,19 @@ class MLPSpeculator(nn.Module):
previous_hidden_states: torch.Tensor,
num_predict_tokens: int,
sampling_metadata: SamplingMetadata,
) -> List[SamplerOutput]:
head_index: int
) -> Tuple[Optional[SamplerOutput], Optional[torch.Tensor]]:
if num_predict_tokens > self.max_speculative_tokens:
raise ValueError(f"Max speculative tokens for model is "
f"{self.max_speculative_tokens}, but "
f"{num_predict_tokens} were requested")
# b x 1 x d
previous_hidden_states = previous_hidden_states.unsqueeze(1)
if self.scale_input:
previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2
# b x 1
last_tokens = input_ids.unsqueeze(1)
next_tokens = []
for head_index in range(num_predict_tokens):
# Project and predict
z = self.emb[head_index](last_tokens) # b k d
states = self.proj[head_index](previous_hidden_states)
z = self.emb[head_index](input_ids) # b k d
states, _ = self.proj[head_index](previous_hidden_states)
# Weighted add of state_weight*state and emb_weight*z
# Let subsequent LN take care of denominator
......@@ -183,14 +185,19 @@ class MLPSpeculator(nn.Module):
# TODO: not yet supporting top_k_tokens_per_head
states = states.flatten(0, 1)
# sampling_metadata is not None indicates that driver card is running
if sampling_metadata is not None:
logits = self.logits_processor(self.head[head_index], states,
sampling_metadata)
output = self.sampler(logits, sampling_metadata)
last_tokens = output.sampled_token_ids
next_tokens.append(output)
return next_tokens
return output, previous_hidden_states
else:
logits = self.head[head_index].linear_method.apply(self.head[head_index],
states,
bias=None)
logits = tensor_model_parallel_gather(logits)
return None, None
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
......@@ -201,7 +208,8 @@ class MLPSpeculator(nn.Module):
default_weight_loader)
weight_loader(param, loaded_weight)
if self.use_llama_nn and os.environ['LM_NN'] == '1' and "head" in name:
if self.use_llama_nn:
if (os.environ['LM_NN'] == '1' and "head" in name) or "proj" in name:
_weight = torch.zeros_like(param.data)
ori_shape =_weight.shape
......
......@@ -144,7 +144,7 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
medusa_buffers=self.medusa_buffers)
# create tree attn masks
if self.medusa_buffers is not None:
if self.is_driver_worker and self.medusa_buffers is not None:
seq_lens = tensor_dict["seq_lens"]
max_context_len = max(seq_lens)
for sampler_output, seq_len in zip(model_outputs, seq_lens):
......
from typing import List, Optional, Set, Tuple
from typing import List, Optional, Set, Tuple, Dict
import torch
......@@ -7,6 +7,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.distributed import broadcast_tensor_dict
class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
......@@ -15,6 +16,58 @@ class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
Not currently compatible with LoRA or chunked prefill.
"""
def _get_driver_input_and_broadcast(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
index: int,
last_tokens: Optional[torch.Tensor]=None,
previous_hidden_states: Optional[torch.Tensor]=None,
sampling_metadata: Optional[SamplingMetadata]=None
) -> Dict[str, torch.Tensor]:
if sampling_metadata is None and execute_model_req is not None:
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
(input_tokens, seq_lens,
query_lens) = self._prepare_input_tensors(seq_group_metadata_list)
# b x 1
last_tokens = input_tokens.unsqueeze(1)
generators = self.model_runner.get_generators(
execute_model_req.finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.model_runner.pin_memory, generators)
previous_hidden_states = execute_model_req.previous_hidden_states.hidden_states
# b x 1 x d
previous_hidden_states = previous_hidden_states.unsqueeze(1)
tensor_dict = {
"input_tokens": last_tokens,
"previous_hidden_states": previous_hidden_states,
"sample_len": sample_len,
"head_index": index
}
if self.do_metadata_broadcast:
broadcast_tensor_dict(tensor_dict, src=0)
return tensor_dict, sampling_metadata
def _get_worker_input_from_broadcast(
self
) -> Optional[Dict[str, torch.Tensor]]:
""" Get the worker input from the broadcasted tensor dict. """
assert self.do_metadata_broadcast
assert not self.is_driver_worker
broadcast_data = broadcast_tensor_dict(src=0)
return broadcast_data
@torch.inference_mode()
def sampler_output(
self,
......@@ -33,24 +86,42 @@ class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
"""
self._raise_if_unsupported(execute_model_req)
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
(input_tokens, seq_lens,
query_lens) = self._prepare_input_tensors(seq_group_metadata_list)
generators = self.model_runner.get_generators(
execute_model_req.finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.model_runner.pin_memory, generators)
model_outputs = self.model_runner.model.generate_proposals(
input_ids=input_tokens,
previous_hidden_states=execute_model_req.previous_hidden_states.
hidden_states,
num_predict_tokens=sample_len,
sampling_metadata=sampling_metadata)
model_outputs = []
last_tokens = None
previous_hidden_states = None
sampling_metadata = None
for index in range(sample_len):
if self.is_driver_worker:
tensor_dict, sampling_metadata = self._get_driver_input_and_broadcast(execute_model_req,
sample_len,
index,
last_tokens,
previous_hidden_states,
sampling_metadata)
assert sampling_metadata is not None
output, previous_hidden_states = self.model_runner.model.generate_proposals(
input_ids=tensor_dict["input_tokens"],
previous_hidden_states=tensor_dict["previous_hidden_states"],
num_predict_tokens=tensor_dict["sample_len"],
sampling_metadata=sampling_metadata,
head_index=index)
last_tokens = output.sampled_token_ids
model_outputs.append(output)
else:
tensor_dict = self._get_worker_input_from_broadcast()
if tensor_dict is None:
raise ValueError("Can not get inputs of mlp_speculator worker!!!")
self.model_runner.model.generate_proposals(
input_ids=tensor_dict["input_tokens"],
previous_hidden_states=tensor_dict["previous_hidden_states"],
num_predict_tokens=tensor_dict["sample_len"],
sampling_metadata=None,
head_index=tensor_dict["head_index"])
if self.is_driver_worker:
assert len(model_outputs) == sample_len
return model_outputs, True
......
......@@ -350,6 +350,9 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
"""MultiStepWorker does not yet implement support for cache swap
operations or beam search.
"""
if execute_model_req is None:
return None
if any([
execute_model_req.blocks_to_swap_in,
execute_model_req.blocks_to_swap_out,
......
......@@ -38,6 +38,7 @@ from vllm.worker.worker import Worker
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
from vllm.worker.cache_engine import CacheEngine
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
logger = init_logger(__name__)
......@@ -134,6 +135,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
ngram_prompt_lookup_min = (
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
if ngram_prompt_lookup_max > 0:
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
'parallel_config']
assert draft_parallel_config.tensor_parallel_size == 1
proposer_worker = NGramWorker(**draft_worker_kwargs)
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
ngram_prompt_lookup_max)
......@@ -608,7 +612,22 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.scorer_worker.execute_model()
if not data["disable_all_speculation"]:
if not self.tree_style_spec_decoding:
# if not self.tree_style_spec_decoding:
# # Even if num_lookahead_slots is zero, we want to run the
# # proposer model as it may have KV.
# #
# # We run the proposer once per lookahead slot. In the future we
# # should delegate how many times it runs to the proposer.
# for _ in range(max(num_lookahead_slots, 1)):
# self.proposer_worker.execute_model()
# else:
# if not data["no_spec"]:
# self.proposer_worker.sampler_output(None, None, None)
if issubclass(type(self.proposer_worker), NonLLMProposerWorkerBase):
if not data["no_spec"]:
self.proposer_worker.sampler_output(None, num_lookahead_slots, None)
else:
# Even if num_lookahead_slots is zero, we want to run the
# proposer model as it may have KV.
#
......@@ -616,9 +635,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# should delegate how many times it runs to the proposer.
for _ in range(max(num_lookahead_slots, 1)):
self.proposer_worker.execute_model()
else:
if not data["no_spec"]:
self.proposer_worker.sampler_output(None, None, None)
if not data["no_spec"]:
self.scorer_worker.execute_model()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment