Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e2b16c47
Unverified
Commit
e2b16c47
authored
Jan 13, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Jan 12, 2025
Browse files
add sampling_scaling_penalties kernel (#2846)
parent
c4f9707e
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
150 additions
and
1 deletion
+150
-1
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+1
-0
sgl-kernel/pyproject.toml
sgl-kernel/pyproject.toml
+1
-1
sgl-kernel/setup.py
sgl-kernel/setup.py
+1
-0
sgl-kernel/src/sgl-kernel/__init__.py
sgl-kernel/src/sgl-kernel/__init__.py
+2
-0
sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu
sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu
+64
-0
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
+5
-0
sgl-kernel/src/sgl-kernel/csrc/vectorization.cuh
sgl-kernel/src/sgl-kernel/csrc/vectorization.cuh
+30
-0
sgl-kernel/src/sgl-kernel/ops/__init__.py
sgl-kernel/src/sgl-kernel/ops/__init__.py
+7
-0
sgl-kernel/tests/test_sampling_scaling_penalties.py
sgl-kernel/tests/test_sampling_scaling_penalties.py
+39
-0
No files found.
sgl-kernel/CMakeLists.txt
View file @
e2b16c47
...
...
@@ -32,6 +32,7 @@ add_library(_kernels SHARED
src/sgl-kernel/csrc/trt_reduce_kernel.cu
src/sgl-kernel/csrc/moe_align_kernel.cu
src/sgl-kernel/csrc/int8_gemm_kernel.cu
src/sgl-kernel/csrc/sampling_scaling_penalties.cu
src/sgl-kernel/csrc/sgl_kernel_ops.cu
)
...
...
sgl-kernel/pyproject.toml
View file @
e2b16c47
...
...
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name
=
"sgl-kernel"
version
=
"0.0.2.post1
1
"
version
=
"0.0.2.post1
2
"
description
=
"Kernel Library for SGLang"
readme
=
"README.md"
requires-python
=
">=3.8"
...
...
sgl-kernel/setup.py
View file @
e2b16c47
...
...
@@ -50,6 +50,7 @@ ext_modules = [
"src/sgl-kernel/csrc/trt_reduce_kernel.cu"
,
"src/sgl-kernel/csrc/moe_align_kernel.cu"
,
"src/sgl-kernel/csrc/int8_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu"
,
"src/sgl-kernel/csrc/sgl_kernel_ops.cu"
,
],
include_dirs
=
include_dirs
,
...
...
sgl-kernel/src/sgl-kernel/__init__.py
View file @
e2b16c47
...
...
@@ -4,6 +4,7 @@ from sgl_kernel.ops import (
init_custom_reduce
,
int8_scaled_mm
,
moe_align_block_size
,
sampling_scaling_penalties
,
)
__all__
=
[
...
...
@@ -12,4 +13,5 @@ __all__ = [
"custom_dispose"
,
"custom_reduce"
,
"int8_scaled_mm"
,
"sampling_scaling_penalties"
,
]
sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu
0 → 100644
View file @
e2b16c47
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <THC/THCAtomics.cuh>
#include "utils.hpp"
#include "vectorization.cuh"
template
<
typename
scalar_t
>
__global__
void
sampling_scaling_penalties_kernel
(
const
scalar_t
*
logits
,
const
scalar_t
*
scaling_penalties
,
scalar_t
*
output
,
const
int32_t
numel
)
{
const
int32_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int32_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
auto
const
*
vectorized_logits
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
logits
);
auto
const
*
vectorized_penalties
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
scaling_penalties
);
auto
*
vectorized_output
=
reinterpret_cast
<
vec4_t
<
scalar_t
>*>
(
output
);
const
int32_t
num_vec_elems
=
numel
>>
2
;
#pragma unroll 4
for
(
int32_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
stride
)
{
vec4_t
<
scalar_t
>
logits_vec
=
vectorized_logits
[
i
];
vec4_t
<
scalar_t
>
penalties_vec
=
vectorized_penalties
[
i
];
vec4_t
<
scalar_t
>
out_vec
;
out_vec
.
x
=
logits_vec
.
x
>
0
?
logits_vec
.
x
/
penalties_vec
.
x
:
logits_vec
.
x
*
penalties_vec
.
x
;
out_vec
.
y
=
logits_vec
.
y
>
0
?
logits_vec
.
y
/
penalties_vec
.
y
:
logits_vec
.
y
*
penalties_vec
.
y
;
out_vec
.
z
=
logits_vec
.
z
>
0
?
logits_vec
.
z
/
penalties_vec
.
z
:
logits_vec
.
z
*
penalties_vec
.
z
;
out_vec
.
w
=
logits_vec
.
w
>
0
?
logits_vec
.
w
/
penalties_vec
.
w
:
logits_vec
.
w
*
penalties_vec
.
w
;
vectorized_output
[
i
]
=
out_vec
;
}
const
int32_t
start_idx
=
num_vec_elems
*
4
;
for
(
int32_t
i
=
start_idx
+
tid
;
i
<
numel
;
i
+=
stride
)
{
scalar_t
logit
=
logits
[
i
];
scalar_t
penalty
=
scaling_penalties
[
i
];
output
[
i
]
=
logit
>
0
?
logit
/
penalty
:
logit
*
penalty
;
}
}
torch
::
Tensor
sampling_scaling_penalties
(
const
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
scaling_penalties
)
{
auto
output
=
torch
::
empty_like
(
logits
);
const
auto
numel
=
logits
.
numel
();
const
int
threads
=
512
;
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
logits
.
scalar_type
(),
"sampling_scaling_penalties_kernel"
,
([
&
]
{
const
int
blocks
=
(
numel
+
threads
*
4
-
1
)
/
(
threads
*
4
);
sampling_scaling_penalties_kernel
<
scalar_t
><<<
blocks
,
threads
,
0
,
stream
>>>
(
logits
.
data_ptr
<
scalar_t
>
(),
scaling_penalties
.
data_ptr
<
scalar_t
>
(),
output
.
data_ptr
<
scalar_t
>
(),
numel
);
}));
return
output
;
}
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
View file @
e2b16c47
...
...
@@ -12,6 +12,9 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
token_cnts_buffer
,
torch
::
Tensor
cumsum_buffer
);
// sampling_scaling_penalties
torch
::
Tensor
sampling_scaling_penalties
(
const
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
scaling_penalties
);
// int8_scaled_mm
torch
::
Tensor
int8_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
...
...
@@ -24,6 +27,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"all_reduce"
,
&
all_reduce
,
"custom all reduce (CUDA)"
);
// moe_align_block_size
m
.
def
(
"moe_align_block_size"
,
&
moe_align_block_size
,
"MOE Align Block Size (CUDA)"
);
// sampling_scaling_penalties
m
.
def
(
"sampling_scaling_penalties"
,
&
sampling_scaling_penalties
,
"Sampling scaling penalties (CUDA)"
);
// int8_scaled_mm
m
.
def
(
"int8_scaled_mm"
,
&
int8_scaled_mm
,
"INT8 scaled matmul (CUDA)"
);
}
sgl-kernel/src/sgl-kernel/csrc/vectorization.cuh
0 → 100644
View file @
e2b16c47
// Adapted from https://github.com/vllm-project/vllm/blob/main/csrc/quantization/vectorization.cuh
#pragma once
/**
* __device__ datatypes vectorized by 4
*/
// Include both AMD and NVIDIA fp8 types to avoid circular import
// TODO(luka/varun) use FP8_TYPE instead after refactoring
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e4m3fn.h>
// Vectorization containers
template
<
typename
scalar_t
>
struct
__align__
(
8
)
vec4_t
{
scalar_t
x
;
scalar_t
y
;
scalar_t
z
;
scalar_t
w
;
};
template
<
typename
quant_type_t
>
struct
__align__
(
4
)
q8x4_t
{
static_assert
(
std
::
is_same_v
<
quant_type_t
,
int8_t
>
||
std
::
is_same_v
<
quant_type_t
,
c10
::
Float8_e4m3fn
>
||
std
::
is_same_v
<
quant_type_t
,
c10
::
Float8_e4m3fnuz
>
);
quant_type_t
x
;
quant_type_t
y
;
quant_type_t
z
;
quant_type_t
w
;
}
;
sgl-kernel/src/sgl-kernel/ops/__init__.py
View file @
e2b16c47
...
...
@@ -3,6 +3,9 @@ from sgl_kernel.ops._kernels import dispose as _dispose
from
sgl_kernel.ops._kernels
import
init_custom_ar
as
_init_custom_ar
from
sgl_kernel.ops._kernels
import
int8_scaled_mm
as
_int8_scaled_mm
from
sgl_kernel.ops._kernels
import
moe_align_block_size
as
_moe_align_block_size
from
sgl_kernel.ops._kernels
import
(
sampling_scaling_penalties
as
_sampling_scaling_penalties
,
)
def
init_custom_reduce
(
rank_id
,
num_devices
,
buffers
,
barrier_in
,
barrier_out
):
...
...
@@ -39,6 +42,10 @@ def moe_align_block_size(
)
def
sampling_scaling_penalties
(
logits
,
scaling_penalties
):
return
_sampling_scaling_penalties
(
logits
,
scaling_penalties
)
def
int8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
=
None
):
return
_int8_scaled_mm
(
mat_a
,
...
...
sgl-kernel/tests/test_sampling_scaling_penalties.py
0 → 100644
View file @
e2b16c47
import
torch
from
sgl_kernel
import
sampling_scaling_penalties
def
test_sampling_scaling_penalties
():
batch_sizes
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
65
]
vocab_sizes
=
[
2048
,
4096
,
8192
,
16384
,
32768
,
32767
]
dtypes
=
[
torch
.
float32
,
torch
.
half
,
torch
.
bfloat16
]
device
=
torch
.
device
(
"cuda"
)
for
dtype
in
dtypes
:
rtol
=
1e-3
atol
=
1e-3
for
bs
in
batch_sizes
:
for
vocab_size
in
vocab_sizes
:
logits
=
torch
.
randn
(
bs
,
vocab_size
,
device
=
device
,
dtype
=
dtype
)
scaling_penalties
=
(
torch
.
rand
(
bs
,
vocab_size
,
device
=
device
,
dtype
=
dtype
)
+
0.5
)
ref_output
=
torch
.
where
(
logits
>
0
,
logits
/
scaling_penalties
,
logits
*
scaling_penalties
)
kernel_output
=
sampling_scaling_penalties
(
logits
,
scaling_penalties
)
torch
.
testing
.
assert_close
(
kernel_output
,
ref_output
,
rtol
=
rtol
,
atol
=
atol
,
msg
=
f
"Failed for batch_size=
{
bs
}
, vocab_size=
{
vocab_size
}
, dtype=
{
dtype
}
"
,
)
if
__name__
==
"__main__"
:
test_sampling_scaling_penalties
()
print
(
"All tests passed!"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment