Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8267f991
Unverified
Commit
8267f991
authored
Jun 06, 2025
by
Yu Guo
Committed by
GitHub
Jun 06, 2025
Browse files
improve logits bias (#19041)
parent
7353492a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
1 deletion
+16
-1
vllm/v1/sample/sampler.py
vllm/v1/sample/sampler.py
+16
-1
No files found.
vllm/v1/sample/sampler.py
View file @
8267f991
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.utils
import
async_tensor_h2d
,
is_pin_memory_available
from
vllm.v1.outputs
import
LogprobsTensors
,
SamplerOutput
from
vllm.v1.outputs
import
LogprobsTensors
,
SamplerOutput
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.ops.bad_words
import
apply_bad_words
from
vllm.v1.sample.ops.bad_words
import
apply_bad_words
...
@@ -20,6 +21,7 @@ class Sampler(nn.Module):
...
@@ -20,6 +21,7 @@ class Sampler(nn.Module):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
self
.
topk_topp_sampler
=
TopKTopPSampler
()
self
.
topk_topp_sampler
=
TopKTopPSampler
()
self
.
pin_memory
=
is_pin_memory_available
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -232,6 +234,10 @@ class Sampler(nn.Module):
...
@@ -232,6 +234,10 @@ class Sampler(nn.Module):
# One idea is implement this as a PyTorch C++ op, and we may
# One idea is implement this as a PyTorch C++ op, and we may
# even optimize the logit_bias layout.
# even optimize the logit_bias layout.
rows
:
list
[
int
]
=
[]
cols
:
list
[
int
]
=
[]
vals
:
list
[
float
]
=
[]
# Get vocabulary size from logits
# Get vocabulary size from logits
vocab_size
=
logits
.
shape
[
-
1
]
vocab_size
=
logits
.
shape
[
-
1
]
...
@@ -244,7 +250,16 @@ class Sampler(nn.Module):
...
@@ -244,7 +250,16 @@ class Sampler(nn.Module):
f
"token_id
{
token_id
}
in logit_bias contains "
f
"token_id
{
token_id
}
in logit_bias contains "
f
"out-of-vocab token id. Vocabulary size: "
f
"out-of-vocab token id. Vocabulary size: "
f
"
{
vocab_size
}
"
)
f
"
{
vocab_size
}
"
)
logits
[
i
,
token_id
]
+=
bias
rows
.
append
(
i
)
cols
.
append
(
token_id
)
vals
.
append
(
bias
)
if
rows
:
indices
=
async_tensor_h2d
([
rows
,
cols
],
torch
.
int64
,
logits
.
device
,
self
.
pin_memory
)
values
=
async_tensor_h2d
(
vals
,
torch
.
float
,
logits
.
device
,
self
.
pin_memory
)
logits
.
index_put_
(
tuple
(
indices
),
values
=
values
,
accumulate
=
True
)
return
logits
return
logits
def
apply_allowed_token_ids
(
def
apply_allowed_token_ids
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment