Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
eb06dbcb
"vscode:/vscode.git/clone" did not exist on "a3541eb2eceafcfb707a78607f9d883b929344c7"
Unverified
Commit
eb06dbcb
authored
Mar 09, 2025
by
Lianmin Zheng
Committed by
GitHub
Mar 09, 2025
Browse files
Move rope and bmm into sgl-kernel (#4241)
parent
9dfafa74
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
183 additions
and
13 deletions
+183
-13
sgl-kernel/csrc/elementwise/rope.cu
sgl-kernel/csrc/elementwise/rope.cu
+89
-0
sgl-kernel/csrc/gemm/bmm_fp8.cu
sgl-kernel/csrc/gemm/bmm_fp8.cu
+76
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+9
-9
sgl-kernel/pyproject.toml
sgl-kernel/pyproject.toml
+6
-1
sgl-kernel/setup.py
sgl-kernel/setup.py
+3
-3
No files found.
sgl-kernel/csrc/elementwise/rope.cu
0 → 100644
View file @
eb06dbcb
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <flashinfer/pos_enc.cuh>
#include "pytorch_extension_utils.h"
using
namespace
flashinfer
;
void
apply_rope_pos_ids_cos_sin_cache
(
at
::
Tensor
q
,
at
::
Tensor
k
,
at
::
Tensor
q_rope
,
at
::
Tensor
k_rope
,
at
::
Tensor
cos_sin_cache
,
at
::
Tensor
pos_ids
,
bool
interleave
,
int64_t
cuda_stream
)
{
CHECK_LAST_DIM_CONTIGUOUS
(
q
);
CHECK_LAST_DIM_CONTIGUOUS
(
k
);
CHECK_INPUT
(
cos_sin_cache
);
CHECK_INPUT
(
pos_ids
);
auto
device
=
q
.
device
();
CHECK_EQ
(
k
.
device
(),
device
);
CHECK_EQ
(
cos_sin_cache
.
device
(),
device
);
CHECK_EQ
(
pos_ids
.
device
(),
device
);
CHECK_DIM
(
3
,
q
);
// q: (nnz, H_Q, D)
CHECK_DIM
(
3
,
k
);
// k: (nnz, H_K, D)
// cos_sin_cache: (max_seq_len, R)
// First half of R is cos, second half is sin
CHECK_DIM
(
2
,
cos_sin_cache
);
CHECK_EQ
(
q
.
size
(
0
),
k
.
size
(
0
));
CHECK_EQ
(
q
.
size
(
2
),
k
.
size
(
2
));
unsigned
int
rotary_dim
=
cos_sin_cache
.
size
(
1
);
unsigned
int
num_qo_heads
=
q
.
size
(
1
);
unsigned
int
num_kv_heads
=
k
.
size
(
1
);
unsigned
int
head_dim
=
q
.
size
(
2
);
unsigned
int
nnz
=
q
.
size
(
0
);
size_t
q_stride_n
=
q
.
stride
(
0
);
size_t
q_stride_h
=
q
.
stride
(
1
);
size_t
k_stride_n
=
k
.
stride
(
0
);
size_t
k_stride_h
=
k
.
stride
(
1
);
size_t
q_rope_stride_n
=
q_rope
.
stride
(
0
);
size_t
q_rope_stride_h
=
q_rope
.
stride
(
1
);
size_t
k_rope_stride_n
=
k_rope
.
stride
(
0
);
size_t
k_rope_stride_h
=
k_rope
.
stride
(
1
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16
(
q
.
scalar_type
(),
c_type
,
[
&
]
{
cudaError_t
status
=
BatchQKApplyRotaryPosIdsCosSinCache
(
static_cast
<
c_type
*>
(
q
.
data_ptr
()),
static_cast
<
c_type
*>
(
k
.
data_ptr
()),
static_cast
<
c_type
*>
(
q_rope
.
data_ptr
()),
static_cast
<
c_type
*>
(
k_rope
.
data_ptr
()),
static_cast
<
float
*>
(
cos_sin_cache
.
data_ptr
()),
static_cast
<
int32_t
*>
(
pos_ids
.
data_ptr
()),
nnz
,
num_qo_heads
,
num_kv_heads
,
rotary_dim
,
head_dim
,
q_stride_n
,
q_stride_h
,
k_stride_n
,
k_stride_h
,
q_rope_stride_n
,
q_rope_stride_h
,
k_rope_stride_n
,
k_rope_stride_h
,
interleave
,
stream
);
TORCH_CHECK
(
status
==
cudaSuccess
,
"BatchQKApplyRotaryPosIdsCosSinCache failed with error code "
+
std
::
string
(
cudaGetErrorString
(
status
)));
return
true
;
});
}
sgl-kernel/csrc/gemm/bmm_fp8.cu
0 → 100644
View file @
eb06dbcb
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <driver_types.h>
#include <flashinfer/gemm/bmm_fp8.cuh>
#include "pytorch_extension_utils.h"
void
bmm_fp8
(
at
::
Tensor
A
,
at
::
Tensor
B
,
at
::
Tensor
D
,
at
::
Tensor
A_scale
,
at
::
Tensor
B_scale
,
at
::
Tensor
workspace_buffer
,
int64_t
cublas_handle
,
int64_t
cuda_stream
)
{
TORCH_CHECK
(
A
.
is_cuda
(),
"A must be a CUDA tensor"
);
TORCH_CHECK
(
B
.
is_cuda
(),
"B must be a CUDA tensor"
);
TORCH_CHECK
(
D
.
is_cuda
(),
"D must be a CUDA tensor"
);
TORCH_CHECK
(
A
.
dim
()
==
3
,
"Expected 3D tensor for A"
);
TORCH_CHECK
(
B
.
dim
()
==
3
,
"Expected 3D tensor for B"
);
TORCH_CHECK
(
D
.
dim
()
==
3
,
"Expected 3D tensor for D"
);
TORCH_CHECK
(
A
.
size
(
0
)
==
B
.
size
(
0
)
&&
A
.
size
(
0
)
==
D
.
size
(
0
),
"Batch sizes must match"
);
TORCH_CHECK
(
A
.
size
(
2
)
==
B
.
size
(
1
),
"Incompatible matrix sizes"
);
TORCH_CHECK
(
A
.
size
(
1
)
==
D
.
size
(
1
)
&&
B
.
size
(
2
)
==
D
.
size
(
2
),
"Result tensor has incorrect shape"
);
// PyTorch is row major by default. cuBLASLt is column major by default.
// We need row major D as expected.
// A ^ T * B = D, so D ^ T = B ^ T * A
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8
(
B
.
scalar_type
(),
b_type
,
[
&
]
{
return
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8
(
A
.
scalar_type
(),
a_type
,
[
&
]
{
return
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16
(
D
.
scalar_type
(),
d_type
,
[
&
]
{
auto
batch_size
=
A
.
size
(
0
);
auto
m
=
A
.
size
(
1
);
auto
k
=
A
.
size
(
2
);
auto
n
=
B
.
size
(
2
);
auto
lt_handle
=
reinterpret_cast
<
cublasLtHandle_t
>
(
cublas_handle
);
auto
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
auto
status
=
flashinfer
::
bmm_fp8
::
bmm_fp8_internal_cublaslt
(
workspace_buffer
.
data_ptr
(),
workspace_buffer
.
numel
(),
static_cast
<
b_type
*>
(
B
.
data_ptr
()),
static_cast
<
a_type
*>
(
A
.
data_ptr
()),
static_cast
<
d_type
*>
(
D
.
data_ptr
()),
batch_size
,
n
,
m
,
k
,
static_cast
<
float
*>
(
B_scale
.
data_ptr
()),
static_cast
<
float
*>
(
A_scale
.
data_ptr
()),
lt_handle
,
stream
);
TORCH_CHECK
(
status
==
CUBLAS_STATUS_SUCCESS
,
"bmm_fp8_internal_cublaslt failed: "
,
cublasGetStatusString
(
status
));
return
true
;
});
});
});
}
sgl-kernel/include/sgl_kernel_ops.h
View file @
eb06dbcb
...
@@ -140,6 +140,15 @@ void cublas_grouped_gemm(
...
@@ -140,6 +140,15 @@ void cublas_grouped_gemm(
const
torch
::
Dtype
&
out_dtype
,
const
torch
::
Dtype
&
out_dtype
,
int64_t
cublas_handle
,
int64_t
cublas_handle
,
int64_t
cuda_stream
);
int64_t
cuda_stream
);
void
bmm_fp8
(
at
::
Tensor
A
,
at
::
Tensor
B
,
at
::
Tensor
D
,
at
::
Tensor
A_scale
,
at
::
Tensor
B_scale
,
at
::
Tensor
workspace_buffer
,
int64_t
cublas_handle
,
int64_t
cuda_stream
);
/*
/*
* From csrc/moe
* From csrc/moe
...
@@ -198,15 +207,6 @@ void build_tree_kernel(
...
@@ -198,15 +207,6 @@ void build_tree_kernel(
/*
/*
* From FlashInfer
* From FlashInfer
*/
*/
void
bmm_fp8
(
at
::
Tensor
A
,
at
::
Tensor
B
,
at
::
Tensor
D
,
at
::
Tensor
A_scale
,
at
::
Tensor
B_scale
,
at
::
Tensor
workspace_buffer
,
int64_t
cublas_handle
,
int64_t
cuda_stream
);
void
min_p_sampling_from_probs
(
void
min_p_sampling_from_probs
(
at
::
Tensor
probs
,
at
::
Tensor
probs
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
uniform_samples
,
...
...
sgl-kernel/pyproject.toml
View file @
eb06dbcb
[build-system]
[build-system]
requires
=
[
"setuptools>=61.0"
,
"wheel"
,
"torch"
]
requires
=
[
"setuptools>=61.0"
,
"scikit-build-core>=0.10"
,
"torch==2.5.1"
,
"wheel"
,
]
build-backend
=
"setuptools.build_meta"
build-backend
=
"setuptools.build_meta"
[project]
[project]
...
...
sgl-kernel/setup.py
View file @
eb06dbcb
...
@@ -97,6 +97,8 @@ sources = [
...
@@ -97,6 +97,8 @@ sources = [
"csrc/allreduce/trt_reduce_kernel.cu"
,
"csrc/allreduce/trt_reduce_kernel.cu"
,
"csrc/attention/lightning_attention_decode_kernel.cu"
,
"csrc/attention/lightning_attention_decode_kernel.cu"
,
"csrc/elementwise/fused_add_rms_norm_kernel.cu"
,
"csrc/elementwise/fused_add_rms_norm_kernel.cu"
,
"csrc/elementwise/rope.cu"
,
"csrc/gemm/bmm_fp8.cu"
,
"csrc/gemm/cublas_grouped_gemm.cu"
,
"csrc/gemm/cublas_grouped_gemm.cu"
,
"csrc/gemm/fp8_gemm_kernel.cu"
,
"csrc/gemm/fp8_gemm_kernel.cu"
,
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
,
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
,
...
@@ -109,11 +111,9 @@ sources = [
...
@@ -109,11 +111,9 @@ sources = [
"csrc/speculative/speculative_sampling.cu"
,
"csrc/speculative/speculative_sampling.cu"
,
"csrc/torch_extension.cc"
,
"csrc/torch_extension.cc"
,
"3rdparty/flashinfer/csrc/activation.cu"
,
"3rdparty/flashinfer/csrc/activation.cu"
,
"3rdparty/flashinfer/csrc/bmm_fp8.cu"
,
"3rdparty/flashinfer/csrc/norm.cu"
,
"3rdparty/flashinfer/csrc/norm.cu"
,
"3rdparty/flashinfer/csrc/sampling.cu"
,
"3rdparty/flashinfer/csrc/renorm.cu"
,
"3rdparty/flashinfer/csrc/renorm.cu"
,
"3rdparty/flashinfer/csrc/
rope
.cu"
,
"3rdparty/flashinfer/csrc/
sampling
.cu"
,
]
]
enable_bf16
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_BF16"
,
"0"
)
==
"1"
enable_bf16
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_BF16"
,
"0"
)
==
"1"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment