Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1591c68f
Commit
1591c68f
authored
May 25, 2024
by
zhuwenwen
Browse files
merge v0.4.2
parents
09bcf00b
c7f2cf2b
Changes
265
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2129 additions
and
144 deletions
+2129
-144
benchmarks/kernels/benchmark_mixtral_moe.py
benchmarks/kernels/benchmark_mixtral_moe.py
+71
-38
benchmarks/kernels/benchmark_paged_attention.py
benchmarks/kernels/benchmark_paged_attention.py
+12
-13
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+38
-38
csrc/cache.h
csrc/cache.h
+8
-0
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+80
-0
csrc/cpu/attention.cpp
csrc/cpu/attention.cpp
+46
-46
csrc/ops.h
csrc/ops.h
+30
-5
csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu
csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu
+1
-0
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
+1
-0
csrc/punica/bgmv/bgmv_config.h
csrc/punica/bgmv/bgmv_config.h
+78
-0
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
+1
-0
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
+1
-0
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
+1
-0
csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
+1
-0
csrc/punica/bgmv/bgmv_impl.cuh
csrc/punica/bgmv/bgmv_impl.cuh
+4
-1
csrc/punica/bgmv/generator.py
csrc/punica/bgmv/generator.py
+1
-0
csrc/punica/punica_ops.cc
csrc/punica/punica_ops.cc
+1
-1
csrc/pybind.cpp
csrc/pybind.cpp
+8
-1
csrc/quantization/fp8/fp8_cuda_kernels.cu
csrc/quantization/fp8/fp8_cuda_kernels.cu
+24
-1
csrc/quantization/gptq_marlin/gptq_marlin.cu
csrc/quantization/gptq_marlin/gptq_marlin.cu
+1722
-0
No files found.
benchmarks/kernels/benchmark_mixtral_moe.py
View file @
1591c68f
import
argparse
import
json
import
json
import
os
import
os
import
sys
import
sys
...
@@ -5,6 +6,7 @@ import sys
...
@@ -5,6 +6,7 @@ import sys
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
triton
import
triton
from
tqdm
import
tqdm
from
vllm.model_executor.layers.fused_moe
import
(
fused_moe
,
from
vllm.model_executor.layers.fused_moe
import
(
fused_moe
,
get_config_file_name
)
get_config_file_name
)
...
@@ -12,16 +14,16 @@ from vllm.model_executor.layers.fused_moe import (fused_moe,
...
@@ -12,16 +14,16 @@ from vllm.model_executor.layers.fused_moe import (fused_moe,
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'0'
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'0'
def
main
():
def
main
(
dtype
:
str
):
method
=
fused_moe
method
=
fused_moe
for
bs
in
[
for
bs
in
[
1
,
2
,
4
,
8
,
16
,
24
,
32
,
48
,
64
,
96
,
128
,
256
,
512
,
1024
,
1536
,
1
,
2
,
4
,
8
,
16
,
24
,
32
,
48
,
64
,
96
,
128
,
256
,
512
,
1024
,
1536
,
2048
,
3072
,
4096
2048
,
3072
,
4096
]:
]:
run_grid
(
bs
,
method
=
method
)
run_grid
(
bs
,
method
=
method
,
dtype
=
dtype
)
def
run_grid
(
bs
,
method
):
def
run_grid
(
bs
,
method
,
dtype
:
str
):
d_model
=
4096
d_model
=
4096
num_total_experts
=
8
num_total_experts
=
8
top_k
=
2
top_k
=
2
...
@@ -34,39 +36,29 @@ def run_grid(bs, method):
...
@@ -34,39 +36,29 @@ def run_grid(bs, method):
num_trials
=
1
num_trials
=
1
configs
=
[]
configs
=
[]
if
bs
<=
16
:
BLOCK_SIZES_M
=
[
16
]
elif
bs
<=
32
:
BLOCK_SIZES_M
=
[
16
,
32
]
elif
bs
<=
64
:
BLOCK_SIZES_M
=
[
16
,
32
,
64
]
elif
bs
<=
128
:
BLOCK_SIZES_M
=
[
16
,
32
,
64
,
128
]
else
:
BLOCK_SIZES_M
=
[
16
,
32
,
64
,
128
,
256
]
for
block_size_n
in
[
32
,
64
,
128
,
256
]:
for
block_size_n
in
[
32
,
64
,
128
,
256
]:
for
block_size_m
in
BLOCK_SIZES_M
:
for
block_size_m
in
[
16
,
32
,
64
,
128
,
256
]
:
for
block_size_k
in
[
64
,
128
,
256
]:
for
block_size_k
in
[
64
,
128
,
256
]:
for
group_size_m
in
[
1
,
16
,
32
,
64
]:
for
group_size_m
in
[
1
,
16
,
32
,
64
]:
for
num_warps
in
[
4
,
8
]:
for
num_warps
in
[
4
,
8
]:
configs
.
append
({
for
num_stages
in
[
2
,
3
,
4
,
5
]:
"BLOCK_SIZE_M"
:
block_size_m
,
configs
.
append
({
"BLOCK_SIZE_N"
:
block_size_n
,
"BLOCK_SIZE_M"
:
block_size_m
,
"BLOCK_SIZE_K"
:
block_size_k
,
"BLOCK_SIZE_N"
:
block_size_n
,
"GROUP_SIZE_M"
:
group_size_m
,
"BLOCK_SIZE_K"
:
block_size_k
,
"num_warps"
:
num_warps
,
"GROUP_SIZE_M"
:
group_size_m
,
"num_stages"
:
4
,
"num_warps"
:
num_warps
,
})
"num_stages"
:
num_stages
,
})
best_config
=
None
best_config
=
None
best_time_us
=
1e20
best_time_us
=
1e20
for
config
in
configs
:
print
(
f
'
{
tp_size
=
}
{
bs
=
}
'
)
print
(
f
'
{
tp_size
=
}
{
bs
=
}
'
)
print
(
f
'
{
config
}
'
)
for
config
in
tqdm
(
config
s
):
# warmup
# warmup
print
(
'warming up'
)
try
:
try
:
for
_
in
range
(
num_warmup_trials
):
for
_
in
range
(
num_warmup_trials
):
run_timing
(
run_timing
(
...
@@ -79,12 +71,12 @@ def run_grid(bs, method):
...
@@ -79,12 +71,12 @@ def run_grid(bs, method):
model_intermediate_size
=
model_intermediate_size
,
model_intermediate_size
=
model_intermediate_size
,
method
=
method
,
method
=
method
,
config
=
config
,
config
=
config
,
dtype
=
dtype
,
)
)
except
triton
.
runtime
.
autotuner
.
OutOfResources
:
except
triton
.
runtime
.
autotuner
.
OutOfResources
:
continue
continue
# trial
# trial
print
(
'benchmarking'
)
for
_
in
range
(
num_trials
):
for
_
in
range
(
num_trials
):
kernel_dur_ms
=
run_timing
(
kernel_dur_ms
=
run_timing
(
num_calls
=
num_calls
,
num_calls
=
num_calls
,
...
@@ -96,6 +88,7 @@ def run_grid(bs, method):
...
@@ -96,6 +88,7 @@ def run_grid(bs, method):
model_intermediate_size
=
model_intermediate_size
,
model_intermediate_size
=
model_intermediate_size
,
method
=
method
,
method
=
method
,
config
=
config
,
config
=
config
,
dtype
=
dtype
,
)
)
kernel_dur_us
=
1000
*
kernel_dur_ms
kernel_dur_us
=
1000
*
kernel_dur_ms
...
@@ -105,16 +98,18 @@ def run_grid(bs, method):
...
@@ -105,16 +98,18 @@ def run_grid(bs, method):
best_config
=
config
best_config
=
config
best_time_us
=
kernel_dur_us
best_time_us
=
kernel_dur_us
print
(
f
'
{
kernel_dur_us
=
:.
1
f
}
{
model_dur_ms
=
:.
1
f
}
'
tqdm
.
write
(
f
'
{
bs
=
}
{
tp_size
=
}
{
top_k
=
}
{
num_total_experts
=
}
'
f
'
{
kernel_dur_us
=
:.
1
f
}
{
model_dur_ms
=
:.
1
f
}
'
f
'
{
d_model
=
}
{
model_intermediate_size
=
}
{
num_layers
=
}
'
)
f
'
{
bs
=
}
{
tp_size
=
}
{
top_k
=
}
{
num_total_experts
=
}
'
f
'
{
d_model
=
}
{
model_intermediate_size
=
}
{
num_layers
=
}
'
)
print
(
"best_time_us"
,
best_time_us
)
print
(
"best_time_us"
,
best_time_us
)
print
(
"best_config"
,
best_config
)
print
(
"best_config"
,
best_config
)
# holds Dict[str, Dict[str, int]]
# holds Dict[str, Dict[str, int]]
filename
=
get_config_file_name
(
num_total_experts
,
filename
=
get_config_file_name
(
num_total_experts
,
model_intermediate_size
//
tp_size
)
model_intermediate_size
//
tp_size
,
"float8"
if
dtype
==
"float8"
else
None
)
print
(
f
"writing config to file
{
filename
}
"
)
print
(
f
"writing config to file
{
filename
}
"
)
existing_content
=
{}
existing_content
=
{}
if
os
.
path
.
exists
(
filename
):
if
os
.
path
.
exists
(
filename
):
...
@@ -128,27 +123,48 @@ def run_grid(bs, method):
...
@@ -128,27 +123,48 @@ def run_grid(bs, method):
def
run_timing
(
num_calls
:
int
,
bs
:
int
,
d_model
:
int
,
num_total_experts
:
int
,
def
run_timing
(
num_calls
:
int
,
bs
:
int
,
d_model
:
int
,
num_total_experts
:
int
,
top_k
:
int
,
tp_size
:
int
,
model_intermediate_size
:
int
,
method
,
top_k
:
int
,
tp_size
:
int
,
model_intermediate_size
:
int
,
method
,
config
)
->
float
:
config
,
dtype
:
str
)
->
float
:
shard_intermediate_size
=
model_intermediate_size
//
tp_size
shard_intermediate_size
=
model_intermediate_size
//
tp_size
hidden_states
=
torch
.
rand
(
hidden_states
=
torch
.
rand
(
(
bs
,
d_model
),
(
bs
,
d_model
),
device
=
"cuda:0"
,
device
=
"cuda:0"
,
dtype
=
torch
.
b
float16
,
dtype
=
torch
.
float16
,
)
)
w
s
=
torch
.
rand
(
w
1
=
torch
.
rand
(
(
num_total_experts
,
2
*
shard_intermediate_size
,
d_model
),
(
num_total_experts
,
2
*
shard_intermediate_size
,
d_model
),
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
dtype
=
hidden_states
.
dtype
,
)
)
w2
s
=
torch
.
rand
(
w2
=
torch
.
rand
(
(
num_total_experts
,
d_model
,
shard_intermediate_size
),
(
num_total_experts
,
d_model
,
shard_intermediate_size
),
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
dtype
=
hidden_states
.
dtype
,
)
)
w1_scale
=
None
w2_scale
=
None
a1_scale
=
None
a2_scale
=
None
if
dtype
==
"float8"
:
w1
=
w1
.
to
(
torch
.
float8_e4m3fn
)
w2
=
w2
.
to
(
torch
.
float8_e4m3fn
)
w1_scale
=
torch
.
ones
(
num_total_experts
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
ones
(
num_total_experts
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
float32
)
a1_scale
=
torch
.
ones
(
1
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
float32
)
a2_scale
=
torch
.
ones
(
1
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
float32
)
gating_output
=
F
.
softmax
(
torch
.
rand
(
gating_output
=
F
.
softmax
(
torch
.
rand
(
(
num_calls
,
bs
,
num_total_experts
),
(
num_calls
,
bs
,
num_total_experts
),
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
...
@@ -163,13 +179,18 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
...
@@ -163,13 +179,18 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
for
i
in
range
(
num_calls
):
for
i
in
range
(
num_calls
):
hidden_states
=
method
(
hidden_states
=
method
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
w1
=
ws
,
w1
=
w1
,
w2
=
w2s
,
w2
=
w2
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
gating_output
=
gating_output
[
i
],
gating_output
=
gating_output
[
i
],
topk
=
2
,
topk
=
2
,
renormalize
=
True
,
renormalize
=
True
,
inplace
=
True
,
inplace
=
True
,
override_config
=
config
,
override_config
=
config
,
use_fp8
=
dtype
==
"float8"
,
)
)
end_event
.
record
()
end_event
.
record
()
end_event
.
synchronize
()
end_event
.
synchronize
()
...
@@ -179,4 +200,16 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
...
@@ -179,4 +200,16 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
sys
.
exit
(
main
())
parser
=
argparse
.
ArgumentParser
(
prog
=
'benchmark_mixtral_moe'
,
description
=
'Benchmark and tune the fused_moe kernel'
,
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'auto'
,
choices
=
[
'float8'
,
'float16'
],
help
=
'Data type used for fused_moe kernel computations'
,
)
args
=
parser
.
parse_args
()
sys
.
exit
(
main
(
args
.
dtype
))
benchmarks/kernels/benchmark_paged_attention.py
View file @
1591c68f
...
@@ -16,7 +16,7 @@ PARTITION_SIZE = 512
...
@@ -16,7 +16,7 @@ PARTITION_SIZE = 512
def
main
(
def
main
(
version
:
str
,
version
:
str
,
num_seqs
:
int
,
num_seqs
:
int
,
context
_len
:
int
,
seq
_len
:
int
,
num_query_heads
:
int
,
num_query_heads
:
int
,
num_kv_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
...
@@ -48,12 +48,12 @@ def main(
...
@@ -48,12 +48,12 @@ def main(
dtype
=
torch
.
float
,
dtype
=
torch
.
float
,
device
=
device
)
device
=
device
)
context
_lens
=
[
context
_len
for
_
in
range
(
num_seqs
)]
seq
_lens
=
[
seq
_len
for
_
in
range
(
num_seqs
)]
max_
context
_len
=
max
(
context
_lens
)
max_
seq
_len
=
max
(
seq
_lens
)
context
_lens
=
torch
.
tensor
(
context
_lens
,
dtype
=
torch
.
int
,
device
=
device
)
seq
_lens
=
torch
.
tensor
(
seq
_lens
,
dtype
=
torch
.
int
,
device
=
device
)
# Create the block tables.
# Create the block tables.
max_num_blocks_per_seq
=
(
max_
context
_len
+
block_size
-
1
)
//
block_size
max_num_blocks_per_seq
=
(
max_
seq
_len
+
block_size
-
1
)
//
block_size
block_tables
=
[]
block_tables
=
[]
for
_
in
range
(
num_seqs
):
for
_
in
range
(
num_seqs
):
block_table
=
[
block_table
=
[
...
@@ -77,8 +77,7 @@ def main(
...
@@ -77,8 +77,7 @@ def main(
# Prepare for the paged attention kernel.
# Prepare for the paged attention kernel.
output
=
torch
.
empty_like
(
query
)
output
=
torch
.
empty_like
(
query
)
if
version
==
"v2"
:
if
version
==
"v2"
:
num_partitions
=
((
max_context_len
+
PARTITION_SIZE
-
1
)
//
num_partitions
=
((
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
PARTITION_SIZE
)
tmp_output
=
torch
.
empty
(
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_query_heads
,
num_partitions
,
head_size
),
size
=
(
num_seqs
,
num_query_heads
,
num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
dtype
=
output
.
dtype
,
...
@@ -110,9 +109,9 @@ def main(
...
@@ -110,9 +109,9 @@ def main(
num_kv_heads
,
num_kv_heads
,
scale
,
scale
,
block_tables
,
block_tables
,
context
_lens
,
seq
_lens
,
block_size
,
block_size
,
max_
context
_len
,
max_
seq
_len
,
alibi_slopes
,
alibi_slopes
,
kv_cache_dtype
,
kv_cache_dtype
,
kv_scale
,
kv_scale
,
...
@@ -129,9 +128,9 @@ def main(
...
@@ -129,9 +128,9 @@ def main(
num_kv_heads
,
num_kv_heads
,
scale
,
scale
,
block_tables
,
block_tables
,
context
_lens
,
seq
_lens
,
block_size
,
block_size
,
max_
context
_len
,
max_
seq
_len
,
alibi_slopes
,
alibi_slopes
,
kv_cache_dtype
,
kv_cache_dtype
,
kv_scale
,
kv_scale
,
...
@@ -166,7 +165,7 @@ if __name__ == '__main__':
...
@@ -166,7 +165,7 @@ if __name__ == '__main__':
choices
=
[
"v1"
,
"v2"
],
choices
=
[
"v1"
,
"v2"
],
default
=
"v2"
)
default
=
"v2"
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--
context-
len"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--
seq_
len"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--num-query-heads"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--num-query-heads"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--num-kv-heads"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--num-kv-heads"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--head-size"
,
parser
.
add_argument
(
"--head-size"
,
...
@@ -199,7 +198,7 @@ if __name__ == '__main__':
...
@@ -199,7 +198,7 @@ if __name__ == '__main__':
main
(
main
(
version
=
args
.
version
,
version
=
args
.
version
,
num_seqs
=
args
.
batch_size
,
num_seqs
=
args
.
batch_size
,
context
_len
=
args
.
context
_len
,
seq
_len
=
args
.
seq
_len
,
num_query_heads
=
args
.
num_query_heads
,
num_query_heads
=
args
.
num_query_heads
,
num_kv_heads
=
args
.
num_kv_heads
,
num_kv_heads
=
args
.
num_kv_heads
,
head_size
=
args
.
head_size
,
head_size
=
args
.
head_size
,
...
...
csrc/attention/attention_kernels.cu
View file @
1591c68f
...
@@ -104,7 +104,7 @@ __device__ void paged_attention_kernel(
...
@@ -104,7 +104,7 @@ __device__ void paged_attention_kernel(
const
int
num_kv_heads
,
// [num_heads]
const
int
num_kv_heads
,
// [num_heads]
const
float
scale
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
q_stride
,
...
@@ -115,23 +115,23 @@ __device__ void paged_attention_kernel(
...
@@ -115,23 +115,23 @@ __device__ void paged_attention_kernel(
const
int
partition_idx
=
blockIdx
.
z
;
const
int
partition_idx
=
blockIdx
.
z
;
const
int
max_num_partitions
=
gridDim
.
z
;
const
int
max_num_partitions
=
gridDim
.
z
;
constexpr
bool
USE_PARTITIONING
=
PARTITION_SIZE
>
0
;
constexpr
bool
USE_PARTITIONING
=
PARTITION_SIZE
>
0
;
const
int
context_len
=
context
_lens
[
seq_idx
];
const
int
seq_len
=
seq
_lens
[
seq_idx
];
if
(
USE_PARTITIONING
&&
partition_idx
*
PARTITION_SIZE
>=
context
_len
)
{
if
(
USE_PARTITIONING
&&
partition_idx
*
PARTITION_SIZE
>=
seq
_len
)
{
// No work to do. Terminate the thread block.
// No work to do. Terminate the thread block.
return
;
return
;
}
}
const
int
num_
context
_blocks
=
DIVIDE_ROUND_UP
(
context
_len
,
BLOCK_SIZE
);
const
int
num_
seq
_blocks
=
DIVIDE_ROUND_UP
(
seq
_len
,
BLOCK_SIZE
);
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_
context
_blocks
;
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_
seq
_blocks
;
// [start_block_idx, end_block_idx) is the range of blocks to process.
// [start_block_idx, end_block_idx) is the range of blocks to process.
const
int
start_block_idx
=
USE_PARTITIONING
?
partition_idx
*
num_blocks_per_partition
:
0
;
const
int
start_block_idx
=
USE_PARTITIONING
?
partition_idx
*
num_blocks_per_partition
:
0
;
const
int
end_block_idx
=
MIN
(
start_block_idx
+
num_blocks_per_partition
,
num_
context
_blocks
);
const
int
end_block_idx
=
MIN
(
start_block_idx
+
num_blocks_per_partition
,
num_
seq
_blocks
);
const
int
num_blocks
=
end_block_idx
-
start_block_idx
;
const
int
num_blocks
=
end_block_idx
-
start_block_idx
;
// [start_token_idx, end_token_idx) is the range of tokens to process.
// [start_token_idx, end_token_idx) is the range of tokens to process.
const
int
start_token_idx
=
start_block_idx
*
BLOCK_SIZE
;
const
int
start_token_idx
=
start_block_idx
*
BLOCK_SIZE
;
const
int
end_token_idx
=
MIN
(
start_token_idx
+
num_blocks
*
BLOCK_SIZE
,
context
_len
);
const
int
end_token_idx
=
MIN
(
start_token_idx
+
num_blocks
*
BLOCK_SIZE
,
seq
_len
);
const
int
num_tokens
=
end_token_idx
-
start_token_idx
;
const
int
num_tokens
=
end_token_idx
-
start_token_idx
;
constexpr
int
THREAD_GROUP_SIZE
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
constexpr
int
THREAD_GROUP_SIZE
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
...
@@ -245,12 +245,12 @@ __device__ void paged_attention_kernel(
...
@@ -245,12 +245,12 @@ __device__ void paged_attention_kernel(
// This includes a reduction across the threads in the same thread group.
// This includes a reduction across the threads in the same thread group.
float
qk
=
scale
*
Qk_dot
<
scalar_t
,
THREAD_GROUP_SIZE
>::
dot
(
q_vecs
[
thread_group_offset
],
k_vecs
);
float
qk
=
scale
*
Qk_dot
<
scalar_t
,
THREAD_GROUP_SIZE
>::
dot
(
q_vecs
[
thread_group_offset
],
k_vecs
);
// Add the ALiBi bias if slopes are given.
// Add the ALiBi bias if slopes are given.
qk
+=
(
alibi_slope
!=
0
)
?
alibi_slope
*
(
token_idx
-
context
_len
+
1
)
:
0
;
qk
+=
(
alibi_slope
!=
0
)
?
alibi_slope
*
(
token_idx
-
seq
_len
+
1
)
:
0
;
if
(
thread_group_offset
==
0
)
{
if
(
thread_group_offset
==
0
)
{
// Store the partial reductions to shared memory.
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
// NOTE(woosuk): It is required to zero out the masked logits.
const
bool
mask
=
token_idx
>=
context
_len
;
const
bool
mask
=
token_idx
>=
seq
_len
;
logits
[
token_idx
-
start_token_idx
]
=
mask
?
0.
f
:
qk
;
logits
[
token_idx
-
start_token_idx
]
=
mask
?
0.
f
:
qk
;
// Update the max value.
// Update the max value.
qk_max
=
mask
?
qk_max
:
fmaxf
(
qk_max
,
qk
);
qk_max
=
mask
?
qk_max
:
fmaxf
(
qk_max
,
qk
);
...
@@ -364,14 +364,14 @@ __device__ void paged_attention_kernel(
...
@@ -364,14 +364,14 @@ __device__ void paged_attention_kernel(
}
else
{
}
else
{
v_vec
=
*
reinterpret_cast
<
const
V_vec
*>
(
v_ptr
+
offset
);
v_vec
=
*
reinterpret_cast
<
const
V_vec
*>
(
v_ptr
+
offset
);
}
}
if
(
block_idx
==
num_
context
_blocks
-
1
)
{
if
(
block_idx
==
num_
seq
_blocks
-
1
)
{
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
// we should explicitly zero out the values since they may contain NaNs.
// we should explicitly zero out the values since they may contain NaNs.
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
V_VEC_SIZE
;
j
++
)
{
for
(
int
j
=
0
;
j
<
V_VEC_SIZE
;
j
++
)
{
v_vec_ptr
[
j
]
=
token_idx
+
j
<
context
_len
?
v_vec_ptr
[
j
]
:
zero_value
;
v_vec_ptr
[
j
]
=
token_idx
+
j
<
seq
_len
?
v_vec_ptr
[
j
]
:
zero_value
;
}
}
}
}
accs
[
i
]
+=
dot
(
logits_vec
,
v_vec
);
accs
[
i
]
+=
dot
(
logits_vec
,
v_vec
);
...
@@ -457,7 +457,7 @@ __global__ void paged_attention_v1_kernel(
...
@@ -457,7 +457,7 @@ __global__ void paged_attention_v1_kernel(
const
int
num_kv_heads
,
// [num_heads]
const
int
num_kv_heads
,
// [num_heads]
const
float
scale
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
q_stride
,
...
@@ -466,7 +466,7 @@ __global__ void paged_attention_v1_kernel(
...
@@ -466,7 +466,7 @@ __global__ void paged_attention_v1_kernel(
const
float
kv_scale
)
{
const
float
kv_scale
)
{
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
IS_FP8_KV_CACHE
>
(
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
IS_FP8_KV_CACHE
>
(
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
v_cache
,
num_kv_heads
,
scale
,
block_tables
,
context
_lens
,
out
,
q
,
k_cache
,
v_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq
_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
kv_scale
);
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
kv_scale
);
}
}
...
@@ -489,7 +489,7 @@ __global__ void paged_attention_v2_kernel(
...
@@ -489,7 +489,7 @@ __global__ void paged_attention_v2_kernel(
const
int
num_kv_heads
,
// [num_heads]
const
int
num_kv_heads
,
// [num_heads]
const
float
scale
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
q_stride
,
...
@@ -498,7 +498,7 @@ __global__ void paged_attention_v2_kernel(
...
@@ -498,7 +498,7 @@ __global__ void paged_attention_v2_kernel(
const
float
kv_scale
)
{
const
float
kv_scale
)
{
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
IS_FP8_KV_CACHE
,
PARTITION_SIZE
>
(
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
IS_FP8_KV_CACHE
,
PARTITION_SIZE
>
(
exp_sums
,
max_logits
,
tmp_out
,
q
,
k_cache
,
v_cache
,
num_kv_heads
,
scale
,
exp_sums
,
max_logits
,
tmp_out
,
q
,
k_cache
,
v_cache
,
num_kv_heads
,
scale
,
block_tables
,
context
_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
block_tables
,
seq
_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
kv_scale
);
q_stride
,
kv_block_stride
,
kv_head_stride
,
kv_scale
);
}
}
...
@@ -513,13 +513,13 @@ __global__ void paged_attention_v2_reduce_kernel(
...
@@ -513,13 +513,13 @@ __global__ void paged_attention_v2_reduce_kernel(
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
const
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
const
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
const
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
const
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
max_num_partitions
)
{
const
int
max_num_partitions
)
{
const
int
num_heads
=
gridDim
.
x
;
const
int
num_heads
=
gridDim
.
x
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
seq_idx
=
blockIdx
.
y
;
const
int
seq_idx
=
blockIdx
.
y
;
const
int
context_len
=
context
_lens
[
seq_idx
];
const
int
seq_len
=
seq
_lens
[
seq_idx
];
const
int
num_partitions
=
DIVIDE_ROUND_UP
(
context
_len
,
PARTITION_SIZE
);
const
int
num_partitions
=
DIVIDE_ROUND_UP
(
seq
_len
,
PARTITION_SIZE
);
if
(
num_partitions
==
1
)
{
if
(
num_partitions
==
1
)
{
// No need to reduce. Only copy tmp_out to out.
// No need to reduce. Only copy tmp_out to out.
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
...
@@ -616,7 +616,7 @@ __global__ void paged_attention_v2_reduce_kernel(
...
@@ -616,7 +616,7 @@ __global__ void paged_attention_v2_reduce_kernel(
num_kv_heads, \
num_kv_heads, \
scale, \
scale, \
block_tables_ptr, \
block_tables_ptr, \
context
_lens_ptr, \
seq
_lens_ptr,
\
max_num_blocks_per_seq, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
alibi_slopes_ptr, \
q_stride, \
q_stride, \
...
@@ -639,8 +639,8 @@ void paged_attention_v1_launcher(
...
@@ -639,8 +639,8 @@ void paged_attention_v1_launcher(
int
num_kv_heads
,
int
num_kv_heads
,
float
scale
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context
_lens
,
torch
::
Tensor
&
seq
_lens
,
int
max_
context
_len
,
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
kv_scale
)
{
float
kv_scale
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_seqs
=
query
.
size
(
0
);
...
@@ -664,11 +664,11 @@ void paged_attention_v1_launcher(
...
@@ -664,11 +664,11 @@ void paged_attention_v1_launcher(
CACHE_T
*
key_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
key_cache
.
data_ptr
());
CACHE_T
*
key_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
key_cache
.
data_ptr
());
CACHE_T
*
value_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
value_cache
.
data_ptr
());
CACHE_T
*
value_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
value_cache
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
context
_lens_ptr
=
context
_lens
.
data_ptr
<
int
>
();
int
*
seq
_lens_ptr
=
seq
_lens
.
data_ptr
<
int
>
();
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
padded_max_
context
_len
=
DIVIDE_ROUND_UP
(
max_
context
_len
,
BLOCK_SIZE
)
*
BLOCK_SIZE
;
int
padded_max_
seq
_len
=
DIVIDE_ROUND_UP
(
max_
seq
_len
,
BLOCK_SIZE
)
*
BLOCK_SIZE
;
int
logits_size
=
padded_max_
context
_len
*
sizeof
(
float
);
int
logits_size
=
padded_max_
seq
_len
*
sizeof
(
float
);
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here!
// Keep that in sync with the logic here!
...
@@ -715,8 +715,8 @@ void paged_attention_v1_launcher(
...
@@ -715,8 +715,8 @@ void paged_attention_v1_launcher(
num_kv_heads, \
num_kv_heads, \
scale, \
scale, \
block_tables, \
block_tables, \
context
_lens, \
seq
_lens, \
max_
context
_len, \
max_
seq
_len, \
alibi_slopes, \
alibi_slopes, \
kv_scale);
kv_scale);
...
@@ -746,9 +746,9 @@ void paged_attention_v1(
...
@@ -746,9 +746,9 @@ void paged_attention_v1(
int
num_kv_heads
,
// [num_heads]
int
num_kv_heads
,
// [num_heads]
float
scale
,
float
scale
,
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
context
_lens
,
// [num_seqs]
torch
::
Tensor
&
seq
_lens
,
// [num_seqs]
int
block_size
,
int
block_size
,
int
max_
context
_len
,
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
)
{
float
kv_scale
)
{
...
@@ -790,7 +790,7 @@ void paged_attention_v1(
...
@@ -790,7 +790,7 @@ void paged_attention_v1(
num_kv_heads, \
num_kv_heads, \
scale, \
scale, \
block_tables_ptr, \
block_tables_ptr, \
context
_lens_ptr, \
seq
_lens_ptr, \
max_num_blocks_per_seq, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
alibi_slopes_ptr, \
q_stride, \
q_stride, \
...
@@ -803,7 +803,7 @@ void paged_attention_v1(
...
@@ -803,7 +803,7 @@ void paged_attention_v1(
exp_sums_ptr, \
exp_sums_ptr, \
max_logits_ptr, \
max_logits_ptr, \
tmp_out_ptr, \
tmp_out_ptr, \
context
_lens_ptr, \
seq
_lens_ptr, \
max_num_partitions);
max_num_partitions);
template
<
template
<
...
@@ -824,8 +824,8 @@ void paged_attention_v2_launcher(
...
@@ -824,8 +824,8 @@ void paged_attention_v2_launcher(
int
num_kv_heads
,
int
num_kv_heads
,
float
scale
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context
_lens
,
torch
::
Tensor
&
seq
_lens
,
int
max_
context
_len
,
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
float
kv_scale
)
{
float
kv_scale
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_seqs
=
query
.
size
(
0
);
...
@@ -852,10 +852,10 @@ void paged_attention_v2_launcher(
...
@@ -852,10 +852,10 @@ void paged_attention_v2_launcher(
CACHE_T
*
key_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
key_cache
.
data_ptr
());
CACHE_T
*
key_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
key_cache
.
data_ptr
());
CACHE_T
*
value_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
value_cache
.
data_ptr
());
CACHE_T
*
value_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
value_cache
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
context
_lens_ptr
=
context
_lens
.
data_ptr
<
int
>
();
int
*
seq
_lens_ptr
=
seq
_lens
.
data_ptr
<
int
>
();
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_
context
_len
,
PARTITION_SIZE
);
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_
seq
_len
,
PARTITION_SIZE
);
int
logits_size
=
PARTITION_SIZE
*
sizeof
(
float
);
int
logits_size
=
PARTITION_SIZE
*
sizeof
(
float
);
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
...
@@ -909,8 +909,8 @@ void paged_attention_v2_launcher(
...
@@ -909,8 +909,8 @@ void paged_attention_v2_launcher(
num_kv_heads, \
num_kv_heads, \
scale, \
scale, \
block_tables, \
block_tables, \
context
_lens, \
seq
_lens, \
max_
context
_len, \
max_
seq
_len, \
alibi_slopes, \
alibi_slopes, \
kv_scale);
kv_scale);
...
@@ -943,9 +943,9 @@ void paged_attention_v2(
...
@@ -943,9 +943,9 @@ void paged_attention_v2(
int
num_kv_heads
,
// [num_heads]
int
num_kv_heads
,
// [num_heads]
float
scale
,
float
scale
,
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
context
_lens
,
// [num_seqs]
torch
::
Tensor
&
seq
_lens
,
// [num_seqs]
int
block_size
,
int
block_size
,
int
max_
context
_len
,
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
)
{
float
kv_scale
)
{
...
...
csrc/cache.h
View file @
1591c68f
...
@@ -24,6 +24,14 @@ void reshape_and_cache(
...
@@ -24,6 +24,14 @@ void reshape_and_cache(
const
std
::
string
&
kv_cache_dtype
,
const
std
::
string
&
kv_cache_dtype
,
const
float
kv_scale
);
const
float
kv_scale
);
void
reshape_and_cache_flash
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
value
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
slot_mapping
,
const
std
::
string
&
kv_cache_dtype
);
// Just for unittest
// Just for unittest
void
convert_fp8
(
void
convert_fp8
(
torch
::
Tensor
&
src_cache
,
torch
::
Tensor
&
src_cache
,
...
...
csrc/cache_kernels.cu
View file @
1591c68f
...
@@ -215,6 +215,41 @@ __global__ void reshape_and_cache_kernel(
...
@@ -215,6 +215,41 @@ __global__ void reshape_and_cache_kernel(
}
}
}
}
template
<
typename
scalar_t
>
__global__
void
reshape_and_cache_flash_kernel
(
const
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
value
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
k_cache
,
// [num_blocks, block_size, num_heads, head_size]
scalar_t
*
__restrict__
v_cache
,
// [num_blocks, block_size, num_heads, head_size]
const
int64_t
*
__restrict__
slot_mapping
,
// [num_tokens]
const
int
block_stride
,
const
int
key_stride
,
const
int
value_stride
,
const
int
num_heads
,
const
int
head_size
,
const
int
block_size
)
{
const
int64_t
token_idx
=
blockIdx
.
x
;
const
int64_t
slot_idx
=
slot_mapping
[
token_idx
];
// NOTE: slot_idx can be -1 if the token is padded
if
(
slot_idx
<
0
)
{
return
;
}
const
int64_t
block_idx
=
slot_idx
/
block_size
;
const
int64_t
block_offset
=
slot_idx
%
block_size
;
const
int
n
=
num_heads
*
head_size
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
)
{
const
int64_t
src_key_idx
=
token_idx
*
key_stride
+
i
;
const
int64_t
src_value_idx
=
token_idx
*
value_stride
+
i
;
const
int
head_idx
=
i
/
head_size
;
const
int
head_offset
=
i
%
head_size
;
const
int64_t
tgt_value_idx
=
block_idx
*
block_stride
+
block_offset
*
num_heads
*
head_size
+
head_idx
*
head_size
+
head_offset
;
k_cache
[
tgt_value_idx
]
=
key
[
src_key_idx
];
v_cache
[
tgt_value_idx
]
=
value
[
src_value_idx
];
}
}
}
// namespace vllm
}
// namespace vllm
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \
...
@@ -275,6 +310,51 @@ void reshape_and_cache(
...
@@ -275,6 +310,51 @@ void reshape_and_cache(
}
}
}
}
void
reshape_and_cache_flash
(
torch
::
Tensor
&
key
,
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
value
,
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
k_cache
,
// [num_blocks, block_size, num_heads, head_size]
torch
::
Tensor
&
v_cache
,
// [num_blocks, block_size, num_heads, head_size]
torch
::
Tensor
&
slot_mapping
,
// [num_tokens]
const
std
::
string
&
kv_cache_dtype
)
{
// FIXME: only support auto datatype, does not support fp8
if
(
kv_cache_dtype
!=
"auto"
)
{
TORCH_CHECK
(
false
,
"Unsupported data type of kv cache: "
,
kv_cache_dtype
);
}
int
num_tokens
=
key
.
size
(
0
);
int
num_heads
=
key
.
size
(
1
);
int
head_size
=
key
.
size
(
2
);
int
block_size
=
k_cache
.
size
(
1
);
int
key_stride
=
key
.
stride
(
0
);
int
value_stride
=
value
.
stride
(
0
);
int
block_stride
=
k_cache
.
stride
(
0
);
TORCH_CHECK
(
k_cache
.
stride
(
0
)
==
v_cache
.
stride
(
0
));
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
key
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
key
.
scalar_type
(),
"reshape_and_cache_flash"
,
[
&
]
{
vllm
::
reshape_and_cache_flash_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
key
.
data_ptr
<
scalar_t
>
(),
value
.
data_ptr
<
scalar_t
>
(),
k_cache
.
data_ptr
<
scalar_t
>
(),
v_cache
.
data_ptr
<
scalar_t
>
(),
slot_mapping
.
data_ptr
<
int64_t
>
(),
block_stride
,
key_stride
,
value_stride
,
num_heads
,
head_size
,
block_size
);
});
}
namespace
vllm
{
namespace
vllm
{
template
<
typename
Tout
,
typename
Tin
>
template
<
typename
Tout
,
typename
Tin
>
...
...
csrc/cpu/attention.cpp
View file @
1591c68f
...
@@ -70,11 +70,11 @@ template <typename T>
...
@@ -70,11 +70,11 @@ template <typename T>
FORCE_INLINE
std
::
pair
<
T
,
T
>
FORCE_INLINE
std
::
pair
<
T
,
T
>
reduceSoftmaxAlibi
(
T
*
data
,
const
int
size
,
const
int
capacity
,
reduceSoftmaxAlibi
(
T
*
data
,
const
int
size
,
const
int
capacity
,
const
float
alibi_slope
,
const
int
start_index
,
const
float
alibi_slope
,
const
int
start_index
,
const
int
context
_len
)
{
const
int
seq
_len
)
{
data
[
0
]
+=
alibi_slope
*
(
start_index
-
context
_len
+
1
);
data
[
0
]
+=
alibi_slope
*
(
start_index
-
seq
_len
+
1
);
T
max
=
data
[
0
];
T
max
=
data
[
0
];
for
(
int
i
=
1
;
i
<
size
;
++
i
)
{
for
(
int
i
=
1
;
i
<
size
;
++
i
)
{
T
qk
=
data
[
i
]
+
alibi_slope
*
(
start_index
+
i
-
context
_len
+
1
);
T
qk
=
data
[
i
]
+
alibi_slope
*
(
start_index
+
i
-
seq
_len
+
1
);
data
[
i
]
=
qk
;
data
[
i
]
=
qk
;
max
=
max
>=
qk
?
max
:
qk
;
max
=
max
>=
qk
?
max
:
qk
;
}
}
...
@@ -225,7 +225,7 @@ struct paged_attention_v1_impl {
...
@@ -225,7 +225,7 @@ struct paged_attention_v1_impl {
const
int
num_kv_heads
,
const
float
scale
,
const
int
num_kv_heads
,
const
float
scale
,
const
int
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
...
@@ -235,32 +235,32 @@ struct paged_attention_v1_impl {
...
@@ -235,32 +235,32 @@ struct paged_attention_v1_impl {
static_assert
(
BLOCK_SIZE
==
16
);
static_assert
(
BLOCK_SIZE
==
16
);
int
max_
context
_len
=
max_num_blocks_per_seq
*
BLOCK_SIZE
;
int
max_
seq
_len
=
max_num_blocks_per_seq
*
BLOCK_SIZE
;
int
max_
context
_len_padded
=
(
max_
context
_len
+
15
)
&
0xFFFFFFF0
;
int
max_
seq
_len_padded
=
(
max_
seq
_len
+
15
)
&
0xFFFFFFF0
;
TORCH_CHECK
((
max_
context
_len_padded
*
sizeof
(
float
))
%
64
==
0
);
TORCH_CHECK
((
max_
seq
_len_padded
*
sizeof
(
float
))
%
64
==
0
);
const
int
parallel_work_item_num
=
omp_get_max_threads
();
const
int
parallel_work_item_num
=
omp_get_max_threads
();
size_t
logits_bytes
=
size_t
logits_bytes
=
parallel_work_item_num
*
max_
context
_len_padded
*
sizeof
(
float
);
parallel_work_item_num
*
max_
seq
_len_padded
*
sizeof
(
float
);
float
*
logits
=
(
float
*
)
std
::
aligned_alloc
(
float
*
logits
=
(
float
*
)
std
::
aligned_alloc
(
64
,
logits_bytes
);
// Cacheline alignment for each context token.
64
,
logits_bytes
);
// Cacheline alignment for each context token.
// [parallel_work_item_num, max_
context
_len_padded]
// [parallel_work_item_num, max_
seq
_len_padded]
#pragma omp parallel for collapse(2) schedule(dynamic, 1)
#pragma omp parallel for collapse(2) schedule(dynamic, 1)
for
(
int
seq_idx
=
0
;
seq_idx
<
num_seqs
;
++
seq_idx
)
{
for
(
int
seq_idx
=
0
;
seq_idx
<
num_seqs
;
++
seq_idx
)
{
for
(
int
head_idx
=
0
;
head_idx
<
num_heads
;
++
head_idx
)
{
for
(
int
head_idx
=
0
;
head_idx
<
num_heads
;
++
head_idx
)
{
int
context_len
=
context
_lens
[
seq_idx
];
int
seq_len
=
seq
_lens
[
seq_idx
];
const
int
*
seq_block_table
=
const
int
*
seq_block_table
=
block_tables
+
max_num_blocks_per_seq
*
seq_idx
;
block_tables
+
max_num_blocks_per_seq
*
seq_idx
;
const
int
block_num
=
(
context
_len
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
;
const
int
block_num
=
(
seq
_len
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
;
const
int64_t
kv_head_idx
=
head_idx
/
num_queries_per_kv
;
const
int64_t
kv_head_idx
=
head_idx
/
num_queries_per_kv
;
const
scalar_t
*
__restrict__
q_vec_ptr
=
const
scalar_t
*
__restrict__
q_vec_ptr
=
q
+
seq_idx
*
q_stride
+
head_idx
*
HEAD_SIZE
;
q
+
seq_idx
*
q_stride
+
head_idx
*
HEAD_SIZE
;
const
int
last_block_token_num
=
const
int
last_block_token_num
=
context
_len
-
(
block_num
-
1
)
*
BLOCK_SIZE
;
seq
_len
-
(
block_num
-
1
)
*
BLOCK_SIZE
;
float
*
__restrict__
thread_block_logits
=
float
*
__restrict__
thread_block_logits
=
logits
+
omp_get_thread_num
()
*
max_
context
_len_padded
;
logits
+
omp_get_thread_num
()
*
max_
seq
_len_padded
;
// Compute logits
// Compute logits
for
(
int
block_idx
=
0
;
block_idx
<
block_num
;
++
block_idx
)
{
for
(
int
block_idx
=
0
;
block_idx
<
block_num
;
++
block_idx
)
{
...
@@ -278,11 +278,11 @@ struct paged_attention_v1_impl {
...
@@ -278,11 +278,11 @@ struct paged_attention_v1_impl {
// Compute softmax
// Compute softmax
if
(
alibi_slopes
)
{
if
(
alibi_slopes
)
{
reduceSoftmaxAlibi
(
thread_block_logits
,
context
_len
,
reduceSoftmaxAlibi
(
thread_block_logits
,
seq
_len
,
block_num
*
BLOCK_SIZE
,
alibi_slopes
[
head_idx
],
0
,
block_num
*
BLOCK_SIZE
,
alibi_slopes
[
head_idx
],
0
,
context
_len
);
seq
_len
);
}
else
{
}
else
{
reduceSoftmax
(
thread_block_logits
,
context
_len
,
reduceSoftmax
(
thread_block_logits
,
seq
_len
,
block_num
*
BLOCK_SIZE
);
block_num
*
BLOCK_SIZE
);
}
}
...
@@ -340,7 +340,7 @@ struct paged_attention_v1_impl {
...
@@ -340,7 +340,7 @@ struct paged_attention_v1_impl {
#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \
paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr,
context
_lens_ptr, max_num_blocks_per_seq, \
block_tables_ptr,
seq
_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \
num_heads);
num_heads);
...
@@ -348,8 +348,8 @@ template <typename T, int BLOCK_SIZE>
...
@@ -348,8 +348,8 @@ template <typename T, int BLOCK_SIZE>
void
paged_attention_v1_impl_launcher
(
void
paged_attention_v1_impl_launcher
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context
_lens
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq
_lens
,
int
max_
context
_len
,
const
c10
::
optional
<
torch
::
Tensor
>
&
alibi_slopes
)
{
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>
&
alibi_slopes
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_seqs
=
query
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
int
head_size
=
query
.
size
(
2
);
...
@@ -369,7 +369,7 @@ void paged_attention_v1_impl_launcher(
...
@@ -369,7 +369,7 @@ void paged_attention_v1_impl_launcher(
T
*
key_cache_ptr
=
reinterpret_cast
<
T
*>
(
key_cache
.
data_ptr
());
T
*
key_cache_ptr
=
reinterpret_cast
<
T
*>
(
key_cache
.
data_ptr
());
T
*
value_cache_ptr
=
reinterpret_cast
<
T
*>
(
value_cache
.
data_ptr
());
T
*
value_cache_ptr
=
reinterpret_cast
<
T
*>
(
value_cache
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
context
_lens_ptr
=
context
_lens
.
data_ptr
<
int
>
();
int
*
seq
_lens_ptr
=
seq
_lens
.
data_ptr
<
int
>
();
switch
(
head_size
)
{
switch
(
head_size
)
{
case
64
:
case
64
:
...
@@ -399,7 +399,7 @@ void paged_attention_v1_impl_launcher(
...
@@ -399,7 +399,7 @@ void paged_attention_v1_impl_launcher(
#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
context
_lens, max_
context
_len, alibi_slopes);
seq
_lens, max_
seq
_len, alibi_slopes);
#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
switch (block_size) { \
...
@@ -416,8 +416,8 @@ void paged_attention_v1(torch::Tensor &out, torch::Tensor &query,
...
@@ -416,8 +416,8 @@ void paged_attention_v1(torch::Tensor &out, torch::Tensor &query,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context
_lens
,
int
block_size
,
torch
::
Tensor
&
seq
_lens
,
int
block_size
,
int
max_
context
_len
,
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>
&
alibi_slopes
,
const
c10
::
optional
<
torch
::
Tensor
>
&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
)
{
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
)
{
TORCH_CHECK
(
kv_scale
==
1.0
f
);
TORCH_CHECK
(
kv_scale
==
1.0
f
);
...
@@ -448,7 +448,7 @@ struct paged_attention_v2_impl {
...
@@ -448,7 +448,7 @@ struct paged_attention_v2_impl {
const
int
num_kv_heads
,
const
float
scale
,
const
int
num_kv_heads
,
const
float
scale
,
const
int
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
...
@@ -465,22 +465,22 @@ struct paged_attention_v2_impl {
...
@@ -465,22 +465,22 @@ struct paged_attention_v2_impl {
for
(
int
partition_idx
=
0
;
partition_idx
<
max_num_partitions
;
for
(
int
partition_idx
=
0
;
partition_idx
<
max_num_partitions
;
++
partition_idx
)
{
++
partition_idx
)
{
for
(
int
head_idx
=
0
;
head_idx
<
num_heads
;
++
head_idx
)
{
for
(
int
head_idx
=
0
;
head_idx
<
num_heads
;
++
head_idx
)
{
const
int
context_len
=
context
_lens
[
seq_idx
];
const
int
seq_len
=
seq
_lens
[
seq_idx
];
const
int
start_token_idx
=
partition_idx
*
PARTITION_SIZE
;
const
int
start_token_idx
=
partition_idx
*
PARTITION_SIZE
;
if
(
start_token_idx
>=
context
_len
)
if
(
start_token_idx
>=
seq
_len
)
continue
;
continue
;
const
int
partition_num
=
const
int
partition_num
=
(
context
_len
+
PARTITION_SIZE
-
1
)
/
PARTITION_SIZE
;
(
seq
_len
+
PARTITION_SIZE
-
1
)
/
PARTITION_SIZE
;
const
bool
no_reduce
=
(
partition_num
==
1
);
const
bool
no_reduce
=
(
partition_num
==
1
);
const
int
context_
token_num
=
const
int
token_num
=
(
std
::
min
(
context
_len
,
start_token_idx
+
PARTITION_SIZE
)
-
(
std
::
min
(
seq
_len
,
start_token_idx
+
PARTITION_SIZE
)
-
start_token_idx
);
start_token_idx
);
const
int
block_num
=
const
int
block_num
=
(
context_
token_num
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
;
(
token_num
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
;
const
int
last_block_token_num
=
const
int
last_block_token_num
=
context_
token_num
-
(
block_num
-
1
)
*
BLOCK_SIZE
;
token_num
-
(
block_num
-
1
)
*
BLOCK_SIZE
;
const
int
*
seq_block_table
=
block_tables
+
const
int
*
seq_block_table
=
block_tables
+
max_num_blocks_per_seq
*
seq_idx
+
max_num_blocks_per_seq
*
seq_idx
+
start_token_idx
/
BLOCK_SIZE
;
start_token_idx
/
BLOCK_SIZE
;
...
@@ -507,10 +507,10 @@ struct paged_attention_v2_impl {
...
@@ -507,10 +507,10 @@ struct paged_attention_v2_impl {
std
::
pair
<
float
,
float
>
max_and_sum
;
std
::
pair
<
float
,
float
>
max_and_sum
;
if
(
alibi_slopes
)
{
if
(
alibi_slopes
)
{
max_and_sum
=
reduceSoftmaxAlibi
(
max_and_sum
=
reduceSoftmaxAlibi
(
logits
,
context_
token_num
,
block_num
*
BLOCK_SIZE
,
logits
,
token_num
,
block_num
*
BLOCK_SIZE
,
alibi_slopes
[
head_idx
],
start_token_idx
,
context
_len
);
alibi_slopes
[
head_idx
],
start_token_idx
,
seq
_len
);
}
else
{
}
else
{
max_and_sum
=
reduceSoftmax
(
logits
,
context_
token_num
,
max_and_sum
=
reduceSoftmax
(
logits
,
token_num
,
block_num
*
BLOCK_SIZE
);
block_num
*
BLOCK_SIZE
);
}
}
...
@@ -583,9 +583,9 @@ struct paged_attention_v2_impl {
...
@@ -583,9 +583,9 @@ struct paged_attention_v2_impl {
#pragma omp parallel for collapse(2) schedule(static, 1)
#pragma omp parallel for collapse(2) schedule(static, 1)
for
(
int
seq_idx
=
0
;
seq_idx
<
num_seqs
;
++
seq_idx
)
{
for
(
int
seq_idx
=
0
;
seq_idx
<
num_seqs
;
++
seq_idx
)
{
for
(
int
head_idx
=
0
;
head_idx
<
num_heads
;
++
head_idx
)
{
for
(
int
head_idx
=
0
;
head_idx
<
num_heads
;
++
head_idx
)
{
const
int
context_len
=
context
_lens
[
seq_idx
];
const
int
seq_len
=
seq
_lens
[
seq_idx
];
const
int
partition_num
=
const
int
partition_num
=
(
context
_len
+
PARTITION_SIZE
-
1
)
/
PARTITION_SIZE
;
(
seq
_len
+
PARTITION_SIZE
-
1
)
/
PARTITION_SIZE
;
if
(
partition_num
==
1
)
if
(
partition_num
==
1
)
continue
;
continue
;
...
@@ -612,9 +612,9 @@ struct paged_attention_v2_impl {
...
@@ -612,9 +612,9 @@ struct paged_attention_v2_impl {
for
(
int
seq_idx
=
0
;
seq_idx
<
num_seqs
;
++
seq_idx
)
{
for
(
int
seq_idx
=
0
;
seq_idx
<
num_seqs
;
++
seq_idx
)
{
for
(
int
head_idx
=
0
;
head_idx
<
num_heads
;
++
head_idx
)
{
for
(
int
head_idx
=
0
;
head_idx
<
num_heads
;
++
head_idx
)
{
for
(
int
group_idx
=
0
;
group_idx
<
head_group_num
;
++
group_idx
)
{
for
(
int
group_idx
=
0
;
group_idx
<
head_group_num
;
++
group_idx
)
{
const
int
context_len
=
context
_lens
[
seq_idx
];
const
int
seq_len
=
seq
_lens
[
seq_idx
];
const
int
partition_num
=
const
int
partition_num
=
(
context
_len
+
PARTITION_SIZE
-
1
)
/
PARTITION_SIZE
;
(
seq
_len
+
PARTITION_SIZE
-
1
)
/
PARTITION_SIZE
;
if
(
partition_num
==
1
)
if
(
partition_num
==
1
)
continue
;
continue
;
...
@@ -649,7 +649,7 @@ struct paged_attention_v2_impl {
...
@@ -649,7 +649,7 @@ struct paged_attention_v2_impl {
paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
context
_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
seq
_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, num_seqs, num_heads, \
kv_block_stride, kv_head_stride, num_seqs, num_heads, \
max_num_partitions);
max_num_partitions);
...
@@ -658,8 +658,8 @@ void paged_attention_v2_impl_launcher(
...
@@ -658,8 +658,8 @@ void paged_attention_v2_impl_launcher(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context
_lens
,
int
block_size
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq
_lens
,
int
block_size
,
int
max_
context
_len
,
const
c10
::
optional
<
torch
::
Tensor
>
&
alibi_slopes
)
{
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>
&
alibi_slopes
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_seqs
=
query
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
int
head_size
=
query
.
size
(
2
);
...
@@ -683,7 +683,7 @@ void paged_attention_v2_impl_launcher(
...
@@ -683,7 +683,7 @@ void paged_attention_v2_impl_launcher(
T
*
key_cache_ptr
=
reinterpret_cast
<
T
*>
(
key_cache
.
data_ptr
());
T
*
key_cache_ptr
=
reinterpret_cast
<
T
*>
(
key_cache
.
data_ptr
());
T
*
value_cache_ptr
=
reinterpret_cast
<
T
*>
(
value_cache
.
data_ptr
());
T
*
value_cache_ptr
=
reinterpret_cast
<
T
*>
(
value_cache
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
context
_lens_ptr
=
context
_lens
.
data_ptr
<
int
>
();
int
*
seq
_lens_ptr
=
seq
_lens
.
data_ptr
<
int
>
();
switch
(
head_size
)
{
switch
(
head_size
)
{
case
64
:
case
64
:
...
@@ -713,8 +713,8 @@ void paged_attention_v2_impl_launcher(
...
@@ -713,8 +713,8 @@ void paged_attention_v2_impl_launcher(
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables,
context
_lens, block_size, \
num_kv_heads, scale, block_tables,
seq
_lens, block_size, \
max_
context
_len, alibi_slopes);
max_
seq
_len, alibi_slopes);
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
switch (block_size) { \
...
@@ -732,8 +732,8 @@ void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums,
...
@@ -732,8 +732,8 @@ void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context
_lens
,
int
block_size
,
torch
::
Tensor
&
seq
_lens
,
int
block_size
,
int
max_
context
_len
,
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>
&
alibi_slopes
,
const
c10
::
optional
<
torch
::
Tensor
>
&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
)
{
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
)
{
TORCH_CHECK
(
kv_scale
==
1.0
f
);
TORCH_CHECK
(
kv_scale
==
1.0
f
);
...
...
csrc/ops.h
View file @
1591c68f
...
@@ -10,9 +10,9 @@ void paged_attention_v1(
...
@@ -10,9 +10,9 @@ void paged_attention_v1(
int
num_kv_heads
,
int
num_kv_heads
,
float
scale
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context
_lens
,
torch
::
Tensor
&
seq
_lens
,
int
block_size
,
int
block_size
,
int
max_
context
_len
,
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
);
float
kv_scale
);
...
@@ -28,9 +28,9 @@ void paged_attention_v2(
...
@@ -28,9 +28,9 @@ void paged_attention_v2(
int
num_kv_heads
,
int
num_kv_heads
,
float
scale
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context
_lens
,
torch
::
Tensor
&
seq
_lens
,
int
block_size
,
int
block_size
,
int
max_
context
_len
,
int
max_
seq
_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
);
float
kv_scale
);
...
@@ -124,6 +124,26 @@ torch::Tensor marlin_gemm(
...
@@ -124,6 +124,26 @@ torch::Tensor marlin_gemm(
int64_t
size_m
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_n
,
int64_t
size_k
);
int64_t
size_k
);
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
);
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
);
#endif
#endif
void
squeezellm_gemm
(
void
squeezellm_gemm
(
...
@@ -146,7 +166,12 @@ void gptq_shuffle(
...
@@ -146,7 +166,12 @@ void gptq_shuffle(
torch
::
Tensor
q_perm
,
torch
::
Tensor
q_perm
,
int
bit
);
int
bit
);
void
scaled_fp8_quant
(
void
static_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
scale
);
void
dynamic_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
scale
);
torch
::
Tensor
&
scale
);
...
...
csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu
View file @
1591c68f
...
@@ -2,3 +2,4 @@
...
@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_bfloat16
,
nv_bfloat16
,
nv_bfloat16
)
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_bfloat16
,
nv_bfloat16
,
nv_bfloat16
)
FOR_INST_BGMV_WIDE_NARROW
(
INST_BGMV_ONESIDE
,
nv_bfloat16
,
nv_bfloat16
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
View file @
1591c68f
...
@@ -2,3 +2,4 @@
...
@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_bfloat16
,
float
,
nv_bfloat16
)
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_bfloat16
,
float
,
nv_bfloat16
)
FOR_INST_BGMV_WIDE_NARROW
(
INST_BGMV_ONESIDE
,
nv_bfloat16
,
float
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_config.h
View file @
1591c68f
...
@@ -74,6 +74,74 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
...
@@ -74,6 +74,74 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
// and vllm/tests/lora/test_punica.py
// and vllm/tests/lora/test_punica.py
// Used for defining kernels going from the variety of
// dim in to the narrow dim out
// Using it for the fully sharded column
// parallel LoRA A which splits the rank dim
#define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \
f(in_T, out_T, W_T, 128, narrow) \
f(in_T, out_T, W_T, 256, narrow) \
f(in_T, out_T, W_T, 512, narrow) \
f(in_T, out_T, W_T, 640, narrow) \
f(in_T, out_T, W_T, 768, narrow) \
f(in_T, out_T, W_T, 1024, narrow) \
f(in_T, out_T, W_T, 1152, narrow) \
f(in_T, out_T, W_T, 1280, narrow) \
f(in_T, out_T, W_T, 1536, narrow) \
f(in_T, out_T, W_T, 1728, narrow) \
f(in_T, out_T, W_T, 1792, narrow) \
f(in_T, out_T, W_T, 2048, narrow) \
f(in_T, out_T, W_T, 2304, narrow) \
f(in_T, out_T, W_T, 2560, narrow) \
f(in_T, out_T, W_T, 2752, narrow) \
f(in_T, out_T, W_T, 2816, narrow) \
f(in_T, out_T, W_T, 3072, narrow) \
f(in_T, out_T, W_T, 3456, narrow) \
f(in_T, out_T, W_T, 3584, narrow) \
f(in_T, out_T, W_T, 4096, narrow) \
f(in_T, out_T, W_T, 4608, narrow) \
f(in_T, out_T, W_T, 5120, narrow) \
f(in_T, out_T, W_T, 5504, narrow) \
f(in_T, out_T, W_T, 5632, narrow) \
f(in_T, out_T, W_T, 6144, narrow) \
f(in_T, out_T, W_T, 6848, narrow) \
f(in_T, out_T, W_T, 6912, narrow) \
f(in_T, out_T, W_T, 7168, narrow) \
f(in_T, out_T, W_T, 8192, narrow) \
f(in_T, out_T, W_T, 9216, narrow) \
f(in_T, out_T, W_T, 10240, narrow) \
f(in_T, out_T, W_T, 11008, narrow) \
f(in_T, out_T, W_T, 12288, narrow) \
f(in_T, out_T, W_T, 13696, narrow) \
f(in_T, out_T, W_T, 13824, narrow) \
f(in_T, out_T, W_T, 14336, narrow) \
f(in_T, out_T, W_T, 15360, narrow) \
f(in_T, out_T, W_T, 16384, narrow) \
f(in_T, out_T, W_T, 20480, narrow) \
f(in_T, out_T, W_T, 22016, narrow) \
f(in_T, out_T, W_T, 24576, narrow) \
f(in_T, out_T, W_T, 27392, narrow) \
f(in_T, out_T, W_T, 28672, narrow) \
f(in_T, out_T, W_T, 32000, narrow) \
f(in_T, out_T, W_T, 32256, narrow) \
f(in_T, out_T, W_T, 32512, narrow) \
f(in_T, out_T, W_T, 32768, narrow) \
f(in_T, out_T, W_T, 33024, narrow) \
f(in_T, out_T, W_T, 36864, narrow) \
f(in_T, out_T, W_T, 43264, narrow) \
f(in_T, out_T, W_T, 49152, narrow) \
f(in_T, out_T, W_T, 64000, narrow) \
f(in_T, out_T, W_T, 64256, narrow) \
f(in_T, out_T, W_T, 64512, narrow) \
f(in_T, out_T, W_T, 102400, narrow) \
f(in_T, out_T, W_T, 102656, narrow) \
f(in_T, out_T, W_T, 102912, narrow) \
f(in_T, out_T, W_T, 128000, narrow) \
f(in_T, out_T, W_T, 128256, narrow) \
f(in_T, out_T, W_T, 128512, narrow) \
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
// Keep this in sync with vllm/config::LoRAConfig
// Keep this in sync with vllm/config::LoRAConfig
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
...
@@ -81,4 +149,14 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
...
@@ -81,4 +149,14 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
#define FOR_INST_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 1) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 2) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 4) \
f(in_T, out_T, W_T, 8, 64) \
f(in_T, out_T, W_T, 16, 64) \
f(in_T, out_T, W_T, 32, 64) \
f(in_T, out_T, W_T, 64, 64)
// clang-format on
// clang-format on
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
View file @
1591c68f
...
@@ -2,3 +2,4 @@
...
@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
nv_half
,
nv_half
)
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
nv_half
,
nv_half
)
FOR_INST_BGMV_WIDE_NARROW
(
INST_BGMV_ONESIDE
,
nv_half
,
nv_half
,
nv_half
)
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
View file @
1591c68f
...
@@ -2,3 +2,4 @@
...
@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
float
,
nv_half
)
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
float
,
nv_half
)
FOR_INST_BGMV_WIDE_NARROW
(
INST_BGMV_ONESIDE
,
nv_half
,
float
,
nv_half
)
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
View file @
1591c68f
...
@@ -2,3 +2,4 @@
...
@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
nv_bfloat16
,
nv_bfloat16
)
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
nv_bfloat16
,
nv_bfloat16
)
FOR_INST_BGMV_WIDE_NARROW
(
INST_BGMV_ONESIDE
,
float
,
nv_bfloat16
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
View file @
1591c68f
...
@@ -2,3 +2,4 @@
...
@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
nv_half
,
nv_half
)
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
nv_half
,
nv_half
)
FOR_INST_BGMV_WIDE_NARROW
(
INST_BGMV_ONESIDE
,
float
,
nv_half
,
nv_half
)
csrc/punica/bgmv/bgmv_impl.cuh
View file @
1591c68f
...
@@ -199,7 +199,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
...
@@ -199,7 +199,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
constexpr
int
tz
=
4
;
constexpr
int
tz
=
4
;
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
constexpr
(
feat_in
<
feat_out
)
{
if
constexpr
(
feat_in
<
=
feat_out
)
{
static_assert
(
feat_in
%
vec_size
==
0
);
static_assert
(
feat_in
%
vec_size
==
0
);
constexpr
int
tx
=
feat_in
/
vec_size
;
constexpr
int
tx
=
feat_in
/
vec_size
;
...
@@ -289,6 +289,9 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
...
@@ -289,6 +289,9 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
int64_t num_layers, int64_t layer_idx, float scale);
int64_t num_layers, int64_t layer_idx, float scale);
#define INST_BGMV_ONESIDE(in_T, out_T, W_T, feat_in, feat_out) \
INST_BGMV(feat_in, feat_out, in_T, out_T, W_T)
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
INST_BGMV(narrow, wide, in_T, out_T, W_T) \
INST_BGMV(narrow, wide, in_T, out_T, W_T) \
INST_BGMV(wide, narrow, in_T, out_T, W_T)
INST_BGMV(wide, narrow, in_T, out_T, W_T)
csrc/punica/bgmv/generator.py
View file @
1591c68f
...
@@ -10,6 +10,7 @@ TEMPLATE = """
...
@@ -10,6 +10,7 @@ TEMPLATE = """
#include "bgmv_impl.cuh"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, {input_dtype}, {output_dtype}, {weight_dtype})
"""
.
lstrip
()
# noqa: E501
"""
.
lstrip
()
# noqa: E501
for
input_dtype
in
DTYPES
:
for
input_dtype
in
DTYPES
:
...
...
csrc/punica/punica_ops.cc
View file @
1591c68f
...
@@ -79,12 +79,12 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
...
@@ -79,12 +79,12 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
FOR_BGMV_WIDE_NARROW
(
CASE
,
_
,
_
,
_
)
FOR_BGMV_WIDE_NARROW
(
CASE
,
_
,
_
,
_
)
FOR_INST_BGMV_WIDE_NARROW
(
CASE_ONESIDE
,
_
,
_
,
_
)
#undef CASE
#undef CASE
#undef CASE_ONESIDE
#undef CASE_ONESIDE
default:
default:
return
false
;
return
false
;
}
}
return
true
;
return
true
;
}
}
...
...
csrc/pybind.cpp
View file @
1591c68f
...
@@ -67,13 +67,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -67,13 +67,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops
.
def
(
"aqlm_dequant"
,
&
aqlm_dequant
,
"Decompression method for AQLM"
);
ops
.
def
(
"aqlm_dequant"
,
&
aqlm_dequant
,
"Decompression method for AQLM"
);
ops
.
def
(
"awq_gemm"
,
&
awq_gemm
,
"Quantized GEMM for AWQ"
);
ops
.
def
(
"awq_gemm"
,
&
awq_gemm
,
"Quantized GEMM for AWQ"
);
ops
.
def
(
"marlin_gemm"
,
&
marlin_gemm
,
"Marlin Optimized Quantized GEMM for GPTQ"
);
ops
.
def
(
"marlin_gemm"
,
&
marlin_gemm
,
"Marlin Optimized Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
,
"gptq_marlin Optimized Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_marlin_repack"
,
&
gptq_marlin_repack
,
"gptq_marlin repack from GPTQ"
);
ops
.
def
(
"awq_dequantize"
,
&
awq_dequantize
,
"Dequantization for AWQ"
);
ops
.
def
(
"awq_dequantize"
,
&
awq_dequantize
,
"Dequantization for AWQ"
);
#endif
#endif
ops
.
def
(
"gptq_gemm"
,
&
gptq_gemm
,
"Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_gemm"
,
&
gptq_gemm
,
"Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_shuffle"
,
&
gptq_shuffle
,
"Post processing for GPTQ"
);
ops
.
def
(
"gptq_shuffle"
,
&
gptq_shuffle
,
"Post processing for GPTQ"
);
ops
.
def
(
"squeezellm_gemm"
,
&
squeezellm_gemm
,
"Quantized GEMM for SqueezeLLM"
);
ops
.
def
(
"squeezellm_gemm"
,
&
squeezellm_gemm
,
"Quantized GEMM for SqueezeLLM"
);
ops
.
def
(
"scaled_fp8_quant"
,
&
scaled_fp8_quant
,
"Compute FP8 quantized tensor and scaling factor"
);
ops
.
def
(
"static_scaled_fp8_quant"
,
&
static_scaled_fp8_quant
,
"Compute FP8 quantized tensor for given scaling factor"
);
ops
.
def
(
"dynamic_scaled_fp8_quant"
,
&
dynamic_scaled_fp8_quant
,
"Compute FP8 quantized tensor and scaling factor"
);
ops
.
def
(
ops
.
def
(
"moe_align_block_size"
,
"moe_align_block_size"
,
&
moe_align_block_size
,
&
moe_align_block_size
,
...
@@ -93,6 +96,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -93,6 +96,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"reshape_and_cache"
,
"reshape_and_cache"
,
&
reshape_and_cache
,
&
reshape_and_cache
,
"Reshape the key and value tensors and cache them"
);
"Reshape the key and value tensors and cache them"
);
cache_ops
.
def
(
"reshape_and_cache_flash"
,
&
reshape_and_cache_flash
,
"Reshape the key and value tensors and cache them"
);
cache_ops
.
def
(
cache_ops
.
def
(
"convert_fp8"
,
"convert_fp8"
,
&
convert_fp8
,
&
convert_fp8
,
...
...
csrc/quantization/fp8/fp8_cuda_kernels.cu
View file @
1591c68f
...
@@ -74,7 +74,30 @@ __global__ void scaled_fp8_quant_kernel(
...
@@ -74,7 +74,30 @@ __global__ void scaled_fp8_quant_kernel(
}
// namespace vllm
}
// namespace vllm
void
scaled_fp8_quant
(
void
static_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
,
// [..., d]
torch
::
Tensor
&
scale
)
// [1]
{
int64_t
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
int64_t
num_elems
=
input
.
numel
();
dim3
grid
(
num_tokens
);
dim3
block
(
1024
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"scaled_fp8_quant_kernel"
,
[
&
]
{
vllm
::
scaled_fp8_quant_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
c10
::
Float8_e4m3fn
>
(),
input
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
num_elems
);
});
}
void
dynamic_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
,
// [..., d]
torch
::
Tensor
&
input
,
// [..., d]
torch
::
Tensor
&
scale
)
// [1]
torch
::
Tensor
&
scale
)
// [1]
...
...
csrc/quantization/gptq_marlin/gptq_marlin.cu
0 → 100644
View file @
1591c68f
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
6
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment