Unverified Commit f210f0b7 authored by cwazai's avatar cwazai Committed by GitHub
Browse files

[lora/moe] Avoid extra intermediate buffer & Python slicing in expand phase...


[lora/moe] Avoid extra intermediate buffer & Python slicing in expand phase when split_k == 1 (#32774)
Signed-off-by: default avatar陈建华 <1647430658@qq.com>
parent 392c5af4
...@@ -84,6 +84,7 @@ def _fused_moe_lora_kernel( ...@@ -84,6 +84,7 @@ def _fused_moe_lora_kernel(
num_slice_c: tl.constexpr, num_slice_c: tl.constexpr,
top_k: tl.constexpr, top_k: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr,
ADD_INPUTS: tl.constexpr,
USE_B_L2_CACHE: tl.constexpr, # new, enable .ca load for B USE_B_L2_CACHE: tl.constexpr, # new, enable .ca load for B
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
...@@ -211,7 +212,11 @@ def _fused_moe_lora_kernel( ...@@ -211,7 +212,11 @@ def _fused_moe_lora_kernel(
c_mask = token_mask[:, None] & (offs_cn[None, :] < N) c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
if SPLIT_K == 1: if SPLIT_K == 1:
tl.store(c_ptrs, accumulator, mask=c_mask) if ADD_INPUTS:
prev = tl.load(c_ptrs, mask=c_mask, other=0.0)
tl.store(c_ptrs, prev + accumulator, mask=c_mask)
else:
tl.store(c_ptrs, accumulator, mask=c_mask)
else: else:
tl.atomic_add(c_ptrs, accumulator, mask=c_mask, sem="relaxed") tl.atomic_add(c_ptrs, accumulator, mask=c_mask, sem="relaxed")
...@@ -305,6 +310,7 @@ def _fused_moe_lora_shrink( ...@@ -305,6 +310,7 @@ def _fused_moe_lora_shrink(
num_slice_c=num_slices, num_slice_c=num_slices,
top_k=1 if mul_routed_weight else top_k_num, top_k=1 if mul_routed_weight else top_k_num,
MUL_ROUTED_WEIGHT=False, MUL_ROUTED_WEIGHT=False,
ADD_INPUTS=False,
USE_B_L2_CACHE=True, # new USE_B_L2_CACHE=True, # new
IS_PRIMARY=True, IS_PRIMARY=True,
**shrink_config, **shrink_config,
...@@ -315,7 +321,6 @@ def _fused_moe_lora_shrink( ...@@ -315,7 +321,6 @@ def _fused_moe_lora_shrink(
def _fused_moe_lora_expand( def _fused_moe_lora_expand(
output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),) output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),)
a_intermediate_cache1: torch.Tensor, # (num_slices, M, top_k_num, max_lora_rank) a_intermediate_cache1: torch.Tensor, # (num_slices, M, top_k_num, max_lora_rank)
b_intermediate_cache1: torch.Tensor, # (num_slices, M, top_k_num, output_dim_size)
lora_b_stacked: list[ lora_b_stacked: list[
torch.Tensor torch.Tensor
], # [(max_loras, num_experts, max_lora_rank, K,),...] ], # [(max_loras, num_experts, max_lora_rank, K,),...]
...@@ -376,10 +381,15 @@ def _fused_moe_lora_expand( ...@@ -376,10 +381,15 @@ def _fused_moe_lora_expand(
## max_loras + 1 to handle the no-lora case (lora_id == -1) ## max_loras + 1 to handle the no-lora case (lora_id == -1)
lora_b_stacked[0].shape[0] + 1, lora_b_stacked[0].shape[0] + 1,
) )
# Fast path: directly accumulate into the corresponding slice interval of output.
out_view = output[:, :, offset : offset + num_slices * N]
slice_c_size = N * out_view.stride(2)
_fused_moe_lora_kernel[grid]( _fused_moe_lora_kernel[grid](
a_intermediate_cache1, a_intermediate_cache1,
b_ptr, b_ptr,
b_intermediate_cache1, out_view,
topk_weights, topk_weights,
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
...@@ -398,22 +408,21 @@ def _fused_moe_lora_expand( ...@@ -398,22 +408,21 @@ def _fused_moe_lora_expand(
w1_lora_b_stacked.stride(1), w1_lora_b_stacked.stride(1),
w1_lora_b_stacked.stride(3), w1_lora_b_stacked.stride(3),
w1_lora_b_stacked.stride(2), w1_lora_b_stacked.stride(2),
b_intermediate_cache1.stride(2), out_view.stride(1),
b_intermediate_cache1.stride(3), out_view.stride(2),
sorted_token_ids.stride(0), sorted_token_ids.stride(0),
expert_ids.stride(0), expert_ids.stride(0),
slice_a_size=a_intermediate_cache1.numel() // num_slices, slice_a_size=a_intermediate_cache1.numel() // num_slices,
slice_c_size=b_intermediate_cache1.numel() // num_slices, slice_c_size=slice_c_size,
num_slice_a=num_slices, num_slice_a=num_slices,
num_slice_c=num_slices, num_slice_c=num_slices,
top_k=1, top_k=1,
MUL_ROUTED_WEIGHT=mul_routed_weight, MUL_ROUTED_WEIGHT=mul_routed_weight,
ADD_INPUTS=True,
USE_B_L2_CACHE=True, # new USE_B_L2_CACHE=True, # new
IS_PRIMARY=False, IS_PRIMARY=False,
**expand_config, **expand_config,
) )
for i in range(num_slices):
output[:, :, i * N + offset : (i + 1) * N + offset] += b_intermediate_cache1[i]
@torch.inference_mode() @torch.inference_mode()
...@@ -484,11 +493,6 @@ def _fused_moe_lora( ...@@ -484,11 +493,6 @@ def _fused_moe_lora(
device=device, device=device,
) )
b_intermediate_cache1 = torch.zeros(
(num_slices, M, top_k_num, w1_output_dim_size),
dtype=output.dtype,
device=device,
)
use_gdc = supports_pdl(device) and not fully_sharded use_gdc = supports_pdl(device) and not fully_sharded
_fused_moe_lora_shrink( _fused_moe_lora_shrink(
a_intermediate_cache1, a_intermediate_cache1,
...@@ -537,7 +541,6 @@ def _fused_moe_lora( ...@@ -537,7 +541,6 @@ def _fused_moe_lora(
_fused_moe_lora_expand( _fused_moe_lora_expand(
output, output,
a_intermediate_cache1, a_intermediate_cache1,
b_intermediate_cache1,
lora_b_stacked, lora_b_stacked,
topk_weights, topk_weights,
sorted_token_ids, sorted_token_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment