"tests/vscode:/vscode.git/clone" did not exist on "b9b469ea508717f72cc6c1a5cb8d2e522c87bb23"
Unverified Commit 3f70a2b1 authored by HELSON's avatar HELSON Committed by GitHub
Browse files

removed noisy function during evaluation of MoE router (#419)

parent adebb3e0
......@@ -52,7 +52,7 @@ class Top1Router(nn.Module):
def forward(self, inputs: torch.Tensor, cuda_mode: bool = False):
if self.noisy_func is not None:
if self.noisy_func is not None and self.training:
inputs_noisy = self.noisy_func(inputs)
else:
inputs_noisy = inputs
......@@ -126,7 +126,7 @@ class Top2Router(nn.Module):
def forward(self, inputs: torch.Tensor, cuda_mode: bool = False):
# inputs: [s, h]
if self.noisy_func is not None:
if self.noisy_func is not None and self.training:
inputs = self.noisy_func(inputs)
logits = autocast_softmax(inputs, dim=-1) # logits: [s, e]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment