Unverified Commit 8e4ca4d1 authored by gnovack's avatar gnovack Committed by GitHub
Browse files

Bugfix - pass 'max_num_tokens_padded' into 'moe_lora_align_block_size' (#27311)


Signed-off-by: default avatargnovack <gnovack@amazon.com>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 1a0f4def
...@@ -124,18 +124,14 @@ __global__ void moe_lora_align_sum_kernel( ...@@ -124,18 +124,14 @@ __global__ void moe_lora_align_sum_kernel(
void moe_lora_align_block_size(torch::Tensor topk_ids, void moe_lora_align_block_size(torch::Tensor topk_ids,
torch::Tensor token_lora_mapping, torch::Tensor token_lora_mapping,
int64_t num_experts, int64_t block_size, int64_t num_experts, int64_t block_size,
int64_t max_loras, int64_t max_loras, int64_t max_num_tokens_padded,
int64_t max_num_m_blocks,
torch::Tensor sorted_token_ids, torch::Tensor sorted_token_ids,
torch::Tensor expert_ids, torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad) { torch::Tensor num_tokens_post_pad) {
const int topk_num = topk_ids.size(1); const int topk_num = topk_ids.size(1);
int max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1);
TORCH_CHECK(block_size > 0, "block_size should be greater than 0. "); TORCH_CHECK(block_size > 0, "block_size should be greater than 0. ");
max_num_tokens_padded = round_to_next_multiple_of(
max_num_tokens_padded, static_cast<int>(block_size));
int max_num_m_blocks = div_ceil(max_num_tokens_padded, block_size);
int device_max_shared_mem; int device_max_shared_mem;
auto dev = topk_ids.get_device(); auto dev = topk_ids.get_device();
......
...@@ -23,7 +23,8 @@ void batched_moe_align_block_size(int64_t max_tokens_per_batch, ...@@ -23,7 +23,8 @@ void batched_moe_align_block_size(int64_t max_tokens_per_batch,
void moe_lora_align_block_size(torch::Tensor topk_ids, void moe_lora_align_block_size(torch::Tensor topk_ids,
torch::Tensor token_lora_mapping, torch::Tensor token_lora_mapping,
int64_t num_experts, int64_t block_size, int64_t num_experts, int64_t block_size,
int64_t max_loras, int64_t max_loras, int64_t max_num_tokens_padded,
int64_t max_num_m_blocks,
torch::Tensor sorted_token_ids, torch::Tensor sorted_token_ids,
torch::Tensor expert_ids, torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad); torch::Tensor num_tokens_post_pad);
......
...@@ -40,6 +40,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ...@@ -40,6 +40,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" Tensor token_lora_mapping," " Tensor token_lora_mapping,"
" int num_experts," " int num_experts,"
" int block_size, int max_loras, " " int block_size, int max_loras, "
" int max_num_tokens_padded, "
" int max_num_m_blocks, "
" Tensor !sorted_token_ids," " Tensor !sorted_token_ids,"
" Tensor !experts_ids," " Tensor !experts_ids,"
" Tensor !num_tokens_post_pad) -> () "); " Tensor !num_tokens_post_pad) -> () ");
......
...@@ -142,6 +142,8 @@ def use_fused_moe_lora_kernel( ...@@ -142,6 +142,8 @@ def use_fused_moe_lora_kernel(
num_experts, num_experts,
block_size, block_size,
max_loras, max_loras,
max_num_tokens_padded,
max_num_m_blocks,
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
......
...@@ -36,7 +36,7 @@ def test_gptoss20b_lora(gptoss20b_lora_files): ...@@ -36,7 +36,7 @@ def test_gptoss20b_lora(gptoss20b_lora_files):
llm = vllm.LLM( llm = vllm.LLM(
MODEL_PATH, MODEL_PATH,
enable_lora=True, enable_lora=True,
max_loras=1, max_loras=4,
trust_remote_code=True, trust_remote_code=True,
) )
......
...@@ -68,6 +68,8 @@ def test_moe_lora_align_block_size( ...@@ -68,6 +68,8 @@ def test_moe_lora_align_block_size(
num_experts, num_experts,
block_size, block_size,
max_loras, max_loras,
max_num_tokens_padded,
max_num_m_blocks,
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
num_tokens_post_pad, num_tokens_post_pad,
......
...@@ -1801,6 +1801,8 @@ def moe_lora_align_block_size( ...@@ -1801,6 +1801,8 @@ def moe_lora_align_block_size(
num_experts: int, num_experts: int,
block_size: int, block_size: int,
max_loras: int, max_loras: int,
max_num_tokens_padded: int,
max_num_m_blocks: int,
sorted_token_ids: torch.Tensor, sorted_token_ids: torch.Tensor,
experts_ids: torch.Tensor, experts_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor, num_tokens_post_pad: torch.Tensor,
...@@ -1811,6 +1813,8 @@ def moe_lora_align_block_size( ...@@ -1811,6 +1813,8 @@ def moe_lora_align_block_size(
num_experts, num_experts,
block_size, block_size,
max_loras, max_loras,
max_num_tokens_padded,
max_num_m_blocks,
sorted_token_ids, sorted_token_ids,
experts_ids, experts_ids,
num_tokens_post_pad, num_tokens_post_pad,
......
...@@ -341,6 +341,8 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -341,6 +341,8 @@ class PunicaWrapperGPU(PunicaWrapperBase):
num_experts, num_experts,
block_size, block_size,
max_loras, max_loras,
max_num_tokens_padded,
max_num_m_blocks,
sorted_ids, sorted_ids,
expert_ids, expert_ids,
num_tokens_post_pad, num_tokens_post_pad,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment