Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ea657f20
Unverified
Commit
ea657f20
authored
Dec 08, 2025
by
gnovack
Committed by
GitHub
Dec 09, 2025
Browse files
Lora MoE Align Improvements (#29257)
Signed-off-by:
gnovack
<
gnovack@amazon.com
>
parent
db14f61f
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
360 additions
and
249 deletions
+360
-249
CMakeLists.txt
CMakeLists.txt
+0
-1
csrc/moe/moe_align_sum_kernels.cu
csrc/moe/moe_align_sum_kernels.cu
+354
-71
csrc/moe/moe_lora_align_sum_kernels.cu
csrc/moe/moe_lora_align_sum_kernels.cu
+0
-174
csrc/moe/moe_ops.h
csrc/moe/moe_ops.h
+1
-1
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+2
-1
tests/lora/test_moe_lora_align_sum.py
tests/lora/test_moe_lora_align_sum.py
+1
-1
vllm/_custom_ops.py
vllm/_custom_ops.py
+2
-0
No files found.
CMakeLists.txt
View file @
ea657f20
...
@@ -944,7 +944,6 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
...
@@ -944,7 +944,6 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
set
(
VLLM_MOE_EXT_SRC
set
(
VLLM_MOE_EXT_SRC
"csrc/moe/torch_bindings.cpp"
"csrc/moe/torch_bindings.cpp"
"csrc/moe/moe_align_sum_kernels.cu"
"csrc/moe/moe_align_sum_kernels.cu"
"csrc/moe/moe_lora_align_sum_kernels.cu"
"csrc/moe/topk_softmax_kernels.cu"
)
"csrc/moe/topk_softmax_kernels.cu"
)
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
)
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
)
...
...
csrc/moe/moe_align_sum_kernels.cu
View file @
ea657f20
This diff is collapsed.
Click to expand it.
csrc/moe/moe_lora_align_sum_kernels.cu
deleted
100644 → 0
View file @
db14f61f
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/ATen.h>
#include <ATen/cuda/Atomic.cuh>
#include "../cuda_compat.h"
#include "../dispatch_utils.h"
#include "core/math.hpp"
namespace
{
__device__
__forceinline__
int32_t
index
(
int32_t
total_col
,
int32_t
row
,
int32_t
col
)
{
return
row
*
total_col
+
col
;
}
}
// namespace
// TODO: Refactor common parts with moe_align_sum_kernels
template
<
typename
scalar_t
,
typename
token_cnts_t
>
__global__
void
moe_lora_align_sum_kernel
(
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
token_lora_mapping
,
int64_t
block_size
,
int
num_experts
,
int
max_loras
,
size_t
numel
,
int
max_num_tokens_padded
,
int
max_num_m_blocks
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
expert_ids
,
int
topk_num
,
int32_t
*
total_tokens_post_pad
,
int32_t
*
adapter_enabled
,
int32_t
*
lora_ids
)
{
const
size_t
tokens_per_thread
=
div_ceil
(
numel
,
blockDim
.
x
);
const
size_t
start_idx
=
threadIdx
.
x
*
tokens_per_thread
;
int
lora_idx
=
blockIdx
.
x
;
int
lora_id
=
lora_ids
[
lora_idx
];
if
(
lora_id
==
-
1
||
adapter_enabled
[
lora_id
]
==
0
)
{
return
;
}
extern
__shared__
int32_t
shared_mem
[];
int32_t
*
cumsum
=
shared_mem
;
token_cnts_t
*
tokens_cnts
=
(
token_cnts_t
*
)(
shared_mem
+
num_experts
+
1
);
// Initialize sorted_token_ids with numel
for
(
size_t
it
=
threadIdx
.
x
;
it
<
max_num_tokens_padded
;
it
+=
blockDim
.
x
)
{
sorted_token_ids
[
lora_id
*
max_num_tokens_padded
+
it
]
=
numel
;
}
// Initialize expert_ids with -1
for
(
size_t
it
=
threadIdx
.
x
;
it
<
max_num_m_blocks
;
it
+=
blockDim
.
x
)
{
expert_ids
[
lora_id
*
max_num_m_blocks
+
it
]
=
-
1
;
}
// Initialize total_tokens_post_pad with 0
if
(
threadIdx
.
x
==
0
)
{
total_tokens_post_pad
[
lora_id
]
=
0
;
}
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
+
1
,
i
)]
=
0
;
}
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
int
mask
=
token_lora_mapping
[
i
/
topk_num
]
==
lora_id
;
int
idx
=
index
(
num_experts
,
threadIdx
.
x
+
1
,
topk_ids
[
i
]);
tokens_cnts
[
idx
]
+=
mask
;
}
__syncthreads
();
// For each expert we accumulate the token counts from the different threads.
if
(
threadIdx
.
x
<
num_experts
)
{
tokens_cnts
[
index
(
num_experts
,
0
,
threadIdx
.
x
)]
=
0
;
for
(
int
i
=
1
;
i
<=
blockDim
.
x
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
i
,
threadIdx
.
x
)]
+=
tokens_cnts
[
index
(
num_experts
,
i
-
1
,
threadIdx
.
x
)];
}
}
__syncthreads
();
// We accumulate the token counts of all experts in thread 0.
if
(
threadIdx
.
x
==
0
)
{
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
div_ceil
(
tokens_cnts
[
index
(
num_experts
,
blockDim
.
x
,
i
-
1
)],
block_size
)
*
block_size
;
}
total_tokens_post_pad
[
lora_id
]
=
static_cast
<
int32_t
>
(
cumsum
[
num_experts
]);
}
__syncthreads
();
/**
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
if
(
threadIdx
.
x
<
num_experts
)
{
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
i
+=
block_size
)
{
expert_ids
[
index
(
max_num_m_blocks
,
lora_id
,
i
/
block_size
)]
=
threadIdx
.
x
;
}
}
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
int32_t
expert_id
=
topk_ids
[
i
];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* processed by the expert with expert_id within the current thread's token
* shard.
*/
int32_t
rank_post_pad
=
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)]
+
cumsum
[
expert_id
];
int
mask
=
(
int
)
token_lora_mapping
[
i
/
topk_num
]
==
lora_id
;
atomicAdd
(
&
sorted_token_ids
[
index
(
max_num_tokens_padded
,
lora_id
,
rank_post_pad
)],
(
i
-
numel
)
*
mask
);
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)]
+=
mask
;
}
}
void
moe_lora_align_block_size
(
torch
::
Tensor
topk_ids
,
torch
::
Tensor
token_lora_mapping
,
int64_t
num_experts
,
int64_t
block_size
,
int64_t
max_loras
,
int64_t
max_num_tokens_padded
,
int64_t
max_num_m_blocks
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
expert_ids
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
adapter_enabled
,
torch
::
Tensor
lora_ids
)
{
const
int
topk_num
=
topk_ids
.
size
(
1
);
TORCH_CHECK
(
block_size
>
0
,
"block_size should be greater than 0. "
);
int
device_max_shared_mem
;
auto
dev
=
topk_ids
.
get_device
();
cudaDeviceGetAttribute
(
&
device_max_shared_mem
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
dev
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
int32_t
num_thread
=
max
((
int32_t
)
num_experts
,
128
);
// WARP_SIZE,
TORCH_CHECK
(
num_thread
<=
1024
,
"num_thread must be less than 1024, "
"and fallback is not implemented yet."
);
const
int32_t
shared_mem
=
(
num_thread
+
1
)
*
num_experts
*
sizeof
(
int32_t
)
+
(
num_experts
+
1
)
*
sizeof
(
int32_t
);
if
(
shared_mem
>
device_max_shared_mem
)
{
TORCH_CHECK
(
false
,
"Shared memory usage exceeds device limit, and global memory "
"fallback is not implemented yet."
);
}
VLLM_DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"moe_lora_align_sum_kernel"
,
[
&
]
{
dim3
blockDim
(
num_thread
);
auto
kernel
=
moe_lora_align_sum_kernel
<
scalar_t
,
int32_t
>
;
AT_CUDA_CHECK
(
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize
(
(
void
*
)
kernel
,
shared_mem
));
kernel
<<<
max_loras
,
blockDim
,
shared_mem
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
token_lora_mapping
.
data_ptr
<
int32_t
>
(),
block_size
,
num_experts
,
max_loras
,
topk_ids
.
numel
(),
max_num_tokens_padded
,
max_num_m_blocks
,
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
expert_ids
.
data_ptr
<
int32_t
>
(),
topk_num
,
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
adapter_enabled
.
data_ptr
<
int32_t
>
(),
lora_ids
.
data_ptr
<
int32_t
>
());
});
}
\ No newline at end of file
csrc/moe/moe_ops.h
View file @
ea657f20
...
@@ -27,7 +27,7 @@ void moe_lora_align_block_size(
...
@@ -27,7 +27,7 @@ void moe_lora_align_block_size(
int64_t
max_num_tokens_padded
,
int64_t
max_num_m_blocks
,
int64_t
max_num_tokens_padded
,
int64_t
max_num_m_blocks
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
expert_ids
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
expert_ids
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
adapter_enabled
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
adapter_enabled
,
torch
::
Tensor
lora_ids
);
torch
::
Tensor
lora_ids
,
std
::
optional
<
torch
::
Tensor
>
maybe_expert_map
);
#ifndef USE_ROCM
#ifndef USE_ROCM
torch
::
Tensor
moe_wna16_gemm
(
torch
::
Tensor
input
,
torch
::
Tensor
output
,
torch
::
Tensor
moe_wna16_gemm
(
torch
::
Tensor
input
,
torch
::
Tensor
output
,
torch
::
Tensor
b_qweight
,
torch
::
Tensor
b_scales
,
torch
::
Tensor
b_qweight
,
torch
::
Tensor
b_scales
,
...
...
csrc/moe/torch_bindings.cpp
View file @
ea657f20
...
@@ -47,7 +47,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
...
@@ -47,7 +47,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" Tensor !experts_ids,"
" Tensor !experts_ids,"
" Tensor !num_tokens_post_pad,"
" Tensor !num_tokens_post_pad,"
" Tensor !adapter_enabled,"
" Tensor !adapter_enabled,"
" Tensor !lora_ids) -> () "
);
" Tensor !lora_ids,"
" Tensor? maybe_expert_map) -> () "
);
m
.
impl
(
"moe_lora_align_block_size"
,
torch
::
kCUDA
,
&
moe_lora_align_block_size
);
m
.
impl
(
"moe_lora_align_block_size"
,
torch
::
kCUDA
,
&
moe_lora_align_block_size
);
#ifndef USE_ROCM
#ifndef USE_ROCM
...
...
tests/lora/test_moe_lora_align_sum.py
View file @
ea657f20
...
@@ -32,7 +32,7 @@ def sample_data(num_experts, max_loras, num_tokens, topk_num):
...
@@ -32,7 +32,7 @@ def sample_data(num_experts, max_loras, num_tokens, topk_num):
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
100
,
200
,
1024
,
4096
])
# 81920
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
100
,
200
,
1024
,
4096
])
# 81920
@
pytest
.
mark
.
parametrize
(
"topk_num"
,
[
6
])
@
pytest
.
mark
.
parametrize
(
"topk_num"
,
[
6
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
64
,
128
,
256
,
512
])
@
pytest
.
mark
.
parametrize
(
"max_loras"
,
[
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"max_loras"
,
[
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
def
test_moe_lora_align_block_size
(
def
test_moe_lora_align_block_size
(
...
...
vllm/_custom_ops.py
View file @
ea657f20
...
@@ -1961,6 +1961,7 @@ def moe_lora_align_block_size(
...
@@ -1961,6 +1961,7 @@ def moe_lora_align_block_size(
num_tokens_post_pad
:
torch
.
Tensor
,
num_tokens_post_pad
:
torch
.
Tensor
,
adapter_enabled
:
torch
.
Tensor
,
adapter_enabled
:
torch
.
Tensor
,
lora_ids
:
torch
.
Tensor
,
lora_ids
:
torch
.
Tensor
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
)
->
None
:
)
->
None
:
torch
.
ops
.
_moe_C
.
moe_lora_align_block_size
(
torch
.
ops
.
_moe_C
.
moe_lora_align_block_size
(
topk_ids
,
topk_ids
,
...
@@ -1975,6 +1976,7 @@ def moe_lora_align_block_size(
...
@@ -1975,6 +1976,7 @@ def moe_lora_align_block_size(
num_tokens_post_pad
,
num_tokens_post_pad
,
adapter_enabled
,
adapter_enabled
,
lora_ids
,
lora_ids
,
expert_map
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment