Commit 71c60bd5 authored by zhuwenwen's avatar zhuwenwen
Browse files

[feat]add gen_mla_pa_tables.py and optimize triton config configuration

parent 146eb9d3
'''
QH=num_attention_heads KVH=1 QKD=kv_lora_rank + qk_rope_head_dim VD=kv_lora_rank
deepseek v2: QH=16, QKD=512+64=576, VD=512
deepseek v3: QH=128, QKD=512+64=576, VD=512
python3 gen_mla_pa_tables.py --QH 16 --KVH 1 --QKD 576 --VD 512 --TP 1
python3 gen_mla_pa_tables.py --QH 16 --KVH 1 --QKD 576 --VD 512 --TP 2
python3 gen_mla_pa_tables.py --QH 16 --KVH 1 --QKD 576 --VD 512 --TP 4
python3 gen_mla_pa_tables.py --QH 16 --KVH 1 --QKD 576 --VD 512 --TP 8
python3 gen_mla_pa_tables.py --QH 128 --KVH 1 --QKD 576 --VD 512 --TP 32
'''
import argparse
def parse_args():
parser = argparse.ArgumentParser(description="Generate mla pa tables (B, SEQ, QH, KVH, QKD, VD).")
parser.add_argument('--SEQSTART', type=int, required=False, default=100, help='Value for SEQSTART')
parser.add_argument('--SEQEND', type=int, required=False, default=8193, help='Value for SEQEND')
parser.add_argument('--QH', type=int, required=True, help='Value for QH')
parser.add_argument('--KVH', type=int, required=True, help='Value for KVH')
parser.add_argument('--QKD', type=int, required=True, help='Value for QKD')
parser.add_argument('--VD', type=int, required=True, help='Value for VD')
parser.add_argument('--TP', type=int, required=True, help='Value for TP')
return parser.parse_args()
def generate_B():
# 1 4 8 16 ... 128,144 160 ... 256, 288 320 ... 512
B_values = [1, 4] + [i * 8 for i in range(1, 17)]
# B_values = [1, 4] + [i * 8 for i in range(1, 17)] + [(i + 1) * 16 for i in range(8, 16)] + [(i + 1) * 32 for i in range(8, 16)]
return B_values
def generate_SEQ_bs1_32(seq_start, seq_end):
SEQ_values_part1 = list(range(seq_start, 5000, 300))
SEQ_values_part2 = list(range(5000, seq_end, 500))
SEQ_values = [1] + SEQ_values_part1 + SEQ_values_part2
return SEQ_values
def generate_SEQ_bs40_96(seq_start, seq_end):
SEQ_values_part1 = list(range(seq_start, 5000, 200))
SEQ_values_part2 = list(range(5000, seq_end, 300))
SEQ_values = [1] + SEQ_values_part1 + SEQ_values_part2
return SEQ_values
def generate_SEQ_other(seq_start, seq_end):
SEQ_values_part1 = list(range(seq_start, 5000, 300))
SEQ_values_part2 = list(range(5000, seq_end, 500))
SEQ_values = [1] + SEQ_values_part1 + SEQ_values_part2
return SEQ_values
def generate_tuples(seq_start, seq_end, QH, KVH, QKD, VD):
B_values = generate_B()
SEQ_values_b1_32 = generate_SEQ_bs1_32(seq_start, seq_end)
SEQ_values_b40_96 = generate_SEQ_bs40_96(seq_start, seq_end)
SEQ_values_other = generate_SEQ_other(seq_start, seq_end)
pa_tables = []
for B in B_values:
if B <= 32:
for SEQ in SEQ_values_b1_32:
pa_tables.append((B, SEQ, QH, KVH, QKD, VD))
elif B <= 96 and B > 32:
for SEQ in SEQ_values_b40_96:
pa_tables.append((B, SEQ, QH, KVH, QKD, VD))
else:
for SEQ in SEQ_values_other:
pa_tables.append((B, SEQ, QH, KVH, QKD, VD))
return pa_tables
if __name__ == "__main__":
args = parse_args()
assert args.QH % args.TP == 0, "error QH, QH % TP must be 0"
tuple_sizes = generate_tuples(args.SEQSTART, args.SEQEND, args.QH // args.TP, max(args.KVH // args.TP, 1), args.QKD, args.VD)
for t in tuple_sizes:
print(f"{t},")
This diff is collapsed.
import functools
import json
import torch
import os
from enum import Enum
from typing import Any, Dict, Optional, Tuple
import bisect
from vllm.logger import init_logger
logger = init_logger(__name__)
class KERNLE_KINDS(Enum):
v1 = 0
v1_tc = 1
v1_tc_2 = 2
v1_2stages = 3
v1_2stages_tc = 4
v2 = 5
v2_tc = 6
TOTAL_KIND = 7
class BestConfig():
def __init__(self):
self.batch_size = 0
self.seq_len = 0
self.kernel_kind = KERNLE_KINDS.TOTAL_KIND
self.BLOCK_N = 0
self.BLOCK_SEQ = 0
self.SPLIT_K = 0
self.num_stages = 0
self.num_warps = 0
self.NUM_KV_SPLITS = 0
self.BLOCK_N_2 = 0
self.num_stages_2 = 0
self.num_warps_2 = 0
self.best_us = 0
self.decode_fwd_stage1 = None
self.decode_fwd_stage2 = None
def get_mla_config_file_name(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> str:
if cache_dtype == "default":
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_default.json"
device_name = torch.cuda.get_device_name().replace(" ", "_")
if "K100_AI" in device_name:
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_K100AI.json"
elif "BW" in device_name:
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_BW.json"
else:
raise ValueError(f"Unsurpport device name: {device_name}")
def get_config_file_name(QH: int, KVH: int, D: int, cache_dtype: Optional[str]) -> str:
if cache_dtype == "default":
return f"QH={QH}_KVH={KVH}_D={D}_default.json"
device_name = torch.cuda.get_device_name().replace(" ", "_")
if "K100_AI" in device_name:
return f"QH={QH}_KVH={KVH}_D={D}_{cache_dtype}_K100AI.json"
elif "BW" in device_name:
return f"QH={QH}_KVH={KVH}_D={D}_{cache_dtype}_BW.json"
else:
raise ValueError(f"Unsurpport device name: {device_name}")
def get_attention_mla_configs_json(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> Optional[Dict[Any, Any]]:
# First look up if an optimized configuration is available in the configs
# directory
json_file_name = get_mla_config_file_name(QH, KVH, QKD, VD, cache_dtype)
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info("Using decode attention configuration from %s for attention layer.", config_file_path)
# If a configuration has been found, return it
return json.load(f)
else:
logger.warning("Can not find best decode attention configuration %s for attention layer, it may not have the best performance to use default json. Please tune one. ", config_file_path)
json_file_name = get_mla_config_file_name(16, 1, 576, 512, "default")
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.warning("Using default decode attention configuration from %s for attention layer. It may not have the best performance to use default json. ", config_file_path)
# If a configuration has been found, return it
return json.load(f)
else:
raise ValueError("Please surpport default config can match 16 1 576 512")
# If no optimized configuration is available, we will use the default
# configuration
return None
def get_config_map(attention_configs):
ret_map = {}
for bs in attention_configs.keys():
int_bs = int(bs)
seq_map = {}
seq_configs = attention_configs[bs]
ret_map[int_bs] = seq_map
for seq_len in seq_configs.keys():
int_seq_len = int(seq_len)
kind_config = seq_configs[seq_len]
configs = BestConfig()
configs.batch_size = int_bs
configs.seq_len = int_seq_len
configs.best_us = kind_config['best_us']
seq_map[int_seq_len] = configs
if kind_config['kernel_kind'] == 'v1':
best_config = kind_config['best_config']
configs.kernel_kind = KERNLE_KINDS.v1
configs.BLOCK_N = best_config['BLOCK_N']
configs.num_stages = best_config['num_stages']
configs.num_warps = best_config['num_warps']
elif kind_config['kernel_kind'] == 'v1_tc':
best_config = kind_config['best_config']
configs.kernel_kind = KERNLE_KINDS.v1_tc
configs.BLOCK_N = best_config['BLOCK_N']
configs.num_stages = best_config['num_stages']
configs.num_warps = best_config['num_warps']
elif kind_config['kernel_kind'] == 'v1_tc_2':
best_config = kind_config['best_config']
configs.kernel_kind = KERNLE_KINDS.v1_tc_2
configs.BLOCK_N = best_config['BLOCK_N']
configs.num_stages = best_config['num_stages']
configs.num_warps = best_config['num_warps']
elif kind_config['kernel_kind'] == 'v1_2stages':
best_config = kind_config['best_config']
stage1 = best_config['stage1']
stage2 = best_config['stage2']
configs.kernel_kind = KERNLE_KINDS.v1_2stages
# configs.SPLIT_K = stage1['SPLIT_K']
configs.BLOCK_N = stage1['BLOCK_N']
configs.num_stages = stage1['num_stages']
configs.num_warps = stage1['num_warps']
configs.BLOCK_N_2 = stage2['BLOCK_N']
configs.num_stages_2 = stage2['num_stages']
configs.num_warps_2 = stage2['num_warps']
elif kind_config['kernel_kind'] == 'v1_2stages_tc':
best_config = kind_config['best_config']
stage1 = best_config['stage1']
stage2 = best_config['stage2']
configs.kernel_kind = KERNLE_KINDS.v1_2stages_tc
# configs.SPLIT_K = stage1['SPLIT_K']
configs.BLOCK_N = stage1['BLOCK_N']
configs.num_stages = stage1['num_stages']
configs.num_warps = stage1['num_warps']
configs.BLOCK_N_2 = stage2['BLOCK_N']
configs.num_stages_2 = stage2['num_stages']
configs.num_warps_2 = stage2['num_warps']
elif kind_config['kernel_kind'] == 'v2':
best_config = kind_config['best_config']
stage1 = best_config['stage1']
stage2 = best_config['stage2']
configs.kernel_kind = KERNLE_KINDS.v2
if 'BLOCK_SEQ' in stage1:
configs.BLOCK_SEQ = stage1['BLOCK_SEQ']
else:
configs.NUM_KV_SPLITS = stage1['NUM_KV_SPLITS']
configs.BLOCK_N = stage1['BLOCK_N']
configs.num_stages = stage1['num_stages']
configs.num_warps = stage1['num_warps']
configs.num_stages_2 = stage2['num_stages']
configs.num_warps_2 = stage2['num_warps']
elif kind_config['kernel_kind'] == 'v2_tc':
best_config = kind_config['best_config']
stage1 = best_config['stage1']
stage2 = best_config['stage2']
configs.kernel_kind = KERNLE_KINDS.v2_tc
if 'BLOCK_SEQ' in stage1:
configs.BLOCK_SEQ = stage1['BLOCK_SEQ']
else:
configs.NUM_KV_SPLITS = stage1['NUM_KV_SPLITS']
configs.BLOCK_N = stage1['BLOCK_N']
configs.num_stages = stage1['num_stages']
configs.num_warps = stage1['num_warps']
configs.num_stages_2 = stage2['num_stages']
configs.num_warps_2 = stage2['num_warps']
return ret_map
@functools.lru_cache
def get_attention_mla_configs(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> Optional[Dict[Any, Any]]:
attention_configs = get_attention_mla_configs_json(QH, KVH, QKD, VD, cache_dtype)
return get_config_map(attention_configs)
def get_closest_key(dic_keys, target_key):
keys = list(dic_keys)
idx = bisect.bisect_left(keys, target_key)
if idx == 0:
return keys[0]
if idx == len(keys):
return keys[-1]
left_key = keys[idx - 1]
right_key = keys[idx]
if target_key - left_key <= right_key - target_key:
return left_key
else:
return right_key
def get_nearest_config(bs_key, mean_kv_seqlen_key, config):
closest_bs_key = get_closest_key(config.keys(), bs_key)
closest_mean_kv_seqlen_key = get_closest_key(config[closest_bs_key].keys(), mean_kv_seqlen_key)
return config[closest_bs_key][closest_mean_kv_seqlen_key]
def get_config(bs_key, mean_kv_seqlen_key, config):
if bs_key in config and mean_kv_seqlen_key in config[bs_key]:
return config[bs_key][mean_kv_seqlen_key]
else:
raise ValueError(f"No matching configuration found for bs key: {bs_key} and mean kv seq key: {mean_kv_seqlen_key} when init decode attention db")
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os
import functools
import json
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
...@@ -10,6 +7,7 @@ from itertools import accumulate ...@@ -10,6 +7,7 @@ from itertools import accumulate
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
from .triton_config import get_nearest_config, get_attention_mla_configs, get_config
try: try:
from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper
...@@ -39,65 +37,6 @@ if TYPE_CHECKING: ...@@ -39,65 +37,6 @@ if TYPE_CHECKING:
ModelInputForGPUWithSamplingMetadata) ModelInputForGPUWithSamplingMetadata)
def get_config(bs_key, mean_kv_seqlen_key, config):
# 转换参数为字符串以匹配字典的键
bs_key_str = str(bs_key)
mean_kv_seqlen_key_str = str(mean_kv_seqlen_key)
# 检查字典中是否存在对应的配置
if bs_key_str in config and mean_kv_seqlen_key_str in config[bs_key_str]:
return config[bs_key_str][mean_kv_seqlen_key_str]
else:
raise ValueError(f"No matching configuration found for bs key: {bs_key} and mean kv seq key: {mean_kv_seqlen_key} when init decode attention db")
def get_mla_config_file_name(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> str:
if cache_dtype == "default":
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_default.json"
device_name = torch.cuda.get_device_name().replace(" ", "_")
if "K100_AI" in device_name:
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_K100AI.json"
elif "BW" in device_name:
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_BW.json"
else:
raise ValueError(f"Unsurpport device name: {device_name}")
@functools.lru_cache
def get_attention_mla_configs(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> Optional[Dict[Any, Any]]:
# First look up if an optimized configuration is available in the configs
# directory
json_file_name = get_mla_config_file_name(QH, KVH, QKD, VD, cache_dtype)
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info("Using decode attention configuration from %s for attention layer.", config_file_path)
# If a configuration has been found, return it
return json.load(f)
else:
logger.warning("Can not find best decode attention configuration %s for attention layer, it may not have the best performance to use default json. Please tune one. ", config_file_path)
json_file_name = get_mla_config_file_name(16, 1, 576, 512, "default")
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.warning("Using default decode attention configuration from %s for attention layer. It may not have the best performance to use default json. ", config_file_path)
# If a configuration has been found, return it
return json.load(f)
else:
raise ValueError("Please surpport default config can match 16 1 576 512")
# If no optimized configuration is available, we will use the default
# configuration
return None
class TritonMLABackend(AttentionBackend): class TritonMLABackend(AttentionBackend):
@staticmethod @staticmethod
......
...@@ -36,6 +36,7 @@ import triton.language as tl ...@@ -36,6 +36,7 @@ import triton.language as tl
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm import envs from vllm import envs
from ..backends.triton_config import KERNLE_KINDS
is_hip_ = current_platform.is_rocm() is_hip_ = current_platform.is_rocm()
os.environ["TRITON_HIP_USE_NEW_STREAM_PIPELINE"] = f"0" os.environ["TRITON_HIP_USE_NEW_STREAM_PIPELINE"] = f"0"
...@@ -897,8 +898,8 @@ def _decode_v1_stage1_use_tc( ...@@ -897,8 +898,8 @@ def _decode_v1_stage1_use_tc(
sm_scale, sm_scale,
page_size, page_size,
num_kv_splits, num_kv_splits,
logit_cap,
best_config, best_config,
logit_cap,
): ):
Lk = k_buffer.shape[-1] Lk = k_buffer.shape[-1]
...@@ -916,17 +917,18 @@ def _decode_v1_stage1_use_tc( ...@@ -916,17 +917,18 @@ def _decode_v1_stage1_use_tc(
batch, head_num = q.shape[0], q.shape[1] batch, head_num = q.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[-2] kv_group_num = q.shape[1] // k_buffer.shape[-2]
BLOCK_N = best_config['BLOCK_N'] BLOCK_N = best_config.BLOCK_N
SPLIT_K = num_kv_splits # best_config['SPLIT_K'] ? SPLIT_K = num_kv_splits # best_config.SPLIT_K
num_stages = best_config['num_stages'] num_stages = best_config.num_stages
num_warps = best_config['num_warps'] num_warps = best_config.num_warps
BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num))) BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
grid = lambda META: ( grid = lambda META: (
batch, batch,
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
SPLIT_K, SPLIT_K,
) )
_decode_v1_kernel_stage1_use_tc[grid]( if best_config.decode_fwd_stage1 is None:
best_config.decode_fwd_stage1 = _decode_v1_kernel_stage1_use_tc[grid](
q, q,
k_buffer, k_buffer,
sm_scale, sm_scale,
...@@ -955,6 +957,24 @@ def _decode_v1_stage1_use_tc( ...@@ -955,6 +957,24 @@ def _decode_v1_stage1_use_tc(
Lk=Lk, Lk=Lk,
kpack=2, kpack=2,
) )
else:
best_config.decode_fwd_stage1 = _decode_v1_kernel_stage1_use_tc[grid](
q,
k_buffer,
sm_scale,
Req_to_tokens,
#B_req_idx,
B_Start_Loc,
B_Seqlen,
att_out,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(-3),
k_buffer.stride(-2),
att_out.stride(0),
)
# return _decode_v1_kernel_stage1_use_tc.best_config # return _decode_v1_kernel_stage1_use_tc.best_config
...@@ -966,21 +986,22 @@ def _decode_v1_stage2_use_tc( ...@@ -966,21 +986,22 @@ def _decode_v1_stage2_use_tc(
#b_req_idx, #b_req_idx,
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
page_size,
best_config, best_config,
page_size,
): ):
batch, head_num = b_seq_len.shape[0], logits.shape[0] batch, head_num = b_seq_len.shape[0], logits.shape[0]
kv_group_num = logits.shape[0] // v_buffer.shape[-2] kv_group_num = logits.shape[0] // v_buffer.shape[-2]
BLOCK_N = best_config['BLOCK_N'] BLOCK_N = best_config.BLOCK_N_2
num_stages = best_config['num_stages'] num_stages = best_config.num_stages_2
num_warps = best_config['num_warps'] num_warps = best_config.num_warps_2
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num)) BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1) grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)
Lv = v_buffer.shape[-1] Lv = v_buffer.shape[-1]
BLOCK_DMODEL = triton.next_power_of_2(Lv) BLOCK_DMODEL = triton.next_power_of_2(Lv)
_decode_v1_kernel_stage2_use_tc[grid]( if best_config.decode_fwd_stage2 is None:
best_config.decode_fwd_stage2 = _decode_v1_kernel_stage2_use_tc[grid](
logits, logits,
v_buffer, v_buffer,
o, o,
...@@ -1004,6 +1025,23 @@ def _decode_v1_stage2_use_tc( ...@@ -1004,6 +1025,23 @@ def _decode_v1_stage2_use_tc(
num_warps=num_warps, num_warps=num_warps,
num_stages=num_stages, num_stages=num_stages,
) )
else:
best_config.decode_fwd_stage2 = _decode_v1_kernel_stage2_use_tc[grid](
logits,
v_buffer,
o,
req_to_tokens,
#b_req_idx,
b_start_loc,
b_seq_len,
logits.stride(0),
v_buffer.stride(-3),
v_buffer.stride(-2),
o.stride(0),
o.stride(1),
req_to_tokens.stride(0),
)
# return _decode_v1_kernel_stage2_use_tc.best_config # return _decode_v1_kernel_stage2_use_tc.best_config
...@@ -1059,8 +1097,8 @@ def decode_attention_v1( ...@@ -1059,8 +1097,8 @@ def decode_attention_v1(
sm_scale, sm_scale,
page_size, page_size,
num_kv_splits, num_kv_splits,
best_config,
logit_cap, logit_cap,
best_config['stage1'],
) )
_decode_v1_stage2_use_tc( _decode_v1_stage2_use_tc(
attn_logits, attn_logits,
...@@ -1070,12 +1108,11 @@ def decode_attention_v1( ...@@ -1070,12 +1108,11 @@ def decode_attention_v1(
#b_req_idx, #b_req_idx,
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
best_config,
page_size, page_size,
best_config['stage2'],
) )
# @triton.autotune( # @triton.autotune(
# configs=[ # configs=[
# triton.Config({"BLOCK_N": 16}, num_warps=2, num_stages=1), # triton.Config({"BLOCK_N": 16}, num_warps=2, num_stages=1),
...@@ -1274,9 +1311,9 @@ def _decode_v2_stage1_use_tc( ...@@ -1274,9 +1311,9 @@ def _decode_v2_stage1_use_tc(
logit_cap, logit_cap,
): ):
BLOCK = best_config['BLOCK_N'] BLOCK = best_config.BLOCK_N
num_stages = best_config['num_stages'] num_stages = best_config.num_stages
num_warps = best_config['num_warps'] num_warps = best_config.num_warps
Lk = k_buffer.shape[-1] Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1] Lv = v_buffer.shape[-1]
...@@ -1304,7 +1341,8 @@ def _decode_v2_stage1_use_tc( ...@@ -1304,7 +1341,8 @@ def _decode_v2_stage1_use_tc(
NUM_KV_SPLITS, NUM_KV_SPLITS,
) )
_decode_v2_kernel_stage1_use_tc[grid]( if best_config.decode_fwd_stage1 is None:
best_config.decode_fwd_stage1 =_decode_v2_kernel_stage1_use_tc[grid](
q, q,
k_buffer, k_buffer,
v_buffer, v_buffer,
...@@ -1339,8 +1377,30 @@ def _decode_v2_stage1_use_tc( ...@@ -1339,8 +1377,30 @@ def _decode_v2_stage1_use_tc(
Lv=Lv, Lv=Lv,
kpack=2, kpack=2,
) )
else:
best_config.decode_fwd_stage1[grid](
q,
k_buffer,
v_buffer,
sm_scale,
Req_to_tokens,
# B_req_idx,
B_Seqlen,
att_out,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(-3),
k_buffer.stride(-2),
v_buffer.stride(-3),
v_buffer.stride(-2),
att_out.stride(0),
att_out.stride(1),
att_out.stride(2),
)
# return _decode_v2_kernel_stage1_use_tc.best_config # return _decode_v2_kernel_stage1_use_tc.best_config
# @triton.autotune( # @triton.autotune(
# configs=[ # configs=[
# triton.Config({}, num_warps=1, num_stages=1), # triton.Config({}, num_warps=1, num_stages=1),
...@@ -1416,8 +1476,8 @@ def _decode_v2_stage2_use_tc( ...@@ -1416,8 +1476,8 @@ def _decode_v2_stage2_use_tc(
num_kv_splits, num_kv_splits,
best_config, best_config,
): ):
num_stages = best_config['num_stages'] num_stages = best_config.num_stages_2
num_warps = best_config['num_warps'] num_warps = best_config.num_warps_2
batch, head_num = q.shape[0], q.shape[1] batch, head_num = q.shape[0], q.shape[1]
Lv = v_buffer.shape[-1] Lv = v_buffer.shape[-1]
...@@ -1425,8 +1485,9 @@ def _decode_v2_stage2_use_tc( ...@@ -1425,8 +1485,9 @@ def _decode_v2_stage2_use_tc(
NUM_KV_SPLITS = num_kv_splits NUM_KV_SPLITS = num_kv_splits
grid = (batch, head_num) grid = (batch, head_num, 1)
_decode_v2_kernel_stage2[grid]( if best_config.decode_fwd_stage2 is None:
best_config.decode_fwd_stage2 = _decode_v2_kernel_stage2[grid](
logits, logits,
o, o,
b_seq_len, b_seq_len,
...@@ -1441,6 +1502,17 @@ def _decode_v2_stage2_use_tc( ...@@ -1441,6 +1502,17 @@ def _decode_v2_stage2_use_tc(
num_warps=num_warps, num_warps=num_warps,
num_stages=num_stages, num_stages=num_stages,
) )
else:
best_config.decode_fwd_stage2[grid](
logits,
o,
b_seq_len,
logits.stride(0),
logits.stride(1),
logits.stride(2),
o.stride(0),
o.stride(1),
)
# return _decode_v2_kernel_stage2.best_config # return _decode_v2_kernel_stage2.best_config
...@@ -1485,11 +1557,11 @@ def decode_attention_v2( ...@@ -1485,11 +1557,11 @@ def decode_attention_v2(
b_seq_len, b_seq_len,
num_kv_splits, num_kv_splits,
sm_scale, sm_scale,
best_config,
page_size, page_size,
logit_cap, logit_cap,
best_config['stage1'],
) )
_decode_v2_stage2_use_tc(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits, best_config['stage2']) _decode_v2_stage2_use_tc(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits, best_config)
def decode_attention_fwd( def decode_attention_fwd(
...@@ -1560,7 +1632,7 @@ def decode_attention_fwd( ...@@ -1560,7 +1632,7 @@ def decode_attention_fwd(
logit_cap, logit_cap,
)''' )'''
if best_config['kernel_kind'] == 'v1_2stages_tc': if best_config.kernel_kind == KERNLE_KINDS.v1_2stages_tc:
attn_logits_v1 = torch.empty( attn_logits_v1 = torch.empty(
(q.shape[1],k_buffer.shape[0]*page_size), (q.shape[1],k_buffer.shape[0]*page_size),
dtype=torch.float16, dtype=torch.float16,
...@@ -1576,11 +1648,11 @@ def decode_attention_fwd( ...@@ -1576,11 +1648,11 @@ def decode_attention_fwd(
attn_logits_v1, attn_logits_v1,
num_kv_splits, num_kv_splits,
sm_scale, sm_scale,
best_config=best_config['best_config'], best_config=best_config,
page_size=page_size, page_size=page_size,
logit_cap=logit_cap, logit_cap=logit_cap,
) )
elif best_config['kernel_kind'] == 'v2_tc': elif best_config.kernel_kind == KERNLE_KINDS.v2_tc:
decode_attention_v2( decode_attention_v2(
q, q,
k_buffer, k_buffer,
...@@ -1591,7 +1663,7 @@ def decode_attention_fwd( ...@@ -1591,7 +1663,7 @@ def decode_attention_fwd(
attn_logits, attn_logits,
num_kv_splits, num_kv_splits,
sm_scale, sm_scale,
best_config=best_config['best_config'], best_config=best_config,
page_size=page_size, page_size=page_size,
logit_cap=logit_cap, logit_cap=logit_cap,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment