Unverified Commit c651eb23 authored by Wangyi Jiang's avatar Wangyi Jiang Committed by GitHub
Browse files

Simplify the implementation of jitter noise in moe models (#27643)

parent b54993aa
...@@ -188,17 +188,8 @@ class GPTSanJapaneseTop1Router(nn.Module): ...@@ -188,17 +188,8 @@ class GPTSanJapaneseTop1Router(nn.Module):
hidden_states = hidden_states.to(self.dtype) hidden_states = hidden_states.to(self.dtype)
if self.jitter_noise > 0: if self.jitter_noise > 0:
# Get the lower and upper bound of the uniform distribution
# Adapted from: https://stackoverflow.com/questions/44328530/how-to-get-a-uniform-distribution-in-a-range-r1-r2-in-pytorch
distrib_lower_bound = 1.0 - self.jitter_noise
distrib_upper_bound = 1.0 + self.jitter_noise
uniform_distrib = torch.rand(hidden_states.shape, device=hidden_states.device, dtype=self.dtype)
uniform_distrib = uniform_distrib * (distrib_lower_bound - distrib_upper_bound)
uniform_distrib = uniform_distrib + distrib_upper_bound
# Multiply the token inputs by the uniform distribution - adding some noise # Multiply the token inputs by the uniform distribution - adding some noise
hidden_states *= uniform_distrib hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
# Shape: [num_groups, tokens_per_group, num_experts] # Shape: [num_groups, tokens_per_group, num_experts]
self._cast_classifier() self._cast_classifier()
......
...@@ -169,17 +169,8 @@ class SwitchTransformersTop1Router(nn.Module): ...@@ -169,17 +169,8 @@ class SwitchTransformersTop1Router(nn.Module):
hidden_states = hidden_states.to(self.dtype) hidden_states = hidden_states.to(self.dtype)
if self.jitter_noise > 0: if self.jitter_noise > 0:
# Get the lower and upper bound of the uniform distribution
# Adapted from: https://stackoverflow.com/questions/44328530/how-to-get-a-uniform-distribution-in-a-range-r1-r2-in-pytorch
distrib_lower_bound = 1.0 - self.jitter_noise
distrib_upper_bound = 1.0 + self.jitter_noise
uniform_distrib = torch.rand(hidden_states.shape, device=hidden_states.device, dtype=self.dtype)
uniform_distrib = uniform_distrib * (distrib_lower_bound - distrib_upper_bound)
uniform_distrib = uniform_distrib + distrib_upper_bound
# Multiply the token inputs by the uniform distribution - adding some noise # Multiply the token inputs by the uniform distribution - adding some noise
hidden_states *= uniform_distrib hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
# Shape: [num_groups, tokens_per_group, num_experts] # Shape: [num_groups, tokens_per_group, num_experts]
self._cast_classifier() self._cast_classifier()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment