Unverified Commit 6b0aeb58 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[moe] optim: reduce memory consumption in fused_moe (#3692)

parent bb3e5268
......@@ -933,20 +933,21 @@ def fused_experts_impl(
config = get_config_func(M)
intermediate_cache1 = torch.empty(
(M, topk_ids.shape[1], N),
cache = torch.empty(
M * topk_ids.shape[1] * max(N, w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache1 = cache[: M * topk_ids.shape[1] * N].view(
(M, topk_ids.shape[1], N),
)
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache3 = torch.empty(
intermediate_cache3 = cache[: M * topk_ids.shape[1] * w2.shape[1]].view(
(M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment