Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
31548116
Unverified
Commit
31548116
authored
Dec 26, 2024
by
Yineng Zhang
Committed by
GitHub
Dec 26, 2024
Browse files
fix moe_align_block_size_kernel for shared memory issue (#2579)
Co-authored-by:
ispobock
<
ispobaoke@163.com
>
parent
53aed988
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
225 additions
and
2 deletions
+225
-2
sgl-kernel/pyproject.toml
sgl-kernel/pyproject.toml
+1
-1
sgl-kernel/setup.py
sgl-kernel/setup.py
+20
-0
sgl-kernel/src/sgl-kernel/__init__.py
sgl-kernel/src/sgl-kernel/__init__.py
+8
-1
sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu
+151
-0
sgl-kernel/src/sgl-kernel/ops/__init__.py
sgl-kernel/src/sgl-kernel/ops/__init__.py
+19
-0
sgl-kernel/tests/test_moe_align.py
sgl-kernel/tests/test_moe_align.py
+26
-0
No files found.
sgl-kernel/pyproject.toml
View file @
31548116
...
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
...
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
[project]
name
=
"sgl-kernel"
name
=
"sgl-kernel"
version
=
"0.0.2.post
5
"
version
=
"0.0.2.post
6
"
description
=
"Kernel Library for SGLang"
description
=
"Kernel Library for SGLang"
readme
=
"README.md"
readme
=
"README.md"
requires-python
=
">=3.8"
requires-python
=
">=3.8"
...
...
sgl-kernel/setup.py
View file @
31548116
...
@@ -109,6 +109,26 @@ setup(
...
@@ -109,6 +109,26 @@ setup(
libraries
=
[
"c10"
,
"torch"
,
"torch_python"
],
libraries
=
[
"c10"
,
"torch"
,
"torch_python"
],
extra_link_args
=
[
"-Wl,-rpath,$ORIGIN/../../torch/lib"
],
extra_link_args
=
[
"-Wl,-rpath,$ORIGIN/../../torch/lib"
],
),
),
CUDAExtension
(
"sgl_kernel.ops.moe_align_block_size"
,
[
"src/sgl-kernel/csrc/moe_align_kernel.cu"
,
],
extra_compile_args
=
{
"nvcc"
:
[
"-O3"
,
"-Xcompiler"
,
"-fPIC"
,
"-gencode=arch=compute_75,code=sm_75"
,
"-gencode=arch=compute_80,code=sm_80"
,
"-gencode=arch=compute_89,code=sm_89"
,
"-gencode=arch=compute_90,code=sm_90"
,
],
"cxx"
:
[
"-O3"
],
},
libraries
=
[
"c10"
,
"torch"
,
"torch_python"
],
extra_link_args
=
[
"-Wl,-rpath,$ORIGIN/../../torch/lib"
],
),
],
],
cmdclass
=
{
"build_ext"
:
BuildExtension
},
cmdclass
=
{
"build_ext"
:
BuildExtension
},
install_requires
=
[
"torch"
],
install_requires
=
[
"torch"
],
...
...
sgl-kernel/src/sgl-kernel/__init__.py
View file @
31548116
from
.ops
import
custom_dispose
,
custom_reduce
,
init_custom_reduce
,
warp_reduce
from
.ops
import
(
custom_dispose
,
custom_reduce
,
init_custom_reduce
,
moe_align_block_size
,
warp_reduce
,
)
__all__
=
[
__all__
=
[
"warp_reduce"
,
"warp_reduce"
,
"init_custom_reduce"
,
"init_custom_reduce"
,
"custom_dispose"
,
"custom_dispose"
,
"custom_reduce"
,
"custom_reduce"
,
"moe_align_block_size"
,
]
]
sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu
0 → 100644
View file @
31548116
// Adapted from https://github.com/vllm-project/vllm/blob/v0.6.5/csrc/moe/moe_align_sum_kernels.cu
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include <torch/extension.h>
#include <THC/THCAtomics.cuh>
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#endif
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#ifndef USE_ROCM
#define DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
#else
#define DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
#endif
#define CEILDIV(x, y) (((x) + (y)-1) / (y))
#define DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
__device__
__forceinline__
int32_t
index
(
int32_t
total_col
,
int32_t
row
,
int32_t
col
)
{
// don't worry about overflow because num_experts is relatively small
return
row
*
total_col
+
col
;
}
template
<
typename
scalar_t
>
__global__
void
moe_align_block_size_kernel
(
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
sorted_token_ids
,
int32_t
*
expert_ids
,
int32_t
*
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
block_size
,
size_t
numel
,
int32_t
*
tokens_cnts
,
int32_t
*
cumsum
)
{
const
size_t
tokens_per_thread
=
CEILDIV
(
numel
,
blockDim
.
x
);
const
size_t
start_idx
=
threadIdx
.
x
*
tokens_per_thread
;
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
+
1
,
i
)]
=
0
;
}
/**
* In the first step we compute token_cnts[thread_index + 1][expert_index],
* which counts how many tokens in the token shard of thread_index are
* assigned to expert expert_index.
*/
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
++
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
+
1
,
topk_ids
[
i
])];
}
__syncthreads
();
// For each expert we accumulate the token counts from the different threads.
if
(
threadIdx
.
x
<
num_experts
)
{
tokens_cnts
[
index
(
num_experts
,
0
,
threadIdx
.
x
)]
=
0
;
for
(
int
i
=
1
;
i
<=
blockDim
.
x
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
i
,
threadIdx
.
x
)]
+=
tokens_cnts
[
index
(
num_experts
,
i
-
1
,
threadIdx
.
x
)];
}
}
__syncthreads
();
// We accumulate the token counts of all experts in thread 0.
if
(
threadIdx
.
x
==
0
)
{
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
CEILDIV
(
tokens_cnts
[
index
(
num_experts
,
blockDim
.
x
,
i
-
1
)],
block_size
)
*
block_size
;
}
*
total_tokens_post_pad
=
cumsum
[
num_experts
];
}
__syncthreads
();
/**
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
if
(
threadIdx
.
x
<
num_experts
)
{
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
i
+=
block_size
)
{
expert_ids
[
i
/
block_size
]
=
threadIdx
.
x
;
}
}
/**
* Each thread processes a token shard, calculating the index of each token
* after sorting by expert number. Given the example topk_ids =
* [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
* *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
* padding value(preset in python).
*/
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
int32_t
expert_id
=
topk_ids
[
i
];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* processed by the expert with expert_id within the current thread's token
* shard.
*/
int32_t
rank_post_pad
=
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)]
+
cumsum
[
expert_id
];
sorted_token_ids
[
rank_post_pad
]
=
i
;
++
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)];
}
}
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
int64_t
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
)
{
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
const
int32_t
num_thread
=
max
((
int32_t
)
num_experts
,
WARP_SIZE
);
const
int32_t
mem_tokens_cnts
=
((
num_experts
+
1
)
*
num_experts
)
*
sizeof
(
int32_t
);
const
int32_t
mem_cumsum
=
(
num_experts
+
1
)
*
sizeof
(
int32_t
);
// allocate global memory
int32_t
*
tokens_cnts
;
int32_t
*
cumsum
;
cudaMalloc
(
&
tokens_cnts
,
mem_tokens_cnts
);
cudaMalloc
(
&
cumsum
,
mem_cumsum
);
// set dynamic shared mem
auto
kernel
=
moe_align_block_size_kernel
<
scalar_t
>
;
kernel
<<<
1
,
num_thread
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
(),
tokens_cnts
,
cumsum
);
cudaFree
(
tokens_cnts
);
cudaFree
(
cumsum
);
});
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"moe_align_block_size"
,
&
moe_align_block_size
,
"MOE Align Block Size (CUDA)"
);
}
sgl-kernel/src/sgl-kernel/ops/__init__.py
View file @
31548116
from
.custom_reduce_cuda
import
all_reduce
as
_all_reduce
from
.custom_reduce_cuda
import
all_reduce
as
_all_reduce
from
.custom_reduce_cuda
import
dispose
as
_dispose
from
.custom_reduce_cuda
import
dispose
as
_dispose
from
.custom_reduce_cuda
import
init_custom_ar
as
_init_custom_ar
from
.custom_reduce_cuda
import
init_custom_ar
as
_init_custom_ar
from
.moe_align_block_size
import
moe_align_block_size
as
_moe_align_block_size
from
.warp_reduce_cuda
import
reduce
as
_reduce
from
.warp_reduce_cuda
import
reduce
as
_reduce
...
@@ -18,3 +19,21 @@ def custom_dispose(fa):
...
@@ -18,3 +19,21 @@ def custom_dispose(fa):
def
custom_reduce
(
fa
,
inp
,
out
):
def
custom_reduce
(
fa
,
inp
,
out
):
_all_reduce
(
fa
,
inp
,
out
)
_all_reduce
(
fa
,
inp
,
out
)
def
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_token_ids
,
experts_ids
,
num_tokens_post_pad
,
):
_moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_token_ids
,
experts_ids
,
num_tokens_post_pad
,
)
sgl-kernel/tests/test_moe_align.py
0 → 100644
View file @
31548116
import
torch
from
sgl_kernel
import
moe_align_block_size
def
test_moe_align_block_size
():
num_experts
=
256
block_size
=
128
topk_ids
=
torch
.
randint
(
0
,
num_experts
,
(
3
,
4
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids
=
torch
.
empty
(
(
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
max_num_tokens_padded
//
block_size
expert_ids
=
torch
.
empty
(
(
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
)
test_moe_align_block_size
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment