Commit 28375803 authored by 王敏's avatar 王敏
Browse files

1.优化medusa推理,节省cpu耗时

2.更新medusa readme
3.解决benchmark_moe报错问题
parent 3c9817d2
......@@ -304,7 +304,9 @@ def main(args: argparse.Namespace):
else:
batch_sizes = [args.batch_size]
ray.init()
ray.init(address=None,
ignore_reinit_error=True,
num_gpus=args.tp_size)
num_gpus = int(ray.available_resources()["GPU"])
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
......
......@@ -17,9 +17,9 @@ Vllm medusa model的实现在[vllm/model_executor/models/medusa.py]
# medusa 模型需要转换为vllm中Medusa的模型格式
```bash
python medusa_weight_converter.py --medusa_num_heads 4 --medusa_num_layers 1 --medusa_model_path /work/medusa/qwen2_72b_head_4/adapter_model.bin --vocab_size 152064 --hidden_size 8192 --output_dir /work/medusa/sugon/vllm-medusa-qwen2-72b-head-4 --medusa_choices="[(0,), (0, 0), (0, 0, 0), (0, 0, 0, 0), (0, 1), (1,), (1, 0), (0, 0, 1), (0, 1, 0), (0, 2), (1, 0, 0), (2,), (2, 0), (0, 3), (0, 0, 2), (0, 2, 0), (0, 4), (0, 0, 1, 0), (0, 1, 0, 0), (2, 0, 0), (3,), (0, 5), (0, 0, 0, 1), (3, 0), (0, 0, 3), (1, 0, 0, 0), (0, 3, 0), (0, 6), (0, 0, 4), (0, 4, 0), (1, 1), (4,)]"
python medusa_weight_converter.py --medusa_num_heads 4 --medusa_num_layers 1 --medusa_model_path /work/model.bin --vocab_size 152064 --hidden_size 8192 --output_dir /work/medusa/vllm-medusa-qwen2-72b-head-4 --medusa_choices="[(0), (0, 0), (0, 0, 0), (0, 1), (1), (1, 0), (0, 0, 0, 0), (0, 0, 1), (0, 2), (0, 1, 0), (2), (0, 0, 2), (0, 3), (1, 0, 0), (2, 0), (0, 2, 0), (0, 4), (0, 0, 3), (3), (0, 0, 0, 1), (0, 5), (0, 0, 1, 0), (0, 0, 4)]"
```
此处qwen2_72b_head_4是medusa模型使用peft lora训练后保存的权重,其他格式也可参考[medusa_weight_converter.py]修改进行权重转换
此处model.bin是训练后保存的medusa head权重
### Run
......@@ -28,30 +28,34 @@ python medusa_weight_converter.py --medusa_num_heads 4 --medusa_num_layers 1 --m
python3 -m vllm.entrypoints.openai.api_server \
--served-model-name qwen_medusa \
--model /models/Qwen2-72B-Instruct/ -tp 4 \
--max-model-len 1024 --max-num-seqs 8 --gpu-memory-utilization 0.7 \
--max-model-len 1024 --max-num-seqs 8 --gpu-memory-utilization 0.8 \
--speculative-model /work/medusa/vllm-medusa-qwen2-72b-head-4 \
--speculative-draft-tensor-parallel-size 4 \
--speculative-disable-by-batch-size 4 \
--use-v2-block-manager \
--spec-decoding-acceptance-method typical_acceptance_sampler \
--enforce-eager --dtype float16 --trust-remote-code --port 8086\
--enable-lora --lora-modules medusa-lora=/work/qwen2_72b_head_4 \
--max-lora-rank 32 --lora-extra-vocab-size 0 --merge-lora True \
--lora-target-modules qkv_proj \
--dtype float16 --trust-remote-code --port 8086\
--tree-style-spec-decoding True\
--num-speculative-heads 4 --num-speculative-tokens 33
--num-speculative-heads 4 --num-speculative-tokens 24
```
merge-lora可以将lora权重和base model权重融合,提升整体推理速度,若对精度有严格要求,可不设置此参数
num-speculative-tokens和medusa choices的个数相关,num_speculative_tokens = len(medusa_choices) + 2
num-speculative-tokens和medusa choices的个数相关,num_speculative_tokens = len(medusa_choices) + 1
# do request
```bash
curl http://localhost:8086/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "medusa-lora",
"model": "qwen_medusa",
"prompt": "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n帮我写一个C++的快速排序算法<|im_end|>\n<|im_start|>assistant\n",
"max_tokens": 256,
"temperature": 0.0
}'
```bash
```
### benchmark
python medusa_benchmark_throughput.py --model /data/llm-models/qwen2/Qwen2-72B-Instruct/ -tp 4 --dtype float16 --trust-remote-code --max-num-seqs 1 --dataset /work/test/medusa_benchmark_data.json --max-model-len 4096 --gpu-memory-utilization 0.9
可设置max-num-seqs对不同的batch进行性能测试
This diff is collapsed.
......@@ -20,9 +20,9 @@ from vllm.model_executor.utils import set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE = 64
TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE = 'base_model.model.medusa_head.{}.{}.linear.weight'
TRAINED_MEDUSA_HEADS_NEMA_TEMPLATE = 'base_model.model.medusa_head.{}.1.weight'
TRAINED_BLOCK_BIAS_NAME_TEMPLATE = 'base_model.model.medusa_head.{}.{}.linear.bias'
TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE = 'medusa_head.{}.{}.linear.weight'
TRAINED_MEDUSA_HEADS_NEMA_TEMPLATE = 'medusa_head.{}.1.weight'
TRAINED_BLOCK_BIAS_NAME_TEMPLATE = 'medusa_head.{}.{}.linear.bias'
VLLM_BLOCK_WEIGHT_NAME_TEMPLATE = 'blocks.{}.layers.{}.weight'
VLLM_BLOCK_BIAS_NAME_TEMPLATE = 'blocks.{}.layers.{}.bias'
......
......@@ -238,6 +238,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
tree_attention_masks_tensor: Optional[torch.Tensor] = None
block_tables_list: Optional[List[int]] = None
@property
def prefill_metadata(
self) -> Optional["BlocksparseFlashAttentionMetadata"]:
......@@ -269,7 +271,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
block_tables=self.block_tables[:self.num_prefills],
use_cuda_graph=False,
tree_attention_masks_tensor=self.tree_attention_masks_tensor
tree_attention_masks_tensor=self.tree_attention_masks_tensor,
block_tables_list=self.block_tables_list
)
return self._cached_prefill_metadata
......@@ -298,7 +301,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph,
tree_attention_masks_tensor=self.tree_attention_masks_tensor
tree_attention_masks_tensor=self.tree_attention_masks_tensor,
block_tables_list=self.block_tables_list
)
return self._cached_decode_metadata
......
......@@ -168,6 +168,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
_cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
tree_attention_masks_tensor: Optional[torch.Tensor] = None
block_tables_list: Optional[List[int]] = None
@property
def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
......@@ -199,7 +200,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
block_tables=self.block_tables[:self.num_prefills],
use_cuda_graph=False,
tree_attention_masks_tensor=self.tree_attention_masks_tensor
tree_attention_masks_tensor=self.tree_attention_masks_tensor,
block_tables_list=self.block_tables_list
)
return self._cached_prefill_metadata
......@@ -228,7 +230,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph,
tree_attention_masks_tensor=self.tree_attention_masks_tensor
tree_attention_masks_tensor=self.tree_attention_masks_tensor,
block_tables_list=self.block_tables_list
)
return self._cached_decode_metadata
......
......@@ -272,7 +272,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
tree_attention_masks_tensor=tree_attention_masks_tensor
tree_attention_masks_tensor=tree_attention_masks_tensor,
block_tables_list=self.block_tables
)
......
......@@ -190,6 +190,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
cross_block_tables: Optional[torch.Tensor] = None
tree_attention_masks_tensor: Optional[torch.Tensor] = None
block_tables_list: Optional[List[int]] = None
def __post_init__(self):
# Set during the execution of the first attention op.
......@@ -271,7 +272,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables,
tree_attention_masks_tensor=self.tree_attention_masks_tensor)
tree_attention_masks_tensor=self.tree_attention_masks_tensor,
block_tables_list=self.block_tables_list)
return self._cached_prefill_metadata
@property
......@@ -311,7 +313,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables,
tree_attention_masks_tensor=self.tree_attention_masks_tensor)
tree_attention_masks_tensor=self.tree_attention_masks_tensor,
block_tables_list=self.block_tables_list)
return self._cached_decode_metadata
......
......@@ -41,7 +41,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceStatus, VLLM_INVALID_TOKEN_ID)
SequenceStatus, CompletionSequenceGroupOutput, VLLM_INVALID_TOKEN_ID)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
......@@ -989,7 +989,7 @@ class LLMEngine:
output = [outputs_by_sequence_group[0][i]]
# tree style speculative decoding may generate empty output in first step
if outputs and isinstance(output[0], SamplerOutput):
if outputs and isinstance(output[0], CompletionSequenceGroupOutput):
samples = [o.samples[0] for o in output]
valid_samples = [
sample for sample in samples
......
......@@ -235,7 +235,6 @@ class Sampler(nn.Module):
sampling_metadata: Metadata for sampling.
"""
assert logits is not None
original_logits = logits.clone()
_, vocab_size = logits.shape
# Prepare sampling tensors with pinned memory to avoid blocking.
......@@ -320,7 +319,7 @@ class Sampler(nn.Module):
sample_logprobs,
on_device_tensors=on_device_tensors,
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output,
logits=original_logits)
logits=logits)
@property
def _should_modify_greedy_probs_inplace(self) -> bool:
......
......@@ -198,10 +198,11 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
k = draft_token_ids.shape[-1]
output_token_id_list = []
logger.info("accept_length:%s", accept_length)
accept_length_list = accept_length.cpu().tolist()
logger.info("accept_length:%s", accept_length_list)
for i in range(batch_size):
output_best_candidates.append(best_candidate[i])
accept_lengths.append(accept_length[i])
accept_lengths.append(accept_length_list[i])
if not first_step_flags[i]:
select_indices = cart_candidates[i, best_candidate[i], : accept_length[i] + 1]
......
......@@ -996,9 +996,6 @@ class SequenceGroupMetadata(
# TODO: We should maintain this states out of the sequence group.
num_speculative_tokens: Optional[int] = None
tree_attn_masks : Optional[torch.Tensor] = None
tree_position_ids : Optional[torch.Tensor] = None
def __post_init__(self):
if self.seq_data is not None and self.token_chunk_size is None:
if self.is_prompt:
......@@ -1036,11 +1033,6 @@ class SequenceGroupMetadata(
assert self.state.current_step < self.state.num_steps
self.state.current_step += 1
def set_tree_style_args(self, tree_attn_masks: Optional[torch.Tensor],
tree_position_ids: Optional[torch.Tensor]):
self.tree_attn_masks = tree_attn_masks
self.tree_position_ids = tree_position_ids
class SequenceOutput(
msgspec.Struct,
......
......@@ -343,7 +343,7 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
# Move the tensors in the dictionary to the specified device
medusa_buffers = {
k: (v.clone().to(device) if k != "tree_position_ids" else v.clone())
k: v.clone().to(device)
if isinstance(v, torch.Tensor)
else torch.tensor(v, device=device)
for k, v in medusa_buffers.items()
......
......@@ -318,8 +318,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"""
(self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
) = True
(self.scorer_worker.model_runner.model.sampler.
should_modify_greedy_probs_inplace) = True
# tree_style decoding modify probs in _verify_tokens
if not self.tree_style_spec_decoding:
(self.scorer_worker.model_runner.model.sampler.
should_modify_greedy_probs_inplace) = True
self.proposer_worker.set_include_gpu_probs_tensor()
self.proposer_worker.set_should_modify_greedy_probs_inplace()
......@@ -694,7 +697,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposal_scores: SpeculativeScores,
proposals: SpeculativeProposals,
max_proposal_len: int,
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], List[int]]:
) -> Tuple[torch.Tensor, torch.Tensor, List[List[int]], List[int]]:
"""Determine which speculative tokens are accepted using the
probabilities of each token according to the proposer and scorer models.
......@@ -712,7 +715,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
original_indices = spec_indices + non_spec_indices
# Get probabilities of target model, including bonus tokens.
proposal_verifier_probs = proposal_scores.probs[spec_indices]
if non_spec_indices:
proposal_verifier_probs = proposal_scores.probs[spec_indices]
else:
proposal_verifier_probs = proposal_scores.probs
if self.tree_style_spec_decoding:
retrieve_indices = proposals.retrieve_indices
......@@ -722,17 +728,24 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
# Get bonus tokens from target model.
bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
bonus_token_ids = proposal_scores.token_ids[:, -1:]
if non_spec_indices:
bonus_token_ids = bonus_token_ids[spec_indices, :]
# Get probabilities according to proposal method.
proposal_probs = proposals.proposal_probs[spec_indices] \
if proposals.proposal_probs is not None else None
proposal_probs = proposals.proposal_probs if proposals.proposal_probs is not None else None
if non_spec_indices:
proposal_probs = proposal_probs[spec_indices]
# Get proposed tokens.
proposal_token_ids = proposals.proposal_token_ids[spec_indices]
proposal_token_ids = proposals.proposal_token_ids
if non_spec_indices:
proposal_token_ids = proposal_token_ids[spec_indices]
# Get tree buffers.
cart_candidates = proposals.cart_candidates[spec_indices] if proposals.cart_candidates is not None else None
cart_candidates = proposals.cart_candidates if proposals.cart_candidates is not None else None
if non_spec_indices:
cart_candidates = cart_candidates[spec_indices]
# Sampler arguments
sampler_extra_kwargs: Dict[str, Any] = {}
......@@ -820,6 +833,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
previous_logits_list = []
previous_hidden_state_list = []
retrieve_indices = retrieve_indices.cpu()
for i in range(batch_size):
logit = logits[i, best_candidates[i], accept_lengths[i]].unsqueeze(0)
......@@ -865,13 +880,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
model_input = self.scorer._scorer_worker.model_input
block_tables = None
if hasattr(model_input, 'attn_metadata') and hasattr(model_input.attn_metadata, 'block_tables'):
block_tables = model_input.attn_metadata.block_tables
if hasattr(model_input, 'attn_metadata') and hasattr(model_input.attn_metadata, 'block_tables_list'):
block_tables = model_input.attn_metadata.block_tables_list
if block_tables is None:
raise RuntimeError("Can not get block_tables from model_input.")
block_tables = block_tables.cpu().tolist()
cache_engine = self.scorer._scorer_worker.cache_engines[execute_model_req.virtual_engine]
block_size = cache_engine.block_size
......@@ -885,10 +898,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
if accept_legth > 0:
select_indices = select_indices_list[i][1:] + seq_lens[i]
select_indices = select_indices.tolist()
self.compute_slot_mapping(select_indices_slot_mapping, i*block_table_stride,
select_indices, block_size, block_tables)
target_indices = torch.arange(accept_legth+1)[1:] + seq_lens[i]
target_indices = target_indices.tolist()
self.compute_slot_mapping(target_slot_mapping, i*block_table_stride,
target_indices, block_size, block_tables)
......@@ -900,12 +915,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
dtype=torch.long,
device=self.device).view(-1, 1)
src_dst_tensor = torch.cat([select_indices_slot_tensor, target_slot_mapping_tensor], dim=-1) #[batch_size*T, 2]
# kv_caches = self.scorer._scorer_worker.kv_cache[execute_model_req.virtual_engine]
# kv_cache_dtype = cache_engine.cache_config.cache_dtype
# backend = cache_engine.attn_backend
# num_kv_heads = cache_engine.num_kv_heads
# head_size = cache_engine.head_size
# backend.move_cache(kv_caches, src_dst_tensor, kv_cache_dtype, num_kv_heads*4, head_size)
self.kvcache_slot_to_be_moved = src_dst_tensor
def compute_slot_mapping(self, slot_mapping: List[int],
......
......@@ -469,6 +469,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.sliding_window + self.block_size - 1) // self.block_size
self.block_aligned_sliding_window = \
self.sliding_window_blocks * self.block_size
if hasattr(self.runner, "tree_attn_masks"):
self.tree_attn_masks = self.runner.tree_attn_masks
self.tree_position_ids = self.runner.tree_position_ids
else:
self.tree_attn_masks = None
self.tree_position_ids = None
self.is_encoder_decoder_model = self.runner.model_config.is_encoder_decoder_model
def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
seq_group_metadata: SequenceGroupMetadata):
......@@ -511,10 +520,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
inter_data.input_positions[seq_idx] = list(range(context_len, seq_len))
if seq_group_metadata.tree_position_ids is not None:
inter_data.input_positions[seq_idx] = seq_group_metadata.tree_position_ids.contiguous().tolist()
inter_data.tree_attn_masks[seq_idx] = seq_group_metadata.tree_attn_masks
inter_data.query_lens[
seq_idx] = seq_len - context_len if inter_data.is_prompt else 1
......@@ -718,7 +723,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
encoder_seq_len = 0
if self.runner.model_config.is_encoder_decoder_model:
if self.is_encoder_decoder_model:
encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len()
inter_data = self.init_cached_inter_data(
......@@ -796,7 +801,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
if not inter_data.is_prompt:
max_decode_seq_len = max(max_decode_seq_len,
max(inter_data.seq_lens))
if self.runner.model_config.is_encoder_decoder_model:
if self.is_encoder_decoder_model:
max_encoder_seq_len = max(max_encoder_seq_len,
inter_data.encoder_seq_len)
......@@ -847,22 +852,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Sequence and query lengths.
if cuda_graph_pad_size:
seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size))
# prepare tree attention masks
max_context_len = 0
for inter_data in self.inter_data_list:
max_context_len = max(max_context_len, max(inter_data.context_lens))
tree_attention_masks_list = []
for inter_data in self.inter_data_list:
for i in range(len(inter_data.seq_lens)):
if inter_data.tree_attn_masks:
tree_attn_masks = inter_data.tree_attn_masks[i]
if tree_attn_masks is not None:
tree_attention_masks_list.append(tree_attn_masks)
tree_attention_masks_tensor = None
if tree_attention_masks_list:
tree_attention_masks_tensor = torch.stack(tree_attention_masks_list, dim=0)
tree_attention_masks_tensor = self.tree_attn_masks
if tree_attention_masks_tensor is not None:
tree_attention_masks_tensor = tree_attention_masks_tensor.contiguous()
input_positions_tensor = self.tree_position_ids.contiguous()
# Attention metadata.
attn_metadata = self.attn_metadata_builder.build(
......@@ -1038,6 +1033,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.inter_data_cache: Dict[int, PyObjectCache] = {}
self.sampling_metadata_cache: SamplingMetadataCache = \
SamplingMetadataCache()
self.tree_attn_masks: Optional[torch.Tensor] = None
self.tree_position_ids : Optional[torch.Tensor] = None
def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
......@@ -1505,6 +1503,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
@property
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()
def set_tree_style_args(self, tree_attn_masks: Optional[torch.Tensor],
tree_position_ids: Optional[torch.Tensor]):
self.tree_attn_masks = tree_attn_masks
self.tree_position_ids = tree_position_ids
class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
......
......@@ -283,11 +283,9 @@ class LocalOrDistributedWorkerBase(WorkerBase):
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
# set tree_attn_masks and position ids to seq_group_metadata_list
if execute_model_req.tree_attn_masks is not None:
for i, seq_group_metadata in enumerate(execute_model_req.seq_group_metadata_list):
seq_group_metadata.set_tree_style_args(tree_attn_masks=execute_model_req.tree_attn_masks[i],
tree_position_ids=execute_model_req.tree_position_ids[i])
if hasattr(self.model_runner, "set_tree_style_args"):
self.model_runner.set_tree_style_args(tree_attn_masks=execute_model_req.tree_attn_masks,
tree_position_ids=execute_model_req.tree_position_ids)
model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment