Unverified Commit d5503ca7 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[LoRA] LoRA PDL improvement (#31660)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent a2ad15c0
...@@ -163,15 +163,17 @@ def _fused_moe_lora_kernel( ...@@ -163,15 +163,17 @@ def _fused_moe_lora_kernel(
# accumulator # accumulator
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# GDC wait waits for ALL programs in the prior kernel to complete
# before continuing.
if USE_GDC and not IS_PRIMARY: if USE_GDC and not IS_PRIMARY:
tl.extra.cuda.gdc_wait() tl.extra.cuda.gdc_wait()
for k in range(0, grid_k): for k in range(0, grid_k):
k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K) k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K)
# GDC wait waits for ALL programs in the prior kernel to complete
# before continuing.
# pre-fetch lora weight # pre-fetch lora weight
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
if USE_GDC and not IS_PRIMARY:
tl.extra.cuda.gdc_wait()
a = tl.load( a = tl.load(
a_ptrs, a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < k_remaining), mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment