Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5d6d1adf
Unverified
Commit
5d6d1adf
authored
Jun 04, 2025
by
Vadim Gimpelson
Committed by
GitHub
Jun 03, 2025
Browse files
[KERNEL] Sampler. CUDA kernel for applying repetition penalty (#18437)
parent
1409ef91
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
218 additions
and
9 deletions
+218
-9
CMakeLists.txt
CMakeLists.txt
+1
-0
csrc/ops.h
csrc/ops.h
+5
-0
csrc/sampler.cu
csrc/sampler.cu
+86
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+7
-0
tests/kernels/test_apply_repetition_penalties.py
tests/kernels/test_apply_repetition_penalties.py
+76
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+39
-0
vllm/model_executor/layers/utils.py
vllm/model_executor/layers/utils.py
+4
-9
No files found.
CMakeLists.txt
View file @
5d6d1adf
...
...
@@ -242,6 +242,7 @@ set(VLLM_EXT_SRC
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/layernorm_quant_kernels.cu"
"csrc/sampler.cu"
"csrc/cuda_view.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
...
...
csrc/ops.h
View file @
5d6d1adf
...
...
@@ -92,6 +92,11 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
weight
,
double
epsilon
);
void
apply_repetition_penalties_
(
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
prompt_mask
,
const
torch
::
Tensor
&
output_mask
,
const
torch
::
Tensor
&
repetition_penalties
);
void
rms_norm_static_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
torch
::
Tensor
&
scale
,
double
epsilon
);
...
...
csrc/sampler.cu
0 → 100644
View file @
5d6d1adf
#include "dispatch_utils.h"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#ifndef USE_ROCM
#include <cub/cub.cuh>
#else
#include <hipcub/hipcub.hpp>
#endif
namespace
vllm
{
template
<
typename
scalar_t
>
__global__
void
apply_repetition_penalties_kernel
(
scalar_t
*
__restrict__
logits
,
// [num_seqs, vocab_size]
const
bool
*
__restrict__
prompt_mask
,
// [num_seqs, vocab_size]
const
bool
*
__restrict__
output_mask
,
// [num_seqs, vocab_size]
const
scalar_t
*
__restrict__
repetition_penalties
,
// [num_seqs]
const
int
num_seqs
,
const
int
vocab_size
,
const
int
tile_size
)
{
// Each block handles one sequence and a tile of vocab
const
int
seq_idx
=
blockIdx
.
x
;
if
(
seq_idx
>=
num_seqs
)
return
;
const
int
tile_start
=
blockIdx
.
y
*
tile_size
;
const
int
tile_end
=
min
(
tile_start
+
tile_size
,
vocab_size
);
// Load repetition penalty for this sequence
const
scalar_t
penalty
=
repetition_penalties
[
seq_idx
];
// Each thread processes multiple vocab items within the tile
for
(
int
vocab_idx
=
tile_start
+
threadIdx
.
x
;
vocab_idx
<
tile_end
;
vocab_idx
+=
blockDim
.
x
)
{
const
int64_t
idx
=
static_cast
<
int64_t
>
(
seq_idx
)
*
vocab_size
+
vocab_idx
;
const
bool
is_repeated
=
prompt_mask
[
idx
]
||
output_mask
[
idx
];
if
(
is_repeated
)
{
scalar_t
logit
=
logits
[
idx
];
if
(
logit
>
0
)
{
logits
[
idx
]
=
logit
/
penalty
;
}
else
{
logits
[
idx
]
=
logit
*
penalty
;
}
}
}
}
}
// namespace vllm
void
apply_repetition_penalties_
(
torch
::
Tensor
&
logits
,
// [num_seqs, vocab_size], in-place
const
torch
::
Tensor
&
prompt_mask
,
// [num_seqs, vocab_size]
const
torch
::
Tensor
&
output_mask
,
// [num_seqs, vocab_size]
const
torch
::
Tensor
&
repetition_penalties
)
{
// [num_seqs]
TORCH_CHECK
(
logits
.
is_contiguous
());
TORCH_CHECK
(
prompt_mask
.
is_contiguous
());
TORCH_CHECK
(
output_mask
.
is_contiguous
());
TORCH_CHECK
(
repetition_penalties
.
is_contiguous
());
int
vocab_size
=
logits
.
size
(
-
1
);
int
num_seqs
=
logits
.
size
(
0
);
// Get number of SMs on the current device
int
sms
=
0
;
cudaDeviceGetAttribute
(
&
sms
,
cudaDevAttrMultiProcessorCount
,
logits
.
get_device
());
// Compute tile_num and tile_size
int
tile_num
=
std
::
min
(
vocab_size
,
std
::
max
(
1
,
(
sms
+
num_seqs
-
1
)
/
num_seqs
));
int
tile_size
=
(
vocab_size
+
tile_num
-
1
)
/
tile_num
;
// Each block handles one sequence and a tile of vocab
dim3
grid
(
num_seqs
,
tile_num
);
dim3
block
(
std
::
min
(
tile_size
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
logits
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
logits
.
scalar_type
(),
"apply_repetition_penalties_kernel"
,
[
&
]
{
vllm
::
apply_repetition_penalties_kernel
<
scalar_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
logits
.
data_ptr
<
scalar_t
>
(),
prompt_mask
.
data_ptr
<
bool
>
(),
output_mask
.
data_ptr
<
bool
>
(),
repetition_penalties
.
data_ptr
<
scalar_t
>
(),
num_seqs
,
vocab_size
,
tile_size
);
});
}
\ No newline at end of file
csrc/torch_bindings.cpp
View file @
5d6d1adf
...
...
@@ -170,6 +170,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float epsilon) -> ()"
);
ops
.
impl
(
"fused_add_rms_norm"
,
torch
::
kCUDA
,
&
fused_add_rms_norm
);
// Apply repetition penalties to logits in-place
ops
.
def
(
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
"Tensor output_mask, Tensor repetition_penalties) -> ()"
);
ops
.
impl
(
"apply_repetition_penalties_"
,
torch
::
kCUDA
,
&
apply_repetition_penalties_
);
// Layernorm-quant
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops
.
def
(
...
...
tests/kernels/test_apply_repetition_penalties.py
0 → 100644
View file @
5d6d1adf
# SPDX-License-Identifier: Apache-2.0
import
pytest
import
torch
from
tests.kernels.utils
import
opcheck
from
vllm._custom_ops
import
(
apply_repetition_penalties_cuda
,
apply_repetition_penalties_torch
)
from
vllm.platforms
import
current_platform
NUM_SEQS
=
[
1
,
2
,
3
,
4
,
8
,
13
,
17
,
32
,
37
,
256
,
1023
,
1024
,
1025
]
# [stress, stress, stress, Qwen, llama 4]
VOCAB_SIZES
=
[
17
,
256
,
1019
,
151936
,
202048
]
REPETITION_PENALTY_VALUES
=
[
1.05
]
SEEDS
=
[
0
]
DTYPES
=
[
torch
.
float32
,
torch
.
float16
]
@
pytest
.
mark
.
parametrize
(
"num_seqs"
,
NUM_SEQS
)
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
VOCAB_SIZES
)
@
pytest
.
mark
.
parametrize
(
"repetition_penalty"
,
REPETITION_PENALTY_VALUES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"This test for checking CUDA kernel"
)
@
torch
.
inference_mode
()
def
test_apply_repetition_penalties
(
num_seqs
:
int
,
vocab_size
:
int
,
repetition_penalty
:
float
,
dtype
:
torch
.
dtype
,
seed
:
int
,
)
->
None
:
"""
Test the apply_repetition_penalties custom op
against a reference implementation.
"""
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
"cuda:0"
)
# Create test data
logits
=
torch
.
randn
(
num_seqs
,
vocab_size
,
dtype
=
dtype
)
# Create masks with some random tokens marked as repeated
prompt_mask
=
torch
.
zeros
(
num_seqs
,
vocab_size
,
dtype
=
torch
.
bool
)
output_mask
=
torch
.
zeros
(
num_seqs
,
vocab_size
,
dtype
=
torch
.
bool
)
# Mark some tokens as repeated in prompt and output
prompt_indices
=
torch
.
randint
(
0
,
vocab_size
,
(
num_seqs
,
max
(
1
,
vocab_size
//
200
)))
output_indices
=
torch
.
randint
(
0
,
vocab_size
,
(
num_seqs
,
max
(
1
,
vocab_size
//
200
)))
for
i
in
range
(
num_seqs
):
prompt_mask
[
i
,
prompt_indices
[
i
]]
=
True
output_mask
[
i
,
output_indices
[
i
]]
=
True
# Create repetition penalties tensor
repetition_penalties
=
torch
.
full
((
num_seqs
,
),
repetition_penalty
,
dtype
=
dtype
)
# Run all three implementations
logits_torch
=
logits
.
clone
()
logits_cuda
=
logits
.
clone
()
apply_repetition_penalties_torch
(
logits_torch
,
prompt_mask
,
output_mask
,
repetition_penalties
)
apply_repetition_penalties_cuda
(
logits_cuda
,
prompt_mask
,
output_mask
,
repetition_penalties
)
# Compare all outputs to reference
torch
.
testing
.
assert_close
(
logits_torch
,
logits_cuda
,
rtol
=
1e-3
,
atol
=
1e-3
)
# Test the operator by applying the opcheck utility
opcheck
(
torch
.
ops
.
_C
.
apply_repetition_penalties_
,
(
logits
.
clone
(),
prompt_mask
,
output_mask
,
repetition_penalties
))
vllm/_custom_ops.py
View file @
5d6d1adf
...
...
@@ -282,6 +282,45 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
torch
.
ops
.
_C
.
fused_add_rms_norm
(
input
,
residual
,
weight
,
epsilon
)
def
apply_repetition_penalties_torch
(
logits
:
torch
.
Tensor
,
prompt_mask
:
torch
.
Tensor
,
output_mask
:
torch
.
Tensor
,
repetition_penalties
:
torch
.
Tensor
)
->
None
:
repetition_penalties
=
repetition_penalties
.
unsqueeze
(
dim
=
1
).
repeat
(
1
,
logits
.
size
(
1
))
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
penalties
=
torch
.
where
(
prompt_mask
|
output_mask
,
repetition_penalties
,
1.0
)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
scaling
=
torch
.
where
(
logits
>
0
,
1.0
/
penalties
,
penalties
)
logits
*=
scaling
def
apply_repetition_penalties_cuda
(
logits
:
torch
.
Tensor
,
prompt_mask
:
torch
.
Tensor
,
output_mask
:
torch
.
Tensor
,
repetition_penalties
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C
.
apply_repetition_penalties_
(
logits
,
prompt_mask
,
output_mask
,
repetition_penalties
)
def
apply_repetition_penalties
(
logits
:
torch
.
Tensor
,
prompt_mask
:
torch
.
Tensor
,
output_mask
:
torch
.
Tensor
,
repetition_penalties
:
torch
.
Tensor
)
->
None
:
"""Apply repetition penalties to logits in-place.
Args:
logits: The logits tensor of shape [num_seqs, vocab_size].
prompt_mask: A boolean tensor indicating which tokens appear in the prompt.
output_mask: A boolean tensor indicating which tokens appear in the output.
repetition_penalties: The repetition penalties of shape (num_seqs, ).
"""
if
current_platform
.
is_cuda
()
and
logits
.
is_contiguous
():
apply_repetition_penalties_cuda
(
logits
,
prompt_mask
,
output_mask
,
repetition_penalties
)
else
:
apply_repetition_penalties_torch
(
logits
,
prompt_mask
,
output_mask
,
repetition_penalties
)
def
advance_step_flashattn
(
num_seqs
:
int
,
num_queries
:
int
,
block_size
:
int
,
input_tokens
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/utils.py
View file @
5d6d1adf
...
...
@@ -50,16 +50,11 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
vocab_size
,
num_seqs
)
output_bin_counts
,
output_mask
=
get_token_bin_counts_and_mask
(
output_tokens_tensor
,
vocab_size
,
num_seqs
)
repetition_penalties
=
repetition_penalties
.
unsqueeze
(
dim
=
1
).
repeat
(
1
,
vocab_size
)
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
penalties
=
torch
.
where
(
prompt_mask
|
output_mask
,
repetition_penalties
,
1.0
)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
scaling
=
torch
.
where
(
logits
>
0
,
1.0
/
penalties
,
penalties
)
logits
*=
scaling
# Apply repetition penalties as a custom op
from
vllm._custom_ops
import
apply_repetition_penalties
apply_repetition_penalties
(
logits
,
prompt_mask
,
output_mask
,
repetition_penalties
)
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment