Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
863176e5
Commit
863176e5
authored
Sep 06, 2025
by
SAC_fanth
Browse files
增加reduce修改
parent
a5d54d38
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
107 additions
and
13 deletions
+107
-13
csrc/moe/moe_align_sum_kernels.cu
csrc/moe/moe_align_sum_kernels.cu
+90
-0
csrc/moe/moe_ops.h
csrc/moe/moe_ops.h
+1
-0
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+2
-1
vllm/_custom_ops.py
vllm/_custom_ops.py
+2
-1
vllm/version.py
vllm/version.py
+12
-11
No files found.
csrc/moe/moe_align_sum_kernels.cu
View file @
863176e5
...
...
@@ -173,6 +173,35 @@ __global__ void moe_sum_kernel(
}
}
template
<
typename
scalar_t
,
int
TOPK
,
int
SPLIT_D
,
int
BLOCK_DIM
>
__global__
void
moe_sum_sharedmem_topk8
(
scalar_t
*
__restrict__
out
,
const
scalar_t
*
__restrict__
input
,
const
int
d
)
{
const
int
token_idx
=
blockIdx
.
x
/
SPLIT_D
;
const
int
sub_block
=
blockIdx
.
x
%
SPLIT_D
;
const
int
d_per_block
=
(
d
+
SPLIT_D
-
1
)
/
SPLIT_D
;
const
int64_t
d_start
=
sub_block
*
d_per_block
;
const
int64_t
token_offset
=
token_idx
*
TOPK
*
d
;
const
int64_t
d_end
=
min
(
d_start
+
d_per_block
,
d
);
__shared__
__align__
(
16
)
scalar_t
sem_input
[
TOPK
][
BLOCK_DIM
];
for
(
int64_t
idx
=
d_start
+
threadIdx
.
x
;
idx
<
d_end
;
idx
+=
blockDim
.
x
)
{
sem_input
[
0
][
threadIdx
.
x
]
=
input
[
token_offset
+
0
*
d
+
idx
];
sem_input
[
1
][
threadIdx
.
x
]
=
input
[
token_offset
+
1
*
d
+
idx
];
sem_input
[
2
][
threadIdx
.
x
]
=
input
[
token_offset
+
2
*
d
+
idx
];
sem_input
[
3
][
threadIdx
.
x
]
=
input
[
token_offset
+
3
*
d
+
idx
];
sem_input
[
4
][
threadIdx
.
x
]
=
input
[
token_offset
+
4
*
d
+
idx
];
sem_input
[
5
][
threadIdx
.
x
]
=
input
[
token_offset
+
5
*
d
+
idx
];
sem_input
[
6
][
threadIdx
.
x
]
=
input
[
token_offset
+
6
*
d
+
idx
];
sem_input
[
7
][
threadIdx
.
x
]
=
input
[
token_offset
+
7
*
d
+
idx
];
__syncthreads
();
scalar_t
x
=
sem_input
[
0
][
threadIdx
.
x
]
+
sem_input
[
1
][
threadIdx
.
x
]
+
sem_input
[
2
][
threadIdx
.
x
]
+
sem_input
[
3
][
threadIdx
.
x
]
+
sem_input
[
4
][
threadIdx
.
x
]
+
sem_input
[
5
][
threadIdx
.
x
]
+
sem_input
[
6
][
threadIdx
.
x
]
+
sem_input
[
7
][
threadIdx
.
x
];
out
[
token_idx
*
d
+
idx
]
=
x
;
}
}
template
<
typename
scalar_t
>
__global__
void
moe_align_block_size_small_batch_expert_kernel
(
const
scalar_t
*
__restrict__
topk_ids
,
...
...
@@ -353,6 +382,67 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
});
break
;
default:
at
::
sum_out
(
output
,
input
,
1
);
break
;
}
}
void
moe_sum_opt1
(
torch
::
Tensor
&
input
,
// [num_tokens, topk, hidden_size]
torch
::
Tensor
&
output
)
// [num_tokens, hidden_size]
{
const
int
hidden_size
=
input
.
size
(
-
1
);
const
auto
num_tokens
=
output
.
numel
()
/
hidden_size
;
const
int
topk
=
input
.
size
(
1
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
output
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
constexpr
int
splitD_
=
8
;
const
int
TOPK8_GRID_DIM
=
num_tokens
*
splitD_
;
constexpr
int
TOPK8_BLOCK_DIM
=
256
;
dim3
grid_8
(
TOPK8_GRID_DIM
);
dim3
block_8
(
TOPK8_BLOCK_DIM
);
switch
(
topk
)
{
case
2
:
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_sum_kernel"
,
[
&
]
{
vllm
::
moe
::
moe_sum_kernel
<
scalar_t
,
2
><<<
grid
,
block
,
0
,
stream
>>>
(
output
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
hidden_size
);
});
break
;
case
3
:
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_sum_kernel"
,
[
&
]
{
vllm
::
moe
::
moe_sum_kernel
<
scalar_t
,
3
><<<
grid
,
block
,
0
,
stream
>>>
(
output
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
hidden_size
);
});
break
;
case
4
:
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_sum_kernel"
,
[
&
]
{
vllm
::
moe
::
moe_sum_kernel
<
scalar_t
,
4
><<<
grid
,
block
,
0
,
stream
>>>
(
output
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
hidden_size
);
});
break
;
case
8
:
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_sum_sharedmem_topk8"
,
[
&
]{
vllm
::
moe
::
moe_sum_sharedmem_topk8
<
scalar_t
,
8
,
splitD_
,
TOPK8_BLOCK_DIM
><<<
grid_8
,
block_8
,
0
,
stream
>>>
(
output
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
hidden_size
);
});
break
;
default:
at
::
sum_out
(
output
,
input
,
1
);
break
;
...
...
csrc/moe/moe_ops.h
View file @
863176e5
...
...
@@ -7,6 +7,7 @@ void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch
::
Tensor
&
gating_output
);
void
moe_sum
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
output
);
void
moe_sum_opt1
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
output
);
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
int64_t
block_size
,
torch
::
Tensor
sorted_token_ids
,
...
...
csrc/moe/torch_bindings.cpp
View file @
863176e5
...
...
@@ -11,8 +11,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Calculate the result of moe by summing up the partial results
// from all selected experts.
m
.
def
(
"moe_sum(Tensor input, Tensor! output) -> ()"
);
m
.
def
(
"moe_sum_opt1(Tensor input, Tensor! output) -> ()"
);
m
.
impl
(
"moe_sum"
,
torch
::
kCUDA
,
&
moe_sum
);
m
.
impl
(
"moe_sum_opt1"
,
torch
::
kCUDA
,
&
moe_sum_opt1
);
// Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size.
m
.
def
(
...
...
vllm/_custom_ops.py
View file @
863176e5
...
...
@@ -1971,7 +1971,8 @@ def wvSplitKQ(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype,
# moe
def
moe_sum
(
input
:
torch
.
Tensor
,
output
:
torch
.
Tensor
):
torch
.
ops
.
_moe_C
.
moe_sum
(
input
,
output
)
def
moe_sum_opt1
(
input
:
torch
.
Tensor
,
output
:
torch
.
Tensor
):
torch
.
ops
.
_moe_C
.
moe_sum_opt1
(
input
,
output
)
def
moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
block_size
:
int
,
sorted_token_ids
:
torch
.
Tensor
,
...
...
vllm/version.py
View file @
863176e5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
try
:
from
._version
import
__version__
,
__version_tuple__
__version__
=
"0.9.2"
__version_tuple__
=
(
0
,
9
,
2
)
__hcu_version__
=
f
'0.9.2+das.opt1.rc1.a5d54d3.dtk25041'
from
vllm.version
import
__version__
,
__version_tuple__
,
__hcu_version__
except
Exception
as
e
:
import
warnings
warnings
.
warn
(
f
"Failed to read commit hash:
\n
{
e
}
"
,
warnings
.
warn
(
f
"Failed to read commit hash:
\n
+ str(e)
"
,
RuntimeWarning
,
stacklevel
=
2
)
__version__
=
"dev"
__version_tuple__
=
(
0
,
0
,
__version__
)
def
_prev_minor_version_was
(
version_str
):
"""
Check whether a given version matches the previous minor version.
'''
Check whether a given version matches the previous minor version.
Return True if version_str matches the previous minor version.
...
...
@@ -23,19 +24,19 @@ def _prev_minor_version_was(version_str):
supplied version_str is '0.6'.
Used for --show-hidden-metrics-for-version.
"""
'''
# Match anything if this is a dev tree
if
__version_tuple__
[
0
:
2
]
==
(
0
,
0
):
return
True
# Note - this won't do the right thing when we release 1.0!
assert
__version_tuple__
[
0
]
==
0
#
assert __version_tuple__[0] == 0
assert
isinstance
(
__version_tuple__
[
1
],
int
)
return
version_str
==
f
"
{
__version_tuple__
[
0
]
}
.
{
__version_tuple__
[
1
]
-
1
}
"
def
_prev_minor_version
():
"""
For the purpose of testing, return a previous minor version number.
"""
'''
For the purpose of testing, return a previous minor version number.
'''
# In dev tree, this will return "0.-1", but that will work fine"
assert
isinstance
(
__version_tuple__
[
1
],
int
)
return
f
"
{
__version_tuple__
[
0
]
}
.
{
__version_tuple__
[
1
]
-
1
}
"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment