Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
95f789ad
Unverified
Commit
95f789ad
authored
Jan 26, 2025
by
Yineng Zhang
Committed by
GitHub
Jan 26, 2025
Browse files
minor: cleanup sgl-kernel (#3143)
parent
4f118a39
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
11 additions
and
225 deletions
+11
-225
sgl-kernel/developer_guide.md
sgl-kernel/developer_guide.md
+4
-0
sgl-kernel/setup.py
sgl-kernel/setup.py
+0
-2
sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm.cu
sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm.cu
+0
-92
sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu
.../src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu
+1
-2
sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu
+1
-16
sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu
sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu
+0
-61
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu
+1
-0
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu
+1
-0
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
+1
-5
sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh
sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh
+1
-2
sgl-kernel/src/sgl-kernel/include/utils.h
sgl-kernel/src/sgl-kernel/include/utils.h
+1
-2
sgl-kernel/src/sgl-kernel/torch_extension.cc
sgl-kernel/src/sgl-kernel/torch_extension.cc
+0
-4
sgl-kernel/tests/test_sampling_scaling_penalties.py
sgl-kernel/tests/test_sampling_scaling_penalties.py
+0
-39
No files found.
sgl-kernel/developer_guide.md
View file @
95f789ad
...
...
@@ -40,6 +40,10 @@ Development build:
make build
```
Note:
The
`sgl-kernel`
is rapidly evolving. If you experience a compilation failure, try using
`make rebuild`
.
### Testing & Benchmarking
1.
Add pytest tests in
[
tests/
](
https://github.com/sgl-project/sglang/tree/main/sgl-kernel/tests
)
...
...
sgl-kernel/setup.py
View file @
95f789ad
...
...
@@ -82,10 +82,8 @@ sources = [
"src/sgl-kernel/csrc/trt_reduce_kernel.cu"
,
"src/sgl-kernel/csrc/moe_align_kernel.cu"
,
"src/sgl-kernel/csrc/int8_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu"
,
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu"
,
"src/sgl-kernel/csrc/rotary_embedding.cu"
,
"src/sgl-kernel/csrc/fused_add_rms_norm.cu"
,
"3rdparty/flashinfer/csrc/activation.cu"
,
"3rdparty/flashinfer/csrc/bmm_fp8.cu"
,
"3rdparty/flashinfer/csrc/group_gemm.cu"
,
...
...
sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm.cu
deleted
100644 → 0
View file @
4f118a39
// Adapted from
// https://github.com/InternLM/lmdeploy/blob/800b6010c0bf76aadf678bc38a507b749fb9774c/src/turbomind/kernels/norm/rms_norm.cu
#include <turbomind/kernels/core/array_ops.h>
#include <turbomind/kernels/core/common.h>
#include <cub/block/block_reduce.cuh>
using
namespace
turbomind
;
template
<
class
T
,
class
Tacc
,
int
block_dim
,
int
vec_size
>
__global__
void
BiasResidualRMSNormKernel
(
T
*
__restrict__
residual
,
T
*
__restrict__
hidden_states
,
const
T
*
__restrict__
weights
,
const
T
*
__restrict__
bias
,
int
dims
,
int
num
,
float
eps
,
float
inv_dims
)
{
const
int
ti
=
blockIdx
.
x
;
const
int
di
=
threadIdx
.
x
*
vec_size
;
if
(
ti
>=
num
)
{
return
;
}
residual
+=
dims
*
ti
;
hidden_states
+=
dims
*
ti
;
Array
<
Tacc
,
vec_size
>
accum
{};
Array
<
T
,
vec_size
>
r_vec
;
Array
<
T
,
vec_size
>
h_vec
;
Array
<
T
,
vec_size
>
b_vec
;
for
(
int
i
=
di
;
i
<
dims
;
i
+=
block_dim
*
vec_size
)
{
Load
(
r_vec
,
&
residual
[
i
]);
Load
(
h_vec
,
&
hidden_states
[
i
]);
using
namespace
ops
;
r_vec
=
r_vec
+
h_vec
;
if
(
bias
)
{
Ldg
(
b_vec
,
&
bias
[
i
]);
r_vec
=
r_vec
+
b_vec
;
}
Store
(
&
residual
[
i
],
r_vec
);
Array
<
Tacc
,
vec_size
>
tmp
=
cast
<
Tacc
>
(
r_vec
);
accum
=
accum
+
tmp
*
tmp
;
}
float
sum
{};
PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
vec_size
;
++
i
)
{
sum
+=
accum
[
i
];
}
using
BlockReduce
=
cub
::
BlockReduce
<
Tacc
,
block_dim
>
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
sum
=
BlockReduce
{
temp_storage
}.
Sum
(
sum
);
__shared__
float
shared_sum
;
if
(
threadIdx
.
x
==
0
)
{
shared_sum
=
rsqrtf
(
sum
*
inv_dims
+
eps
);
}
__syncthreads
();
sum
=
shared_sum
;
Array
<
T
,
vec_size
>
w_vec
;
for
(
int
i
=
di
;
i
<
dims
;
i
+=
block_dim
*
vec_size
)
{
Load
(
r_vec
,
&
residual
[
i
]);
Ldg
(
w_vec
,
&
weights
[
i
]);
PRAGMA_UNROLL
for
(
int
c
=
0
;
c
<
vec_size
;
++
c
)
{
r_vec
[
c
]
=
(
T
)((
float
)
r_vec
[
c
]
*
sum
)
*
w_vec
[
c
];
}
Store
(
&
hidden_states
[
i
],
r_vec
);
}
}
template
<
class
T
>
void
invokeBiasResidualRMSNorm
(
T
*
residual
,
T
*
hidden_states
,
const
T
*
weights
,
const
T
*
bias
,
int
dims
,
int
num
,
float
eps
,
cudaStream_t
st
)
{
constexpr
int
vec_size
=
16
/
sizeof
(
T
);
constexpr
int
threads
=
512
;
const
int
blocks
=
num
;
BiasResidualRMSNormKernel
<
T
,
float
,
threads
,
vec_size
>
<<<
blocks
,
threads
,
0
,
st
>>>
(
residual
,
hidden_states
,
weights
,
bias
,
dims
,
num
,
eps
,
1.
f
/
dims
);
}
sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu
View file @
95f789ad
...
...
@@ -3,8 +3,7 @@
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include "utils.h"
#include <torch/extension.h>
#define THREADS_PER_BLOCK 128
...
...
sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu
View file @
95f789ad
...
...
@@ -3,28 +3,14 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>
#include <THC/THCAtomics.cuh>
#include "utils.h"
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#endif
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#ifndef USE_ROCM
#define DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
#else
#define DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
#endif
#define CEILDIV(x, y) (((x) + (y)-1) / (y))
...
...
@@ -39,7 +25,6 @@
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
__device__
__forceinline__
int32_t
index
(
int32_t
total_col
,
int32_t
row
,
int32_t
col
)
{
// don't worry about overflow because num_experts is relatively small
return
row
*
total_col
+
col
;
}
...
...
sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu
deleted
100644 → 0
View file @
4f118a39
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <THC/THCAtomics.cuh>
#include <flashinfer/vec_dtypes.cuh>
#include "utils.h"
template
<
typename
scalar_t
>
__global__
void
sampling_scaling_penalties_kernel
(
const
scalar_t
*
logits
,
const
scalar_t
*
scaling_penalties
,
scalar_t
*
output
,
const
int32_t
numel
)
{
const
int32_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int32_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
constexpr
uint32_t
vec_size
=
16
/
sizeof
(
scalar_t
);
using
vec_t
=
flashinfer
::
vec_t
<
scalar_t
,
vec_size
>
;
const
int32_t
num_vec_elems
=
numel
/
vec_size
;
#pragma unroll 1
for
(
int32_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
stride
)
{
vec_t
logits_vec
,
penalties_vec
,
out_vec
;
logits_vec
.
cast_load
(
logits
+
i
*
vec_size
);
penalties_vec
.
cast_load
(
scaling_penalties
+
i
*
vec_size
);
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
vec_size
;
++
j
)
{
out_vec
[
j
]
=
logits_vec
[
j
]
>
scalar_t
(
0.0
f
)
?
logits_vec
[
j
]
/
penalties_vec
[
j
]
:
logits_vec
[
j
]
*
penalties_vec
[
j
];
}
out_vec
.
cast_store
(
output
+
i
*
vec_size
);
}
// process the remaining elements
const
int32_t
start_idx
=
num_vec_elems
*
vec_size
;
for
(
int32_t
i
=
start_idx
+
tid
;
i
<
numel
;
i
+=
stride
)
{
scalar_t
logit
=
logits
[
i
];
scalar_t
penalty
=
scaling_penalties
[
i
];
output
[
i
]
=
logit
>
scalar_t
(
0.0
f
)
?
logit
/
penalty
:
logit
*
penalty
;
}
}
torch
::
Tensor
sampling_scaling_penalties
(
const
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
scaling_penalties
)
{
auto
output
=
torch
::
empty_like
(
logits
);
const
auto
numel
=
logits
.
numel
();
const
int
threads
=
512
;
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
logits
.
scalar_type
(),
scalar_t
,
[
&
]
{
uint32_t
vec_size
=
16
/
sizeof
(
scalar_t
);
const
int
blocks
=
(
numel
+
threads
*
vec_size
-
1
)
/
(
threads
*
vec_size
);
sampling_scaling_penalties_kernel
<
scalar_t
><<<
blocks
,
threads
,
0
,
stream
>>>
(
static_cast
<
scalar_t
*>
(
logits
.
data_ptr
()),
static_cast
<
scalar_t
*>
(
scaling_penalties
.
data_ptr
()),
static_cast
<
scalar_t
*>
(
output
.
data_ptr
()),
numel
);
return
true
;
});
return
output
;
}
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu
View file @
95f789ad
...
...
@@ -26,6 +26,7 @@
#include <tuple>
#include "trt_reduce_internal.cuh"
#include "utils.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu
View file @
95f789ad
...
...
@@ -5,6 +5,7 @@
#include <cassert>
#include "trt_reduce_internal.cuh"
#include "utils.h"
using
namespace
trt_llm
;
...
...
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
View file @
95f789ad
#pragma once
#include <Python.h>
#include <torch/extension.h>
#include <vector>
#include "utils.h"
#define _CONCAT(A, B) A##B
#define CONCAT(A, B) _CONCAT(A, B)
...
...
@@ -36,9 +35,6 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
token_cnts_buffer
,
torch
::
Tensor
cumsum_buffer
);
// sampling_scaling_penalties
torch
::
Tensor
sampling_scaling_penalties
(
const
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
scaling_penalties
);
// int8_scaled_mm
torch
::
Tensor
int8_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
...
...
sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh
View file @
95f789ad
...
...
@@ -17,12 +17,11 @@
*/
#pragma once
#include <cuda_fp16.h>
#include <stdint.h>
#include <torch/all.h>
#include "utils.h"
namespace
trt_llm
{
constexpr
size_t
WARP_SIZE
=
32
;
constexpr
size_t
MAX_ALL_REDUCE_BLOCKS
=
36
;
...
...
sgl-kernel/src/sgl-kernel/include/utils.h
View file @
95f789ad
#pragma once
#include <cuda_runtime.h>
#include <pytorch_extension_utils.h>
#include <torch/extension.h>
#include <sstream>
#include "sgl_kernels_ops.h"
struct
cuda_error
:
public
std
::
runtime_error
{
/**
* @brief Constructs a `cuda_error` object with the given `message`.
...
...
sgl-kernel/src/sgl-kernel/torch_extension.cc
View file @
95f789ad
...
...
@@ -28,10 +28,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"
);
m
.
impl
(
"moe_align_block_size"
,
torch
::
kCUDA
,
&
moe_align_block_size
);
// sampling_scaling_penalties
m
.
def
(
"sampling_scaling_penalties(Tensor logits, Tensor scaling_penalties) -> Tensor"
);
m
.
impl
(
"sampling_scaling_penalties"
,
torch
::
kCUDA
,
&
sampling_scaling_penalties
);
// int8_scaled_mm
m
.
def
(
"int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
...
...
sgl-kernel/tests/test_sampling_scaling_penalties.py
deleted
100644 → 0
View file @
4f118a39
import
pytest
import
torch
from
sgl_kernel
import
sampling_scaling_penalties
batch_sizes
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
65
]
vocab_sizes
=
[
2048
,
4096
,
8192
,
16384
,
32768
,
32767
]
dtypes
=
[
torch
.
float32
,
torch
.
half
,
torch
.
bfloat16
]
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
batch_sizes
)
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
vocab_sizes
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
dtypes
)
def
test_sampling_scaling_penalties
(
batch_size
,
vocab_size
,
dtype
):
device
=
torch
.
device
(
"cuda"
)
rtol
=
1e-3
atol
=
1e-3
logits
=
torch
.
randn
(
batch_size
,
vocab_size
,
device
=
device
,
dtype
=
dtype
)
scaling_penalties
=
(
torch
.
rand
(
batch_size
,
vocab_size
,
device
=
device
,
dtype
=
dtype
)
+
0.5
)
ref_output
=
torch
.
where
(
logits
>
0
,
logits
/
scaling_penalties
,
logits
*
scaling_penalties
)
kernel_output
=
sampling_scaling_penalties
(
logits
,
scaling_penalties
)
torch
.
testing
.
assert_close
(
kernel_output
,
ref_output
,
rtol
=
rtol
,
atol
=
atol
,
msg
=
f
"Failed for batch_size=
{
batch_size
}
, vocab_size=
{
vocab_size
}
, dtype=
{
dtype
}
"
,
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment