Commit 71c60bd5 authored by zhuwenwen's avatar zhuwenwen
Browse files

[feat]add gen_mla_pa_tables.py and optimize triton config configuration

parent 146eb9d3
'''
QH=num_attention_heads KVH=1 QKD=kv_lora_rank + qk_rope_head_dim VD=kv_lora_rank
deepseek v2: QH=16, QKD=512+64=576, VD=512
deepseek v3: QH=128, QKD=512+64=576, VD=512
python3 gen_mla_pa_tables.py --QH 16 --KVH 1 --QKD 576 --VD 512 --TP 1
python3 gen_mla_pa_tables.py --QH 16 --KVH 1 --QKD 576 --VD 512 --TP 2
python3 gen_mla_pa_tables.py --QH 16 --KVH 1 --QKD 576 --VD 512 --TP 4
python3 gen_mla_pa_tables.py --QH 16 --KVH 1 --QKD 576 --VD 512 --TP 8
python3 gen_mla_pa_tables.py --QH 128 --KVH 1 --QKD 576 --VD 512 --TP 32
'''
import argparse
def parse_args():
parser = argparse.ArgumentParser(description="Generate mla pa tables (B, SEQ, QH, KVH, QKD, VD).")
parser.add_argument('--SEQSTART', type=int, required=False, default=100, help='Value for SEQSTART')
parser.add_argument('--SEQEND', type=int, required=False, default=8193, help='Value for SEQEND')
parser.add_argument('--QH', type=int, required=True, help='Value for QH')
parser.add_argument('--KVH', type=int, required=True, help='Value for KVH')
parser.add_argument('--QKD', type=int, required=True, help='Value for QKD')
parser.add_argument('--VD', type=int, required=True, help='Value for VD')
parser.add_argument('--TP', type=int, required=True, help='Value for TP')
return parser.parse_args()
def generate_B():
# 1 4 8 16 ... 128,144 160 ... 256, 288 320 ... 512
B_values = [1, 4] + [i * 8 for i in range(1, 17)]
# B_values = [1, 4] + [i * 8 for i in range(1, 17)] + [(i + 1) * 16 for i in range(8, 16)] + [(i + 1) * 32 for i in range(8, 16)]
return B_values
def generate_SEQ_bs1_32(seq_start, seq_end):
SEQ_values_part1 = list(range(seq_start, 5000, 300))
SEQ_values_part2 = list(range(5000, seq_end, 500))
SEQ_values = [1] + SEQ_values_part1 + SEQ_values_part2
return SEQ_values
def generate_SEQ_bs40_96(seq_start, seq_end):
SEQ_values_part1 = list(range(seq_start, 5000, 200))
SEQ_values_part2 = list(range(5000, seq_end, 300))
SEQ_values = [1] + SEQ_values_part1 + SEQ_values_part2
return SEQ_values
def generate_SEQ_other(seq_start, seq_end):
SEQ_values_part1 = list(range(seq_start, 5000, 300))
SEQ_values_part2 = list(range(5000, seq_end, 500))
SEQ_values = [1] + SEQ_values_part1 + SEQ_values_part2
return SEQ_values
def generate_tuples(seq_start, seq_end, QH, KVH, QKD, VD):
B_values = generate_B()
SEQ_values_b1_32 = generate_SEQ_bs1_32(seq_start, seq_end)
SEQ_values_b40_96 = generate_SEQ_bs40_96(seq_start, seq_end)
SEQ_values_other = generate_SEQ_other(seq_start, seq_end)
pa_tables = []
for B in B_values:
if B <= 32:
for SEQ in SEQ_values_b1_32:
pa_tables.append((B, SEQ, QH, KVH, QKD, VD))
elif B <= 96 and B > 32:
for SEQ in SEQ_values_b40_96:
pa_tables.append((B, SEQ, QH, KVH, QKD, VD))
else:
for SEQ in SEQ_values_other:
pa_tables.append((B, SEQ, QH, KVH, QKD, VD))
return pa_tables
if __name__ == "__main__":
args = parse_args()
assert args.QH % args.TP == 0, "error QH, QH % TP must be 0"
tuple_sizes = generate_tuples(args.SEQSTART, args.SEQEND, args.QH // args.TP, max(args.KVH // args.TP, 1), args.QKD, args.VD)
for t in tuple_sizes:
print(f"{t},")
This diff is collapsed.
import functools
import json
import torch
import os
from enum import Enum
from typing import Any, Dict, Optional, Tuple
import bisect
from vllm.logger import init_logger
logger = init_logger(__name__)
class KERNLE_KINDS(Enum):
v1 = 0
v1_tc = 1
v1_tc_2 = 2
v1_2stages = 3
v1_2stages_tc = 4
v2 = 5
v2_tc = 6
TOTAL_KIND = 7
class BestConfig():
def __init__(self):
self.batch_size = 0
self.seq_len = 0
self.kernel_kind = KERNLE_KINDS.TOTAL_KIND
self.BLOCK_N = 0
self.BLOCK_SEQ = 0
self.SPLIT_K = 0
self.num_stages = 0
self.num_warps = 0
self.NUM_KV_SPLITS = 0
self.BLOCK_N_2 = 0
self.num_stages_2 = 0
self.num_warps_2 = 0
self.best_us = 0
self.decode_fwd_stage1 = None
self.decode_fwd_stage2 = None
def get_mla_config_file_name(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> str:
if cache_dtype == "default":
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_default.json"
device_name = torch.cuda.get_device_name().replace(" ", "_")
if "K100_AI" in device_name:
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_K100AI.json"
elif "BW" in device_name:
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_BW.json"
else:
raise ValueError(f"Unsurpport device name: {device_name}")
def get_config_file_name(QH: int, KVH: int, D: int, cache_dtype: Optional[str]) -> str:
if cache_dtype == "default":
return f"QH={QH}_KVH={KVH}_D={D}_default.json"
device_name = torch.cuda.get_device_name().replace(" ", "_")
if "K100_AI" in device_name:
return f"QH={QH}_KVH={KVH}_D={D}_{cache_dtype}_K100AI.json"
elif "BW" in device_name:
return f"QH={QH}_KVH={KVH}_D={D}_{cache_dtype}_BW.json"
else:
raise ValueError(f"Unsurpport device name: {device_name}")
def get_attention_mla_configs_json(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> Optional[Dict[Any, Any]]:
# First look up if an optimized configuration is available in the configs
# directory
json_file_name = get_mla_config_file_name(QH, KVH, QKD, VD, cache_dtype)
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info("Using decode attention configuration from %s for attention layer.", config_file_path)
# If a configuration has been found, return it
return json.load(f)
else:
logger.warning("Can not find best decode attention configuration %s for attention layer, it may not have the best performance to use default json. Please tune one. ", config_file_path)
json_file_name = get_mla_config_file_name(16, 1, 576, 512, "default")
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.warning("Using default decode attention configuration from %s for attention layer. It may not have the best performance to use default json. ", config_file_path)
# If a configuration has been found, return it
return json.load(f)
else:
raise ValueError("Please surpport default config can match 16 1 576 512")
# If no optimized configuration is available, we will use the default
# configuration
return None
def get_config_map(attention_configs):
ret_map = {}
for bs in attention_configs.keys():
int_bs = int(bs)
seq_map = {}
seq_configs = attention_configs[bs]
ret_map[int_bs] = seq_map
for seq_len in seq_configs.keys():
int_seq_len = int(seq_len)
kind_config = seq_configs[seq_len]
configs = BestConfig()
configs.batch_size = int_bs
configs.seq_len = int_seq_len
configs.best_us = kind_config['best_us']
seq_map[int_seq_len] = configs
if kind_config['kernel_kind'] == 'v1':
best_config = kind_config['best_config']
configs.kernel_kind = KERNLE_KINDS.v1
configs.BLOCK_N = best_config['BLOCK_N']
configs.num_stages = best_config['num_stages']
configs.num_warps = best_config['num_warps']
elif kind_config['kernel_kind'] == 'v1_tc':
best_config = kind_config['best_config']
configs.kernel_kind = KERNLE_KINDS.v1_tc
configs.BLOCK_N = best_config['BLOCK_N']
configs.num_stages = best_config['num_stages']
configs.num_warps = best_config['num_warps']
elif kind_config['kernel_kind'] == 'v1_tc_2':
best_config = kind_config['best_config']
configs.kernel_kind = KERNLE_KINDS.v1_tc_2
configs.BLOCK_N = best_config['BLOCK_N']
configs.num_stages = best_config['num_stages']
configs.num_warps = best_config['num_warps']
elif kind_config['kernel_kind'] == 'v1_2stages':
best_config = kind_config['best_config']
stage1 = best_config['stage1']
stage2 = best_config['stage2']
configs.kernel_kind = KERNLE_KINDS.v1_2stages
# configs.SPLIT_K = stage1['SPLIT_K']
configs.BLOCK_N = stage1['BLOCK_N']
configs.num_stages = stage1['num_stages']
configs.num_warps = stage1['num_warps']
configs.BLOCK_N_2 = stage2['BLOCK_N']
configs.num_stages_2 = stage2['num_stages']
configs.num_warps_2 = stage2['num_warps']
elif kind_config['kernel_kind'] == 'v1_2stages_tc':
best_config = kind_config['best_config']
stage1 = best_config['stage1']
stage2 = best_config['stage2']
configs.kernel_kind = KERNLE_KINDS.v1_2stages_tc
# configs.SPLIT_K = stage1['SPLIT_K']
configs.BLOCK_N = stage1['BLOCK_N']
configs.num_stages = stage1['num_stages']
configs.num_warps = stage1['num_warps']
configs.BLOCK_N_2 = stage2['BLOCK_N']
configs.num_stages_2 = stage2['num_stages']
configs.num_warps_2 = stage2['num_warps']
elif kind_config['kernel_kind'] == 'v2':
best_config = kind_config['best_config']
stage1 = best_config['stage1']
stage2 = best_config['stage2']
configs.kernel_kind = KERNLE_KINDS.v2
if 'BLOCK_SEQ' in stage1:
configs.BLOCK_SEQ = stage1['BLOCK_SEQ']
else:
configs.NUM_KV_SPLITS = stage1['NUM_KV_SPLITS']
configs.BLOCK_N = stage1['BLOCK_N']
configs.num_stages = stage1['num_stages']
configs.num_warps = stage1['num_warps']
configs.num_stages_2 = stage2['num_stages']
configs.num_warps_2 = stage2['num_warps']
elif kind_config['kernel_kind'] == 'v2_tc':
best_config = kind_config['best_config']
stage1 = best_config['stage1']
stage2 = best_config['stage2']
configs.kernel_kind = KERNLE_KINDS.v2_tc
if 'BLOCK_SEQ' in stage1:
configs.BLOCK_SEQ = stage1['BLOCK_SEQ']
else:
configs.NUM_KV_SPLITS = stage1['NUM_KV_SPLITS']
configs.BLOCK_N = stage1['BLOCK_N']
configs.num_stages = stage1['num_stages']
configs.num_warps = stage1['num_warps']
configs.num_stages_2 = stage2['num_stages']
configs.num_warps_2 = stage2['num_warps']
return ret_map
@functools.lru_cache
def get_attention_mla_configs(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> Optional[Dict[Any, Any]]:
attention_configs = get_attention_mla_configs_json(QH, KVH, QKD, VD, cache_dtype)
return get_config_map(attention_configs)
def get_closest_key(dic_keys, target_key):
keys = list(dic_keys)
idx = bisect.bisect_left(keys, target_key)
if idx == 0:
return keys[0]
if idx == len(keys):
return keys[-1]
left_key = keys[idx - 1]
right_key = keys[idx]
if target_key - left_key <= right_key - target_key:
return left_key
else:
return right_key
def get_nearest_config(bs_key, mean_kv_seqlen_key, config):
closest_bs_key = get_closest_key(config.keys(), bs_key)
closest_mean_kv_seqlen_key = get_closest_key(config[closest_bs_key].keys(), mean_kv_seqlen_key)
return config[closest_bs_key][closest_mean_kv_seqlen_key]
def get_config(bs_key, mean_kv_seqlen_key, config):
if bs_key in config and mean_kv_seqlen_key in config[bs_key]:
return config[bs_key][mean_kv_seqlen_key]
else:
raise ValueError(f"No matching configuration found for bs key: {bs_key} and mean kv seq key: {mean_kv_seqlen_key} when init decode attention db")
# SPDX-License-Identifier: Apache-2.0
import os
import functools
import json
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
......@@ -10,6 +7,7 @@ from itertools import accumulate
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
from vllm.multimodal import MultiModalPlaceholderMap
from .triton_config import get_nearest_config, get_attention_mla_configs, get_config
try:
from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper
......@@ -39,65 +37,6 @@ if TYPE_CHECKING:
ModelInputForGPUWithSamplingMetadata)
def get_config(bs_key, mean_kv_seqlen_key, config):
# 转换参数为字符串以匹配字典的键
bs_key_str = str(bs_key)
mean_kv_seqlen_key_str = str(mean_kv_seqlen_key)
# 检查字典中是否存在对应的配置
if bs_key_str in config and mean_kv_seqlen_key_str in config[bs_key_str]:
return config[bs_key_str][mean_kv_seqlen_key_str]
else:
raise ValueError(f"No matching configuration found for bs key: {bs_key} and mean kv seq key: {mean_kv_seqlen_key} when init decode attention db")
def get_mla_config_file_name(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> str:
if cache_dtype == "default":
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_default.json"
device_name = torch.cuda.get_device_name().replace(" ", "_")
if "K100_AI" in device_name:
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_K100AI.json"
elif "BW" in device_name:
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_BW.json"
else:
raise ValueError(f"Unsurpport device name: {device_name}")
@functools.lru_cache
def get_attention_mla_configs(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> Optional[Dict[Any, Any]]:
# First look up if an optimized configuration is available in the configs
# directory
json_file_name = get_mla_config_file_name(QH, KVH, QKD, VD, cache_dtype)
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info("Using decode attention configuration from %s for attention layer.", config_file_path)
# If a configuration has been found, return it
return json.load(f)
else:
logger.warning("Can not find best decode attention configuration %s for attention layer, it may not have the best performance to use default json. Please tune one. ", config_file_path)
json_file_name = get_mla_config_file_name(16, 1, 576, 512, "default")
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.warning("Using default decode attention configuration from %s for attention layer. It may not have the best performance to use default json. ", config_file_path)
# If a configuration has been found, return it
return json.load(f)
else:
raise ValueError("Please surpport default config can match 16 1 576 512")
# If no optimized configuration is available, we will use the default
# configuration
return None
class TritonMLABackend(AttentionBackend):
@staticmethod
......
......@@ -36,6 +36,7 @@ import triton.language as tl
from vllm.platforms import current_platform
from vllm import envs
from ..backends.triton_config import KERNLE_KINDS
is_hip_ = current_platform.is_rocm()
os.environ["TRITON_HIP_USE_NEW_STREAM_PIPELINE"] = f"0"
......@@ -897,8 +898,8 @@ def _decode_v1_stage1_use_tc(
sm_scale,
page_size,
num_kv_splits,
logit_cap,
best_config,
logit_cap,
):
Lk = k_buffer.shape[-1]
......@@ -916,17 +917,18 @@ def _decode_v1_stage1_use_tc(
batch, head_num = q.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[-2]
BLOCK_N = best_config['BLOCK_N']
SPLIT_K = num_kv_splits # best_config['SPLIT_K'] ?
num_stages = best_config['num_stages']
num_warps = best_config['num_warps']
BLOCK_N = best_config.BLOCK_N
SPLIT_K = num_kv_splits # best_config.SPLIT_K
num_stages = best_config.num_stages
num_warps = best_config.num_warps
BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
grid = lambda META: (
batch,
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
SPLIT_K,
)
_decode_v1_kernel_stage1_use_tc[grid](
if best_config.decode_fwd_stage1 is None:
best_config.decode_fwd_stage1 = _decode_v1_kernel_stage1_use_tc[grid](
q,
k_buffer,
sm_scale,
......@@ -955,6 +957,24 @@ def _decode_v1_stage1_use_tc(
Lk=Lk,
kpack=2,
)
else:
best_config.decode_fwd_stage1 = _decode_v1_kernel_stage1_use_tc[grid](
q,
k_buffer,
sm_scale,
Req_to_tokens,
#B_req_idx,
B_Start_Loc,
B_Seqlen,
att_out,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(-3),
k_buffer.stride(-2),
att_out.stride(0),
)
# return _decode_v1_kernel_stage1_use_tc.best_config
......@@ -966,21 +986,22 @@ def _decode_v1_stage2_use_tc(
#b_req_idx,
b_start_loc,
b_seq_len,
page_size,
best_config,
page_size,
):
batch, head_num = b_seq_len.shape[0], logits.shape[0]
kv_group_num = logits.shape[0] // v_buffer.shape[-2]
BLOCK_N = best_config['BLOCK_N']
num_stages = best_config['num_stages']
num_warps = best_config['num_warps']
BLOCK_N = best_config.BLOCK_N_2
num_stages = best_config.num_stages_2
num_warps = best_config.num_warps_2
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)
Lv = v_buffer.shape[-1]
BLOCK_DMODEL = triton.next_power_of_2(Lv)
_decode_v1_kernel_stage2_use_tc[grid](
if best_config.decode_fwd_stage2 is None:
best_config.decode_fwd_stage2 = _decode_v1_kernel_stage2_use_tc[grid](
logits,
v_buffer,
o,
......@@ -1004,6 +1025,23 @@ def _decode_v1_stage2_use_tc(
num_warps=num_warps,
num_stages=num_stages,
)
else:
best_config.decode_fwd_stage2 = _decode_v1_kernel_stage2_use_tc[grid](
logits,
v_buffer,
o,
req_to_tokens,
#b_req_idx,
b_start_loc,
b_seq_len,
logits.stride(0),
v_buffer.stride(-3),
v_buffer.stride(-2),
o.stride(0),
o.stride(1),
req_to_tokens.stride(0),
)
# return _decode_v1_kernel_stage2_use_tc.best_config
......@@ -1059,8 +1097,8 @@ def decode_attention_v1(
sm_scale,
page_size,
num_kv_splits,
best_config,
logit_cap,
best_config['stage1'],
)
_decode_v1_stage2_use_tc(
attn_logits,
......@@ -1070,12 +1108,11 @@ def decode_attention_v1(
#b_req_idx,
b_start_loc,
b_seq_len,
best_config,
page_size,
best_config['stage2'],
)
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_N": 16}, num_warps=2, num_stages=1),
......@@ -1274,9 +1311,9 @@ def _decode_v2_stage1_use_tc(
logit_cap,
):
BLOCK = best_config['BLOCK_N']
num_stages = best_config['num_stages']
num_warps = best_config['num_warps']
BLOCK = best_config.BLOCK_N
num_stages = best_config.num_stages
num_warps = best_config.num_warps
Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1]
......@@ -1304,7 +1341,8 @@ def _decode_v2_stage1_use_tc(
NUM_KV_SPLITS,
)
_decode_v2_kernel_stage1_use_tc[grid](
if best_config.decode_fwd_stage1 is None:
best_config.decode_fwd_stage1 =_decode_v2_kernel_stage1_use_tc[grid](
q,
k_buffer,
v_buffer,
......@@ -1339,8 +1377,30 @@ def _decode_v2_stage1_use_tc(
Lv=Lv,
kpack=2,
)
else:
best_config.decode_fwd_stage1[grid](
q,
k_buffer,
v_buffer,
sm_scale,
Req_to_tokens,
# B_req_idx,
B_Seqlen,
att_out,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(-3),
k_buffer.stride(-2),
v_buffer.stride(-3),
v_buffer.stride(-2),
att_out.stride(0),
att_out.stride(1),
att_out.stride(2),
)
# return _decode_v2_kernel_stage1_use_tc.best_config
# @triton.autotune(
# configs=[
# triton.Config({}, num_warps=1, num_stages=1),
......@@ -1416,8 +1476,8 @@ def _decode_v2_stage2_use_tc(
num_kv_splits,
best_config,
):
num_stages = best_config['num_stages']
num_warps = best_config['num_warps']
num_stages = best_config.num_stages_2
num_warps = best_config.num_warps_2
batch, head_num = q.shape[0], q.shape[1]
Lv = v_buffer.shape[-1]
......@@ -1425,8 +1485,9 @@ def _decode_v2_stage2_use_tc(
NUM_KV_SPLITS = num_kv_splits
grid = (batch, head_num)
_decode_v2_kernel_stage2[grid](
grid = (batch, head_num, 1)
if best_config.decode_fwd_stage2 is None:
best_config.decode_fwd_stage2 = _decode_v2_kernel_stage2[grid](
logits,
o,
b_seq_len,
......@@ -1441,6 +1502,17 @@ def _decode_v2_stage2_use_tc(
num_warps=num_warps,
num_stages=num_stages,
)
else:
best_config.decode_fwd_stage2[grid](
logits,
o,
b_seq_len,
logits.stride(0),
logits.stride(1),
logits.stride(2),
o.stride(0),
o.stride(1),
)
# return _decode_v2_kernel_stage2.best_config
......@@ -1485,11 +1557,11 @@ def decode_attention_v2(
b_seq_len,
num_kv_splits,
sm_scale,
best_config,
page_size,
logit_cap,
best_config['stage1'],
)
_decode_v2_stage2_use_tc(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits, best_config['stage2'])
_decode_v2_stage2_use_tc(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits, best_config)
def decode_attention_fwd(
......@@ -1560,7 +1632,7 @@ def decode_attention_fwd(
logit_cap,
)'''
if best_config['kernel_kind'] == 'v1_2stages_tc':
if best_config.kernel_kind == KERNLE_KINDS.v1_2stages_tc:
attn_logits_v1 = torch.empty(
(q.shape[1],k_buffer.shape[0]*page_size),
dtype=torch.float16,
......@@ -1576,11 +1648,11 @@ def decode_attention_fwd(
attn_logits_v1,
num_kv_splits,
sm_scale,
best_config=best_config['best_config'],
best_config=best_config,
page_size=page_size,
logit_cap=logit_cap,
)
elif best_config['kernel_kind'] == 'v2_tc':
elif best_config.kernel_kind == KERNLE_KINDS.v2_tc:
decode_attention_v2(
q,
k_buffer,
......@@ -1591,7 +1663,7 @@ def decode_attention_fwd(
attn_logits,
num_kv_splits,
sm_scale,
best_config=best_config['best_config'],
best_config=best_config,
page_size=page_size,
logit_cap=logit_cap,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment