Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
47eb139f
Unverified
Commit
47eb139f
authored
Dec 01, 2024
by
Yineng Zhang
Committed by
GitHub
Dec 01, 2024
Browse files
feat: use warp reduce as a simple example (#2304)
parent
5c18a037
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
193 additions
and
17 deletions
+193
-17
.gitignore
.gitignore
+33
-0
sgl-kernel/pyproject.toml
sgl-kernel/pyproject.toml
+14
-17
sgl-kernel/setup.py
sgl-kernel/setup.py
+20
-0
sgl-kernel/src/sgl-kernel/__init__.py
sgl-kernel/src/sgl-kernel/__init__.py
+3
-0
sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc
sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc
+21
-0
sgl-kernel/src/sgl-kernel/csrc/warp_reduce_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/warp_reduce_kernel.cu
+97
-0
sgl-kernel/src/sgl-kernel/ops/__init__.py
sgl-kernel/src/sgl-kernel/ops/__init__.py
+5
-0
No files found.
.gitignore
View file @
47eb139f
...
...
@@ -185,3 +185,36 @@ work_dirs/
*.csv
!logo.png
# Prerequisites
*.d
# Compiled Object files
*.slo
*.lo
*.o
*.obj
# Precompiled Headers
*.gch
*.pch
# Compiled Dynamic libraries
*.so
*.dylib
*.dll
# Fortran module files
*.mod
*.smod
# Compiled Static libraries
*.lai
*.la
*.a
*.lib
# Executables
*.exe
*.out
*.app
sgl-kernel/pyproject.toml
View file @
47eb139f
[build-system]
requires
=
[
"setuptools>=61.0"
,
"wheel"
]
requires
=
[
"setuptools>=61.0"
,
"wheel"
,
"torch"
]
build-backend
=
"setuptools.build_meta"
[project]
name
=
"sgl-kernel"
version
=
"0.0.
1
"
version
=
"0.0.
2
"
description
=
"Kernel Library for SGLang"
readme
=
"README.md"
requires-python
=
">=3.8"
license
=
{
file
=
"LICENSE"
}
classifiers
=
[
"Programming Language :: Python :: 3"
,
"License :: OSI Approved :: Apache Software License"
,
"Programming Language :: Python :: 3"
,
"License :: OSI Approved :: Apache Software License"
,
"Programming Language :: C++"
,
"Programming Language :: CUDA"
,
]
dependencies
=
[
"torch"
,
]
dependencies
=
["numpy"]
[project.optional-dependencies]
srt
=
["torch"]
all
=
["sgl-kernel[srt]"]
[project.urls]
"Homepage"
=
"https://github.com/sgl-project/sglang"
"Bug
Tracker"
=
"https://github.com/sgl-project/sglang/issues"
[tool.setuptools.packages.find]
exclude
=
[
"dist*"
,
"tests*"
,
]
[tool.setuptools]
package-dir
=
{
"sgl_kernel"
=
"src/sgl-kernel"
}
packages
=
[
"sgl_kernel"
,
"sgl_kernel.ops"
,
"sgl_kernel.csrc"
]
[tool.wheel]
exclude
=
[
"dist*"
,
"tests*"
,
"dist*"
,
"tests*"
,
]
sgl-kernel/setup.py
0 → 100644
View file @
47eb139f
from
setuptools
import
find_packages
,
setup
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
setup
(
name
=
"sgl-kernel"
,
version
=
"0.0.2"
,
packages
=
find_packages
(
where
=
"src"
),
package_dir
=
{
""
:
"src"
},
ext_modules
=
[
CUDAExtension
(
"sgl_kernel.ops.warp_reduce_cuda"
,
[
"src/sgl-kernel/csrc/warp_reduce.cc"
,
"src/sgl-kernel/csrc/warp_reduce_kernel.cu"
,
],
)
],
cmdclass
=
{
"build_ext"
:
BuildExtension
},
install_requires
=
[
"torch"
],
)
sgl-kernel/src/sgl-kernel/__init__.py
View file @
47eb139f
from
.ops
import
warp_reduce
__all__
=
[
"warp_reduce"
]
sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc
0 → 100644
View file @
47eb139f
#include <torch/extension.h>
#include <vector>
torch
::
Tensor
warp_reduce_cuda
(
torch
::
Tensor
input
);
#define CHECK_CUDA(x) \
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
torch
::
Tensor
warp_reduce
(
torch
::
Tensor
input
)
{
CHECK_INPUT
(
input
);
return
warp_reduce_cuda
(
input
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"reduce"
,
&
warp_reduce
,
"Warp Reduce (CUDA)"
);
}
sgl-kernel/src/sgl-kernel/csrc/warp_reduce_kernel.cu
0 → 100644
View file @
47eb139f
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#define FINAL_MASK 0xffffffff
#define BLOCK_SIZE 256
template
<
typename
scalar_t
>
__device__
__forceinline__
scalar_t
add
(
scalar_t
a
,
scalar_t
b
)
{
return
a
+
b
;
}
template
<
typename
scalar_t
>
__device__
__forceinline__
scalar_t
warpReduceSum
(
scalar_t
val
)
{
#pragma unroll
for
(
int
offset
=
16
;
offset
>
0
;
offset
/=
2
)
{
val
+=
__shfl_down_sync
(
FINAL_MASK
,
val
,
offset
);
}
return
val
;
}
template
<
typename
scalar_t
>
__device__
__forceinline__
scalar_t
blockReduceSum
(
scalar_t
val
)
{
__shared__
scalar_t
shared
[
32
];
int
lane
=
threadIdx
.
x
%
32
;
int
wid
=
threadIdx
.
x
/
32
;
val
=
warpReduceSum
(
val
);
// First reduce within warp
if
(
lane
==
0
)
shared
[
wid
]
=
val
;
// Write reduced value to shared memory
__syncthreads
();
// Wait for all partial reductions
// Read from shared memory only if that warp existed
val
=
(
threadIdx
.
x
<
(
blockDim
.
x
/
32
))
?
shared
[
lane
]
:
0
;
if
(
wid
==
0
)
val
=
warpReduceSum
(
val
);
// Final reduce within first warp
return
val
;
}
template
<
typename
scalar_t
>
__global__
void
warp_reduce_cuda_kernel
(
const
torch
::
PackedTensorAccessor32
<
scalar_t
,
1
,
torch
::
RestrictPtrTraits
>
input
,
torch
::
PackedTensorAccessor32
<
scalar_t
,
1
,
torch
::
RestrictPtrTraits
>
output
,
int
N
)
{
scalar_t
sum
=
0
;
// Grid-stride loop
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
sum
+=
input
[
i
];
}
// Perform block-wide reduction
sum
=
blockReduceSum
(
sum
);
// Write result for this block to global memory
if
(
threadIdx
.
x
==
0
)
{
output
[
blockIdx
.
x
]
=
sum
;
}
}
torch
::
Tensor
warp_reduce_cuda
(
torch
::
Tensor
input
)
{
// Input validation
TORCH_CHECK
(
input
.
dim
()
==
1
,
"1D tensor expected"
);
TORCH_CHECK
(
input
.
is_cuda
(),
"CUDA tensor expected"
);
const
auto
N
=
input
.
size
(
0
);
// Handle empty tensor
if
(
N
==
0
)
{
return
torch
::
zeros
({
1
},
input
.
options
());
}
// Calculate grid dimensions
const
int
threads
=
BLOCK_SIZE
;
const
int
blocks
=
(
N
+
threads
-
1
)
/
threads
;
// Allocate output tensor for partial sums
auto
output
=
torch
::
empty
({
blocks
},
input
.
options
());
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"warp_reduce_cuda"
,
([
&
]
{
warp_reduce_cuda_kernel
<
scalar_t
><<<
blocks
,
threads
>>>
(
input
.
packed_accessor32
<
scalar_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
output
.
packed_accessor32
<
scalar_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
N
);
}));
// Sum the partial results
return
output
.
sum
();
}
sgl-kernel/src/sgl-kernel/ops/__init__.py
0 → 100644
View file @
47eb139f
from
.warp_reduce_cuda
import
reduce
as
_reduce
def
warp_reduce
(
input_tensor
):
return
_reduce
(
input_tensor
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment