Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7d183246
Commit
7d183246
authored
May 15, 2025
by
wujl5
Browse files
optimized sum kernel
parent
2f6f5bb3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
39 additions
and
4 deletions
+39
-4
csrc/moe/moe_align_sum_kernels.cu
csrc/moe/moe_align_sum_kernels.cu
+39
-4
No files found.
csrc/moe/moe_align_sum_kernels.cu
View file @
7d183246
...
...
@@ -363,6 +363,35 @@ __global__ void moe_sum_kernel(
}
}
template
<
typename
scalar_t
,
int
TOPK
,
int
SPLIT_D
,
int
BLOCK_DIM
>
__global__
void
moe_sum_sharedmem_topk8
(
scalar_t
*
__restrict__
out
,
const
scalar_t
*
__restrict__
input
,
const
int
d
)
{
const
int
token_idx
=
blockIdx
.
x
/
SPLIT_D
;
const
int
sub_block
=
blockIdx
.
x
%
SPLIT_D
;
const
int
d_per_block
=
(
d
+
SPLIT_D
-
1
)
/
SPLIT_D
;
const
int64_t
d_start
=
sub_block
*
d_per_block
;
const
int64_t
token_offset
=
token_idx
*
TOPK
*
d
;
const
int64_t
d_end
=
min
(
d_start
+
d_per_block
,
d
);
__shared__
__align__
(
16
)
scalar_t
sem_input
[
TOPK
][
BLOCK_DIM
];
for
(
int64_t
idx
=
d_start
+
threadIdx
.
x
;
idx
<
d_end
;
idx
+=
blockDim
.
x
)
{
sem_input
[
0
][
threadIdx
.
x
]
=
input
[
token_offset
+
0
*
d
+
idx
];
sem_input
[
1
][
threadIdx
.
x
]
=
input
[
token_offset
+
1
*
d
+
idx
];
sem_input
[
2
][
threadIdx
.
x
]
=
input
[
token_offset
+
2
*
d
+
idx
];
sem_input
[
3
][
threadIdx
.
x
]
=
input
[
token_offset
+
3
*
d
+
idx
];
sem_input
[
4
][
threadIdx
.
x
]
=
input
[
token_offset
+
4
*
d
+
idx
];
sem_input
[
5
][
threadIdx
.
x
]
=
input
[
token_offset
+
5
*
d
+
idx
];
sem_input
[
6
][
threadIdx
.
x
]
=
input
[
token_offset
+
6
*
d
+
idx
];
sem_input
[
7
][
threadIdx
.
x
]
=
input
[
token_offset
+
7
*
d
+
idx
];
__syncthreads
();
scalar_t
x
=
sem_input
[
0
][
threadIdx
.
x
]
+
sem_input
[
1
][
threadIdx
.
x
]
+
sem_input
[
2
][
threadIdx
.
x
]
+
sem_input
[
3
][
threadIdx
.
x
]
+
sem_input
[
4
][
threadIdx
.
x
]
+
sem_input
[
5
][
threadIdx
.
x
]
+
sem_input
[
6
][
threadIdx
.
x
]
+
sem_input
[
7
][
threadIdx
.
x
];
out
[
token_idx
*
d
+
idx
]
=
x
;
}
}
}
// namespace moe
}
// namespace vllm
...
...
@@ -504,6 +533,12 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
output
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
constexpr
int
splitD_
=
8
;
const
int
TOPK8_GRID_DIM
=
num_tokens
*
splitD_
;
constexpr
int
TOPK8_BLOCK_DIM
=
256
;
dim3
grid_8
(
TOPK8_GRID_DIM
);
dim3
block_8
(
TOPK8_BLOCK_DIM
);
switch
(
topk
)
{
case
2
:
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_sum_kernel"
,
[
&
]
{
...
...
@@ -530,15 +565,15 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
break
;
case
8
:
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_sum_
kernel
"
,
[
&
]
{
vllm
::
moe
::
moe_sum_
kernel
<
scalar_t
,
8
><<<
grid
,
block
,
0
,
stream
>>>
(
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_sum_
sharedmem_topk8
"
,
[
&
]{
vllm
::
moe
::
moe_sum_
sharedmem_topk8
<
scalar_t
,
8
,
splitD_
,
TOPK8_BLOCK_DIM
><<<
grid
_8
,
block
_8
,
0
,
stream
>>>
(
output
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
hidden_size
);
});
break
;
default:
at
::
sum_out
(
output
,
input
,
1
);
break
;
}
}
\ No newline at end of file
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment