Commit 8c552f40 authored by 王敏's avatar 王敏
Browse files

[fix]解决开启mtp后,在极端情况碰到显存不足时,导致mla中申请的tensor数据错乱问题

parent 2c8a16d6
...@@ -212,7 +212,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -212,7 +212,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, LinearBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down from vllm.utils import cdiv, round_down, is_pin_memory_available
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata) CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
...@@ -399,18 +399,41 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -399,18 +399,41 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self.block_table = block_table self.block_table = block_table
self.use_spec_decode = False self.use_spec_decode = False
self.num_scheduled_tokens_np = np.zeros(scheduler_config.max_num_seqs, dtype=np.int32) self.decode_token_num_threshold = 1
# support for cudagraph spec docoding
self.spec_decode_block_table_tensor = None
self.spec_decode_seq_lens = None
self.decode_token_num_threshold = 1
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
speculative_config = vllm_config.speculative_config speculative_config = vllm_config.speculative_config
if speculative_config and speculative_config.num_speculative_tokens > 1: if speculative_config and speculative_config.num_speculative_tokens > 1:
self.use_spec_decode = True self.use_spec_decode = True
self.decode_token_num_threshold = 1 + speculative_config.num_speculative_tokens self.decode_token_num_threshold = 1 + speculative_config.num_speculative_tokens
self.device = self.runner.device
self.pin_memory = is_pin_memory_available()
#self.num_scheduled_tokens_np = np.zeros(scheduler_config.max_num_seqs, dtype=np.int32)
self.num_scheduled_tokens = torch.zeros(scheduler_config.max_num_seqs,
dtype=torch.int32,
device=self.device)
self.num_scheduled_tokens_cpu = torch.zeros(scheduler_config.max_num_seqs,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.num_scheduled_tokens_np = self.num_scheduled_tokens_cpu.numpy()
self.seq_lens_minus = torch.zeros(scheduler_config.max_num_seqs * self.decode_token_num_threshold,
dtype=torch.int32,
device=self.device)
self.seq_lens_minus_cpu = torch.zeros(scheduler_config.max_num_seqs * self.decode_token_num_threshold,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.seq_lens_minus_np = self.seq_lens_minus_cpu.numpy()
# support for cudagraph spec docoding
self.spec_decode_block_table_tensor = None
self.spec_decode_seq_lens = None
def reorder_batch(self, input_batch: "InputBatch", def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool: scheduler_output: "SchedulerOutput") -> bool:
...@@ -444,6 +467,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -444,6 +467,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
num_computed_tokens = input_batch.num_computed_tokens_cpu[req_idx] num_computed_tokens = input_batch.num_computed_tokens_cpu[req_idx]
num_prompt_tokens = input_batch.num_prompt_tokens[req_idx] num_prompt_tokens = input_batch.num_prompt_tokens[req_idx]
self.num_scheduled_tokens_np[i] = num_tokens self.num_scheduled_tokens_np[i] = num_tokens
if num_computed_tokens < num_prompt_tokens or (num_tokens > self.decode_token_num_threshold): if num_computed_tokens < num_prompt_tokens or (num_tokens > self.decode_token_num_threshold):
prefills.append(i) prefills.append(i)
num_prefill_tokens += num_tokens num_prefill_tokens += num_tokens
...@@ -646,14 +670,22 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -646,14 +670,22 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
if self._num_decodes > 0: if self._num_decodes > 0:
if self.use_spec_decode and not common_attn_metadata.spec_layer_decoding: if self.use_spec_decode and not common_attn_metadata.spec_layer_decoding:
query_lens = self.num_scheduled_tokens_np[:self._num_decodes] query_lens = self.num_scheduled_tokens_np[:self._num_decodes]
self.num_scheduled_tokens[:self._num_decodes].copy_(
self.num_scheduled_tokens_cpu[:self._num_decodes],
non_blocking=True)
repeats = self.num_scheduled_tokens[:self._num_decodes]
cu_num_blocks = np.cumsum(query_lens) cu_num_blocks = np.cumsum(query_lens)
virtual_batches = cu_num_blocks[-1] virtual_batches = cu_num_blocks[-1]
block_offsets = np.repeat(cu_num_blocks - query_lens, query_lens) block_offsets = np.repeat(cu_num_blocks - query_lens, query_lens)
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
rarange = np.repeat(query_lens, query_lens) - arange - 1 rarange = np.repeat(query_lens, query_lens) - arange - 1
repeats = torch.from_numpy(query_lens).pin_memory().to( self.seq_lens_minus_np[:rarange.size] = rarange
block_table_tensor.device, non_blocking=True).contiguous() self.seq_lens_minus[:rarange.size].copy_(
self.seq_lens_minus_cpu[:rarange.size],
non_blocking=True)
seq_lens_minus = self.seq_lens_minus[:rarange.size]
if envs.VLLM_ZERO_OVERHEAD: if envs.VLLM_ZERO_OVERHEAD:
decode_block_table_tensor = torch.empty((self._num_decode_tokens, block_table_tensor.shape[1]), decode_block_table_tensor = torch.empty((self._num_decode_tokens, block_table_tensor.shape[1]),
...@@ -670,8 +702,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -670,8 +702,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
block_table_tensor[:self._num_decodes, ...], block_table_tensor[:self._num_decodes, ...],
repeats, dim=0).contiguous() repeats, dim=0).contiguous()
decode_seq_lens = torch.repeat_interleave(seq_lens[:self._num_decodes], repeats, dim=0).contiguous() decode_seq_lens = torch.repeat_interleave(seq_lens[:self._num_decodes], repeats, dim=0).contiguous()
seq_lens_minus = torch.from_numpy(rarange).to(torch.int32).pin_memory().to(
seq_lens.device, non_blocking=True).contiguous()
decode_seq_lens = decode_seq_lens - seq_lens_minus decode_seq_lens = decode_seq_lens - seq_lens_minus
if self.spec_decode_block_table_tensor is not None: if self.spec_decode_block_table_tensor is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment