Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
eb06dbcb
Unverified
Commit
eb06dbcb
authored
Mar 09, 2025
by
Lianmin Zheng
Committed by
GitHub
Mar 09, 2025
Browse files
Move rope and bmm into sgl-kernel (#4241)
parent
9dfafa74
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
183 additions
and
13 deletions
+183
-13
sgl-kernel/csrc/elementwise/rope.cu
sgl-kernel/csrc/elementwise/rope.cu
+89
-0
sgl-kernel/csrc/gemm/bmm_fp8.cu
sgl-kernel/csrc/gemm/bmm_fp8.cu
+76
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+9
-9
sgl-kernel/pyproject.toml
sgl-kernel/pyproject.toml
+6
-1
sgl-kernel/setup.py
sgl-kernel/setup.py
+3
-3
No files found.
sgl-kernel/csrc/elementwise/rope.cu
0 → 100644
View file @
eb06dbcb
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <flashinfer/pos_enc.cuh>
#include "pytorch_extension_utils.h"
using
namespace
flashinfer
;
void
apply_rope_pos_ids_cos_sin_cache
(
at
::
Tensor
q
,
at
::
Tensor
k
,
at
::
Tensor
q_rope
,
at
::
Tensor
k_rope
,
at
::
Tensor
cos_sin_cache
,
at
::
Tensor
pos_ids
,
bool
interleave
,
int64_t
cuda_stream
)
{
CHECK_LAST_DIM_CONTIGUOUS
(
q
);
CHECK_LAST_DIM_CONTIGUOUS
(
k
);
CHECK_INPUT
(
cos_sin_cache
);
CHECK_INPUT
(
pos_ids
);
auto
device
=
q
.
device
();
CHECK_EQ
(
k
.
device
(),
device
);
CHECK_EQ
(
cos_sin_cache
.
device
(),
device
);
CHECK_EQ
(
pos_ids
.
device
(),
device
);
CHECK_DIM
(
3
,
q
);
// q: (nnz, H_Q, D)
CHECK_DIM
(
3
,
k
);
// k: (nnz, H_K, D)
// cos_sin_cache: (max_seq_len, R)
// First half of R is cos, second half is sin
CHECK_DIM
(
2
,
cos_sin_cache
);
CHECK_EQ
(
q
.
size
(
0
),
k
.
size
(
0
));
CHECK_EQ
(
q
.
size
(
2
),
k
.
size
(
2
));
unsigned
int
rotary_dim
=
cos_sin_cache
.
size
(
1
);
unsigned
int
num_qo_heads
=
q
.
size
(
1
);
unsigned
int
num_kv_heads
=
k
.
size
(
1
);
unsigned
int
head_dim
=
q
.
size
(
2
);
unsigned
int
nnz
=
q
.
size
(
0
);
size_t
q_stride_n
=
q
.
stride
(
0
);
size_t
q_stride_h
=
q
.
stride
(
1
);
size_t
k_stride_n
=
k
.
stride
(
0
);
size_t
k_stride_h
=
k
.
stride
(
1
);
size_t
q_rope_stride_n
=
q_rope
.
stride
(
0
);
size_t
q_rope_stride_h
=
q_rope
.
stride
(
1
);
size_t
k_rope_stride_n
=
k_rope
.
stride
(
0
);
size_t
k_rope_stride_h
=
k_rope
.
stride
(
1
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16
(
q
.
scalar_type
(),
c_type
,
[
&
]
{
cudaError_t
status
=
BatchQKApplyRotaryPosIdsCosSinCache
(
static_cast
<
c_type
*>
(
q
.
data_ptr
()),
static_cast
<
c_type
*>
(
k
.
data_ptr
()),
static_cast
<
c_type
*>
(
q_rope
.
data_ptr
()),
static_cast
<
c_type
*>
(
k_rope
.
data_ptr
()),
static_cast
<
float
*>
(
cos_sin_cache
.
data_ptr
()),
static_cast
<
int32_t
*>
(
pos_ids
.
data_ptr
()),
nnz
,
num_qo_heads
,
num_kv_heads
,
rotary_dim
,
head_dim
,
q_stride_n
,
q_stride_h
,
k_stride_n
,
k_stride_h
,
q_rope_stride_n
,
q_rope_stride_h
,
k_rope_stride_n
,
k_rope_stride_h
,
interleave
,
stream
);
TORCH_CHECK
(
status
==
cudaSuccess
,
"BatchQKApplyRotaryPosIdsCosSinCache failed with error code "
+
std
::
string
(
cudaGetErrorString
(
status
)));
return
true
;
});
}
sgl-kernel/csrc/gemm/bmm_fp8.cu
0 → 100644
View file @
eb06dbcb
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <driver_types.h>
#include <flashinfer/gemm/bmm_fp8.cuh>
#include "pytorch_extension_utils.h"
void
bmm_fp8
(
at
::
Tensor
A
,
at
::
Tensor
B
,
at
::
Tensor
D
,
at
::
Tensor
A_scale
,
at
::
Tensor
B_scale
,
at
::
Tensor
workspace_buffer
,
int64_t
cublas_handle
,
int64_t
cuda_stream
)
{
TORCH_CHECK
(
A
.
is_cuda
(),
"A must be a CUDA tensor"
);
TORCH_CHECK
(
B
.
is_cuda
(),
"B must be a CUDA tensor"
);
TORCH_CHECK
(
D
.
is_cuda
(),
"D must be a CUDA tensor"
);
TORCH_CHECK
(
A
.
dim
()
==
3
,
"Expected 3D tensor for A"
);
TORCH_CHECK
(
B
.
dim
()
==
3
,
"Expected 3D tensor for B"
);
TORCH_CHECK
(
D
.
dim
()
==
3
,
"Expected 3D tensor for D"
);
TORCH_CHECK
(
A
.
size
(
0
)
==
B
.
size
(
0
)
&&
A
.
size
(
0
)
==
D
.
size
(
0
),
"Batch sizes must match"
);
TORCH_CHECK
(
A
.
size
(
2
)
==
B
.
size
(
1
),
"Incompatible matrix sizes"
);
TORCH_CHECK
(
A
.
size
(
1
)
==
D
.
size
(
1
)
&&
B
.
size
(
2
)
==
D
.
size
(
2
),
"Result tensor has incorrect shape"
);
// PyTorch is row major by default. cuBLASLt is column major by default.
// We need row major D as expected.
// A ^ T * B = D, so D ^ T = B ^ T * A
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8
(
B
.
scalar_type
(),
b_type
,
[
&
]
{
return
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8
(
A
.
scalar_type
(),
a_type
,
[
&
]
{
return
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16
(
D
.
scalar_type
(),
d_type
,
[
&
]
{
auto
batch_size
=
A
.
size
(
0
);
auto
m
=
A
.
size
(
1
);
auto
k
=
A
.
size
(
2
);
auto
n
=
B
.
size
(
2
);
auto
lt_handle
=
reinterpret_cast
<
cublasLtHandle_t
>
(
cublas_handle
);
auto
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
auto
status
=
flashinfer
::
bmm_fp8
::
bmm_fp8_internal_cublaslt
(
workspace_buffer
.
data_ptr
(),
workspace_buffer
.
numel
(),
static_cast
<
b_type
*>
(
B
.
data_ptr
()),
static_cast
<
a_type
*>
(
A
.
data_ptr
()),
static_cast
<
d_type
*>
(
D
.
data_ptr
()),
batch_size
,
n
,
m
,
k
,
static_cast
<
float
*>
(
B_scale
.
data_ptr
()),
static_cast
<
float
*>
(
A_scale
.
data_ptr
()),
lt_handle
,
stream
);
TORCH_CHECK
(
status
==
CUBLAS_STATUS_SUCCESS
,
"bmm_fp8_internal_cublaslt failed: "
,
cublasGetStatusString
(
status
));
return
true
;
});
});
});
}
sgl-kernel/include/sgl_kernel_ops.h
View file @
eb06dbcb
...
@@ -140,6 +140,15 @@ void cublas_grouped_gemm(
...
@@ -140,6 +140,15 @@ void cublas_grouped_gemm(
const
torch
::
Dtype
&
out_dtype
,
const
torch
::
Dtype
&
out_dtype
,
int64_t
cublas_handle
,
int64_t
cublas_handle
,
int64_t
cuda_stream
);
int64_t
cuda_stream
);
void
bmm_fp8
(
at
::
Tensor
A
,
at
::
Tensor
B
,
at
::
Tensor
D
,
at
::
Tensor
A_scale
,
at
::
Tensor
B_scale
,
at
::
Tensor
workspace_buffer
,
int64_t
cublas_handle
,
int64_t
cuda_stream
);
/*
/*
* From csrc/moe
* From csrc/moe
...
@@ -198,15 +207,6 @@ void build_tree_kernel(
...
@@ -198,15 +207,6 @@ void build_tree_kernel(
/*
/*
* From FlashInfer
* From FlashInfer
*/
*/
void
bmm_fp8
(
at
::
Tensor
A
,
at
::
Tensor
B
,
at
::
Tensor
D
,
at
::
Tensor
A_scale
,
at
::
Tensor
B_scale
,
at
::
Tensor
workspace_buffer
,
int64_t
cublas_handle
,
int64_t
cuda_stream
);
void
min_p_sampling_from_probs
(
void
min_p_sampling_from_probs
(
at
::
Tensor
probs
,
at
::
Tensor
probs
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
uniform_samples
,
...
...
sgl-kernel/pyproject.toml
View file @
eb06dbcb
[build-system]
[build-system]
requires
=
[
"setuptools>=61.0"
,
"wheel"
,
"torch"
]
requires
=
[
"setuptools>=61.0"
,
"scikit-build-core>=0.10"
,
"torch==2.5.1"
,
"wheel"
,
]
build-backend
=
"setuptools.build_meta"
build-backend
=
"setuptools.build_meta"
[project]
[project]
...
...
sgl-kernel/setup.py
View file @
eb06dbcb
...
@@ -97,6 +97,8 @@ sources = [
...
@@ -97,6 +97,8 @@ sources = [
"csrc/allreduce/trt_reduce_kernel.cu"
,
"csrc/allreduce/trt_reduce_kernel.cu"
,
"csrc/attention/lightning_attention_decode_kernel.cu"
,
"csrc/attention/lightning_attention_decode_kernel.cu"
,
"csrc/elementwise/fused_add_rms_norm_kernel.cu"
,
"csrc/elementwise/fused_add_rms_norm_kernel.cu"
,
"csrc/elementwise/rope.cu"
,
"csrc/gemm/bmm_fp8.cu"
,
"csrc/gemm/cublas_grouped_gemm.cu"
,
"csrc/gemm/cublas_grouped_gemm.cu"
,
"csrc/gemm/fp8_gemm_kernel.cu"
,
"csrc/gemm/fp8_gemm_kernel.cu"
,
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
,
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
,
...
@@ -109,11 +111,9 @@ sources = [
...
@@ -109,11 +111,9 @@ sources = [
"csrc/speculative/speculative_sampling.cu"
,
"csrc/speculative/speculative_sampling.cu"
,
"csrc/torch_extension.cc"
,
"csrc/torch_extension.cc"
,
"3rdparty/flashinfer/csrc/activation.cu"
,
"3rdparty/flashinfer/csrc/activation.cu"
,
"3rdparty/flashinfer/csrc/bmm_fp8.cu"
,
"3rdparty/flashinfer/csrc/norm.cu"
,
"3rdparty/flashinfer/csrc/norm.cu"
,
"3rdparty/flashinfer/csrc/sampling.cu"
,
"3rdparty/flashinfer/csrc/renorm.cu"
,
"3rdparty/flashinfer/csrc/renorm.cu"
,
"3rdparty/flashinfer/csrc/
rope
.cu"
,
"3rdparty/flashinfer/csrc/
sampling
.cu"
,
]
]
enable_bf16
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_BF16"
,
"0"
)
==
"1"
enable_bf16
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_BF16"
,
"0"
)
==
"1"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment