Commit 52121d00 authored by zhuwenwen's avatar zhuwenwen
Browse files

update triton mla

parent 40083064
......@@ -488,10 +488,10 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha is None:
sha = get_sha(vllm_root)
if (major, minor) == ('2', '4'):
version = 'das.opt1.cust1.' + sha[:7]
version = 'das.opt1.' + sha[:7]
else:
if (major, minor) == ('2', '4'):
version = 'das.opt1.cust1'
version = 'das.opt1'
# dtk version
......
# SPDX-License-Identifier: Apache-2.0
import os
import functools
import json
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
......@@ -16,6 +19,8 @@ except ImportError:
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend,
......@@ -32,6 +37,54 @@ from vllm.utils import async_tensor_h2d, make_tensor_with_pad
if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)
def get_mla_config_file_name(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> str:
if cache_dtype == "default":
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_default.json"
device_name = torch.cuda.get_device_name().replace(" ", "_")
if "K100_AI" in device_name:
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_K100AI.json"
elif "BW" in device_name:
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_BW.json"
else:
raise ValueError(f"Unsurpport device name: {device_name}")
@functools.lru_cache
def get_attention_mla_configs(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> Optional[Dict[Any, Any]]:
# First look up if an optimized configuration is available in the configs
# directory
json_file_name = get_mla_config_file_name(QH, KVH, QKD, VD, cache_dtype)
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info("Using decode attention configuration from %s for attention layer.", config_file_path)
# If a configuration has been found, return it
return json.load(f)
else:
logger.warning("Can not find best decode attention configuration %s for attention layer, it may not have the best performance to use default json. Please tune one. ", config_file_path)
json_file_name = get_mla_config_file_name(16, 1, 576, 512, "default")
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.warning("Using default decode attention configuration from %s for attention layer. It may not have the best performance to use default json. ", config_file_path)
# If a configuration has been found, return it
return json.load(f)
else:
raise ValueError("Please surpport default config can match 16 1 576 512")
# If no optimized configuration is available, we will use the default
# configuration
return None
class TritonMLABackend(AttentionBackend):
......@@ -682,7 +735,7 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
"encoder/decoder cross-attention "
"are not implemented for "
"TritonMLAImpl")
def _forward_prefill(
self,
q: torch.Tensor,
......@@ -735,12 +788,14 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
# config = get_attention_mla_configs(self.num_heads, 1, self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank, "fp16")
# Run MQA
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
decode_meta.block_tables,
decode_meta.seq_lens_tensor, attn_logits,
attn_metadata.num_kv_splits, self.scale,
attn_metadata.num_kv_splits, self.scale, # config,
PAGE_SIZE)
return self._v_up_proj_and_o_proj(o)
......@@ -540,7 +540,7 @@ def _decode_softmax_reducev_fwd(
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {
"waves_per_eu": 4,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
......@@ -625,141 +625,21 @@ def decode_attention_fwd_grouped(
# opt
@triton.autotune(
configs=[
triton.Config({"SPLIT_K": 1, "BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 1, "BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 1, "BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 1, "BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 1, "BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 1, "BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 1, "BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 1, "BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 1, "BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 1, "BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 1, "BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 1, "BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 1, "BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 1, "BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 1, "BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
],
key=["B_Seqlen","stride_qbs","stride_buf_kbs","stride_buf_kh"]
)
......@@ -1016,6 +896,7 @@ def _decode_v1_stage1_use_tc(
B_Seqlen,
sm_scale,
page_size,
num_kv_splits,
logit_cap,
):
Lk = k_buffer.shape[-1]
......@@ -1033,12 +914,12 @@ def _decode_v1_stage1_use_tc(
# batch, head_num = B_req_idx.shape[0], q.shape[1]
batch, head_num = q.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[-2]
SPLIT_K = num_kv_splits
BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
grid = lambda META: (
batch,
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
META['SPLIT_K'],
SPLIT_K,
)
_decode_v1_kernel_stage1_use_tc[grid](
q,
......@@ -1060,6 +941,7 @@ def _decode_v1_stage1_use_tc(
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE,
BLOCK_H=BLOCK_H,
SPLIT_K=SPLIT_K,
PAGE_SIZE=page_size,
logit_cap=logit_cap,
Lk=Lk,
......@@ -1119,6 +1001,7 @@ def decode_attention_v1(
b_start_loc,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
logit_cap=0.0,
......@@ -1134,6 +1017,7 @@ def decode_attention_v1(
b_seq_len,
sm_scale,
page_size,
num_kv_splits,
logit_cap,
)
_decode_v1_stage2_best_config = _decode_v1_stage2_use_tc(
......@@ -1152,121 +1036,21 @@ def decode_attention_v1(
@triton.autotune(
configs=[
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 16}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 16}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 16}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 32}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 32}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 32}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 64}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 64}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 64}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 128}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 128}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 128}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 256}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 256}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 256}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 16}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 16}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 16}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 32}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 32}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 32}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 64}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 64}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 64}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 128}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 128}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 128}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 256}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 256}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 256}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 4, "BLOCK_N": 16}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 4, "BLOCK_N": 16}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 4, "BLOCK_N": 64}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 4, "BLOCK_N": 64}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 4, "BLOCK_N": 128}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 4, "BLOCK_N": 128}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 4, "BLOCK_N": 128}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 4, "BLOCK_N": 256}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 4, "BLOCK_N": 256}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 4, "BLOCK_N": 256}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 16}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 16}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 16}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 32}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 32}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 32}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 64}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 64}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 64}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 128}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 128}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 128}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 256}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 256}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 256}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 16}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 16}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 16}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 32}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 32}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 32}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 64}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 64}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 64}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 128}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 128}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 128}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 256}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 256}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 256}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 16}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 16}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 16}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 32}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 32}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 32}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 64}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 64}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 64}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 128}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 128}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 128}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 256}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 256}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 256}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 16}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 16}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 16}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 32}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 32}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 32}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 64}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 64}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 128}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 128}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 128}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 256}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 256}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 256}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 16}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 16}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 16}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 32}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 32}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 32}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 64}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 64}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 64}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 128}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 128}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 256}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 256}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 256}, num_warps=8, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=8, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=8, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=8, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=8, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=8, num_stages=1),
],
key=["B_Seqlen","stride_qbs","stride_buf_kbs","stride_buf_kh", "stride_buf_vbs", "stride_buf_vh"]
)
......@@ -1441,6 +1225,7 @@ def _decode_v2_stage1_use_tc(
Req_to_tokens,
# B_req_idx,
B_Seqlen,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
......@@ -1464,11 +1249,12 @@ def _decode_v2_stage1_use_tc(
batch, head_num = q.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[-2]
BLOCK_H = 16
NUM_KV_SPLITS = num_kv_splits
grid = lambda META: (
batch,
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
META['NUM_KV_SPLITS'],
NUM_KV_SPLITS,
)
_decode_v2_kernel_stage1_use_tc[grid](
......@@ -1496,6 +1282,7 @@ def _decode_v2_stage1_use_tc(
BLOCK_DPE=BLOCK_DPE,
BLOCK_DV=BLOCK_DV,
BLOCK_H=BLOCK_H,
NUM_KV_SPLITS=NUM_KV_SPLITS,
PAGE_SIZE=page_size,
logit_cap=logit_cap,
Lk=Lk,
......@@ -1612,6 +1399,7 @@ def decode_attention_v2(
# b_req_idx,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
logit_cap=0.0,
......@@ -1624,11 +1412,12 @@ def decode_attention_v2(
req_to_token,
# b_req_idx,
b_seq_len,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
)
_decode_v2_stage2_best_config = _decode_v2_stage2_use_tc(attn_logits, q, o, v_buffer, b_seq_len, _decode_v2_stage1_best_config.kwargs["NUM_KV_SPLITS"])
_decode_v2_stage2_best_config = _decode_v2_stage2_use_tc(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits)
return _decode_v2_stage1_best_config, _decode_v2_stage2_best_config
......@@ -1642,6 +1431,7 @@ def decode_attention_fwd(
attn_logits,
num_kv_splits,
sm_scale,
# config,
page_size=1,
logit_cap=0.0,
):
......@@ -1688,25 +1478,64 @@ def decode_attention_fwd(
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
)
# attn_logits_v1 = torch.empty(
# (q.shape[1],k_buffer.shape[0]*page_size),
# dtype=torch.float16,
# device="cuda")
# decode_attention_v1(
# q,
# k_buffer,
# v_buffer,
# o,
# req_to_token,
# b_start_loc,
# b_seq_len,
# attn_logits_v1,
# #num_kv_splits, # sub
# sm_scale,
# page_size,
# logit_cap,
# )
\ No newline at end of file
# attn_logits_v1 = torch.empty(
# (q.shape[1],k_buffer.shape[0]*page_size),
# dtype=torch.float16,
# device="cuda")
# decode_attention_v1(
# q,
# k_buffer,
# v_buffer,
# o,
# req_to_token,
# b_start_loc,
# b_seq_len,
# attn_logits_v1,
# num_kv_splits, # sub
# sm_scale,
# page_size,
# logit_cap,
# )
# if best_config['kernel_kind'] == 'v1_2stages_tc':
# attn_logits_v1 = torch.empty(
# (q.shape[1],k_buffer.shape[0]*page_size),
# dtype=torch.float16,
# device="cuda")
# decode_attention_v1(
# q,
# k_buffer,
# v_buffer,
# o,
# req_to_token,
# b_start_loc,
# b_seq_len,
# attn_logits_v1,
# num_kv_splits,
# sm_scale,
# config,
# page_size,
# logit_cap,
# )
# elif best_config['kernel_kind'] == 'v2_tc':
# decode_attention_v2(
# q,
# k_buffer,
# v_buffer,
# o,
# req_to_token,
# b_seq_len,
# attn_logits,
# num_kv_splits,
# sm_scale,
# config,
# page_size,
# logit_cap,
# )
# else:
# print("Unknown mla kernel kind: ", best_config['kernel_kind'])
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment