Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7d672d27
Unverified
Commit
7d672d27
authored
Dec 22, 2024
by
Xiaoyu Zhang
Committed by
GitHub
Dec 22, 2024
Browse files
[kernel optimize] benchmark write_req_to_token_pool_triton and optimize kernel (#2509)
parent
d4b17481
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
345 additions
and
0 deletions
+345
-0
benchmark/kernels/scheduler_batch/benchmark_write_req_to_token_pool_triton.py
...heduler_batch/benchmark_write_req_to_token_pool_triton.py
+345
-0
No files found.
benchmark/kernels/scheduler_batch/benchmark_write_req_to_token_pool_triton.py
0 → 100644
View file @
7d672d27
import
itertools
import
os
from
typing
import
List
import
numpy
as
np
import
pytest
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
write_req_to_token_pool_triton
(
req_to_token_ptr
,
# [max_batch, max_context_len]
req_pool_indices
,
pre_lens
,
seq_lens
,
extend_lens
,
out_cache_loc
,
req_to_token_ptr_stride
:
tl
.
constexpr
,
):
BLOCK_SIZE
:
tl
.
constexpr
=
512
pid
=
tl
.
program_id
(
0
)
req_pool_index
=
tl
.
load
(
req_pool_indices
+
pid
)
pre_len
=
tl
.
load
(
pre_lens
+
pid
)
seq_len
=
tl
.
load
(
seq_lens
+
pid
)
# TODO: optimize this?
cumsum_start
=
0
for
i
in
range
(
pid
):
cumsum_start
+=
tl
.
load
(
extend_lens
+
i
)
num_loop
=
tl
.
cdiv
(
seq_len
-
pre_len
,
BLOCK_SIZE
)
for
i
in
range
(
num_loop
):
offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
i
*
BLOCK_SIZE
mask
=
offset
<
(
seq_len
-
pre_len
)
value
=
tl
.
load
(
out_cache_loc
+
cumsum_start
+
offset
,
mask
=
mask
)
tl
.
store
(
req_to_token_ptr
+
req_pool_index
*
req_to_token_ptr_stride
+
offset
+
pre_len
,
value
,
mask
=
mask
,
)
@
triton
.
jit
def
write_req_to_token_pool_triton_optimize
(
req_to_token_ptr
,
# [max_batch, max_context_len]
req_pool_indices
,
pre_lens
,
seq_lens
,
extend_lens
,
out_cache_loc
,
req_to_token_ptr_stride
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid_batch
=
tl
.
program_id
(
0
)
pid_token
=
tl
.
program_id
(
1
)
req_pool_index
=
tl
.
load
(
req_pool_indices
+
pid_batch
)
pre_len
=
tl
.
load
(
pre_lens
+
pid_batch
)
seq_len
=
tl
.
load
(
seq_lens
+
pid_batch
)
extend_len
=
seq_len
-
pre_len
cumsum_start
=
0
for
i
in
range
(
pid_batch
):
cumsum_start
+=
tl
.
load
(
extend_lens
+
i
)
token_start
=
pid_token
*
BLOCK_SIZE
offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
actual_offset
=
token_start
+
offset
mask
=
actual_offset
<
extend_len
src_ptr
=
out_cache_loc
+
cumsum_start
+
actual_offset
src_ptr
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
src_ptr
,
BLOCK_SIZE
),
BLOCK_SIZE
)
value
=
tl
.
load
(
src_ptr
,
mask
=
mask
)
dst_ptr
=
(
req_to_token_ptr
+
req_pool_index
*
req_to_token_ptr_stride
+
actual_offset
+
pre_len
)
dst_ptr
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
dst_ptr
,
BLOCK_SIZE
),
BLOCK_SIZE
)
tl
.
store
(
dst_ptr
,
value
,
mask
=
mask
)
def
write_req_to_token_pool_reference
(
req_to_token
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
pre_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
extend_lens
:
torch
.
Tensor
,
out_cache_loc
:
torch
.
Tensor
,
)
->
None
:
"""Reference implementation using PyTorch"""
for
i
in
range
(
len
(
req_pool_indices
)):
req_pool_idx
=
req_pool_indices
[
i
].
item
()
pre_len
=
pre_lens
[
i
].
item
()
seq_len
=
seq_lens
[
i
].
item
()
extend_len
=
extend_lens
[
i
].
item
()
cumsum_start
=
sum
(
extend_lens
[:
i
].
tolist
())
# Copy values from out_cache_loc to req_to_token
req_to_token
[
req_pool_idx
,
pre_len
:
seq_len
]
=
out_cache_loc
[
cumsum_start
:
cumsum_start
+
extend_len
]
def
test_write_req_to_token_pool
():
max_batch
=
4097
max_context_len
=
6148
batch_size
=
1
extend_len
=
14
# Initialize input tensors
req_to_token
=
torch
.
zeros
(
(
max_batch
,
max_context_len
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
req_pool_indices
=
torch
.
tensor
([
42
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
pre_lens
=
torch
.
tensor
([
8
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
seq_lens
=
torch
.
tensor
([
22
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
extend_lens
=
torch
.
tensor
([
extend_len
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
out_cache_loc
=
torch
.
arange
(
extend_len
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# Create copies for reference implementation
req_to_token_ref
=
req_to_token
.
clone
()
req_to_token_opt
=
req_to_token
.
clone
()
# Run original triton kernel
write_req_to_token_pool_triton
[(
batch_size
,)](
req_to_token
,
req_pool_indices
,
pre_lens
,
seq_lens
,
extend_lens
,
out_cache_loc
,
max_context_len
,
)
# Run optimized triton kernel
def
grid
(
batch_size
,
extend_len
):
num_token_blocks
=
triton
.
cdiv
(
extend_len
,
512
)
return
(
batch_size
,
num_token_blocks
)
write_req_to_token_pool_triton_optimize
[
grid
(
batch_size
,
extend_len
)](
req_to_token_opt
,
req_pool_indices
,
pre_lens
,
seq_lens
,
extend_lens
,
out_cache_loc
,
max_context_len
,
BLOCK_SIZE
=
512
,
)
# Run reference implementation
write_req_to_token_pool_reference
(
req_to_token_ref
,
req_pool_indices
,
pre_lens
,
seq_lens
,
extend_lens
,
out_cache_loc
,
)
# Compare results
torch
.
testing
.
assert_close
(
req_to_token
,
req_to_token_ref
)
torch
.
testing
.
assert_close
(
req_to_token_opt
,
req_to_token_ref
)
# Test case 2: batch size > 1
batch_size
=
3
extend_lens_list
=
[
14
,
20
,
30
]
total_extend_len
=
sum
(
extend_lens_list
)
req_to_token
=
torch
.
zeros
(
(
max_batch
,
max_context_len
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
req_pool_indices
=
torch
.
tensor
([
42
,
100
,
200
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
pre_lens
=
torch
.
tensor
([
8
,
10
,
15
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
seq_lens
=
torch
.
tensor
([
22
,
30
,
45
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
extend_lens
=
torch
.
tensor
(
extend_lens_list
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
out_cache_loc
=
torch
.
arange
(
total_extend_len
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
req_to_token_ref
=
req_to_token
.
clone
()
req_to_token_opt
=
req_to_token
.
clone
()
# Run original triton kernel
write_req_to_token_pool_triton
[(
batch_size
,)](
req_to_token
,
req_pool_indices
,
pre_lens
,
seq_lens
,
extend_lens
,
out_cache_loc
,
max_context_len
,
)
# Run optimized triton kernel
max_extend_len
=
max
(
extend_lens_list
)
write_req_to_token_pool_triton_optimize
[
grid
(
batch_size
,
max_extend_len
)](
req_to_token_opt
,
req_pool_indices
,
pre_lens
,
seq_lens
,
extend_lens
,
out_cache_loc
,
max_context_len
,
BLOCK_SIZE
=
512
,
)
# Run reference implementation
write_req_to_token_pool_reference
(
req_to_token_ref
,
req_pool_indices
,
pre_lens
,
seq_lens
,
extend_lens
,
out_cache_loc
,
)
# Compare results
torch
.
testing
.
assert_close
(
req_to_token
,
req_to_token_ref
)
torch
.
testing
.
assert_close
(
req_to_token_opt
,
req_to_token_ref
)
def
get_benchmark
():
batch_sizes
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
]
extend_lens
=
[
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
]
configs
=
list
(
itertools
.
product
(
batch_sizes
,
extend_lens
))
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
,
"extend_len"
],
x_vals
=
configs
,
line_arg
=
"provider"
,
line_vals
=
[
"reference"
,
"triton"
,
"triton_optimize"
],
line_names
=
[
"PyTorch"
,
"Triton"
,
"Triton Optimized"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
),
(
"red"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
"write-req-to-token-pool-performance"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
extend_len
,
provider
):
max_batch
=
256
max_context_len
=
16384
extend_lens_list
=
[
extend_len
]
*
batch_size
total_extend_len
=
sum
(
extend_lens_list
)
req_to_token
=
torch
.
zeros
(
(
max_batch
,
max_context_len
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
req_pool_indices
=
torch
.
arange
(
batch_size
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
pre_lens
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
*
8
seq_lens
=
pre_lens
+
extend_len
extend_lens
=
torch
.
tensor
(
extend_lens_list
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
out_cache_loc
=
torch
.
arange
(
total_extend_len
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"reference"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
write_req_to_token_pool_reference
(
req_to_token
.
clone
(),
req_pool_indices
,
pre_lens
,
seq_lens
,
extend_lens
,
out_cache_loc
,
),
quantiles
=
quantiles
,
)
elif
provider
==
"triton"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
write_req_to_token_pool_triton
[(
batch_size
,)](
req_to_token
.
clone
(),
req_pool_indices
,
pre_lens
,
seq_lens
,
extend_lens
,
out_cache_loc
,
max_context_len
,
),
quantiles
=
quantiles
,
)
else
:
def
run_optimized
():
block_size
=
128
if
extend_len
<=
1024
else
512
grid_config
=
(
batch_size
,
triton
.
cdiv
(
extend_len
,
block_size
))
write_req_to_token_pool_triton_optimize
[
grid_config
](
req_to_token
.
clone
(),
req_pool_indices
,
pre_lens
,
seq_lens
,
extend_lens
,
out_cache_loc
,
max_context_len
,
BLOCK_SIZE
=
block_size
,
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
run_optimized
,
quantiles
=
quantiles
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
return
benchmark
def
run_benchmark
(
save_path
:
str
=
"./configs/benchmark_ops/write_req_to_token_pool/"
):
"""Run benchmark and save results"""
# Ensure save path exists
os
.
makedirs
(
save_path
,
exist_ok
=
True
)
# Run correctness test
test_write_req_to_token_pool
()
print
(
"Correctness test passed!"
)
# Run performance test
benchmark
=
get_benchmark
()
benchmark
.
run
(
print_data
=
True
,
save_path
=
save_path
)
if
__name__
==
"__main__"
:
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--save_path"
,
type
=
str
,
default
=
"./configs/benchmark_ops/write_req_to_token_pool/"
,
help
=
"Path to save benchmark results"
,
)
args
=
parser
.
parse_args
()
run_benchmark
(
args
.
save_path
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment