Commit 98f67566 authored by zhuwenwen's avatar zhuwenwen
Browse files

remove unused kernels

parent 0a3cede3
......@@ -12,7 +12,7 @@
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh"
// #include "../quantization/fp8/amd/quant_utils.cuh"
typedef __hip_bfloat16 __nv_bfloat16;
#else
#include "../quantization/fp8/nvidia/quant_utils.cuh"
......
......@@ -52,47 +52,6 @@ void paged_attention_v2(
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v1_opt(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v2_opt(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v1_opt_tc(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v2_opt_tc(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void merge_attn_states(torch::Tensor& output,
std::optional<torch::Tensor> output_lse,
......@@ -191,12 +150,12 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& out,
torch::Tensor& input_global_scale);
#endif
void persistent_masked_m_silu_mul_quant(
const at::Tensor& input, // (E, T, 2*H)
const at::Tensor& counts, // (E)
at::Tensor& y_q, // (E, T, H) [OUT]
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
bool use_ue8m0);
// void persistent_masked_m_silu_mul_quant(
// const at::Tensor& input, // (E, T, 2*H)
// const at::Tensor& counts, // (E)
// at::Tensor& y_q, // (E, T, H) [OUT]
// at::Tensor& y_s, // (E, T, H//group_size) [OUT]
// bool use_ue8m0);
void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
......
......@@ -597,139 +597,139 @@ void silu_and_mul_quant(torch::Tensor& out, // [..., d]
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
}
void persistent_masked_m_silu_mul_quant(
const at::Tensor& input, // (E, T, 2*H)
const at::Tensor& tokens_per_expert, // (E)
at::Tensor& y_q, // (E, T, H) [OUT]
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
bool cast_scale_ue8m0) {
#ifndef USE_ROCM
// This kernel currently only supports H % 128 == 0 and assumes a
// fixed GROUP_SIZE of 128.
static constexpr int GROUP_SIZE = 128;
TORCH_CHECK(input.dtype() == torch::kBFloat16);
TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn ||
y_q.dtype() == torch::kFloat8_e4m3fnuz);
TORCH_CHECK(input.size(-1) % (GROUP_SIZE * 2) == 0);
bool const is_packed_ue8m0 =
(y_s.dtype() == torch::kInt32 && cast_scale_ue8m0);
TORCH_CHECK(y_s.dtype() == torch::kFloat32 || is_packed_ue8m0);
using Idx_t = int64_t;
Idx_t E = input.size(0);
Idx_t T = input.size(1);
Idx_t H = input.size(2) / 2;
Idx_t stride_i_e = input.stride(0);
Idx_t stride_i_t = input.stride(1);
Idx_t stride_i_h = input.stride(2);
Idx_t stride_yq_e = y_q.stride(0);
Idx_t stride_yq_t = y_q.stride(1);
Idx_t stride_yq_h = y_q.stride(2);
Idx_t stride_counts_e = tokens_per_expert.stride(0);
int const NUM_GROUPS = H / GROUP_SIZE;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// TODO: Get this from cuda_arch ?
static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32;
#define KERNEL(BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \
STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, STAGES) \
static constexpr int NUM_WARPS = THREAD_COUNT / WARP_SIZE; \
int sms = SILU_V2_BLOCK_COUNT; \
static constexpr int max_shared_mem_bytes = \
GROUP_SIZE * 2 * STAGES * NUM_WARPS * 2; \
dim3 grid(sms), block(THREAD_COUNT); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
VLLM_DISPATCH_FP8_TYPES( \
y_q.scalar_type(), "silu_mul_fp8_quant_deep_gemm_kernel", [&] { \
vllm::silu_mul_fp8_quant_deep_gemm_kernel< \
BLOCK_COUNT, max_shared_mem_bytes, fp8_t, scale_t, THREAD_COUNT, \
Idx_t, CEIL_UE8M0, GROUP_SIZE, STAGES> \
<<<grid, block, max_shared_mem_bytes + (E + 1) * 16, stream>>>( \
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
(fp8_t*)y_q.data_ptr(), \
reinterpret_cast<scale_t*>(y_s.data_ptr()), \
reinterpret_cast<int32_t*>(tokens_per_expert.data_ptr()), E, \
T, H, stride_i_e, stride_i_t, stride_i_h, stride_yq_e, \
stride_yq_t, stride_yq_h, STRIDE_YS_E, STRIDE_YS_T, \
STRIDE_YS_G, STRIDE_YS_P, stride_counts_e); \
});
#define LAUNCH_ON_H(scale_t, STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \
STRIDE_YS_P, CEIL_UE8M0) \
if (H >= 4096 && (NUM_GROUPS % 8) == 0) { \
/* 8 warp config */ \
static constexpr int NUM_STAGES = 4; \
static constexpr int THREAD_COUNT = 256; \
KERNEL(SILU_V2_BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, \
STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, NUM_STAGES); \
} else { \
/* 1 warp config */ \
static constexpr int THREAD_COUNT = 32; \
KERNEL(SILU_V2_BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, \
STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, 2); \
}
Idx_t stride_ys_e = y_s.stride(0);
Idx_t stride_ys_t = y_s.stride(1);
Idx_t stride_ys_g = y_s.stride(2);
Idx_t stride_ys_p = 0;
if (!cast_scale_ue8m0) {
TORCH_CHECK(!is_packed_ue8m0);
LAUNCH_ON_H(float, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p,
false);
return;
}
if (!is_packed_ue8m0) {
// UE8M0 but not packed
LAUNCH_ON_H(float, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p,
true);
return;
}
TORCH_CHECK(cast_scale_ue8m0 && is_packed_ue8m0);
TORCH_CHECK(y_s.dtype() == torch::kInt32);
// Int32 packed ue8m0 scales tensor.
// Let E, T, G be the number to experts, number of tokens and number of groups
// respectively. Let, E = 2, T = 4, G = 6, in this case the int32 scales
// tensor are of shape [1, 4, 2] and stride [8, 1, 4]. The scales are expected
// to be arranged as follows,
// [[T0G0-T0G1-T0G2-T0G3, T0G4-T0G5-X-X,],
// [T1G0-T1G1-T1G2-T1G3, T1G4-T1G5-X-X,]
// [T2G0-T2G1-T2G2-T2G3, T2G4-T2G5-X-X,]
// [T3G0-T3G1-T3G2-T3G3, T3G4-T3G5-X-X,]]
// where, TxGy is the scale ue8m0 scale value of Token x, Group y.
//
// In memory (in bytes) the scale values are arranged as,
// [T0G0, T0G1, T0G2, T0G3, T1G0, T1G2, T1G3, T1G4, T2G0, T2G1, T2G3, T2G4,
// T3G0, T3G1, T3G2, T3G3, T0G4, T0G5, X, X, T1G4, T1G5, X, X, T2G4, T2G5,
// X, X, T3G4, T3G5, X, X]
//
// An Int32 tensor of size [1, 4, 2] and stride [8, 1, 4] can be represented
// as an uint8 tensor of shape [1, 2, 4, 4] and stride [32, 16, 4, 1]. In
// english, ignoring the Experts dimension, the original int32 tensor is
// simply treated as two packed [4, 4] uint8 tensor (or two [4, 1] int32
// tensor). The following strides setting reflects this change. Caveat: This
// means that the G dimension is no longer contiguous. i.e. Note that to move
// from G3 to G4, we need to jump along the packing dimension. The kernel
// handles this case.
stride_ys_e *= sizeof(int32_t);
stride_ys_p = T * sizeof(int32_t); // Packing dimension
stride_ys_t = sizeof(int32_t);
stride_ys_g = 1;
LAUNCH_ON_H(uint8_t, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p,
true);
#endif
}
// void persistent_masked_m_silu_mul_quant(
// const at::Tensor& input, // (E, T, 2*H)
// const at::Tensor& tokens_per_expert, // (E)
// at::Tensor& y_q, // (E, T, H) [OUT]
// at::Tensor& y_s, // (E, T, H//group_size) [OUT]
// bool cast_scale_ue8m0) {
// #ifndef USE_ROCM
// // This kernel currently only supports H % 128 == 0 and assumes a
// // fixed GROUP_SIZE of 128.
// static constexpr int GROUP_SIZE = 128;
// TORCH_CHECK(input.dtype() == torch::kBFloat16);
// TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn ||
// y_q.dtype() == torch::kFloat8_e4m3fnuz);
// TORCH_CHECK(input.size(-1) % (GROUP_SIZE * 2) == 0);
// bool const is_packed_ue8m0 =
// (y_s.dtype() == torch::kInt32 && cast_scale_ue8m0);
// TORCH_CHECK(y_s.dtype() == torch::kFloat32 || is_packed_ue8m0);
// using Idx_t = int64_t;
// Idx_t E = input.size(0);
// Idx_t T = input.size(1);
// Idx_t H = input.size(2) / 2;
// Idx_t stride_i_e = input.stride(0);
// Idx_t stride_i_t = input.stride(1);
// Idx_t stride_i_h = input.stride(2);
// Idx_t stride_yq_e = y_q.stride(0);
// Idx_t stride_yq_t = y_q.stride(1);
// Idx_t stride_yq_h = y_q.stride(2);
// Idx_t stride_counts_e = tokens_per_expert.stride(0);
// int const NUM_GROUPS = H / GROUP_SIZE;
// const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// // TODO: Get this from cuda_arch ?
// static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32;
// #define KERNEL(BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \
// STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, STAGES) \
// static constexpr int NUM_WARPS = THREAD_COUNT / WARP_SIZE; \
// int sms = SILU_V2_BLOCK_COUNT; \
// static constexpr int max_shared_mem_bytes = \
// GROUP_SIZE * 2 * STAGES * NUM_WARPS * 2; \
// dim3 grid(sms), block(THREAD_COUNT); \
// const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
// VLLM_DISPATCH_FP8_TYPES( \
// y_q.scalar_type(), "silu_mul_fp8_quant_deep_gemm_kernel", [&] { \
// vllm::silu_mul_fp8_quant_deep_gemm_kernel< \
// BLOCK_COUNT, max_shared_mem_bytes, fp8_t, scale_t, THREAD_COUNT, \
// Idx_t, CEIL_UE8M0, GROUP_SIZE, STAGES> \
// <<<grid, block, max_shared_mem_bytes + (E + 1) * 16, stream>>>( \
// reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
// (fp8_t*)y_q.data_ptr(), \
// reinterpret_cast<scale_t*>(y_s.data_ptr()), \
// reinterpret_cast<int32_t*>(tokens_per_expert.data_ptr()), E, \
// T, H, stride_i_e, stride_i_t, stride_i_h, stride_yq_e, \
// stride_yq_t, stride_yq_h, STRIDE_YS_E, STRIDE_YS_T, \
// STRIDE_YS_G, STRIDE_YS_P, stride_counts_e); \
// });
// #define LAUNCH_ON_H(scale_t, STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \
// STRIDE_YS_P, CEIL_UE8M0) \
// if (H >= 4096 && (NUM_GROUPS % 8) == 0) { \
// /* 8 warp config */ \
// static constexpr int NUM_STAGES = 4; \
// static constexpr int THREAD_COUNT = 256; \
// KERNEL(SILU_V2_BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, \
// STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, NUM_STAGES); \
// } else { \
// /* 1 warp config */ \
// static constexpr int THREAD_COUNT = 32; \
// KERNEL(SILU_V2_BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, \
// STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, 2); \
// }
// Idx_t stride_ys_e = y_s.stride(0);
// Idx_t stride_ys_t = y_s.stride(1);
// Idx_t stride_ys_g = y_s.stride(2);
// Idx_t stride_ys_p = 0;
// if (!cast_scale_ue8m0) {
// TORCH_CHECK(!is_packed_ue8m0);
// LAUNCH_ON_H(float, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p,
// false);
// return;
// }
// if (!is_packed_ue8m0) {
// // UE8M0 but not packed
// LAUNCH_ON_H(float, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p,
// true);
// return;
// }
// TORCH_CHECK(cast_scale_ue8m0 && is_packed_ue8m0);
// TORCH_CHECK(y_s.dtype() == torch::kInt32);
// // Int32 packed ue8m0 scales tensor.
// // Let E, T, G be the number to experts, number of tokens and number of groups
// // respectively. Let, E = 2, T = 4, G = 6, in this case the int32 scales
// // tensor are of shape [1, 4, 2] and stride [8, 1, 4]. The scales are expected
// // to be arranged as follows,
// // [[T0G0-T0G1-T0G2-T0G3, T0G4-T0G5-X-X,],
// // [T1G0-T1G1-T1G2-T1G3, T1G4-T1G5-X-X,]
// // [T2G0-T2G1-T2G2-T2G3, T2G4-T2G5-X-X,]
// // [T3G0-T3G1-T3G2-T3G3, T3G4-T3G5-X-X,]]
// // where, TxGy is the scale ue8m0 scale value of Token x, Group y.
// //
// // In memory (in bytes) the scale values are arranged as,
// // [T0G0, T0G1, T0G2, T0G3, T1G0, T1G2, T1G3, T1G4, T2G0, T2G1, T2G3, T2G4,
// // T3G0, T3G1, T3G2, T3G3, T0G4, T0G5, X, X, T1G4, T1G5, X, X, T2G4, T2G5,
// // X, X, T3G4, T3G5, X, X]
// //
// // An Int32 tensor of size [1, 4, 2] and stride [8, 1, 4] can be represented
// // as an uint8 tensor of shape [1, 2, 4, 4] and stride [32, 16, 4, 1]. In
// // english, ignoring the Experts dimension, the original int32 tensor is
// // simply treated as two packed [4, 4] uint8 tensor (or two [4, 1] int32
// // tensor). The following strides setting reflects this change. Caveat: This
// // means that the G dimension is no longer contiguous. i.e. Note that to move
// // from G3 to G4, we need to jump along the packing dimension. The kernel
// // handles this case.
// stride_ys_e *= sizeof(int32_t);
// stride_ys_p = T * sizeof(int32_t); // Packing dimension
// stride_ys_t = sizeof(int32_t);
// stride_ys_g = 1;
// LAUNCH_ON_H(uint8_t, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p,
// true);
// #endif
// }
......@@ -20,12 +20,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops
//
ops.def(
"persistent_masked_m_silu_mul_quant(Tensor input, Tensor counts, Tensor! "
"y_q, Tensor! y_s,"
"bool use_ue8m0) -> ()");
ops.impl("persistent_masked_m_silu_mul_quant", torch::kCUDA,
&persistent_masked_m_silu_mul_quant);
// ops.def(
// "persistent_masked_m_silu_mul_quant(Tensor input, Tensor counts, Tensor! "
// "y_q, Tensor! y_s,"
// "bool use_ue8m0) -> ()");
// ops.impl("persistent_masked_m_silu_mul_quant", torch::kCUDA,
// &persistent_masked_m_silu_mul_quant);
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
......@@ -63,62 +63,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
// Compute the attention between an input query and the cached
// keys/values using PagedAttention. (opt)
ops.def(
"paged_attention_v1_opt("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1_opt", torch::kCUDA, &paged_attention_v1_opt);
// PagedAttention V2 (opt).
ops.def(
"paged_attention_v2_opt("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2_opt", torch::kCUDA, &paged_attention_v2_opt);
// Compute the attention between an input query and the cached
// keys/values using PagedAttention. (opt)
ops.def(
"paged_attention_v1_opt_tc("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1_opt_tc", torch::kCUDA, &paged_attention_v1_opt_tc);
// PagedAttention V2 (opt).
ops.def(
"paged_attention_v2_opt_tc("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2_opt_tc", torch::kCUDA, &paged_attention_v2_opt_tc);
// Merge attn states
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case)
......@@ -132,7 +76,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor suffix_lse) -> ()");
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
// #ifndef USE_ROCM
ops.def(
"convert_vertical_slash_indexes("
" Tensor! block_count, Tensor! block_offset, "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment