Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0af3d4f0
Unverified
Commit
0af3d4f0
authored
Nov 19, 2025
by
vllmellm
Committed by
GitHub
Nov 18, 2025
Browse files
[FEAT] [AITER] [ROCm] integrate aiter sampling ops (#26084)
Signed-off-by:
vllmellm
<
vllm.ellm@embeddedllm.com
>
parent
da8dadf6
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
77 additions
and
0 deletions
+77
-0
vllm/v1/sample/ops/topk_topp_sampler.py
vllm/v1/sample/ops/topk_topp_sampler.py
+77
-0
No files found.
vllm/v1/sample/ops/topk_topp_sampler.py
View file @
0af3d4f0
...
@@ -7,6 +7,7 @@ import torch.nn as nn
...
@@ -7,6 +7,7 @@ import torch.nn as nn
from
packaging
import
version
from
packaging
import
version
from
vllm
import
envs
from
vllm
import
envs
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.config.model
import
LogprobsMode
from
vllm.config.model
import
LogprobsMode
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
CpuArchEnum
,
current_platform
from
vllm.platforms
import
CpuArchEnum
,
current_platform
...
@@ -55,6 +56,17 @@ class TopKTopPSampler(nn.Module):
...
@@ -55,6 +56,17 @@ class TopKTopPSampler(nn.Module):
self
.
forward
=
self
.
forward_native
self
.
forward
=
self
.
forward_native
else
:
else
:
self
.
forward
=
self
.
forward_cpu
self
.
forward
=
self
.
forward_cpu
elif
(
logprobs_mode
not
in
(
"processed_logits"
,
"processed_logprobs"
)
and
rocm_aiter_ops
.
is_enabled
()
):
import
aiter.ops.sampling
# noqa: F401
self
.
aiter_ops
=
torch
.
ops
.
aiter
logger
.
info_once
(
"Using aiter sampler on ROCm (lazy import, sampling-only)."
)
self
.
forward
=
self
.
forward_hip
else
:
else
:
self
.
forward
=
self
.
forward_native
self
.
forward
=
self
.
forward_native
...
@@ -138,6 +150,64 @@ class TopKTopPSampler(nn.Module):
...
@@ -138,6 +150,64 @@ class TopKTopPSampler(nn.Module):
return
probs
.
div_
(
q
).
argmax
(
dim
=-
1
).
view
(
-
1
),
logits_to_return
return
probs
.
div_
(
q
).
argmax
(
dim
=-
1
).
view
(
-
1
),
logits_to_return
def
forward_hip
(
self
,
logits
:
torch
.
Tensor
,
generators
:
dict
[
int
,
torch
.
Generator
],
k
:
torch
.
Tensor
|
None
,
p
:
torch
.
Tensor
|
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
"""Optimized ROCm/aiter path (same structure as forward_cuda)."""
if
(
k
is
None
and
p
is
None
)
or
generators
:
if
generators
:
logger
.
warning_once
(
"aiter sampler does not support per-request generators; "
"falling back to PyTorch-native."
)
return
self
.
forward_native
(
logits
,
generators
,
k
,
p
)
assert
self
.
logprobs_mode
not
in
(
"processed_logits"
,
"processed_logprobs"
,
),
"aiter sampler does not support returning logits/logprobs."
return
self
.
aiter_sample
(
logits
,
k
,
p
,
generators
),
None
def
aiter_sample
(
self
,
logits
:
torch
.
Tensor
,
k
:
torch
.
Tensor
|
None
,
p
:
torch
.
Tensor
|
None
,
generators
:
dict
[
int
,
torch
.
Generator
],
)
->
torch
.
Tensor
:
"""Sample from logits using aiter ops."""
use_top_k
=
k
is
not
None
use_top_p
=
p
is
not
None
# Joint k+p path
if
use_top_p
and
use_top_k
:
probs
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
).
contiguous
()
next_token_ids
=
self
.
aiter_ops
.
top_k_top_p_sampling_from_probs
(
probs
,
None
,
*
_to_tensor_scalar_tuple
(
k
),
*
_to_tensor_scalar_tuple
(
p
),
deterministic
=
True
,
)
return
next_token_ids
.
view
(
-
1
)
# Top-p only path
elif
use_top_p
:
probs
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
).
contiguous
()
next_token_ids
=
self
.
aiter_ops
.
top_p_sampling_from_probs
(
probs
,
None
,
*
_to_tensor_scalar_tuple
(
p
),
deterministic
=
True
)
return
next_token_ids
.
view
(
-
1
)
# Top-k only path
elif
use_top_k
:
probs
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
).
contiguous
()
renorm_probs
=
self
.
aiter_ops
.
top_k_renorm_probs
(
probs
,
*
_to_tensor_scalar_tuple
(
k
)
)
return
torch
.
multinomial
(
renorm_probs
,
num_samples
=
1
).
view
(
-
1
)
raise
RuntimeError
(
"aiter_sample was called with no active top-k or top-p."
)
# Note: this is a workaround for
# Note: this is a workaround for
# https://github.com/pytorch/pytorch/pull/151218
# https://github.com/pytorch/pytorch/pull/151218
...
@@ -288,3 +358,10 @@ def flashinfer_sample(
...
@@ -288,3 +358,10 @@ def flashinfer_sample(
)
)
return
next_token_ids
.
view
(
-
1
)
return
next_token_ids
.
view
(
-
1
)
def
_to_tensor_scalar_tuple
(
x
):
if
isinstance
(
x
,
torch
.
Tensor
):
return
(
x
,
0
)
else
:
return
(
None
,
x
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment