Commit 98f67566 authored by zhuwenwen's avatar zhuwenwen
Browse files

remove unused kernels

parent 0a3cede3
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#ifdef USE_ROCM #ifdef USE_ROCM
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh" // #include "../quantization/fp8/amd/quant_utils.cuh"
typedef __hip_bfloat16 __nv_bfloat16; typedef __hip_bfloat16 __nv_bfloat16;
#else #else
#include "../quantization/fp8/nvidia/quant_utils.cuh" #include "../quantization/fp8/nvidia/quant_utils.cuh"
......
...@@ -52,47 +52,6 @@ void paged_attention_v2( ...@@ -52,47 +52,6 @@ void paged_attention_v2(
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step); const int64_t blocksparse_head_sliding_step);
void paged_attention_v1_opt(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v2_opt(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v1_opt_tc(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v2_opt_tc(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void merge_attn_states(torch::Tensor& output, void merge_attn_states(torch::Tensor& output,
std::optional<torch::Tensor> output_lse, std::optional<torch::Tensor> output_lse,
...@@ -191,12 +150,12 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& out, ...@@ -191,12 +150,12 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& out,
torch::Tensor& input_global_scale); torch::Tensor& input_global_scale);
#endif #endif
void persistent_masked_m_silu_mul_quant( // void persistent_masked_m_silu_mul_quant(
const at::Tensor& input, // (E, T, 2*H) // const at::Tensor& input, // (E, T, 2*H)
const at::Tensor& counts, // (E) // const at::Tensor& counts, // (E)
at::Tensor& y_q, // (E, T, H) [OUT] // at::Tensor& y_q, // (E, T, H) [OUT]
at::Tensor& y_s, // (E, T, H//group_size) [OUT] // at::Tensor& y_s, // (E, T, H//group_size) [OUT]
bool use_ue8m0); // bool use_ue8m0);
void mul_and_silu(torch::Tensor& out, torch::Tensor& input); void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
......
...@@ -597,139 +597,139 @@ void silu_and_mul_quant(torch::Tensor& out, // [..., d] ...@@ -597,139 +597,139 @@ void silu_and_mul_quant(torch::Tensor& out, // [..., d]
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
} }
void persistent_masked_m_silu_mul_quant( // void persistent_masked_m_silu_mul_quant(
const at::Tensor& input, // (E, T, 2*H) // const at::Tensor& input, // (E, T, 2*H)
const at::Tensor& tokens_per_expert, // (E) // const at::Tensor& tokens_per_expert, // (E)
at::Tensor& y_q, // (E, T, H) [OUT] // at::Tensor& y_q, // (E, T, H) [OUT]
at::Tensor& y_s, // (E, T, H//group_size) [OUT] // at::Tensor& y_s, // (E, T, H//group_size) [OUT]
bool cast_scale_ue8m0) { // bool cast_scale_ue8m0) {
#ifndef USE_ROCM // #ifndef USE_ROCM
// This kernel currently only supports H % 128 == 0 and assumes a // // This kernel currently only supports H % 128 == 0 and assumes a
// fixed GROUP_SIZE of 128. // // fixed GROUP_SIZE of 128.
static constexpr int GROUP_SIZE = 128; // static constexpr int GROUP_SIZE = 128;
TORCH_CHECK(input.dtype() == torch::kBFloat16); // TORCH_CHECK(input.dtype() == torch::kBFloat16);
TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn || // TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn ||
y_q.dtype() == torch::kFloat8_e4m3fnuz); // y_q.dtype() == torch::kFloat8_e4m3fnuz);
TORCH_CHECK(input.size(-1) % (GROUP_SIZE * 2) == 0); // TORCH_CHECK(input.size(-1) % (GROUP_SIZE * 2) == 0);
bool const is_packed_ue8m0 = // bool const is_packed_ue8m0 =
(y_s.dtype() == torch::kInt32 && cast_scale_ue8m0); // (y_s.dtype() == torch::kInt32 && cast_scale_ue8m0);
TORCH_CHECK(y_s.dtype() == torch::kFloat32 || is_packed_ue8m0); // TORCH_CHECK(y_s.dtype() == torch::kFloat32 || is_packed_ue8m0);
using Idx_t = int64_t; // using Idx_t = int64_t;
Idx_t E = input.size(0); // Idx_t E = input.size(0);
Idx_t T = input.size(1); // Idx_t T = input.size(1);
Idx_t H = input.size(2) / 2; // Idx_t H = input.size(2) / 2;
Idx_t stride_i_e = input.stride(0); // Idx_t stride_i_e = input.stride(0);
Idx_t stride_i_t = input.stride(1); // Idx_t stride_i_t = input.stride(1);
Idx_t stride_i_h = input.stride(2); // Idx_t stride_i_h = input.stride(2);
Idx_t stride_yq_e = y_q.stride(0); // Idx_t stride_yq_e = y_q.stride(0);
Idx_t stride_yq_t = y_q.stride(1); // Idx_t stride_yq_t = y_q.stride(1);
Idx_t stride_yq_h = y_q.stride(2); // Idx_t stride_yq_h = y_q.stride(2);
Idx_t stride_counts_e = tokens_per_expert.stride(0); // Idx_t stride_counts_e = tokens_per_expert.stride(0);
int const NUM_GROUPS = H / GROUP_SIZE; // int const NUM_GROUPS = H / GROUP_SIZE;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// TODO: Get this from cuda_arch ? // // TODO: Get this from cuda_arch ?
static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32; // static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32;
#define KERNEL(BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \ // #define KERNEL(BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \
STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, STAGES) \ // STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, STAGES) \
static constexpr int NUM_WARPS = THREAD_COUNT / WARP_SIZE; \ // static constexpr int NUM_WARPS = THREAD_COUNT / WARP_SIZE; \
int sms = SILU_V2_BLOCK_COUNT; \ // int sms = SILU_V2_BLOCK_COUNT; \
static constexpr int max_shared_mem_bytes = \ // static constexpr int max_shared_mem_bytes = \
GROUP_SIZE * 2 * STAGES * NUM_WARPS * 2; \ // GROUP_SIZE * 2 * STAGES * NUM_WARPS * 2; \
dim3 grid(sms), block(THREAD_COUNT); \ // dim3 grid(sms), block(THREAD_COUNT); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ // const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
VLLM_DISPATCH_FP8_TYPES( \ // VLLM_DISPATCH_FP8_TYPES( \
y_q.scalar_type(), "silu_mul_fp8_quant_deep_gemm_kernel", [&] { \ // y_q.scalar_type(), "silu_mul_fp8_quant_deep_gemm_kernel", [&] { \
vllm::silu_mul_fp8_quant_deep_gemm_kernel< \ // vllm::silu_mul_fp8_quant_deep_gemm_kernel< \
BLOCK_COUNT, max_shared_mem_bytes, fp8_t, scale_t, THREAD_COUNT, \ // BLOCK_COUNT, max_shared_mem_bytes, fp8_t, scale_t, THREAD_COUNT, \
Idx_t, CEIL_UE8M0, GROUP_SIZE, STAGES> \ // Idx_t, CEIL_UE8M0, GROUP_SIZE, STAGES> \
<<<grid, block, max_shared_mem_bytes + (E + 1) * 16, stream>>>( \ // <<<grid, block, max_shared_mem_bytes + (E + 1) * 16, stream>>>( \
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \ // reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
(fp8_t*)y_q.data_ptr(), \ // (fp8_t*)y_q.data_ptr(), \
reinterpret_cast<scale_t*>(y_s.data_ptr()), \ // reinterpret_cast<scale_t*>(y_s.data_ptr()), \
reinterpret_cast<int32_t*>(tokens_per_expert.data_ptr()), E, \ // reinterpret_cast<int32_t*>(tokens_per_expert.data_ptr()), E, \
T, H, stride_i_e, stride_i_t, stride_i_h, stride_yq_e, \ // T, H, stride_i_e, stride_i_t, stride_i_h, stride_yq_e, \
stride_yq_t, stride_yq_h, STRIDE_YS_E, STRIDE_YS_T, \ // stride_yq_t, stride_yq_h, STRIDE_YS_E, STRIDE_YS_T, \
STRIDE_YS_G, STRIDE_YS_P, stride_counts_e); \ // STRIDE_YS_G, STRIDE_YS_P, stride_counts_e); \
}); // });
#define LAUNCH_ON_H(scale_t, STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \ // #define LAUNCH_ON_H(scale_t, STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \
STRIDE_YS_P, CEIL_UE8M0) \ // STRIDE_YS_P, CEIL_UE8M0) \
if (H >= 4096 && (NUM_GROUPS % 8) == 0) { \ // if (H >= 4096 && (NUM_GROUPS % 8) == 0) { \
/* 8 warp config */ \ // /* 8 warp config */ \
static constexpr int NUM_STAGES = 4; \ // static constexpr int NUM_STAGES = 4; \
static constexpr int THREAD_COUNT = 256; \ // static constexpr int THREAD_COUNT = 256; \
KERNEL(SILU_V2_BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, \ // KERNEL(SILU_V2_BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, \
STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, NUM_STAGES); \ // STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, NUM_STAGES); \
} else { \ // } else { \
/* 1 warp config */ \ // /* 1 warp config */ \
static constexpr int THREAD_COUNT = 32; \ // static constexpr int THREAD_COUNT = 32; \
KERNEL(SILU_V2_BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, \ // KERNEL(SILU_V2_BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, \
STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, 2); \ // STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, 2); \
} // }
Idx_t stride_ys_e = y_s.stride(0); // Idx_t stride_ys_e = y_s.stride(0);
Idx_t stride_ys_t = y_s.stride(1); // Idx_t stride_ys_t = y_s.stride(1);
Idx_t stride_ys_g = y_s.stride(2); // Idx_t stride_ys_g = y_s.stride(2);
Idx_t stride_ys_p = 0; // Idx_t stride_ys_p = 0;
if (!cast_scale_ue8m0) { // if (!cast_scale_ue8m0) {
TORCH_CHECK(!is_packed_ue8m0); // TORCH_CHECK(!is_packed_ue8m0);
LAUNCH_ON_H(float, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p, // LAUNCH_ON_H(float, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p,
false); // false);
return; // return;
} // }
if (!is_packed_ue8m0) { // if (!is_packed_ue8m0) {
// UE8M0 but not packed // // UE8M0 but not packed
LAUNCH_ON_H(float, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p, // LAUNCH_ON_H(float, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p,
true); // true);
return; // return;
} // }
TORCH_CHECK(cast_scale_ue8m0 && is_packed_ue8m0); // TORCH_CHECK(cast_scale_ue8m0 && is_packed_ue8m0);
TORCH_CHECK(y_s.dtype() == torch::kInt32); // TORCH_CHECK(y_s.dtype() == torch::kInt32);
// Int32 packed ue8m0 scales tensor. // // Int32 packed ue8m0 scales tensor.
// Let E, T, G be the number to experts, number of tokens and number of groups // // Let E, T, G be the number to experts, number of tokens and number of groups
// respectively. Let, E = 2, T = 4, G = 6, in this case the int32 scales // // respectively. Let, E = 2, T = 4, G = 6, in this case the int32 scales
// tensor are of shape [1, 4, 2] and stride [8, 1, 4]. The scales are expected // // tensor are of shape [1, 4, 2] and stride [8, 1, 4]. The scales are expected
// to be arranged as follows, // // to be arranged as follows,
// [[T0G0-T0G1-T0G2-T0G3, T0G4-T0G5-X-X,], // // [[T0G0-T0G1-T0G2-T0G3, T0G4-T0G5-X-X,],
// [T1G0-T1G1-T1G2-T1G3, T1G4-T1G5-X-X,] // // [T1G0-T1G1-T1G2-T1G3, T1G4-T1G5-X-X,]
// [T2G0-T2G1-T2G2-T2G3, T2G4-T2G5-X-X,] // // [T2G0-T2G1-T2G2-T2G3, T2G4-T2G5-X-X,]
// [T3G0-T3G1-T3G2-T3G3, T3G4-T3G5-X-X,]] // // [T3G0-T3G1-T3G2-T3G3, T3G4-T3G5-X-X,]]
// where, TxGy is the scale ue8m0 scale value of Token x, Group y. // // where, TxGy is the scale ue8m0 scale value of Token x, Group y.
// // //
// In memory (in bytes) the scale values are arranged as, // // In memory (in bytes) the scale values are arranged as,
// [T0G0, T0G1, T0G2, T0G3, T1G0, T1G2, T1G3, T1G4, T2G0, T2G1, T2G3, T2G4, // // [T0G0, T0G1, T0G2, T0G3, T1G0, T1G2, T1G3, T1G4, T2G0, T2G1, T2G3, T2G4,
// T3G0, T3G1, T3G2, T3G3, T0G4, T0G5, X, X, T1G4, T1G5, X, X, T2G4, T2G5, // // T3G0, T3G1, T3G2, T3G3, T0G4, T0G5, X, X, T1G4, T1G5, X, X, T2G4, T2G5,
// X, X, T3G4, T3G5, X, X] // // X, X, T3G4, T3G5, X, X]
// // //
// An Int32 tensor of size [1, 4, 2] and stride [8, 1, 4] can be represented // // An Int32 tensor of size [1, 4, 2] and stride [8, 1, 4] can be represented
// as an uint8 tensor of shape [1, 2, 4, 4] and stride [32, 16, 4, 1]. In // // as an uint8 tensor of shape [1, 2, 4, 4] and stride [32, 16, 4, 1]. In
// english, ignoring the Experts dimension, the original int32 tensor is // // english, ignoring the Experts dimension, the original int32 tensor is
// simply treated as two packed [4, 4] uint8 tensor (or two [4, 1] int32 // // simply treated as two packed [4, 4] uint8 tensor (or two [4, 1] int32
// tensor). The following strides setting reflects this change. Caveat: This // // tensor). The following strides setting reflects this change. Caveat: This
// means that the G dimension is no longer contiguous. i.e. Note that to move // // means that the G dimension is no longer contiguous. i.e. Note that to move
// from G3 to G4, we need to jump along the packing dimension. The kernel // // from G3 to G4, we need to jump along the packing dimension. The kernel
// handles this case. // // handles this case.
stride_ys_e *= sizeof(int32_t); // stride_ys_e *= sizeof(int32_t);
stride_ys_p = T * sizeof(int32_t); // Packing dimension // stride_ys_p = T * sizeof(int32_t); // Packing dimension
stride_ys_t = sizeof(int32_t); // stride_ys_t = sizeof(int32_t);
stride_ys_g = 1; // stride_ys_g = 1;
LAUNCH_ON_H(uint8_t, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p, // LAUNCH_ON_H(uint8_t, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p,
true); // true);
#endif // #endif
} // }
...@@ -20,12 +20,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -20,12 +20,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops // vLLM custom ops
// //
ops.def( // ops.def(
"persistent_masked_m_silu_mul_quant(Tensor input, Tensor counts, Tensor! " // "persistent_masked_m_silu_mul_quant(Tensor input, Tensor counts, Tensor! "
"y_q, Tensor! y_s," // "y_q, Tensor! y_s,"
"bool use_ue8m0) -> ()"); // "bool use_ue8m0) -> ()");
ops.impl("persistent_masked_m_silu_mul_quant", torch::kCUDA, // ops.impl("persistent_masked_m_silu_mul_quant", torch::kCUDA,
&persistent_masked_m_silu_mul_quant); // &persistent_masked_m_silu_mul_quant);
ops.def("weak_ref_tensor(Tensor input) -> Tensor"); ops.def("weak_ref_tensor(Tensor input) -> Tensor");
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor); ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
...@@ -63,62 +63,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -63,62 +63,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step) -> ()"); " int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
// Compute the attention between an input query and the cached
// keys/values using PagedAttention. (opt)
ops.def(
"paged_attention_v1_opt("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1_opt", torch::kCUDA, &paged_attention_v1_opt);
// PagedAttention V2 (opt).
ops.def(
"paged_attention_v2_opt("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2_opt", torch::kCUDA, &paged_attention_v2_opt);
// Compute the attention between an input query and the cached
// keys/values using PagedAttention. (opt)
ops.def(
"paged_attention_v1_opt_tc("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1_opt_tc", torch::kCUDA, &paged_attention_v1_opt_tc);
// PagedAttention V2 (opt).
ops.def(
"paged_attention_v2_opt_tc("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2_opt_tc", torch::kCUDA, &paged_attention_v2_opt_tc);
// Merge attn states // Merge attn states
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 // Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case) // can be used to combine partial attention results (in the split-KV case)
...@@ -132,7 +76,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -132,7 +76,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor suffix_lse) -> ()"); " Tensor suffix_lse) -> ()");
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states); ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
// #ifndef USE_ROCM
ops.def( ops.def(
"convert_vertical_slash_indexes(" "convert_vertical_slash_indexes("
" Tensor! block_count, Tensor! block_offset, " " Tensor! block_count, Tensor! block_offset, "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment