Commit 13b1dcfe authored by zhuwenwen's avatar zhuwenwen
Browse files

add mla tuning script and configs

parent b95d1275
'''
QH=num_attention_heads KVH=1 QKD=kv_lora_rank + qk_rope_head_dim VD=kv_lora_rank
deepseek v2: QH=16, QKD=512+64=576, VD=512
deepseek v3: QH=128, QKD=512+64=576, VD=512
python3 gen_mla_pa_tables.py --QH 16 --KVH 1 --QKD 576 --VD 512 --TP 1
python3 gen_mla_pa_tables.py --QH 16 --KVH 1 --QKD 576 --VD 512 --TP 2
python3 gen_mla_pa_tables.py --QH 16 --KVH 1 --QKD 576 --VD 512 --TP 4
python3 gen_mla_pa_tables.py --QH 16 --KVH 1 --QKD 576 --VD 512 --TP 8
python3 gen_mla_pa_tables.py --QH 128 --KVH 1 --QKD 576 --VD 512 --TP 32
'''
import argparse
def parse_args():
parser = argparse.ArgumentParser(description="Generate mla pa tables (B, SEQ, QH, KVH, QKD, VD).")
parser.add_argument('--SEQSTART', type=int, required=False, default=100, help='Value for SEQSTART')
parser.add_argument('--SEQEND', type=int, required=False, default=8193, help='Value for SEQEND')
parser.add_argument('--QH', type=int, required=True, help='Value for QH')
parser.add_argument('--KVH', type=int, required=True, help='Value for KVH')
parser.add_argument('--QKD', type=int, required=True, help='Value for QKD')
parser.add_argument('--VD', type=int, required=True, help='Value for VD')
parser.add_argument('--TP', type=int, required=True, help='Value for TP')
return parser.parse_args()
def generate_B():
# 1 4 8 16 ... 128,144 160 ... 256, 288 320 ... 512
B_values = [1, 4] + [i * 8 for i in range(1, 17)]
# B_values = [1, 4] + [i * 8 for i in range(1, 17)] + [(i + 1) * 16 for i in range(8, 16)] + [(i + 1) * 32 for i in range(8, 16)]
return B_values
def generate_SEQ_bs1_32(seq_start, seq_end):
SEQ_values_part1 = list(range(seq_start, 5000, 300))
SEQ_values_part2 = list(range(5000, seq_end, 500))
SEQ_values = [1] + SEQ_values_part1 + SEQ_values_part2
return SEQ_values
def generate_SEQ_bs40_96(seq_start, seq_end):
SEQ_values_part1 = list(range(seq_start, 5000, 200))
SEQ_values_part2 = list(range(5000, seq_end, 300))
SEQ_values = [1] + SEQ_values_part1 + SEQ_values_part2
return SEQ_values
def generate_SEQ_other(seq_start, seq_end):
SEQ_values_part1 = list(range(seq_start, 5000, 300))
SEQ_values_part2 = list(range(5000, seq_end, 500))
SEQ_values = [1] + SEQ_values_part1 + SEQ_values_part2
return SEQ_values
def generate_tuples(seq_start, seq_end, QH, KVH, QKD, VD):
B_values = generate_B()
SEQ_values_b1_32 = generate_SEQ_bs1_32(seq_start, seq_end)
SEQ_values_b40_96 = generate_SEQ_bs40_96(seq_start, seq_end)
SEQ_values_other = generate_SEQ_other(seq_start, seq_end)
pa_tables = []
for B in B_values:
if B <= 32:
for SEQ in SEQ_values_b1_32:
pa_tables.append((B, SEQ, QH, KVH, QKD, VD))
elif B <= 96 and B > 32:
for SEQ in SEQ_values_b40_96:
pa_tables.append((B, SEQ, QH, KVH, QKD, VD))
else:
for SEQ in SEQ_values_other:
pa_tables.append((B, SEQ, QH, KVH, QKD, VD))
return pa_tables
if __name__ == "__main__":
args = parse_args()
assert args.QH % args.TP == 0, "error QH, QH % TP must be 0"
tuple_sizes = generate_tuples(args.SEQSTART, args.SEQEND, args.QH // args.TP, max(args.KVH // args.TP, 1), args.QKD, args.VD)
for t in tuple_sizes:
print(f"{t},")
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
import os
import json
import pytest
import torch
import triton
from triton_decode_attention import decode_attentionv1_fwd, decode_attentionv2_fwd
def cdiv(a, b):
return (a + b - 1) // b
@pytest.mark.parametrize("B", [1])
# @pytest.mark.parametrize("L", [100])
# @pytest.mark.parametrize("L", [1,100,400,700,1000,1300,1600,1900,2200,2500,2800,3100,3400,3700,4000,4300,4600,4900,5000,5500,6000,6500,7000,7500,8000])
@pytest.mark.parametrize("L", [1,100,400,700,1000,1300,1600,1900,2200,2500,2800,3100,3400,3700,4000,4300,4600,4900,5000,5500,6000,6500,7000,7500,8000,8500,9000,9500,10000,10500,11000,11500,12000,12500,13000,13500,14000,14500,15000,15500,16000,16500,17000,17500,18000,18500,19000,19500,20000,20500,21000,21500,22000,22500,23000,23500,24000,24500,25000,25500,26000,26500,27000,27500,28000,28500,29000,29500,30000,30500,31000,31500,32000,32500])
@pytest.mark.parametrize("H_Q", [4, 8, 16])
@pytest.mark.parametrize("H_KV", [1])
@pytest.mark.parametrize("D_QK", [576])
@pytest.mark.parametrize("D_V", [512])
@pytest.mark.parametrize("CACHE_SIZE", [16384])
@pytest.mark.parametrize("PAGE_SIZE", [16])
def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
assert CACHE_SIZE % PAGE_SIZE == 0
dtype = torch.bfloat16
seq_len = L # This represents the number of tokens already in the sequence
sm_scale = 1.0 / (D_QK**0.5)
num_kv_splits = 4
num_pages_per_batch = cdiv(seq_len, PAGE_SIZE) #这里为向上取整,65,(1027+16-1)//16
req_to_page = torch.randint(0,
CACHE_SIZE // PAGE_SIZE,
(B, num_pages_per_batch, 1), #shape为(B, num_pages_per_batch, 1)的tensor,大小取值为0 至cache_size//page_size
device="cuda")
req_to_token = req_to_page * PAGE_SIZE
req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE)
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(
1, 1, -1)
req_to_token = req_to_token.view(B, -1)
req_to_token = req_to_token[:, :seq_len].contiguous()
# q represents the new token being generated, one per batch
q = torch.randn(B, H_Q, D_QK, dtype=dtype, device="cuda")
# k_buffer and v_buffer represent all previous tokens
# Page size is 1.
k_buffer = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device="cuda")
v_buffer = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device="cuda")
# o will have the same shape as q
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
b_seq_len = torch.full((B, ), seq_len, device="cuda")
attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1),
dtype=torch.float32,
device="cuda",
)
b_req_idx = torch.arange(B, device="cuda").to(torch.int32)
# Call the original implementation.
decode_attentionv2_fwd(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
)
# Page size can be larger than 1.
k_buffer = k_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK)
v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)
o1 = torch.zeros_like(o)
configs = {
"v2_tc": {"stage1": {}, "stage2": {}},
"v1_2stages_tc": {"stage1": {}, "stage2": {}},
}
ms = {
"v1_2stages_tc": 10000.0,
"v2_tc": 10000.0,
}
final_best_config = {
"kernel_kind": "",
"best_config": {},
"best_us": 0.0,
}
v2_tc_stage1_best_config, v2_tc_stage2_best_config = decode_attentionv2_fwd(
q,
k_buffer,
v_buffer,
o1,
req_to_page,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
PAGE_SIZE,
)
assert torch.allclose(o, o1, atol=1e-2, rtol=1e-2)
quantiles = [0.5, 0.2, 0.8]
v2_tc_ms, v2_tc_min_ms, v2_tc_max_ms = triton.testing.do_bench(lambda:
decode_attentionv2_fwd(
q,
k_buffer,
v_buffer,
o1,
req_to_page,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
PAGE_SIZE,
), quantiles=quantiles)
for key, value in v2_tc_stage1_best_config.kwargs.items():
configs["v2_tc"]["stage1"][key] = value
configs["v2_tc"]["stage1"]["num_stages"] = v2_tc_stage1_best_config.num_stages
configs["v2_tc"]["stage1"]["num_warps"] = v2_tc_stage1_best_config.num_warps
for key, value in v2_tc_stage2_best_config.kwargs.items():
configs["v2_tc"]["stage2"][key] = value
configs["v2_tc"]["stage2"]["num_stages"] = v2_tc_stage2_best_config.num_stages
configs["v2_tc"]["stage2"]["num_warps"] = v2_tc_stage2_best_config.num_warps
ms["v2_tc"] = v2_tc_ms
print(f"v2_tc best configs is {configs['v2_tc']}")
print("print mla decode attention v2 kernel [B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE] min cost :",[B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE], v2_tc_ms)
o2 = torch.zeros_like(o)
v1_tc_stage1_best_config, v1_tc_stage2_best_config = decode_attentionv1_fwd(
q,
k_buffer,
v_buffer,
o2,
req_to_page,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
PAGE_SIZE,
)
assert torch.allclose(o, o2, atol=1e-2, rtol=1e-2)
v1_tc_ms, v1_tc_min_ms, v1_tc_max_ms = triton.testing.do_bench(lambda:
decode_attentionv1_fwd(
q,
k_buffer,
v_buffer,
o1,
req_to_page,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
PAGE_SIZE,
), quantiles=quantiles)
for key, value in v1_tc_stage1_best_config.kwargs.items():
configs["v1_2stages_tc"]["stage1"][key] = value
configs["v1_2stages_tc"]["stage1"]["num_stages"] = v1_tc_stage1_best_config.num_stages
configs["v1_2stages_tc"]["stage1"]["num_warps"] = v1_tc_stage1_best_config.num_warps
configs["v1_2stages_tc"]["stage1"]["num_ldmatrixes"] = v1_tc_stage1_best_config.num_ldmatrixes
for key, value in v1_tc_stage2_best_config.kwargs.items():
configs["v1_2stages_tc"]["stage2"][key] = value
configs["v1_2stages_tc"]["stage2"]["num_stages"] = v1_tc_stage2_best_config.num_stages
configs["v1_2stages_tc"]["stage2"]["num_warps"] = v1_tc_stage2_best_config.num_warps
configs["v1_2stages_tc"]["stage2"]["num_ldmatrixes"] = v1_tc_stage1_best_config.num_ldmatrixes
ms["v1_2stages_tc"] = v1_tc_ms
min_key, min_ms = min(ms.items(), key=lambda x: x[1])
final_best_config["kernel_kind"] = min_key
final_best_config["best_config"] = configs[min_key]
final_best_config["best_us"] = min_ms * 1000
print(f"v1_2stages_tc best configs is {configs['v1_2stages_tc']}")
print("print mla decode attention v1 kernel [B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE] min cost :",[B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE], v1_tc_ms)
print(f"Tuned_decode_attention choose {min_key} kernel, min cost {min_ms} ms, best config of {min_key} kernel is {configs[min_key]}")
assert torch.allclose(o, o2, atol=1e-2, rtol=1e-2)
#**************save config**************#
batch = b_req_idx.shape[0]
mean_seq_len = int((b_seq_len.sum() / max(1, batch)).item())
device_name = torch.cuda.get_device_name().replace(" ", "_")
if "K100_AI" in device_name:
# return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_K100AI.json"
file_name = f"QH={H_Q}_KVH={H_KV}_QKD={D_QK}_VD={D_V}_fp16_K100AI.json"
elif "BW" in device_name:
# return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_BW.json"
file_name = f"QH={H_Q}_KVH={H_KV}_QKD={D_QK}_VD={D_V}_fp16_BW.json"
else:
raise ValueError(f"Unsurpport device name: {device_name}")
if os.path.exists(file_name):
with open(file_name, 'r') as file:
config_info = json.load(file)
else:
config_info = {}
# 如果 config_info 中没有当前的 batch,初始化它为一个空字典
# if f"{batch}" not in config_info:
# config_info[f"{batch}"] = {}
# 把新的 mean_seq_len 配置加入到当前 batch 中
# config_info[f"{batch}"][f"{mean_seq_len}"] = final_best_config
config_info[f"{mean_seq_len}"] = final_best_config
# 保存最佳配置
with open(file_name, 'w') as file:
json.dump(config_info, file, indent=1)
#**************save config**************#
This diff is collapsed.
...@@ -9,14 +9,11 @@ from vllm.logger import init_logger ...@@ -9,14 +9,11 @@ from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
class KERNLE_KINDS(Enum): class KERNLE_KINDS(Enum):
v1 = 0 v1_2stages = 0
v1_tc = 1 v1_2stages_tc = 1
v1_tc_2 = 2 v2 = 2
v1_2stages = 3 v2_tc = 3
v1_2stages_tc = 4 TOTAL_KIND = 4
v2 = 5
v2_tc = 6
TOTAL_KIND = 7
class BestConfig(): class BestConfig():
def __init__(self): def __init__(self):
...@@ -24,8 +21,9 @@ class BestConfig(): ...@@ -24,8 +21,9 @@ class BestConfig():
self.seq_len = 0 self.seq_len = 0
self.kernel_kind = KERNLE_KINDS.TOTAL_KIND self.kernel_kind = KERNLE_KINDS.TOTAL_KIND
self.BLOCK_N = 0 self.BLOCK_N = 0
self.BLOCK_SEQ = 0 self.BLOCK_DIM = 0
self.SPLIT_K = 0 # self.BLOCK_SEQ = 0
# self.SPLIT_K = 0
self.num_stages = 0 self.num_stages = 0
self.num_warps = 0 self.num_warps = 0
self.NUM_KV_SPLITS = 0 self.NUM_KV_SPLITS = 0
...@@ -49,19 +47,6 @@ def get_mla_config_file_name(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: ...@@ -49,19 +47,6 @@ def get_mla_config_file_name(QH: int, KVH: int, QKD: int, VD: int, cache_dtype:
raise ValueError(f"Unsurpport device name: {device_name}") raise ValueError(f"Unsurpport device name: {device_name}")
def get_config_file_name(QH: int, KVH: int, D: int, cache_dtype: Optional[str]) -> str:
if cache_dtype == "default":
return f"QH={QH}_KVH={KVH}_D={D}_default.json"
device_name = torch.cuda.get_device_name().replace(" ", "_")
if "K100_AI" in device_name:
return f"QH={QH}_KVH={KVH}_D={D}_{cache_dtype}_K100AI.json"
elif "BW" in device_name:
return f"QH={QH}_KVH={KVH}_D={D}_{cache_dtype}_BW.json"
else:
raise ValueError(f"Unsurpport device name: {device_name}")
def get_attention_mla_configs_json(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> Optional[Dict[Any, Any]]: def get_attention_mla_configs_json(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> Optional[Dict[Any, Any]]:
# First look up if an optimized configuration is available in the configs # First look up if an optimized configuration is available in the configs
...@@ -73,7 +58,7 @@ def get_attention_mla_configs_json(QH: int, KVH: int, QKD: int, VD: int, cache_d ...@@ -73,7 +58,7 @@ def get_attention_mla_configs_json(QH: int, KVH: int, QKD: int, VD: int, cache_d
) )
if os.path.exists(config_file_path): if os.path.exists(config_file_path):
with open(config_file_path) as f: with open(config_file_path) as f:
logger.info("Using decode attention configuration from %s for attention layer.", config_file_path) # logger.info("Using decode attention configuration from %s for attention layer.", config_file_path)
# If a configuration has been found, return it # If a configuration has been found, return it
return json.load(f) return json.load(f)
else: else:
...@@ -107,29 +92,11 @@ def get_config_map(attention_configs): ...@@ -107,29 +92,11 @@ def get_config_map(attention_configs):
int_seq_len = int(seq_len) int_seq_len = int(seq_len)
kind_config = seq_configs[seq_len] kind_config = seq_configs[seq_len]
configs = BestConfig() configs = BestConfig()
configs.batch_size = int_bs # configs.batch_size = int_bs
configs.seq_len = int_seq_len # configs.seq_len = int_seq_len
configs.best_us = kind_config['best_us'] configs.best_us = kind_config['best_us']
seq_map[int_seq_len] = configs seq_map[int_seq_len] = configs
if kind_config['kernel_kind'] == 'v1': if kind_config['kernel_kind'] == 'v1_2stages':
best_config = kind_config['best_config']
configs.kernel_kind = KERNLE_KINDS.v1
configs.BLOCK_N = best_config['BLOCK_N']
configs.num_stages = best_config['num_stages']
configs.num_warps = best_config['num_warps']
elif kind_config['kernel_kind'] == 'v1_tc':
best_config = kind_config['best_config']
configs.kernel_kind = KERNLE_KINDS.v1_tc
configs.BLOCK_N = best_config['BLOCK_N']
configs.num_stages = best_config['num_stages']
configs.num_warps = best_config['num_warps']
elif kind_config['kernel_kind'] == 'v1_tc_2':
best_config = kind_config['best_config']
configs.kernel_kind = KERNLE_KINDS.v1_tc_2
configs.BLOCK_N = best_config['BLOCK_N']
configs.num_stages = best_config['num_stages']
configs.num_warps = best_config['num_warps']
elif kind_config['kernel_kind'] == 'v1_2stages':
best_config = kind_config['best_config'] best_config = kind_config['best_config']
stage1 = best_config['stage1'] stage1 = best_config['stage1']
stage2 = best_config['stage2'] stage2 = best_config['stage2']
...@@ -158,10 +125,10 @@ def get_config_map(attention_configs): ...@@ -158,10 +125,10 @@ def get_config_map(attention_configs):
stage1 = best_config['stage1'] stage1 = best_config['stage1']
stage2 = best_config['stage2'] stage2 = best_config['stage2']
configs.kernel_kind = KERNLE_KINDS.v2 configs.kernel_kind = KERNLE_KINDS.v2
if 'BLOCK_SEQ' in stage1: # if 'BLOCK_SEQ' in stage1:
configs.BLOCK_SEQ = stage1['BLOCK_SEQ'] # configs.BLOCK_SEQ = stage1['BLOCK_SEQ']
else: # else:
configs.NUM_KV_SPLITS = stage1['NUM_KV_SPLITS'] # configs.NUM_KV_SPLITS = stage1['NUM_KV_SPLITS']
configs.BLOCK_N = stage1['BLOCK_N'] configs.BLOCK_N = stage1['BLOCK_N']
configs.num_stages = stage1['num_stages'] configs.num_stages = stage1['num_stages']
configs.num_warps = stage1['num_warps'] configs.num_warps = stage1['num_warps']
...@@ -172,11 +139,12 @@ def get_config_map(attention_configs): ...@@ -172,11 +139,12 @@ def get_config_map(attention_configs):
stage1 = best_config['stage1'] stage1 = best_config['stage1']
stage2 = best_config['stage2'] stage2 = best_config['stage2']
configs.kernel_kind = KERNLE_KINDS.v2_tc configs.kernel_kind = KERNLE_KINDS.v2_tc
if 'BLOCK_SEQ' in stage1: # if 'BLOCK_SEQ' in stage1:
configs.BLOCK_SEQ = stage1['BLOCK_SEQ'] # configs.BLOCK_SEQ = stage1['BLOCK_SEQ']
else: # else:
configs.NUM_KV_SPLITS = stage1['NUM_KV_SPLITS'] # configs.NUM_KV_SPLITS = stage1['NUM_KV_SPLITS']
configs.BLOCK_N = stage1['BLOCK_N'] configs.BLOCK_N = stage1['BLOCK_N']
configs.BLOCK_DIM = stage1['BLOCK_DIM']
configs.num_stages = stage1['num_stages'] configs.num_stages = stage1['num_stages']
configs.num_warps = stage1['num_warps'] configs.num_warps = stage1['num_warps']
configs.num_stages_2 = stage2['num_stages'] configs.num_stages_2 = stage2['num_stages']
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
...@@ -742,9 +743,13 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]): ...@@ -742,9 +743,13 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
PAGE_SIZE = kv_c_and_k_pe_cache.size(1) PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
# TODO # TODO
for bs in self.attn_configs.keys(): max_seq_len = torch.max(decode_meta.seq_lens_tensor).item()
for mean_seq_len in self.attn_configs[bs].keys(): if os.environ.get('PA_MATCH_USE_MEAN_SEQ') == '1':
best_config = get_config(bs, mean_seq_len, self.attn_configs) match_seq_len = int((decode_meta.seq_lens_tensor.sum()/ max(1, B)).item())
else:
match_seq_len = max_seq_len
best_config = self.attn_configs[min(self.attn_configs.keys(), key=lambda x: abs(int(x) - match_seq_len))]
# Run MQA # Run MQA
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
......
...@@ -36,7 +36,7 @@ import triton.language as tl ...@@ -36,7 +36,7 @@ import triton.language as tl
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm import envs from vllm import envs
from ..backends.triton_config import KERNLE_KINDS # from ..backends.triton_config import KERNLE_KINDS
is_hip_ = current_platform.is_rocm() is_hip_ = current_platform.is_rocm()
os.environ["TRITON_HIP_USE_NEW_STREAM_PIPELINE"] = f"0" os.environ["TRITON_HIP_USE_NEW_STREAM_PIPELINE"] = f"0"
...@@ -676,7 +676,7 @@ def _decode_v1_kernel_stage1_use_tc( ...@@ -676,7 +676,7 @@ def _decode_v1_kernel_stage1_use_tc(
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
split_k_id = tl.program_id(2) split_k_id = tl.program_id(2)
reduce_dtype = Att_Out.dtype.element_ty # reduce_dtype = Att_Out.dtype.element_ty
if BLOCK_H < kv_group_num: if BLOCK_H < kv_group_num:
VALID_BLOCK_H: tl.constexpr = BLOCK_H VALID_BLOCK_H: tl.constexpr = BLOCK_H
...@@ -695,14 +695,14 @@ def _decode_v1_kernel_stage1_use_tc( ...@@ -695,14 +695,14 @@ def _decode_v1_kernel_stage1_use_tc(
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
q = tl.load( q = tl.load(
Q + offs_q, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk), other=0.0 Q + offs_q, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk), other=0.0
).to(reduce_dtype) ) # .to(reduce_dtype)
if BLOCK_DPE > 0: if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
off_qpe = ( off_qpe = (
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
) )
qpe = tl.load(Q + off_qpe, mask=mask_h[:, None], other=0.0).to(reduce_dtype) qpe = tl.load(Q + off_qpe, mask=mask_h[:, None], other=0.0) # .to(reduce_dtype)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K) kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K)
split_k_start = kv_len_per_split * split_k_id split_k_start = kv_len_per_split * split_k_id
...@@ -726,7 +726,7 @@ def _decode_v1_kernel_stage1_use_tc( ...@@ -726,7 +726,7 @@ def _decode_v1_kernel_stage1_use_tc(
K_Buffer + offs_buf_k, K_Buffer + offs_buf_k,
mask=(offs_n[None, :] < split_k_end) & (offs_d[:, None] < Lk), mask=(offs_n[None, :] < split_k_end) & (offs_d[:, None] < Lk),
other=0.0, other=0.0,
).to(reduce_dtype) ) # .to(reduce_dtype)
qk = tl.dot(q, k) qk = tl.dot(q, k)
if BLOCK_DPE > 0: if BLOCK_DPE > 0:
offs_buf_kpe = ( offs_buf_kpe = (
...@@ -738,7 +738,7 @@ def _decode_v1_kernel_stage1_use_tc( ...@@ -738,7 +738,7 @@ def _decode_v1_kernel_stage1_use_tc(
K_Buffer + offs_buf_kpe, K_Buffer + offs_buf_kpe,
mask=offs_n[None, :] < split_k_end, mask=offs_n[None, :] < split_k_end,
other=0.0, other=0.0,
).to(reduce_dtype) ) # .to(reduce_dtype)
qk += tl.dot(qpe, kpe) qk += tl.dot(qpe, kpe)
qk *= sm_scale qk *= sm_scale
...@@ -917,18 +917,18 @@ def _decode_v1_stage1_use_tc( ...@@ -917,18 +917,18 @@ def _decode_v1_stage1_use_tc(
batch, head_num = q.shape[0], q.shape[1] batch, head_num = q.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[-2] kv_group_num = q.shape[1] // k_buffer.shape[-2]
BLOCK_N = best_config.BLOCK_N BLOCK_N = best_config['stage1']['BLOCK_N']
SPLIT_K = num_kv_splits # best_config.SPLIT_K SPLIT_K = num_kv_splits # best_config.SPLIT_K
num_stages = best_config.num_stages num_stages = best_config['stage1']['num_stages']
num_warps = best_config.num_warps num_warps = best_config['stage1']['num_warps']
BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num))) BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
grid = lambda META: ( grid = lambda META: (
batch, batch,
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
SPLIT_K, SPLIT_K,
) )
if best_config.decode_fwd_stage1 is None: _decode_v1_kernel_stage1_use_tc[grid](
best_config.decode_fwd_stage1 = _decode_v1_kernel_stage1_use_tc[grid](
q, q,
k_buffer, k_buffer,
sm_scale, sm_scale,
...@@ -957,23 +957,6 @@ def _decode_v1_stage1_use_tc( ...@@ -957,23 +957,6 @@ def _decode_v1_stage1_use_tc(
Lk=Lk, Lk=Lk,
kpack=2, kpack=2,
) )
else:
best_config.decode_fwd_stage1 = _decode_v1_kernel_stage1_use_tc[grid](
q,
k_buffer,
sm_scale,
Req_to_tokens,
#B_req_idx,
B_Start_Loc,
B_Seqlen,
att_out,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(-3),
k_buffer.stride(-2),
att_out.stride(0),
)
# return _decode_v1_kernel_stage1_use_tc.best_config # return _decode_v1_kernel_stage1_use_tc.best_config
...@@ -992,16 +975,16 @@ def _decode_v1_stage2_use_tc( ...@@ -992,16 +975,16 @@ def _decode_v1_stage2_use_tc(
batch, head_num = b_seq_len.shape[0], logits.shape[0] batch, head_num = b_seq_len.shape[0], logits.shape[0]
kv_group_num = logits.shape[0] // v_buffer.shape[-2] kv_group_num = logits.shape[0] // v_buffer.shape[-2]
BLOCK_N = best_config.BLOCK_N_2 BLOCK_N = best_config['stage2']['BLOCK_N']
num_stages = best_config.num_stages_2 num_stages = best_config['stage2']['num_stages']
num_warps = best_config.num_warps_2 num_warps = best_config['stage2']['num_warps']
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num)) BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1) grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)
Lv = v_buffer.shape[-1] Lv = v_buffer.shape[-1]
BLOCK_DMODEL = triton.next_power_of_2(Lv) BLOCK_DMODEL = triton.next_power_of_2(Lv)
if best_config.decode_fwd_stage2 is None:
best_config.decode_fwd_stage2 = _decode_v1_kernel_stage2_use_tc[grid]( _decode_v1_kernel_stage2_use_tc[grid](
logits, logits,
v_buffer, v_buffer,
o, o,
...@@ -1025,22 +1008,6 @@ def _decode_v1_stage2_use_tc( ...@@ -1025,22 +1008,6 @@ def _decode_v1_stage2_use_tc(
num_warps=num_warps, num_warps=num_warps,
num_stages=num_stages, num_stages=num_stages,
) )
else:
best_config.decode_fwd_stage2 = _decode_v1_kernel_stage2_use_tc[grid](
logits,
v_buffer,
o,
req_to_tokens,
#b_req_idx,
b_start_loc,
b_seq_len,
logits.stride(0),
v_buffer.stride(-3),
v_buffer.stride(-2),
o.stride(0),
o.stride(1),
req_to_tokens.stride(0),
)
# return _decode_v1_kernel_stage2_use_tc.best_config # return _decode_v1_kernel_stage2_use_tc.best_config
...@@ -1115,21 +1082,16 @@ def decode_attention_v1( ...@@ -1115,21 +1082,16 @@ def decode_attention_v1(
# @triton.autotune( # @triton.autotune(
# configs=[ # configs=[
# triton.Config({"BLOCK_N": 16}, num_warps=2, num_stages=1), # triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 16}, num_warps=4, num_stages=1), # triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 16}, num_warps=8, num_stages=1), # triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 32}, num_warps=2, num_stages=1), # triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 32}, num_warps=4, num_stages=1), # triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 32}, num_warps=8, num_stages=1), # triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 64}, num_warps=2, num_stages=1), # triton.Config({"BLOCK_N": 128, "BLOCK_DIM":32}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=1), # triton.Config({"BLOCK_N": 128, "BLOCK_DIM":32}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 64}, num_warps=8, num_stages=1), # triton.Config({"BLOCK_N": 256, "BLOCK_DIM":32}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 128}, num_warps=2, num_stages=1), # triton.Config({"BLOCK_N": 256, "BLOCK_DIM":32}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 128}, num_warps=8, num_stages=1),
# triton.Config({"BLOCK_N": 256}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 256}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 256}, num_warps=8, num_stages=1),
# ], # ],
# key=["B_Seqlen","stride_qbs","stride_buf_kbs","stride_buf_kh", "stride_buf_vbs", "stride_buf_vh"] # key=["B_Seqlen","stride_qbs","stride_buf_kbs","stride_buf_kh", "stride_buf_vbs", "stride_buf_vh"]
# ) # )
...@@ -1159,6 +1121,7 @@ def _decode_v2_kernel_stage1_use_tc( ...@@ -1159,6 +1121,7 @@ def _decode_v2_kernel_stage1_use_tc(
BLOCK_DPE: tl.constexpr, BLOCK_DPE: tl.constexpr,
BLOCK_DV: tl.constexpr, BLOCK_DV: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
BLOCK_DIM: tl.constexpr,
BLOCK_H: tl.constexpr, BLOCK_H: tl.constexpr,
NUM_KV_SPLITS: tl.constexpr, NUM_KV_SPLITS: tl.constexpr,
PAGE_SIZE: tl.constexpr, PAGE_SIZE: tl.constexpr,
...@@ -1179,16 +1142,16 @@ def _decode_v2_kernel_stage1_use_tc( ...@@ -1179,16 +1142,16 @@ def _decode_v2_kernel_stage1_use_tc(
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
mask_h = mask_h & (cur_head < q_head_num) mask_h = mask_h & (cur_head < q_head_num)
offs_d = tl.arange(0, BLOCK_DMODEL) # offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV) offs_dv = tl.arange(0, BLOCK_DV)
mask_d = offs_d < Lk # mask_d = offs_d < Lk
mask_dv = offs_dv < Lv mask_dv = offs_dv < Lv
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
# cur_batch_req_idx = tl.load(B_req_idx + cur_batch) # cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
cur_batch_req_idx = cur_batch cur_batch_req_idx = cur_batch
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] # offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) # q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
if BLOCK_DPE > 0: if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
...@@ -1208,26 +1171,25 @@ def _decode_v2_kernel_stage1_use_tc( ...@@ -1208,26 +1171,25 @@ def _decode_v2_kernel_stage1_use_tc(
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
NUM_DIM_SPLIT = tl.cdiv(BLOCK_DMODEL, BLOCK_DIM)
if split_kv_end > split_kv_start: if split_kv_end > split_kv_start:
for start_n in range(split_kv_start, split_kv_end, BLOCK_N): for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N) offs_n = start_n + tl.arange(0, BLOCK_N)
kv_page_number = tl.load( kv_page_number = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n // PAGE_SIZE, Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n // PAGE_SIZE,
mask=offs_n < split_kv_end, mask=offs_n < split_kv_end,
other=0,
) )
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
offs_buf_k = (
kv_loc[None, :] * stride_buf_kbs qk = tl.zeros([BLOCK_H, BLOCK_N], dtype=tl.float32)
+ cur_kv_head * stride_buf_kh for i in range(0, NUM_DIM_SPLIT):
+ offs_d[:, None] offs_d = tl.arange(0, BLOCK_DIM) + i * BLOCK_DIM
) mask_d = offs_d < Lk
k = tl.load( offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None,:]
K_Buffer + offs_buf_k, q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), offs_buf_k = kv_loc[None, :] * stride_buf_kbs + cur_kv_head * stride_buf_kh + offs_d[:, None]
other=0.0, k = tl.load(K_Buffer + offs_buf_k, mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), other=0.0)
) qk += tl.dot(q, k.to(q.dtype))
qk = tl.dot(q, k.to(q.dtype))
if BLOCK_DPE > 0: if BLOCK_DPE > 0:
offs_buf_kpe = ( offs_buf_kpe = (
kv_loc[None, :] * stride_buf_kbs kv_loc[None, :] * stride_buf_kbs
...@@ -1311,9 +1273,10 @@ def _decode_v2_stage1_use_tc( ...@@ -1311,9 +1273,10 @@ def _decode_v2_stage1_use_tc(
logit_cap, logit_cap,
): ):
BLOCK = best_config.BLOCK_N BLOCK = best_config['stage1']['BLOCK_N']
num_stages = best_config.num_stages BLOCK_DIM = best_config['stage1']['BLOCK_DIM']
num_warps = best_config.num_warps num_stages = best_config['stage1']['num_stages']
num_warps = best_config['stage1']['num_warps']
Lk = k_buffer.shape[-1] Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1] Lv = v_buffer.shape[-1]
...@@ -1341,8 +1304,7 @@ def _decode_v2_stage1_use_tc( ...@@ -1341,8 +1304,7 @@ def _decode_v2_stage1_use_tc(
NUM_KV_SPLITS, NUM_KV_SPLITS,
) )
if best_config.decode_fwd_stage1 is None: _decode_v2_kernel_stage1_use_tc[grid](
best_config.decode_fwd_stage1 =_decode_v2_kernel_stage1_use_tc[grid](
q, q,
k_buffer, k_buffer,
v_buffer, v_buffer,
...@@ -1367,6 +1329,7 @@ def _decode_v2_stage1_use_tc( ...@@ -1367,6 +1329,7 @@ def _decode_v2_stage1_use_tc(
BLOCK_DPE=BLOCK_DPE, BLOCK_DPE=BLOCK_DPE,
BLOCK_DV=BLOCK_DV, BLOCK_DV=BLOCK_DV,
BLOCK_N=BLOCK, BLOCK_N=BLOCK,
BLOCK_DIM=BLOCK_DIM,
BLOCK_H=BLOCK_H, BLOCK_H=BLOCK_H,
NUM_KV_SPLITS=NUM_KV_SPLITS, NUM_KV_SPLITS=NUM_KV_SPLITS,
PAGE_SIZE=page_size, PAGE_SIZE=page_size,
...@@ -1377,33 +1340,14 @@ def _decode_v2_stage1_use_tc( ...@@ -1377,33 +1340,14 @@ def _decode_v2_stage1_use_tc(
Lv=Lv, Lv=Lv,
kpack=2, kpack=2,
) )
else:
best_config.decode_fwd_stage1[grid](
q,
k_buffer,
v_buffer,
sm_scale,
Req_to_tokens,
# B_req_idx,
B_Seqlen,
att_out,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(-3),
k_buffer.stride(-2),
v_buffer.stride(-3),
v_buffer.stride(-2),
att_out.stride(0),
att_out.stride(1),
att_out.stride(2),
)
# return _decode_v2_kernel_stage1_use_tc.best_config # return _decode_v2_kernel_stage1_use_tc.best_config
# @triton.autotune( # @triton.autotune(
# configs=[ # configs=[
# triton.Config({}, num_warps=1, num_stages=1), # triton.Config({}, num_warps=1, num_stages=1),
# triton.Config({}, num_warps=1, num_stages=1),
# triton.Config({}, num_warps=2, num_stages=1), # triton.Config({}, num_warps=2, num_stages=1),
# triton.Config({}, num_warps=4, num_stages=1), # triton.Config({}, num_warps=4, num_stages=1),
# triton.Config({}, num_warps=8, num_stages=1), # triton.Config({}, num_warps=8, num_stages=1),
...@@ -1476,8 +1420,8 @@ def _decode_v2_stage2_use_tc( ...@@ -1476,8 +1420,8 @@ def _decode_v2_stage2_use_tc(
num_kv_splits, num_kv_splits,
best_config, best_config,
): ):
num_stages = best_config.num_stages_2 num_stages = best_config['stage2']['num_stages']
num_warps = best_config.num_warps_2 num_warps = best_config['stage2']['num_warps']
batch, head_num = q.shape[0], q.shape[1] batch, head_num = q.shape[0], q.shape[1]
Lv = v_buffer.shape[-1] Lv = v_buffer.shape[-1]
...@@ -1486,8 +1430,7 @@ def _decode_v2_stage2_use_tc( ...@@ -1486,8 +1430,7 @@ def _decode_v2_stage2_use_tc(
NUM_KV_SPLITS = num_kv_splits NUM_KV_SPLITS = num_kv_splits
grid = (batch, head_num, 1) grid = (batch, head_num, 1)
if best_config.decode_fwd_stage2 is None: _decode_v2_kernel_stage2[grid](
best_config.decode_fwd_stage2 = _decode_v2_kernel_stage2[grid](
logits, logits,
o, o,
b_seq_len, b_seq_len,
...@@ -1502,17 +1445,6 @@ def _decode_v2_stage2_use_tc( ...@@ -1502,17 +1445,6 @@ def _decode_v2_stage2_use_tc(
num_warps=num_warps, num_warps=num_warps,
num_stages=num_stages, num_stages=num_stages,
) )
else:
best_config.decode_fwd_stage2[grid](
logits,
o,
b_seq_len,
logits.stride(0),
logits.stride(1),
logits.stride(2),
o.stride(0),
o.stride(1),
)
# return _decode_v2_kernel_stage2.best_config # return _decode_v2_kernel_stage2.best_config
...@@ -1580,7 +1512,7 @@ def decode_attention_fwd( ...@@ -1580,7 +1512,7 @@ def decode_attention_fwd(
): ):
assert num_kv_splits == attn_logits.shape[2] assert num_kv_splits == attn_logits.shape[2]
kv_group_num = q.shape[1] // v_buffer.shape[-2] kv_group_num = q.shape[1] // v_buffer.shape[-2]
b_start_loc = torch.arange(0, k_buffer.shape[0] * page_size, k_buffer.shape[0] * page_size // q.shape[0], device="cuda").to(torch.int32) b_start_loc = torch.arange(0, req_to_token.shape[0]*req_to_token.shape[1], req_to_token.shape[0]*req_to_token.shape[1] // q.shape[0], device="cuda").to(torch.int32)
if kv_group_num == 1: if kv_group_num == 1:
# MHA # MHA
decode_attention_fwd_normal( decode_attention_fwd_normal(
...@@ -1614,8 +1546,8 @@ def decode_attention_fwd( ...@@ -1614,8 +1546,8 @@ def decode_attention_fwd(
logit_cap, logit_cap,
) )
attn_logits_v1 = torch.empty( attn_logits_v1 = torch.empty(
(q.shape[1],k_buffer.shape[0]*page_size), (q.shape[1],req_to_token.shape[0]*req_to_token.shape[1]*page_size),
dtype=torch.float16, dtype=torch.float32,
device="cuda") device="cuda")
decode_attention_v1( decode_attention_v1(
q, q,
...@@ -1632,10 +1564,10 @@ def decode_attention_fwd( ...@@ -1632,10 +1564,10 @@ def decode_attention_fwd(
logit_cap, logit_cap,
)''' )'''
if best_config.kernel_kind == KERNLE_KINDS.v1_2stages_tc: if best_config['kernel_kind'] == 'v1_2stages_tc':
attn_logits_v1 = torch.empty( attn_logits_v1 = torch.empty(
(q.shape[1],k_buffer.shape[0]*page_size), (q.shape[1],req_to_token.shape[0]*req_to_token.shape[1]*page_size),
dtype=torch.float16, dtype=torch.float32,
device="cuda") device="cuda")
decode_attention_v1( decode_attention_v1(
q, q,
...@@ -1648,11 +1580,11 @@ def decode_attention_fwd( ...@@ -1648,11 +1580,11 @@ def decode_attention_fwd(
attn_logits_v1, attn_logits_v1,
num_kv_splits, num_kv_splits,
sm_scale, sm_scale,
best_config=best_config, best_config=best_config['best_config'],
page_size=page_size, page_size=page_size,
logit_cap=logit_cap, logit_cap=logit_cap,
) )
elif best_config.kernel_kind == KERNLE_KINDS.v2_tc: elif best_config['kernel_kind'] == 'v2_tc':
decode_attention_v2( decode_attention_v2(
q, q,
k_buffer, k_buffer,
...@@ -1663,7 +1595,7 @@ def decode_attention_fwd( ...@@ -1663,7 +1595,7 @@ def decode_attention_fwd(
attn_logits, attn_logits,
num_kv_splits, num_kv_splits,
sm_scale, sm_scale,
best_config=best_config, best_config=best_config['best_config'],
page_size=page_size, page_size=page_size,
logit_cap=logit_cap, logit_cap=logit_cap,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment