Commit f6aa3d19 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.11.0-dev-wm-1205' into 'v0.11.0-dev'

去掉无效代码

See merge request dcutoolkit/deeplearing/vllm!285
parents b8412df6 7343379a
...@@ -186,17 +186,13 @@ for chunk_idx in range(cdiv(C, MCC)): ...@@ -186,17 +186,13 @@ for chunk_idx in range(cdiv(C, MCC)):
return curr_o @ W_O return curr_o @ W_O
""" """
import os
import functools import functools
from abc import abstractmethod from abc import abstractmethod
import numpy as np
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Generic, Optional, TypeVar, Union from typing import Generic, Optional, TypeVar, Union
import torch import torch
import os
from tqdm import tqdm from tqdm import tqdm
import vllm.envs as envs import vllm.envs as envs
...@@ -558,14 +554,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -558,14 +554,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
device=device, device=device,
) )
self.block_table = block_table
self.use_spec_decode = False
# support for cudagraph spec docoding
self.spec_decode_block_table_tensor = None
self.spec_decode_seq_lens = None
def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
qo_indptr = prefill.query_start_loc qo_indptr = prefill.query_start_loc
...@@ -659,31 +647,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -659,31 +647,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
Currently, only decode is supported for full cudagraphs with MLA. Currently, only decode is supported for full cudagraphs with MLA.
""" """
m = common_attn_metadata m = common_attn_metadata
# assert m.num_reqs <= (m.num_actual_tokens * assert m.num_reqs <= (m.num_actual_tokens *
# self.reorder_batch_threshold), \ self.reorder_batch_threshold), \
# "MLA only supports decode-only full CUDAGraph capture. " \ "MLA only supports decode-only full CUDAGraph capture. " \
# "Make sure all cudagraph capture sizes <= max_num_seq." "Make sure all cudagraph capture sizes <= max_num_seq."
# assert m.max_query_len <= self.reorder_batch_threshold # decode only assert m.max_query_len <= self.reorder_batch_threshold # decode only
self.use_spec_decode = m.num_speculative_tokens > 0
# support for cudagraph spec docoding
if self.use_spec_decode:
for i in range(m.num_reqs):
self.num_scheduled_tokens_np[i] = m.num_actual_tokens // m.num_reqs
if self.spec_decode_block_table_tensor is None:
max_num_reqs = m.seq_lens.shape[0]
block_table_tensor = self.block_table.get_device_tensor()
tokens_per_seq = 1+m.num_speculative_tokens
self.spec_decode_block_table_tensor = torch.zeros((block_table_tensor.shape[0]*tokens_per_seq,
block_table_tensor.shape[1]),
dtype=block_table_tensor.dtype,
device=m.seq_lens.device)
self.spec_decode_seq_lens = torch.zeros(max_num_reqs * tokens_per_seq,
dtype=m.seq_lens.dtype,
device=m.seq_lens.device)
return self.build(0, m) return self.build(0, m)
def build(self, def build(self,
...@@ -699,15 +669,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -699,15 +669,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# function. We should avoid GPU -> CPU sync as much as possible because # function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels. # it blocks on all previous kernels.
device = self.device device = self.device
block_table = self.block_table
block_table_tensor = common_attn_metadata.block_table_tensor block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping slot_mapping = common_attn_metadata.slot_mapping
if slot_mapping is None:
block_table.slot_mapping[:num_tokens].copy_(
block_table.slot_mapping_cpu[:num_tokens],
non_blocking=True)
block_table.slot_mapping[num_tokens:].fill_(-1)
slot_mapping = block_table.slot_mapping[:num_tokens]
query_start_loc = common_attn_metadata.query_start_loc query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
...@@ -873,65 +836,16 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -873,65 +836,16 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
- prefill_query_start_loc[:-1] - prefill_query_start_loc[:-1]
prefill_metadata.cudnn_workspace = self.cudnn_workspace prefill_metadata.cudnn_workspace = self.cudnn_workspace
# TODO @ wangming
decode_metadata = None decode_metadata = None
# if num_decodes > 0: if num_decodes > 0:
# if self.use_spec_decode and not common_attn_metadata.spec_layer_decoding: decode_metadata = self._build_decode(
# query_lens = self.num_scheduled_tokens_np[:num_decodes] block_table_tensor=block_table_tensor[:num_decodes, ...],
# cu_num_blocks = np.cumsum(query_lens) seq_lens_cpu=seq_lens_cpu[:num_decodes],
# virtual_batches = cu_num_blocks[-1] seq_lens_device=seq_lens[:num_decodes],
# block_offsets = np.repeat(cu_num_blocks - query_lens, query_lens) query_start_loc_cpu=query_start_loc_cpu[:num_decodes + 1],
# arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets query_start_loc_device=query_start_loc[:num_decodes + 1],
# rarange = np.repeat(query_lens, query_lens) - arange - 1 num_decode_tokens=num_decode_tokens,
)
# repeats = torch.from_numpy(query_lens).pin_memory().to(
# block_table_tensor.device, non_blocking=True).contiguous()
# decode_block_table_tensor = torch.repeat_interleave(
# block_table_tensor[:self._num_decodes, ...],
# repeats, dim=0).contiguous()
# decode_seq_lens = torch.repeat_interleave(seq_lens[:self._num_decodes], repeats, dim=0).contiguous()
# seq_lens_minus = torch.from_numpy(rarange).to(torch.int32).pin_memory().to(
# seq_lens.device, non_blocking=True).contiguous()
# decode_seq_lens = decode_seq_lens - seq_lens_minus
# if self.spec_decode_block_table_tensor is not None:
# self.spec_decode_block_table_tensor[:self._num_decode_tokens].copy_(decode_block_table_tensor)
# self.spec_decode_seq_lens[:self._num_decode_tokens].copy_(decode_seq_lens)
# decode_metadata = self._build_decode(
# block_table_tensor=self.spec_decode_block_table_tensor[:self._num_decode_tokens, ...],
# seq_lens=self.spec_decode_seq_lens[:self._num_decode_tokens],
# )
# else:
# decode_metadata = self._build_decode(
# block_table_tensor=decode_block_table_tensor,
# seq_lens=decode_seq_lens,
# )
# else:
# self._num_decode_tokens = num_decodes
# if self.use_spec_decode and self.spec_decode_block_table_tensor is not None:
# self.spec_decode_block_table_tensor[:self._num_decode_tokens].copy_(block_table_tensor[:self._num_decode_tokens, ...])
# self.spec_decode_seq_lens[:self._num_decode_tokens].copy_(seq_lens[:self._num_decode_tokens])
# decode_metadata = self._build_decode(
# block_table_tensor=self.spec_decode_block_table_tensor[:self._num_decode_tokens, ...],
# seq_lens=self.spec_decode_seq_lens[:self._num_decode_tokens],
# )
# else:
# decode_metadata = self._build_decode(
# block_table_tensor=block_table_tensor[:self._num_decode_tokens, ...],
# seq_lens=seq_lens[:self._num_decode_tokens],
# )
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:num_decodes, ...],
seq_lens_cpu=seq_lens_cpu[:num_decodes],
seq_lens_device=seq_lens[:num_decodes],
query_start_loc_cpu=query_start_loc_cpu[:num_decodes + 1],
query_start_loc_device=query_start_loc[:num_decodes + 1],
num_decode_tokens=num_decode_tokens,
)
attn_metadata = self.metadata_cls( attn_metadata = self.metadata_cls(
num_reqs=common_attn_metadata.num_reqs, num_reqs=common_attn_metadata.num_reqs,
......
...@@ -131,15 +131,15 @@ class EagleProposer: ...@@ -131,15 +131,15 @@ class EagleProposer:
with_numpy=True) with_numpy=True)
# Determine allowed attention backends once during initialization. # Determine allowed attention backends once during initialization.
self.allowed_attn_types: Optional[tuple] = None # self.allowed_attn_types: Optional[tuple] = None
if current_platform.is_rocm(): # if current_platform.is_rocm():
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] # rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
# vllm.v1.attention.backends.rocm_aiter_fa is an optional backend # # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend
if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"): # if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"):
from vllm.v1.attention.backends.rocm_aiter_fa import ( # from vllm.v1.attention.backends.rocm_aiter_fa import (
AiterFlashAttentionMetadata) # AiterFlashAttentionMetadata)
rocm_types.append(AiterFlashAttentionMetadata) # rocm_types.append(AiterFlashAttentionMetadata)
self.allowed_attn_types = tuple(rocm_types) # self.allowed_attn_types = tuple(rocm_types)
# Parse the speculative token tree. # Parse the speculative token tree.
spec_token_tree = self.speculative_config.speculative_token_tree spec_token_tree = self.speculative_config.speculative_token_tree
...@@ -273,8 +273,7 @@ class EagleProposer: ...@@ -273,8 +273,7 @@ class EagleProposer:
with set_forward_context(per_layer_attn_metadata, with set_forward_context(per_layer_attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=num_input_tokens, num_tokens=num_input_tokens):
skip_cuda_graphs=not decoding):
ret_hidden_states = self.model( ret_hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=self.positions[:num_input_tokens], positions=self.positions[:num_input_tokens],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment