Commit 201768d5 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.6.2-dev_wm' into 'v0.6.2-dev'

优化medusa 推理

See merge request dcutoolkit/deeplearing/vllm!41
parents 87a2e37f 28375803
...@@ -304,7 +304,9 @@ def main(args: argparse.Namespace): ...@@ -304,7 +304,9 @@ def main(args: argparse.Namespace):
else: else:
batch_sizes = [args.batch_size] batch_sizes = [args.batch_size]
ray.init() ray.init(address=None,
ignore_reinit_error=True,
num_gpus=args.tp_size)
num_gpus = int(ray.available_resources()["GPU"]) num_gpus = int(ray.available_resources()["GPU"])
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)] workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
......
...@@ -17,9 +17,9 @@ Vllm medusa model的实现在[vllm/model_executor/models/medusa.py] ...@@ -17,9 +17,9 @@ Vllm medusa model的实现在[vllm/model_executor/models/medusa.py]
# medusa 模型需要转换为vllm中Medusa的模型格式 # medusa 模型需要转换为vllm中Medusa的模型格式
```bash ```bash
python medusa_weight_converter.py --medusa_num_heads 4 --medusa_num_layers 1 --medusa_model_path /work/medusa/qwen2_72b_head_4/adapter_model.bin --vocab_size 152064 --hidden_size 8192 --output_dir /work/medusa/sugon/vllm-medusa-qwen2-72b-head-4 --medusa_choices="[(0,), (0, 0), (0, 0, 0), (0, 0, 0, 0), (0, 1), (1,), (1, 0), (0, 0, 1), (0, 1, 0), (0, 2), (1, 0, 0), (2,), (2, 0), (0, 3), (0, 0, 2), (0, 2, 0), (0, 4), (0, 0, 1, 0), (0, 1, 0, 0), (2, 0, 0), (3,), (0, 5), (0, 0, 0, 1), (3, 0), (0, 0, 3), (1, 0, 0, 0), (0, 3, 0), (0, 6), (0, 0, 4), (0, 4, 0), (1, 1), (4,)]" python medusa_weight_converter.py --medusa_num_heads 4 --medusa_num_layers 1 --medusa_model_path /work/model.bin --vocab_size 152064 --hidden_size 8192 --output_dir /work/medusa/vllm-medusa-qwen2-72b-head-4 --medusa_choices="[(0), (0, 0), (0, 0, 0), (0, 1), (1), (1, 0), (0, 0, 0, 0), (0, 0, 1), (0, 2), (0, 1, 0), (2), (0, 0, 2), (0, 3), (1, 0, 0), (2, 0), (0, 2, 0), (0, 4), (0, 0, 3), (3), (0, 0, 0, 1), (0, 5), (0, 0, 1, 0), (0, 0, 4)]"
``` ```
此处qwen2_72b_head_4是medusa模型使用peft lora训练后保存的权重,其他格式也可参考[medusa_weight_converter.py]修改进行权重转换 此处model.bin是训练后保存的medusa head权重
### Run ### Run
...@@ -28,30 +28,34 @@ python medusa_weight_converter.py --medusa_num_heads 4 --medusa_num_layers 1 --m ...@@ -28,30 +28,34 @@ python medusa_weight_converter.py --medusa_num_heads 4 --medusa_num_layers 1 --m
python3 -m vllm.entrypoints.openai.api_server \ python3 -m vllm.entrypoints.openai.api_server \
--served-model-name qwen_medusa \ --served-model-name qwen_medusa \
--model /models/Qwen2-72B-Instruct/ -tp 4 \ --model /models/Qwen2-72B-Instruct/ -tp 4 \
--max-model-len 1024 --max-num-seqs 8 --gpu-memory-utilization 0.7 \ --max-model-len 1024 --max-num-seqs 8 --gpu-memory-utilization 0.8 \
--speculative-model /work/medusa/vllm-medusa-qwen2-72b-head-4 \ --speculative-model /work/medusa/vllm-medusa-qwen2-72b-head-4 \
--speculative-draft-tensor-parallel-size 4 \ --speculative-draft-tensor-parallel-size 4 \
--speculative-disable-by-batch-size 4 \ --speculative-disable-by-batch-size 4 \
--use-v2-block-manager \ --use-v2-block-manager \
--spec-decoding-acceptance-method typical_acceptance_sampler \ --spec-decoding-acceptance-method typical_acceptance_sampler \
--enforce-eager --dtype float16 --trust-remote-code --port 8086\ --dtype float16 --trust-remote-code --port 8086\
--enable-lora --lora-modules medusa-lora=/work/qwen2_72b_head_4 \
--max-lora-rank 32 --lora-extra-vocab-size 0 --merge-lora True \
--lora-target-modules qkv_proj \
--tree-style-spec-decoding True\ --tree-style-spec-decoding True\
--num-speculative-heads 4 --num-speculative-tokens 33 --num-speculative-heads 4 --num-speculative-tokens 24
``` ```
merge-lora可以将lora权重和base model权重融合,提升整体推理速度,若对精度有严格要求,可不设置此参数 merge-lora可以将lora权重和base model权重融合,提升整体推理速度,若对精度有严格要求,可不设置此参数
num-speculative-tokens和medusa choices的个数相关,num_speculative_tokens = len(medusa_choices) + 2 num-speculative-tokens和medusa choices的个数相关,num_speculative_tokens = len(medusa_choices) + 1
# do request # do request
```bash ```bash
curl http://localhost:8086/v1/completions \ curl http://localhost:8086/v1/completions \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"model": "medusa-lora", "model": "qwen_medusa",
"prompt": "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n帮我写一个C++的快速排序算法<|im_end|>\n<|im_start|>assistant\n", "prompt": "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n帮我写一个C++的快速排序算法<|im_end|>\n<|im_start|>assistant\n",
"max_tokens": 256, "max_tokens": 256,
"temperature": 0.0 "temperature": 0.0
}' }'
```bash ```
### benchmark
python medusa_benchmark_throughput.py --model /data/llm-models/qwen2/Qwen2-72B-Instruct/ -tp 4 --dtype float16 --trust-remote-code --max-num-seqs 1 --dataset /work/test/medusa_benchmark_data.json --max-model-len 4096 --gpu-memory-utilization 0.9
可设置max-num-seqs对不同的batch进行性能测试
This diff is collapsed.
...@@ -20,9 +20,9 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -20,9 +20,9 @@ from vllm.model_executor.utils import set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE = 64 DEFAULT_VOCAB_PADDING_SIZE = 64
TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE = 'base_model.model.medusa_head.{}.{}.linear.weight' TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE = 'medusa_head.{}.{}.linear.weight'
TRAINED_MEDUSA_HEADS_NEMA_TEMPLATE = 'base_model.model.medusa_head.{}.1.weight' TRAINED_MEDUSA_HEADS_NEMA_TEMPLATE = 'medusa_head.{}.1.weight'
TRAINED_BLOCK_BIAS_NAME_TEMPLATE = 'base_model.model.medusa_head.{}.{}.linear.bias' TRAINED_BLOCK_BIAS_NAME_TEMPLATE = 'medusa_head.{}.{}.linear.bias'
VLLM_BLOCK_WEIGHT_NAME_TEMPLATE = 'blocks.{}.layers.{}.weight' VLLM_BLOCK_WEIGHT_NAME_TEMPLATE = 'blocks.{}.layers.{}.weight'
VLLM_BLOCK_BIAS_NAME_TEMPLATE = 'blocks.{}.layers.{}.bias' VLLM_BLOCK_BIAS_NAME_TEMPLATE = 'blocks.{}.layers.{}.bias'
......
...@@ -238,6 +238,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata): ...@@ -238,6 +238,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
tree_attention_masks_tensor: Optional[torch.Tensor] = None tree_attention_masks_tensor: Optional[torch.Tensor] = None
block_tables_list: Optional[List[int]] = None
@property @property
def prefill_metadata( def prefill_metadata(
self) -> Optional["BlocksparseFlashAttentionMetadata"]: self) -> Optional["BlocksparseFlashAttentionMetadata"]:
...@@ -269,7 +271,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata): ...@@ -269,7 +271,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
context_lens_tensor=self.context_lens_tensor[:self.num_prefills], context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
block_tables=self.block_tables[:self.num_prefills], block_tables=self.block_tables[:self.num_prefills],
use_cuda_graph=False, use_cuda_graph=False,
tree_attention_masks_tensor=self.tree_attention_masks_tensor tree_attention_masks_tensor=self.tree_attention_masks_tensor,
block_tables_list=self.block_tables_list
) )
return self._cached_prefill_metadata return self._cached_prefill_metadata
...@@ -298,7 +301,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata): ...@@ -298,7 +301,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
context_lens_tensor=None, context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:], block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph, use_cuda_graph=self.use_cuda_graph,
tree_attention_masks_tensor=self.tree_attention_masks_tensor tree_attention_masks_tensor=self.tree_attention_masks_tensor,
block_tables_list=self.block_tables_list
) )
return self._cached_decode_metadata return self._cached_decode_metadata
......
...@@ -168,6 +168,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -168,6 +168,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
_cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
tree_attention_masks_tensor: Optional[torch.Tensor] = None tree_attention_masks_tensor: Optional[torch.Tensor] = None
block_tables_list: Optional[List[int]] = None
@property @property
def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
...@@ -199,7 +200,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -199,7 +200,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
context_lens_tensor=self.context_lens_tensor[:self.num_prefills], context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
block_tables=self.block_tables[:self.num_prefills], block_tables=self.block_tables[:self.num_prefills],
use_cuda_graph=False, use_cuda_graph=False,
tree_attention_masks_tensor=self.tree_attention_masks_tensor tree_attention_masks_tensor=self.tree_attention_masks_tensor,
block_tables_list=self.block_tables_list
) )
return self._cached_prefill_metadata return self._cached_prefill_metadata
...@@ -228,7 +230,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -228,7 +230,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
context_lens_tensor=None, context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:], block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph, use_cuda_graph=self.use_cuda_graph,
tree_attention_masks_tensor=self.tree_attention_masks_tensor tree_attention_masks_tensor=self.tree_attention_masks_tensor,
block_tables_list=self.block_tables_list
) )
return self._cached_decode_metadata return self._cached_decode_metadata
......
...@@ -272,7 +272,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): ...@@ -272,7 +272,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
context_lens_tensor=context_lens_tensor, context_lens_tensor=context_lens_tensor,
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=use_captured_graph, use_cuda_graph=use_captured_graph,
tree_attention_masks_tensor=tree_attention_masks_tensor tree_attention_masks_tensor=tree_attention_masks_tensor,
block_tables_list=self.block_tables
) )
......
...@@ -190,6 +190,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -190,6 +190,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
cross_block_tables: Optional[torch.Tensor] = None cross_block_tables: Optional[torch.Tensor] = None
tree_attention_masks_tensor: Optional[torch.Tensor] = None tree_attention_masks_tensor: Optional[torch.Tensor] = None
block_tables_list: Optional[List[int]] = None
def __post_init__(self): def __post_init__(self):
# Set during the execution of the first attention op. # Set during the execution of the first attention op.
...@@ -271,7 +272,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -271,7 +272,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
max_encoder_seq_len=self.max_encoder_seq_len, max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping, cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables, cross_block_tables=self.cross_block_tables,
tree_attention_masks_tensor=self.tree_attention_masks_tensor) tree_attention_masks_tensor=self.tree_attention_masks_tensor,
block_tables_list=self.block_tables_list)
return self._cached_prefill_metadata return self._cached_prefill_metadata
@property @property
...@@ -311,7 +313,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -311,7 +313,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
max_encoder_seq_len=self.max_encoder_seq_len, max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping, cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables, cross_block_tables=self.cross_block_tables,
tree_attention_masks_tensor=self.tree_attention_masks_tensor) tree_attention_masks_tensor=self.tree_attention_masks_tensor,
block_tables_list=self.block_tables_list)
return self._cached_decode_metadata return self._cached_decode_metadata
......
...@@ -41,7 +41,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest ...@@ -41,7 +41,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
Sequence, SequenceGroup, SequenceGroupMetadata, Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceStatus, VLLM_INVALID_TOKEN_ID) SequenceStatus, CompletionSequenceGroupOutput, VLLM_INVALID_TOKEN_ID)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer) init_tracer)
from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.config import try_get_generation_config
...@@ -989,7 +989,7 @@ class LLMEngine: ...@@ -989,7 +989,7 @@ class LLMEngine:
output = [outputs_by_sequence_group[0][i]] output = [outputs_by_sequence_group[0][i]]
# tree style speculative decoding may generate empty output in first step # tree style speculative decoding may generate empty output in first step
if outputs and isinstance(output[0], SamplerOutput): if outputs and isinstance(output[0], CompletionSequenceGroupOutput):
samples = [o.samples[0] for o in output] samples = [o.samples[0] for o in output]
valid_samples = [ valid_samples = [
sample for sample in samples sample for sample in samples
......
...@@ -235,7 +235,6 @@ class Sampler(nn.Module): ...@@ -235,7 +235,6 @@ class Sampler(nn.Module):
sampling_metadata: Metadata for sampling. sampling_metadata: Metadata for sampling.
""" """
assert logits is not None assert logits is not None
original_logits = logits.clone()
_, vocab_size = logits.shape _, vocab_size = logits.shape
# Prepare sampling tensors with pinned memory to avoid blocking. # Prepare sampling tensors with pinned memory to avoid blocking.
...@@ -320,7 +319,7 @@ class Sampler(nn.Module): ...@@ -320,7 +319,7 @@ class Sampler(nn.Module):
sample_logprobs, sample_logprobs,
on_device_tensors=on_device_tensors, on_device_tensors=on_device_tensors,
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output, skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output,
logits=original_logits) logits=logits)
@property @property
def _should_modify_greedy_probs_inplace(self) -> bool: def _should_modify_greedy_probs_inplace(self) -> bool:
......
...@@ -198,10 +198,11 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): ...@@ -198,10 +198,11 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
k = draft_token_ids.shape[-1] k = draft_token_ids.shape[-1]
output_token_id_list = [] output_token_id_list = []
logger.info("accept_length:%s", accept_length) accept_length_list = accept_length.cpu().tolist()
logger.info("accept_length:%s", accept_length_list)
for i in range(batch_size): for i in range(batch_size):
output_best_candidates.append(best_candidate[i]) output_best_candidates.append(best_candidate[i])
accept_lengths.append(accept_length[i]) accept_lengths.append(accept_length_list[i])
if not first_step_flags[i]: if not first_step_flags[i]:
select_indices = cart_candidates[i, best_candidate[i], : accept_length[i] + 1] select_indices = cart_candidates[i, best_candidate[i], : accept_length[i] + 1]
......
...@@ -996,9 +996,6 @@ class SequenceGroupMetadata( ...@@ -996,9 +996,6 @@ class SequenceGroupMetadata(
# TODO: We should maintain this states out of the sequence group. # TODO: We should maintain this states out of the sequence group.
num_speculative_tokens: Optional[int] = None num_speculative_tokens: Optional[int] = None
tree_attn_masks : Optional[torch.Tensor] = None
tree_position_ids : Optional[torch.Tensor] = None
def __post_init__(self): def __post_init__(self):
if self.seq_data is not None and self.token_chunk_size is None: if self.seq_data is not None and self.token_chunk_size is None:
if self.is_prompt: if self.is_prompt:
...@@ -1036,11 +1033,6 @@ class SequenceGroupMetadata( ...@@ -1036,11 +1033,6 @@ class SequenceGroupMetadata(
assert self.state.current_step < self.state.num_steps assert self.state.current_step < self.state.num_steps
self.state.current_step += 1 self.state.current_step += 1
def set_tree_style_args(self, tree_attn_masks: Optional[torch.Tensor],
tree_position_ids: Optional[torch.Tensor]):
self.tree_attn_masks = tree_attn_masks
self.tree_position_ids = tree_position_ids
class SequenceOutput( class SequenceOutput(
msgspec.Struct, msgspec.Struct,
......
...@@ -343,7 +343,7 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker): ...@@ -343,7 +343,7 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
# Move the tensors in the dictionary to the specified device # Move the tensors in the dictionary to the specified device
medusa_buffers = { medusa_buffers = {
k: (v.clone().to(device) if k != "tree_position_ids" else v.clone()) k: v.clone().to(device)
if isinstance(v, torch.Tensor) if isinstance(v, torch.Tensor)
else torch.tensor(v, device=device) else torch.tensor(v, device=device)
for k, v in medusa_buffers.items() for k, v in medusa_buffers.items()
......
...@@ -318,6 +318,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -318,6 +318,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
""" """
(self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
) = True ) = True
# tree_style decoding modify probs in _verify_tokens
if not self.tree_style_spec_decoding:
(self.scorer_worker.model_runner.model.sampler. (self.scorer_worker.model_runner.model.sampler.
should_modify_greedy_probs_inplace) = True should_modify_greedy_probs_inplace) = True
self.proposer_worker.set_include_gpu_probs_tensor() self.proposer_worker.set_include_gpu_probs_tensor()
...@@ -694,7 +697,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -694,7 +697,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposal_scores: SpeculativeScores, proposal_scores: SpeculativeScores,
proposals: SpeculativeProposals, proposals: SpeculativeProposals,
max_proposal_len: int, max_proposal_len: int,
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], List[int]]: ) -> Tuple[torch.Tensor, torch.Tensor, List[List[int]], List[int]]:
"""Determine which speculative tokens are accepted using the """Determine which speculative tokens are accepted using the
probabilities of each token according to the proposer and scorer models. probabilities of each token according to the proposer and scorer models.
...@@ -712,7 +715,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -712,7 +715,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
original_indices = spec_indices + non_spec_indices original_indices = spec_indices + non_spec_indices
# Get probabilities of target model, including bonus tokens. # Get probabilities of target model, including bonus tokens.
if non_spec_indices:
proposal_verifier_probs = proposal_scores.probs[spec_indices] proposal_verifier_probs = proposal_scores.probs[spec_indices]
else:
proposal_verifier_probs = proposal_scores.probs
if self.tree_style_spec_decoding: if self.tree_style_spec_decoding:
retrieve_indices = proposals.retrieve_indices retrieve_indices = proposals.retrieve_indices
...@@ -722,17 +728,24 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -722,17 +728,24 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices] non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
# Get bonus tokens from target model. # Get bonus tokens from target model.
bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:] bonus_token_ids = proposal_scores.token_ids[:, -1:]
if non_spec_indices:
bonus_token_ids = bonus_token_ids[spec_indices, :]
# Get probabilities according to proposal method. # Get probabilities according to proposal method.
proposal_probs = proposals.proposal_probs[spec_indices] \ proposal_probs = proposals.proposal_probs if proposals.proposal_probs is not None else None
if proposals.proposal_probs is not None else None if non_spec_indices:
proposal_probs = proposal_probs[spec_indices]
# Get proposed tokens. # Get proposed tokens.
proposal_token_ids = proposals.proposal_token_ids[spec_indices] proposal_token_ids = proposals.proposal_token_ids
if non_spec_indices:
proposal_token_ids = proposal_token_ids[spec_indices]
# Get tree buffers. # Get tree buffers.
cart_candidates = proposals.cart_candidates[spec_indices] if proposals.cart_candidates is not None else None cart_candidates = proposals.cart_candidates if proposals.cart_candidates is not None else None
if non_spec_indices:
cart_candidates = cart_candidates[spec_indices]
# Sampler arguments # Sampler arguments
sampler_extra_kwargs: Dict[str, Any] = {} sampler_extra_kwargs: Dict[str, Any] = {}
...@@ -821,6 +834,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -821,6 +834,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
previous_hidden_state_list = [] previous_hidden_state_list = []
retrieve_indices = retrieve_indices.cpu()
for i in range(batch_size): for i in range(batch_size):
logit = logits[i, best_candidates[i], accept_lengths[i]].unsqueeze(0) logit = logits[i, best_candidates[i], accept_lengths[i]].unsqueeze(0)
previous_logits_list.append(logit) previous_logits_list.append(logit)
...@@ -865,14 +880,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -865,14 +880,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
model_input = self.scorer._scorer_worker.model_input model_input = self.scorer._scorer_worker.model_input
block_tables = None block_tables = None
if hasattr(model_input, 'attn_metadata') and hasattr(model_input.attn_metadata, 'block_tables'): if hasattr(model_input, 'attn_metadata') and hasattr(model_input.attn_metadata, 'block_tables_list'):
block_tables = model_input.attn_metadata.block_tables block_tables = model_input.attn_metadata.block_tables_list
if block_tables is None: if block_tables is None:
raise RuntimeError("Can not get block_tables from model_input.") raise RuntimeError("Can not get block_tables from model_input.")
block_tables = block_tables.cpu().tolist()
cache_engine = self.scorer._scorer_worker.cache_engines[execute_model_req.virtual_engine] cache_engine = self.scorer._scorer_worker.cache_engines[execute_model_req.virtual_engine]
block_size = cache_engine.block_size block_size = cache_engine.block_size
batch_size = len(select_indices_list) batch_size = len(select_indices_list)
...@@ -885,10 +898,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -885,10 +898,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
if accept_legth > 0: if accept_legth > 0:
select_indices = select_indices_list[i][1:] + seq_lens[i] select_indices = select_indices_list[i][1:] + seq_lens[i]
select_indices = select_indices.tolist()
self.compute_slot_mapping(select_indices_slot_mapping, i*block_table_stride, self.compute_slot_mapping(select_indices_slot_mapping, i*block_table_stride,
select_indices, block_size, block_tables) select_indices, block_size, block_tables)
target_indices = torch.arange(accept_legth+1)[1:] + seq_lens[i] target_indices = torch.arange(accept_legth+1)[1:] + seq_lens[i]
target_indices = target_indices.tolist()
self.compute_slot_mapping(target_slot_mapping, i*block_table_stride, self.compute_slot_mapping(target_slot_mapping, i*block_table_stride,
target_indices, block_size, block_tables) target_indices, block_size, block_tables)
...@@ -900,12 +915,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -900,12 +915,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
dtype=torch.long, dtype=torch.long,
device=self.device).view(-1, 1) device=self.device).view(-1, 1)
src_dst_tensor = torch.cat([select_indices_slot_tensor, target_slot_mapping_tensor], dim=-1) #[batch_size*T, 2] src_dst_tensor = torch.cat([select_indices_slot_tensor, target_slot_mapping_tensor], dim=-1) #[batch_size*T, 2]
# kv_caches = self.scorer._scorer_worker.kv_cache[execute_model_req.virtual_engine]
# kv_cache_dtype = cache_engine.cache_config.cache_dtype
# backend = cache_engine.attn_backend
# num_kv_heads = cache_engine.num_kv_heads
# head_size = cache_engine.head_size
# backend.move_cache(kv_caches, src_dst_tensor, kv_cache_dtype, num_kv_heads*4, head_size)
self.kvcache_slot_to_be_moved = src_dst_tensor self.kvcache_slot_to_be_moved = src_dst_tensor
def compute_slot_mapping(self, slot_mapping: List[int], def compute_slot_mapping(self, slot_mapping: List[int],
......
...@@ -470,6 +470,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -470,6 +470,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.block_aligned_sliding_window = \ self.block_aligned_sliding_window = \
self.sliding_window_blocks * self.block_size self.sliding_window_blocks * self.block_size
if hasattr(self.runner, "tree_attn_masks"):
self.tree_attn_masks = self.runner.tree_attn_masks
self.tree_position_ids = self.runner.tree_position_ids
else:
self.tree_attn_masks = None
self.tree_position_ids = None
self.is_encoder_decoder_model = self.runner.model_config.is_encoder_decoder_model
def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
seq_group_metadata: SequenceGroupMetadata): seq_group_metadata: SequenceGroupMetadata):
"""Compute context length, sequence length and tokens """Compute context length, sequence length and tokens
...@@ -511,10 +520,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -511,10 +520,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
inter_data.input_positions[seq_idx] = list(range(context_len, seq_len)) inter_data.input_positions[seq_idx] = list(range(context_len, seq_len))
if seq_group_metadata.tree_position_ids is not None:
inter_data.input_positions[seq_idx] = seq_group_metadata.tree_position_ids.contiguous().tolist()
inter_data.tree_attn_masks[seq_idx] = seq_group_metadata.tree_attn_masks
inter_data.query_lens[ inter_data.query_lens[
seq_idx] = seq_len - context_len if inter_data.is_prompt else 1 seq_idx] = seq_len - context_len if inter_data.is_prompt else 1
...@@ -718,7 +723,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -718,7 +723,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
encoder_seq_len = 0 encoder_seq_len = 0
if self.runner.model_config.is_encoder_decoder_model: if self.is_encoder_decoder_model:
encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len() encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len()
inter_data = self.init_cached_inter_data( inter_data = self.init_cached_inter_data(
...@@ -796,7 +801,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -796,7 +801,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
if not inter_data.is_prompt: if not inter_data.is_prompt:
max_decode_seq_len = max(max_decode_seq_len, max_decode_seq_len = max(max_decode_seq_len,
max(inter_data.seq_lens)) max(inter_data.seq_lens))
if self.runner.model_config.is_encoder_decoder_model: if self.is_encoder_decoder_model:
max_encoder_seq_len = max(max_encoder_seq_len, max_encoder_seq_len = max(max_encoder_seq_len,
inter_data.encoder_seq_len) inter_data.encoder_seq_len)
...@@ -849,20 +854,10 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -849,20 +854,10 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size)) seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size))
# prepare tree attention masks # prepare tree attention masks
max_context_len = 0 tree_attention_masks_tensor = self.tree_attn_masks
for inter_data in self.inter_data_list: if tree_attention_masks_tensor is not None:
max_context_len = max(max_context_len, max(inter_data.context_lens))
tree_attention_masks_list = []
for inter_data in self.inter_data_list:
for i in range(len(inter_data.seq_lens)):
if inter_data.tree_attn_masks:
tree_attn_masks = inter_data.tree_attn_masks[i]
if tree_attn_masks is not None:
tree_attention_masks_list.append(tree_attn_masks)
tree_attention_masks_tensor = None
if tree_attention_masks_list:
tree_attention_masks_tensor = torch.stack(tree_attention_masks_list, dim=0)
tree_attention_masks_tensor = tree_attention_masks_tensor.contiguous() tree_attention_masks_tensor = tree_attention_masks_tensor.contiguous()
input_positions_tensor = self.tree_position_ids.contiguous()
# Attention metadata. # Attention metadata.
attn_metadata = self.attn_metadata_builder.build( attn_metadata = self.attn_metadata_builder.build(
...@@ -1039,6 +1034,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1039,6 +1034,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.sampling_metadata_cache: SamplingMetadataCache = \ self.sampling_metadata_cache: SamplingMetadataCache = \
SamplingMetadataCache() SamplingMetadataCache()
self.tree_attn_masks: Optional[torch.Tensor] = None
self.tree_position_ids : Optional[torch.Tensor] = None
def load_model(self) -> None: def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model) logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: with DeviceMemoryProfiler() as m:
...@@ -1506,6 +1504,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1506,6 +1504,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
def vocab_size(self) -> int: def vocab_size(self) -> int:
return self.model_config.get_vocab_size() return self.model_config.get_vocab_size()
def set_tree_style_args(self, tree_attn_masks: Optional[torch.Tensor],
tree_position_ids: Optional[torch.Tensor]):
self.tree_attn_masks = tree_attn_masks
self.tree_position_ids = tree_position_ids
class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
""" """
......
...@@ -283,11 +283,9 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -283,11 +283,9 @@ class LocalOrDistributedWorkerBase(WorkerBase):
worker_input: WorkerInput = self.prepare_worker_input( worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req) execute_model_req=execute_model_req)
# set tree_attn_masks and position ids to seq_group_metadata_list if hasattr(self.model_runner, "set_tree_style_args"):
if execute_model_req.tree_attn_masks is not None: self.model_runner.set_tree_style_args(tree_attn_masks=execute_model_req.tree_attn_masks,
for i, seq_group_metadata in enumerate(execute_model_req.seq_group_metadata_list): tree_position_ids=execute_model_req.tree_position_ids)
seq_group_metadata.set_tree_style_args(tree_attn_masks=execute_model_req.tree_attn_masks[i],
tree_position_ids=execute_model_req.tree_position_ids[i])
model_input: ModelRunnerInputBase = ( model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input( self.model_runner.prepare_model_input(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment