Commit 0ead9141 authored by xuwx1's avatar xuwx1
Browse files

updata moe_utils.py

parent f43ec2dd
Pipeline #2549 passed with stage
...@@ -19,7 +19,7 @@ try: ...@@ -19,7 +19,7 @@ try:
except ImportError: except ImportError:
HAVE_TE = False HAVE_TE = False
@torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True})
def switch_load_balancing_loss_func( def switch_load_balancing_loss_func(
probs: torch.Tensor, probs: torch.Tensor,
tokens_per_expert: torch.Tensor, tokens_per_expert: torch.Tensor,
...@@ -217,7 +217,7 @@ class MoEAuxLossAutoScaler(torch.autograd.Function): ...@@ -217,7 +217,7 @@ class MoEAuxLossAutoScaler(torch.autograd.Function):
""" """
MoEAuxLossAutoScaler.main_loss_backward_scale = scale MoEAuxLossAutoScaler.main_loss_backward_scale = scale
@torch.compile(mode='max-autotune-no-cudagraphs')
def permute( def permute(
tokens, tokens,
routing_map, routing_map,
...@@ -278,7 +278,7 @@ def permute( ...@@ -278,7 +278,7 @@ def permute(
return permuted_input, sorted_indices return permuted_input, sorted_indices
@torch.compile(mode='max-autotune-no-cudagraphs')
def unpermute( def unpermute(
permuted_tokens: torch.Tensor, permuted_tokens: torch.Tensor,
sorted_indices: torch.Tensor, sorted_indices: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment