Commit 89de2153 authored by Rick Ho's avatar Rick Ho
Browse files

zero gate update

parent 94eca783
......@@ -10,15 +10,18 @@ import torch.nn.functional as F
class ZeroGate(nn.Module):
def __init__(self, d_model, num_expert, world_size, top_k=2):
super().__init__()
self.top_k = top_k
def forward(self, inp):
r'''
The naive implementation simply calculates the top-k of a linear layer's
output.
'''
idx = torch.zeros(inp.shape[0], dtype=torch.int64, device=inp.device)
score = torch.ones(inp.shape[0], device=inp.device)
return idx, score.reshape(-1, 1, 1)
idx = torch.zeros(inp.shape[0] * self.top_k,
dtype=torch.int64, device=inp.device)
score = torch.ones(inp.shape[0] * self.top_k,
device=inp.device) / self.top_k
return idx, score.reshape(-1, 1, self.top_k)
class NaiveGate(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment