Commit 0ead9141 authored by xuwx1's avatar xuwx1
Browse files

updata moe_utils.py

parent f43ec2dd
Pipeline #2549 passed with stage
......@@ -19,7 +19,7 @@ try:
except ImportError:
HAVE_TE = False
@torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True})
def switch_load_balancing_loss_func(
probs: torch.Tensor,
tokens_per_expert: torch.Tensor,
......@@ -217,7 +217,7 @@ class MoEAuxLossAutoScaler(torch.autograd.Function):
"""
MoEAuxLossAutoScaler.main_loss_backward_scale = scale
@torch.compile(mode='max-autotune-no-cudagraphs')
def permute(
tokens,
routing_map,
......@@ -278,7 +278,7 @@ def permute(
return permuted_input, sorted_indices
@torch.compile(mode='max-autotune-no-cudagraphs')
def unpermute(
permuted_tokens: torch.Tensor,
sorted_indices: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment