Unverified Commit 6b0aeb58 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[moe] optim: reduce memory consumption in fused_moe (#3692)

parent bb3e5268
...@@ -933,20 +933,21 @@ def fused_experts_impl( ...@@ -933,20 +933,21 @@ def fused_experts_impl(
config = get_config_func(M) config = get_config_func(M)
intermediate_cache1 = torch.empty( cache = torch.empty(
(M, topk_ids.shape[1], N), M * topk_ids.shape[1] * max(N, w2.shape[1]),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
) )
intermediate_cache1 = cache[: M * topk_ids.shape[1] * N].view(
(M, topk_ids.shape[1], N),
)
intermediate_cache2 = torch.empty( intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N // 2), (M * topk_ids.shape[1], N // 2),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
) )
intermediate_cache3 = torch.empty( intermediate_cache3 = cache[: M * topk_ids.shape[1] * w2.shape[1]].view(
(M, topk_ids.shape[1], w2.shape[1]), (M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
) )
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment