Unverified Commit 5a3adf58 authored by gnovack's avatar gnovack Committed by GitHub
Browse files

fused_moe_lora PDL improvements (#30716)


Signed-off-by: default avatargnovack <gnovack@amazon.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 6fe58876
...@@ -156,16 +156,22 @@ def _fused_moe_lora_kernel( ...@@ -156,16 +156,22 @@ def _fused_moe_lora_kernel(
+ offs_bn[None, :] * stride_bn + offs_bn[None, :] * stride_bn
) )
if USE_GDC and IS_PRIMARY:
# GDC launch dependents hints the runtime system to launch dependent kernels.
tl.extra.cuda.gdc_launch_dependents()
# accumulator # accumulator
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, grid_k):
k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K)
# pre-fetch lora weight
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
# GDC wait waits for ALL programs in the prior kernel to complete # GDC wait waits for ALL programs in the prior kernel to complete
# before continuing. # before continuing.
if USE_GDC and not IS_PRIMARY: if USE_GDC and not IS_PRIMARY:
tl.extra.cuda.gdc_wait() tl.extra.cuda.gdc_wait()
for k in range(0, grid_k):
k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K)
# pre-fetch lora weight
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
a = tl.load( a = tl.load(
a_ptrs, a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < k_remaining), mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
...@@ -179,9 +185,6 @@ def _fused_moe_lora_kernel( ...@@ -179,9 +185,6 @@ def _fused_moe_lora_kernel(
if MUL_ROUTED_WEIGHT: if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None] accumulator = accumulator * moe_weight[:, None]
if USE_GDC and IS_PRIMARY:
# GDC launch dependents hints the runtime system to launch dependent kernels.
tl.extra.cuda.gdc_launch_dependents()
accumulator = accumulator.to(c_ptr.dtype.element_ty) accumulator = accumulator.to(c_ptr.dtype.element_ty)
# Write back the block of the output # Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
...@@ -290,6 +293,7 @@ def _fused_moe_lora_shrink( ...@@ -290,6 +293,7 @@ def _fused_moe_lora_shrink(
def _fused_moe_lora_expand( def _fused_moe_lora_expand(
output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),) output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),)
a_intermediate_cache1: torch.Tensor, # (num_slices, M, top_k_num, max_lora_rank) a_intermediate_cache1: torch.Tensor, # (num_slices, M, top_k_num, max_lora_rank)
b_intermediate_cache1: torch.Tensor, # (num_slices, M, top_k_num, output_dim_size)
lora_b_stacked: list[ lora_b_stacked: list[
torch.Tensor torch.Tensor
], # [(max_loras, num_experts, max_lora_rank, K,),...] ], # [(max_loras, num_experts, max_lora_rank, K,),...]
...@@ -331,11 +335,6 @@ def _fused_moe_lora_expand( ...@@ -331,11 +335,6 @@ def _fused_moe_lora_expand(
-1, a_intermediate_cache1.shape[3] -1, a_intermediate_cache1.shape[3]
) )
b_intermediate_cache1 = torch.zeros(
(num_slices, M, top_k_num, w1_output_dim_size),
dtype=output.dtype,
device=device,
)
use_gdc = supports_pdl(a_intermediate_cache1.device) use_gdc = supports_pdl(a_intermediate_cache1.device)
expand_config = { expand_config = {
"BLOCK_SIZE_M": block_size_m, "BLOCK_SIZE_M": block_size_m,
...@@ -460,6 +459,12 @@ def _fused_moe_lora( ...@@ -460,6 +459,12 @@ def _fused_moe_lora(
device=device, device=device,
) )
b_intermediate_cache1 = torch.zeros(
(num_slices, M, top_k_num, w1_output_dim_size),
dtype=output.dtype,
device=device,
)
_fused_moe_lora_shrink( _fused_moe_lora_shrink(
a_intermediate_cache1, a_intermediate_cache1,
qcurr_hidden_states, qcurr_hidden_states,
...@@ -506,6 +511,7 @@ def _fused_moe_lora( ...@@ -506,6 +511,7 @@ def _fused_moe_lora(
_fused_moe_lora_expand( _fused_moe_lora_expand(
output, output,
a_intermediate_cache1, a_intermediate_cache1,
b_intermediate_cache1,
lora_b_stacked, lora_b_stacked,
topk_weights, topk_weights,
sorted_token_ids, sorted_token_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment