Unverified Commit a532c838 authored by gnovack's avatar gnovack Committed by GitHub
Browse files

use 'max_active_experts' for moe lora input size (#33197)


Signed-off-by: default avatargnovack <gnovack@amazon.com>
parent 1e5ad9b7
...@@ -47,6 +47,8 @@ def test_moe_lora_align_block_size( ...@@ -47,6 +47,8 @@ def test_moe_lora_align_block_size(
# compute paddings # compute paddings
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
if topk_ids.numel() < num_experts:
max_num_tokens_padded = topk_ids.numel() * block_size
max_num_m_blocks = CEILDIV(max_num_tokens_padded, block_size) max_num_m_blocks = CEILDIV(max_num_tokens_padded, block_size)
# init output tensors # init output tensors
......
...@@ -351,6 +351,8 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -351,6 +351,8 @@ class PunicaWrapperGPU(PunicaWrapperBase):
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
if pad_sorted_ids: if pad_sorted_ids:
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
if topk_ids.numel() < num_experts:
max_num_tokens_padded = topk_ids.numel() * block_size
sorted_ids = torch.empty( sorted_ids = torch.empty(
(max_loras * max_num_tokens_padded,), (max_loras * max_num_tokens_padded,),
dtype=torch.int32, dtype=torch.int32,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment