Commit 7b681c9e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.7.2-dev' into v0.7.3-dev

parents d969c150 bed03c5a
......@@ -161,7 +161,7 @@ def benchmark_config(
nn_moe = False
block_shape=[0, group_size]
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
print(f"shape: {x.shape[0]} | config: {config}")
def prepare(i: int):
input_gating.copy_(gating_output[i])
......@@ -187,6 +187,7 @@ def benchmark_config(
a2_scale=a2_scale,
use_nn_moe=nn_moe,
block_shape=block_shape,
moe_ep_size=1,
)
# JIT compilation & warmup
......@@ -221,8 +222,7 @@ def benchmark_config(
end_event.record()
end_event.synchronize()
latencies.append(start_event.elapsed_time(end_event))
avg = sum(latencies) / (num_iters * 10) * 1000 # us
print(f"avg: {avg}")
avg = sum(latencies) / (num_iters) * 1000 # us
# graph.reset()
return avg
......@@ -694,7 +694,7 @@ if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument("--model",
type=str,
default="/home/yang/llm-models/vllm-awq-models/DeepSeek-R1-AWQ/")
default="")
parser.add_argument("--tp-size",
"-tp",
"--tensor-parallel-size",
......@@ -711,4 +711,4 @@ if __name__ == "__main__":
parser.add_argument("--trust-remote-code", action="store_true", default=True)
args = parser.parse_args()
main(args)
main(args)
\ No newline at end of file
......@@ -92,7 +92,7 @@ __device__ void paged_attention_kernel_opt(
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const float* k_scale_ptr, const float* v_scale_ptr, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int seq_idx = blockIdx.z;
......@@ -316,7 +316,7 @@ __device__ void paged_attention_kernel_opt(
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, k_scale);
k_vec_quant, *k_scale_ptr);
}
}
}
......@@ -483,7 +483,7 @@ __device__ void paged_attention_kernel_opt(
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
v_scale);
*v_scale_ptr);
}
if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the
......@@ -610,7 +610,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel_opt(
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel_opt<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
......@@ -650,7 +650,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel_opt(
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel_opt<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
......@@ -770,7 +770,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt(
}
} // namespace vllm
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel_opt<T, CACHE_T, HEAD_SIZE, \
......@@ -783,20 +783,10 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt(
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
// #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
// vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
// NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads> \
// <<<dim3(grid), dim3(block)>>>( \
// out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
// scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
// alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
// kv_scale, tp_rank, blocksparse_local_blocks, \
// blocksparse_vert_stride, blocksparse_block_size, \
// blocksparse_head_sliding_step);
// TODO(woosuk): Tune NUM_THREADS.
template <typename T, typename CACHE_T, int BLOCK_SIZE,
......@@ -805,8 +795,8 @@ void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0);
......@@ -829,7 +819,8 @@ void paged_attention_v1_launcher(
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
......@@ -910,7 +901,7 @@ void paged_attention_v1_opt(
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
......@@ -928,7 +919,7 @@ void paged_attention_v1_opt(
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_heads, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \
hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel_opt<T, HEAD_SIZE, NUM_THREADS, \
......@@ -945,8 +936,8 @@ void paged_attention_v2_launcher(
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0);
......@@ -965,7 +956,8 @@ void paged_attention_v2_launcher(
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
......@@ -1058,7 +1050,7 @@ void paged_attention_v2_opt(
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
......@@ -1070,4 +1062,4 @@ void paged_attention_v2_opt(
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
\ No newline at end of file
#undef DIVIDE_ROUND_UP
......@@ -162,7 +162,7 @@ __device__ void paged_attention_kernel_TC(
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int seq_idx = blockIdx.z;
......@@ -639,7 +639,7 @@ __global__ void paged_attention_v1_kernel_TC(
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
#if defined(__gfx936__) || defined(__gfx928__)
......@@ -678,7 +678,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC(
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
#if defined(__gfx936__) || defined(__gfx928__)
......@@ -814,7 +814,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_t
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads,num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
......@@ -908,8 +908,8 @@ void paged_attention_v1_launcher_opt_tc(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0);
......@@ -932,7 +932,8 @@ void paged_attention_v1_launcher_opt_tc(
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
......@@ -1001,7 +1002,7 @@ void paged_attention_v1_launcher_opt_tc(
break; \
}
void paged_attention_v1_opt(
void paged_attention_v1(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
......@@ -1014,7 +1015,7 @@ void paged_attention_v1_opt(
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
......@@ -1032,14 +1033,14 @@ void paged_attention_v1_opt_tc(
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||(device_name!="gfx928" && device_name!="gfx936")){
paged_attention_v1_opt(out,query,key_cache,value_cache,num_kv_heads,
paged_attention_v1(out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step);
......@@ -1059,7 +1060,7 @@ void paged_attention_v1_opt_tc(
max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, \
num_heads, num_kv_heads, scale, block_tables_ptr, seq_lens_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step); \
hipLaunchKernelGGL( \
......@@ -1133,8 +1134,8 @@ void paged_attention_v2_launcher_opt_tc(
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0);
......@@ -1156,6 +1157,8 @@ void paged_attention_v2_launcher_opt_tc(
: nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
......@@ -1231,7 +1234,7 @@ void paged_attention_v2_launcher_opt_tc(
break; \
}
void paged_attention_v2_opt(
void paged_attention_v2(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
......@@ -1248,7 +1251,7 @@ void paged_attention_v2_opt(
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
......@@ -1270,14 +1273,14 @@ void paged_attention_v2_opt_tc(
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||(device_name!="gfx928" && device_name!="gfx936")){
paged_attention_v2_opt(out,exp_sums,max_logits,tmp_out,query,key_cache,value_cache,num_kv_heads,
paged_attention_v2(out,exp_sums,max_logits,tmp_out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step);
......
......@@ -105,7 +105,7 @@ __device__ void paged_attention_with_mask_kernel(
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const float* k_scale_ptr, const float* v_scale_ptr, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
......@@ -286,7 +286,7 @@ __device__ void paged_attention_with_mask_kernel(
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, k_scale);
k_vec_quant, *k_scale_ptr);
}
}
......@@ -424,7 +424,7 @@ __device__ void paged_attention_with_mask_kernel(
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
v_scale);
*v_scale_ptr);
}
if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the
......@@ -522,7 +522,7 @@ __global__ void paged_attention_v1_with_mask_kernel(
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
......@@ -559,7 +559,7 @@ __global__ void paged_attention_v2_with_mask_kernel(
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
......@@ -693,7 +693,7 @@ __global__ void paged_attention_v2_reduce_kernel(
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step, attn_masks_ptr, \
attn_masks_stride);
......@@ -706,8 +706,8 @@ void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
......@@ -728,7 +728,8 @@ void paged_attention_v1_launcher(
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
......@@ -842,7 +843,7 @@ void paged_attention_v1_with_mask(
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
......@@ -862,7 +863,7 @@ void paged_attention_v1_with_mask(
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step, \
attn_masks_ptr, attn_masks_stride); \
......@@ -880,8 +881,8 @@ void paged_attention_v2_launcher(
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
......@@ -902,7 +903,8 @@ void paged_attention_v2_launcher(
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
......@@ -1023,7 +1025,7 @@ void paged_attention_v2_with_mask(
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
......@@ -1037,4 +1039,4 @@ void paged_attention_v2_with_mask(
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
\ No newline at end of file
#undef DIVIDE_ROUND_UP
......@@ -92,7 +92,7 @@ __device__ void paged_attention_with_mask_kernel_opt(
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const float* k_scale_ptr, const float* v_scale_ptr, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
......@@ -317,7 +317,7 @@ __device__ void paged_attention_with_mask_kernel_opt(
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, k_scale);
k_vec_quant, *k_scale_ptr);
}
}
}
......@@ -498,7 +498,7 @@ __device__ void paged_attention_with_mask_kernel_opt(
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
v_scale);
*v_scale_ptr);
}
if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the
......@@ -625,7 +625,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_with_mask_kernel_opt
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
......@@ -666,7 +666,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_with_mask_kernel_opt
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
......@@ -800,7 +800,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt(
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step, attn_masks_ptr, \
attn_masks_stride);
......@@ -823,8 +823,8 @@ void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
......@@ -849,7 +849,8 @@ void paged_attention_v1_launcher(
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
......@@ -940,7 +941,7 @@ void paged_attention_v1_opt_with_mask(
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
......@@ -960,7 +961,7 @@ void paged_attention_v1_opt_with_mask(
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_heads, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step, \
attn_masks_ptr, attn_masks_stride); \
......@@ -978,8 +979,8 @@ void paged_attention_v2_launcher(
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
......@@ -1000,7 +1001,8 @@ void paged_attention_v2_launcher(
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
......@@ -1100,7 +1102,7 @@ void paged_attention_v2_opt_with_mask(
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
......@@ -1114,4 +1116,4 @@ void paged_attention_v2_opt_with_mask(
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
\ No newline at end of file
#undef DIVIDE_ROUND_UP
......@@ -57,7 +57,7 @@ void paged_attention_v1_opt(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
......@@ -68,7 +68,7 @@ void paged_attention_v2_opt(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
......@@ -78,7 +78,7 @@ void paged_attention_v1_opt_tc(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
......@@ -89,7 +89,7 @@ void paged_attention_v2_opt_tc(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
......@@ -101,7 +101,7 @@ void paged_attention_v1_with_mask(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
......@@ -114,7 +114,7 @@ void paged_attention_v2_with_mask(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
......@@ -126,7 +126,7 @@ void paged_attention_v1_opt_with_mask(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
......@@ -139,7 +139,7 @@ void paged_attention_v2_opt_with_mask(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
......@@ -151,7 +151,7 @@ void paged_attention_v1_opt_tc_with_mask(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
......@@ -164,7 +164,7 @@ void paged_attention_v2_opt_tc_with_mask(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
......
......@@ -58,7 +58,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
......@@ -72,7 +72,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
......@@ -86,7 +86,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
......@@ -100,7 +100,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
......@@ -114,7 +114,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step,"
......@@ -130,7 +130,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step,"
......@@ -146,7 +146,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step,"
......@@ -162,7 +162,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step,"
......@@ -178,7 +178,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step,"
......@@ -194,7 +194,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step,"
......
......@@ -17,6 +17,6 @@ setuptools_scm>=8
torch == 2.4.1
triton == 3.0.0
flash_attn == 2.6.1
lmslim == 0.2.0
lmslim == 0.2.1
numa
......@@ -508,10 +508,10 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha is None:
sha = get_sha(vllm_root)
if (major, minor) == ('2', '4'):
version = 'das.opt1.' + sha[:7]
version = 'das.opt1.cust1.' + sha[:7]
else:
if (major, minor) == ('2', '4'):
version = 'das.opt1'
version = 'das.opt1.cust1'
# dtk version
......
......@@ -117,8 +117,8 @@ def paged_attention_v1_with_mask(
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
......@@ -152,8 +152,8 @@ def paged_attention_v2_with_mask(
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
......@@ -185,8 +185,8 @@ def paged_attention_v1_opt(
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
......@@ -217,8 +217,8 @@ def paged_attention_v2_opt(
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
......@@ -246,8 +246,8 @@ def paged_attention_v1_opt_with_mask(
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
......@@ -281,8 +281,8 @@ def paged_attention_v2_opt_with_mask(
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
......@@ -314,8 +314,8 @@ def paged_attention_v1_opt_tc(
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
......@@ -345,8 +345,8 @@ def paged_attention_v2_opt_tc(
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
......@@ -375,8 +375,8 @@ def paged_attention_v1_opt_tc_with_mask(
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
......@@ -409,8 +409,8 @@ def paged_attention_v2_opt_tc_with_mask(
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
......
......@@ -88,7 +88,8 @@ def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])
visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", [])
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2VLForConditionalGeneration', 'Qwen2_5_VLForConditionalGeneration', 'Qwen2MoeForCausalLM', 'ChatGLMModel', 'ChatGLMForConditionalGeneration',
# TODO: support deepseek distillation series models ( 'LlamaForCausalLM', 'Qwen2ForCausalLM' )
support_nn_architectures = ['QWenLMHeadModel', 'Qwen2VLForConditionalGeneration', 'Qwen2_5_VLForConditionalGeneration', 'Qwen2MoeForCausalLM', 'ChatGLMModel', 'ChatGLMForConditionalGeneration',
'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM', 'MLPSpeculatorPreTrainedModel', 'FalconForCausalLM', 'DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM']
if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0':
......
......@@ -656,8 +656,6 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
# 暂时awq不支持cutlass
envs.VLLM_USE_TRITON_AWQ = True
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
......@@ -861,13 +859,13 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1)
# 暂时不支持TN
if self.quant_method == "awq" and not envs.VLLM_USE_TRITON_AWQ:
if self.config.quantization_config["quant_method"] == "awq" and not envs.VLLM_USE_TRITON_AWQ:
lay_key_words = [
"self_attn.q_a_proj.qweight",
"self_attn.q_b_proj.qweight",
"self_attn.kv_a_proj_with_mqa.qweight",
"self_attn.kv_b_proj.qweight",
"self_attn.kv_a_proj_with_mqa.qweight",
"self_attn.o_proj.qweight",
"mlp.gate_up_proj.qweight",
"mlp.down_proj.qweight",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment