Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
131fbe9b
Commit
131fbe9b
authored
Feb 15, 2025
by
王敏
Browse files
[fix]修复moe模型专家并行kernel报错
parent
f112086f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
125 additions
and
70 deletions
+125
-70
csrc/moe/moe_align_sum_kernels.cu
csrc/moe/moe_align_sum_kernels.cu
+125
-70
No files found.
csrc/moe/moe_align_sum_kernels.cu
View file @
131fbe9b
...
...
@@ -263,6 +263,79 @@ __global__ void sgl_moe_align_block_size_kernel(
}
}
// taken from
// https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a
template
<
typename
scalar_t
>
__global__
void
sgl_ep_moe_align_block_size_kernel
(
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
sorted_token_ids
,
int32_t
*
expert_ids
,
int32_t
*
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
block_size
,
size_t
numel
,
int32_t
*
cumsum
,
int32_t
start_expert
,
int32_t
end_expert
)
{
__shared__
int32_t
shared_counts
[
32
][
8
];
__shared__
int32_t
local_offsets
[
256
];
const
int
warp_id
=
threadIdx
.
x
/
32
;
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
experts_per_warp
=
8
;
const
int
my_expert_start
=
warp_id
*
experts_per_warp
;
for
(
int
i
=
0
;
i
<
experts_per_warp
;
++
i
)
{
if
(
my_expert_start
+
i
<
num_experts
)
{
shared_counts
[
warp_id
][
i
]
=
0
;
}
}
const
size_t
tokens_per_thread
=
CEILDIV
(
numel
,
blockDim
.
x
);
const
size_t
start_idx
=
threadIdx
.
x
*
tokens_per_thread
;
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
int
expert_id
=
topk_ids
[
i
];
if
(
expert_id
>=
start_expert
&&
expert_id
<
end_expert
)
{
expert_id
-=
start_expert
;
int
warp_idx
=
expert_id
/
experts_per_warp
;
int
expert_offset
=
expert_id
%
experts_per_warp
;
atomicAdd
(
&
shared_counts
[
warp_idx
][
expert_offset
],
1
);
}
}
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
int
expert_count
=
0
;
int
warp_idx
=
(
i
-
1
)
/
experts_per_warp
;
int
expert_offset
=
(
i
-
1
)
%
experts_per_warp
;
expert_count
=
shared_counts
[
warp_idx
][
expert_offset
];
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
CEILDIV
(
expert_count
,
block_size
)
*
block_size
;
}
*
total_tokens_post_pad
=
cumsum
[
num_experts
];
}
__syncthreads
();
if
(
threadIdx
.
x
<
num_experts
)
{
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
i
+=
block_size
)
{
expert_ids
[
i
/
block_size
]
=
threadIdx
.
x
;
}
local_offsets
[
threadIdx
.
x
]
=
cumsum
[
threadIdx
.
x
];
}
__syncthreads
();
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
int32_t
expert_id
=
topk_ids
[
i
];
if
(
expert_id
>=
start_expert
&&
expert_id
<
end_expert
)
{
expert_id
-=
start_expert
;
int32_t
rank_post_pad
=
atomicAdd
(
&
local_offsets
[
expert_id
],
1
);
sorted_token_ids
[
rank_post_pad
]
=
i
;
}
}
}
template
<
typename
scalar_t
,
int
TOPK
>
__global__
void
moe_sum_kernel
(
scalar_t
*
__restrict__
out
,
// [..., d]
...
...
@@ -488,75 +561,55 @@ void ep_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
((
num_thread
+
1
)
*
num_experts
)
*
sizeof
(
uint16_t
)
+
(
num_experts
+
1
)
*
sizeof
(
int32_t
);
// bool use_global_memory = false;
// bool use_i16 = false; // Use uint16_t for shared memory token counts
// if (shared_mem_i32 < device_max_shared_mem) {
// // Do nothing in this case. We're all set to use int32_t token counts
// } else if (shared_mem_i16 < device_max_shared_mem &&
// topk_ids.numel() <= 65535) {
// // when nelements of topk_ids is smaller than 65535 (max value of uint16),
// // element value of token_cnts would also smaller than 65535,
// // so we can use uint16 as dtype of token_cnts
// use_i16 = true;
// } else {
// use_global_memory = true;
// }
// if (use_global_memory) {
// VLLM_DISPATCH_INTEGRAL_TYPES(
// topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
// // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// // tensors
// const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
// auto options_int = torch::TensorOptions()
// .dtype(torch::kInt)
// .device(topk_ids.device());
bool
use_sgl_kernel
=
false
;
bool
use_i16
=
false
;
// Use uint16_t for shared memory token counts
if
(
shared_mem_i32
<
device_max_shared_mem
)
{
// Do nothing in this case. We're all set to use int32_t token counts
}
else
if
(
shared_mem_i16
<
device_max_shared_mem
&&
topk_ids
.
numel
()
<=
65535
)
{
// when nelements of topk_ids is smaller than 65535 (max value of uint16),
// element value of token_cnts would also smaller than 65535,
// so we can use uint16 as dtype of token_cnts
use_i16
=
true
;
}
else
{
use_sgl_kernel
=
true
;
}
if
(
use_sgl_kernel
)
{
VLLM_DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"sgl_ep_moe_align_block_size_kernel"
,
[
&
]
{
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
auto
options_int
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt
).
device
(
topk_ids
.
device
());
// torch::Tensor token_cnts_buffer =
// torch::empty({(num_experts + 1) * num_experts}, options_int);
// torch::Tensor cumsum_buffer =
// torch::empty({num_experts + 1}, options_int);
// auto kernel =
// vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>;
// kernel<<<1, num_thread, 0, stream>>>(
// topk_ids.data_ptr<scalar_t>(),
// sorted_token_ids.data_ptr<int32_t>(),
// experts_ids.data_ptr<int32_t>(),
// num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
// topk_ids.numel(), token_cnts_buffer.data_ptr<int32_t>(),
// cumsum_buffer.data_ptr<int32_t>());
// });
// } else if (use_i16) {
// VLLM_DISPATCH_INTEGRAL_TYPES(
// topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// // set dynamic shared mem
// auto kernel =
// vllm::moe::moe_align_block_size_kernel<scalar_t, uint16_t>;
// AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
// (void*)kernel, shared_mem_i16));
// kernel<<<1, num_thread, shared_mem_i16, stream>>>(
// topk_ids.data_ptr<scalar_t>(),
// sorted_token_ids.data_ptr<int32_t>(),
// experts_ids.data_ptr<int32_t>(),
// num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
// topk_ids.numel());
// });
// } else {
// VLLM_DISPATCH_INTEGRAL_TYPES(
// topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// auto kernel =
// vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
// AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
// (void*)kernel, shared_mem_i32));
// kernel<<<1, num_thread, shared_mem_i32, stream>>>(
// topk_ids.data_ptr<scalar_t>(),
// sorted_token_ids.data_ptr<int32_t>(),
// experts_ids.data_ptr<int32_t>(),
// num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
// topk_ids.numel());
// });
// }
torch
::
Tensor
cumsum_buffer
=
torch
::
empty
({
num_experts
+
1
},
options_int
);
auto
kernel
=
vllm
::
moe
::
sgl_ep_moe_align_block_size_kernel
<
scalar_t
>
;
kernel
<<<
1
,
1024
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
(),
start_expert
,
end_expert
);
});
}
else
if
(
use_i16
)
{
VLLM_DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"ep_moe_align_block_size_kernel"
,
[
&
]
{
auto
kernel
=
vllm
::
moe
::
ep_moe_align_block_size_kernel
<
scalar_t
,
uint16_t
>
;
AT_CUDA_CHECK
(
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize
(
(
void
*
)
kernel
,
shared_mem_i16
));
kernel
<<<
1
,
num_thread
,
shared_mem_i16
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
(),
start_expert
,
end_expert
);
});
}
else
{
VLLM_DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"ep_moe_align_block_size_kernel"
,
[
&
]
{
auto
kernel
=
...
...
@@ -570,6 +623,8 @@ void ep_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
(),
start_expert
,
end_expert
);
});
}
}
void
sgl_moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment