Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
445a2a4d
Unverified
Commit
445a2a4d
authored
Apr 10, 2026
by
Ganesh R
Committed by
GitHub
Apr 10, 2026
Browse files
feat(cpu): add CPU support for draft model speculative decoding (#32662)
Signed-off-by:
R
<
Ganesh.R@amd.com
>
parent
55d037e2
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
945 additions
and
7 deletions
+945
-7
cmake/cpu_extension.cmake
cmake/cpu_extension.cmake
+3
-0
csrc/cpu/spec_decode_utils.cpp
csrc/cpu/spec_decode_utils.cpp
+409
-0
csrc/cpu/torch_bindings.cpp
csrc/cpu/torch_bindings.cpp
+119
-0
vllm/utils/cpu_triton_utils.py
vllm/utils/cpu_triton_utils.py
+274
-0
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+4
-5
vllm/v1/spec_decode/utils.py
vllm/v1/spec_decode/utils.py
+15
-1
vllm/v1/worker/cpu_model_runner.py
vllm/v1/worker/cpu_model_runner.py
+121
-1
No files found.
cmake/cpu_extension.cmake
View file @
445a2a4d
...
@@ -349,6 +349,7 @@ endif()
...
@@ -349,6 +349,7 @@ endif()
set
(
VLLM_EXT_SRC
set
(
VLLM_EXT_SRC
"csrc/cpu/activation.cpp"
"csrc/cpu/activation.cpp"
"csrc/cpu/utils.cpp"
"csrc/cpu/utils.cpp"
"csrc/cpu/spec_decode_utils.cpp"
"csrc/cpu/layernorm.cpp"
"csrc/cpu/layernorm.cpp"
"csrc/cpu/mla_decode.cpp"
"csrc/cpu/mla_decode.cpp"
"csrc/cpu/pos_encoding.cpp"
"csrc/cpu/pos_encoding.cpp"
...
@@ -383,6 +384,7 @@ if (ENABLE_X86_ISA)
...
@@ -383,6 +384,7 @@ if (ENABLE_X86_ISA)
"csrc/cpu/cpu_wna16.cpp"
"csrc/cpu/cpu_wna16.cpp"
"csrc/cpu/cpu_fused_moe.cpp"
"csrc/cpu/cpu_fused_moe.cpp"
"csrc/cpu/utils.cpp"
"csrc/cpu/utils.cpp"
"csrc/cpu/spec_decode_utils.cpp"
"csrc/cpu/cpu_attn.cpp"
"csrc/cpu/cpu_attn.cpp"
"csrc/cpu/dnnl_kernels.cpp"
"csrc/cpu/dnnl_kernels.cpp"
"csrc/cpu/torch_bindings.cpp"
"csrc/cpu/torch_bindings.cpp"
...
@@ -395,6 +397,7 @@ if (ENABLE_X86_ISA)
...
@@ -395,6 +397,7 @@ if (ENABLE_X86_ISA)
set
(
VLLM_EXT_SRC_AVX2
set
(
VLLM_EXT_SRC_AVX2
"csrc/cpu/utils.cpp"
"csrc/cpu/utils.cpp"
"csrc/cpu/spec_decode_utils.cpp"
"csrc/cpu/cpu_attn.cpp"
"csrc/cpu/cpu_attn.cpp"
"csrc/cpu/torch_bindings.cpp"
"csrc/cpu/torch_bindings.cpp"
# TODO: Remove these files
# TODO: Remove these files
...
...
csrc/cpu/spec_decode_utils.cpp
0 → 100644
View file @
445a2a4d
#include "cpu_types.hpp"
#include <algorithm>
namespace
cpu_utils
{
void
eagle_prepare_inputs_padded_kernel_impl
(
const
torch
::
Tensor
&
cu_num_draft_tokens
,
const
torch
::
Tensor
&
valid_sampled_tokens_count
,
const
torch
::
Tensor
&
query_start_loc_gpu
,
torch
::
Tensor
&
token_indices_to_sample
,
torch
::
Tensor
&
num_rejected_tokens_gpu
,
const
int64_t
num_reqs
)
{
const
int64_t
*
cu_draft_ptr
=
cu_num_draft_tokens
.
data_ptr
<
int64_t
>
();
const
int64_t
*
valid_count_ptr
=
valid_sampled_tokens_count
.
data_ptr
<
int64_t
>
();
const
int32_t
*
query_loc_ptr
=
query_start_loc_gpu
.
data_ptr
<
int32_t
>
();
int32_t
*
indices_out_ptr
=
token_indices_to_sample
.
data_ptr
<
int32_t
>
();
int64_t
*
rejected_out_ptr
=
num_rejected_tokens_gpu
.
data_ptr
<
int64_t
>
();
#pragma omp parallel for
for
(
int64_t
req_idx
=
0
;
req_idx
<
num_reqs
;
++
req_idx
)
{
int64_t
start_idx
=
req_idx
==
0
?
0
:
cu_draft_ptr
[
req_idx
-
1
];
int64_t
num_draft_tokens
=
cu_draft_ptr
[
req_idx
]
-
start_idx
;
int64_t
num_valid_tokens
=
valid_count_ptr
[
req_idx
];
int64_t
num_rejected
=
0
;
if
(
num_draft_tokens
>
0
)
{
num_rejected
=
num_draft_tokens
+
1
-
num_valid_tokens
;
}
int32_t
q_last_tok_idx
=
query_loc_ptr
[
req_idx
+
1
]
-
1
;
int32_t
index_to_sample
=
q_last_tok_idx
-
num_rejected
;
indices_out_ptr
[
req_idx
]
=
index_to_sample
;
rejected_out_ptr
[
req_idx
]
=
num_rejected
;
}
}
void
eagle_prepare_next_token_padded_kernel_impl
(
const
torch
::
Tensor
&
sampled_token_ids
,
const
torch
::
Tensor
&
discard_request_mask
,
const
torch
::
Tensor
&
backup_next_token_ids
,
torch
::
Tensor
&
next_token_ids
,
torch
::
Tensor
&
valid_sampled_tokens_count
,
const
int64_t
vocab_size
,
const
int64_t
num_sampled_tokens_per_req
,
const
int64_t
num_reqs
)
{
const
int64_t
*
sampled_ids_ptr
=
sampled_token_ids
.
data_ptr
<
int64_t
>
();
const
bool
*
discard_mask_ptr
=
discard_request_mask
.
data_ptr
<
bool
>
();
const
int64_t
*
backup_ids_ptr
=
backup_next_token_ids
.
data_ptr
<
int64_t
>
();
int64_t
*
next_ids_out_ptr
=
next_token_ids
.
data_ptr
<
int64_t
>
();
int64_t
*
valid_count_out_ptr
=
valid_sampled_tokens_count
.
data_ptr
<
int64_t
>
();
const
int64_t
stride
=
sampled_token_ids
.
stride
(
0
);
#pragma omp parallel for
for
(
int64_t
req_idx
=
0
;
req_idx
<
num_reqs
;
++
req_idx
)
{
const
int64_t
*
row_ptr
=
sampled_ids_ptr
+
req_idx
*
stride
;
int64_t
valid_count
=
0
;
int64_t
last_valid_token
=
-
1
;
for
(
int64_t
pos
=
0
;
pos
<
num_sampled_tokens_per_req
;
++
pos
)
{
int64_t
token
=
row_ptr
[
pos
];
if
(
token
!=
-
1
&&
token
<
vocab_size
)
{
valid_count
++
;
last_valid_token
=
token
;
}
}
bool
discard
=
discard_mask_ptr
[
req_idx
];
if
(
discard
)
{
next_ids_out_ptr
[
req_idx
]
=
backup_ids_ptr
[
req_idx
];
valid_count_out_ptr
[
req_idx
]
=
0
;
}
else
{
next_ids_out_ptr
[
req_idx
]
=
(
valid_count
>
0
)
?
last_valid_token
:
backup_ids_ptr
[
req_idx
];
valid_count_out_ptr
[
req_idx
]
=
valid_count
;
}
}
}
void
eagle_step_slot_mapping_metadata_kernel_impl
(
const
torch
::
Tensor
&
positions
,
const
torch
::
Tensor
&
block_table
,
torch
::
Tensor
&
seq_lens
,
torch
::
Tensor
&
out_clamped_positions
,
torch
::
Tensor
&
out_slot_mapping
,
const
int64_t
block_size
,
const
int64_t
max_model_len
,
const
int64_t
PAD_ID
)
{
const
int64_t
batch_size
=
positions
.
size
(
0
);
const
int64_t
input_batch_size
=
out_slot_mapping
.
size
(
0
);
const
int64_t
*
pos_ptr
=
positions
.
data_ptr
<
int64_t
>
();
const
int32_t
*
bt_ptr
=
block_table
.
data_ptr
<
int32_t
>
();
int32_t
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int32_t
>
();
int64_t
*
out_clamped_ptr
=
out_clamped_positions
.
data_ptr
<
int64_t
>
();
int64_t
*
out_slot_ptr
=
out_slot_mapping
.
data_ptr
<
int64_t
>
();
const
int64_t
bt_stride
=
block_table
.
stride
(
0
);
const
int64_t
n_blocks_per_req
=
block_table
.
size
(
1
);
#pragma omp parallel for
for
(
int64_t
req_idx
=
0
;
req_idx
<
input_batch_size
;
++
req_idx
)
{
if
(
req_idx
>=
batch_size
)
{
out_slot_ptr
[
req_idx
]
=
PAD_ID
;
continue
;
}
int64_t
position
=
pos_ptr
[
req_idx
];
int64_t
new_position
=
position
+
1
;
bool
exceeds_max
=
new_position
>=
max_model_len
;
int64_t
clamped_position
=
exceeds_max
?
0
:
new_position
;
out_clamped_ptr
[
req_idx
]
=
clamped_position
;
int64_t
block_number
=
clamped_position
/
block_size
;
block_number
=
std
::
min
(
block_number
,
n_blocks_per_req
-
1
);
int32_t
block_id
=
bt_ptr
[
req_idx
*
bt_stride
+
block_number
];
int64_t
slot_id
=
block_id
*
block_size
+
(
clamped_position
%
block_size
);
out_slot_ptr
[
req_idx
]
=
exceeds_max
?
PAD_ID
:
slot_id
;
int32_t
seq_len
=
seq_lens_ptr
[
req_idx
];
int32_t
new_seq_len
=
exceeds_max
?
1
:
(
seq_len
+
1
);
new_seq_len
=
std
::
min
(
new_seq_len
,
static_cast
<
int32_t
>
(
max_model_len
));
seq_lens_ptr
[
req_idx
]
=
new_seq_len
;
}
}
void
copy_and_expand_eagle_inputs_kernel_impl
(
const
torch
::
Tensor
&
target_token_ids
,
const
torch
::
Tensor
&
target_positions
,
const
torch
::
Tensor
&
next_token_ids
,
torch
::
Tensor
&
out_input_ids
,
torch
::
Tensor
&
out_positions
,
torch
::
Tensor
&
out_is_rejected_token_mask
,
torch
::
Tensor
&
out_is_masked_token_mask
,
torch
::
Tensor
&
out_new_token_indices
,
torch
::
Tensor
&
out_hidden_state_mapping
,
const
torch
::
Tensor
&
query_start_loc
,
const
torch
::
Tensor
&
query_end_loc
,
const
int64_t
padding_token_id
,
const
int64_t
parallel_drafting_token_id
,
const
int64_t
total_input_tokens
,
const
int64_t
num_padding_slots_per_request
,
const
bool
shift_input_ids
)
{
const
int64_t
num_reqs
=
query_end_loc
.
size
(
0
);
const
int64_t
*
target_ids_ptr
=
target_token_ids
.
data_ptr
<
int64_t
>
();
const
int64_t
*
target_pos_ptr
=
target_positions
.
data_ptr
<
int64_t
>
();
const
int64_t
*
next_ids_ptr
=
next_token_ids
.
data_ptr
<
int64_t
>
();
const
int32_t
*
query_start_ptr
=
query_start_loc
.
data_ptr
<
int32_t
>
();
const
int32_t
*
query_end_ptr
=
query_end_loc
.
data_ptr
<
int32_t
>
();
int64_t
*
out_ids_ptr
=
out_input_ids
.
data_ptr
<
int64_t
>
();
int64_t
*
out_pos_ptr
=
out_positions
.
data_ptr
<
int64_t
>
();
bool
*
out_rej_mask_ptr
=
out_is_rejected_token_mask
.
data_ptr
<
bool
>
();
bool
*
out_mask_ptr
=
out_is_masked_token_mask
.
data_ptr
<
bool
>
();
int32_t
*
out_new_idx_ptr
=
out_new_token_indices
.
data_ptr
<
int32_t
>
();
int32_t
*
out_hidden_map_ptr
=
out_hidden_state_mapping
.
data_ptr
<
int32_t
>
();
#pragma omp parallel for
for
(
int64_t
req_idx
=
0
;
req_idx
<
num_reqs
;
++
req_idx
)
{
int32_t
q_start
=
query_start_ptr
[
req_idx
];
int32_t
next_q_start
=
query_start_ptr
[
req_idx
+
1
];
int32_t
q_end
=
query_end_ptr
[
req_idx
];
int64_t
num_valid_tokens
=
shift_input_ids
?
(
q_end
-
q_start
)
:
(
q_end
-
q_start
+
1
);
int64_t
input_offset
=
shift_input_ids
?
1
:
0
;
int64_t
out_start
=
q_start
+
req_idx
*
(
num_padding_slots_per_request
-
(
shift_input_ids
?
1
:
0
));
int64_t
num_rejected
=
next_q_start
-
q_end
-
1
;
int64_t
total_output_tokens
=
num_valid_tokens
+
num_padding_slots_per_request
+
num_rejected
;
int64_t
start_pos
=
target_pos_ptr
[
q_start
];
int64_t
bonus_token
=
next_ids_ptr
[
req_idx
];
for
(
int64_t
j
=
0
;
j
<
total_output_tokens
;
++
j
)
{
int64_t
out_idx
=
out_start
+
j
;
bool
is_valid
=
j
<
num_valid_tokens
;
bool
is_bonus
=
j
==
num_valid_tokens
;
bool
is_parallel
=
(
j
>
num_valid_tokens
)
&&
(
j
<
num_valid_tokens
+
num_padding_slots_per_request
);
bool
is_rejected
=
j
>=
num_valid_tokens
+
num_padding_slots_per_request
;
int64_t
in_idx
=
std
::
min
(
static_cast
<
int64_t
>
(
q_start
+
input_offset
+
j
),
total_input_tokens
-
1
);
int64_t
token_id
=
padding_token_id
;
if
(
is_valid
)
token_id
=
target_ids_ptr
[
in_idx
];
else
if
(
is_bonus
)
token_id
=
bonus_token
;
else
if
(
is_parallel
)
token_id
=
parallel_drafting_token_id
;
out_ids_ptr
[
out_idx
]
=
token_id
;
out_pos_ptr
[
out_idx
]
=
is_rejected
?
0
:
(
start_pos
+
j
);
out_rej_mask_ptr
[
out_idx
]
=
is_rejected
;
out_mask_ptr
[
out_idx
]
=
is_parallel
;
if
(
is_bonus
||
is_parallel
)
{
int64_t
new_token_local_idx
=
j
-
num_valid_tokens
;
int64_t
new_token_out_idx
=
req_idx
*
num_padding_slots_per_request
+
new_token_local_idx
;
out_new_idx_ptr
[
new_token_out_idx
]
=
out_idx
;
}
}
if
(
shift_input_ids
)
{
int64_t
n_input
=
next_q_start
-
q_start
;
for
(
int64_t
j
=
0
;
j
<
n_input
;
++
j
)
{
out_hidden_map_ptr
[
q_start
+
j
]
=
out_start
+
j
;
}
}
}
}
void
rejection_greedy_sample_kernel_impl
(
torch
::
Tensor
&
output_token_ids
,
const
torch
::
Tensor
&
cu_num_draft_tokens
,
const
torch
::
Tensor
&
draft_token_ids
,
const
torch
::
Tensor
&
target_argmax
,
const
torch
::
Tensor
&
bonus_token_ids
,
const
std
::
optional
<
torch
::
Tensor
>&
is_greedy
,
const
int64_t
max_spec_len
)
{
const
int64_t
batch_size
=
cu_num_draft_tokens
.
size
(
0
);
int64_t
*
out_ptr
=
output_token_ids
.
data_ptr
<
int64_t
>
();
const
int64_t
*
cu_draft_ptr
=
cu_num_draft_tokens
.
data_ptr
<
int64_t
>
();
const
int64_t
*
draft_ids_ptr
=
draft_token_ids
.
data_ptr
<
int64_t
>
();
const
int64_t
*
target_argmax_ptr
=
target_argmax
.
data_ptr
<
int64_t
>
();
const
int64_t
*
bonus_ids_ptr
=
bonus_token_ids
.
data_ptr
<
int64_t
>
();
const
bool
*
greedy_ptr
=
is_greedy
.
has_value
()
?
is_greedy
.
value
().
data_ptr
<
bool
>
()
:
nullptr
;
const
int64_t
out_stride
=
output_token_ids
.
stride
(
0
);
const
int64_t
bonus_stride
=
bonus_token_ids
.
stride
(
0
);
#pragma omp parallel for
for
(
int64_t
req_idx
=
0
;
req_idx
<
batch_size
;
++
req_idx
)
{
if
(
greedy_ptr
&&
!
greedy_ptr
[
req_idx
])
continue
;
int64_t
start_idx
=
req_idx
==
0
?
0
:
cu_draft_ptr
[
req_idx
-
1
];
int64_t
end_idx
=
cu_draft_ptr
[
req_idx
];
int64_t
num_draft_tokens
=
end_idx
-
start_idx
;
bool
rejected
=
false
;
for
(
int64_t
pos
=
0
;
pos
<
num_draft_tokens
;
++
pos
)
{
int64_t
target_id
=
target_argmax_ptr
[
start_idx
+
pos
];
out_ptr
[
req_idx
*
out_stride
+
pos
]
=
target_id
;
if
(
draft_ids_ptr
[
start_idx
+
pos
]
!=
target_id
)
{
rejected
=
true
;
break
;
}
}
if
(
!
rejected
)
{
out_ptr
[
req_idx
*
out_stride
+
num_draft_tokens
]
=
bonus_ids_ptr
[
req_idx
*
bonus_stride
];
}
}
}
void
rejection_random_sample_kernel_impl
(
torch
::
Tensor
&
output_token_ids
,
const
torch
::
Tensor
&
cu_num_draft_tokens
,
const
torch
::
Tensor
&
draft_token_ids
,
const
std
::
optional
<
torch
::
Tensor
>&
draft_probs
,
const
torch
::
Tensor
&
target_probs
,
const
torch
::
Tensor
&
bonus_token_ids
,
const
torch
::
Tensor
&
recovered_token_ids
,
const
torch
::
Tensor
&
uniform_probs
,
const
std
::
optional
<
torch
::
Tensor
>&
is_greedy
,
const
int64_t
max_spec_len
,
const
int64_t
vocab_size
,
const
bool
no_draft_probs
)
{
const
int64_t
batch_size
=
cu_num_draft_tokens
.
size
(
0
);
int64_t
*
out_ptr
=
output_token_ids
.
data_ptr
<
int64_t
>
();
const
int64_t
*
cu_draft_ptr
=
cu_num_draft_tokens
.
data_ptr
<
int64_t
>
();
const
int64_t
*
draft_ids_ptr
=
draft_token_ids
.
data_ptr
<
int64_t
>
();
const
float
*
draft_probs_ptr
=
no_draft_probs
?
nullptr
:
draft_probs
.
value
().
data_ptr
<
float
>
();
const
float
*
target_probs_ptr
=
target_probs
.
data_ptr
<
float
>
();
const
int64_t
*
bonus_ids_ptr
=
bonus_token_ids
.
data_ptr
<
int64_t
>
();
const
int64_t
*
recovered_ids_ptr
=
recovered_token_ids
.
data_ptr
<
int64_t
>
();
const
float
*
uniform_probs_ptr
=
uniform_probs
.
data_ptr
<
float
>
();
const
bool
*
greedy_ptr
=
is_greedy
.
has_value
()
?
is_greedy
.
value
().
data_ptr
<
bool
>
()
:
nullptr
;
const
int64_t
out_stride
=
output_token_ids
.
stride
(
0
);
const
int64_t
bonus_stride
=
bonus_token_ids
.
stride
(
0
);
const
int64_t
target_stride
=
target_probs
.
stride
(
0
);
const
int64_t
draft_probs_stride
=
no_draft_probs
?
0
:
draft_probs
.
value
().
stride
(
0
);
#pragma omp parallel for
for
(
int64_t
req_idx
=
0
;
req_idx
<
batch_size
;
++
req_idx
)
{
if
(
greedy_ptr
&&
greedy_ptr
[
req_idx
])
continue
;
int64_t
start_idx
=
req_idx
==
0
?
0
:
cu_draft_ptr
[
req_idx
-
1
];
int64_t
end_idx
=
cu_draft_ptr
[
req_idx
];
int64_t
num_draft_tokens
=
end_idx
-
start_idx
;
bool
rejected
=
false
;
for
(
int64_t
pos
=
0
;
pos
<
num_draft_tokens
;
++
pos
)
{
int64_t
token_idx
=
start_idx
+
pos
;
int64_t
draft_id
=
draft_ids_ptr
[
token_idx
];
float
p
=
target_probs_ptr
[
token_idx
*
target_stride
+
draft_id
];
float
q
=
no_draft_probs
?
1.0
f
:
draft_probs_ptr
[
token_idx
*
draft_probs_stride
+
draft_id
];
float
uniform_p
=
uniform_probs_ptr
[
token_idx
];
float
ratio
=
(
q
>
0.0
f
)
?
(
p
/
q
)
:
0.0
f
;
if
(
ratio
>=
uniform_p
)
{
out_ptr
[
req_idx
*
out_stride
+
pos
]
=
draft_id
;
}
else
{
out_ptr
[
req_idx
*
out_stride
+
pos
]
=
recovered_ids_ptr
[
token_idx
];
rejected
=
true
;
break
;
}
}
if
(
!
rejected
)
{
out_ptr
[
req_idx
*
out_stride
+
num_draft_tokens
]
=
bonus_ids_ptr
[
req_idx
*
bonus_stride
];
}
}
}
void
expand_kernel_impl
(
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
cu_num_tokens
,
const
int64_t
replace_from
,
const
int64_t
replace_to
)
{
const
int64_t
batch_size
=
cu_num_tokens
.
size
(
0
);
const
int64_t
*
cu_tokens_ptr
=
cu_num_tokens
.
data_ptr
<
int64_t
>
();
int64_t
*
out_ptr
=
output
.
data_ptr
<
int64_t
>
();
const
int64_t
*
in_ptr
=
input
.
data_ptr
<
int64_t
>
();
#pragma omp parallel for
for
(
int64_t
req_idx
=
0
;
req_idx
<
batch_size
;
++
req_idx
)
{
int64_t
start_idx
=
req_idx
==
0
?
0
:
cu_tokens_ptr
[
req_idx
-
1
];
int64_t
end_idx
=
cu_tokens_ptr
[
req_idx
];
int64_t
val
=
in_ptr
[
req_idx
];
if
(
val
==
replace_from
)
{
val
=
replace_to
;
}
for
(
int64_t
i
=
start_idx
;
i
<
end_idx
;
++
i
)
{
out_ptr
[
i
]
=
val
;
}
}
}
void
sample_recovered_tokens_kernel_impl
(
torch
::
Tensor
&
output_token_ids
,
const
torch
::
Tensor
&
cu_num_draft_tokens
,
const
torch
::
Tensor
&
draft_token_ids
,
const
std
::
optional
<
torch
::
Tensor
>&
draft_probs
,
const
torch
::
Tensor
&
target_probs
,
const
torch
::
Tensor
&
inv_q
,
const
int64_t
vocab_size
,
const
bool
no_draft_probs
)
{
const
int64_t
batch_size
=
cu_num_draft_tokens
.
size
(
0
);
int64_t
*
out_ptr
=
output_token_ids
.
data_ptr
<
int64_t
>
();
const
int64_t
*
cu_draft_ptr
=
cu_num_draft_tokens
.
data_ptr
<
int64_t
>
();
const
int64_t
*
draft_ids_ptr
=
draft_token_ids
.
data_ptr
<
int64_t
>
();
const
float
*
draft_probs_ptr
=
no_draft_probs
?
nullptr
:
draft_probs
.
value
().
data_ptr
<
float
>
();
const
float
*
target_probs_ptr
=
target_probs
.
data_ptr
<
float
>
();
const
float
*
inv_q_ptr
=
inv_q
.
data_ptr
<
float
>
();
const
int64_t
target_stride
=
target_probs
.
stride
(
0
);
const
int64_t
draft_probs_stride
=
no_draft_probs
?
0
:
draft_probs
.
value
().
stride
(
0
);
const
int64_t
inv_q_stride
=
inv_q
.
stride
(
0
);
#pragma omp parallel for
for
(
int64_t
req_idx
=
0
;
req_idx
<
batch_size
;
++
req_idx
)
{
int64_t
start_idx
=
req_idx
==
0
?
0
:
cu_draft_ptr
[
req_idx
-
1
];
int64_t
end_idx
=
cu_draft_ptr
[
req_idx
];
int64_t
num_draft_tokens
=
end_idx
-
start_idx
;
const
float
*
req_inv_q
=
inv_q_ptr
+
req_idx
*
inv_q_stride
;
for
(
int64_t
pos
=
0
;
pos
<
num_draft_tokens
;
++
pos
)
{
int64_t
token_idx
=
start_idx
+
pos
;
int64_t
draft_id
=
draft_ids_ptr
[
token_idx
];
const
float
*
token_target_probs
=
target_probs_ptr
+
token_idx
*
target_stride
;
const
float
*
token_draft_probs
=
no_draft_probs
?
nullptr
:
(
draft_probs_ptr
+
token_idx
*
draft_probs_stride
);
int64_t
best_id
=
0
;
float
best_val
=
-
1.0
f
;
for
(
int64_t
v
=
0
;
v
<
vocab_size
;
++
v
)
{
float
prob
=
token_target_probs
[
v
];
if
(
no_draft_probs
)
{
if
(
v
==
draft_id
)
prob
=
0.0
f
;
}
else
{
float
diff
=
prob
-
token_draft_probs
[
v
];
prob
=
diff
>
0.0
f
?
diff
:
0.0
f
;
}
float
val
=
prob
*
req_inv_q
[
v
];
if
(
val
>
best_val
)
{
best_val
=
val
;
best_id
=
v
;
}
}
out_ptr
[
token_idx
]
=
best_id
;
}
}
}
}
// namespace cpu_utils
csrc/cpu/torch_bindings.cpp
View file @
445a2a4d
...
@@ -138,6 +138,61 @@ void compute_slot_mapping_kernel_impl(const torch::Tensor query_start_loc,
...
@@ -138,6 +138,61 @@ void compute_slot_mapping_kernel_impl(const torch::Tensor query_start_loc,
torch
::
Tensor
slot_mapping
,
torch
::
Tensor
slot_mapping
,
const
int64_t
block_size
);
const
int64_t
block_size
);
namespace
cpu_utils
{
void
eagle_prepare_inputs_padded_kernel_impl
(
const
torch
::
Tensor
&
cu_num_draft_tokens
,
const
torch
::
Tensor
&
valid_sampled_tokens_count
,
const
torch
::
Tensor
&
query_start_loc_gpu
,
torch
::
Tensor
&
token_indices_to_sample
,
torch
::
Tensor
&
num_rejected_tokens_gpu
,
const
int64_t
num_reqs
);
void
eagle_prepare_next_token_padded_kernel_impl
(
const
torch
::
Tensor
&
sampled_token_ids
,
const
torch
::
Tensor
&
discard_request_mask
,
const
torch
::
Tensor
&
backup_next_token_ids
,
torch
::
Tensor
&
next_token_ids
,
torch
::
Tensor
&
valid_sampled_tokens_count
,
const
int64_t
vocab_size
,
const
int64_t
num_sampled_tokens_per_req
,
const
int64_t
num_reqs
);
void
eagle_step_slot_mapping_metadata_kernel_impl
(
const
torch
::
Tensor
&
positions
,
const
torch
::
Tensor
&
block_table
,
torch
::
Tensor
&
seq_lens
,
torch
::
Tensor
&
out_clamped_positions
,
torch
::
Tensor
&
out_slot_mapping
,
const
int64_t
block_size
,
const
int64_t
max_model_len
,
const
int64_t
PAD_ID
);
void
copy_and_expand_eagle_inputs_kernel_impl
(
const
torch
::
Tensor
&
target_token_ids
,
const
torch
::
Tensor
&
target_positions
,
const
torch
::
Tensor
&
next_token_ids
,
torch
::
Tensor
&
out_input_ids
,
torch
::
Tensor
&
out_positions
,
torch
::
Tensor
&
out_is_rejected_token_mask
,
torch
::
Tensor
&
out_is_masked_token_mask
,
torch
::
Tensor
&
out_new_token_indices
,
torch
::
Tensor
&
out_hidden_state_mapping
,
const
torch
::
Tensor
&
query_start_loc
,
const
torch
::
Tensor
&
query_end_loc
,
const
int64_t
padding_token_id
,
const
int64_t
parallel_drafting_token_id
,
const
int64_t
total_input_tokens
,
const
int64_t
num_padding_slots_per_request
,
const
bool
shift_input_ids
);
void
rejection_greedy_sample_kernel_impl
(
torch
::
Tensor
&
output_token_ids
,
const
torch
::
Tensor
&
cu_num_draft_tokens
,
const
torch
::
Tensor
&
draft_token_ids
,
const
torch
::
Tensor
&
target_argmax
,
const
torch
::
Tensor
&
bonus_token_ids
,
const
std
::
optional
<
torch
::
Tensor
>&
is_greedy
,
const
int64_t
max_spec_len
);
void
rejection_random_sample_kernel_impl
(
torch
::
Tensor
&
output_token_ids
,
const
torch
::
Tensor
&
cu_num_draft_tokens
,
const
torch
::
Tensor
&
draft_token_ids
,
const
std
::
optional
<
torch
::
Tensor
>&
draft_probs
,
const
torch
::
Tensor
&
target_probs
,
const
torch
::
Tensor
&
bonus_token_ids
,
const
torch
::
Tensor
&
recovered_token_ids
,
const
torch
::
Tensor
&
uniform_probs
,
const
std
::
optional
<
torch
::
Tensor
>&
is_greedy
,
const
int64_t
max_spec_len
,
const
int64_t
vocab_size
,
const
bool
no_draft_probs
);
void
expand_kernel_impl
(
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
cu_num_tokens
,
const
int64_t
replace_from
,
const
int64_t
replace_to
);
void
sample_recovered_tokens_kernel_impl
(
torch
::
Tensor
&
output_token_ids
,
const
torch
::
Tensor
&
cu_num_draft_tokens
,
const
torch
::
Tensor
&
draft_token_ids
,
const
std
::
optional
<
torch
::
Tensor
>&
draft_probs
,
const
torch
::
Tensor
&
target_probs
,
const
torch
::
Tensor
&
inv_q
,
const
int64_t
vocab_size
,
const
bool
no_draft_probs
);
}
// namespace cpu_utils
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
ops
)
{
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
ops
)
{
// vLLM custom ops
// vLLM custom ops
...
@@ -363,6 +418,70 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -363,6 +418,70 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"positions, Tensor block_table, Tensor(a3!) slot_mapping, SymInt "
"positions, Tensor block_table, Tensor(a3!) slot_mapping, SymInt "
"block_size) -> ()"
,
"block_size) -> ()"
,
&
compute_slot_mapping_kernel_impl
);
&
compute_slot_mapping_kernel_impl
);
// Speculative decoding kernels
ops
.
def
(
"eagle_prepare_inputs_padded_kernel_impl(Tensor cu_num_draft_tokens, "
"Tensor valid_sampled_tokens_count, Tensor query_start_loc_gpu, "
"Tensor(a3!) token_indices_to_sample, "
"Tensor(a4!) num_rejected_tokens_gpu, "
"SymInt num_reqs) -> ()"
,
&
cpu_utils
::
eagle_prepare_inputs_padded_kernel_impl
);
ops
.
def
(
"eagle_prepare_next_token_padded_kernel_impl("
"Tensor sampled_token_ids, Tensor discard_request_mask, "
"Tensor backup_next_token_ids, Tensor(a3!) next_token_ids, "
"Tensor(a4!) valid_sampled_tokens_count, SymInt vocab_size, "
"SymInt num_sampled_tokens_per_req, SymInt num_reqs) -> ()"
,
&
cpu_utils
::
eagle_prepare_next_token_padded_kernel_impl
);
ops
.
def
(
"eagle_step_slot_mapping_metadata_kernel_impl("
"Tensor positions, Tensor block_table, Tensor(a2!) seq_lens, "
"Tensor(a3!) out_clamped_positions, Tensor(a4!) out_slot_mapping, "
"SymInt block_size, SymInt max_model_len, SymInt PAD_ID) -> ()"
,
&
cpu_utils
::
eagle_step_slot_mapping_metadata_kernel_impl
);
ops
.
def
(
"copy_and_expand_eagle_inputs_kernel_impl("
"Tensor target_token_ids, Tensor target_positions, "
"Tensor next_token_ids, Tensor(a3!) out_input_ids, "
"Tensor(a4!) out_positions, "
"Tensor(a5!) out_is_rejected_token_mask, "
"Tensor(a6!) out_is_masked_token_mask, "
"Tensor(a7!) out_new_token_indices, "
"Tensor(a8!) out_hidden_state_mapping, "
"Tensor query_start_loc, Tensor query_end_loc, "
"SymInt padding_token_id, SymInt parallel_drafting_token_id, "
"SymInt total_input_tokens, SymInt num_padding_slots_per_request, "
"bool shift_input_ids) -> ()"
,
&
cpu_utils
::
copy_and_expand_eagle_inputs_kernel_impl
);
ops
.
def
(
"rejection_greedy_sample_kernel_impl("
"Tensor(a0!) output_token_ids, Tensor cu_num_draft_tokens, "
"Tensor draft_token_ids, Tensor target_argmax, "
"Tensor bonus_token_ids, Tensor? is_greedy, "
"SymInt max_spec_len) -> ()"
,
&
cpu_utils
::
rejection_greedy_sample_kernel_impl
);
ops
.
def
(
"rejection_random_sample_kernel_impl("
"Tensor(a0!) output_token_ids, Tensor cu_num_draft_tokens, "
"Tensor draft_token_ids, Tensor? draft_probs, "
"Tensor target_probs, Tensor bonus_token_ids, "
"Tensor recovered_token_ids, Tensor uniform_probs, "
"Tensor? is_greedy, SymInt max_spec_len, SymInt vocab_size, "
"bool no_draft_probs) -> ()"
,
&
cpu_utils
::
rejection_random_sample_kernel_impl
);
ops
.
def
(
"expand_kernel_impl(Tensor(a0!) output, Tensor input, "
"Tensor cu_num_tokens, SymInt replace_from, "
"SymInt replace_to) -> ()"
,
&
cpu_utils
::
expand_kernel_impl
);
ops
.
def
(
"sample_recovered_tokens_kernel_impl("
"Tensor(a0!) output_token_ids, Tensor cu_num_draft_tokens, "
"Tensor draft_token_ids, Tensor? draft_probs, "
"Tensor target_probs, Tensor inv_q, SymInt vocab_size, "
"bool no_draft_probs) -> ()"
,
&
cpu_utils
::
sample_recovered_tokens_kernel_impl
);
}
}
REGISTER_EXTENSION
(
TORCH_EXTENSION_NAME
)
REGISTER_EXTENSION
(
TORCH_EXTENSION_NAME
)
vllm/utils/cpu_triton_utils.py
View file @
445a2a4d
...
@@ -45,3 +45,277 @@ def _compute_slot_mapping_kernel_impl(
...
@@ -45,3 +45,277 @@ def _compute_slot_mapping_kernel_impl(
compute_slot_mapping_kernel
=
_FuncWrapper
(
_compute_slot_mapping_kernel_impl
)
compute_slot_mapping_kernel
=
_FuncWrapper
(
_compute_slot_mapping_kernel_impl
)
def
_ensure_int64
(
t
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
t
if
t
.
dtype
==
torch
.
int64
else
t
.
to
(
torch
.
int64
)
def
_eagle_prepare_inputs_padded_kernel_impl
(
cu_num_draft_tokens
,
valid_sampled_tokens_count
,
query_start_loc_gpu
,
token_indices_to_sample
,
num_rejected_tokens_gpu
,
num_reqs
,
):
# C++ expects int64 for cu_num_draft_tokens, valid_sampled_tokens_count,
# and num_rejected_tokens_gpu, but Python allocates them as int32.
orig_rejected_dtype
=
num_rejected_tokens_gpu
.
dtype
rejected_i64
=
(
num_rejected_tokens_gpu
if
orig_rejected_dtype
==
torch
.
int64
else
num_rejected_tokens_gpu
.
to
(
torch
.
int64
)
)
torch
.
ops
.
_C
.
eagle_prepare_inputs_padded_kernel_impl
(
_ensure_int64
(
cu_num_draft_tokens
),
_ensure_int64
(
valid_sampled_tokens_count
),
query_start_loc_gpu
,
token_indices_to_sample
,
rejected_i64
,
num_reqs
,
)
if
orig_rejected_dtype
!=
torch
.
int64
:
num_rejected_tokens_gpu
.
copy_
(
rejected_i64
.
to
(
orig_rejected_dtype
))
def
_eagle_prepare_next_token_padded_kernel_impl
(
sampled_token_ids
,
discard_request_mask
,
backup_next_token_ids
,
next_token_ids
,
valid_sampled_tokens_count
,
vocab_size
,
num_sampled_tokens_per_req
,
num_reqs
,
stride
=
None
,
BLOCK_SIZE_TOKENS
=
None
,
):
# C++ reads all integer tensors as int64_t*. Output tensors are written
# in-place so we create int64 copies, call C++, and copy back.
orig_next_dtype
=
next_token_ids
.
dtype
orig_valid_dtype
=
valid_sampled_tokens_count
.
dtype
next_i64
=
_ensure_int64
(
next_token_ids
)
valid_i64
=
_ensure_int64
(
valid_sampled_tokens_count
)
torch
.
ops
.
_C
.
eagle_prepare_next_token_padded_kernel_impl
(
_ensure_int64
(
sampled_token_ids
),
discard_request_mask
,
_ensure_int64
(
backup_next_token_ids
),
next_i64
,
valid_i64
,
vocab_size
,
num_sampled_tokens_per_req
,
num_reqs
,
)
if
orig_next_dtype
!=
torch
.
int64
:
next_token_ids
.
copy_
(
next_i64
.
to
(
orig_next_dtype
))
if
orig_valid_dtype
!=
torch
.
int64
:
valid_sampled_tokens_count
.
copy_
(
valid_i64
.
to
(
orig_valid_dtype
))
def
_eagle_step_slot_mapping_metadata_kernel_impl
(
positions
,
block_table
,
stride
,
seq_lens
,
out_clamped_positions
,
out_slot_mapping
,
block_size
,
max_model_len
,
n_blocks_per_req
,
PAD_ID
,
batch_size
=
None
,
):
assert
batch_size
is
None
or
batch_size
==
positions
.
shape
[
0
],
(
f
"batch_size mismatch:
{
batch_size
}
vs positions.shape[0]=
{
positions
.
shape
[
0
]
}
"
)
torch
.
ops
.
_C
.
eagle_step_slot_mapping_metadata_kernel_impl
(
positions
,
block_table
,
seq_lens
,
out_clamped_positions
,
out_slot_mapping
,
block_size
,
max_model_len
,
PAD_ID
,
)
def
_copy_and_expand_eagle_inputs_kernel_impl
(
target_token_ids_ptr
,
target_positions_ptr
,
next_token_ids_ptr
,
out_input_ids_ptr
,
out_positions_ptr
,
out_is_rejected_token_mask_ptr
,
out_is_masked_token_mask_ptr
,
out_new_token_indices_ptr
,
out_hidden_state_mapping_ptr
,
query_start_loc_ptr
,
query_end_loc_ptr
,
padding_token_id
,
parallel_drafting_token_id
,
total_input_tokens
,
num_padding_slots_per_request
,
shift_input_ids
,
BLOCK_SIZE_TOKENS
=
None
,
BLOCK_SIZE_REQS
=
None
,
):
"""Adapter between Triton kernel call convention and C++ implementation.
The Triton kernel uses '_ptr' suffixed parameter names and compile-time
constants (BLOCK_SIZE_TOKENS, BLOCK_SIZE_REQS) which are not needed by
the C++ implementation. C++ reads token id tensors as int64_t*.
Output tensors that are int32 need copy-back after C++ writes int64.
"""
orig_ids_dtype
=
out_input_ids_ptr
.
dtype
orig_pos_dtype
=
out_positions_ptr
.
dtype
out_ids_i64
=
_ensure_int64
(
out_input_ids_ptr
)
out_pos_i64
=
_ensure_int64
(
out_positions_ptr
)
torch
.
ops
.
_C
.
copy_and_expand_eagle_inputs_kernel_impl
(
_ensure_int64
(
target_token_ids_ptr
),
_ensure_int64
(
target_positions_ptr
),
_ensure_int64
(
next_token_ids_ptr
),
out_ids_i64
,
out_pos_i64
,
out_is_rejected_token_mask_ptr
,
out_is_masked_token_mask_ptr
,
out_new_token_indices_ptr
,
out_hidden_state_mapping_ptr
,
query_start_loc_ptr
,
query_end_loc_ptr
,
padding_token_id
,
parallel_drafting_token_id
,
total_input_tokens
,
num_padding_slots_per_request
,
shift_input_ids
,
)
if
orig_ids_dtype
!=
torch
.
int64
:
out_input_ids_ptr
.
copy_
(
out_ids_i64
.
to
(
orig_ids_dtype
))
if
orig_pos_dtype
!=
torch
.
int64
:
out_positions_ptr
.
copy_
(
out_pos_i64
.
to
(
orig_pos_dtype
))
def
_rejection_greedy_sample_kernel_impl
(
output_token_ids
,
cu_num_draft_tokens
,
draft_token_ids
,
target_argmax
,
bonus_token_ids
,
is_greedy
,
max_spec_len
,
):
# C++ kernel expects int64 for all integer tensors.
orig_dtype
=
output_token_ids
.
dtype
output_token_ids_i64
=
_ensure_int64
(
output_token_ids
)
torch
.
ops
.
_C
.
rejection_greedy_sample_kernel_impl
(
output_token_ids_i64
,
_ensure_int64
(
cu_num_draft_tokens
),
_ensure_int64
(
draft_token_ids
),
_ensure_int64
(
target_argmax
),
_ensure_int64
(
bonus_token_ids
),
is_greedy
,
max_spec_len
,
)
if
orig_dtype
!=
torch
.
int64
:
output_token_ids
.
copy_
(
output_token_ids_i64
.
to
(
orig_dtype
))
def
_rejection_random_sample_kernel_impl
(
output_token_ids
,
cu_num_draft_tokens
,
draft_token_ids
,
draft_probs
,
target_probs
,
bonus_token_ids
,
recovered_token_ids
,
uniform_probs
,
is_greedy
,
max_spec_len
,
vocab_size
,
NO_DRAFT_PROBS
=
False
,
):
# C++ kernel expects int64 for all integer tensors and float32 for probs.
# uniform_probs is intentionally float64 in Python to avoid exact-zero
# samples; cast to float32 here for C++ compatibility.
orig_dtype
=
output_token_ids
.
dtype
output_token_ids_i64
=
_ensure_int64
(
output_token_ids
)
torch
.
ops
.
_C
.
rejection_random_sample_kernel_impl
(
output_token_ids_i64
,
_ensure_int64
(
cu_num_draft_tokens
),
_ensure_int64
(
draft_token_ids
),
draft_probs
,
target_probs
,
_ensure_int64
(
bonus_token_ids
),
_ensure_int64
(
recovered_token_ids
),
uniform_probs
.
to
(
torch
.
float32
),
is_greedy
,
max_spec_len
,
vocab_size
,
NO_DRAFT_PROBS
,
)
if
orig_dtype
!=
torch
.
int64
:
output_token_ids
.
copy_
(
output_token_ids_i64
.
to
(
orig_dtype
))
def
_expand_kernel_impl
(
output
,
input_val
,
cu_num_tokens
,
replace_from
,
replace_to
,
MAX_NUM_TOKENS
=
None
,
):
torch
.
ops
.
_C
.
expand_kernel_impl
(
_ensure_int64
(
output
),
_ensure_int64
(
input_val
),
_ensure_int64
(
cu_num_tokens
),
replace_from
,
replace_to
,
)
def
_sample_recovered_tokens_kernel_impl
(
output_token_ids
,
cu_num_draft_tokens
,
draft_token_ids
,
draft_probs
,
target_probs
,
inv_q
,
vocab_size
,
BLOCK_SIZE
=
None
,
NO_DRAFT_PROBS
=
False
,
):
# C++ reads integer tensors as int64_t*; ensure correct dtype.
orig_dtype
=
output_token_ids
.
dtype
output_i64
=
_ensure_int64
(
output_token_ids
)
torch
.
ops
.
_C
.
sample_recovered_tokens_kernel_impl
(
output_i64
,
_ensure_int64
(
cu_num_draft_tokens
),
_ensure_int64
(
draft_token_ids
),
draft_probs
,
target_probs
,
inv_q
,
vocab_size
,
NO_DRAFT_PROBS
,
)
if
orig_dtype
!=
torch
.
int64
:
output_token_ids
.
copy_
(
output_i64
.
to
(
orig_dtype
))
eagle_prepare_inputs_padded_kernel
=
_FuncWrapper
(
_eagle_prepare_inputs_padded_kernel_impl
)
eagle_prepare_next_token_padded_kernel
=
_FuncWrapper
(
_eagle_prepare_next_token_padded_kernel_impl
)
copy_and_expand_eagle_inputs_kernel
=
_FuncWrapper
(
_copy_and_expand_eagle_inputs_kernel_impl
)
eagle_step_slot_mapping_metadata_kernel
=
_FuncWrapper
(
_eagle_step_slot_mapping_metadata_kernel_impl
)
rejection_greedy_sample_kernel
=
_FuncWrapper
(
_rejection_greedy_sample_kernel_impl
)
rejection_random_sample_kernel
=
_FuncWrapper
(
_rejection_random_sample_kernel_impl
)
expand_kernel
=
_FuncWrapper
(
_expand_kernel_impl
)
sample_recovered_tokens_kernel
=
_FuncWrapper
(
_sample_recovered_tokens_kernel_impl
)
vllm/v1/spec_decode/eagle.py
View file @
445a2a4d
...
@@ -26,7 +26,6 @@ from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
...
@@ -26,7 +26,6 @@ from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from
vllm.model_executor.models.qwen3_dflash
import
DFlashQwen3ForCausalLM
from
vllm.model_executor.models.qwen3_dflash
import
DFlashQwen3ForCausalLM
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
triton
from
vllm.utils.platform_utils
import
is_pin_memory_available
from
vllm.utils.platform_utils
import
is_pin_memory_available
from
vllm.v1.attention.backend
import
CommonAttentionMetadata
from
vllm.v1.attention.backend
import
CommonAttentionMetadata
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
...
@@ -48,6 +47,7 @@ from vllm.v1.spec_decode.utils import (
...
@@ -48,6 +47,7 @@ from vllm.v1.spec_decode.utils import (
eagle_prepare_next_token_padded_kernel
,
eagle_prepare_next_token_padded_kernel
,
eagle_step_update_slot_mapping_and_metadata
,
eagle_step_update_slot_mapping_and_metadata
,
extend_all_queries_by_N
,
extend_all_queries_by_N
,
next_power_of_2
,
)
)
from
vllm.v1.utils
import
CpuGpuBuffer
from
vllm.v1.utils
import
CpuGpuBuffer
from
vllm.v1.worker.dp_utils
import
coordinate_batch_across_dp
from
vllm.v1.worker.dp_utils
import
coordinate_batch_across_dp
...
@@ -689,9 +689,7 @@ class SpecDecodeBaseProposer:
...
@@ -689,9 +689,7 @@ class SpecDecodeBaseProposer:
max_num_tokens_per_request
=
(
max_num_tokens_per_request
=
(
cad
.
max_query_len
+
self
.
net_num_new_slots_per_request
cad
.
max_query_len
+
self
.
net_num_new_slots_per_request
)
)
BLOCK_SIZE_TOKENS
=
min
(
BLOCK_SIZE_TOKENS
=
min
(
256
,
next_power_of_2
(
max_num_tokens_per_request
))
256
,
triton
.
next_power_of_2
(
max_num_tokens_per_request
)
)
num_blocks
=
(
num_blocks
=
(
max_num_tokens_per_request
+
BLOCK_SIZE_TOKENS
-
1
max_num_tokens_per_request
+
BLOCK_SIZE_TOKENS
-
1
)
//
BLOCK_SIZE_TOKENS
)
//
BLOCK_SIZE_TOKENS
...
@@ -717,6 +715,7 @@ class SpecDecodeBaseProposer:
...
@@ -717,6 +715,7 @@ class SpecDecodeBaseProposer:
query_end_loc
=
cad
.
query_start_loc
[
1
:]
-
1
query_end_loc
=
cad
.
query_start_loc
[
1
:]
-
1
if
num_rejected_tokens_gpu
is
not
None
:
if
num_rejected_tokens_gpu
is
not
None
:
query_end_loc
=
query_end_loc
-
num_rejected_tokens_gpu
query_end_loc
=
query_end_loc
-
num_rejected_tokens_gpu
copy_and_expand_eagle_inputs_kernel
[
grid
](
copy_and_expand_eagle_inputs_kernel
[
grid
](
# (Padded) Inputs from the target model
# (Padded) Inputs from the target model
target_token_ids_ptr
=
target_token_ids
,
target_token_ids_ptr
=
target_token_ids
,
...
@@ -899,7 +898,7 @@ class SpecDecodeBaseProposer:
...
@@ -899,7 +898,7 @@ class SpecDecodeBaseProposer:
grid
=
(
batch_size
,)
grid
=
(
batch_size
,)
# Find the next power of 2 for block sizes
# Find the next power of 2 for block sizes
BLOCK_SIZE_TOKENS
=
triton
.
next_power_of_2
(
num_tokens
)
BLOCK_SIZE_TOKENS
=
next_power_of_2
(
num_tokens
)
eagle_prepare_next_token_padded_kernel
[
grid
](
eagle_prepare_next_token_padded_kernel
[
grid
](
sampled_token_ids
,
sampled_token_ids
,
discard_request_mask
,
discard_request_mask
,
...
...
vllm/v1/spec_decode/utils.py
View file @
445a2a4d
...
@@ -11,6 +11,20 @@ from vllm.v1.attention.backends.utils import (
...
@@ -11,6 +11,20 @@ from vllm.v1.attention.backends.utils import (
PADDING_SLOT_ID
=
-
1
PADDING_SLOT_ID
=
-
1
def
next_power_of_2
(
n
:
int
)
->
int
:
"""Return the smallest power of 2 >= n."""
if
n
<=
0
:
return
1
n
-=
1
n
|=
n
>>
1
n
|=
n
>>
2
n
|=
n
>>
4
n
|=
n
>>
8
n
|=
n
>>
16
n
|=
n
>>
32
return
n
+
1
@
triton
.
jit
@
triton
.
jit
def
eagle_step_slot_mapping_metadata_kernel
(
def
eagle_step_slot_mapping_metadata_kernel
(
positions_ptr
,
# [batch_size] - current positions (1D view for M-RoPE)
positions_ptr
,
# [batch_size] - current positions (1D view for M-RoPE)
...
@@ -102,8 +116,8 @@ def eagle_step_update_slot_mapping_and_metadata(
...
@@ -102,8 +116,8 @@ def eagle_step_update_slot_mapping_and_metadata(
batch_size
=
positions_1d
.
shape
[
0
]
batch_size
=
positions_1d
.
shape
[
0
]
if
input_batch_size
is
None
:
if
input_batch_size
is
None
:
input_batch_size
=
batch_size
input_batch_size
=
batch_size
n_blocks_per_req
=
block_table_tensor
.
shape
[
1
]
n_blocks_per_req
=
block_table_tensor
.
shape
[
1
]
eagle_step_slot_mapping_metadata_kernel
[(
input_batch_size
,)](
eagle_step_slot_mapping_metadata_kernel
[(
input_batch_size
,)](
positions_1d
,
positions_1d
,
block_table_tensor
,
block_table_tensor
,
...
...
vllm/v1/worker/cpu_model_runner.py
View file @
445a2a4d
...
@@ -11,6 +11,8 @@ from vllm.config import VllmConfig
...
@@ -11,6 +11,8 @@ from vllm.config import VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.tracing
import
instrument
from
vllm.tracing
import
instrument
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.utils
import
CpuGpuBuffer
from
vllm.v1.utils
import
CpuGpuBuffer
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
...
@@ -23,7 +25,7 @@ class CPUModelRunner(GPUModelRunner):
...
@@ -23,7 +25,7 @@ class CPUModelRunner(GPUModelRunner):
super
().
__init__
(
vllm_config
,
device
)
super
().
__init__
(
vllm_config
,
device
)
assert
device
==
torch
.
device
(
"cpu"
)
assert
device
==
torch
.
device
(
"cpu"
)
assert
self
.
speculative_config
is
None
,
"spec decode is not supported."
# Note: speculative decoding is now supported on CPU with C++ native impls
self
.
use_cuda_graph
=
False
self
.
use_cuda_graph
=
False
self
.
cascade_attn_enabled
=
False
self
.
cascade_attn_enabled
=
False
...
@@ -61,6 +63,34 @@ class CPUModelRunner(GPUModelRunner):
...
@@ -61,6 +63,34 @@ class CPUModelRunner(GPUModelRunner):
cpu_tl
.
compute_slot_mapping_kernel
cpu_tl
.
compute_slot_mapping_kernel
)
)
# Speculative decoding fallbacks
import
vllm.v1.sample.rejection_sampler
import
vllm.v1.spec_decode.eagle
import
vllm.v1.spec_decode.utils
vllm
.
v1
.
spec_decode
.
eagle
.
eagle_prepare_inputs_padded_kernel
=
(
cpu_tl
.
eagle_prepare_inputs_padded_kernel
)
vllm
.
v1
.
spec_decode
.
eagle
.
eagle_prepare_next_token_padded_kernel
=
(
cpu_tl
.
eagle_prepare_next_token_padded_kernel
)
vllm
.
v1
.
spec_decode
.
eagle
.
copy_and_expand_eagle_inputs_kernel
=
(
cpu_tl
.
copy_and_expand_eagle_inputs_kernel
)
vllm
.
v1
.
spec_decode
.
utils
.
eagle_step_slot_mapping_metadata_kernel
=
(
cpu_tl
.
eagle_step_slot_mapping_metadata_kernel
)
vllm
.
v1
.
sample
.
rejection_sampler
.
rejection_greedy_sample_kernel
=
(
cpu_tl
.
rejection_greedy_sample_kernel
)
vllm
.
v1
.
sample
.
rejection_sampler
.
rejection_random_sample_kernel
=
(
cpu_tl
.
rejection_random_sample_kernel
)
vllm
.
v1
.
sample
.
rejection_sampler
.
expand_kernel
=
cpu_tl
.
expand_kernel
vllm
.
v1
.
sample
.
rejection_sampler
.
sample_recovered_tokens_kernel
=
(
cpu_tl
.
sample_recovered_tokens_kernel
)
@
instrument
(
span_name
=
"Loading (CPU)"
)
@
instrument
(
span_name
=
"Loading (CPU)"
)
def
load_model
(
self
,
load_dummy_weights
:
bool
=
False
)
->
None
:
def
load_model
(
self
,
load_dummy_weights
:
bool
=
False
)
->
None
:
if
load_dummy_weights
:
if
load_dummy_weights
:
...
@@ -74,6 +104,10 @@ class CPUModelRunner(GPUModelRunner):
...
@@ -74,6 +104,10 @@ class CPUModelRunner(GPUModelRunner):
if
self
.
lora_config
:
if
self
.
lora_config
:
self
.
model
=
self
.
load_lora_model
(
self
.
model
,
self
.
vllm_config
,
self
.
device
)
self
.
model
=
self
.
load_lora_model
(
self
.
model
,
self
.
vllm_config
,
self
.
device
)
if
hasattr
(
self
,
"drafter"
):
logger
.
info_once
(
"Loading drafter model..."
)
self
.
drafter
.
load_model
(
self
.
model
)
def
get_model
(
self
)
->
nn
.
Module
:
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
model
return
self
.
model
...
@@ -89,8 +123,29 @@ class CPUModelRunner(GPUModelRunner):
...
@@ -89,8 +123,29 @@ class CPUModelRunner(GPUModelRunner):
)
)
)
)
# Warm up drafter for speculative decoding
if
self
.
speculative_config
and
(
self
.
speculative_config
.
uses_draft_model
()):
from
vllm.v1.spec_decode.draft_model
import
DraftModelProposer
if
isinstance
(
self
.
drafter
,
(
DraftModelProposer
)):
logger
.
info
(
"Warming up drafter model..."
)
self
.
drafter
.
dummy_run
(
max
(
16
,
self
.
max_num_reqs
))
logger
.
info
(
"Warming up done."
)
logger
.
info
(
"Warming up done."
)
def
initialize_kv_cache
(
self
,
kv_cache_config
:
KVCacheConfig
,
is_profiling
:
bool
=
False
,
)
->
None
:
super
().
initialize_kv_cache
(
kv_cache_config
,
is_profiling
)
if
self
.
speculative_config
:
if
self
.
speculative_config
.
use_eagle
():
logger
.
info
(
"EAGLE drafter KV cache initialized for CPU backend"
)
elif
self
.
speculative_config
.
uses_draft_model
():
logger
.
info
(
"Draft model KV cache initialized for CPU backend"
)
def
_init_device_properties
(
self
)
->
None
:
def
_init_device_properties
(
self
)
->
None
:
pass
pass
...
@@ -102,6 +157,71 @@ class CPUModelRunner(GPUModelRunner):
...
@@ -102,6 +157,71 @@ class CPUModelRunner(GPUModelRunner):
# so stale KV cache data never affects computation.
# so stale KV cache data never affects computation.
pass
pass
# =========================================================================
# CPU-safe overrides for speculative decoding methods
# These methods override GPU-specific implementations that use CUDA streams
# =========================================================================
def
_copy_draft_token_ids_to_cpu
(
self
,
scheduler_output
:
"SchedulerOutput"
,
zeros_only
:
bool
=
False
)
->
None
:
"""CPU-safe version: no async copy needed, tensors already on CPU."""
if
self
.
use_async_scheduling
and
not
(
scheduler_output
.
has_structured_output_requests
or
self
.
input_batch
.
sampling_metadata
.
output_token_ids
):
return
self
.
_draft_token_req_ids
=
self
.
input_batch
.
req_ids
.
copy
()
draft_token_ids
:
torch
.
Tensor
=
self
.
_draft_token_ids
if
not
torch
.
is_tensor
(
draft_token_ids
):
return
num_reqs
=
draft_token_ids
.
shape
[
0
]
if
self
.
draft_token_ids_cpu
is
not
None
:
if
not
zeros_only
:
self
.
draft_token_ids_cpu
[:
num_reqs
].
copy_
(
draft_token_ids
)
else
:
self
.
draft_token_ids_cpu
[:
num_reqs
]
=
0
def
_get_draft_token_ids_cpu
(
self
)
->
tuple
[
list
[
list
[
int
]],
list
[
str
]]:
"""CPU-safe version: no event synchronization needed."""
if
isinstance
(
self
.
_draft_token_ids
,
list
):
return
self
.
_draft_token_ids
,
self
.
input_batch
.
req_ids
req_ids
=
self
.
_draft_token_req_ids
if
req_ids
is
None
:
return
[],
[]
if
self
.
draft_token_ids_cpu
is
not
None
:
return
self
.
draft_token_ids_cpu
[:
len
(
req_ids
)].
tolist
(),
req_ids
return
[],
[]
def
_copy_valid_sampled_token_count
(
self
,
next_token_ids
:
torch
.
Tensor
,
valid_sampled_tokens_count
:
torch
.
Tensor
)
->
None
:
"""CPU-safe version: direct copy without CUDA streams."""
if
self
.
valid_sampled_token_count_cpu
is
None
:
return
counts
=
valid_sampled_tokens_count
counts_cpu
=
self
.
valid_sampled_token_count_cpu
counts_cpu
[:
counts
.
shape
[
0
]].
copy_
(
counts
)
self
.
input_batch
.
prev_sampled_token_ids
=
next_token_ids
.
unsqueeze
(
1
)
def
_get_valid_sampled_token_count
(
self
)
->
list
[
int
]:
"""CPU-safe version: no event synchronization needed."""
prev_sampled_token_ids
=
self
.
input_batch
.
prev_sampled_token_ids
if
prev_sampled_token_ids
is
None
:
return
[]
counts_cpu
=
self
.
valid_sampled_token_count_cpu
if
counts_cpu
is
None
:
return
[]
return
counts_cpu
[:
prev_sampled_token_ids
.
shape
[
0
]].
tolist
()
def
_to_list
(
self
,
sampled_token_ids
:
torch
.
Tensor
)
->
list
[
list
[
int
]]:
"""CPU-safe version: direct tolist() without CUDA events."""
return
sampled_token_ids
.
tolist
()
@
contextmanager
@
contextmanager
def
_torch_cuda_wrapper
():
def
_torch_cuda_wrapper
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment