Unverified Commit e33192b2 authored by cwazai's avatar cwazai Committed by GitHub
Browse files

[lora/moe] Improve fused MoE‑LoRA kernel indexing and memory access (#32770)


Signed-off-by: default avatar陈建华 <1647430658@qq.com>
Signed-off-by: default avatarYanwen Lin <lyw1124278064@gmail.com>
Signed-off-by: default avatarkimheesu <wlskaka4@gmail.com>
Signed-off-by: default avatarDivakar Verma <divakar.verma@amd.com>
Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Signed-off-by: default avatarganyi <ygan@amd.com>
Signed-off-by: default avatarwhx-sjtu <2952154980@qq.com>
Signed-off-by: default avatarelvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: default avatarDaniel Serebrenik <daserebrenik@nvidia.com>
Signed-off-by: default avatarYanan Cao <gmagogsfm@gmail.com>
Signed-off-by: default avatarXin Yang <xyangx@amazon.com>
Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
Signed-off-by: default avatarMatthew Wong <Matthew.Wong2@amd.com>
Signed-off-by: default avatarknlnguyen1802 <knlnguyen1802@gmail.com>
Signed-off-by: default avatarIfta Khairul Alam Adil <ikaadil007@gmail.com>
Signed-off-by: Ifta khairul Alam Adil <25082512+ikaadil@users.noreply.gith...
parent 61274bde
......@@ -62,6 +62,7 @@ def _fused_moe_lora_kernel(
num_experts,
lora_ids,
adapter_enabled,
max_loras, # <<< PR2: rename, used for masks when grid axis-2 != max_loras
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
......@@ -83,6 +84,7 @@ def _fused_moe_lora_kernel(
num_slice_c: tl.constexpr,
top_k: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
USE_B_L2_CACHE: tl.constexpr, # new, enable .ca load for B
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
......@@ -104,10 +106,13 @@ def _fused_moe_lora_kernel(
if moe_enabled == 0:
# Early exit for the no moe lora case.
return
# The grid size on axis 2 is (max_loras + 1) to handle the no-lora case
# (lora_id == -1), but sorted_token_ids and expert_ids are allocated with
# shape (max_loras, ...). Use (num_programs - 1) for correct bounds checking.
max_loras = tl.num_programs(axis=2) - 1
# The grid's axis-2 dimension is max_loras + 1 to accommodate the -1 sentinel.
# This guard ensures we don't access sorted_token_ids / expert_ids /
# num_tokens_post_padded beyond their allocated bounds if an invalid
# lora_id somehow appears. Although the caller should pass correct
# max_loras, defensive programming prevents accidental out-of-bounds.
if lora_id >= max_loras:
return
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
# calculate pid_m,pid_n
......@@ -136,10 +141,11 @@ def _fused_moe_lora_kernel(
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty))
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
# remove modulo wrap-around
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int32)
offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int32)
token_ind = stride_tl * lora_id + offs_token_id
offs_token = tl.load(
sorted_token_ids_ptr + token_ind,
......@@ -176,7 +182,13 @@ def _fused_moe_lora_kernel(
# GDC wait waits for ALL programs in the prior kernel to complete
# before continuing.
# pre-fetch lora weight
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
# add (offs_bn < N) mask; optional .ca for B
b_mask = (offs_k[:, None] < k_remaining) & (offs_bn[None, :] < N)
if USE_B_L2_CACHE:
b = tl.load(b_ptrs, mask=b_mask, other=0.0, cache_modifier=".ca")
else:
b = tl.load(b_ptrs, mask=b_mask, other=0.0)
if USE_GDC and not IS_PRIMARY:
tl.extra.cuda.gdc_wait()
a = tl.load(
......@@ -276,6 +288,7 @@ def _fused_moe_lora_shrink(
num_experts,
lora_ids,
adapter_enabled,
lora_a_stacked[0].shape[0],
qcurr_hidden_states.stride(0),
qcurr_hidden_states.stride(1),
w1_lora_a_stacked.stride(0),
......@@ -292,6 +305,7 @@ def _fused_moe_lora_shrink(
num_slice_c=num_slices,
top_k=1 if mul_routed_weight else top_k_num,
MUL_ROUTED_WEIGHT=False,
USE_B_L2_CACHE=True, # new
IS_PRIMARY=True,
**shrink_config,
)
......@@ -377,6 +391,7 @@ def _fused_moe_lora_expand(
num_experts,
lora_ids,
adapter_enabled,
lora_b_stacked[0].shape[0],
a_intermediate_cache1.stride(0),
a_intermediate_cache1.stride(1),
w1_lora_b_stacked.stride(0),
......@@ -393,6 +408,7 @@ def _fused_moe_lora_expand(
num_slice_c=num_slices,
top_k=1,
MUL_ROUTED_WEIGHT=mul_routed_weight,
USE_B_L2_CACHE=True, # new
IS_PRIMARY=False,
**expand_config,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment