// SPDX-License-Identifier: MIT #include #include #include #include #include #include "aiter_hip_common.h" #include "py_itfs_common.h" struct __attribute__((packed)) KernelArgs { void *ptr_O; p2 _p0; void *ptr_Q; p2 _p1; void *ptr_K; p2 _p2; void *ptr_V; p2 _p3; void *ptr_BT; p2 _p4; void *ptr_CL; p2 _p5; void *ptr_KQ; p2 _p6; void *ptr_VQ; p2 _p7; float sclg2e; p3 _p12; unsigned int mblk; p3 _p13; unsigned int kv_nheads; p3 _p14; unsigned int Qs; p3 _p15; unsigned int Bs; p3 _p16; unsigned int KVs; p3 _p17; unsigned int GQA; p3 _p18; }; const float f_log2E = log2f(expf(1)); torch::Tensor pa_fwd(torch::Tensor &Q, // [num_seqs, num_heads, head_size] torch::Tensor &K, // [num_blocks, num_kv_heads, head_size/x, block_size, x] torch::Tensor &V, // [num_blocks, num_kv_heads, block_size/X, head_size, X] torch::Tensor &block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor &context_lens, // [num_seqs] int max_num_blocks, std::optional &K_QScale, std::optional &V_QScale, std::optional &out_, std::optional high_precision = 1) { torch::Tensor output = out_.value_or(torch::empty_like(Q)); int batch = context_lens.size(0); // int max_num_blocks = block_tables.size(1); int num_heads = Q.size(1); int head_size = Q.size(2); int num_kv_heads = K.size(1); int block_size = K.size(3); const int gqa_ratio = num_heads / num_kv_heads; TORCH_CHECK(block_size == 16, __func__, " for now only support block_size == 16"); int dim = head_size; int stride_Q = Q.stride(0) * Q.itemsize(); int stride_KV_head = block_size * dim * K.itemsize(); int stride_KV_blk = stride_KV_head * num_kv_heads; float k_log2e = f_log2E; float k_scalar = sqrt(dim); k_scalar = (float)((double)k_log2e / (double)k_scalar); KernelArgs args; size_t arg_size = sizeof(args); args.ptr_O = output.data_ptr(); args.ptr_Q = Q.data_ptr(); args.ptr_K = K.data_ptr(); args.ptr_V = V.data_ptr(); args.ptr_BT = block_tables.data_ptr(); args.ptr_CL = context_lens.data_ptr(); if (K_QScale) { args.ptr_KQ = K_QScale.value().data_ptr(); args.ptr_VQ = V_QScale.value().data_ptr(); } else { args.ptr_KQ = nullptr; args.ptr_VQ = nullptr; } args.sclg2e = k_scalar; args.mblk = max_num_blocks; args.kv_nheads = num_kv_heads; args.Qs = stride_Q; args.Bs = stride_KV_blk; args.KVs = stride_KV_head; args.GQA = gqa_ratio; // std::cout << "sclg2e: " << args.sclg2e << " mblk:" << args.mblk << " kv_nheads:" << args.kv_nheads << " Qs:" << args.Qs << " Bs:" << args.Bs << " KVs:" << args.KVs << std::endl; const at::cuda::OptionalCUDAGuard device_guard(device_of(Q)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AiterAsmKernel *impl_ptr = nullptr; if (K_QScale) { if (Q.dtype() == at::ScalarType::Half) { if (K.dtype() == at::ScalarType::Byte || K.dtype() == at::ScalarType::Char) { static AiterAsmKernel impl_a16w8_f16_i8("pa_a16w8_2tg_g8_i8", "pa_a16w8_f16_2tg_g8_i8.co"); impl_ptr = &impl_a16w8_f16_i8; } else if (K.dtype() == torch_fp8) { if (high_precision.value() == 0) { static AiterAsmKernel impl_a16w8_f16_f8("_ZN5aiter32pa_fp16_pertokenFp8_a16w8_2tg_g8E", "/pa/pa_fp16_pertokenFp8_a16w8_2tg_g8.co"); impl_ptr = &impl_a16w8_f16_f8; } else if (high_precision.value() == 1) { static AiterAsmKernel impl_a16w8_2tg_g8_f8_q_fp16_tail_bf16("_ZN5aiter45pa_bf16_pertokenFp8_a16w8_g8_q_fp16_tail_bf16E", "/pa/pa_bf16_pertokenFp8_a16w8_g8_q_fp16_tail_bf16.co"); impl_ptr = &impl_a16w8_2tg_g8_f8_q_fp16_tail_bf16; } else { TORCH_CHECK(false, __func__, ": high_precision value only support (0, 1) grades on fp16 asm pa for fp8 kv cache !!!"); } } } else if (Q.dtype() == at::ScalarType::BFloat16) { if (K.dtype() == at::ScalarType::Byte || K.dtype() == at::ScalarType::Char) { static AiterAsmKernel impl_a16w8_b16_i8("_ZN5aiter33pa_bf16_pertokenInt8_a16w8_2tg_g8E", "/pa/pa_bf16_pertokenInt8_a16w8_2tg_g8.co"); impl_ptr = &impl_a16w8_b16_i8; } else if (K.dtype() == torch_fp8) { if (high_precision.value() == 0) { static AiterAsmKernel impl_a16w8_b16_f8("_ZN5aiter32pa_bf16_pertokenFp8_a16w8_2tg_g8E", "/pa/pa_bf16_pertokenFp8_a16w8_2tg_g8.co"); impl_ptr = &impl_a16w8_b16_f8; } else if (high_precision.value() == 1) { static AiterAsmKernel impl_a16w8_b16_f8_tail_bf16("_ZN5aiter42pa_bf16_pertokenFp8_a16w8_2tg_g8_tail_bf16E", "/pa/pa_bf16_pertokenFp8_a16w8_2tg_g8_tail_bf16.co"); impl_ptr = &impl_a16w8_b16_f8_tail_bf16; } else if (high_precision.value() == 2) { static AiterAsmKernel impl_a16w8_b16_f8_gemm1_bf16("_ZN5aiter43pa_bf16_pertokenFp8_a16w8_2tg_g8_gemm1_bf16E", "/pa/pa_bf16_pertokenFp8_a16w8_2tg_g8_gemm1_bf16.co"); impl_ptr = &impl_a16w8_b16_f8_gemm1_bf16; } else { TORCH_CHECK(false, __func__, ": high_precision value only support (0, 1, 2) grades on bf16 asm pa for fp8 kv cache !!!"); } } } } else { TORCH_CHECK(Q.is_contiguous(), __func__, ":a16w16 only support Q.is_contiguous() for now"); TORCH_CHECK(num_kv_heads == 1, __func__, ":a16w16 only support num_kv_heads==1, for now"); if (Q.dtype() == at::ScalarType::Half) { static AiterAsmKernel impl_a16w16_f16("pa_kernel_func", "pa_a16w16_f16.co"); impl_ptr = &impl_a16w16_f16; } else if (Q.dtype() == at::ScalarType::BFloat16) { static AiterAsmKernel impl_a16w16_b16("_ZN5aiter22pa_bf16_noquant_a16w16E", "/pa/pa_bf16_noquant_a16w16.co"); impl_ptr = &impl_a16w16_b16; } } TORCH_CHECK(impl_ptr != nullptr, __func__, ": unsupport current Q_type:", Q.scalar_type()); impl_ptr->launch_kernel({&args, &arg_size, num_kv_heads, // gdx batch, // gdy 1, // gdz 256, // bdx: 4 wv64 1, // bdy 1, // bdz stream}); return output; }