Commit 26084d72 authored by 王敏's avatar 王敏
Browse files

[fix]解决部分mtp启动报错

parent b924a846
......@@ -186,17 +186,13 @@ for chunk_idx in range(cdiv(C, MCC)):
return curr_o @ W_O
"""
import os
import functools
from abc import abstractmethod
import numpy as np
from dataclasses import dataclass, field
from typing import Generic, Optional, TypeVar, Union
import torch
import os
from tqdm import tqdm
import vllm.envs as envs
......@@ -558,14 +554,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
device=device,
)
self.block_table = block_table
self.use_spec_decode = False
# support for cudagraph spec docoding
self.spec_decode_block_table_tensor = None
self.spec_decode_seq_lens = None
def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
qo_indptr = prefill.query_start_loc
......@@ -659,31 +647,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
Currently, only decode is supported for full cudagraphs with MLA.
"""
m = common_attn_metadata
# assert m.num_reqs <= (m.num_actual_tokens *
# self.reorder_batch_threshold), \
# "MLA only supports decode-only full CUDAGraph capture. " \
# "Make sure all cudagraph capture sizes <= max_num_seq."
# assert m.max_query_len <= self.reorder_batch_threshold # decode only
self.use_spec_decode = m.num_speculative_tokens > 0
# support for cudagraph spec docoding
if self.use_spec_decode:
for i in range(m.num_reqs):
self.num_scheduled_tokens_np[i] = m.num_actual_tokens // m.num_reqs
if self.spec_decode_block_table_tensor is None:
max_num_reqs = m.seq_lens.shape[0]
block_table_tensor = self.block_table.get_device_tensor()
tokens_per_seq = 1+m.num_speculative_tokens
self.spec_decode_block_table_tensor = torch.zeros((block_table_tensor.shape[0]*tokens_per_seq,
block_table_tensor.shape[1]),
dtype=block_table_tensor.dtype,
device=m.seq_lens.device)
self.spec_decode_seq_lens = torch.zeros(max_num_reqs * tokens_per_seq,
dtype=m.seq_lens.dtype,
device=m.seq_lens.device)
assert m.num_reqs <= (m.num_actual_tokens *
self.reorder_batch_threshold), \
"MLA only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
assert m.max_query_len <= self.reorder_batch_threshold # decode only
return self.build(0, m)
def build(self,
......@@ -699,15 +669,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device = self.device
block_table = self.block_table
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
if slot_mapping is None:
block_table.slot_mapping[:num_tokens].copy_(
block_table.slot_mapping_cpu[:num_tokens],
non_blocking=True)
block_table.slot_mapping[num_tokens:].fill_(-1)
slot_mapping = block_table.slot_mapping[:num_tokens]
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
......@@ -873,65 +836,16 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
- prefill_query_start_loc[:-1]
prefill_metadata.cudnn_workspace = self.cudnn_workspace
# TODO @ wangming
decode_metadata = None
# if num_decodes > 0:
# if self.use_spec_decode and not common_attn_metadata.spec_layer_decoding:
# query_lens = self.num_scheduled_tokens_np[:num_decodes]
# cu_num_blocks = np.cumsum(query_lens)
# virtual_batches = cu_num_blocks[-1]
# block_offsets = np.repeat(cu_num_blocks - query_lens, query_lens)
# arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
# rarange = np.repeat(query_lens, query_lens) - arange - 1
# repeats = torch.from_numpy(query_lens).pin_memory().to(
# block_table_tensor.device, non_blocking=True).contiguous()
# decode_block_table_tensor = torch.repeat_interleave(
# block_table_tensor[:self._num_decodes, ...],
# repeats, dim=0).contiguous()
# decode_seq_lens = torch.repeat_interleave(seq_lens[:self._num_decodes], repeats, dim=0).contiguous()
# seq_lens_minus = torch.from_numpy(rarange).to(torch.int32).pin_memory().to(
# seq_lens.device, non_blocking=True).contiguous()
# decode_seq_lens = decode_seq_lens - seq_lens_minus
# if self.spec_decode_block_table_tensor is not None:
# self.spec_decode_block_table_tensor[:self._num_decode_tokens].copy_(decode_block_table_tensor)
# self.spec_decode_seq_lens[:self._num_decode_tokens].copy_(decode_seq_lens)
# decode_metadata = self._build_decode(
# block_table_tensor=self.spec_decode_block_table_tensor[:self._num_decode_tokens, ...],
# seq_lens=self.spec_decode_seq_lens[:self._num_decode_tokens],
# )
# else:
# decode_metadata = self._build_decode(
# block_table_tensor=decode_block_table_tensor,
# seq_lens=decode_seq_lens,
# )
# else:
# self._num_decode_tokens = num_decodes
# if self.use_spec_decode and self.spec_decode_block_table_tensor is not None:
# self.spec_decode_block_table_tensor[:self._num_decode_tokens].copy_(block_table_tensor[:self._num_decode_tokens, ...])
# self.spec_decode_seq_lens[:self._num_decode_tokens].copy_(seq_lens[:self._num_decode_tokens])
# decode_metadata = self._build_decode(
# block_table_tensor=self.spec_decode_block_table_tensor[:self._num_decode_tokens, ...],
# seq_lens=self.spec_decode_seq_lens[:self._num_decode_tokens],
# )
# else:
# decode_metadata = self._build_decode(
# block_table_tensor=block_table_tensor[:self._num_decode_tokens, ...],
# seq_lens=seq_lens[:self._num_decode_tokens],
# )
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:num_decodes, ...],
seq_lens_cpu=seq_lens_cpu[:num_decodes],
seq_lens_device=seq_lens[:num_decodes],
query_start_loc_cpu=query_start_loc_cpu[:num_decodes + 1],
query_start_loc_device=query_start_loc[:num_decodes + 1],
num_decode_tokens=num_decode_tokens,
)
if num_decodes > 0:
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:num_decodes, ...],
seq_lens_cpu=seq_lens_cpu[:num_decodes],
seq_lens_device=seq_lens[:num_decodes],
query_start_loc_cpu=query_start_loc_cpu[:num_decodes + 1],
query_start_loc_device=query_start_loc[:num_decodes + 1],
num_decode_tokens=num_decode_tokens,
)
attn_metadata = self.metadata_cls(
num_reqs=common_attn_metadata.num_reqs,
......
......@@ -131,15 +131,15 @@ class EagleProposer:
with_numpy=True)
# Determine allowed attention backends once during initialization.
self.allowed_attn_types: Optional[tuple] = None
if current_platform.is_rocm():
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
# vllm.v1.attention.backends.rocm_aiter_fa is an optional backend
if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"):
from vllm.v1.attention.backends.rocm_aiter_fa import (
AiterFlashAttentionMetadata)
rocm_types.append(AiterFlashAttentionMetadata)
self.allowed_attn_types = tuple(rocm_types)
# self.allowed_attn_types: Optional[tuple] = None
# if current_platform.is_rocm():
# rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
# # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend
# if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"):
# from vllm.v1.attention.backends.rocm_aiter_fa import (
# AiterFlashAttentionMetadata)
# rocm_types.append(AiterFlashAttentionMetadata)
# self.allowed_attn_types = tuple(rocm_types)
# Parse the speculative token tree.
spec_token_tree = self.speculative_config.speculative_token_tree
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment