Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
81b979f2
Unverified
Commit
81b979f2
authored
Dec 27, 2024
by
Woosuk Kwon
Committed by
GitHub
Dec 27, 2024
Browse files
[V1] Fix yapf (#11538)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
371d04d3
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
19 deletions
+21
-19
vllm/v1/sample/ops/penalties.py
vllm/v1/sample/ops/penalties.py
+13
-11
vllm/v1/sample/sampler.py
vllm/v1/sample/sampler.py
+8
-8
No files found.
vllm/v1/sample/ops/penalties.py
View file @
81b979f2
...
...
@@ -2,8 +2,7 @@ from typing import List, Set, Tuple
import
torch
from
vllm.model_executor.layers.utils
import
(
apply_penalties
as
_apply_penalties
)
from
vllm.model_executor.layers.utils
import
apply_penalties
from
vllm.utils
import
is_pin_memory_available
,
make_tensor_with_pad
...
...
@@ -17,25 +16,28 @@ def apply_min_token_penalties(logits: torch.Tensor,
"""
min_tokens_logits_to_penalize
:
List
[
Tuple
[
int
,
int
]]
=
[]
for
index
,
min_token
in
enumerate
(
min_tokens
):
if
(
len
(
output_token_ids
[
index
])
<
min_token
)
:
if
len
(
output_token_ids
[
index
])
<
min_token
:
for
stop_token_id
in
stop_token_ids
[
index
]:
min_tokens_logits_to_penalize
.
append
((
index
,
stop_token_id
))
if
min_tokens_logits_to_penalize
:
logits
[
tuple
(
zip
(
*
min_tokens_logits_to_penalize
))]
=
-
float
(
"inf"
)
def
apply_penalties
(
logits
:
torch
.
Tensor
,
prompt_token_ids
:
torch
.
Tensor
,
def
apply_all_penalties
(
logits
:
torch
.
Tensor
,
prompt_token_ids
:
torch
.
Tensor
,
presence_penalties
:
torch
.
Tensor
,
frequency_penalties
:
torch
.
Tensor
,
repetition_penalties
:
torch
.
Tensor
,
output_token_ids
:
List
[
List
[
int
]])
->
torch
.
Tensor
:
output_token_ids
:
List
[
List
[
int
]],
)
->
torch
.
Tensor
:
"""
Applies presence, frequency and repetition penalties to the logits.
"""
_
,
vocab_size
=
logits
.
shape
output_tokens_t
=
_convert_to_tensors
(
output_token_ids
,
vocab_size
,
logits
.
device
)
return
_
apply_penalties
(
logits
,
prompt_token_ids
,
output_tokens_t
,
return
apply_penalties
(
logits
,
prompt_token_ids
,
output_tokens_t
,
presence_penalties
,
frequency_penalties
,
repetition_penalties
)
...
...
vllm/v1/sample/sampler.py
View file @
81b979f2
...
...
@@ -6,8 +6,8 @@ import torch.nn as nn
from
vllm.v1.outputs
import
SamplerOutput
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.ops.penalties
import
(
apply_
min_token
_penalties
,
apply_penalties
)
from
vllm.v1.sample.ops.penalties
import
(
apply_
all
_penalties
,
apply_
min_token_
penalties
)
from
vllm.v1.sample.ops.topk_topp_sampler
import
TopKTopPSampler
_SAMPLING_EPS
=
1e-5
...
...
@@ -127,8 +127,8 @@ class Sampler(nn.Module):
sampling_metadata
.
min_tokens
)
if
not
sampling_metadata
.
no_penalties
:
assert
sampling_metadata
.
prompt_token_ids
is
not
None
logits
=
apply_penalties
(
logits
,
sampling_metadata
.
prompt_token_ids
,
logits
=
apply_
all_
penalties
(
logits
,
sampling_metadata
.
prompt_token_ids
,
sampling_metadata
.
presence_penalties
,
sampling_metadata
.
frequency_penalties
,
sampling_metadata
.
repetition_penalties
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment