Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b6dde330
Unverified
Commit
b6dde330
authored
Nov 13, 2024
by
Pavani Majety
Committed by
GitHub
Nov 13, 2024
Browse files
[Core] Flashinfer - Remove advance step size restriction (#10282)
parent
1b886aa1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
38 additions
and
28 deletions
+38
-28
csrc/prepare_inputs/advance_step.cu
csrc/prepare_inputs/advance_step.cu
+38
-28
No files found.
csrc/prepare_inputs/advance_step.cu
View file @
b6dde330
...
...
@@ -88,6 +88,7 @@ inline void verify_tensor(std::string const& name, torch::Tensor const& t,
}
}
/// each thread processes a block per query
__global__
void
advance_step_flashinfer_kernel
(
int
num_threads
,
int
num_seqs
,
int
num_queries
,
int
block_size
,
long
*
input_tokens_ptr
,
long
const
*
sampled_token_ids_ptr
,
...
...
@@ -134,8 +135,10 @@ __global__ void advance_step_flashinfer_indptr_kernel(
int
num_threads
,
int
num_seqs
,
int
num_queries
,
int
*
paged_kv_indptr_ptr
,
int
*
block_table_bound_ptr
)
{
int
idx
=
blockIdx
.
x
*
num_threads
+
threadIdx
.
x
;
// Update paged_kv_indptr
if
(
idx
==
0
)
{
paged_kv_indptr_ptr
[
idx
]
=
0
;
}
if
(
idx
<
num_queries
)
{
int
sum
=
0
;
for
(
int
i
=
0
;
i
<=
idx
;
++
i
)
{
...
...
@@ -146,20 +149,33 @@ __global__ void advance_step_flashinfer_indptr_kernel(
}
__global__
void
advance_step_flashinfer_indices_kernel
(
int
num_threads
,
int
num_seqs
,
int
num_queries
,
int
const
*
block_tables_ptr
,
int64_t
const
block_tables_stride
,
int
*
paged_kv_indices_ptr
,
int
num_seqs
,
int
num_queries
,
int
const
*
block_tables_ptr
,
int64_t
const
max_num_blocks_per_seq
,
int
*
paged_kv_indices_ptr
,
int
*
paged_kv_indptr_ptr
,
int
*
block_table_bound_ptr
)
{
int
idx
=
blockIdx
.
x
*
num_threads
+
threadIdx
.
x
;
int
row
=
idx
/
block_tables_stride
;
int
col
=
idx
%
block_tables_stride
;
if
(
row
<
num_queries
&&
col
<
block_table_bound_ptr
[
row
])
{
paged_kv_indices_ptr
[
paged_kv_indptr_ptr
[
row
]
+
col
]
=
block_tables_ptr
[
row
*
block_tables_stride
+
col
];
// note: max_num_blocks_per_seq = block_tables.stride(0)
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
// when cuda graphs are enabled, paged_kv_indptr tensor
// has to be updated for the padded queries
// tid represents a query# for paged_kv_indptr tensor
if
(
num_queries
<
tid
&&
tid
<=
num_seqs
)
{
paged_kv_indptr_ptr
[
tid
]
=
paged_kv_indptr_ptr
[
num_queries
];
}
// if cudagraph, fill padded seqs with the last valid seq's indptr
if
(
num_queries
<
row
&&
row
<=
num_seqs
)
{
paged_kv_indptr_ptr
[
row
]
=
paged_kv_indptr_ptr
[
num_queries
];
// each thread processes a block_ptr in block_tables
// block_tables shape: [num_queries, max_num_blocks_per_seq]
// paged_kv_indices is flattened block_tables.
for
(
int
idx
=
tid
;
idx
<
(
num_seqs
*
max_num_blocks_per_seq
);
idx
+=
(
gridDim
.
x
*
blockDim
.
x
))
{
// block_tables-row = paged_kv_indptr[queryNum]
int
queryNum
=
idx
/
max_num_blocks_per_seq
;
int
col
=
idx
%
max_num_blocks_per_seq
;
if
(
queryNum
<
num_queries
&&
col
<
block_table_bound_ptr
[
queryNum
])
{
int
indices_arr_idx
=
paged_kv_indptr_ptr
[
queryNum
]
+
col
;
int
block_tables_idx
=
queryNum
*
max_num_blocks_per_seq
+
col
;
paged_kv_indices_ptr
[
indices_arr_idx
]
=
block_tables_ptr
[
block_tables_idx
];
}
}
}
...
...
@@ -247,22 +263,16 @@ void advance_step_flashinfer(
int
threads
;
cudaDeviceGetAttribute
(
&
blocks
,
cudaDevAttrMultiProcessorCount
,
dev
);
cudaDeviceGetAttribute
(
&
threads
,
cudaDevAttrMaxThreadsPerBlock
,
dev
);
if
(
logging
)
{
printf
(
"launching kernel with %d blocks
\n
"
,
blocks
);
}
// TODO(will): support arbitrary block_tables stride
if
((
blocks
*
threads
)
/
block_tables
.
stride
(
0
)
<
num_queries
)
{
TORCH_CHECK
(
false
,
"multi-step: not enough threads to map block_table to"
"FlashInfer's paged_kv_indices on GPU. Try reducing the number "
"of seqs,"
,
" increasing the block size or take smaller steps."
,
" num_queries = "
,
num_queries
,
" block_tables.stride(0) = "
,
block_tables
.
stride
(
0
),
" blocks = "
,
blocks
,
" max_threads = "
,
threads
);
int
block_tables_stride
=
block_tables
.
stride
(
0
);
TORCH_CHECK
((
blocks
*
threads
>
num_queries
),
"multi-step: not enough threads to map to num_queries = "
,
num_queries
,
" block_tables.stride(0) = "
,
block_tables
.
stride
(
0
),
" blocks = "
,
blocks
,
" max_threads = "
,
threads
);
if
(
logging
)
{
printf
(
"launching kernels with %d blocks and %d threads
\n
"
,
blocks
,
threads
);
}
advance_step_flashinfer_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
threads
,
num_seqs
,
num_queries
,
block_size
,
reinterpret_cast
<
long
*>
(
input_tokens
.
data_ptr
()),
...
...
@@ -281,7 +291,7 @@ void advance_step_flashinfer(
reinterpret_cast
<
int
*>
(
block_table_bound
.
data_ptr
()));
advance_step_flashinfer_indices_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
threads
,
num_seqs
,
num_queries
,
num_seqs
,
num_queries
,
reinterpret_cast
<
int
const
*>
(
block_tables
.
data_ptr
()),
block_tables
.
stride
(
0
),
reinterpret_cast
<
int
*>
(
paged_kv_indices
.
data_ptr
()),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment