Commit f7be09fc authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.1-dev-wm' into 'v0.9.1-dev'

1.优化v1 engine mtp;2.添加v1 engine chunked prefill开关

See merge request dcutoolkit/deeplearing/vllm!152
parents b2fa85ce f0acced0
......@@ -423,6 +423,10 @@ class ModelConfig:
- "vllm" will use the vLLM model implementation.\n
- "transformers" will use the Transformers model implementation."""
enable_chunked_prefill: Optional[bool] = None
"""If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
......@@ -452,6 +456,7 @@ class ModelConfig:
factors.append(self.rope_theta)
# hf_config can control how the model looks!
factors.append(self.hf_config.to_json_string())
factors.append(self.enable_chunked_prefill)
str_factors = str(factors)
assert_hashable(str_factors)
return hashlib.sha256(str(factors).encode()).hexdigest()
......
......@@ -956,6 +956,7 @@ class EngineArgs:
override_generation_config=self.override_generation_config,
enable_sleep_mode=self.enable_sleep_mode,
model_impl=self.model_impl,
enable_chunked_prefill=self.enable_chunked_prefill
)
def create_load_config(self) -> LoadConfig:
......@@ -1046,7 +1047,7 @@ class EngineArgs:
# Set default arguments for V0 or V1 Engine.
if use_v1:
self._set_default_args_v1(usage_context)
self._set_default_args_v1(usage_context, model_config)
else:
self._set_default_args_v0(model_config)
......@@ -1532,12 +1533,16 @@ class EngineArgs:
if self.max_num_seqs is None:
self.max_num_seqs = 256
def _set_default_args_v1(self, usage_context: UsageContext) -> None:
def _set_default_args_v1(self, usage_context: UsageContext, model_config: ModelConfig) -> None:
"""Set Default Arguments for V1 Engine."""
# V1 always uses chunked prefills.
self.enable_chunked_prefill = True
if model_config.enable_chunked_prefill is not None and \
model_config.enable_chunked_prefill is False:
self.enable_chunked_prefill = False
# V1 enables prefix caching by default.
if self.enable_prefix_caching is None:
self.enable_prefix_caching = True
......
......@@ -100,6 +100,15 @@ def with_amdsmi_context(fn):
return wrapper
def device_id_to_physical_device_id(device_id: int) -> int:
if "CUDA_VISIBLE_DEVICES" in os.environ:
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
physical_device_id = device_ids[device_id]
return int(physical_device_id)
else:
return device_id
@cache
def on_gfx1x() -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
......@@ -298,13 +307,10 @@ class RocmPlatform(Platform):
@with_amdsmi_context
@lru_cache(maxsize=8)
def get_device_name(cls, device_id: int = 0) -> str:
physical_device_id = cls.device_id_to_physical_device_id(device_id)
physical_device_id = device_id_to_physical_device_id(device_id)
handle = amdsmi_get_processor_handles()[physical_device_id]
asic_info = amdsmi_get_gpu_asic_info(handle)
device_name: str = asic_info["device_id"]
if device_name in _ROCM_DEVICE_ID_NAME_MAP:
return _ROCM_DEVICE_ID_NAME_MAP[device_name]
return asic_info["market_name"]
# return amdsmi_get_gpu_asic_info(handle)["market_name"]
return torch.cuda.get_device_name(device_id)
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
......
......@@ -161,7 +161,7 @@ curr_o, curr_lse = scaled_dot_product_attention(
for chunk_idx in range(cdiv(C, MCC)):
chunk_start = chunk_idx * MCC
chunk_end = min(chunk_start + MCC, C)
Sc = chunk_end - chunk_start
Sc = chunk_end - chunk_start_table
cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end]
cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end]
cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P)
......@@ -191,6 +191,9 @@ import functools
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
from itertools import chain
import numpy as np
import torch
import os
......@@ -208,7 +211,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase,
UnquantizedLinearMethod)
from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down
from vllm.utils import cdiv, round_down, is_pin_memory_available
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
......@@ -385,6 +388,32 @@ class MLACommonMetadataBuilder(Generic[M]):
)
self.block_table = block_table
self._use_spec_decode = False
self.pin_memory = is_pin_memory_available()
self._num_scheduled_tokens = torch.zeros(scheduler_config.max_num_seqs,
dtype=torch.int32,
device=runner.device)
self._num_scheduled_tokens_cpu_tensor = torch.zeros(
(scheduler_config.max_num_seqs, ),
device="cpu",
dtype=torch.int32,
pin_memory=self.pin_memory,
)
self._num_scheduled_tokens_np = self._num_scheduled_tokens_cpu_tensor.numpy()
self._seq_lens_minus = torch.zeros(scheduler_config.max_num_seqs*5,
dtype=torch.int32,
device=runner.device)
self._seq_lens_minus_cpu_tensor = torch.zeros(
(scheduler_config.max_num_seqs*5, ),
device="cpu",
dtype=torch.int32,
pin_memory=self.pin_memory,
)
self._seq_lens_minus_np = self._seq_lens_minus_cpu_tensor.numpy()
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
# We now want to reorder the batch so that the "decode" requests are and
......@@ -397,6 +426,8 @@ class MLACommonMetadataBuilder(Generic[M]):
prefills = []
num_decode_tokens = 0
num_prefill_tokens = 0
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
......@@ -404,12 +435,23 @@ class MLACommonMetadataBuilder(Generic[M]):
# we should update this to something like < 8 in the future but
# currently the TritonMLA._forward_decode only supports
# num_tokens = 1
if num_tokens == 1:
decodes.append(i)
num_decode_tokens += num_tokens
else:
# if num_tokens == 2 or num_tokens == 1:
# decodes.append(i)
# num_decode_tokens += num_tokens
# else:
# prefills.append(i)
# num_prefill_tokens += num_tokens
req_idx = input_batch.req_id_to_index[req_id]
num_computed_tokens = input_batch.num_computed_tokens_cpu[req_idx]
num_prompt_tokens = input_batch.num_prompt_tokens[req_idx]
self._num_scheduled_tokens_np[i] = num_tokens
if num_computed_tokens < num_prompt_tokens:
prefills.append(i)
num_prefill_tokens += num_tokens
else:
decodes.append(i)
num_decode_tokens += num_tokens
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
......@@ -435,6 +477,11 @@ class MLACommonMetadataBuilder(Generic[M]):
input_batch.swap_states(prefills[i - 1], decode_idx)
modified_batch = True
# num_scheduled_tokens also need to be swapped
tmp = self._num_scheduled_tokens_np[decode_idx]
self._num_scheduled_tokens_np[decode_idx] = self._num_scheduled_tokens_np[prefills[i - 1]]
self._num_scheduled_tokens_np[prefills[i - 1]] = tmp
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
......@@ -442,6 +489,12 @@ class MLACommonMetadataBuilder(Generic[M]):
self._num_prefills = num_prefills
self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens
self._use_spec_decode = use_spec_decode
if use_spec_decode:
self._num_scheduled_tokens[:len(input_batch.req_ids)].copy_(
self._num_scheduled_tokens_cpu_tensor[:len(input_batch.req_ids)],
non_blocking=True)
return modified_batch
......@@ -548,10 +601,41 @@ class MLACommonMetadataBuilder(Generic[M]):
decode_metadata = None
if self._num_decodes > 0:
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:self._num_decodes, ...],
seq_lens=seq_lens[:self._num_decodes],
)
if self._use_spec_decode:
# generate block_table/seq_lens of mla in spec decoding scenarios
if common_attn_metadata.num_rejected_tokens_tuple is None:
repeats = self._num_scheduled_tokens[:self._num_decodes]
repeats_cpu = self._num_scheduled_tokens_np[:self._num_decodes]
else:
repeats = self._num_scheduled_tokens[:self._num_decodes] - \
common_attn_metadata.num_rejected_tokens_tuple[1][:self._num_decodes]
num_rejected_tokens = common_attn_metadata.num_rejected_tokens_tuple[0][:self._num_decodes]
repeats_cpu = self._num_scheduled_tokens_np[:self._num_decodes] - \
np.array(num_rejected_tokens)
self._num_decode_tokens -= sum(num_rejected_tokens)
decode_block_table_tensor = torch.repeat_interleave(
block_table_tensor[:self._num_decodes, ...],
repeats, dim=0)
total_decode_tokens = np.sum(repeats_cpu)
decode_seq_lens = torch.repeat_interleave(seq_lens[:self._num_decodes], repeats, dim=0)
self._seq_lens_minus_np[:total_decode_tokens] = np.fromiter(
chain.from_iterable(np.flip(np.arange(x)) for x in repeats_cpu),
dtype=int)
self._seq_lens_minus[:total_decode_tokens].copy_(self._seq_lens_minus_cpu_tensor[:total_decode_tokens],
non_blocking=True)
decode_seq_lens = decode_seq_lens - self._seq_lens_minus[:total_decode_tokens]
decode_metadata = self._build_decode(
block_table_tensor=decode_block_table_tensor,
seq_lens=decode_seq_lens,
)
else:
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:self._num_decodes, ...],
seq_lens=seq_lens[:self._num_decodes],
)
return self.metadata_cls(
num_actual_tokens=num_actual_tokens,
......
......@@ -17,7 +17,8 @@ class CommonAttentionMetadata:
seq_lens: torch.Tensor
"""(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens"""
num_rejected_tokens_tuple: tuple[list[int], torch.Tensor] = None
"""(batch_size,), record the rejected tokens number in cpu and gpu"""
def validate_kv_sharing_target(current_layer_name, target_layer_name,
static_forward_context):
......
......@@ -14,6 +14,7 @@ from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata,
FlashAttentionMetadata)
from vllm.v1.attention.backends.mla.common import MLACommonMetadata, MLACommonDecodeMetadata
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel
......@@ -91,7 +92,9 @@ class EagleProposer:
cu_num_tokens: torch.Tensor,
# [batch_size, max_num_blocks_per_req]
block_table: torch.Tensor,
sampling_metadata: SamplingMetadata,
# [batch_size]
num_rejected_tokens_tuple: tuple[list[int], torch.Tensor],
sampling_metadata: SamplingMetadata
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
......@@ -138,7 +141,9 @@ class EagleProposer:
max_query_len = query_lens.max().item()
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=cu_num_tokens, seq_lens=seq_lens)
query_start_loc=cu_num_tokens,
seq_lens=seq_lens,
num_rejected_tokens_tuple=num_rejected_tokens_tuple)
assert self.runner is not None
......@@ -196,7 +201,12 @@ class EagleProposer:
draft_token_ids_list = [draft_token_ids]
positions = target_positions[last_token_indices]
hidden_states = hidden_states[last_token_indices]
if self.method == "deepseek_mtp":
hidden_states = last_hidden_states[last_token_indices]
else:
hidden_states = hidden_states[last_token_indices]
if self.use_cuda_graph and \
batch_size <= self.cudagraph_batch_sizes[-1]:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
......@@ -205,6 +215,17 @@ class EagleProposer:
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
if isinstance(attn_metadata, MLACommonMetadata):
attn_metadata.num_decodes = batch_size
attn_metadata.num_decode_tokens = batch_size
attn_metadata.num_prefills = 0
block_table = self.runner.attn_metadata_builders[0].block_table.get_device_tensor()[:batch_size, ...]
attn_metadata.decode = self.runner.attn_metadata_builders[0]._build_decode(
block_table_tensor=block_table,
seq_lens=(seq_lens + 1),
)
for _ in range(self.num_speculative_tokens - 1):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
......@@ -224,23 +245,28 @@ class EagleProposer:
clamped_positions = torch.where(exceeds_max_model_len, 0,
positions)
# Increment the sequence lengths.
attn_metadata.max_seq_len += 1
attn_metadata.seq_lens += 1
# Consider max model length.
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
self.max_model_len)
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
if isinstance(attn_metadata, MLACommonMetadata):
attn_metadata.decode.seq_lens += 1
else:
attn_metadata.seq_lens += 1
# Increment the sequence lengths.
attn_metadata.max_seq_len += 1
# Consider max model length.
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
self.max_model_len)
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
# Compute the slot mapping.
block_numbers = clamped_positions // self.block_size
block_ids = block_table.gather(dim=1,
index=block_numbers.view(-1, 1))
index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1)
attn_metadata.slot_mapping = (block_ids * self.block_size +
clamped_positions % self.block_size)
clamped_positions % self.block_size)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
......@@ -256,12 +282,18 @@ class EagleProposer:
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=input_batch_size):
last_hidden_states, hidden_states = self.model(
ret_hidden_states = self.model(
self.input_ids[:input_batch_size],
self.positions[:input_batch_size],
self.hidden_states[:input_batch_size],
)
hidden_states = hidden_states[:batch_size]
if self.method == "deepseek_mtp":
last_hidden_states = ret_hidden_states
hidden_states = last_hidden_states[:batch_size]
else:
last_hidden_states, hidden_states = ret_hidden_states
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size],
None)
......
......@@ -1441,6 +1441,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else:
block_table = None
num_rejected_tokens_tuple = None
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens]
......@@ -1480,6 +1481,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = eagle_attn_metadata.slot_mapping[
token_indices]
num_rejected_tokens_tuple = (num_rejected_tokens, num_rejected_tokens_tensor)
draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
......@@ -1489,6 +1491,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
cu_num_tokens=cu_num_tokens,
block_table=block_table,
sampling_metadata=sampling_metadata,
num_rejected_tokens_tuple=num_rejected_tokens_tuple
)
spec_token_ids = draft_token_ids.tolist()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment