Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
98f67566
Commit
98f67566
authored
Dec 13, 2025
by
zhuwenwen
Browse files
remove unused kernels
parent
0a3cede3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
149 additions
and
247 deletions
+149
-247
csrc/moe/moe_fused_gate.cu
csrc/moe/moe_fused_gate.cu
+1
-1
csrc/ops.h
csrc/ops.h
+6
-47
csrc/quantization/activation_kernels.cu
csrc/quantization/activation_kernels.cu
+136
-136
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+6
-63
No files found.
csrc/moe/moe_fused_gate.cu
View file @
98f67566
...
...
@@ -12,7 +12,7 @@
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh"
//
#include "../quantization/fp8/amd/quant_utils.cuh"
typedef
__hip_bfloat16
__nv_bfloat16
;
#else
#include "../quantization/fp8/nvidia/quant_utils.cuh"
...
...
csrc/ops.h
View file @
98f67566
...
...
@@ -52,47 +52,6 @@ void paged_attention_v2(
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
void
paged_attention_v1_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
void
paged_attention_v2_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
void
paged_attention_v1_opt_tc
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
void
paged_attention_v2_opt_tc
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
void
merge_attn_states
(
torch
::
Tensor
&
output
,
std
::
optional
<
torch
::
Tensor
>
output_lse
,
...
...
@@ -191,12 +150,12 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& out,
torch
::
Tensor
&
input_global_scale
);
#endif
void
persistent_masked_m_silu_mul_quant
(
const
at
::
Tensor
&
input
,
// (E, T, 2*H)
const
at
::
Tensor
&
counts
,
// (E)
at
::
Tensor
&
y_q
,
// (E, T, H) [OUT]
at
::
Tensor
&
y_s
,
// (E, T, H//group_size) [OUT]
bool
use_ue8m0
);
//
void persistent_masked_m_silu_mul_quant(
//
const at::Tensor& input, // (E, T, 2*H)
//
const at::Tensor& counts, // (E)
//
at::Tensor& y_q, // (E, T, H) [OUT]
//
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
//
bool use_ue8m0);
void
mul_and_silu
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
...
...
csrc/quantization/activation_kernels.cu
View file @
98f67566
...
...
@@ -597,139 +597,139 @@ void silu_and_mul_quant(torch::Tensor& out, // [..., d]
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
silu_kernel
);
}
void
persistent_masked_m_silu_mul_quant
(
const
at
::
Tensor
&
input
,
// (E, T, 2*H)
const
at
::
Tensor
&
tokens_per_expert
,
// (E)
at
::
Tensor
&
y_q
,
// (E, T, H) [OUT]
at
::
Tensor
&
y_s
,
// (E, T, H//group_size) [OUT]
bool
cast_scale_ue8m0
)
{
#ifndef USE_ROCM
// This kernel currently only supports H % 128 == 0 and assumes a
// fixed GROUP_SIZE of 128.
static
constexpr
int
GROUP_SIZE
=
128
;
TORCH_CHECK
(
input
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
y_q
.
dtype
()
==
torch
::
kFloat8_e4m3fn
||
y_q
.
dtype
()
==
torch
::
kFloat8_e4m3fnuz
);
TORCH_CHECK
(
input
.
size
(
-
1
)
%
(
GROUP_SIZE
*
2
)
==
0
);
bool
const
is_packed_ue8m0
=
(
y_s
.
dtype
()
==
torch
::
kInt32
&&
cast_scale_ue8m0
);
TORCH_CHECK
(
y_s
.
dtype
()
==
torch
::
kFloat32
||
is_packed_ue8m0
);
using
Idx_t
=
int64_t
;
Idx_t
E
=
input
.
size
(
0
);
Idx_t
T
=
input
.
size
(
1
);
Idx_t
H
=
input
.
size
(
2
)
/
2
;
Idx_t
stride_i_e
=
input
.
stride
(
0
);
Idx_t
stride_i_t
=
input
.
stride
(
1
);
Idx_t
stride_i_h
=
input
.
stride
(
2
);
Idx_t
stride_yq_e
=
y_q
.
stride
(
0
);
Idx_t
stride_yq_t
=
y_q
.
stride
(
1
);
Idx_t
stride_yq_h
=
y_q
.
stride
(
2
);
Idx_t
stride_counts_e
=
tokens_per_expert
.
stride
(
0
);
int
const
NUM_GROUPS
=
H
/
GROUP_SIZE
;
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
// TODO: Get this from cuda_arch ?
static
constexpr
int
SILU_V2_BLOCK_COUNT
=
132
*
32
;
#define KERNEL(BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \
STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, STAGES) \
static constexpr int NUM_WARPS = THREAD_COUNT / WARP_SIZE; \
int sms = SILU_V2_BLOCK_COUNT; \
static constexpr int max_shared_mem_bytes = \
GROUP_SIZE * 2 * STAGES * NUM_WARPS * 2; \
dim3 grid(sms), block(THREAD_COUNT); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
VLLM_DISPATCH_FP8_TYPES( \
y_q.scalar_type(), "silu_mul_fp8_quant_deep_gemm_kernel", [&] { \
vllm::silu_mul_fp8_quant_deep_gemm_kernel< \
BLOCK_COUNT, max_shared_mem_bytes, fp8_t, scale_t, THREAD_COUNT, \
Idx_t, CEIL_UE8M0, GROUP_SIZE, STAGES> \
<<<grid, block, max_shared_mem_bytes + (E + 1) * 16, stream>>>( \
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
(fp8_t*)y_q.data_ptr(), \
reinterpret_cast<scale_t*>(y_s.data_ptr()), \
reinterpret_cast<int32_t*>(tokens_per_expert.data_ptr()), E, \
T, H, stride_i_e, stride_i_t, stride_i_h, stride_yq_e, \
stride_yq_t, stride_yq_h, STRIDE_YS_E, STRIDE_YS_T, \
STRIDE_YS_G, STRIDE_YS_P, stride_counts_e); \
});
#define LAUNCH_ON_H(scale_t, STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \
STRIDE_YS_P, CEIL_UE8M0) \
if (H >= 4096 && (NUM_GROUPS % 8) == 0) { \
/* 8 warp config */
\
static constexpr int NUM_STAGES = 4; \
static constexpr int THREAD_COUNT = 256; \
KERNEL(SILU_V2_BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, \
STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, NUM_STAGES); \
} else { \
/* 1 warp config */
\
static constexpr int THREAD_COUNT = 32; \
KERNEL(SILU_V2_BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, \
STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, 2); \
}
Idx_t
stride_ys_e
=
y_s
.
stride
(
0
);
Idx_t
stride_ys_t
=
y_s
.
stride
(
1
);
Idx_t
stride_ys_g
=
y_s
.
stride
(
2
);
Idx_t
stride_ys_p
=
0
;
if
(
!
cast_scale_ue8m0
)
{
TORCH_CHECK
(
!
is_packed_ue8m0
);
LAUNCH_ON_H
(
float
,
stride_ys_e
,
stride_ys_t
,
stride_ys_g
,
stride_ys_p
,
false
);
return
;
}
if
(
!
is_packed_ue8m0
)
{
// UE8M0 but not packed
LAUNCH_ON_H
(
float
,
stride_ys_e
,
stride_ys_t
,
stride_ys_g
,
stride_ys_p
,
true
);
return
;
}
TORCH_CHECK
(
cast_scale_ue8m0
&&
is_packed_ue8m0
);
TORCH_CHECK
(
y_s
.
dtype
()
==
torch
::
kInt32
);
// Int32 packed ue8m0 scales tensor.
// Let E, T, G be the number to experts, number of tokens and number of groups
// respectively. Let, E = 2, T = 4, G = 6, in this case the int32 scales
// tensor are of shape [1, 4, 2] and stride [8, 1, 4]. The scales are expected
// to be arranged as follows,
// [[T0G0-T0G1-T0G2-T0G3, T0G4-T0G5-X-X,],
// [T1G0-T1G1-T1G2-T1G3, T1G4-T1G5-X-X,]
// [T2G0-T2G1-T2G2-T2G3, T2G4-T2G5-X-X,]
// [T3G0-T3G1-T3G2-T3G3, T3G4-T3G5-X-X,]]
// where, TxGy is the scale ue8m0 scale value of Token x, Group y.
//
// In memory (in bytes) the scale values are arranged as,
// [T0G0, T0G1, T0G2, T0G3, T1G0, T1G2, T1G3, T1G4, T2G0, T2G1, T2G3, T2G4,
// T3G0, T3G1, T3G2, T3G3, T0G4, T0G5, X, X, T1G4, T1G5, X, X, T2G4, T2G5,
// X, X, T3G4, T3G5, X, X]
//
// An Int32 tensor of size [1, 4, 2] and stride [8, 1, 4] can be represented
// as an uint8 tensor of shape [1, 2, 4, 4] and stride [32, 16, 4, 1]. In
// english, ignoring the Experts dimension, the original int32 tensor is
// simply treated as two packed [4, 4] uint8 tensor (or two [4, 1] int32
// tensor). The following strides setting reflects this change. Caveat: This
// means that the G dimension is no longer contiguous. i.e. Note that to move
// from G3 to G4, we need to jump along the packing dimension. The kernel
// handles this case.
stride_ys_e
*=
sizeof
(
int32_t
);
stride_ys_p
=
T
*
sizeof
(
int32_t
);
// Packing dimension
stride_ys_t
=
sizeof
(
int32_t
);
stride_ys_g
=
1
;
LAUNCH_ON_H
(
uint8_t
,
stride_ys_e
,
stride_ys_t
,
stride_ys_g
,
stride_ys_p
,
true
);
#endif
}
//
void persistent_masked_m_silu_mul_quant(
//
const at::Tensor& input, // (E, T, 2*H)
//
const at::Tensor& tokens_per_expert, // (E)
//
at::Tensor& y_q, // (E, T, H) [OUT]
//
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
//
bool cast_scale_ue8m0) {
//
#ifndef USE_ROCM
//
// This kernel currently only supports H % 128 == 0 and assumes a
//
// fixed GROUP_SIZE of 128.
//
static constexpr int GROUP_SIZE = 128;
//
TORCH_CHECK(input.dtype() == torch::kBFloat16);
//
TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn ||
//
y_q.dtype() == torch::kFloat8_e4m3fnuz);
//
TORCH_CHECK(input.size(-1) % (GROUP_SIZE * 2) == 0);
//
bool const is_packed_ue8m0 =
//
(y_s.dtype() == torch::kInt32 && cast_scale_ue8m0);
//
TORCH_CHECK(y_s.dtype() == torch::kFloat32 || is_packed_ue8m0);
//
using Idx_t = int64_t;
//
Idx_t E = input.size(0);
//
Idx_t T = input.size(1);
//
Idx_t H = input.size(2) / 2;
//
Idx_t stride_i_e = input.stride(0);
//
Idx_t stride_i_t = input.stride(1);
//
Idx_t stride_i_h = input.stride(2);
//
Idx_t stride_yq_e = y_q.stride(0);
//
Idx_t stride_yq_t = y_q.stride(1);
//
Idx_t stride_yq_h = y_q.stride(2);
//
Idx_t stride_counts_e = tokens_per_expert.stride(0);
//
int const NUM_GROUPS = H / GROUP_SIZE;
//
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
//
// TODO: Get this from cuda_arch ?
//
static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32;
//
#define KERNEL(BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \
//
STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, STAGES) \
//
static constexpr int NUM_WARPS = THREAD_COUNT / WARP_SIZE; \
//
int sms = SILU_V2_BLOCK_COUNT; \
//
static constexpr int max_shared_mem_bytes = \
//
GROUP_SIZE * 2 * STAGES * NUM_WARPS * 2; \
//
dim3 grid(sms), block(THREAD_COUNT); \
//
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
//
VLLM_DISPATCH_FP8_TYPES( \
//
y_q.scalar_type(), "silu_mul_fp8_quant_deep_gemm_kernel", [&] { \
//
vllm::silu_mul_fp8_quant_deep_gemm_kernel< \
//
BLOCK_COUNT, max_shared_mem_bytes, fp8_t, scale_t, THREAD_COUNT, \
//
Idx_t, CEIL_UE8M0, GROUP_SIZE, STAGES> \
//
<<<grid, block, max_shared_mem_bytes + (E + 1) * 16, stream>>>( \
//
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
//
(fp8_t*)y_q.data_ptr(), \
//
reinterpret_cast<scale_t*>(y_s.data_ptr()), \
//
reinterpret_cast<int32_t*>(tokens_per_expert.data_ptr()), E, \
//
T, H, stride_i_e, stride_i_t, stride_i_h, stride_yq_e, \
//
stride_yq_t, stride_yq_h, STRIDE_YS_E, STRIDE_YS_T, \
//
STRIDE_YS_G, STRIDE_YS_P, stride_counts_e); \
//
});
//
#define LAUNCH_ON_H(scale_t, STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \
//
STRIDE_YS_P, CEIL_UE8M0) \
//
if (H >= 4096 && (NUM_GROUPS % 8) == 0) { \
//
/* 8 warp config */ \
//
static constexpr int NUM_STAGES = 4; \
//
static constexpr int THREAD_COUNT = 256; \
//
KERNEL(SILU_V2_BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, \
//
STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, NUM_STAGES); \
//
} else { \
//
/* 1 warp config */ \
//
static constexpr int THREAD_COUNT = 32; \
//
KERNEL(SILU_V2_BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, \
//
STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, 2); \
//
}
//
Idx_t stride_ys_e = y_s.stride(0);
//
Idx_t stride_ys_t = y_s.stride(1);
//
Idx_t stride_ys_g = y_s.stride(2);
//
Idx_t stride_ys_p = 0;
//
if (!cast_scale_ue8m0) {
//
TORCH_CHECK(!is_packed_ue8m0);
//
LAUNCH_ON_H(float, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p,
//
false);
//
return;
//
}
//
if (!is_packed_ue8m0) {
//
// UE8M0 but not packed
//
LAUNCH_ON_H(float, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p,
//
true);
//
return;
//
}
//
TORCH_CHECK(cast_scale_ue8m0 && is_packed_ue8m0);
//
TORCH_CHECK(y_s.dtype() == torch::kInt32);
//
// Int32 packed ue8m0 scales tensor.
//
// Let E, T, G be the number to experts, number of tokens and number of groups
//
// respectively. Let, E = 2, T = 4, G = 6, in this case the int32 scales
//
// tensor are of shape [1, 4, 2] and stride [8, 1, 4]. The scales are expected
//
// to be arranged as follows,
//
// [[T0G0-T0G1-T0G2-T0G3, T0G4-T0G5-X-X,],
//
// [T1G0-T1G1-T1G2-T1G3, T1G4-T1G5-X-X,]
//
// [T2G0-T2G1-T2G2-T2G3, T2G4-T2G5-X-X,]
//
// [T3G0-T3G1-T3G2-T3G3, T3G4-T3G5-X-X,]]
//
// where, TxGy is the scale ue8m0 scale value of Token x, Group y.
//
//
//
// In memory (in bytes) the scale values are arranged as,
//
// [T0G0, T0G1, T0G2, T0G3, T1G0, T1G2, T1G3, T1G4, T2G0, T2G1, T2G3, T2G4,
//
// T3G0, T3G1, T3G2, T3G3, T0G4, T0G5, X, X, T1G4, T1G5, X, X, T2G4, T2G5,
//
// X, X, T3G4, T3G5, X, X]
//
//
//
// An Int32 tensor of size [1, 4, 2] and stride [8, 1, 4] can be represented
//
// as an uint8 tensor of shape [1, 2, 4, 4] and stride [32, 16, 4, 1]. In
//
// english, ignoring the Experts dimension, the original int32 tensor is
//
// simply treated as two packed [4, 4] uint8 tensor (or two [4, 1] int32
//
// tensor). The following strides setting reflects this change. Caveat: This
//
// means that the G dimension is no longer contiguous. i.e. Note that to move
//
// from G3 to G4, we need to jump along the packing dimension. The kernel
//
// handles this case.
//
stride_ys_e *= sizeof(int32_t);
//
stride_ys_p = T * sizeof(int32_t); // Packing dimension
//
stride_ys_t = sizeof(int32_t);
//
stride_ys_g = 1;
//
LAUNCH_ON_H(uint8_t, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p,
//
true);
//
#endif
//
}
csrc/torch_bindings.cpp
View file @
98f67566
...
...
@@ -20,12 +20,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops
//
ops
.
def
(
"persistent_masked_m_silu_mul_quant(Tensor input, Tensor counts, Tensor! "
"y_q, Tensor! y_s,"
"bool use_ue8m0) -> ()"
);
ops
.
impl
(
"persistent_masked_m_silu_mul_quant"
,
torch
::
kCUDA
,
&
persistent_masked_m_silu_mul_quant
);
//
ops.def(
//
"persistent_masked_m_silu_mul_quant(Tensor input, Tensor counts, Tensor! "
//
"y_q, Tensor! y_s,"
//
"bool use_ue8m0) -> ()");
//
ops.impl("persistent_masked_m_silu_mul_quant", torch::kCUDA,
//
&persistent_masked_m_silu_mul_quant);
ops
.
def
(
"weak_ref_tensor(Tensor input) -> Tensor"
);
ops
.
impl
(
"weak_ref_tensor"
,
torch
::
kCUDA
,
&
weak_ref_tensor
);
...
...
@@ -63,62 +63,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v2"
,
torch
::
kCUDA
,
&
paged_attention_v2
);
// Compute the attention between an input query and the cached
// keys/values using PagedAttention. (opt)
ops
.
def
(
"paged_attention_v1_opt("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v1_opt"
,
torch
::
kCUDA
,
&
paged_attention_v1_opt
);
// PagedAttention V2 (opt).
ops
.
def
(
"paged_attention_v2_opt("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v2_opt"
,
torch
::
kCUDA
,
&
paged_attention_v2_opt
);
// Compute the attention between an input query and the cached
// keys/values using PagedAttention. (opt)
ops
.
def
(
"paged_attention_v1_opt_tc("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v1_opt_tc"
,
torch
::
kCUDA
,
&
paged_attention_v1_opt_tc
);
// PagedAttention V2 (opt).
ops
.
def
(
"paged_attention_v2_opt_tc("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v2_opt_tc"
,
torch
::
kCUDA
,
&
paged_attention_v2_opt_tc
);
// Merge attn states
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case)
...
...
@@ -132,7 +76,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor suffix_lse) -> ()"
);
ops
.
impl
(
"merge_attn_states"
,
torch
::
kCUDA
,
&
merge_attn_states
);
// #ifndef USE_ROCM
ops
.
def
(
"convert_vertical_slash_indexes("
" Tensor! block_count, Tensor! block_offset, "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment