Commit 89de2153 authored by Rick Ho's avatar Rick Ho
Browse files

zero gate update

parent 94eca783
...@@ -10,15 +10,18 @@ import torch.nn.functional as F ...@@ -10,15 +10,18 @@ import torch.nn.functional as F
class ZeroGate(nn.Module): class ZeroGate(nn.Module):
def __init__(self, d_model, num_expert, world_size, top_k=2): def __init__(self, d_model, num_expert, world_size, top_k=2):
super().__init__() super().__init__()
self.top_k = top_k
def forward(self, inp): def forward(self, inp):
r''' r'''
The naive implementation simply calculates the top-k of a linear layer's The naive implementation simply calculates the top-k of a linear layer's
output. output.
''' '''
idx = torch.zeros(inp.shape[0], dtype=torch.int64, device=inp.device) idx = torch.zeros(inp.shape[0] * self.top_k,
score = torch.ones(inp.shape[0], device=inp.device) dtype=torch.int64, device=inp.device)
return idx, score.reshape(-1, 1, 1) score = torch.ones(inp.shape[0] * self.top_k,
device=inp.device) / self.top_k
return idx, score.reshape(-1, 1, self.top_k)
class NaiveGate(nn.Module): class NaiveGate(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment