Commit ca4ec0ce authored by lizhigong's avatar lizhigong
Browse files

Merge remote-tracking branch 'origin/v0.7.2-dev' into v0.7.2_zero_overhead

parents 0be169ad ae0ed592
...@@ -5,6 +5,7 @@ import dataclasses ...@@ -5,6 +5,7 @@ import dataclasses
import json import json
import random import random
import time import time
from pathlib import Path
from functools import cache from functools import cache
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import os import os
...@@ -216,15 +217,34 @@ def run_vllm( ...@@ -216,15 +217,34 @@ def run_vllm(
use_beam_search = False use_beam_search = False
print("testing") print("testing")
if not use_beam_search: if not use_beam_search:
start = time.perf_counter() if args.profile:
outputs=llm.generate(prompts, profile_dir = args.profile_result_dir
sampling_params, if not profile_dir:
lora_request=lora_requests, profile_dir = Path(
use_tqdm=True) "."
for output in outputs: ) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
generated_text=output.outputs[0].text print(f"Profiling (results will be saved to '{profile_dir}')...")
print(f"test生成的文本: {generated_text}") with torch.profiler.profile(
end = time.perf_counter() activities=[torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],record_shapes=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler(str(profile_dir))
) as prof:
start = time.perf_counter()
llm.generate(prompts,
sampling_params,
lora_request=lora_requests,
use_tqdm=True)
end = time.perf_counter()
print('Prepare time report')
print(prof.key_averages(group_by_input_shape=True).table(sort_by="self_cuda_time_total", row_limit=-1))
else:
start = time.perf_counter()
llm.generate(prompts,
sampling_params,
lora_request=lora_requests,
use_tqdm=True)
end = time.perf_counter()
else: else:
assert lora_requests is None, "BeamSearch API does not support LoRA" assert lora_requests is None, "BeamSearch API does not support LoRA"
prompts = [request.prompt for request in requests] prompts = [request.prompt for request in requests]
...@@ -502,6 +522,16 @@ if __name__ == "__main__": ...@@ -502,6 +522,16 @@ if __name__ == "__main__":
type=int, type=int,
default=None, default=None,
help="Maximum batch size for HF backend.") help="Maximum batch size for HF backend.")
parser.add_argument(
'--profile',
action='store_true',
help='profile the generation process of a single batch')
parser.add_argument(
'--profile-result-dir',
type=str,
default=None,
help=('path to save the pytorch profiler output. Can be visualized '
'with ui.perfetto.dev or Tensorboard.'))
parser.add_argument( parser.add_argument(
'--output-json', '--output-json',
type=str, type=str,
......
...@@ -151,16 +151,16 @@ def benchmark_config( ...@@ -151,16 +151,16 @@ def benchmark_config(
torch.cuda.synchronize() torch.cuda.synchronize()
# Capture 10 invocations with CUDA graph # Capture 10 invocations with CUDA graph
# graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
# with torch.cuda.graph(graph): with torch.cuda.graph(graph):
# for _ in range(10): for _ in range(10):
# run() run()
# torch.cuda.synchronize() torch.cuda.synchronize()
# Warmup # Warmup
for _ in range(5): for _ in range(5):
# graph.replay() graph.replay()
run() #run()
torch.cuda.synchronize() torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True) start_event = torch.cuda.Event(enable_timing=True)
...@@ -172,28 +172,28 @@ def benchmark_config( ...@@ -172,28 +172,28 @@ def benchmark_config(
torch.cuda.synchronize() torch.cuda.synchronize()
start_event.record() start_event.record()
# graph.replay() graph.replay()
run() #run()
end_event.record() end_event.record()
end_event.synchronize() end_event.synchronize()
latencies.append(start_event.elapsed_time(end_event)) latencies.append(start_event.elapsed_time(end_event))
avg = sum(latencies) / (num_iters * 10) * 1000 # us avg = sum(latencies) / (num_iters * 10) * 1000 # us
# graph.reset() graph.reset()
return avg return avg
def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False): def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False):
block_m_range = [16, 32, 64, 128, 256] block_m_range = [16, 32, 64, 128, 256]
block_n_range = [32, 64, 128, 256] block_n_range = [32, 64, 128, 256]
block_k_range = [16, 32, 64, 128, 256] block_k_range = [32, 64, 128, 256]
if not use_fp16: if not use_fp16:
block_k_range.remove(16) # BLOCK_K=16 not supported for fp8 block_k_range.remove(16) # BLOCK_K=16 not supported for fp8
num_warps_range = [1, 2, 4, 8] num_warps_range = [2, 4, 8]
group_m_range = [1, 4, 8, 16, 32] group_m_range = [1, 16, 32, 64]
num_stage_range = [2] num_stage_range = [2, 3, 4, 5]
waves_per_eu_range = [0] #waves_per_eu_range = [0]
matrix_instr_nonkdim_range = [16, 32] if use_fp16 else [] #matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
kpack_range = [1, 2] if use_fp16 else [] #kpack_range = [1, 2] if use_fp16 else []
param_ranges = { param_ranges = {
"BLOCK_SIZE_M": block_m_range, "BLOCK_SIZE_M": block_m_range,
...@@ -202,7 +202,7 @@ def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False): ...@@ -202,7 +202,7 @@ def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False):
"GROUP_SIZE_M": group_m_range, "GROUP_SIZE_M": group_m_range,
"num_warps": num_warps_range, "num_warps": num_warps_range,
"num_stages": num_stage_range, "num_stages": num_stage_range,
"waves_per_eu": waves_per_eu_range, #"waves_per_eu": waves_per_eu_range,
} }
if nn_moe: if nn_moe:
param_ranges["num_ldmatrixes"] = [1] param_ranges["num_ldmatrixes"] = [1]
...@@ -378,6 +378,8 @@ class BenchmarkWorker: ...@@ -378,6 +378,8 @@ class BenchmarkWorker:
dtype: torch.dtype, dtype: torch.dtype,
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
nn_moe: Optional[bool] = False,
moe_ep_size: Optional[int] = 1
) -> Tuple[Dict[str, int], float]: ) -> Tuple[Dict[str, int], float]:
current_platform.seed_everything(self.seed) current_platform.seed_everything(self.seed)
dtype_str = get_config_dtype_str(dtype, dtype_str = get_config_dtype_str(dtype,
...@@ -386,7 +388,7 @@ class BenchmarkWorker: ...@@ -386,7 +388,7 @@ class BenchmarkWorker:
# NOTE(woosuk): The current naming convention uses w2.shape[2], which # NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul. # is the intermediate size after silu_and_mul.
op_config = get_moe_configs(num_experts, shard_intermediate_size // 2, op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
dtype_str) dtype_str, use_nn_moe=nn_moe)
if op_config is None: if op_config is None:
config = get_default_config(num_tokens, config = get_default_config(num_tokens,
num_experts, num_experts,
...@@ -394,14 +396,16 @@ class BenchmarkWorker: ...@@ -394,14 +396,16 @@ class BenchmarkWorker:
hidden_size, hidden_size,
topk, topk,
dtype_str, dtype_str,
is_marlin=False) is_marlin=False,
use_nn_moe=nn_moe)
else: else:
config = op_config[min(op_config.keys(), config = op_config[min(op_config.keys(),
key=lambda x: abs(x - num_tokens))] key=lambda x: abs(x - num_tokens))]
kernel_time = benchmark_config(config, num_tokens, num_experts, kernel_time = benchmark_config(config, num_tokens, num_experts,
shard_intermediate_size, hidden_size, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8_w8a8, topk, dtype, use_fp8_w8a8,
use_int8_w8a16) use_int8_w8a16, nn_moe=nn_moe,
moe_ep_size=moe_ep_size)
return config, kernel_time return config, kernel_time
def tune( def tune(
...@@ -439,7 +443,7 @@ class BenchmarkWorker: ...@@ -439,7 +443,7 @@ class BenchmarkWorker:
dtype, dtype,
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
num_iters=20, num_iters=10,
nn_moe=nn_moe, nn_moe=nn_moe,
moe_ep_size=moe_ep_size) moe_ep_size=moe_ep_size)
except triton.runtime.autotuner.OutOfResources: except triton.runtime.autotuner.OutOfResources:
...@@ -597,7 +601,7 @@ def main(args: argparse.Namespace): ...@@ -597,7 +601,7 @@ def main(args: argparse.Namespace):
else: else:
outputs = _distribute( outputs = _distribute(
"benchmark", [(batch_size, E, shard_intermediate_size, hidden_size, "benchmark", [(batch_size, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8_w8a8, use_int8_w8a16) topk, dtype, use_fp8_w8a8, use_int8_w8a16, args.nn_moe, moe_ep_size)
for batch_size in batch_sizes]) for batch_size in batch_sizes])
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs): for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
......
This diff is collapsed.
...@@ -107,11 +107,11 @@ inline __device__ void v_mmac_f32_16x16x16_f16(const half4_t& reg_a, const half4 ...@@ -107,11 +107,11 @@ inline __device__ void v_mmac_f32_16x16x16_f16(const half4_t& reg_a, const half4
{ {
if constexpr (is_half){ if constexpr (is_half){
asm volatile("v_mmac_f32_16x16x16_f16 %0, %1, %2, %0" : asm volatile("\n s_nop 1 \n v_mmac_f32_16x16x16_f16 %0, %1, %2, %0" :
"=v"(reg_c) : "v"(reg_a), "v"(reg_b), "0"(reg_c)); "=v"(reg_c) : "v"(reg_a), "v"(reg_b), "0"(reg_c));
} }
else{ else{
asm volatile("v_mmac_f32_16x16x16_bf16 %0, %1, %2, %0" : asm volatile("\n s_nop 1 \n v_mmac_f32_16x16x16_bf16 %0, %1, %2, %0" :
"=v"(reg_c) : "v"(reg_a), "v"(reg_b), "0"(reg_c)); "=v"(reg_c) : "v"(reg_a), "v"(reg_b), "0"(reg_c));
} }
} }
...@@ -159,7 +159,7 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -159,7 +159,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
const int seq_len = __builtin_amdgcn_readfirstlane(seq_lens[seq_idx]); const int seq_len = __builtin_amdgcn_readfirstlane(seq_lens[seq_idx]);
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const bool USE_PARTITIONING = PARTITION_SIZE<num_seq_blocks * BLOCK_SIZE && PARTITION_SIZE>0; const bool USE_PARTITIONING = PARTITION_SIZE<num_seq_blocks * BLOCK_SIZE && PARTITION_SIZE>0;
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) return; if (partition_idx * PARTITION_SIZE >= seq_len) return;
constexpr bool is_half = std::is_same<scalar_t, uint16_t>::value; constexpr bool is_half = std::is_same<scalar_t, uint16_t>::value;
static_assert(HEAD_SIZE<=4*NUM_THREADS,"HEAD_SIZE<=4*NUM_THREADS"); static_assert(HEAD_SIZE<=4*NUM_THREADS,"HEAD_SIZE<=4*NUM_THREADS");
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
...@@ -209,10 +209,12 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -209,10 +209,12 @@ __global__ void paged_attention_kernel_TC_with_mask(
for(int i=0;i<q_boundary;i++){ for(int i=0;i<q_boundary;i++){
if(thread_idx<16){ if(thread_idx<16){
half4x2 temp = *reinterpret_cast<const half4x2*>(q_ptr+i*HEAD_SIZE+thread_idx*8); half4x2 temp = *reinterpret_cast<const half4x2*>(q_ptr+i*HEAD_SIZE+thread_idx*8);
#pragma unroll if constexpr(is_half){
for(int k=0;k<4;k++){ scalar_t *t=reinterpret_cast<scalar_t*>(&temp);
temp.data[0][k]=((float)temp.data[0][k])*scale; #pragma unroll
temp.data[1][k]=((float)temp.data[1][k])*scale; for(int k=0;k<8;k++){
from_float(t[k],to_float(t[k])*scale);
}
} }
q_vecs[i][thread_idx]=temp; q_vecs[i][thread_idx]=temp;
} }
...@@ -249,6 +251,9 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -249,6 +251,9 @@ __global__ void paged_attention_kernel_TC_with_mask(
int reuse_kv_idx=rows+i*4; int reuse_kv_idx=rows+i*4;
if(reuse_kv_idx<REUSE_KV_TIMES){ if(reuse_kv_idx<REUSE_KV_TIMES){
if(reuse_kv_idx>=q_boundary)qk_vec[i]=0; if(reuse_kv_idx>=q_boundary)qk_vec[i]=0;
else {
if constexpr(!is_half) qk_vec[i]*=scale;
}
const int token_idx = block_idx * BLOCK_SIZE+rowid; const int token_idx = block_idx * BLOCK_SIZE+rowid;
if(alibi_slope[i] != 0){ if(alibi_slope[i] != 0){
float alibi=alibi_slope[i]* (token_idx - seq_len + 1); float alibi=alibi_slope[i]* (token_idx - seq_len + 1);
...@@ -316,13 +321,12 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -316,13 +321,12 @@ __global__ void paged_attention_kernel_TC_with_mask(
} }
__syncthreads(); __syncthreads();
constexpr int NUM_ROWS_PER_THREAD =DIVIDE_ROUND_UP(HEAD_SIZE, WARP_SIZE);//2 constexpr int NUM_ROWS_PER_THREAD =DIVIDE_ROUND_UP(HEAD_SIZE, WARP_SIZE);//2
if (q_boundary<=2){ if constexpr(REUSE_KV_TIMES<=2){
constexpr int acc_size = REUSE_KV_TIMES==1?1:2; float accs[REUSE_KV_TIMES][NUM_ROWS_PER_THREAD];
float accs[acc_size][NUM_ROWS_PER_THREAD];
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
#pragma unroll #pragma unroll
for(int k=0;k<acc_size;k++) for(int k=0;k<REUSE_KV_TIMES;k++)
{ {
accs[k][i] = 0.f; accs[k][i] = 0.f;
} }
...@@ -356,7 +360,7 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -356,7 +360,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
float4_t out_vec={0,0,0,0}; float4_t out_vec={0,0,0,0};
builtin_amdgcn_mmac<is_half>(v_vec,logits_vec,out_vec); builtin_amdgcn_mmac<is_half>(v_vec,logits_vec,out_vec);
if(rows==k){ if(rows==k){
for(int resuseid=0;resuseid<acc_size;resuseid++){ for(int resuseid=0;resuseid<REUSE_KV_TIMES;resuseid++){
accs[resuseid][i]+=out_vec[resuseid]; accs[resuseid][i]+=out_vec[resuseid];
} }
} }
...@@ -366,8 +370,7 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -366,8 +370,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
__syncthreads(); __syncthreads();
using floatV_t = __attribute__( (__vector_size__(NUM_ROWS_PER_THREAD * sizeof(float)) )) float; using floatV_t = __attribute__( (__vector_size__(NUM_ROWS_PER_THREAD * sizeof(float)) )) float;
// Perform reduction across warps. // Perform reduction across warps.
#pragma unroll for(int reuse_kv_idx=0; reuse_kv_idx<q_boundary; reuse_kv_idx++) {
for(int reuse_kv_idx=0; reuse_kv_idx<acc_size; reuse_kv_idx++) {
if constexpr (NUM_THREADS>64){ if constexpr (NUM_THREADS>64){
floatV_t* out_smem = reinterpret_cast<floatV_t*>(shared_mem); floatV_t* out_smem = reinterpret_cast<floatV_t*>(shared_mem);
#pragma unroll #pragma unroll
...@@ -780,97 +783,9 @@ __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_v2_reduce_kern ...@@ -780,97 +783,9 @@ __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_v2_reduce_kern
max_num_partitions,PARTITION_SIZE);} max_num_partitions,PARTITION_SIZE);}
static void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITION_SIZE,int &max_num_partitions, void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITION_SIZE,int &max_num_partitions,
int batchsize,int max_seq_len,int qheads,int kvheads,int num_blocks) int batchsize,int max_seq_len,int qheads,int kvheads,int num_blocks);
{
reusekv=1;
num_thread=256;
PARTITION_SIZE=512;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
if(max_seq_len==8192&&num_blocks==1024){//ali test
if(batchsize==1&&qheads==16&&kvheads==16){num_thread=128;return;}
if(batchsize==1&&qheads==32&&kvheads==32){num_thread=64;return;}
if(batchsize==1){
if(qheads==52){reusekv=8;return;}
if(qheads==13){reusekv=2;return;}
reusekv=4;return;
}
if(batchsize==64){
if(qheads==13){PARTITION_SIZE=256;num_thread=128;reusekv=8;}
else if(qheads==32){PARTITION_SIZE=1024;reusekv=8;}
else if(qheads==52||qheads==26){reusekv=16;}
else reusekv=8;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
return;
}
}
if(qheads==kvheads){
if(max_seq_len<=8192){
if(batchsize*qheads>=512){
max_num_partitions=1;
num_thread=64;
}
if(qheads==32&&max_seq_len<=1024)max_num_partitions=1;
}
return;
}
if(max_seq_len<800)max_num_partitions=1;
if(qheads>kvheads*4){
if(max_seq_len<=1000||
max_seq_len<1500&&(batchsize>=8&&qheads>=8||batchsize>=64)||
max_seq_len<1900&&batchsize>=8&&qheads==28
)
max_num_partitions=1;
int blocks=max_num_partitions*batchsize*qheads;
if(device_name=="gfx928"){
if(batchsize*qheads>1024&&max_seq_len>=2000){
max_num_partitions=1;
if(max_seq_len<3900)reusekv=8;
else if(max_seq_len<7800)reusekv=4;
else{
PARTITION_SIZE=2048;
reusekv=8;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
}
return;
}
}
if(max_num_partitions==1){
if(max_seq_len<512){
int bytes=max_seq_len*qheads*batchsize;
if(bytes<51200)reusekv=1;
else if(bytes<256000)reusekv=4;
else reusekv=8;
return;
}
if(batchsize<4||batchsize==4&&qheads==8)reusekv=1;
else if(batchsize<32||batchsize<=64&&qheads==8)reusekv=4;
else reusekv=8;
return;
}
if(blocks<150)return;
if(blocks<600||qheads<=kvheads*4){reusekv=4;return;}
reusekv=8;return;
}
if(device_name=="gfx928"){
if(batchsize*qheads>1024&&max_seq_len>=2000){
max_num_partitions=1;
if(max_seq_len<7800)reusekv=4;
else{
PARTITION_SIZE=2048;
reusekv=4;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
}
return;
}
}
if(max_seq_len<=1000||
max_seq_len<=1500&&(qheads>4&&batchsize>=16||batchsize>=64))
max_num_partitions=1;
int blocks=max_num_partitions*batchsize*qheads;
if(blocks>=150||batchsize>=16||qheads>=8&&(batchsize>=4||max_seq_len>=2000))reusekv=4;
}
template <typename T, typename CACHE_T, int BLOCK_SIZE, template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE> vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE>
void paged_attention_v2_launcher_opt_tc_with_mask( void paged_attention_v2_launcher_opt_tc_with_mask(
...@@ -995,30 +910,6 @@ void paged_attention_v2_launcher_opt_tc_with_mask( ...@@ -995,30 +910,6 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
break; \ break; \
} }
void paged_attention_v2_with_mask(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor&
tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,// [num_seqs, max_seq_len]
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride);
void paged_attention_v2_opt_tc_with_mask( void paged_attention_v2_opt_tc_with_mask(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
...@@ -1042,39 +933,11 @@ void paged_attention_v2_opt_tc_with_mask( ...@@ -1042,39 +933,11 @@ void paged_attention_v2_opt_tc_with_mask(
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len] const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) { const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||(device_name!="gfx928" && device_name!="gfx936")){
paged_attention_v2_with_mask(out,exp_sums,max_logits,tmp_out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step,attn_masks,attn_masks_stride);
}
else{
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V2_LAUNCHER_BLOCK_SIZE) CALL_V2_LAUNCHER_BLOCK_SIZE)
}
} }
void paged_attention_v1_with_mask(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride);
void paged_attention_v1_opt_tc_with_mask( void paged_attention_v1_opt_tc_with_mask(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
...@@ -1095,20 +958,10 @@ void paged_attention_v1_opt_tc_with_mask( ...@@ -1095,20 +958,10 @@ void paged_attention_v1_opt_tc_with_mask(
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len] const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) { const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||(device_name!="gfx928" && device_name!="gfx936")){
paged_attention_v1_with_mask(out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step,attn_masks,attn_masks_stride);
}
else{
paged_attention_v2_opt_tc_with_mask(out,out,out,out,query,key_cache,value_cache,num_kv_heads, paged_attention_v2_opt_tc_with_mask(out,out,out,out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype, scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride, k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step,attn_masks,attn_masks_stride); blocksparse_block_size,blocksparse_head_sliding_step,attn_masks,attn_masks_stride);
}
} }
#undef WARP_SIZE #undef WARP_SIZE
......
# SPDX-License-Identifier: Apache-2.0
import os
import json
import pytest
import torch
import triton
from triton_decode_attention import decode_attentionv1_fwd, decode_attentionv2_fwd
def cdiv(a, b):
return (a + b - 1) // b
@pytest.mark.parametrize("B", [1])
# @pytest.mark.parametrize("L", [100])
# @pytest.mark.parametrize("L", [1,100,400,700,1000,1300,1600,1900,2200,2500,2800,3100,3400,3700,4000,4300,4600,4900,5000,5500,6000,6500,7000,7500,8000])
@pytest.mark.parametrize("L", [1,100,400,700,1000,1300,1600,1900,2200,2500,2800,3100,3400,3700,4000,4300,4600,4900,5000,5500,6000,6500,7000,7500,8000,8500,9000,9500,10000,10500,11000,11500,12000,12500,13000,13500,14000,14500,15000,15500,16000,16500,17000,17500,18000,18500,19000,19500,20000,20500,21000,21500,22000,22500,23000,23500,24000,24500,25000,25500,26000,26500,27000,27500,28000,28500,29000,29500,30000,30500,31000,31500,32000,32500])
@pytest.mark.parametrize("H_Q", [4, 8, 16])
@pytest.mark.parametrize("H_KV", [1])
@pytest.mark.parametrize("D_QK", [576])
@pytest.mark.parametrize("D_V", [512])
@pytest.mark.parametrize("CACHE_SIZE", [16384])
@pytest.mark.parametrize("PAGE_SIZE", [16])
def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
assert CACHE_SIZE % PAGE_SIZE == 0
dtype = torch.bfloat16
seq_len = L # This represents the number of tokens already in the sequence
sm_scale = 1.0 / (D_QK**0.5)
num_kv_splits = 4
num_pages_per_batch = cdiv(seq_len, PAGE_SIZE) #这里为向上取整,65,(1027+16-1)//16
req_to_page = torch.randint(0,
CACHE_SIZE // PAGE_SIZE,
(B, num_pages_per_batch, 1), #shape为(B, num_pages_per_batch, 1)的tensor,大小取值为0 至cache_size//page_size
device="cuda")
req_to_token = req_to_page * PAGE_SIZE
req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE)
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(
1, 1, -1)
req_to_token = req_to_token.view(B, -1)
req_to_token = req_to_token[:, :seq_len].contiguous()
# q represents the new token being generated, one per batch
q = torch.randn(B, H_Q, D_QK, dtype=dtype, device="cuda")
# k_buffer and v_buffer represent all previous tokens
# Page size is 1.
k_buffer = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device="cuda")
v_buffer = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device="cuda")
# o will have the same shape as q
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
b_seq_len = torch.full((B, ), seq_len, device="cuda")
attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1),
dtype=torch.float32,
device="cuda",
)
b_req_idx = torch.arange(B, device="cuda").to(torch.int32)
# Call the original implementation.
decode_attentionv2_fwd(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
)
# Page size can be larger than 1.
k_buffer = k_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK)
v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)
o1 = torch.zeros_like(o)
configs = {
"v2_tc": {"stage1": {}, "stage2": {}},
"v1_2stages_tc": {"stage1": {}, "stage2": {}},
}
ms = {
"v1_2stages_tc": 10000.0,
"v2_tc": 10000.0,
}
final_best_config = {
"kernel_kind": "",
"best_config": {},
"best_us": 0.0,
}
v2_tc_stage1_best_config, v2_tc_stage2_best_config = decode_attentionv2_fwd(
q,
k_buffer,
v_buffer,
o1,
req_to_page,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
PAGE_SIZE,
)
assert torch.allclose(o, o1, atol=1e-2, rtol=1e-2)
quantiles = [0.5, 0.2, 0.8]
v2_tc_ms, v2_tc_min_ms, v2_tc_max_ms = triton.testing.do_bench(lambda:
decode_attentionv2_fwd(
q,
k_buffer,
v_buffer,
o1,
req_to_page,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
PAGE_SIZE,
), quantiles=quantiles)
for key, value in v2_tc_stage1_best_config.kwargs.items():
configs["v2_tc"]["stage1"][key] = value
configs["v2_tc"]["stage1"]["num_stages"] = v2_tc_stage1_best_config.num_stages
configs["v2_tc"]["stage1"]["num_warps"] = v2_tc_stage1_best_config.num_warps
for key, value in v2_tc_stage2_best_config.kwargs.items():
configs["v2_tc"]["stage2"][key] = value
configs["v2_tc"]["stage2"]["num_stages"] = v2_tc_stage2_best_config.num_stages
configs["v2_tc"]["stage2"]["num_warps"] = v2_tc_stage2_best_config.num_warps
ms["v2_tc"] = v2_tc_ms
print(f"v2_tc best configs is {configs['v2_tc']}")
print("print mla decode attention v2 kernel [B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE] min cost :",[B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE], v2_tc_ms)
o2 = torch.zeros_like(o)
v1_tc_stage1_best_config, v1_tc_stage2_best_config = decode_attentionv1_fwd(
q,
k_buffer,
v_buffer,
o2,
req_to_page,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
PAGE_SIZE,
)
assert torch.allclose(o, o2, atol=1e-2, rtol=1e-2)
v1_tc_ms, v1_tc_min_ms, v1_tc_max_ms = triton.testing.do_bench(lambda:
decode_attentionv1_fwd(
q,
k_buffer,
v_buffer,
o1,
req_to_page,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
PAGE_SIZE,
), quantiles=quantiles)
for key, value in v1_tc_stage1_best_config.kwargs.items():
configs["v1_2stages_tc"]["stage1"][key] = value
configs["v1_2stages_tc"]["stage1"]["num_stages"] = v1_tc_stage1_best_config.num_stages
configs["v1_2stages_tc"]["stage1"]["num_warps"] = v1_tc_stage1_best_config.num_warps
configs["v1_2stages_tc"]["stage1"]["num_ldmatrixes"] = v1_tc_stage1_best_config.num_ldmatrixes
for key, value in v1_tc_stage2_best_config.kwargs.items():
configs["v1_2stages_tc"]["stage2"][key] = value
configs["v1_2stages_tc"]["stage2"]["num_stages"] = v1_tc_stage2_best_config.num_stages
configs["v1_2stages_tc"]["stage2"]["num_warps"] = v1_tc_stage2_best_config.num_warps
configs["v1_2stages_tc"]["stage2"]["num_ldmatrixes"] = v1_tc_stage1_best_config.num_ldmatrixes
ms["v1_2stages_tc"] = v1_tc_ms
min_key, min_ms = min(ms.items(), key=lambda x: x[1])
final_best_config["kernel_kind"] = min_key
final_best_config["best_config"] = configs[min_key]
final_best_config["best_us"] = min_ms * 1000
print(f"v1_2stages_tc best configs is {configs['v1_2stages_tc']}")
print("print mla decode attention v1 kernel [B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE] min cost :",[B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE], v1_tc_ms)
print(f"Tuned_decode_attention choose {min_key} kernel, min cost {min_ms} ms, best config of {min_key} kernel is {configs[min_key]}")
assert torch.allclose(o, o2, atol=1e-2, rtol=1e-2)
#**************save config**************#
batch = b_req_idx.shape[0]
mean_seq_len = int((b_seq_len.sum() / max(1, batch)).item())
device_name = torch.cuda.get_device_name().replace(" ", "_")
if "K100_AI" in device_name:
# return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_K100AI.json"
file_name = f"QH={H_Q}_KVH={H_KV}_QKD={D_QK}_VD={D_V}_fp16_K100AI.json"
elif "BW" in device_name:
# return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_BW.json"
file_name = f"QH={H_Q}_KVH={H_KV}_QKD={D_QK}_VD={D_V}_fp16_BW.json"
else:
raise ValueError(f"Unsurpport device name: {device_name}")
if os.path.exists(file_name):
with open(file_name, 'r') as file:
config_info = json.load(file)
else:
config_info = {}
# 如果 config_info 中没有当前的 batch,初始化它为一个空字典
# if f"{batch}" not in config_info:
# config_info[f"{batch}"] = {}
# 把新的 mean_seq_len 配置加入到当前 batch 中
# config_info[f"{batch}"][f"{mean_seq_len}"] = final_best_config
config_info[f"{mean_seq_len}"] = final_best_config
# 保存最佳配置
with open(file_name, 'w') as file:
json.dump(config_info, file, indent=1)
#**************save config**************#
This diff is collapsed.
...@@ -488,10 +488,12 @@ def get_version_add(sha: Optional[str] = None) -> str: ...@@ -488,10 +488,12 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha is None: if sha is None:
sha = get_sha(vllm_root) sha = get_sha(vllm_root)
if (major, minor) == ('2', '4'): if (major, minor) == ('2', '4'):
version = 'das.opt1.cust1.' + sha[:7] # version = 'das.opt1.cust2.' + sha[:7]
version = 'das.opt1.' + sha[:7]
else: else:
if (major, minor) == ('2', '4'): if (major, minor) == ('2', '4'):
version = 'das.opt1.cust1' # version = 'das.opt1.cust2'
version = 'das.opt1'
# dtk version # dtk version
...@@ -688,7 +690,7 @@ package_data = { ...@@ -688,7 +690,7 @@ package_data = {
"model_executor/layers/fused_moe/configs/*.json", "model_executor/layers/fused_moe/configs/*.json",
"model_executor/layers/quantization/utils/configs/*.json", "model_executor/layers/quantization/utils/configs/*.json",
"benchmarks/*.py", "benchmarks/*.py",
"model_executor/layers/quantization/configs/w8a8/*.json", "attention/backends/configs/*.json",
"model_executor/layers/quantization/configs/awq/*.json" "model_executor/layers/quantization/configs/awq/*.json"
] ]
} }
......
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
import pytest import pytest
import torch import torch
import triton
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.attention.ops.triton_decode_attention import decode_attention_fwd, decode_attention_v1, decode_attention_v2
def cdiv(a, b): def cdiv(a, b):
return (a + b - 1) // b return (a + b - 1) // b
...@@ -25,13 +25,13 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): ...@@ -25,13 +25,13 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
sm_scale = 1.0 / (D_QK**0.5) sm_scale = 1.0 / (D_QK**0.5)
num_kv_splits = 8 num_kv_splits = 8
num_pages_per_batch = cdiv(seq_len, PAGE_SIZE) num_pages_per_batch = cdiv(seq_len, PAGE_SIZE) # 向上取整:65, (1027+16-1)//16
req_to_page = torch.randint(0, req_to_page = torch.randint(0,
CACHE_SIZE // PAGE_SIZE, CACHE_SIZE // PAGE_SIZE,
(B, num_pages_per_batch, 1), (B, num_pages_per_batch, 1), #shape为(B, num_pages_per_batch, 1)的tensor,大小取值为0 至cache_size//page_size
device="cuda") device="cuda")
req_to_token = req_to_page * PAGE_SIZE req_to_token = req_to_page * PAGE_SIZE
req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE) req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE) # 维度扩展,从torch.Size([3, 65, 1])扩展至torch.Size([3, 65, 16])
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view( req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(
1, 1, -1) 1, 1, -1)
req_to_token = req_to_token.view(B, -1) req_to_token = req_to_token.view(B, -1)
...@@ -47,14 +47,22 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): ...@@ -47,14 +47,22 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
# o will have the same shape as q # o will have the same shape as q
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
b_seq_len = torch.full((B, ), seq_len, device="cuda") b_seq_len = torch.full((B, ), seq_len, device="cuda")
b_start_loc = torch.arange(0, k_buffer.shape[0] * PAGE_SIZE, k_buffer.shape[0] * PAGE_SIZE // q.shape[0], device="cuda").to(torch.int32)
attn_logits_v1 = torch.empty(
(q.shape[1], k_buffer.shape[0]*PAGE_SIZE),
dtype=torch.float16,
device="cuda")
attn_logits = torch.empty( attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1), (B, H_Q, num_kv_splits, D_V + 1),
dtype=torch.float32, dtype=torch.float32,
device="cuda", device="cuda",
) )
quantiles = [0.5, 0.2, 0.8]
# Call the original implementation. # Call the original implementation.
decode_attention_fwd( decode_attention_fwd(
...@@ -87,5 +95,81 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): ...@@ -87,5 +95,81 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
sm_scale, sm_scale,
PAGE_SIZE, PAGE_SIZE,
) )
assert torch.allclose(o, o1) assert torch.allclose(o, o1)
# v0_tc_ms, v0_tc_min_ms, v0_tc_max_ms = triton.testing.do_bench(lambda:
# decode_attention_fwd(
# q,
# k_buffer,
# v_buffer,
# o1,
# req_to_page,
# b_seq_len,
# attn_logits,
# num_kv_splits,
# sm_scale,
# PAGE_SIZE,
# ), quantiles=quantiles)
# print("print mla decode attention ori kernel [B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE] min cost :",[B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE], v0_tc_ms)
decode_attention_v1(
q,
k_buffer,
v_buffer,
o1,
req_to_page,
b_start_loc,
b_seq_len,
attn_logits_v1,
num_kv_splits,
sm_scale,
PAGE_SIZE,
)
assert torch.allclose(o, o1, atol=1e-2, rtol=1e-2)
# v1_tc_ms, v1_tc_min_ms, v1_tc_max_ms = triton.testing.do_bench(lambda:
# decode_attention_v1(
# q,
# k_buffer,
# v_buffer,
# o1,
# req_to_page,
# b_start_loc,
# b_seq_len,
# attn_logits_v1,
# num_kv_splits,
# sm_scale,
# PAGE_SIZE,
# ), quantiles=quantiles)
# print("print mla decode attention v1 kernel [B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE] min cost :",[B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE], v1_tc_ms)
decode_attention_v2(
q,
k_buffer,
v_buffer,
o1,
req_to_page,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
PAGE_SIZE,
)
assert torch.allclose(o, o1, atol=1e-2, rtol=1e-2)
# v2_tc_ms, v2_tc_min_ms, v2_tc_max_ms = triton.testing.do_bench(lambda:
# decode_attention_v2(
# q,
# k_buffer,
# v_buffer,
# o1,
# req_to_page,
# b_seq_len,
# attn_logits,
# num_kv_splits,
# sm_scale,
# PAGE_SIZE,
# ), quantiles=quantiles)
# print("print mla decode attention v2 kernel [B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE] min cost :",[B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE], v2_tc_ms)
import functools
import json
import torch
import os
from enum import Enum
from typing import Any, Dict, Optional, Tuple
import bisect
from vllm.logger import init_logger
logger = init_logger(__name__)
class KERNLE_KINDS(Enum):
v1_2stages = 0
v1_2stages_tc = 1
v2 = 2
v2_tc = 3
TOTAL_KIND = 4
class BestConfig():
def __init__(self):
self.batch_size = 0
self.seq_len = 0
self.kernel_kind = KERNLE_KINDS.TOTAL_KIND
self.BLOCK_N = 0
self.BLOCK_DIM = 0
# self.BLOCK_SEQ = 0
# self.SPLIT_K = 0
self.num_stages = 0
self.num_warps = 0
self.NUM_KV_SPLITS = 0
self.BLOCK_N_2 = 0
self.num_stages_2 = 0
self.num_warps_2 = 0
self.best_us = 0
self.decode_fwd_stage1 = None
self.decode_fwd_stage2 = None
def get_mla_config_file_name(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> str:
if cache_dtype == "default":
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_default.json"
device_name = torch.cuda.get_device_name().replace(" ", "_")
if "K100_AI" in device_name:
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_K100AI.json"
elif "BW" in device_name:
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_BW.json"
else:
raise ValueError(f"Unsurpport device name: {device_name}")
def get_attention_mla_configs_json(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> Optional[Dict[Any, Any]]:
# First look up if an optimized configuration is available in the configs
# directory
json_file_name = get_mla_config_file_name(QH, KVH, QKD, VD, cache_dtype)
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
# logger.info("Using decode attention configuration from %s for attention layer.", config_file_path)
# If a configuration has been found, return it
return json.load(f)
else:
logger.warning("Can not find best decode attention configuration %s for attention layer, it may not have the best performance to use default json. Please tune one. ", config_file_path)
json_file_name = get_mla_config_file_name(16, 1, 576, 512, "default")
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.warning("Using default decode attention configuration from %s for attention layer. It may not have the best performance to use default json. ", config_file_path)
# If a configuration has been found, return it
return json.load(f)
else:
raise ValueError("Please surpport default config can match 16 1 576 512")
# If no optimized configuration is available, we will use the default
# configuration
return None
def get_config_map(attention_configs):
ret_map = {}
for bs in attention_configs.keys():
int_bs = int(bs)
seq_map = {}
seq_configs = attention_configs[bs]
ret_map[int_bs] = seq_map
for seq_len in seq_configs.keys():
int_seq_len = int(seq_len)
kind_config = seq_configs[seq_len]
configs = BestConfig()
# configs.batch_size = int_bs
# configs.seq_len = int_seq_len
configs.best_us = kind_config['best_us']
seq_map[int_seq_len] = configs
if kind_config['kernel_kind'] == 'v1_2stages':
best_config = kind_config['best_config']
stage1 = best_config['stage1']
stage2 = best_config['stage2']
configs.kernel_kind = KERNLE_KINDS.v1_2stages
# configs.SPLIT_K = stage1['SPLIT_K']
configs.BLOCK_N = stage1['BLOCK_N']
configs.num_stages = stage1['num_stages']
configs.num_warps = stage1['num_warps']
configs.BLOCK_N_2 = stage2['BLOCK_N']
configs.num_stages_2 = stage2['num_stages']
configs.num_warps_2 = stage2['num_warps']
elif kind_config['kernel_kind'] == 'v1_2stages_tc':
best_config = kind_config['best_config']
stage1 = best_config['stage1']
stage2 = best_config['stage2']
configs.kernel_kind = KERNLE_KINDS.v1_2stages_tc
# configs.SPLIT_K = stage1['SPLIT_K']
configs.BLOCK_N = stage1['BLOCK_N']
configs.num_stages = stage1['num_stages']
configs.num_warps = stage1['num_warps']
configs.BLOCK_N_2 = stage2['BLOCK_N']
configs.num_stages_2 = stage2['num_stages']
configs.num_warps_2 = stage2['num_warps']
elif kind_config['kernel_kind'] == 'v2':
best_config = kind_config['best_config']
stage1 = best_config['stage1']
stage2 = best_config['stage2']
configs.kernel_kind = KERNLE_KINDS.v2
# if 'BLOCK_SEQ' in stage1:
# configs.BLOCK_SEQ = stage1['BLOCK_SEQ']
# else:
# configs.NUM_KV_SPLITS = stage1['NUM_KV_SPLITS']
configs.BLOCK_N = stage1['BLOCK_N']
configs.num_stages = stage1['num_stages']
configs.num_warps = stage1['num_warps']
configs.num_stages_2 = stage2['num_stages']
configs.num_warps_2 = stage2['num_warps']
elif kind_config['kernel_kind'] == 'v2_tc':
best_config = kind_config['best_config']
stage1 = best_config['stage1']
stage2 = best_config['stage2']
configs.kernel_kind = KERNLE_KINDS.v2_tc
# if 'BLOCK_SEQ' in stage1:
# configs.BLOCK_SEQ = stage1['BLOCK_SEQ']
# else:
# configs.NUM_KV_SPLITS = stage1['NUM_KV_SPLITS']
configs.BLOCK_N = stage1['BLOCK_N']
configs.BLOCK_DIM = stage1['BLOCK_DIM']
configs.num_stages = stage1['num_stages']
configs.num_warps = stage1['num_warps']
configs.num_stages_2 = stage2['num_stages']
configs.num_warps_2 = stage2['num_warps']
return ret_map
@functools.lru_cache
def get_attention_mla_configs(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> Optional[Dict[Any, Any]]:
attention_configs = get_attention_mla_configs_json(QH, KVH, QKD, VD, cache_dtype)
return get_config_map(attention_configs)
def get_closest_key(dic_keys, target_key):
keys = list(dic_keys)
idx = bisect.bisect_left(keys, target_key)
if idx == 0:
return keys[0]
if idx == len(keys):
return keys[-1]
left_key = keys[idx - 1]
right_key = keys[idx]
if target_key - left_key <= right_key - target_key:
return left_key
else:
return right_key
def get_nearest_config(bs_key, mean_kv_seqlen_key, config):
closest_bs_key = get_closest_key(config.keys(), bs_key)
closest_mean_kv_seqlen_key = get_closest_key(config[closest_bs_key].keys(), mean_kv_seqlen_key)
return config[closest_bs_key][closest_mean_kv_seqlen_key]
def get_config(bs_key, mean_kv_seqlen_key, config):
if bs_key in config and mean_kv_seqlen_key in config[bs_key]:
return config[bs_key][mean_kv_seqlen_key]
else:
raise ValueError(f"No matching configuration found for bs key: {bs_key} and mean kv seq key: {mean_kv_seqlen_key} when init decode attention db")
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
...@@ -7,6 +8,7 @@ from itertools import accumulate ...@@ -7,6 +8,7 @@ from itertools import accumulate
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
from .triton_config import get_nearest_config, get_attention_mla_configs, get_config, get_attention_mla_configs_json
try: try:
from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper
...@@ -16,6 +18,8 @@ except ImportError: ...@@ -16,6 +18,8 @@ except ImportError:
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
import torch import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, from vllm.attention.backends.abstract import (AttentionBackend,
...@@ -32,7 +36,7 @@ from vllm.utils import async_tensor_h2d, make_tensor_with_pad ...@@ -32,7 +36,7 @@ from vllm.utils import async_tensor_h2d, make_tensor_with_pad
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder, from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata) ModelInputForGPUWithSamplingMetadata)
class TritonMLABackend(AttentionBackend): class TritonMLABackend(AttentionBackend):
...@@ -682,7 +686,9 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]): ...@@ -682,7 +686,9 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
"encoder/decoder cross-attention " "encoder/decoder cross-attention "
"are not implemented for " "are not implemented for "
"TritonMLAImpl") "TritonMLAImpl")
self.attn_configs = get_attention_mla_configs_json(self.num_heads, 1, self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank, "fp16")
def _forward_prefill( def _forward_prefill(
self, self,
q: torch.Tensor, q: torch.Tensor,
...@@ -735,12 +741,21 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]): ...@@ -735,12 +741,21 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
PAGE_SIZE = kv_c_and_k_pe_cache.size(1) PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
# TODO
max_seq_len = torch.max(decode_meta.seq_lens_tensor).item()
if os.environ.get('PA_MATCH_USE_MEAN_SEQ') == '1':
match_seq_len = int((decode_meta.seq_lens_tensor.sum()/ max(1, B)).item())
else:
match_seq_len = max_seq_len
best_config = self.attn_configs[min(self.attn_configs.keys(), key=lambda x: abs(int(x) - match_seq_len))]
# Run MQA # Run MQA
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
decode_meta.block_tables, decode_meta.block_tables,
decode_meta.seq_lens_tensor, attn_logits, decode_meta.seq_lens_tensor, attn_logits,
attn_metadata.num_kv_splits, self.scale, attn_metadata.num_kv_splits, self.scale, best_config,
PAGE_SIZE) PAGE_SIZE)
return self._v_up_proj_and_o_proj(o) return self._v_up_proj_and_o_proj(o)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment