Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ea657f20
Unverified
Commit
ea657f20
authored
Dec 08, 2025
by
gnovack
Committed by
GitHub
Dec 09, 2025
Browse files
Lora MoE Align Improvements (#29257)
Signed-off-by:
gnovack
<
gnovack@amazon.com
>
parent
db14f61f
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
360 additions
and
249 deletions
+360
-249
CMakeLists.txt
CMakeLists.txt
+0
-1
csrc/moe/moe_align_sum_kernels.cu
csrc/moe/moe_align_sum_kernels.cu
+354
-71
csrc/moe/moe_lora_align_sum_kernels.cu
csrc/moe/moe_lora_align_sum_kernels.cu
+0
-174
csrc/moe/moe_ops.h
csrc/moe/moe_ops.h
+1
-1
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+2
-1
tests/lora/test_moe_lora_align_sum.py
tests/lora/test_moe_lora_align_sum.py
+1
-1
vllm/_custom_ops.py
vllm/_custom_ops.py
+2
-0
No files found.
CMakeLists.txt
View file @
ea657f20
...
@@ -944,7 +944,6 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
...
@@ -944,7 +944,6 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
set
(
VLLM_MOE_EXT_SRC
set
(
VLLM_MOE_EXT_SRC
"csrc/moe/torch_bindings.cpp"
"csrc/moe/torch_bindings.cpp"
"csrc/moe/moe_align_sum_kernels.cu"
"csrc/moe/moe_align_sum_kernels.cu"
"csrc/moe/moe_lora_align_sum_kernels.cu"
"csrc/moe/topk_softmax_kernels.cu"
)
"csrc/moe/topk_softmax_kernels.cu"
)
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
)
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
)
...
...
csrc/moe/moe_align_sum_kernels.cu
View file @
ea657f20
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
namespace
vllm
{
namespace
vllm
{
namespace
moe
{
namespace
moe
{
namespace
batched_moe_align_block_size
{
namespace
batched_moe_align_block_size
{
// Note num_threads needs to be 1024 for BlockScan Reduction in the kernel.
// Note num_threads needs to be 1024 for BlockScan Reduction in the kernel.
...
@@ -80,23 +79,30 @@ __global__ void batched_moe_align_block_size_kernel(
...
@@ -80,23 +79,30 @@ __global__ void batched_moe_align_block_size_kernel(
}
// namespace batched_moe_align_block_size
}
// namespace batched_moe_align_block_size
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__
global
__
void
moe_align_block_size
_kernel
(
__
device
__
void
_
moe_align_block_size
(
const
scalar_t
*
__restrict__
topk_ids
,
const
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
expert_ids
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
expert_ids
,
int32_t
*
__restrict__
total_tokens_post_pad
,
int32_t
*
__restrict__
total_tokens_post_pad
,
int32_t
*
__restrict__
expert_map
,
int32_t
num_experts
,
int32_t
*
__restrict__
expert_map
,
int32_t
num_experts
,
int32_t
padded_num_experts
,
int32_t
experts_per_warp
,
int32_t
block_size
,
int32_t
padded_num_experts
,
int32_t
experts_per_warp
,
int32_t
block_size
,
size_t
numel
,
int32_t
*
__restrict__
cumsum
,
int32_t
max_num_tokens_padded
,
size_t
numel
,
int32_t
*
__restrict__
cumsum
,
int32_t
max_num_tokens_padded
,
bool
has_expert_map
)
{
int32_t
max_num_m_blocks
,
int32_t
model_offset
,
int32_t
inactive_expert_id
,
int32_t
topk_num
,
int32_t
*
token_mask
,
bool
has_expert_map
)
{
extern
__shared__
int32_t
shared_counts
[];
extern
__shared__
int32_t
shared_counts
[];
// Use a separate threadblock to fill sorted_token_ids.
// Compute input buffer offsets. Typically these will all be 0, except when
// using Multi LoRA.
int
sorted_token_ids_offset
=
max_num_tokens_padded
*
model_offset
;
int
expert_ids_offset
=
max_num_m_blocks
*
model_offset
;
int
cumsum_offset
=
(
num_experts
+
1
)
*
model_offset
;
// Use separate threadblocks to fill sorted_token_ids.
// This is safe since the current kernel does not use sorted_token_ids.
// This is safe since the current kernel does not use sorted_token_ids.
if
(
blockIdx
.
x
==
1
)
{
if
(
blockIdx
.
x
%
2
)
{
// Initialize sorted_token_ids with numel
// Initialize sorted_token_ids with numel
for
(
size_t
it
=
threadIdx
.
x
;
it
<
max_num_tokens_padded
;
for
(
size_t
it
=
threadIdx
.
x
;
it
<
max_num_tokens_padded
;
it
+=
blockDim
.
x
)
{
it
+=
blockDim
.
x
)
{
sorted_token_ids
[
it
]
=
numel
;
sorted_token_ids
[
sorted_token_ids_offset
+
it
]
=
numel
;
}
}
return
;
return
;
}
}
...
@@ -127,7 +133,9 @@ __global__ void moe_align_block_size_kernel(
...
@@ -127,7 +133,9 @@ __global__ void moe_align_block_size_kernel(
}
}
int
warp_idx
=
expert_id
/
experts_per_warp
;
int
warp_idx
=
expert_id
/
experts_per_warp
;
int
expert_offset
=
expert_id
%
experts_per_warp
;
int
expert_offset
=
expert_id
%
experts_per_warp
;
atomicAdd
(
&
shared_counts
[
warp_idx
*
experts_per_warp
+
expert_offset
],
1
);
int
mask
=
token_mask
==
nullptr
?
1
:
token_mask
[
i
/
topk_num
];
atomicAdd
(
&
shared_counts
[
warp_idx
*
experts_per_warp
+
expert_offset
],
mask
);
}
}
__syncthreads
();
__syncthreads
();
...
@@ -148,77 +156,44 @@ __global__ void moe_align_block_size_kernel(
...
@@ -148,77 +156,44 @@ __global__ void moe_align_block_size_kernel(
int
cumsum_val
;
int
cumsum_val
;
BlockScan
(
temp_storage
).
ExclusiveSum
(
expert_count
,
cumsum_val
);
BlockScan
(
temp_storage
).
ExclusiveSum
(
expert_count
,
cumsum_val
);
if
(
expert_id
<=
num_experts
)
{
if
(
expert_id
<=
num_experts
)
{
cumsum
[
expert_id
]
=
cumsum_val
;
cumsum
[
cumsum_offset
+
expert_id
]
=
cumsum_val
;
}
}
if
(
expert_id
==
num_experts
)
{
if
(
expert_id
==
num_experts
)
{
*
total_tokens_post_pad
=
cumsum_val
;
total_tokens_post_pad
[
model_offset
]
=
cumsum_val
;
}
}
__syncthreads
();
__syncthreads
();
if
(
threadIdx
.
x
<
num_experts
)
{
if
(
threadIdx
.
x
<
num_experts
)
{
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
for
(
int
i
=
cumsum
[
cumsum_offset
+
threadIdx
.
x
];
i
+=
block_size
)
{
i
<
cumsum
[
cumsum_offset
+
threadIdx
.
x
+
1
];
i
+=
block_size
)
{
expert_ids
[
i
/
block_size
]
=
threadIdx
.
x
;
expert_ids
[
expert_ids_offset
+
i
/
block_size
]
=
threadIdx
.
x
;
}
}
}
}
// Fill remaining expert_ids with 0
// Fill remaining expert_ids with 0
const
size_t
fill_start_idx
=
cumsum
[
num_experts
]
/
block_size
+
threadIdx
.
x
;
const
size_t
fill_start_idx
=
const
size_t
expert_ids_size
=
CEILDIV
(
max_num_tokens_padded
,
block_size
);
cumsum
[
cumsum_offset
+
num_experts
]
/
block_size
+
threadIdx
.
x
;
for
(
size_t
i
=
fill_start_idx
;
i
<
expert_ids_size
;
i
+=
blockDim
.
x
)
{
for
(
size_t
i
=
fill_start_idx
;
i
<
max_num_m_blocks
;
i
+=
blockDim
.
x
)
{
expert_ids
[
i
]
=
0
;
expert_ids
[
expert_ids_offset
+
i
]
=
inactive_expert_id
;
}
}
template
<
typename
scalar_t
>
__global__
void
count_and_sort_expert_tokens_kernel
(
const
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
cumsum_buffer
,
int32_t
*
__restrict__
expert_map
,
size_t
numel
,
int32_t
num_experts
,
bool
has_expert_map
)
{
const
size_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
size_t
i
=
tid
;
i
<
numel
;
i
+=
stride
)
{
int32_t
expert_id
=
topk_ids
[
i
];
if
(
expert_id
>=
num_experts
)
{
continue
;
}
if
(
has_expert_map
)
{
expert_id
=
expert_map
[
expert_id
];
// filter invalid experts
if
(
expert_id
==
-
1
)
continue
;
}
int32_t
rank_post_pad
=
atomicAdd
(
&
cumsum_buffer
[
expert_id
],
1
);
sorted_token_ids
[
rank_post_pad
]
=
i
;
}
}
template
<
typename
scalar_t
,
int
TOPK
>
__global__
void
moe_sum_kernel
(
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., topk, d]
const
int
d
)
{
const
int64_t
token_idx
=
blockIdx
.
x
;
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
scalar_t
x
=
0.0
;
#pragma unroll
for
(
int
k
=
0
;
k
<
TOPK
;
++
k
)
{
x
+=
VLLM_LDG
(
&
input
[
token_idx
*
TOPK
*
d
+
k
*
d
+
idx
]);
}
out
[
token_idx
*
d
+
idx
]
=
x
;
}
}
}
}
template
<
typename
scalar_t
,
int32_t
fill_threads
>
template
<
typename
scalar_t
,
int32_t
fill_threads
>
__
global
__
void
moe_align_block_size_small_batch_expert
_kernel
(
__
device
__
void
_
moe_align_block_size_small_batch_expert
(
const
scalar_t
*
__restrict__
topk_ids
,
const
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
expert_ids
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
expert_ids
,
int32_t
*
__restrict__
total_tokens_post_pad
,
int32_t
*
__restrict__
total_tokens_post_pad
,
int32_t
*
__restrict__
expert_map
,
int32_t
num_experts
,
int32_t
block_size
,
int32_t
*
__restrict__
expert_map
,
int32_t
num_experts
,
int32_t
block_size
,
size_t
numel
,
int32_t
max_num_tokens_padded
,
bool
has_expert_map
)
{
size_t
numel
,
int32_t
max_num_tokens_padded
,
int32_t
max_num_m_blocks
,
int32_t
inactive_expert_id
,
int32_t
model_offset
,
int32_t
topk_num
,
int32_t
*
token_mask
,
bool
has_expert_map
)
{
// Compute input buffer offsets. Typically these will all be 0, except when
// using Multi LoRA.
int
sorted_token_ids_offset
=
max_num_tokens_padded
*
model_offset
;
int
expert_ids_offset
=
max_num_m_blocks
*
model_offset
;
// Use an additional group of threads to fill sorted_token_ids.
// Use an additional group of threads to fill sorted_token_ids.
// Since the current kernel will use sorted_token_ids afterward,
// Since the current kernel will use sorted_token_ids afterward,
// we fill sorted_token_ids within the same threadblock to make
// we fill sorted_token_ids within the same threadblock to make
...
@@ -227,7 +202,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
...
@@ -227,7 +202,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
// Initialize sorted_token_ids with numel
// Initialize sorted_token_ids with numel
for
(
size_t
it
=
threadIdx
.
x
;
it
<
max_num_tokens_padded
;
for
(
size_t
it
=
threadIdx
.
x
;
it
<
max_num_tokens_padded
;
it
+=
fill_threads
)
{
it
+=
fill_threads
)
{
sorted_token_ids
[
it
]
=
numel
;
sorted_token_ids
[
sorted_token_ids_offset
+
it
]
=
numel
;
}
}
// Three __syncthreads() corresponding to the other threads
// Three __syncthreads() corresponding to the other threads
__syncthreads
();
__syncthreads
();
...
@@ -254,7 +229,8 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
...
@@ -254,7 +229,8 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
// filter invalid expert
// filter invalid expert
if
(
expert_id
==
-
1
)
continue
;
if
(
expert_id
==
-
1
)
continue
;
}
}
++
tokens_cnts
[(
tid
+
1
)
*
num_experts
+
expert_id
];
int
mask
=
token_mask
==
nullptr
?
1
:
token_mask
[
i
/
topk_num
];
tokens_cnts
[(
tid
+
1
)
*
num_experts
+
expert_id
]
+=
mask
;
}
}
__syncthreads
();
__syncthreads
();
...
@@ -277,22 +253,22 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
...
@@ -277,22 +253,22 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
CEILDIV
(
tokens_cnts
[
stride
*
num_experts
+
i
-
1
],
block_size
)
*
CEILDIV
(
tokens_cnts
[
stride
*
num_experts
+
i
-
1
],
block_size
)
*
block_size
;
block_size
;
}
}
*
total_tokens_post_pad
=
static_cast
<
int32_t
>
(
cumsum
[
num_experts
]);
total_tokens_post_pad
[
model_offset
]
=
static_cast
<
int32_t
>
(
cumsum
[
num_experts
]);
}
}
__syncthreads
();
__syncthreads
();
if
(
tid
<
num_experts
)
{
if
(
tid
<
num_experts
)
{
for
(
int
i
=
cumsum
[
tid
];
i
<
cumsum
[
tid
+
1
];
i
+=
block_size
)
{
for
(
int
i
=
cumsum
[
tid
];
i
<
cumsum
[
tid
+
1
];
i
+=
block_size
)
{
expert_ids
[
i
/
block_size
]
=
tid
;
expert_ids
[
expert_ids_offset
+
i
/
block_size
]
=
tid
;
}
}
}
}
// Fill remaining expert_ids with 0
// Fill remaining expert_ids with 0
const
size_t
fill_start_idx
=
cumsum
[
num_experts
]
/
block_size
+
tid
;
const
size_t
fill_start_idx
=
cumsum
[
num_experts
]
/
block_size
+
tid
;
const
size_t
expert_ids_size
=
CEILDIV
(
max_num_tokens_padded
,
block_size
);
for
(
size_t
i
=
fill_start_idx
;
i
<
max_num_m_blocks
;
i
+=
stride
)
{
for
(
size_t
i
=
fill_start_idx
;
i
<
expert_ids_size
;
i
+=
stride
)
{
expert_ids
[
expert_ids_offset
+
i
]
=
inactive_expert_id
;
expert_ids
[
i
]
=
0
;
}
}
for
(
size_t
i
=
tid
;
i
<
numel
;
i
+=
stride
)
{
for
(
size_t
i
=
tid
;
i
<
numel
;
i
+=
stride
)
{
...
@@ -304,9 +280,193 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
...
@@ -304,9 +280,193 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
}
}
int32_t
rank_post_pad
=
int32_t
rank_post_pad
=
tokens_cnts
[
tid
*
num_experts
+
expert_id
]
+
cumsum
[
expert_id
];
tokens_cnts
[
tid
*
num_experts
+
expert_id
]
+
cumsum
[
expert_id
];
sorted_token_ids
[
rank_post_pad
]
=
i
;
if
(
token_mask
==
nullptr
||
token_mask
[
i
/
topk_num
])
{
sorted_token_ids
[
sorted_token_ids_offset
+
rank_post_pad
]
=
i
;
++
tokens_cnts
[
tid
*
num_experts
+
expert_id
];
++
tokens_cnts
[
tid
*
num_experts
+
expert_id
];
}
}
}
}
template
<
typename
scalar_t
>
__device__
void
_count_and_sort_expert_tokens
(
const
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
cumsum_buffer
,
int32_t
*
__restrict__
expert_map
,
size_t
numel
,
int32_t
num_experts
,
int32_t
max_num_tokens_padded
,
int32_t
*
__restrict__
token_mask
,
int32_t
model_offset
,
int32_t
topk_num
,
bool
has_expert_map
)
{
const
size_t
tid
=
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
y
;
for
(
size_t
i
=
tid
;
i
<
numel
;
i
+=
stride
)
{
int32_t
expert_id
=
topk_ids
[
i
];
if
(
expert_id
>=
num_experts
)
{
continue
;
}
if
(
has_expert_map
)
{
expert_id
=
expert_map
[
expert_id
];
// filter invalid experts
if
(
expert_id
==
-
1
)
continue
;
}
if
(
token_mask
==
nullptr
||
token_mask
[
i
/
topk_num
])
{
int32_t
rank_post_pad
=
atomicAdd
(
&
cumsum_buffer
[(
model_offset
*
(
num_experts
+
1
))
+
expert_id
],
1
);
sorted_token_ids
[
max_num_tokens_padded
*
model_offset
+
rank_post_pad
]
=
i
;
}
}
}
template
<
typename
scalar_t
>
__global__
void
moe_align_block_size_kernel
(
const
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
expert_ids
,
int32_t
*
__restrict__
total_tokens_post_pad
,
int32_t
*
__restrict__
expert_map
,
int32_t
num_experts
,
int32_t
padded_num_experts
,
int32_t
experts_per_warp
,
int32_t
block_size
,
size_t
numel
,
int32_t
*
__restrict__
cumsum
,
int32_t
max_num_tokens_padded
,
int32_t
topk_num
,
bool
has_expert_map
)
{
_moe_align_block_size
(
topk_ids
,
sorted_token_ids
,
expert_ids
,
total_tokens_post_pad
,
expert_map
,
num_experts
,
padded_num_experts
,
experts_per_warp
,
block_size
,
numel
,
cumsum
,
max_num_tokens_padded
,
CEILDIV
(
max_num_tokens_padded
,
block_size
),
0
,
0
,
topk_num
,
nullptr
,
has_expert_map
);
}
template
<
typename
scalar_t
>
__global__
void
count_and_sort_expert_tokens_kernel
(
const
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
cumsum_buffer
,
int32_t
*
__restrict__
expert_map
,
size_t
numel
,
int32_t
num_experts
,
int32_t
max_num_tokens_padded
,
int32_t
topk_num
,
bool
has_expert_map
)
{
_count_and_sort_expert_tokens
(
topk_ids
,
sorted_token_ids
,
cumsum_buffer
,
expert_map
,
numel
,
num_experts
,
max_num_tokens_padded
,
nullptr
,
0
,
topk_num
,
has_expert_map
);
}
template
<
typename
scalar_t
,
int
TOPK
>
__global__
void
moe_sum_kernel
(
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., topk, d]
const
int
d
)
{
const
int64_t
token_idx
=
blockIdx
.
x
;
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
scalar_t
x
=
0.0
;
#pragma unroll
for
(
int
k
=
0
;
k
<
TOPK
;
++
k
)
{
x
+=
VLLM_LDG
(
&
input
[
token_idx
*
TOPK
*
d
+
k
*
d
+
idx
]);
}
out
[
token_idx
*
d
+
idx
]
=
x
;
}
}
template
<
typename
scalar_t
,
int32_t
fill_threads
>
__global__
void
moe_align_block_size_small_batch_expert_kernel
(
const
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
expert_ids
,
int32_t
*
__restrict__
total_tokens_post_pad
,
int32_t
*
__restrict__
expert_map
,
int32_t
num_experts
,
int32_t
block_size
,
size_t
numel
,
int32_t
max_num_tokens_padded
,
int32_t
topk_num
,
bool
has_expert_map
)
{
_moe_align_block_size_small_batch_expert
<
scalar_t
,
fill_threads
>
(
topk_ids
,
sorted_token_ids
,
expert_ids
,
total_tokens_post_pad
,
expert_map
,
num_experts
,
block_size
,
numel
,
max_num_tokens_padded
,
CEILDIV
(
max_num_tokens_padded
,
block_size
),
0
,
0
,
topk_num
,
nullptr
,
has_expert_map
);
}
template
<
typename
scalar_t
>
__global__
void
moe_lora_align_block_size_kernel
(
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
__restrict__
token_lora_mapping
,
int64_t
block_size
,
int32_t
*
__restrict__
expert_map
,
int
num_experts
,
int
max_loras
,
size_t
numel
,
int
max_num_tokens_padded
,
int
max_num_m_blocks
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
expert_ids
,
int32_t
topk_num
,
int32_t
*
total_tokens_post_pad
,
int32_t
*
adapter_enabled
,
int32_t
*
__restrict__
cumsum
,
int32_t
experts_per_warp
,
int32_t
padded_num_experts
,
int32_t
*
lora_ids
,
int32_t
*
__restrict__
token_mask
,
bool
has_expert_map
)
{
int
lora_idx
=
blockIdx
.
x
/
2
;
int
lora_id
=
lora_ids
[
lora_idx
];
if
(
lora_id
==
-
1
||
adapter_enabled
[
lora_id
]
==
0
)
{
return
;
}
// Populate the token_mask based on the token-LoRA mapping
int
num_tokens
=
numel
/
topk_num
;
if
(
threadIdx
.
x
==
0
)
{
total_tokens_post_pad
[
lora_id
]
=
0
;
for
(
int
i
=
0
;
i
<
num_tokens
;
i
++
)
{
token_mask
[(
lora_id
*
num_tokens
)
+
i
]
=
(
int
)
token_lora_mapping
[
i
]
==
lora_id
;
}
}
__syncthreads
();
_moe_align_block_size
(
topk_ids
,
sorted_token_ids
,
expert_ids
,
total_tokens_post_pad
,
expert_map
,
num_experts
,
padded_num_experts
,
experts_per_warp
,
block_size
,
numel
,
cumsum
,
max_num_tokens_padded
,
max_num_m_blocks
,
lora_id
,
-
1
,
topk_num
,
&
token_mask
[(
lora_id
*
num_tokens
)],
has_expert_map
);
}
template
<
typename
scalar_t
>
__global__
void
lora_count_and_sort_expert_tokens_kernel
(
const
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
cumsum_buffer
,
int32_t
*
__restrict__
expert_map
,
size_t
numel
,
int32_t
num_experts
,
int32_t
max_num_tokens_padded
,
int32_t
topk_num
,
int32_t
*
token_mask
,
int32_t
*
lora_ids
,
bool
has_expert_map
)
{
int
lora_idx
=
blockIdx
.
x
;
int
lora_id
=
lora_ids
[
lora_idx
];
if
(
lora_id
==
-
1
)
{
return
;
}
int
num_tokens
=
numel
/
topk_num
;
_count_and_sort_expert_tokens
(
topk_ids
,
sorted_token_ids
,
cumsum_buffer
,
expert_map
,
numel
,
num_experts
,
max_num_tokens_padded
,
&
token_mask
[(
lora_id
*
num_tokens
)],
lora_id
,
topk_num
,
has_expert_map
);
}
template
<
typename
scalar_t
,
int32_t
fill_threads
>
__global__
void
moe_lora_align_block_size_small_batch_expert_kernel
(
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
token_lora_mapping
,
int64_t
block_size
,
int32_t
*
__restrict__
expert_map
,
int
num_experts
,
int
max_loras
,
size_t
numel
,
int
max_num_tokens_padded
,
int
max_num_m_blocks
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
expert_ids
,
int
topk_num
,
int32_t
*
total_tokens_post_pad
,
int32_t
*
adapter_enabled
,
int32_t
*
lora_ids
,
int32_t
*
token_mask
,
bool
has_expert_map
)
{
int
lora_idx
=
blockIdx
.
x
;
int
lora_id
=
lora_ids
[
lora_idx
];
if
(
lora_id
==
-
1
||
adapter_enabled
[
lora_id
]
==
0
)
{
return
;
}
int
num_tokens
=
numel
/
topk_num
;
if
(
threadIdx
.
x
==
0
)
{
total_tokens_post_pad
[
lora_id
]
=
0
;
for
(
int
i
=
0
;
i
<
num_tokens
;
i
++
)
{
token_mask
[(
lora_id
*
num_tokens
)
+
i
]
=
(
int
)
token_lora_mapping
[
i
]
==
lora_id
;
}
}
__syncthreads
();
_moe_align_block_size_small_batch_expert
<
scalar_t
,
fill_threads
>
(
topk_ids
,
sorted_token_ids
,
expert_ids
,
total_tokens_post_pad
,
expert_map
,
num_experts
,
block_size
,
numel
,
max_num_tokens_padded
,
max_num_m_blocks
,
-
1
,
lora_id
,
topk_num
,
&
token_mask
[(
lora_id
*
num_tokens
)],
has_expert_map
);
}
}
}
// namespace moe
}
// namespace moe
...
@@ -365,7 +525,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
...
@@ -365,7 +525,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
experts_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
expert_map
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
expert_map
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
(),
sorted_token_ids
.
size
(
0
),
has_expert_map
);
topk_ids
.
numel
(),
sorted_token_ids
.
size
(
0
),
topk_ids
.
size
(
1
),
has_expert_map
);
}
else
{
}
else
{
torch
::
Tensor
cumsum_buffer
=
torch
::
Tensor
cumsum_buffer
=
torch
::
empty
({
num_experts
+
1
},
options_int
);
torch
::
empty
({
num_experts
+
1
},
options_int
);
...
@@ -386,21 +547,23 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
...
@@ -386,21 +547,23 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
expert_map
.
data_ptr
<
int32_t
>
(),
num_experts
,
padded_num_experts
,
expert_map
.
data_ptr
<
int32_t
>
(),
num_experts
,
padded_num_experts
,
experts_per_warp
,
block_size
,
topk_ids
.
numel
(),
experts_per_warp
,
block_size
,
topk_ids
.
numel
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
(),
sorted_token_ids
.
size
(
0
),
cumsum_buffer
.
data_ptr
<
int32_t
>
(),
sorted_token_ids
.
size
(
0
),
has_expert_map
);
topk_ids
.
size
(
1
),
has_expert_map
);
const
int
block_threads
=
std
::
min
(
256
,
(
int
)
threads
);
const
int
block_threads
=
std
::
min
(
256
,
(
int
)
threads
);
const
int
num_blocks
=
const
int
num_blocks
=
(
topk_ids
.
numel
()
+
block_threads
-
1
)
/
block_threads
;
(
topk_ids
.
numel
()
+
block_threads
-
1
)
/
block_threads
;
const
int
max_blocks
=
65535
;
const
int
max_blocks
=
65535
;
const
int
actual_blocks
=
std
::
min
(
num_blocks
,
max_blocks
);
const
int
actual_blocks
=
std
::
min
(
num_blocks
,
max_blocks
);
dim3
gridDims
(
1
,
actual_blocks
);
auto
sort_kernel
=
auto
sort_kernel
=
vllm
::
moe
::
count_and_sort_expert_tokens_kernel
<
scalar_t
>
;
vllm
::
moe
::
count_and_sort_expert_tokens_kernel
<
scalar_t
>
;
sort_kernel
<<<
actual_block
s
,
block_threads
,
0
,
stream
>>>
(
sort_kernel
<<<
gridDim
s
,
block_threads
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
(),
expert_map
.
data_ptr
<
int32_t
>
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
(),
expert_map
.
data_ptr
<
int32_t
>
(),
topk_ids
.
numel
(),
num_experts
,
has_expert_map
);
topk_ids
.
numel
(),
num_experts
,
sorted_token_ids
.
size
(
0
),
topk_ids
.
size
(
1
),
has_expert_map
);
}
}
});
});
}
}
...
@@ -474,3 +637,123 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
...
@@ -474,3 +637,123 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
break
;
break
;
}
}
}
}
void
moe_lora_align_block_size
(
torch
::
Tensor
topk_ids
,
torch
::
Tensor
token_lora_mapping
,
int64_t
num_experts
,
int64_t
block_size
,
int64_t
max_loras
,
int64_t
max_num_tokens_padded
,
int64_t
max_num_m_blocks
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
expert_ids
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
adapter_enabled
,
torch
::
Tensor
lora_ids
,
std
::
optional
<
torch
::
Tensor
>
maybe_expert_map
)
{
const
int
topk_num
=
topk_ids
.
size
(
1
);
TORCH_CHECK
(
block_size
>
0
,
"block_size should be greater than 0. "
);
int
device_max_shared_mem
;
auto
dev
=
topk_ids
.
get_device
();
cudaDeviceGetAttribute
(
&
device_max_shared_mem
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
dev
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
int64_t
padded_num_experts
=
((
num_experts
+
WARP_SIZE
-
1
)
/
WARP_SIZE
)
*
WARP_SIZE
;
// BlockScan uses 1024 threads and assigns one thread per expert.
TORCH_CHECK
(
padded_num_experts
<
1024
,
"padded_num_experts must be less than 1024"
);
auto
options_int
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt
).
device
(
topk_ids
.
device
());
torch
::
Tensor
token_mask
=
torch
::
empty
({
max_loras
*
topk_ids
.
size
(
0
)},
options_int
);
bool
has_expert_map
=
maybe_expert_map
.
has_value
();
torch
::
Tensor
expert_map
;
if
(
has_expert_map
)
{
expert_map
=
maybe_expert_map
.
value
();
}
else
{
expert_map
=
torch
::
empty
({
0
},
options_int
);
}
VLLM_DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"moe_lora_align_sum_kernel"
,
[
&
]
{
bool
small_batch_expert_mode
=
(
topk_ids
.
numel
()
<
1024
)
&&
(
num_experts
<=
64
);
if
(
small_batch_expert_mode
)
{
const
int32_t
num_thread
=
max
((
int32_t
)
num_experts
,
128
);
const
int32_t
shared_mem
=
(
num_thread
+
1
)
*
num_experts
*
sizeof
(
int32_t
)
+
(
num_experts
+
1
)
*
sizeof
(
int32_t
);
if
(
shared_mem
>
device_max_shared_mem
)
{
TORCH_CHECK
(
false
,
"Shared memory usage exceeds device limit."
);
}
// threadIdx.x >= fill_threads: counting experts and aligning
// threadIdx.x < fill_threads: filling sorted_token_ids
constexpr
int32_t
fill_threads
=
256
;
dim3
blockDim
(
num_thread
+
fill_threads
);
auto
kernel
=
vllm
::
moe
::
moe_lora_align_block_size_small_batch_expert_kernel
<
scalar_t
,
fill_threads
>
;
AT_CUDA_CHECK
(
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize
(
(
void
*
)
kernel
,
shared_mem
));
kernel
<<<
max_loras
,
blockDim
,
shared_mem
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
token_lora_mapping
.
data_ptr
<
int32_t
>
(),
block_size
,
expert_map
.
data_ptr
<
int32_t
>
(),
num_experts
,
max_loras
,
topk_ids
.
numel
(),
max_num_tokens_padded
,
max_num_m_blocks
,
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
expert_ids
.
data_ptr
<
int32_t
>
(),
topk_num
,
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
adapter_enabled
.
data_ptr
<
int32_t
>
(),
lora_ids
.
data_ptr
<
int32_t
>
(),
token_mask
.
data_ptr
<
int32_t
>
(),
has_expert_map
);
}
else
{
int
num_thread
=
1024
;
dim3
blockDim
(
num_thread
);
size_t
num_warps
=
CEILDIV
(
padded_num_experts
,
WARP_SIZE
);
size_t
shared_mem_size
=
num_warps
*
WARP_SIZE
*
sizeof
(
int32_t
);
// cumsum buffer
torch
::
Tensor
cumsum
=
torch
::
zeros
({
max_loras
*
(
num_experts
+
1
)},
options_int
);
auto
align_kernel
=
vllm
::
moe
::
moe_lora_align_block_size_kernel
<
scalar_t
>
;
// launch two threadblocks for each lora
// blockIdx.x % 2 == 0: counting experts and aligning
// blockIdx.x % 2 == 1: filling sorted_token_ids
align_kernel
<<<
max_loras
*
2
,
blockDim
,
shared_mem_size
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
token_lora_mapping
.
data_ptr
<
int32_t
>
(),
block_size
,
expert_map
.
data_ptr
<
int32_t
>
(),
num_experts
,
max_loras
,
topk_ids
.
numel
(),
max_num_tokens_padded
,
max_num_m_blocks
,
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
expert_ids
.
data_ptr
<
int32_t
>
(),
topk_num
,
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
adapter_enabled
.
data_ptr
<
int32_t
>
(),
cumsum
.
data_ptr
<
int32_t
>
(),
WARP_SIZE
,
padded_num_experts
,
lora_ids
.
data_ptr
<
int32_t
>
(),
token_mask
.
data_ptr
<
int32_t
>
(),
has_expert_map
);
const
int
block_threads
=
std
::
min
(
256
,
(
int
)
num_thread
);
const
int
num_blocks
=
(
topk_ids
.
numel
()
+
block_threads
-
1
)
/
block_threads
;
const
int
max_blocks
=
65535
;
const
int
actual_blocks
=
std
::
min
(
num_blocks
,
max_blocks
);
dim3
gridDims
(
max_loras
,
actual_blocks
);
auto
sort_kernel
=
vllm
::
moe
::
lora_count_and_sort_expert_tokens_kernel
<
scalar_t
>
;
sort_kernel
<<<
gridDims
,
block_threads
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
cumsum
.
data_ptr
<
int32_t
>
(),
expert_map
.
data_ptr
<
int32_t
>
(),
topk_ids
.
numel
(),
num_experts
,
max_num_tokens_padded
,
topk_num
,
token_mask
.
data_ptr
<
int32_t
>
(),
lora_ids
.
data_ptr
<
int32_t
>
(),
has_expert_map
);
}
});
}
\ No newline at end of file
csrc/moe/moe_lora_align_sum_kernels.cu
deleted
100644 → 0
View file @
db14f61f
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/ATen.h>
#include <ATen/cuda/Atomic.cuh>
#include "../cuda_compat.h"
#include "../dispatch_utils.h"
#include "core/math.hpp"
namespace
{
__device__
__forceinline__
int32_t
index
(
int32_t
total_col
,
int32_t
row
,
int32_t
col
)
{
return
row
*
total_col
+
col
;
}
}
// namespace
// TODO: Refactor common parts with moe_align_sum_kernels
template
<
typename
scalar_t
,
typename
token_cnts_t
>
__global__
void
moe_lora_align_sum_kernel
(
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
token_lora_mapping
,
int64_t
block_size
,
int
num_experts
,
int
max_loras
,
size_t
numel
,
int
max_num_tokens_padded
,
int
max_num_m_blocks
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
expert_ids
,
int
topk_num
,
int32_t
*
total_tokens_post_pad
,
int32_t
*
adapter_enabled
,
int32_t
*
lora_ids
)
{
const
size_t
tokens_per_thread
=
div_ceil
(
numel
,
blockDim
.
x
);
const
size_t
start_idx
=
threadIdx
.
x
*
tokens_per_thread
;
int
lora_idx
=
blockIdx
.
x
;
int
lora_id
=
lora_ids
[
lora_idx
];
if
(
lora_id
==
-
1
||
adapter_enabled
[
lora_id
]
==
0
)
{
return
;
}
extern
__shared__
int32_t
shared_mem
[];
int32_t
*
cumsum
=
shared_mem
;
token_cnts_t
*
tokens_cnts
=
(
token_cnts_t
*
)(
shared_mem
+
num_experts
+
1
);
// Initialize sorted_token_ids with numel
for
(
size_t
it
=
threadIdx
.
x
;
it
<
max_num_tokens_padded
;
it
+=
blockDim
.
x
)
{
sorted_token_ids
[
lora_id
*
max_num_tokens_padded
+
it
]
=
numel
;
}
// Initialize expert_ids with -1
for
(
size_t
it
=
threadIdx
.
x
;
it
<
max_num_m_blocks
;
it
+=
blockDim
.
x
)
{
expert_ids
[
lora_id
*
max_num_m_blocks
+
it
]
=
-
1
;
}
// Initialize total_tokens_post_pad with 0
if
(
threadIdx
.
x
==
0
)
{
total_tokens_post_pad
[
lora_id
]
=
0
;
}
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
+
1
,
i
)]
=
0
;
}
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
int
mask
=
token_lora_mapping
[
i
/
topk_num
]
==
lora_id
;
int
idx
=
index
(
num_experts
,
threadIdx
.
x
+
1
,
topk_ids
[
i
]);
tokens_cnts
[
idx
]
+=
mask
;
}
__syncthreads
();
// For each expert we accumulate the token counts from the different threads.
if
(
threadIdx
.
x
<
num_experts
)
{
tokens_cnts
[
index
(
num_experts
,
0
,
threadIdx
.
x
)]
=
0
;
for
(
int
i
=
1
;
i
<=
blockDim
.
x
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
i
,
threadIdx
.
x
)]
+=
tokens_cnts
[
index
(
num_experts
,
i
-
1
,
threadIdx
.
x
)];
}
}
__syncthreads
();
// We accumulate the token counts of all experts in thread 0.
if
(
threadIdx
.
x
==
0
)
{
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
div_ceil
(
tokens_cnts
[
index
(
num_experts
,
blockDim
.
x
,
i
-
1
)],
block_size
)
*
block_size
;
}
total_tokens_post_pad
[
lora_id
]
=
static_cast
<
int32_t
>
(
cumsum
[
num_experts
]);
}
__syncthreads
();
/**
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
if
(
threadIdx
.
x
<
num_experts
)
{
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
i
+=
block_size
)
{
expert_ids
[
index
(
max_num_m_blocks
,
lora_id
,
i
/
block_size
)]
=
threadIdx
.
x
;
}
}
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
int32_t
expert_id
=
topk_ids
[
i
];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* processed by the expert with expert_id within the current thread's token
* shard.
*/
int32_t
rank_post_pad
=
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)]
+
cumsum
[
expert_id
];
int
mask
=
(
int
)
token_lora_mapping
[
i
/
topk_num
]
==
lora_id
;
atomicAdd
(
&
sorted_token_ids
[
index
(
max_num_tokens_padded
,
lora_id
,
rank_post_pad
)],
(
i
-
numel
)
*
mask
);
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)]
+=
mask
;
}
}
void
moe_lora_align_block_size
(
torch
::
Tensor
topk_ids
,
torch
::
Tensor
token_lora_mapping
,
int64_t
num_experts
,
int64_t
block_size
,
int64_t
max_loras
,
int64_t
max_num_tokens_padded
,
int64_t
max_num_m_blocks
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
expert_ids
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
adapter_enabled
,
torch
::
Tensor
lora_ids
)
{
const
int
topk_num
=
topk_ids
.
size
(
1
);
TORCH_CHECK
(
block_size
>
0
,
"block_size should be greater than 0. "
);
int
device_max_shared_mem
;
auto
dev
=
topk_ids
.
get_device
();
cudaDeviceGetAttribute
(
&
device_max_shared_mem
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
dev
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
int32_t
num_thread
=
max
((
int32_t
)
num_experts
,
128
);
// WARP_SIZE,
TORCH_CHECK
(
num_thread
<=
1024
,
"num_thread must be less than 1024, "
"and fallback is not implemented yet."
);
const
int32_t
shared_mem
=
(
num_thread
+
1
)
*
num_experts
*
sizeof
(
int32_t
)
+
(
num_experts
+
1
)
*
sizeof
(
int32_t
);
if
(
shared_mem
>
device_max_shared_mem
)
{
TORCH_CHECK
(
false
,
"Shared memory usage exceeds device limit, and global memory "
"fallback is not implemented yet."
);
}
VLLM_DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"moe_lora_align_sum_kernel"
,
[
&
]
{
dim3
blockDim
(
num_thread
);
auto
kernel
=
moe_lora_align_sum_kernel
<
scalar_t
,
int32_t
>
;
AT_CUDA_CHECK
(
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize
(
(
void
*
)
kernel
,
shared_mem
));
kernel
<<<
max_loras
,
blockDim
,
shared_mem
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
token_lora_mapping
.
data_ptr
<
int32_t
>
(),
block_size
,
num_experts
,
max_loras
,
topk_ids
.
numel
(),
max_num_tokens_padded
,
max_num_m_blocks
,
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
expert_ids
.
data_ptr
<
int32_t
>
(),
topk_num
,
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
adapter_enabled
.
data_ptr
<
int32_t
>
(),
lora_ids
.
data_ptr
<
int32_t
>
());
});
}
\ No newline at end of file
csrc/moe/moe_ops.h
View file @
ea657f20
...
@@ -27,7 +27,7 @@ void moe_lora_align_block_size(
...
@@ -27,7 +27,7 @@ void moe_lora_align_block_size(
int64_t
max_num_tokens_padded
,
int64_t
max_num_m_blocks
,
int64_t
max_num_tokens_padded
,
int64_t
max_num_m_blocks
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
expert_ids
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
expert_ids
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
adapter_enabled
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
adapter_enabled
,
torch
::
Tensor
lora_ids
);
torch
::
Tensor
lora_ids
,
std
::
optional
<
torch
::
Tensor
>
maybe_expert_map
);
#ifndef USE_ROCM
#ifndef USE_ROCM
torch
::
Tensor
moe_wna16_gemm
(
torch
::
Tensor
input
,
torch
::
Tensor
output
,
torch
::
Tensor
moe_wna16_gemm
(
torch
::
Tensor
input
,
torch
::
Tensor
output
,
torch
::
Tensor
b_qweight
,
torch
::
Tensor
b_scales
,
torch
::
Tensor
b_qweight
,
torch
::
Tensor
b_scales
,
...
...
csrc/moe/torch_bindings.cpp
View file @
ea657f20
...
@@ -47,7 +47,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
...
@@ -47,7 +47,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" Tensor !experts_ids,"
" Tensor !experts_ids,"
" Tensor !num_tokens_post_pad,"
" Tensor !num_tokens_post_pad,"
" Tensor !adapter_enabled,"
" Tensor !adapter_enabled,"
" Tensor !lora_ids) -> () "
);
" Tensor !lora_ids,"
" Tensor? maybe_expert_map) -> () "
);
m
.
impl
(
"moe_lora_align_block_size"
,
torch
::
kCUDA
,
&
moe_lora_align_block_size
);
m
.
impl
(
"moe_lora_align_block_size"
,
torch
::
kCUDA
,
&
moe_lora_align_block_size
);
#ifndef USE_ROCM
#ifndef USE_ROCM
...
...
tests/lora/test_moe_lora_align_sum.py
View file @
ea657f20
...
@@ -32,7 +32,7 @@ def sample_data(num_experts, max_loras, num_tokens, topk_num):
...
@@ -32,7 +32,7 @@ def sample_data(num_experts, max_loras, num_tokens, topk_num):
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
100
,
200
,
1024
,
4096
])
# 81920
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
100
,
200
,
1024
,
4096
])
# 81920
@
pytest
.
mark
.
parametrize
(
"topk_num"
,
[
6
])
@
pytest
.
mark
.
parametrize
(
"topk_num"
,
[
6
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
64
,
128
,
256
,
512
])
@
pytest
.
mark
.
parametrize
(
"max_loras"
,
[
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"max_loras"
,
[
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
def
test_moe_lora_align_block_size
(
def
test_moe_lora_align_block_size
(
...
...
vllm/_custom_ops.py
View file @
ea657f20
...
@@ -1961,6 +1961,7 @@ def moe_lora_align_block_size(
...
@@ -1961,6 +1961,7 @@ def moe_lora_align_block_size(
num_tokens_post_pad
:
torch
.
Tensor
,
num_tokens_post_pad
:
torch
.
Tensor
,
adapter_enabled
:
torch
.
Tensor
,
adapter_enabled
:
torch
.
Tensor
,
lora_ids
:
torch
.
Tensor
,
lora_ids
:
torch
.
Tensor
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
)
->
None
:
)
->
None
:
torch
.
ops
.
_moe_C
.
moe_lora_align_block_size
(
torch
.
ops
.
_moe_C
.
moe_lora_align_block_size
(
topk_ids
,
topk_ids
,
...
@@ -1975,6 +1976,7 @@ def moe_lora_align_block_size(
...
@@ -1975,6 +1976,7 @@ def moe_lora_align_block_size(
num_tokens_post_pad
,
num_tokens_post_pad
,
adapter_enabled
,
adapter_enabled
,
lora_ids
,
lora_ids
,
expert_map
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment