Commit f7be09fc authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.1-dev-wm' into 'v0.9.1-dev'

1.优化v1 engine mtp;2.添加v1 engine chunked prefill开关

See merge request dcutoolkit/deeplearing/vllm!152
parents b2fa85ce f0acced0
...@@ -423,6 +423,10 @@ class ModelConfig: ...@@ -423,6 +423,10 @@ class ModelConfig:
- "vllm" will use the vLLM model implementation.\n - "vllm" will use the vLLM model implementation.\n
- "transformers" will use the Transformers model implementation.""" - "transformers" will use the Transformers model implementation."""
enable_chunked_prefill: Optional[bool] = None
"""If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens."""
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, WARNING: Whenever a new field is added to this config,
...@@ -452,6 +456,7 @@ class ModelConfig: ...@@ -452,6 +456,7 @@ class ModelConfig:
factors.append(self.rope_theta) factors.append(self.rope_theta)
# hf_config can control how the model looks! # hf_config can control how the model looks!
factors.append(self.hf_config.to_json_string()) factors.append(self.hf_config.to_json_string())
factors.append(self.enable_chunked_prefill)
str_factors = str(factors) str_factors = str(factors)
assert_hashable(str_factors) assert_hashable(str_factors)
return hashlib.sha256(str(factors).encode()).hexdigest() return hashlib.sha256(str(factors).encode()).hexdigest()
......
...@@ -956,6 +956,7 @@ class EngineArgs: ...@@ -956,6 +956,7 @@ class EngineArgs:
override_generation_config=self.override_generation_config, override_generation_config=self.override_generation_config,
enable_sleep_mode=self.enable_sleep_mode, enable_sleep_mode=self.enable_sleep_mode,
model_impl=self.model_impl, model_impl=self.model_impl,
enable_chunked_prefill=self.enable_chunked_prefill
) )
def create_load_config(self) -> LoadConfig: def create_load_config(self) -> LoadConfig:
...@@ -1046,7 +1047,7 @@ class EngineArgs: ...@@ -1046,7 +1047,7 @@ class EngineArgs:
# Set default arguments for V0 or V1 Engine. # Set default arguments for V0 or V1 Engine.
if use_v1: if use_v1:
self._set_default_args_v1(usage_context) self._set_default_args_v1(usage_context, model_config)
else: else:
self._set_default_args_v0(model_config) self._set_default_args_v0(model_config)
...@@ -1532,12 +1533,16 @@ class EngineArgs: ...@@ -1532,12 +1533,16 @@ class EngineArgs:
if self.max_num_seqs is None: if self.max_num_seqs is None:
self.max_num_seqs = 256 self.max_num_seqs = 256
def _set_default_args_v1(self, usage_context: UsageContext) -> None: def _set_default_args_v1(self, usage_context: UsageContext, model_config: ModelConfig) -> None:
"""Set Default Arguments for V1 Engine.""" """Set Default Arguments for V1 Engine."""
# V1 always uses chunked prefills. # V1 always uses chunked prefills.
self.enable_chunked_prefill = True self.enable_chunked_prefill = True
if model_config.enable_chunked_prefill is not None and \
model_config.enable_chunked_prefill is False:
self.enable_chunked_prefill = False
# V1 enables prefix caching by default. # V1 enables prefix caching by default.
if self.enable_prefix_caching is None: if self.enable_prefix_caching is None:
self.enable_prefix_caching = True self.enable_prefix_caching = True
......
...@@ -100,6 +100,15 @@ def with_amdsmi_context(fn): ...@@ -100,6 +100,15 @@ def with_amdsmi_context(fn):
return wrapper return wrapper
def device_id_to_physical_device_id(device_id: int) -> int:
if "CUDA_VISIBLE_DEVICES" in os.environ:
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
physical_device_id = device_ids[device_id]
return int(physical_device_id)
else:
return device_id
@cache @cache
def on_gfx1x() -> bool: def on_gfx1x() -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
...@@ -298,13 +307,10 @@ class RocmPlatform(Platform): ...@@ -298,13 +307,10 @@ class RocmPlatform(Platform):
@with_amdsmi_context @with_amdsmi_context
@lru_cache(maxsize=8) @lru_cache(maxsize=8)
def get_device_name(cls, device_id: int = 0) -> str: def get_device_name(cls, device_id: int = 0) -> str:
physical_device_id = cls.device_id_to_physical_device_id(device_id) physical_device_id = device_id_to_physical_device_id(device_id)
handle = amdsmi_get_processor_handles()[physical_device_id] handle = amdsmi_get_processor_handles()[physical_device_id]
asic_info = amdsmi_get_gpu_asic_info(handle) # return amdsmi_get_gpu_asic_info(handle)["market_name"]
device_name: str = asic_info["device_id"] return torch.cuda.get_device_name(device_id)
if device_name in _ROCM_DEVICE_ID_NAME_MAP:
return _ROCM_DEVICE_ID_NAME_MAP[device_name]
return asic_info["market_name"]
@classmethod @classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int: def get_device_total_memory(cls, device_id: int = 0) -> int:
......
...@@ -161,7 +161,7 @@ curr_o, curr_lse = scaled_dot_product_attention( ...@@ -161,7 +161,7 @@ curr_o, curr_lse = scaled_dot_product_attention(
for chunk_idx in range(cdiv(C, MCC)): for chunk_idx in range(cdiv(C, MCC)):
chunk_start = chunk_idx * MCC chunk_start = chunk_idx * MCC
chunk_end = min(chunk_start + MCC, C) chunk_end = min(chunk_start + MCC, C)
Sc = chunk_end - chunk_start Sc = chunk_end - chunk_start_table
cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end] cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end]
cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end] cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end]
cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P) cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P)
...@@ -191,6 +191,9 @@ import functools ...@@ -191,6 +191,9 @@ import functools
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
from itertools import chain
import numpy as np
import torch import torch
import os import os
...@@ -208,7 +211,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -208,7 +211,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, LinearBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down from vllm.utils import cdiv, round_down, is_pin_memory_available
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
...@@ -385,6 +388,32 @@ class MLACommonMetadataBuilder(Generic[M]): ...@@ -385,6 +388,32 @@ class MLACommonMetadataBuilder(Generic[M]):
) )
self.block_table = block_table self.block_table = block_table
self._use_spec_decode = False
self.pin_memory = is_pin_memory_available()
self._num_scheduled_tokens = torch.zeros(scheduler_config.max_num_seqs,
dtype=torch.int32,
device=runner.device)
self._num_scheduled_tokens_cpu_tensor = torch.zeros(
(scheduler_config.max_num_seqs, ),
device="cpu",
dtype=torch.int32,
pin_memory=self.pin_memory,
)
self._num_scheduled_tokens_np = self._num_scheduled_tokens_cpu_tensor.numpy()
self._seq_lens_minus = torch.zeros(scheduler_config.max_num_seqs*5,
dtype=torch.int32,
device=runner.device)
self._seq_lens_minus_cpu_tensor = torch.zeros(
(scheduler_config.max_num_seqs*5, ),
device="cpu",
dtype=torch.int32,
pin_memory=self.pin_memory,
)
self._seq_lens_minus_np = self._seq_lens_minus_cpu_tensor.numpy()
def reorder_batch(self, input_batch: "InputBatch", def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool: scheduler_output: "SchedulerOutput") -> bool:
# We now want to reorder the batch so that the "decode" requests are and # We now want to reorder the batch so that the "decode" requests are and
...@@ -397,6 +426,8 @@ class MLACommonMetadataBuilder(Generic[M]): ...@@ -397,6 +426,8 @@ class MLACommonMetadataBuilder(Generic[M]):
prefills = [] prefills = []
num_decode_tokens = 0 num_decode_tokens = 0
num_prefill_tokens = 0 num_prefill_tokens = 0
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
for i, req_id in enumerate(input_batch.req_ids): for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id] num_tokens = scheduler_output.num_scheduled_tokens[req_id]
...@@ -404,12 +435,23 @@ class MLACommonMetadataBuilder(Generic[M]): ...@@ -404,12 +435,23 @@ class MLACommonMetadataBuilder(Generic[M]):
# we should update this to something like < 8 in the future but # we should update this to something like < 8 in the future but
# currently the TritonMLA._forward_decode only supports # currently the TritonMLA._forward_decode only supports
# num_tokens = 1 # num_tokens = 1
if num_tokens == 1:
decodes.append(i) # if num_tokens == 2 or num_tokens == 1:
num_decode_tokens += num_tokens # decodes.append(i)
else: # num_decode_tokens += num_tokens
# else:
# prefills.append(i)
# num_prefill_tokens += num_tokens
req_idx = input_batch.req_id_to_index[req_id]
num_computed_tokens = input_batch.num_computed_tokens_cpu[req_idx]
num_prompt_tokens = input_batch.num_prompt_tokens[req_idx]
self._num_scheduled_tokens_np[i] = num_tokens
if num_computed_tokens < num_prompt_tokens:
prefills.append(i) prefills.append(i)
num_prefill_tokens += num_tokens num_prefill_tokens += num_tokens
else:
decodes.append(i)
num_decode_tokens += num_tokens
# We hope that this is fairly minimal since decodes # We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are # should be around for a number of iterations so hopefully they are
...@@ -435,6 +477,11 @@ class MLACommonMetadataBuilder(Generic[M]): ...@@ -435,6 +477,11 @@ class MLACommonMetadataBuilder(Generic[M]):
input_batch.swap_states(prefills[i - 1], decode_idx) input_batch.swap_states(prefills[i - 1], decode_idx)
modified_batch = True modified_batch = True
# num_scheduled_tokens also need to be swapped
tmp = self._num_scheduled_tokens_np[decode_idx]
self._num_scheduled_tokens_np[decode_idx] = self._num_scheduled_tokens_np[prefills[i - 1]]
self._num_scheduled_tokens_np[prefills[i - 1]] = tmp
# Save for next `build` call # Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a # TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this # better way of doing this
...@@ -442,6 +489,12 @@ class MLACommonMetadataBuilder(Generic[M]): ...@@ -442,6 +489,12 @@ class MLACommonMetadataBuilder(Generic[M]):
self._num_prefills = num_prefills self._num_prefills = num_prefills
self._num_decode_tokens = num_decode_tokens self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens self._num_prefill_tokens = num_prefill_tokens
self._use_spec_decode = use_spec_decode
if use_spec_decode:
self._num_scheduled_tokens[:len(input_batch.req_ids)].copy_(
self._num_scheduled_tokens_cpu_tensor[:len(input_batch.req_ids)],
non_blocking=True)
return modified_batch return modified_batch
...@@ -548,10 +601,41 @@ class MLACommonMetadataBuilder(Generic[M]): ...@@ -548,10 +601,41 @@ class MLACommonMetadataBuilder(Generic[M]):
decode_metadata = None decode_metadata = None
if self._num_decodes > 0: if self._num_decodes > 0:
decode_metadata = self._build_decode( if self._use_spec_decode:
block_table_tensor=block_table_tensor[:self._num_decodes, ...], # generate block_table/seq_lens of mla in spec decoding scenarios
seq_lens=seq_lens[:self._num_decodes], if common_attn_metadata.num_rejected_tokens_tuple is None:
) repeats = self._num_scheduled_tokens[:self._num_decodes]
repeats_cpu = self._num_scheduled_tokens_np[:self._num_decodes]
else:
repeats = self._num_scheduled_tokens[:self._num_decodes] - \
common_attn_metadata.num_rejected_tokens_tuple[1][:self._num_decodes]
num_rejected_tokens = common_attn_metadata.num_rejected_tokens_tuple[0][:self._num_decodes]
repeats_cpu = self._num_scheduled_tokens_np[:self._num_decodes] - \
np.array(num_rejected_tokens)
self._num_decode_tokens -= sum(num_rejected_tokens)
decode_block_table_tensor = torch.repeat_interleave(
block_table_tensor[:self._num_decodes, ...],
repeats, dim=0)
total_decode_tokens = np.sum(repeats_cpu)
decode_seq_lens = torch.repeat_interleave(seq_lens[:self._num_decodes], repeats, dim=0)
self._seq_lens_minus_np[:total_decode_tokens] = np.fromiter(
chain.from_iterable(np.flip(np.arange(x)) for x in repeats_cpu),
dtype=int)
self._seq_lens_minus[:total_decode_tokens].copy_(self._seq_lens_minus_cpu_tensor[:total_decode_tokens],
non_blocking=True)
decode_seq_lens = decode_seq_lens - self._seq_lens_minus[:total_decode_tokens]
decode_metadata = self._build_decode(
block_table_tensor=decode_block_table_tensor,
seq_lens=decode_seq_lens,
)
else:
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:self._num_decodes, ...],
seq_lens=seq_lens[:self._num_decodes],
)
return self.metadata_cls( return self.metadata_cls(
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
......
...@@ -17,7 +17,8 @@ class CommonAttentionMetadata: ...@@ -17,7 +17,8 @@ class CommonAttentionMetadata:
seq_lens: torch.Tensor seq_lens: torch.Tensor
"""(batch_size,), the length of each request including both computed tokens """(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens""" and newly scheduled tokens"""
num_rejected_tokens_tuple: tuple[list[int], torch.Tensor] = None
"""(batch_size,), record the rejected tokens number in cpu and gpu"""
def validate_kv_sharing_target(current_layer_name, target_layer_name, def validate_kv_sharing_target(current_layer_name, target_layer_name,
static_forward_context): static_forward_context):
......
...@@ -14,6 +14,7 @@ from vllm.model_executor.models import supports_multimodal ...@@ -14,6 +14,7 @@ from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata, from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata,
FlashAttentionMetadata) FlashAttentionMetadata)
from vllm.v1.attention.backends.mla.common import MLACommonMetadata, MLACommonDecodeMetadata
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel
...@@ -91,7 +92,9 @@ class EagleProposer: ...@@ -91,7 +92,9 @@ class EagleProposer:
cu_num_tokens: torch.Tensor, cu_num_tokens: torch.Tensor,
# [batch_size, max_num_blocks_per_req] # [batch_size, max_num_blocks_per_req]
block_table: torch.Tensor, block_table: torch.Tensor,
sampling_metadata: SamplingMetadata, # [batch_size]
num_rejected_tokens_tuple: tuple[list[int], torch.Tensor],
sampling_metadata: SamplingMetadata
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens = target_token_ids.shape[0] num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0] batch_size = next_token_ids.shape[0]
...@@ -138,7 +141,9 @@ class EagleProposer: ...@@ -138,7 +141,9 @@ class EagleProposer:
max_query_len = query_lens.max().item() max_query_len = query_lens.max().item()
common_attn_metadata = CommonAttentionMetadata( common_attn_metadata = CommonAttentionMetadata(
query_start_loc=cu_num_tokens, seq_lens=seq_lens) query_start_loc=cu_num_tokens,
seq_lens=seq_lens,
num_rejected_tokens_tuple=num_rejected_tokens_tuple)
assert self.runner is not None assert self.runner is not None
...@@ -196,7 +201,12 @@ class EagleProposer: ...@@ -196,7 +201,12 @@ class EagleProposer:
draft_token_ids_list = [draft_token_ids] draft_token_ids_list = [draft_token_ids]
positions = target_positions[last_token_indices] positions = target_positions[last_token_indices]
hidden_states = hidden_states[last_token_indices]
if self.method == "deepseek_mtp":
hidden_states = last_hidden_states[last_token_indices]
else:
hidden_states = hidden_states[last_token_indices]
if self.use_cuda_graph and \ if self.use_cuda_graph and \
batch_size <= self.cudagraph_batch_sizes[-1]: batch_size <= self.cudagraph_batch_sizes[-1]:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
...@@ -205,6 +215,17 @@ class EagleProposer: ...@@ -205,6 +215,17 @@ class EagleProposer:
attn_metadata.num_actual_tokens = batch_size attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1 attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1] attn_metadata.query_start_loc = self.arange[:batch_size + 1]
if isinstance(attn_metadata, MLACommonMetadata):
attn_metadata.num_decodes = batch_size
attn_metadata.num_decode_tokens = batch_size
attn_metadata.num_prefills = 0
block_table = self.runner.attn_metadata_builders[0].block_table.get_device_tensor()[:batch_size, ...]
attn_metadata.decode = self.runner.attn_metadata_builders[0]._build_decode(
block_table_tensor=block_table,
seq_lens=(seq_lens + 1),
)
for _ in range(self.num_speculative_tokens - 1): for _ in range(self.num_speculative_tokens - 1):
# Update the inputs. # Update the inputs.
# cast to int32 is crucial when eagle model is compiled. # cast to int32 is crucial when eagle model is compiled.
...@@ -224,23 +245,28 @@ class EagleProposer: ...@@ -224,23 +245,28 @@ class EagleProposer:
clamped_positions = torch.where(exceeds_max_model_len, 0, clamped_positions = torch.where(exceeds_max_model_len, 0,
positions) positions)
# Increment the sequence lengths. if isinstance(attn_metadata, MLACommonMetadata):
attn_metadata.max_seq_len += 1 attn_metadata.decode.seq_lens += 1
attn_metadata.seq_lens += 1 else:
# Consider max model length. attn_metadata.seq_lens += 1
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
self.max_model_len) # Increment the sequence lengths.
# For the requests that exceed the max model length, we set the attn_metadata.max_seq_len += 1
# sequence length to 1 to minimize their overheads in attention. # Consider max model length.
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
self.max_model_len)
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
# Compute the slot mapping. # Compute the slot mapping.
block_numbers = clamped_positions // self.block_size block_numbers = clamped_positions // self.block_size
block_ids = block_table.gather(dim=1, block_ids = block_table.gather(dim=1,
index=block_numbers.view(-1, 1)) index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1) block_ids = block_ids.view(-1)
attn_metadata.slot_mapping = (block_ids * self.block_size + attn_metadata.slot_mapping = (block_ids * self.block_size +
clamped_positions % self.block_size) clamped_positions % self.block_size)
# Mask out the slot mappings that exceed the max model length. # Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the # Otherwise, the KV cache will be inadvertently updated with the
# padding tokens. # padding tokens.
...@@ -256,12 +282,18 @@ class EagleProposer: ...@@ -256,12 +282,18 @@ class EagleProposer:
with set_forward_context(per_layer_attn_metadata, with set_forward_context(per_layer_attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=input_batch_size): num_tokens=input_batch_size):
last_hidden_states, hidden_states = self.model( ret_hidden_states = self.model(
self.input_ids[:input_batch_size], self.input_ids[:input_batch_size],
self.positions[:input_batch_size], self.positions[:input_batch_size],
self.hidden_states[:input_batch_size], self.hidden_states[:input_batch_size],
) )
hidden_states = hidden_states[:batch_size] if self.method == "deepseek_mtp":
last_hidden_states = ret_hidden_states
hidden_states = last_hidden_states[:batch_size]
else:
last_hidden_states, hidden_states = ret_hidden_states
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size], logits = self.model.compute_logits(last_hidden_states[:batch_size],
None) None)
......
...@@ -1441,6 +1441,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1441,6 +1441,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else: else:
block_table = None block_table = None
num_rejected_tokens_tuple = None
if spec_decode_metadata is None: if spec_decode_metadata is None:
# input_ids can be None for multimodal models. # input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens] target_token_ids = self.input_ids[:num_scheduled_tokens]
...@@ -1480,6 +1481,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1480,6 +1481,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_hidden_states = hidden_states[token_indices] target_hidden_states = hidden_states[token_indices]
target_slot_mapping = eagle_attn_metadata.slot_mapping[ target_slot_mapping = eagle_attn_metadata.slot_mapping[
token_indices] token_indices]
num_rejected_tokens_tuple = (num_rejected_tokens, num_rejected_tokens_tensor)
draft_token_ids = self.drafter.propose( draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids, target_token_ids=target_token_ids,
target_positions=target_positions, target_positions=target_positions,
...@@ -1489,6 +1491,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1489,6 +1491,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
cu_num_tokens=cu_num_tokens, cu_num_tokens=cu_num_tokens,
block_table=block_table, block_table=block_table,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
num_rejected_tokens_tuple=num_rejected_tokens_tuple
) )
spec_token_ids = draft_token_ids.tolist() spec_token_ids = draft_token_ids.tolist()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment