Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d9d35def
Unverified
Commit
d9d35def
authored
May 30, 2025
by
JieXin Liang
Committed by
GitHub
May 29, 2025
Browse files
[test] add ut and bm for get_last_loc (#6746)
parent
6df81e8a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
169 additions
and
0 deletions
+169
-0
benchmark/kernels/scheduler_batch/benchmark_get_last_loc_triton.py
.../kernels/scheduler_batch/benchmark_get_last_loc_triton.py
+169
-0
No files found.
benchmark/kernels/scheduler_batch/benchmark_get_last_loc_triton.py
0 → 100644
View file @
d9d35def
import
os
import
torch
import
triton
import
triton.language
as
tl
@
torch
.
compile
(
dynamic
=
True
)
def
get_last_loc_torch
(
req_to_token
:
torch
.
Tensor
,
req_pool_indices_tensor
:
torch
.
Tensor
,
prefix_lens_tensor
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
return
torch
.
where
(
prefix_lens_tensor
>
0
,
req_to_token
[
req_pool_indices_tensor
,
prefix_lens_tensor
-
1
],
torch
.
full_like
(
prefix_lens_tensor
,
-
1
),
)
@
triton
.
jit
def
get_last_loc_kernel
(
req_to_token
,
req_pool_indices_tensor
,
prefix_lens_tensor
,
result
,
num_tokens
,
req_to_token_stride
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
pid
*
BLOCK_SIZE
mask
=
offset
<
num_tokens
prefix_lens
=
tl
.
load
(
prefix_lens_tensor
+
offset
,
mask
=
mask
,
other
=
0
)
req_pool_indices
=
tl
.
load
(
req_pool_indices_tensor
+
offset
,
mask
=
mask
,
other
=
0
)
token_mask
=
prefix_lens
>
0
token_index
=
req_pool_indices
*
req_to_token_stride
+
(
prefix_lens
-
1
)
tokens
=
tl
.
load
(
req_to_token
+
token_index
,
mask
=
token_mask
,
other
=-
1
)
tl
.
store
(
result
+
offset
,
tokens
,
mask
=
mask
)
def
get_last_loc_triton
(
req_to_token
:
torch
.
Tensor
,
req_pool_indices_tensor
:
torch
.
Tensor
,
prefix_lens_tensor
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
BLOCK_SIZE
=
256
num_tokens
=
prefix_lens_tensor
.
shape
[
0
]
result
=
torch
.
empty_like
(
prefix_lens_tensor
)
grid
=
(
triton
.
cdiv
(
num_tokens
,
BLOCK_SIZE
),)
get_last_loc_kernel
[
grid
](
req_to_token
,
req_pool_indices_tensor
,
prefix_lens_tensor
,
result
,
num_tokens
,
req_to_token
.
stride
(
0
),
BLOCK_SIZE
,
)
return
result
def
test_get_last_loc
():
max_batch
=
4097
max_context_len
=
6148
batch_size
=
20
# Initialize input tensors
req_to_token
=
torch
.
zeros
(
(
max_batch
,
max_context_len
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
req_pool_indices
=
torch
.
arange
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
pre_lens
=
torch
.
randint
(
-
max_context_len
//
2
,
max_context_len
,
(
batch_size
,),
dtype
=
torch
.
int64
,
device
=
"cuda"
,
)
last_loc_res
=
get_last_loc_triton
(
req_to_token
,
req_pool_indices
,
pre_lens
)
last_loc_ref
=
get_last_loc_torch
(
req_to_token
,
req_pool_indices
,
pre_lens
)
# Compare results
torch
.
testing
.
assert_close
(
last_loc_res
,
last_loc_ref
)
def
get_benchmark
():
batch_sizes
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
]
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
batch_sizes
,
line_arg
=
"provider"
,
line_vals
=
[
"reference"
,
"triton"
],
line_names
=
[
"PyTorch"
,
"Triton"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
"get-last-loc-performance"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
):
max_batch
=
2048
max_context_len
=
16384
req_to_token
=
torch
.
zeros
(
(
max_batch
,
max_context_len
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
req_pool_indices
=
torch
.
arange
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
pre_lens
=
torch
.
randint
(
-
max_context_len
//
2
,
max_context_len
,
(
batch_size
,),
dtype
=
torch
.
int64
,
device
=
"cuda"
,
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"reference"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
get_last_loc_torch
(
req_to_token
,
req_pool_indices
,
pre_lens
),
quantiles
=
quantiles
,
)
elif
provider
==
"triton"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
get_last_loc_triton
(
req_to_token
,
req_pool_indices
,
pre_lens
),
quantiles
=
quantiles
,
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
return
benchmark
def
run_benchmark
(
save_path
:
str
=
"./configs/benchmark_ops/get_last_loc/"
):
"""Run benchmark and save results"""
# Ensure save path exists
os
.
makedirs
(
save_path
,
exist_ok
=
True
)
# Run correctness test
test_get_last_loc
()
print
(
"Correctness test passed!"
)
# Run performance test
benchmark
=
get_benchmark
()
benchmark
.
run
(
print_data
=
True
,
save_path
=
save_path
)
if
__name__
==
"__main__"
:
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--save_path"
,
type
=
str
,
default
=
"./configs/benchmark_ops/get_last_loc/"
,
help
=
"Path to save benchmark results"
,
)
args
=
parser
.
parse_args
()
run_benchmark
(
args
.
save_path
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment