Commit 52121d00 authored by zhuwenwen's avatar zhuwenwen
Browse files

update triton mla

parent 40083064
...@@ -488,10 +488,10 @@ def get_version_add(sha: Optional[str] = None) -> str: ...@@ -488,10 +488,10 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha is None: if sha is None:
sha = get_sha(vllm_root) sha = get_sha(vllm_root)
if (major, minor) == ('2', '4'): if (major, minor) == ('2', '4'):
version = 'das.opt1.cust1.' + sha[:7] version = 'das.opt1.' + sha[:7]
else: else:
if (major, minor) == ('2', '4'): if (major, minor) == ('2', '4'):
version = 'das.opt1.cust1' version = 'das.opt1'
# dtk version # dtk version
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os
import functools
import json
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
...@@ -16,6 +19,8 @@ except ImportError: ...@@ -16,6 +19,8 @@ except ImportError:
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
import torch import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, from vllm.attention.backends.abstract import (AttentionBackend,
...@@ -34,6 +39,54 @@ if TYPE_CHECKING: ...@@ -34,6 +39,54 @@ if TYPE_CHECKING:
ModelInputForGPUWithSamplingMetadata) ModelInputForGPUWithSamplingMetadata)
def get_mla_config_file_name(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> str:
if cache_dtype == "default":
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_default.json"
device_name = torch.cuda.get_device_name().replace(" ", "_")
if "K100_AI" in device_name:
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_K100AI.json"
elif "BW" in device_name:
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_BW.json"
else:
raise ValueError(f"Unsurpport device name: {device_name}")
@functools.lru_cache
def get_attention_mla_configs(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> Optional[Dict[Any, Any]]:
# First look up if an optimized configuration is available in the configs
# directory
json_file_name = get_mla_config_file_name(QH, KVH, QKD, VD, cache_dtype)
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info("Using decode attention configuration from %s for attention layer.", config_file_path)
# If a configuration has been found, return it
return json.load(f)
else:
logger.warning("Can not find best decode attention configuration %s for attention layer, it may not have the best performance to use default json. Please tune one. ", config_file_path)
json_file_name = get_mla_config_file_name(16, 1, 576, 512, "default")
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.warning("Using default decode attention configuration from %s for attention layer. It may not have the best performance to use default json. ", config_file_path)
# If a configuration has been found, return it
return json.load(f)
else:
raise ValueError("Please surpport default config can match 16 1 576 512")
# If no optimized configuration is available, we will use the default
# configuration
return None
class TritonMLABackend(AttentionBackend): class TritonMLABackend(AttentionBackend):
@staticmethod @staticmethod
...@@ -736,11 +789,13 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]): ...@@ -736,11 +789,13 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
PAGE_SIZE = kv_c_and_k_pe_cache.size(1) PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
# config = get_attention_mla_configs(self.num_heads, 1, self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank, "fp16")
# Run MQA # Run MQA
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
decode_meta.block_tables, decode_meta.block_tables,
decode_meta.seq_lens_tensor, attn_logits, decode_meta.seq_lens_tensor, attn_logits,
attn_metadata.num_kv_splits, self.scale, attn_metadata.num_kv_splits, self.scale, # config,
PAGE_SIZE) PAGE_SIZE)
return self._v_up_proj_and_o_proj(o) return self._v_up_proj_and_o_proj(o)
...@@ -540,7 +540,7 @@ def _decode_softmax_reducev_fwd( ...@@ -540,7 +540,7 @@ def _decode_softmax_reducev_fwd(
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = { extra_kargs = {
"waves_per_eu": 4, "waves_per_eu": 0,
"matrix_instr_nonkdim": 16, "matrix_instr_nonkdim": 16,
"kpack": 2 "kpack": 2
} }
...@@ -625,141 +625,21 @@ def decode_attention_fwd_grouped( ...@@ -625,141 +625,21 @@ def decode_attention_fwd_grouped(
# opt # opt
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({"SPLIT_K": 1, "BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 1, "BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 1, "BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 1, "BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 1, "BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 1, "BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 1, "BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 1, "BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 1, "BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 1, "BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 1, "BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 1, "BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 1, "BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 1, "BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 1, "BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 2, "BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 4, "BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 8, "BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 16, "BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 32, "BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 64, "BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 128, "BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"SPLIT_K": 256, "BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
], ],
key=["B_Seqlen","stride_qbs","stride_buf_kbs","stride_buf_kh"] key=["B_Seqlen","stride_qbs","stride_buf_kbs","stride_buf_kh"]
) )
...@@ -1016,6 +896,7 @@ def _decode_v1_stage1_use_tc( ...@@ -1016,6 +896,7 @@ def _decode_v1_stage1_use_tc(
B_Seqlen, B_Seqlen,
sm_scale, sm_scale,
page_size, page_size,
num_kv_splits,
logit_cap, logit_cap,
): ):
Lk = k_buffer.shape[-1] Lk = k_buffer.shape[-1]
...@@ -1033,12 +914,12 @@ def _decode_v1_stage1_use_tc( ...@@ -1033,12 +914,12 @@ def _decode_v1_stage1_use_tc(
# batch, head_num = B_req_idx.shape[0], q.shape[1] # batch, head_num = B_req_idx.shape[0], q.shape[1]
batch, head_num = q.shape[0], q.shape[1] batch, head_num = q.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[-2] kv_group_num = q.shape[1] // k_buffer.shape[-2]
SPLIT_K = num_kv_splits
BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num))) BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
grid = lambda META: ( grid = lambda META: (
batch, batch,
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
META['SPLIT_K'], SPLIT_K,
) )
_decode_v1_kernel_stage1_use_tc[grid]( _decode_v1_kernel_stage1_use_tc[grid](
q, q,
...@@ -1060,6 +941,7 @@ def _decode_v1_stage1_use_tc( ...@@ -1060,6 +941,7 @@ def _decode_v1_stage1_use_tc(
BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE, BLOCK_DPE=BLOCK_DPE,
BLOCK_H=BLOCK_H, BLOCK_H=BLOCK_H,
SPLIT_K=SPLIT_K,
PAGE_SIZE=page_size, PAGE_SIZE=page_size,
logit_cap=logit_cap, logit_cap=logit_cap,
Lk=Lk, Lk=Lk,
...@@ -1119,6 +1001,7 @@ def decode_attention_v1( ...@@ -1119,6 +1001,7 @@ def decode_attention_v1(
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
attn_logits, attn_logits,
num_kv_splits,
sm_scale, sm_scale,
page_size, page_size,
logit_cap=0.0, logit_cap=0.0,
...@@ -1134,6 +1017,7 @@ def decode_attention_v1( ...@@ -1134,6 +1017,7 @@ def decode_attention_v1(
b_seq_len, b_seq_len,
sm_scale, sm_scale,
page_size, page_size,
num_kv_splits,
logit_cap, logit_cap,
) )
_decode_v1_stage2_best_config = _decode_v1_stage2_use_tc( _decode_v1_stage2_best_config = _decode_v1_stage2_use_tc(
...@@ -1152,121 +1036,21 @@ def decode_attention_v1( ...@@ -1152,121 +1036,21 @@ def decode_attention_v1(
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 16}, num_warps=2, num_stages=1), triton.Config({"BLOCK_N": 16}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 16}, num_warps=4, num_stages=1), triton.Config({"BLOCK_N": 16}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 16}, num_warps=8, num_stages=1), triton.Config({"BLOCK_N": 16}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 32}, num_warps=2, num_stages=1), triton.Config({"BLOCK_N": 32}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 32}, num_warps=4, num_stages=1), triton.Config({"BLOCK_N": 32}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 32}, num_warps=8, num_stages=1), triton.Config({"BLOCK_N": 32}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 64}, num_warps=2, num_stages=1), triton.Config({"BLOCK_N": 64}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 64}, num_warps=4, num_stages=1), triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 64}, num_warps=8, num_stages=1), triton.Config({"BLOCK_N": 64}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 128}, num_warps=2, num_stages=1), triton.Config({"BLOCK_N": 128}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 128}, num_warps=4, num_stages=1), triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 128}, num_warps=8, num_stages=1), triton.Config({"BLOCK_N": 128}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 256}, num_warps=2, num_stages=1), triton.Config({"BLOCK_N": 256}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 256}, num_warps=4, num_stages=1), triton.Config({"BLOCK_N": 256}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 1, "BLOCK_N": 256}, num_warps=8, num_stages=1), triton.Config({"BLOCK_N": 256}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 16}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 16}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 16}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 32}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 32}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 32}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 64}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 64}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 64}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 128}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 128}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 128}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 256}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 256}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 2, "BLOCK_N": 256}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 4, "BLOCK_N": 16}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 4, "BLOCK_N": 16}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 4, "BLOCK_N": 64}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 4, "BLOCK_N": 64}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 4, "BLOCK_N": 128}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 4, "BLOCK_N": 128}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 4, "BLOCK_N": 128}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 4, "BLOCK_N": 256}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 4, "BLOCK_N": 256}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 4, "BLOCK_N": 256}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 16}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 16}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 16}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 32}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 32}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 32}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 64}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 64}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 64}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 128}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 128}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 128}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 256}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 256}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 8, "BLOCK_N": 256}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 16}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 16}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 16}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 32}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 32}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 32}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 64}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 64}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 64}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 128}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 128}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 128}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 256}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 256}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 16, "BLOCK_N": 256}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 16}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 16}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 16}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 32}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 32}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 32}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 64}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 64}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 64}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 128}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 128}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 128}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 256}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 256}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 32, "BLOCK_N": 256}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 16}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 16}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 16}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 32}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 32}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 32}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 64}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 64}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 128}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 128}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 128}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 256}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 256}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 64, "BLOCK_N": 256}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 16}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 16}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 16}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 32}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 32}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 32}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 64}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 64}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 64}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 128}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 128}, num_warps=8, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 256}, num_warps=2, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 256}, num_warps=4, num_stages=1),
triton.Config({"NUM_KV_SPLITS": 128, "BLOCK_N": 256}, num_warps=8, num_stages=1),
], ],
key=["B_Seqlen","stride_qbs","stride_buf_kbs","stride_buf_kh", "stride_buf_vbs", "stride_buf_vh"] key=["B_Seqlen","stride_qbs","stride_buf_kbs","stride_buf_kh", "stride_buf_vbs", "stride_buf_vh"]
) )
...@@ -1441,6 +1225,7 @@ def _decode_v2_stage1_use_tc( ...@@ -1441,6 +1225,7 @@ def _decode_v2_stage1_use_tc(
Req_to_tokens, Req_to_tokens,
# B_req_idx, # B_req_idx,
B_Seqlen, B_Seqlen,
num_kv_splits,
sm_scale, sm_scale,
page_size, page_size,
logit_cap, logit_cap,
...@@ -1464,11 +1249,12 @@ def _decode_v2_stage1_use_tc( ...@@ -1464,11 +1249,12 @@ def _decode_v2_stage1_use_tc(
batch, head_num = q.shape[0], q.shape[1] batch, head_num = q.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[-2] kv_group_num = q.shape[1] // k_buffer.shape[-2]
BLOCK_H = 16 BLOCK_H = 16
NUM_KV_SPLITS = num_kv_splits
grid = lambda META: ( grid = lambda META: (
batch, batch,
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
META['NUM_KV_SPLITS'], NUM_KV_SPLITS,
) )
_decode_v2_kernel_stage1_use_tc[grid]( _decode_v2_kernel_stage1_use_tc[grid](
...@@ -1496,6 +1282,7 @@ def _decode_v2_stage1_use_tc( ...@@ -1496,6 +1282,7 @@ def _decode_v2_stage1_use_tc(
BLOCK_DPE=BLOCK_DPE, BLOCK_DPE=BLOCK_DPE,
BLOCK_DV=BLOCK_DV, BLOCK_DV=BLOCK_DV,
BLOCK_H=BLOCK_H, BLOCK_H=BLOCK_H,
NUM_KV_SPLITS=NUM_KV_SPLITS,
PAGE_SIZE=page_size, PAGE_SIZE=page_size,
logit_cap=logit_cap, logit_cap=logit_cap,
Lk=Lk, Lk=Lk,
...@@ -1612,6 +1399,7 @@ def decode_attention_v2( ...@@ -1612,6 +1399,7 @@ def decode_attention_v2(
# b_req_idx, # b_req_idx,
b_seq_len, b_seq_len,
attn_logits, attn_logits,
num_kv_splits,
sm_scale, sm_scale,
page_size, page_size,
logit_cap=0.0, logit_cap=0.0,
...@@ -1624,11 +1412,12 @@ def decode_attention_v2( ...@@ -1624,11 +1412,12 @@ def decode_attention_v2(
req_to_token, req_to_token,
# b_req_idx, # b_req_idx,
b_seq_len, b_seq_len,
num_kv_splits,
sm_scale, sm_scale,
page_size, page_size,
logit_cap, logit_cap,
) )
_decode_v2_stage2_best_config = _decode_v2_stage2_use_tc(attn_logits, q, o, v_buffer, b_seq_len, _decode_v2_stage1_best_config.kwargs["NUM_KV_SPLITS"]) _decode_v2_stage2_best_config = _decode_v2_stage2_use_tc(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits)
return _decode_v2_stage1_best_config, _decode_v2_stage2_best_config return _decode_v2_stage1_best_config, _decode_v2_stage2_best_config
...@@ -1642,6 +1431,7 @@ def decode_attention_fwd( ...@@ -1642,6 +1431,7 @@ def decode_attention_fwd(
attn_logits, attn_logits,
num_kv_splits, num_kv_splits,
sm_scale, sm_scale,
# config,
page_size=1, page_size=1,
logit_cap=0.0, logit_cap=0.0,
): ):
...@@ -1688,6 +1478,7 @@ def decode_attention_fwd( ...@@ -1688,6 +1478,7 @@ def decode_attention_fwd(
req_to_token, req_to_token,
b_seq_len, b_seq_len,
attn_logits, attn_logits,
num_kv_splits,
sm_scale, sm_scale,
page_size, page_size,
logit_cap, logit_cap,
...@@ -1705,8 +1496,46 @@ def decode_attention_fwd( ...@@ -1705,8 +1496,46 @@ def decode_attention_fwd(
# b_start_loc, # b_start_loc,
# b_seq_len, # b_seq_len,
# attn_logits_v1, # attn_logits_v1,
# #num_kv_splits, # sub # num_kv_splits, # sub
# sm_scale,
# page_size,
# logit_cap,
# )
# if best_config['kernel_kind'] == 'v1_2stages_tc':
# attn_logits_v1 = torch.empty(
# (q.shape[1],k_buffer.shape[0]*page_size),
# dtype=torch.float16,
# device="cuda")
# decode_attention_v1(
# q,
# k_buffer,
# v_buffer,
# o,
# req_to_token,
# b_start_loc,
# b_seq_len,
# attn_logits_v1,
# num_kv_splits,
# sm_scale,
# config,
# page_size,
# logit_cap,
# )
# elif best_config['kernel_kind'] == 'v2_tc':
# decode_attention_v2(
# q,
# k_buffer,
# v_buffer,
# o,
# req_to_token,
# b_seq_len,
# attn_logits,
# num_kv_splits,
# sm_scale, # sm_scale,
# config,
# page_size, # page_size,
# logit_cap, # logit_cap,
# ) # )
# else:
# print("Unknown mla kernel kind: ", best_config['kernel_kind'])
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment