Commit 52121d00 authored by zhuwenwen's avatar zhuwenwen
Browse files

update triton mla

parent 40083064
......@@ -488,10 +488,10 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha is None:
sha = get_sha(vllm_root)
if (major, minor) == ('2', '4'):
version = 'das.opt1.cust1.' + sha[:7]
version = 'das.opt1.' + sha[:7]
else:
if (major, minor) == ('2', '4'):
version = 'das.opt1.cust1'
version = 'das.opt1'
# dtk version
......
# SPDX-License-Identifier: Apache-2.0
import os
import functools
import json
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
......@@ -16,6 +19,8 @@ except ImportError:
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend,
......@@ -34,6 +39,54 @@ if TYPE_CHECKING:
ModelInputForGPUWithSamplingMetadata)
def get_mla_config_file_name(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> str:
if cache_dtype == "default":
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_default.json"
device_name = torch.cuda.get_device_name().replace(" ", "_")
if "K100_AI" in device_name:
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_K100AI.json"
elif "BW" in device_name:
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_BW.json"
else:
raise ValueError(f"Unsurpport device name: {device_name}")
@functools.lru_cache
def get_attention_mla_configs(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> Optional[Dict[Any, Any]]:
# First look up if an optimized configuration is available in the configs
# directory
json_file_name = get_mla_config_file_name(QH, KVH, QKD, VD, cache_dtype)
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info("Using decode attention configuration from %s for attention layer.", config_file_path)
# If a configuration has been found, return it
return json.load(f)
else:
logger.warning("Can not find best decode attention configuration %s for attention layer, it may not have the best performance to use default json. Please tune one. ", config_file_path)
json_file_name = get_mla_config_file_name(16, 1, 576, 512, "default")
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.warning("Using default decode attention configuration from %s for attention layer. It may not have the best performance to use default json. ", config_file_path)
# If a configuration has been found, return it
return json.load(f)
else:
raise ValueError("Please surpport default config can match 16 1 576 512")
# If no optimized configuration is available, we will use the default
# configuration
return None
class TritonMLABackend(AttentionBackend):
@staticmethod
......@@ -736,11 +789,13 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
# config = get_attention_mla_configs(self.num_heads, 1, self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank, "fp16")
# Run MQA
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
decode_meta.block_tables,
decode_meta.seq_lens_tensor, attn_logits,
attn_metadata.num_kv_splits, self.scale,
attn_metadata.num_kv_splits, self.scale, # config,
PAGE_SIZE)
return self._v_up_proj_and_o_proj(o)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment