Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
911c8eb0
Unverified
Commit
911c8eb0
authored
Mar 24, 2025
by
Woosuk Kwon
Committed by
GitHub
Mar 24, 2025
Browse files
[Minor][Spec Decode] Remove compiled_softmax (#15416)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
ebcebeeb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
1 addition
and
33 deletions
+1
-33
vllm/v1/sample/ops/utils.py
vllm/v1/sample/ops/utils.py
+0
-30
vllm/v1/sample/rejection_sampler.py
vllm/v1/sample/rejection_sampler.py
+1
-3
No files found.
vllm/v1/sample/ops/utils.py
deleted
100644 → 0
View file @
ebcebeeb
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Union
import
torch
def
compiled_softmax
(
logits
:
torch
.
Tensor
,
temperature
:
Union
[
float
,
torch
.
Tensor
]
=
1.0
,
)
->
torch
.
Tensor
:
"""Faster softmax kernel generated by torch.compile.
Args:
logits: [n, vocab_size]
temperature: [n] or float
"""
# NOTE(woosuk): Avoid recompilation by marking the first dim as dynamic.
torch
.
_dynamo
.
mark_dynamic
(
logits
,
index
=
0
)
if
isinstance
(
temperature
,
torch
.
Tensor
):
torch
.
_dynamo
.
mark_dynamic
(
temperature
,
index
=
0
)
return
_softmax
(
logits
,
temperature
)
@
torch
.
compile
def
_softmax
(
logits
:
torch
.
Tensor
,
temperature
:
Union
[
float
,
torch
.
Tensor
],
)
->
torch
.
Tensor
:
logits
=
logits
/
temperature
return
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
vllm/v1/sample/rejection_sampler.py
View file @
911c8eb0
...
@@ -9,7 +9,6 @@ import triton.language as tl
...
@@ -9,7 +9,6 @@ import triton.language as tl
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.ops.topk_topp_sampler
import
apply_top_k_top_p
from
vllm.v1.sample.ops.topk_topp_sampler
import
apply_top_k_top_p
from
vllm.v1.sample.ops.utils
import
compiled_softmax
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -275,8 +274,7 @@ def compute_probs(
...
@@ -275,8 +274,7 @@ def compute_probs(
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
# which is slow for large vocab sizes. This may cause performance issues.
# which is slow for large vocab sizes. This may cause performance issues.
logits
=
apply_top_k_top_p
(
logits
,
top_k
,
top_p
)
logits
=
apply_top_k_top_p
(
logits
,
top_k
,
top_p
)
output_prob
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
output_prob
=
compiled_softmax
(
logits
)
return
output_prob
return
output_prob
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment