Unverified Commit 7320ca39 authored by Runkai Tao's avatar Runkai Tao Committed by GitHub
Browse files

Add unpermute-aware fused MoE LoRA path (#32655)


Signed-off-by: default avatarRunkai Tao <rt572@physics.rutgers.edu>
parent cf0a99f8
......@@ -842,6 +842,7 @@ class BenchmarkTensors:
"sorted_token_ids": sorted_token_ids,
"expert_ids": expert_ids,
"num_tokens_post_padded": num_tokens_post_padded,
"token_lora_mapping": self.lora_kernel_meta.token_lora_mapping,
"top_k_num": ctx.top_k_num,
"device": self.input.device,
"N": lora_rank,
......@@ -915,6 +916,7 @@ class BenchmarkTensors:
"sorted_token_ids": sorted_token_ids,
"expert_ids": expert_ids,
"num_tokens_post_padded": num_tokens_post_padded,
"token_lora_mapping": self.lora_kernel_meta.token_lora_mapping,
"top_k_num": ctx.top_k_num,
"device": self.input.device,
"N": lora_rank,
......
......@@ -190,6 +190,7 @@ def use_fused_moe_lora_kernel(
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
token_lora_mapping,
max_lora_rank,
top_k_num,
lora_ids,
......@@ -333,6 +334,189 @@ def test_fused_moe_lora_kernel(
torch.testing.assert_close(output, output2, atol=1e-1, rtol=1e-1)
def use_fused_moe_lora_kernel_naive(
topk_ids,
topk_weights,
token_lora_mapping,
max_lora_rank,
top_k_num,
lora_a_stacked,
lora_b_stacked,
hidden_states,
output,
max_loras,
block_size,
fully_sharded=False,
offset=0,
):
"""
Test helper for naive_block_assignment path.
Skips moe_lora_align_block_size and uses flattened topk_ids as expert_ids.
"""
config = {
"BLOCK_SIZE_M": block_size,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"NUM_WARPS": 4,
"NUM_STAGES": 3,
"SPLIT_K": 1,
}
mul_routed_weight = False
# In naive mode:
# - expert_ids = topk_ids.view(-1), shape: (num_tokens * top_k,)
# - sorted_token_ids = None
# - num_tokens_post_padded = None
expert_ids = topk_ids.reshape(-1)
sorted_token_ids = None
num_tokens_post_padded = None
adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32)
lora_ids = torch.arange(max_loras + 2, dtype=torch.int32)
fused_moe_lora(
output,
hidden_states,
lora_a_stacked,
lora_b_stacked,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
token_lora_mapping,
max_lora_rank,
top_k_num,
lora_ids,
adapter_enabled,
config["BLOCK_SIZE_M"],
config["BLOCK_SIZE_N"],
config["BLOCK_SIZE_K"],
config["GROUP_SIZE_M"],
config["NUM_WARPS"],
config["NUM_STAGES"],
config["SPLIT_K"],
config["BLOCK_SIZE_M"],
config["BLOCK_SIZE_N"],
config["BLOCK_SIZE_K"],
config["GROUP_SIZE_M"],
config["NUM_WARPS"],
config["NUM_STAGES"],
config["SPLIT_K"],
mul_routed_weight=mul_routed_weight,
fully_sharded=fully_sharded,
offset=offset,
)
@pytest.mark.parametrize("num_tokens", [1, 2, 4, 8])
@pytest.mark.parametrize("top_k_num", [1, 2])
@pytest.mark.parametrize("num_experts", [64, 128])
@pytest.mark.parametrize("max_loras", [4, 8])
@pytest.mark.parametrize("N", [1408])
@pytest.mark.parametrize("K", [2048])
@pytest.mark.parametrize("max_lora_rank", [16, 32])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
def test_fused_moe_lora_kernel_naive_block_assignment(
num_tokens,
top_k_num,
num_experts,
max_loras,
N,
K,
max_lora_rank,
block_size,
dtype,
device,
seed,
):
"""
Test the naive_block_assignment path of the fused_moe_lora kernel.
This path is triggered when batch_size * top_k is much smaller than
num_experts * max_loras, and skips the moe_lora_align_block_size kernel.
"""
torch.set_default_device(device)
set_random_seed(seed)
# Verify this configuration would trigger naive_block_assignment
# (num_tokens * top_k * SPARSITY_FACTOR <= num_experts * max_loras)
SPARSITY_FACTOR = 8
assert num_tokens * top_k_num * SPARSITY_FACTOR <= num_experts * max_loras, (
f"Test configuration doesn't meet naive_block_assignment condition: "
f"{num_tokens} * {top_k_num} * {SPARSITY_FACTOR} > {num_experts} * {max_loras}"
)
# the number of randomly generated sentences.
num_sequences = min(num_tokens, 4)
# generate data
topk_ids, topk_weights, token_lora_mapping = sample_data(
num_tokens, num_sequences, max_loras, num_experts, top_k_num
)
# init lora weights
lora_a_stacked = [
torch.rand(
(
max_loras,
num_experts,
max_lora_rank,
K,
),
dtype=dtype,
)
]
lora_b_stacked = [
torch.rand(
(
max_loras,
num_experts,
N,
max_lora_rank,
),
dtype=dtype,
)
]
hidden_states = torch.rand(
(
num_tokens,
K,
),
dtype=dtype,
)
# fused_moe_lora_kernel output (naive path)
output = torch.zeros((num_tokens, top_k_num, N), dtype=dtype)
use_fused_moe_lora_kernel_naive(
topk_ids,
topk_weights,
token_lora_mapping,
max_lora_rank,
top_k_num,
lora_a_stacked,
lora_b_stacked,
hidden_states,
output,
max_loras,
block_size,
)
# pytorch reference output
output_ref = use_torch(
hidden_states,
token_lora_mapping,
topk_ids,
lora_a_stacked,
lora_b_stacked,
top_k_num,
)
torch.testing.assert_close(output, output_ref, atol=1e-1, rtol=1e-1)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("num_tokens", [100])
@pytest.mark.parametrize("top_k_num", [6])
......
......@@ -190,8 +190,18 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
config_dtype=config_dtype,
)
# SPARSITY_FACTOR is a heuristic margin ensuring tokens * top_k
# activates only a small fraction of total experts * loras.
SPARSITY_FACTOR = 8
naive_block_assignment = (
expert_map is None
and num_tokens * top_k * SPARSITY_FACTOR
<= self.base_layer.local_num_experts * self.max_loras
)
# get the block size of m from customized config or default config
(
token_lora_mapping,
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
......@@ -203,6 +213,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.max_loras,
self.adapter_enabled,
expert_map,
naive_block_assignment,
)
moe_state_dict["sorted_token_ids_lora"] = sorted_token_ids_lora
......@@ -210,9 +221,13 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
moe_state_dict["num_tokens_post_padded_lora"] = (
num_tokens_post_padded_lora
)
moe_state_dict["token_lora_mapping"] = token_lora_mapping
if sorted_token_ids_lora is not None:
expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
sorted_token_ids_lora = sorted_token_ids_lora.view(self.max_loras, -1)
sorted_token_ids_lora = sorted_token_ids_lora.view(
self.max_loras, -1
)
#
self.punica_wrapper.add_lora_fused_moe(
......@@ -230,6 +245,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
expand_config, ## pass the expand config
self.adapter_enabled,
fully_sharded=self.fully_sharded,
token_lora_mapping=token_lora_mapping,
)
result = func(*args, **kwargs)
......@@ -270,9 +286,13 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
num_tokens_post_padded_lora = moe_state_dict[
"num_tokens_post_padded_lora"
]
token_lora_mapping = moe_state_dict.get("token_lora_mapping")
if sorted_token_ids_lora is not None:
expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
sorted_token_ids_lora = sorted_token_ids_lora.view(self.max_loras, -1)
sorted_token_ids_lora = sorted_token_ids_lora.view(
self.max_loras, -1
)
intermediate_cache2 = moe_state_dict["intermediate_cache2"]
intermediate_cache3 = args[0]
......@@ -295,6 +315,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
True,
fully_sharded=self.fully_sharded,
offset=shard_size_w2 * self.tp_rank if self.fully_sharded else 0,
token_lora_mapping=token_lora_mapping,
)
result = func(*args, **kwargs)
......
......@@ -12,6 +12,64 @@ from vllm.utils.torch_utils import direct_register_custom_op
from .utils import supports_pdl
@triton.jit
def _get_lora_id(
lora_ids,
token_lora_mapping_ptr,
lora_idx,
pid_m,
top_k_num,
naive_block_assignment: tl.constexpr,
):
"""Returns lora_id"""
if naive_block_assignment:
token_idx = pid_m // top_k_num
return tl.load(token_lora_mapping_ptr + token_idx)
else:
return tl.load(lora_ids + lora_idx)
@triton.jit
def _get_expert_id(
expert_ids_ptr,
lora_id,
pid_m,
stride_el,
max_loras,
naive_block_assignment: tl.constexpr,
):
"""Returns expert_id"""
if naive_block_assignment:
return tl.load(expert_ids_ptr + pid_m)
else:
ind = lora_id * stride_el + pid_m
return tl.load(expert_ids_ptr + ind, ind < max_loras * stride_el, -1)
@triton.jit
def _get_token_offs(
sorted_token_ids_ptr,
lora_id,
pid_m,
offs,
stride_tl,
max_loras,
num_valid_tokens,
naive_block_assignment: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
):
"""Returns token offsets"""
if naive_block_assignment:
return tl.where(offs == 0, pid_m, num_valid_tokens)
else:
offs_token_id = pid_m * BLOCK_SIZE_M + offs
token_ind = stride_tl * lora_id + offs_token_id
return tl.load(
sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0
)
_LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {}
......@@ -36,6 +94,25 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device):
return _LORA_PTR_DICT.get(key)
def _adjust_kernel_inputs(
max_loras: int,
sorted_token_ids: torch.Tensor | None,
expert_ids: torch.Tensor,
):
"""
helper function to adjust kernel inputs when sorted_token_ids is None
"""
if sorted_token_ids is None:
stride_tl = 0
stride_el = 0
grid_lora_dim = 1
else:
stride_tl = sorted_token_ids.stride(0)
stride_el = expert_ids.stride(0)
grid_lora_dim = max_loras + 1
return grid_lora_dim, stride_tl, stride_el
@triton.jit(
do_not_specialize=[
"num_valid_tokens",
......@@ -54,12 +131,14 @@ def _fused_moe_lora_kernel(
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
token_lora_mapping_ptr,
# Matrix dimensions
N,
K,
EM,
num_valid_tokens,
num_experts,
top_k_num,
lora_ids,
adapter_enabled,
max_loras, # <<< PR2: rename, used for masks when grid axis-2 != max_loras
......@@ -82,7 +161,11 @@ def _fused_moe_lora_kernel(
# Meta-parameters
num_slice_a: tl.constexpr,
num_slice_c: tl.constexpr,
top_k: tl.constexpr,
# top_k_num or 1 depending on input token
# is expanded by top_k or not
token_mapping_factor: tl.constexpr,
# whether use naive block assignment
naive_block_assignment: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
ADD_INPUTS: tl.constexpr,
USE_B_L2_CACHE: tl.constexpr, # new, enable .ca load for B
......@@ -97,26 +180,10 @@ def _fused_moe_lora_kernel(
):
pid = tl.program_id(axis=0)
slice_id = tl.program_id(axis=1)
lora_idx = tl.program_id(axis=2)
lora_id = tl.load(lora_ids + lora_idx)
if lora_id == -1:
# Early exit for the no-lora case.
return
moe_enabled = tl.load(adapter_enabled + lora_id)
if moe_enabled == 0:
# Early exit for the no moe lora case.
return
# The grid's axis-2 dimension is max_loras + 1 to accommodate the -1 sentinel.
# This guard ensures we don't access sorted_token_ids / expert_ids /
# num_tokens_post_padded beyond their allocated bounds if an invalid
# lora_id somehow appears. Although the caller should pass correct
# max_loras, defensive programming prevents accidental out-of-bounds.
if lora_id >= max_loras:
return
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
# calculate pid_m,pid_n
lora_idx = tl.program_id(axis=2)
pid_sk = pid % SPLIT_K
pid_m_n = pid // SPLIT_K
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
......@@ -129,14 +196,55 @@ def _fused_moe_lora_kernel(
pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m)
pid_n = (pid_m_n % num_pid_in_group) // group_size_m
offs = tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
# Get lora_id
lora_id = _get_lora_id(
lora_ids,
token_lora_mapping_ptr,
lora_idx,
pid_m,
top_k_num,
naive_block_assignment,
)
if lora_id == -1:
return
moe_enabled = tl.load(adapter_enabled + lora_id)
if moe_enabled == 0:
return
if lora_id >= max_loras:
return
# Non-naive only: check num_tokens_post_padded
if not naive_block_assignment:
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_id)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
# get the expert_id to process curr shard
ind = lora_id * stride_el + pid_m
expert_id = tl.load(expert_ids_ptr + ind, ind < max_loras * stride_el, -1)
# Get expert_id
expert_id = _get_expert_id(
expert_ids_ptr,
lora_id,
pid_m,
stride_el,
max_loras,
naive_block_assignment,
)
if expert_id == -1:
return
# Get token offsets
offs_token = _get_token_offs(
sorted_token_ids_ptr,
lora_id,
pid_m,
offs,
stride_tl,
max_loras,
num_valid_tokens,
naive_block_assignment,
BLOCK_SIZE_M,
)
# get a_ptr,b_ptr,c_ptr
cur_a_ptr = a_ptr + (slice_id % num_slice_a) * slice_a_size
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty))
......@@ -145,19 +253,12 @@ def _fused_moe_lora_kernel(
# remove modulo wrap-around
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int32)
offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int32)
token_ind = stride_tl * lora_id + offs_token_id
offs_token = tl.load(
sorted_token_ids_ptr + token_ind,
mask=token_ind < max_loras * stride_tl,
other=num_valid_tokens,
)
token_mask = offs_token < num_valid_tokens
# get a_ptrs,b_ptrs
a_ptrs = cur_a_ptr + (
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
offs_token[:, None] // token_mapping_factor * stride_am
+ offs_k[None, :] * stride_ak
)
b_ptrs = (
......@@ -230,9 +331,10 @@ def _fused_moe_lora_shrink(
torch.Tensor
], # [(max_loras, num_experts, max_lora_rank, K,),...]
topk_weights: torch.Tensor, # (num_tokens, top_k_num)
sorted_token_ids: torch.Tensor, # (max_loras, _)
expert_ids: torch.Tensor, # (max_loras, _ ,)
num_tokens_post_padded: torch.Tensor, # (max_loras, )
sorted_token_ids: torch.Tensor | None, # (max_loras, _)
expert_ids: torch.Tensor, # (max_loras, _ ,) or (num_tokens * top_k,)
num_tokens_post_padded: torch.Tensor | None, # (max_loras, )
token_lora_mapping: torch.Tensor,
top_k_num: int,
lora_ids: torch.Tensor,
adapter_enabled: torch.Tensor,
......@@ -270,13 +372,15 @@ def _fused_moe_lora_shrink(
b_ptr = _get_ptr(lora_a_stacked, device)
grid_lora_dim, stride_tl, stride_el = _adjust_kernel_inputs(
w1_lora_a_stacked.shape[0], sorted_token_ids, expert_ids
)
grid = lambda META: (
split_k
* triton.cdiv(EM, META["BLOCK_SIZE_M"])
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
len(lora_a_stacked),
## max_loras + 1 to handle the no-lora case (lora_id == -1)
lora_a_stacked[0].shape[0] + 1,
grid_lora_dim,
)
_fused_moe_lora_kernel[grid](
qcurr_hidden_states,
......@@ -286,11 +390,13 @@ def _fused_moe_lora_shrink(
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
token_lora_mapping,
N,
K,
EM,
num_tokens,
num_experts,
top_k_num,
lora_ids,
adapter_enabled,
lora_a_stacked[0].shape[0],
......@@ -302,13 +408,14 @@ def _fused_moe_lora_shrink(
w1_lora_a_stacked.stride(2),
a_intermediate_cache1.stride(2),
a_intermediate_cache1.stride(3),
sorted_token_ids.stride(0),
expert_ids.stride(0),
stride_tl,
stride_el,
slice_a_size=qcurr_hidden_states.numel(),
slice_c_size=a_intermediate_cache1.numel() // num_slices,
num_slice_a=1,
num_slice_c=num_slices,
top_k=1 if mul_routed_weight else top_k_num,
token_mapping_factor=1 if mul_routed_weight else top_k_num,
naive_block_assignment=sorted_token_ids is None,
MUL_ROUTED_WEIGHT=False,
ADD_INPUTS=False,
USE_B_L2_CACHE=True, # new
......@@ -325,9 +432,10 @@ def _fused_moe_lora_expand(
torch.Tensor
], # [(max_loras, num_experts, max_lora_rank, K,),...]
topk_weights: torch.Tensor, # (num_tokens, top_k_num)
sorted_token_ids: torch.Tensor, # (max_loras, _)
expert_ids: torch.Tensor, # (max_loras, _ ,)
num_tokens_post_padded: torch.Tensor, # (max_loras, )
sorted_token_ids: torch.Tensor | None, # (max_loras, _)
expert_ids: torch.Tensor, # (max_loras, _ ,) or (num_tokens * top_k,)
num_tokens_post_padded: torch.Tensor | None, # (max_loras, )
token_lora_mapping: torch.Tensor,
top_k_num: int,
lora_ids: torch.Tensor,
adapter_enabled: torch.Tensor,
......@@ -375,11 +483,14 @@ def _fused_moe_lora_expand(
"launch_pdl": use_gdc, # triton kernel metadata
}
grid_lora_dim, stride_tl, stride_el = _adjust_kernel_inputs(
w1_lora_b_stacked.shape[0], sorted_token_ids, expert_ids
)
grid = lambda META: (
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
len(lora_b_stacked),
## max_loras + 1 to handle the no-lora case (lora_id == -1)
lora_b_stacked[0].shape[0] + 1,
grid_lora_dim,
)
# Fast path: directly accumulate into the corresponding slice interval of output.
......@@ -394,11 +505,13 @@ def _fused_moe_lora_expand(
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
token_lora_mapping,
N,
K,
EM,
num_tokens,
num_experts,
top_k_num,
lora_ids,
adapter_enabled,
lora_b_stacked[0].shape[0],
......@@ -410,13 +523,14 @@ def _fused_moe_lora_expand(
w1_lora_b_stacked.stride(2),
out_view.stride(1),
out_view.stride(2),
sorted_token_ids.stride(0),
expert_ids.stride(0),
stride_tl,
stride_el,
slice_a_size=a_intermediate_cache1.numel() // num_slices,
slice_c_size=slice_c_size,
num_slice_a=num_slices,
num_slice_c=num_slices,
top_k=1,
token_mapping_factor=1,
naive_block_assignment=sorted_token_ids is None,
MUL_ROUTED_WEIGHT=mul_routed_weight,
ADD_INPUTS=True,
USE_B_L2_CACHE=True, # new
......@@ -436,9 +550,10 @@ def _fused_moe_lora(
torch.Tensor
], # [(max_loras, num_experts, N, max_lora_rank,),...]
topk_weights: torch.Tensor, # (num_tokens, top_k_num)
sorted_token_ids: torch.Tensor, # (max_loras, _)
expert_ids: torch.Tensor, # (max_loras, _ ,)
num_tokens_post_padded: torch.Tensor, # (max_loras, )
sorted_token_ids: torch.Tensor | None, # (max_loras, _)
expert_ids: torch.Tensor, # (max_loras, _ ,) or (num_tokens * top_k,)
num_tokens_post_padded: torch.Tensor | None, # (max_loras, )
token_lora_mapping: torch.Tensor,
max_lora_rank: int,
top_k_num: int,
lora_ids: torch.Tensor,
......@@ -462,6 +577,12 @@ def _fused_moe_lora(
offset: int = 0,
) -> None:
assert len(lora_a_stacked) == len(lora_b_stacked) > 0
assert topk_weights.dim() == qcurr_hidden_states.dim() == 2
if sorted_token_ids is None:
assert expert_ids.dim() == 1
else:
assert sorted_token_ids is not None
assert num_tokens_post_padded is not None
assert (
sorted_token_ids.dim()
== expert_ids.dim()
......@@ -482,10 +603,15 @@ def _fused_moe_lora(
num_experts = lora_a_stacked[0].shape[1]
N = max_lora_rank
M = topk_weights.shape[0]
EM = sorted_token_ids.shape[1]
K = qcurr_hidden_states.shape[1]
num_tokens = M * top_k_num
w1_output_dim_size = w1_lora_b_stacked.shape[2]
assert shrink_block_size_m == expand_block_size_m
EM = (
sorted_token_ids.shape[1]
if sorted_token_ids is not None
else num_tokens * shrink_block_size_m
)
a_intermediate_cache1 = torch.zeros(
(num_slices, M, top_k_num, max_lora_rank),
......@@ -502,6 +628,7 @@ def _fused_moe_lora(
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
token_lora_mapping,
top_k_num,
lora_ids,
adapter_enabled,
......@@ -546,6 +673,7 @@ def _fused_moe_lora(
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
token_lora_mapping,
top_k_num,
lora_ids,
adapter_enabled,
......@@ -579,9 +707,10 @@ def _fused_moe_lora_fake(
lora_a_stacked: list[torch.Tensor],
lora_b_stacked: list[torch.Tensor],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor,
sorted_token_ids: torch.Tensor | None,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
num_tokens_post_padded: torch.Tensor | None,
token_lora_mapping: torch.Tensor,
max_lora_rank: int,
top_k_num: int,
lora_ids: torch.Tensor,
......@@ -610,9 +739,10 @@ def _fused_moe_lora_shrink_fake(
qcurr_hidden_states: torch.Tensor,
lora_a_stacked: list[torch.Tensor],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor,
sorted_token_ids: torch.Tensor | None,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
num_tokens_post_padded: torch.Tensor | None,
token_lora_mapping: torch.Tensor,
top_k_num: int,
lora_ids: torch.Tensor,
adapter_enabled: torch.Tensor,
......@@ -642,9 +772,10 @@ def _fused_moe_lora_expand_fake(
a_intermediate_cache1: torch.Tensor,
lora_b_stacked: list[torch.Tensor],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor,
sorted_token_ids: torch.Tensor | None,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
num_tokens_post_padded: torch.Tensor | None,
token_lora_mapping: torch.Tensor,
top_k_num: int,
lora_ids: torch.Tensor,
adapter_enabled: torch.Tensor,
......
......@@ -458,7 +458,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
adapter_enabled: torch.Tensor,
expert_map: torch.Tensor | None = None,
pad_sorted_ids: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns tokens and experts into block-sized chunks for LoRA-based
mixture-of-experts (MoE) execution.
......@@ -473,9 +473,9 @@ class PunicaWrapperBase(PunicaWrapperABC):
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor,
sorted_token_ids: torch.Tensor | None,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
num_tokens_post_padded: torch.Tensor | None,
max_lora_rank: int,
top_k_num: int,
shrink_config,
......@@ -484,6 +484,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
mul_routed_weight=False,
fully_sharded: bool = False,
offset: int = 0,
token_lora_mapping: torch.Tensor | None = None,
):
"""
Performs a fused forward computation for LoRA of
......
......@@ -310,11 +310,20 @@ class PunicaWrapperGPU(PunicaWrapperBase):
adapter_enabled: torch.Tensor,
expert_map: torch.Tensor | None = None,
pad_sorted_ids: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
naive_block_assignment: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns tokens and experts into block-sized chunks for LoRA-based
mixture-of-experts (MoE) execution.
"""
(token_lora_mapping, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(
num_tokens
)
if naive_block_assignment:
expert_ids = topk_ids.reshape(-1)
sorted_ids = None
num_tokens_post_pad = None
else:
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
if pad_sorted_ids:
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
......@@ -334,10 +343,6 @@ class PunicaWrapperGPU(PunicaWrapperBase):
(max_loras), dtype=torch.int32, device=topk_ids.device
)
(token_lora_mapping, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(
num_tokens
)
ops.moe_lora_align_block_size(
topk_ids,
token_lora_mapping,
......@@ -355,7 +360,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
if expert_map is not None:
expert_ids = expert_map[expert_ids]
return sorted_ids, expert_ids, num_tokens_post_pad
return None, sorted_ids, expert_ids, num_tokens_post_pad
def add_lora_fused_moe(
self,
......@@ -364,9 +369,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor,
sorted_token_ids: torch.Tensor | None,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
num_tokens_post_padded: torch.Tensor | None,
max_lora_rank: int,
top_k_num: int,
shrink_config,
......@@ -375,11 +380,21 @@ class PunicaWrapperGPU(PunicaWrapperBase):
mul_routed_weight=False,
fully_sharded: bool = False,
offset: int = 0,
token_lora_mapping: torch.Tensor | None = None,
):
"""
Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer.
"""
(_, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(x.size(0))
(
token_lora_mapping_meta,
_,
_,
_,
lora_ids,
_,
) = self.token_mapping_meta.meta_args(x.size(0))
if token_lora_mapping is None:
token_lora_mapping = token_lora_mapping_meta
fused_moe_lora(
y,
x,
......@@ -389,6 +404,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
token_lora_mapping,
max_lora_rank,
top_k_num,
lora_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment