Unverified Commit 3f70a2b1 authored by HELSON's avatar HELSON Committed by GitHub
Browse files

removed noisy function during evaluation of MoE router (#419)

parent adebb3e0
...@@ -52,7 +52,7 @@ class Top1Router(nn.Module): ...@@ -52,7 +52,7 @@ class Top1Router(nn.Module):
def forward(self, inputs: torch.Tensor, cuda_mode: bool = False): def forward(self, inputs: torch.Tensor, cuda_mode: bool = False):
if self.noisy_func is not None: if self.noisy_func is not None and self.training:
inputs_noisy = self.noisy_func(inputs) inputs_noisy = self.noisy_func(inputs)
else: else:
inputs_noisy = inputs inputs_noisy = inputs
...@@ -126,7 +126,7 @@ class Top2Router(nn.Module): ...@@ -126,7 +126,7 @@ class Top2Router(nn.Module):
def forward(self, inputs: torch.Tensor, cuda_mode: bool = False): def forward(self, inputs: torch.Tensor, cuda_mode: bool = False):
# inputs: [s, h] # inputs: [s, h]
if self.noisy_func is not None: if self.noisy_func is not None and self.training:
inputs = self.noisy_func(inputs) inputs = self.noisy_func(inputs)
logits = autocast_softmax(inputs, dim=-1) # logits: [s, e] logits = autocast_softmax(inputs, dim=-1) # logits: [s, e]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment