Commit 1fe1a4b6 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.6.2-dev_wm' into 'v0.6.2-dev'

[feat]并行解码支持多卡推理

See merge request dcutoolkit/deeplearing/vllm!48
parents a1592b87 4a4e3601
import os import os
import math import math
from typing import Iterable, List, Tuple from typing import Iterable, List, Tuple, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -10,9 +10,11 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -10,9 +10,11 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs import MLPSpeculatorConfig from vllm.transformers_utils.configs import MLPSpeculatorConfig
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.distributed import tensor_model_parallel_all_gather, tensor_model_parallel_gather
SQRT2 = 2**0.5 SQRT2 = 2**0.5
...@@ -95,8 +97,16 @@ class MLPSpeculator(nn.Module): ...@@ -95,8 +97,16 @@ class MLPSpeculator(nn.Module):
# the initial projection from the base model may # the initial projection from the base model may
# have a different size, so that stays separate. # have a different size, so that stays separate.
proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False) # proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False)
proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False) # proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False)
proj_first = ColumnParallelLinear(input_size=self.emb_dim,
output_size=self.inner_dim,
bias=False,
gather_output=True)
proj_tied = ColumnParallelLinear(input_size=self.inner_dim,
output_size=self.inner_dim,
bias=False,
gather_output=True)
self.proj = nn.ModuleList([proj_first] + [proj_tied] * self.proj = nn.ModuleList([proj_first] + [proj_tied] *
(self.max_speculative_tokens - 1)) (self.max_speculative_tokens - 1))
...@@ -116,9 +126,10 @@ class MLPSpeculator(nn.Module): ...@@ -116,9 +126,10 @@ class MLPSpeculator(nn.Module):
]) ])
self.proj = nn.ModuleList([ self.proj = nn.ModuleList([
nn.Linear((self.emb_dim if i == 0 else self.inner_dim), ColumnParallelLinear(input_size=(self.emb_dim if i == 0 else self.inner_dim),
self.inner_dim, output_size=self.inner_dim,
bias=False) bias=False,
gather_output=True)
for i in range(self.max_speculative_tokens) for i in range(self.max_speculative_tokens)
]) ])
...@@ -150,47 +161,43 @@ class MLPSpeculator(nn.Module): ...@@ -150,47 +161,43 @@ class MLPSpeculator(nn.Module):
previous_hidden_states: torch.Tensor, previous_hidden_states: torch.Tensor,
num_predict_tokens: int, num_predict_tokens: int,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> List[SamplerOutput]: head_index: int
) -> Tuple[Optional[SamplerOutput], Optional[torch.Tensor]]:
if num_predict_tokens > self.max_speculative_tokens: if num_predict_tokens > self.max_speculative_tokens:
raise ValueError(f"Max speculative tokens for model is " raise ValueError(f"Max speculative tokens for model is "
f"{self.max_speculative_tokens}, but " f"{self.max_speculative_tokens}, but "
f"{num_predict_tokens} were requested") f"{num_predict_tokens} were requested")
# b x 1 x d
previous_hidden_states = previous_hidden_states.unsqueeze(1)
if self.scale_input: if self.scale_input:
previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2 previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2
# b x 1 # Project and predict
last_tokens = input_ids.unsqueeze(1) z = self.emb[head_index](input_ids) # b k d
states, _ = self.proj[head_index](previous_hidden_states)
next_tokens = []
for head_index in range(num_predict_tokens):
# Project and predict # Weighted add of state_weight*state and emb_weight*z
z = self.emb[head_index](last_tokens) # b k d # Let subsequent LN take care of denominator
states = self.proj[head_index](previous_hidden_states) # state_weight is close to 1, so shouldn't be any precision issues
states.add_(z, alpha=self.emb_weight / self.state_weight)
# Weighted add of state_weight*state and emb_weight*z states = self.activation(self.ln[head_index](states)) # b k d
# Let subsequent LN take care of denominator previous_hidden_states = states
# state_weight is close to 1, so shouldn't be any precision issues # TODO: not yet supporting top_k_tokens_per_head
states.add_(z, alpha=self.emb_weight / self.state_weight) states = states.flatten(0, 1)
states = self.activation(self.ln[head_index](states)) # b k d
previous_hidden_states = states
# TODO: not yet supporting top_k_tokens_per_head
states = states.flatten(0, 1)
# sampling_metadata is not None indicates that driver card is running
if sampling_metadata is not None:
logits = self.logits_processor(self.head[head_index], states, logits = self.logits_processor(self.head[head_index], states,
sampling_metadata) sampling_metadata)
output = self.sampler(logits, sampling_metadata)
last_tokens = output.sampled_token_ids
next_tokens.append(output)
return next_tokens output = self.sampler(logits, sampling_metadata)
return output, previous_hidden_states
else:
logits = self.head[head_index].linear_method.apply(self.head[head_index],
states,
bias=None)
logits = tensor_model_parallel_gather(logits)
return None, None
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
...@@ -201,11 +208,12 @@ class MLPSpeculator(nn.Module): ...@@ -201,11 +208,12 @@ class MLPSpeculator(nn.Module):
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if self.use_llama_nn and os.environ['LM_NN'] == '1' and "head" in name: if self.use_llama_nn:
_weight = torch.zeros_like(param.data) if (os.environ['LM_NN'] == '1' and "head" in name) or "proj" in name:
ori_shape =_weight.shape _weight = torch.zeros_like(param.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, param.data, _weight.shape[0], _weight.shape[1])
param.data.copy_(_weight) ops.trans_w16_gemm(_weight, param.data, _weight.shape[0], _weight.shape[1])
param.data.copy_(_weight)
param.data=param.data.reshape(ori_shape[1],-1)
param.data=param.data.reshape(ori_shape[1],-1)
...@@ -144,7 +144,7 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker): ...@@ -144,7 +144,7 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
medusa_buffers=self.medusa_buffers) medusa_buffers=self.medusa_buffers)
# create tree attn masks # create tree attn masks
if self.medusa_buffers is not None: if self.is_driver_worker and self.medusa_buffers is not None:
seq_lens = tensor_dict["seq_lens"] seq_lens = tensor_dict["seq_lens"]
max_context_len = max(seq_lens) max_context_len = max(seq_lens)
for sampler_output, seq_len in zip(model_outputs, seq_lens): for sampler_output, seq_len in zip(model_outputs, seq_lens):
......
from typing import List, Optional, Set, Tuple from typing import List, Optional, Set, Tuple, Dict
import torch import torch
...@@ -7,6 +7,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput ...@@ -7,6 +7,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.distributed import broadcast_tensor_dict
class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker): class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
...@@ -15,6 +16,58 @@ class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker): ...@@ -15,6 +16,58 @@ class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
Not currently compatible with LoRA or chunked prefill. Not currently compatible with LoRA or chunked prefill.
""" """
def _get_driver_input_and_broadcast(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
index: int,
last_tokens: Optional[torch.Tensor]=None,
previous_hidden_states: Optional[torch.Tensor]=None,
sampling_metadata: Optional[SamplingMetadata]=None
) -> Dict[str, torch.Tensor]:
if sampling_metadata is None and execute_model_req is not None:
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
(input_tokens, seq_lens,
query_lens) = self._prepare_input_tensors(seq_group_metadata_list)
# b x 1
last_tokens = input_tokens.unsqueeze(1)
generators = self.model_runner.get_generators(
execute_model_req.finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.model_runner.pin_memory, generators)
previous_hidden_states = execute_model_req.previous_hidden_states.hidden_states
# b x 1 x d
previous_hidden_states = previous_hidden_states.unsqueeze(1)
tensor_dict = {
"input_tokens": last_tokens,
"previous_hidden_states": previous_hidden_states,
"sample_len": sample_len,
"head_index": index
}
if self.do_metadata_broadcast:
broadcast_tensor_dict(tensor_dict, src=0)
return tensor_dict, sampling_metadata
def _get_worker_input_from_broadcast(
self
) -> Optional[Dict[str, torch.Tensor]]:
""" Get the worker input from the broadcasted tensor dict. """
assert self.do_metadata_broadcast
assert not self.is_driver_worker
broadcast_data = broadcast_tensor_dict(src=0)
return broadcast_data
@torch.inference_mode() @torch.inference_mode()
def sampler_output( def sampler_output(
self, self,
...@@ -33,25 +86,43 @@ class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker): ...@@ -33,25 +86,43 @@ class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
""" """
self._raise_if_unsupported(execute_model_req) self._raise_if_unsupported(execute_model_req)
seq_group_metadata_list = execute_model_req.seq_group_metadata_list model_outputs = []
last_tokens = None
(input_tokens, seq_lens, previous_hidden_states = None
query_lens) = self._prepare_input_tensors(seq_group_metadata_list) sampling_metadata = None
generators = self.model_runner.get_generators( for index in range(sample_len):
execute_model_req.finished_requests_ids) if self.is_driver_worker:
sampling_metadata = SamplingMetadata.prepare( tensor_dict, sampling_metadata = self._get_driver_input_and_broadcast(execute_model_req,
seq_group_metadata_list, seq_lens, query_lens, self.device, sample_len,
self.model_runner.pin_memory, generators) index,
last_tokens,
model_outputs = self.model_runner.model.generate_proposals( previous_hidden_states,
input_ids=input_tokens, sampling_metadata)
previous_hidden_states=execute_model_req.previous_hidden_states. assert sampling_metadata is not None
hidden_states,
num_predict_tokens=sample_len, output, previous_hidden_states = self.model_runner.model.generate_proposals(
sampling_metadata=sampling_metadata) input_ids=tensor_dict["input_tokens"],
previous_hidden_states=tensor_dict["previous_hidden_states"],
assert len(model_outputs) == sample_len num_predict_tokens=tensor_dict["sample_len"],
sampling_metadata=sampling_metadata,
head_index=index)
last_tokens = output.sampled_token_ids
model_outputs.append(output)
else:
tensor_dict = self._get_worker_input_from_broadcast()
if tensor_dict is None:
raise ValueError("Can not get inputs of mlp_speculator worker!!!")
self.model_runner.model.generate_proposals(
input_ids=tensor_dict["input_tokens"],
previous_hidden_states=tensor_dict["previous_hidden_states"],
num_predict_tokens=tensor_dict["sample_len"],
sampling_metadata=None,
head_index=tensor_dict["head_index"])
if self.is_driver_worker:
assert len(model_outputs) == sample_len
return model_outputs, True return model_outputs, True
......
...@@ -350,6 +350,9 @@ class MultiStepWorker(Worker, ProposerWorkerBase): ...@@ -350,6 +350,9 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
"""MultiStepWorker does not yet implement support for cache swap """MultiStepWorker does not yet implement support for cache swap
operations or beam search. operations or beam search.
""" """
if execute_model_req is None:
return None
if any([ if any([
execute_model_req.blocks_to_swap_in, execute_model_req.blocks_to_swap_in,
execute_model_req.blocks_to_swap_out, execute_model_req.blocks_to_swap_out,
......
...@@ -38,6 +38,7 @@ from vllm.worker.worker import Worker ...@@ -38,6 +38,7 @@ from vllm.worker.worker import Worker
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -134,6 +135,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -134,6 +135,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
ngram_prompt_lookup_min = ( ngram_prompt_lookup_min = (
draft_worker_kwargs.pop("ngram_prompt_lookup_min")) draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
if ngram_prompt_lookup_max > 0: if ngram_prompt_lookup_max > 0:
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
'parallel_config']
assert draft_parallel_config.tensor_parallel_size == 1
proposer_worker = NGramWorker(**draft_worker_kwargs) proposer_worker = NGramWorker(**draft_worker_kwargs)
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
ngram_prompt_lookup_max) ngram_prompt_lookup_max)
...@@ -608,7 +612,22 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -608,7 +612,22 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.scorer_worker.execute_model() self.scorer_worker.execute_model()
if not data["disable_all_speculation"]: if not data["disable_all_speculation"]:
if not self.tree_style_spec_decoding: # if not self.tree_style_spec_decoding:
# # Even if num_lookahead_slots is zero, we want to run the
# # proposer model as it may have KV.
# #
# # We run the proposer once per lookahead slot. In the future we
# # should delegate how many times it runs to the proposer.
# for _ in range(max(num_lookahead_slots, 1)):
# self.proposer_worker.execute_model()
# else:
# if not data["no_spec"]:
# self.proposer_worker.sampler_output(None, None, None)
if issubclass(type(self.proposer_worker), NonLLMProposerWorkerBase):
if not data["no_spec"]:
self.proposer_worker.sampler_output(None, num_lookahead_slots, None)
else:
# Even if num_lookahead_slots is zero, we want to run the # Even if num_lookahead_slots is zero, we want to run the
# proposer model as it may have KV. # proposer model as it may have KV.
# #
...@@ -616,9 +635,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -616,9 +635,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# should delegate how many times it runs to the proposer. # should delegate how many times it runs to the proposer.
for _ in range(max(num_lookahead_slots, 1)): for _ in range(max(num_lookahead_slots, 1)):
self.proposer_worker.execute_model() self.proposer_worker.execute_model()
else:
if not data["no_spec"]:
self.proposer_worker.sampler_output(None, None, None)
if not data["no_spec"]: if not data["no_spec"]:
self.scorer_worker.execute_model() self.scorer_worker.execute_model()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment