Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
978aed53
Unverified
Commit
978aed53
authored
Jul 16, 2024
by
Michael Goin
Committed by
GitHub
Jul 16, 2024
Browse files
[Kernel][Attention] Separate `Attention.kv_scale` into `k_scale` and `v_scale` (#6081)
parent
160e1d8c
Changes
33
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
157 additions
and
99 deletions
+157
-99
benchmarks/kernels/benchmark_paged_attention.py
benchmarks/kernels/benchmark_paged_attention.py
+5
-3
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+27
-26
csrc/cache.h
csrc/cache.h
+2
-2
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+11
-10
csrc/cpu/attention.cpp
csrc/cpu/attention.cpp
+6
-6
csrc/cpu/cache.cpp
csrc/cpu/cache.cpp
+3
-2
csrc/cpu/torch_bindings.cpp
csrc/cpu/torch_bindings.cpp
+5
-5
csrc/ops.h
csrc/ops.h
+4
-4
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+5
-5
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+5
-3
tests/kernels/test_blocksparse_attention.py
tests/kernels/test_blocksparse_attention.py
+5
-3
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+2
-2
tests/quantization/test_fp8.py
tests/quantization/test_fp8.py
+35
-5
vllm/_custom_ops.py
vllm/_custom_ops.py
+11
-7
vllm/_ipex_ops.py
vllm/_ipex_ops.py
+6
-3
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+2
-1
vllm/attention/backends/blocksparse_attn.py
vllm/attention/backends/blocksparse_attn.py
+6
-3
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+4
-2
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+4
-2
vllm/attention/backends/ipex_attn.py
vllm/attention/backends/ipex_attn.py
+9
-5
No files found.
benchmarks/kernels/benchmark_paged_attention.py
View file @
978aed53
...
@@ -100,7 +100,7 @@ def main(
...
@@ -100,7 +100,7 @@ def main(
start_time
=
time
.
perf_counter
()
start_time
=
time
.
perf_counter
()
# Using default kv_scale
# Using default kv_scale
kv_scale
=
1.0
k
_scale
=
v_scale
=
1.0
for
_
in
range
(
num_iters
):
for
_
in
range
(
num_iters
):
if
version
==
"v1"
:
if
version
==
"v1"
:
...
@@ -117,7 +117,8 @@ def main(
...
@@ -117,7 +117,8 @@ def main(
max_seq_len
,
max_seq_len
,
alibi_slopes
,
alibi_slopes
,
kv_cache_dtype
,
kv_cache_dtype
,
kv_scale
,
k_scale
,
v_scale
,
)
)
elif
version
==
"v2"
:
elif
version
==
"v2"
:
ops
.
paged_attention_v2
(
ops
.
paged_attention_v2
(
...
@@ -136,7 +137,8 @@ def main(
...
@@ -136,7 +137,8 @@ def main(
max_seq_len
,
max_seq_len
,
alibi_slopes
,
alibi_slopes
,
kv_cache_dtype
,
kv_cache_dtype
,
kv_scale
,
k_scale
,
v_scale
,
)
)
else
:
else
:
raise
ValueError
(
f
"Invalid version:
{
version
}
"
)
raise
ValueError
(
f
"Invalid version:
{
version
}
"
)
...
...
csrc/attention/attention_kernels.cu
View file @
978aed53
...
@@ -105,9 +105,9 @@ __device__ void paged_attention_kernel(
...
@@ -105,9 +105,9 @@ __device__ void paged_attention_kernel(
const
int
max_num_blocks_per_seq
,
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k
v
_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_
vert_stride
,
const
int
blocksparse_
block_siz
e
,
const
int
blocksparse_
local_blocks
,
const
int
blocksparse_
vert_strid
e
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
seq_idx
=
blockIdx
.
y
;
const
int
seq_idx
=
blockIdx
.
y
;
const
int
partition_idx
=
blockIdx
.
z
;
const
int
partition_idx
=
blockIdx
.
z
;
const
int
max_num_partitions
=
gridDim
.
z
;
const
int
max_num_partitions
=
gridDim
.
z
;
...
@@ -285,7 +285,7 @@ __device__ void paged_attention_kernel(
...
@@ -285,7 +285,7 @@ __device__ void paged_attention_kernel(
Quant_vec
k_vec_quant
=
*
reinterpret_cast
<
const
Quant_vec
*>
(
Quant_vec
k_vec_quant
=
*
reinterpret_cast
<
const
Quant_vec
*>
(
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
k_vecs
[
j
]
=
fp8
::
scaled_convert
<
K_vec
,
Quant_vec
,
KV_DTYPE
>
(
k_vecs
[
j
]
=
fp8
::
scaled_convert
<
K_vec
,
Quant_vec
,
KV_DTYPE
>
(
k_vec_quant
,
k
v
_scale
);
k_vec_quant
,
k_scale
);
}
}
}
}
...
@@ -415,7 +415,7 @@ __device__ void paged_attention_kernel(
...
@@ -415,7 +415,7 @@ __device__ void paged_attention_kernel(
*
reinterpret_cast
<
const
V_quant_vec
*>
(
v_ptr
+
offset
);
*
reinterpret_cast
<
const
V_quant_vec
*>
(
v_ptr
+
offset
);
// Vector conversion from V_quant_vec to V_vec.
// Vector conversion from V_quant_vec to V_vec.
v_vec
=
fp8
::
scaled_convert
<
V_vec
,
V_quant_vec
,
KV_DTYPE
>
(
v_quant_vec
,
v_vec
=
fp8
::
scaled_convert
<
V_vec
,
V_quant_vec
,
KV_DTYPE
>
(
v_quant_vec
,
k
v_scale
);
v_scale
);
}
}
if
(
block_idx
==
num_seq_blocks
-
1
)
{
if
(
block_idx
==
num_seq_blocks
-
1
)
{
// NOTE(woosuk): When v_vec contains the tokens that are out of the
// NOTE(woosuk): When v_vec contains the tokens that are out of the
...
@@ -513,15 +513,15 @@ __global__ void paged_attention_v1_kernel(
...
@@ -513,15 +513,15 @@ __global__ void paged_attention_v1_kernel(
const
int
max_num_blocks_per_seq
,
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k
v
_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_
vert_stride
,
const
int
blocksparse_
block_siz
e
,
const
int
blocksparse_
local_blocks
,
const
int
blocksparse_
vert_strid
e
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
>
(
KV_DTYPE
,
IS_BLOCK_SPARSE
>
(
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
v_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
v_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
kv_scale
,
tp_rank
,
blocksparse_local_blocks
,
kv_head_stride
,
k
_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
blocksparse_head_sliding_step
);
}
}
...
@@ -549,14 +549,14 @@ __global__ void paged_attention_v2_kernel(
...
@@ -549,14 +549,14 @@ __global__ void paged_attention_v2_kernel(
const
int
max_num_blocks_per_seq
,
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
k
v
_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_
vert_stride
,
const
int
blocksparse_
block_siz
e
,
const
int
blocksparse_
local_blocks
,
const
int
blocksparse_
vert_strid
e
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
PARTITION_SIZE
>
(
KV_DTYPE
,
IS_BLOCK_SPARSE
,
PARTITION_SIZE
>
(
exp_sums
,
max_logits
,
tmp_out
,
q
,
k_cache
,
v_cache
,
num_kv_heads
,
scale
,
exp_sums
,
max_logits
,
tmp_out
,
q
,
k_cache
,
v_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
kv_scale
,
tp_rank
,
kv_block_stride
,
kv_head_stride
,
k
_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
blocksparse_head_sliding_step
);
}
}
...
@@ -682,7 +682,7 @@ __global__ void paged_attention_v2_reduce_kernel(
...
@@ -682,7 +682,7 @@ __global__ void paged_attention_v2_reduce_kernel(
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
kv_scale, tp_rank, blocksparse_local_blocks,
\
k
_scale,
v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
blocksparse_head_sliding_step);
...
@@ -694,8 +694,8 @@ void paged_attention_v1_launcher(
...
@@ -694,8 +694,8 @@ void paged_attention_v1_launcher(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
max_seq_len
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k
v
_scale
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_head_sliding_step
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_seqs
=
query
.
size
(
0
);
...
@@ -770,7 +770,7 @@ void paged_attention_v1_launcher(
...
@@ -770,7 +770,7 @@ void paged_attention_v1_launcher(
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
IS_BLOCK_SPARSE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, kv_scale, tp_rank,
\
seq_lens, max_seq_len, alibi_slopes, k
_scale,
v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step);
blocksparse_block_size, blocksparse_head_sliding_step);
...
@@ -815,8 +815,8 @@ void paged_attention_v1(
...
@@ -815,8 +815,8 @@ void paged_attention_v1(
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_seq_len
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k
v
_scale
,
const
int64_t
tp_rank
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
)
{
const
int64_t
blocksparse_head_sliding_step
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
...
@@ -833,7 +833,7 @@ void paged_attention_v1(
...
@@ -833,7 +833,7 @@ void paged_attention_v1(
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, kv_scale, tp_rank,
\
kv_block_stride, kv_head_stride, k
_scale,
v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \
blocksparse_block_size, blocksparse_head_sliding_step); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
...
@@ -850,8 +850,8 @@ void paged_attention_v2_launcher(
...
@@ -850,8 +850,8 @@ void paged_attention_v2_launcher(
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
max_seq_len
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k
v
_scale
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
k_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_head_sliding_step
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_seqs
=
query
.
size
(
0
);
...
@@ -932,8 +932,9 @@ void paged_attention_v2_launcher(
...
@@ -932,8 +932,9 @@ void paged_attention_v2_launcher(
IS_BLOCK_SPARSE>( \
IS_BLOCK_SPARSE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_block_size, blocksparse_head_sliding_step);
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
switch (is_block_sparse) { \
...
@@ -980,8 +981,8 @@ void paged_attention_v2(
...
@@ -980,8 +981,8 @@ void paged_attention_v2(
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
torch
::
Tensor
&
seq_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_seq_len
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k
v
_scale
,
const
int64_t
tp_rank
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
)
{
const
int64_t
blocksparse_head_sliding_step
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
...
...
csrc/cache.h
View file @
978aed53
...
@@ -18,8 +18,8 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
...
@@ -18,8 +18,8 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
void
reshape_and_cache
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
value
,
void
reshape_and_cache
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
value
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
slot_mapping
,
torch
::
Tensor
&
slot_mapping
,
const
std
::
string
&
kv_cache_dtype
,
const
std
::
string
&
kv_cache_dtype
,
const
double
k_scale
,
const
double
k
v_scale
);
const
double
v_scale
);
void
reshape_and_cache_flash
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
value
,
void
reshape_and_cache_flash
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
value
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
key_cache
,
...
...
csrc/cache_kernels.cu
View file @
978aed53
...
@@ -159,8 +159,8 @@ __global__ void reshape_and_cache_kernel(
...
@@ -159,8 +159,8 @@ __global__ void reshape_and_cache_kernel(
// block_size]
// block_size]
const
int64_t
*
__restrict__
slot_mapping
,
// [num_tokens]
const
int64_t
*
__restrict__
slot_mapping
,
// [num_tokens]
const
int
key_stride
,
const
int
value_stride
,
const
int
num_heads
,
const
int
key_stride
,
const
int
value_stride
,
const
int
num_heads
,
const
int
head_size
,
const
int
block_size
,
const
int
x
,
const
int
head_size
,
const
int
block_size
,
const
int
x
,
const
float
k_scale
,
const
float
k
v_scale
)
{
const
float
v_scale
)
{
const
int64_t
token_idx
=
blockIdx
.
x
;
const
int64_t
token_idx
=
blockIdx
.
x
;
const
int64_t
slot_idx
=
slot_mapping
[
token_idx
];
const
int64_t
slot_idx
=
slot_mapping
[
token_idx
];
if
(
slot_idx
<
0
)
{
if
(
slot_idx
<
0
)
{
...
@@ -196,9 +196,9 @@ __global__ void reshape_and_cache_kernel(
...
@@ -196,9 +196,9 @@ __global__ void reshape_and_cache_kernel(
value_cache
[
tgt_value_idx
]
=
tgt_value
;
value_cache
[
tgt_value_idx
]
=
tgt_value
;
}
else
{
}
else
{
key_cache
[
tgt_key_idx
]
=
key_cache
[
tgt_key_idx
]
=
fp8
::
scaled_convert
<
cache_t
,
scalar_t
,
kv_dt
>
(
tgt_key
,
k
v
_scale
);
fp8
::
scaled_convert
<
cache_t
,
scalar_t
,
kv_dt
>
(
tgt_key
,
k_scale
);
value_cache
[
tgt_value_idx
]
=
value_cache
[
tgt_value_idx
]
=
fp8
::
scaled_convert
<
cache_t
,
scalar_t
,
kv_dt
>
(
tgt_value
,
k
v_scale
);
fp8
::
scaled_convert
<
cache_t
,
scalar_t
,
kv_dt
>
(
tgt_value
,
v_scale
);
}
}
}
}
}
}
...
@@ -248,7 +248,7 @@ __global__ void reshape_and_cache_flash_kernel(
...
@@ -248,7 +248,7 @@ __global__ void reshape_and_cache_flash_kernel(
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
num_heads, head_size, block_size, x, kv_scale);
num_heads, head_size, block_size, x, k
_scale,
v_scale);
void
reshape_and_cache
(
void
reshape_and_cache
(
torch
::
Tensor
&
key
,
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
key
,
// [num_tokens, num_heads, head_size]
...
@@ -258,7 +258,8 @@ void reshape_and_cache(
...
@@ -258,7 +258,8 @@ void reshape_and_cache(
torch
::
Tensor
&
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
torch
::
Tensor
&
slot_mapping
,
// [num_tokens]
torch
::
Tensor
&
slot_mapping
,
// [num_tokens]
const
std
::
string
&
kv_cache_dtype
,
const
double
kv_scale
)
{
const
std
::
string
&
kv_cache_dtype
,
const
double
k_scale
,
const
double
v_scale
)
{
int
num_tokens
=
key
.
size
(
0
);
int
num_tokens
=
key
.
size
(
0
);
int
num_heads
=
key
.
size
(
1
);
int
num_heads
=
key
.
size
(
1
);
int
head_size
=
key
.
size
(
2
);
int
head_size
=
key
.
size
(
2
);
...
@@ -318,13 +319,13 @@ namespace vllm {
...
@@ -318,13 +319,13 @@ namespace vllm {
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
__global__
void
convert_fp8_kernel
(
const
Tin
*
__restrict__
src_cache
,
__global__
void
convert_fp8_kernel
(
const
Tin
*
__restrict__
src_cache
,
Tout
*
__restrict__
dst_cache
,
Tout
*
__restrict__
dst_cache
,
const
float
kv_
scale
,
const
float
scale
,
const
int64_t
block_stride
)
{
const
int64_t
block_stride
)
{
const
int64_t
block_idx
=
blockIdx
.
x
;
const
int64_t
block_idx
=
blockIdx
.
x
;
for
(
int
i
=
threadIdx
.
x
;
i
<
block_stride
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
block_stride
;
i
+=
blockDim
.
x
)
{
int64_t
idx
=
block_idx
*
block_stride
+
i
;
int64_t
idx
=
block_idx
*
block_stride
+
i
;
dst_cache
[
idx
]
=
dst_cache
[
idx
]
=
fp8
::
scaled_convert
<
Tout
,
Tin
,
kv_dt
>
(
src_cache
[
idx
],
kv_
scale
);
fp8
::
scaled_convert
<
Tout
,
Tin
,
kv_dt
>
(
src_cache
[
idx
],
scale
);
}
}
}
}
...
@@ -333,11 +334,11 @@ __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
...
@@ -333,11 +334,11 @@ __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
reinterpret_cast<Tout*>(dst_cache.data_ptr()),
kv_
scale, block_stride);
reinterpret_cast<Tout*>(dst_cache.data_ptr()), scale, block_stride);
// Only for testing.
// Only for testing.
void
convert_fp8
(
torch
::
Tensor
&
dst_cache
,
torch
::
Tensor
&
src_cache
,
void
convert_fp8
(
torch
::
Tensor
&
dst_cache
,
torch
::
Tensor
&
src_cache
,
const
double
kv_
scale
,
const
std
::
string
&
kv_cache_dtype
)
{
const
double
scale
,
const
std
::
string
&
kv_cache_dtype
)
{
torch
::
Device
src_device
=
src_cache
.
device
();
torch
::
Device
src_device
=
src_cache
.
device
();
torch
::
Device
dst_device
=
dst_cache
.
device
();
torch
::
Device
dst_device
=
dst_cache
.
device
();
TORCH_CHECK
(
src_device
.
is_cuda
(),
"src must be on a GPU"
)
TORCH_CHECK
(
src_device
.
is_cuda
(),
"src must be on a GPU"
)
...
...
csrc/cpu/attention.cpp
View file @
978aed53
...
@@ -423,11 +423,11 @@ void paged_attention_v1(
...
@@ -423,11 +423,11 @@ void paged_attention_v1(
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k
v
_scale
,
const
int64_t
tp_rank
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
)
{
const
int64_t
blocksparse_head_sliding_step
)
{
TORCH_CHECK
(
kv_scale
==
1.0
f
);
TORCH_CHECK
(
k
_scale
==
1.0
f
&&
v_scale
==
1.0
f
);
TORCH_CHECK
(
blocksparse_vert_stride
<=
1
,
TORCH_CHECK
(
blocksparse_vert_stride
<=
1
,
"CPU backend does not support blocksparse attention yet."
);
"CPU backend does not support blocksparse attention yet."
);
VLLM_DISPATCH_FLOATING_TYPES
(
query
.
scalar_type
(),
"paged_attention_v1_impl"
,
VLLM_DISPATCH_FLOATING_TYPES
(
query
.
scalar_type
(),
"paged_attention_v1_impl"
,
...
@@ -742,11 +742,11 @@ void paged_attention_v2(
...
@@ -742,11 +742,11 @@ void paged_attention_v2(
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k
v
_scale
,
const
int64_t
tp_rank
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
)
{
const
int64_t
blocksparse_head_sliding_step
)
{
TORCH_CHECK
(
kv_scale
==
1.0
f
);
TORCH_CHECK
(
k
_scale
==
1.0
f
&&
v_scale
==
1.0
f
);
TORCH_CHECK
(
blocksparse_vert_stride
<=
1
,
TORCH_CHECK
(
blocksparse_vert_stride
<=
1
,
"CPU backend does not support blocksparse attention yet."
);
"CPU backend does not support blocksparse attention yet."
);
VLLM_DISPATCH_FLOATING_TYPES
(
query
.
scalar_type
(),
"paged_attention_v2_impl"
,
VLLM_DISPATCH_FLOATING_TYPES
(
query
.
scalar_type
(),
"paged_attention_v2_impl"
,
...
...
csrc/cpu/cache.cpp
View file @
978aed53
...
@@ -107,8 +107,9 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
...
@@ -107,8 +107,9 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
void
reshape_and_cache
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
value
,
void
reshape_and_cache
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
value
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
slot_mapping
,
torch
::
Tensor
&
slot_mapping
,
const
std
::
string
&
kv_cache_dtype
,
double
kv_scale
)
{
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
TORCH_CHECK
(
kv_scale
==
1.0
f
);
double
v_scale
)
{
TORCH_CHECK
(
k_scale
==
1.0
f
&&
v_scale
==
1.0
f
);
int
num_tokens
=
key
.
size
(
0
);
int
num_tokens
=
key
.
size
(
0
);
int
num_heads
=
key
.
size
(
1
);
int
num_heads
=
key
.
size
(
1
);
...
...
csrc/cpu/torch_bindings.cpp
View file @
978aed53
...
@@ -16,8 +16,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -16,8 +16,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k
v
_scale,
int tp_rank
,"
" str kv_cache_dtype, float k_scale,
float v_scale
,"
" int blocksparse_local_blocks,"
" int
tp_rank, int
blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v1"
,
torch
::
kCPU
,
&
paged_attention_v1
);
ops
.
impl
(
"paged_attention_v1"
,
torch
::
kCPU
,
&
paged_attention_v1
);
...
@@ -30,8 +30,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -30,8 +30,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k
v
_scale,
int tp_rank
,"
" str kv_cache_dtype, float k_scale,
float v_scale
,"
" int blocksparse_local_blocks,"
" int
tp_rank, int
blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v2"
,
torch
::
kCPU
,
&
paged_attention_v2
);
ops
.
impl
(
"paged_attention_v2"
,
torch
::
kCPU
,
&
paged_attention_v2
);
...
@@ -103,7 +103,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
...
@@ -103,7 +103,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" Tensor! key_cache, Tensor! value_cache,"
" Tensor! key_cache, Tensor! value_cache,"
" Tensor slot_mapping,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" str kv_cache_dtype,"
" float kv_scale) -> ()"
);
" float k
_scale, float
v_scale) -> ()"
);
cache_ops
.
impl
(
"reshape_and_cache"
,
torch
::
kCPU
,
&
reshape_and_cache
);
cache_ops
.
impl
(
"reshape_and_cache"
,
torch
::
kCPU
,
&
reshape_and_cache
);
}
}
...
...
csrc/ops.h
View file @
978aed53
...
@@ -8,8 +8,8 @@ void paged_attention_v1(
...
@@ -8,8 +8,8 @@ void paged_attention_v1(
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k
v
_scale
,
const
int64_t
tp_rank
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
const
int64_t
blocksparse_head_sliding_step
);
...
@@ -19,8 +19,8 @@ void paged_attention_v2(
...
@@ -19,8 +19,8 @@ void paged_attention_v2(
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k
v
_scale
,
const
int64_t
tp_rank
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
const
int64_t
blocksparse_head_sliding_step
);
...
...
csrc/torch_bindings.cpp
View file @
978aed53
...
@@ -27,8 +27,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -27,8 +27,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k
v
_scale,
int tp_rank
,"
" str kv_cache_dtype, float k_scale,
float v_scale
,"
" int blocksparse_local_blocks,"
" int
tp_rank, int
blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v1"
,
torch
::
kCUDA
,
&
paged_attention_v1
);
ops
.
impl
(
"paged_attention_v1"
,
torch
::
kCUDA
,
&
paged_attention_v1
);
...
@@ -41,8 +41,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -41,8 +41,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k
v
_scale,
int tp_rank
,"
" str kv_cache_dtype, float k_scale,
float v_scale
,"
" int blocksparse_local_blocks,"
" int
tp_rank, int
blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v2"
,
torch
::
kCUDA
,
&
paged_attention_v2
);
ops
.
impl
(
"paged_attention_v2"
,
torch
::
kCUDA
,
&
paged_attention_v2
);
...
@@ -223,7 +223,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
...
@@ -223,7 +223,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" Tensor! key_cache, Tensor! value_cache,"
" Tensor! key_cache, Tensor! value_cache,"
" Tensor slot_mapping,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" str kv_cache_dtype,"
" float kv_scale) -> ()"
);
" float k
_scale, float
v_scale) -> ()"
);
cache_ops
.
impl
(
"reshape_and_cache"
,
torch
::
kCUDA
,
&
reshape_and_cache
);
cache_ops
.
impl
(
"reshape_and_cache"
,
torch
::
kCUDA
,
&
reshape_and_cache
);
// Reshape the key and value tensors and cache them.
// Reshape the key and value tensors and cache them.
...
...
tests/kernels/test_attention.py
View file @
978aed53
...
@@ -175,7 +175,7 @@ def test_paged_attention(
...
@@ -175,7 +175,7 @@ def test_paged_attention(
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
# Using default kv_scale
# Using default kv_scale
kv_scale
=
1.0
k
_scale
=
v_scale
=
1.0
# Call the paged attention kernel.
# Call the paged attention kernel.
output
=
torch
.
empty_like
(
query
)
output
=
torch
.
empty_like
(
query
)
...
@@ -193,7 +193,8 @@ def test_paged_attention(
...
@@ -193,7 +193,8 @@ def test_paged_attention(
max_seq_len
,
max_seq_len
,
alibi_slopes
,
alibi_slopes
,
kv_cache_dtype
,
kv_cache_dtype
,
kv_scale
,
k_scale
,
v_scale
,
)
)
elif
version
==
"v2"
:
elif
version
==
"v2"
:
num_partitions
=
((
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
num_partitions
=
((
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
...
@@ -224,7 +225,8 @@ def test_paged_attention(
...
@@ -224,7 +225,8 @@ def test_paged_attention(
max_seq_len
,
max_seq_len
,
alibi_slopes
,
alibi_slopes
,
kv_cache_dtype
,
kv_cache_dtype
,
kv_scale
,
k_scale
,
v_scale
,
)
)
else
:
else
:
raise
AssertionError
(
f
"Unknown version:
{
version
}
"
)
raise
AssertionError
(
f
"Unknown version:
{
version
}
"
)
...
...
tests/kernels/test_blocksparse_attention.py
View file @
978aed53
...
@@ -212,7 +212,7 @@ def test_paged_attention(
...
@@ -212,7 +212,7 @@ def test_paged_attention(
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
# Using default kv_scale
# Using default kv_scale
kv_scale
=
1.0
k
_scale
=
v_scale
=
1.0
tp_rank
=
0
tp_rank
=
0
# Call the paged attention kernel.
# Call the paged attention kernel.
...
@@ -231,7 +231,8 @@ def test_paged_attention(
...
@@ -231,7 +231,8 @@ def test_paged_attention(
max_seq_len
,
max_seq_len
,
alibi_slopes
,
alibi_slopes
,
kv_cache_dtype
,
kv_cache_dtype
,
kv_scale
,
k_scale
,
v_scale
,
tp_rank
=
tp_rank
,
tp_rank
=
tp_rank
,
blocksparse_local_blocks
=
blocksparse_local_blocks
,
blocksparse_local_blocks
=
blocksparse_local_blocks
,
blocksparse_vert_stride
=
blocksparse_vert_stride
,
blocksparse_vert_stride
=
blocksparse_vert_stride
,
...
@@ -267,7 +268,8 @@ def test_paged_attention(
...
@@ -267,7 +268,8 @@ def test_paged_attention(
max_seq_len
,
max_seq_len
,
alibi_slopes
,
alibi_slopes
,
kv_cache_dtype
,
kv_cache_dtype
,
kv_scale
,
k_scale
,
v_scale
,
tp_rank
=
tp_rank
,
tp_rank
=
tp_rank
,
blocksparse_local_blocks
=
blocksparse_local_blocks
,
blocksparse_local_blocks
=
blocksparse_local_blocks
,
blocksparse_vert_stride
=
blocksparse_vert_stride
,
blocksparse_vert_stride
=
blocksparse_vert_stride
,
...
...
tests/kernels/test_cache.py
View file @
978aed53
...
@@ -155,11 +155,11 @@ def test_reshape_and_cache(
...
@@ -155,11 +155,11 @@ def test_reshape_and_cache(
cloned_value_cache
=
value_cache
.
clone
()
cloned_value_cache
=
value_cache
.
clone
()
# Using default kv_scale
# Using default kv_scale
kv_scale
=
1.0
k
_scale
=
v_scale
=
1.0
# Call the reshape_and_cache kernel.
# Call the reshape_and_cache kernel.
ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
kv_scale
)
kv_cache_dtype
,
k
_scale
,
v_scale
)
if
kv_cache_dtype
==
"fp8"
:
if
kv_cache_dtype
==
"fp8"
:
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
...
...
tests/quantization/test_fp8.py
View file @
978aed53
...
@@ -7,19 +7,49 @@ import torch
...
@@ -7,19 +7,49 @@ import torch
from
tests.quantization.utils
import
is_quant_method_supported
from
tests.quantization.utils
import
is_quant_method_supported
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.fp8
import
Fp8LinearMethod
from
vllm.model_executor.layers.quantization.fp8
import
(
Fp8KVCacheMethod
,
Fp8LinearMethod
)
MODELS
=
[
MODELS
=
[
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8"
,
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8
-KV
"
,
"nm-testing/Phi-3-mini-128k-instruct-FP8"
,
"nm-testing/Phi-3-mini-128k-instruct-FP8"
,
]
]
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
reason
=
"FP8 is not supported on this GPU type."
)
reason
=
"FP8 is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model_id"
,
MODELS
)
def
test_model_load_and_run
(
vllm_runner
,
model
:
str
):
def
test_model_load_and_run
(
vllm_runner
,
model_id
:
str
):
with
vllm_runner
(
model
)
as
llm
:
with
vllm_runner
(
model_id
)
as
llm
:
# note: this does not test accuracy, just that we can run through
# see lm-eval tests for accuracy
outputs
=
llm
.
generate_greedy
(
prompts
=
[
"Hello my name is"
],
max_tokens
=
10
)
print
(
outputs
[
0
][
1
])
KV_CACHE_MODELS
=
[
# Deprecated AutoFP8 format using .kv_scale
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV"
,
# AutoFP8 format using separate .k_scale and .v_scale
"nm-testing/Qwen2-1.5B-Instruct-FP8-K-V"
,
]
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
reason
=
"FP8 is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"model_id"
,
KV_CACHE_MODELS
)
def
test_kv_cache_model_load_and_run
(
vllm_runner
,
model_id
:
str
):
with
vllm_runner
(
model_id
,
kv_cache_dtype
=
"fp8"
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
attn
=
model
.
model
.
layers
[
0
].
self_attn
.
attn
assert
isinstance
(
attn
.
quant_method
,
Fp8KVCacheMethod
)
# NOTE: it is valid for scales to be 1.0 (default value), but we know
# these checkpoints have scales < 1.0
assert
0.0
<
attn
.
_k_scale
<
1.0
assert
0.0
<
attn
.
_v_scale
<
1.0
# note: this does not test accuracy, just that we can run through
# note: this does not test accuracy, just that we can run through
# see lm-eval tests for accuracy
# see lm-eval tests for accuracy
outputs
=
llm
.
generate_greedy
(
prompts
=
[
"Hello my name is"
],
outputs
=
llm
.
generate_greedy
(
prompts
=
[
"Hello my name is"
],
...
...
vllm/_custom_ops.py
View file @
978aed53
...
@@ -84,7 +84,8 @@ def paged_attention_v1(
...
@@ -84,7 +84,8 @@ def paged_attention_v1(
max_seq_len
:
int
,
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
kv_scale
:
float
,
k_scale
:
float
,
v_scale
:
float
,
tp_rank
:
int
=
0
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
...
@@ -94,8 +95,9 @@ def paged_attention_v1(
...
@@ -94,8 +95,9 @@ def paged_attention_v1(
torch
.
ops
.
_C
.
paged_attention_v1
(
torch
.
ops
.
_C
.
paged_attention_v1
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
def
paged_attention_v2
(
def
paged_attention_v2
(
...
@@ -114,7 +116,8 @@ def paged_attention_v2(
...
@@ -114,7 +116,8 @@ def paged_attention_v2(
max_seq_len
:
int
,
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
kv_scale
:
float
,
k_scale
:
float
,
v_scale
:
float
,
tp_rank
:
int
=
0
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
...
@@ -124,7 +127,7 @@ def paged_attention_v2(
...
@@ -124,7 +127,7 @@ def paged_attention_v2(
torch
.
ops
.
_C
.
paged_attention_v2
(
torch
.
ops
.
_C
.
paged_attention_v2
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
tp_rank
,
alibi_slopes
,
kv_cache_dtype
,
k
_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
blocksparse_block_size
,
blocksparse_head_sliding_step
)
...
@@ -374,11 +377,12 @@ def reshape_and_cache(
...
@@ -374,11 +377,12 @@ def reshape_and_cache(
value_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
kv_scale
:
float
,
k_scale
:
float
,
v_scale
:
float
,
)
->
None
:
)
->
None
:
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
kv_scale
)
kv_cache_dtype
,
k
_scale
,
v_scale
)
def
reshape_and_cache_flash
(
def
reshape_and_cache_flash
(
...
...
vllm/_ipex_ops.py
View file @
978aed53
...
@@ -59,7 +59,8 @@ class ipex_ops:
...
@@ -59,7 +59,8 @@ class ipex_ops:
max_context_len
:
int
,
max_context_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
kv_scale
:
float
,
k_scale
:
float
,
v_scale
:
float
,
tp_rank
:
int
=
0
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
...
@@ -99,7 +100,8 @@ class ipex_ops:
...
@@ -99,7 +100,8 @@ class ipex_ops:
max_context_len
:
int
,
max_context_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
kv_scale
:
float
,
k_scale
:
float
,
v_scale
:
float
,
tp_rank
:
int
=
0
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
...
@@ -227,7 +229,8 @@ class ipex_ops:
...
@@ -227,7 +229,8 @@ class ipex_ops:
value_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
kv_scale
:
float
,
k_scale
:
float
,
v_scale
:
float
,
)
->
None
:
)
->
None
:
assert
kv_cache_dtype
==
"auto"
assert
kv_cache_dtype
==
"auto"
ipex
.
llm
.
modules
.
PagedAttention
.
reshape_and_cache
(
ipex
.
llm
.
modules
.
PagedAttention
.
reshape_and_cache
(
...
...
vllm/attention/backends/abstract.py
View file @
978aed53
...
@@ -134,7 +134,8 @@ class AttentionImpl(ABC, Generic[T]):
...
@@ -134,7 +134,8 @@ class AttentionImpl(ABC, Generic[T]):
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
T
,
attn_metadata
:
T
,
kv_scale
:
float
=
1.0
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
raise
NotImplementedError
raise
NotImplementedError
vllm/attention/backends/blocksparse_attn.py
View file @
978aed53
...
@@ -327,7 +327,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
...
@@ -327,7 +327,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
BlocksparseFlashAttentionMetadata
,
attn_metadata
:
BlocksparseFlashAttentionMetadata
,
kv_scale
:
float
=
1.0
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention and PagedAttention.
"""Forward pass with FlashAttention and PagedAttention.
...
@@ -368,7 +369,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
...
@@ -368,7 +369,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
value_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
kv_scale
,
k_scale
,
v_scale
,
)
)
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
...
@@ -405,7 +407,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
...
@@ -405,7 +407,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
self
.
num_kv_heads
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
scale
,
self
.
alibi_slopes
,
self
.
alibi_slopes
,
kv_scale
,
k_scale
,
v_scale
,
tp_rank
=
self
.
tp_rank
,
tp_rank
=
self
.
tp_rank
,
blocksparse_local_blocks
=
self
.
local_blocks
,
blocksparse_local_blocks
=
self
.
local_blocks
,
blocksparse_vert_stride
=
self
.
vert_stride
,
blocksparse_vert_stride
=
self
.
vert_stride
,
...
...
vllm/attention/backends/flash_attn.py
View file @
978aed53
...
@@ -256,7 +256,8 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -256,7 +256,8 @@ class FlashAttentionImpl(AttentionImpl):
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashAttentionMetadata
,
attn_metadata
:
FlashAttentionMetadata
,
kv_scale
:
float
=
1.0
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention.
"""Forward pass with FlashAttention.
...
@@ -277,7 +278,8 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -277,7 +278,8 @@ class FlashAttentionImpl(AttentionImpl):
"FlashAttentionImpl"
)
"FlashAttentionImpl"
)
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert
kv_scale
==
1.0
,
"kv_scale is not supported in FlashAttention."
assert
k_scale
==
1.0
and
v_scale
==
1.0
,
(
"key/v_scale is not supported in FlashAttention."
)
num_tokens
,
hidden_size
=
query
.
shape
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
# Reshape the query, key, and value tensors.
...
...
vllm/attention/backends/flashinfer.py
View file @
978aed53
...
@@ -223,10 +223,12 @@ class FlashInferImpl(AttentionImpl):
...
@@ -223,10 +223,12 @@ class FlashInferImpl(AttentionImpl):
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
FlashInferMetadata
,
attn_metadata
:
FlashInferMetadata
,
kv_scale
:
float
=
1.0
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
kv_scale
==
1.0
assert
k_scale
==
1.0
and
v_scale
==
1.0
,
(
"key/v_scale is not supported in FlashInfer."
)
if
attn_type
!=
AttentionType
.
DECODER
:
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"encoder/decoder cross-attention "
...
...
vllm/attention/backends/ipex_attn.py
View file @
978aed53
...
@@ -156,7 +156,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
...
@@ -156,7 +156,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
IpexAttnMetadata
,
# type: ignore
attn_metadata
:
IpexAttnMetadata
,
# type: ignore
kv_scale
:
float
=
1.0
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with IPEX varlen_attention and PagedAttention.
"""Forward pass with IPEX varlen_attention and PagedAttention.
...
@@ -170,7 +171,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
...
@@ -170,7 +171,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
Returns:
Returns:
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
assert
kv_scale
==
1.0
assert
k
_scale
==
1.0
and
v_scale
==
1.0
if
attn_type
!=
AttentionType
.
DECODER
:
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"encoder/decoder cross-attention "
...
@@ -192,7 +193,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
...
@@ -192,7 +193,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
value_cache
,
value_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
attn_metadata
.
slot_mapping
.
flatten
(),
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
kv_scale
,
k_scale
,
v_scale
,
)
)
if
attn_metadata
.
is_prompt
:
if
attn_metadata
.
is_prompt
:
...
@@ -273,7 +275,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
...
@@ -273,7 +275,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
max_seq_len
,
max_seq_len
,
self
.
alibi_slopes
,
self
.
alibi_slopes
,
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
kv_scale
,
k_scale
,
v_scale
,
)
)
else
:
else
:
# Run PagedAttention V2.
# Run PagedAttention V2.
...
@@ -305,7 +308,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
...
@@ -305,7 +308,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
max_seq_len
,
max_seq_len
,
self
.
alibi_slopes
,
self
.
alibi_slopes
,
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
kv_scale
,
k_scale
,
v_scale
,
)
)
# Reshape the output tensor.
# Reshape the output tensor.
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment