Commit ffb86c66 authored by liam's avatar liam
Browse files

fix experts torch

parent de082f14
...@@ -459,9 +459,9 @@ class KExpertsTorch(KExpertsBase): ...@@ -459,9 +459,9 @@ class KExpertsTorch(KExpertsBase):
self.up[i] = w["up"][i, ...].to(device=device, dtype=self.dtype) self.up[i] = w["up"][i, ...].to(device=device, dtype=self.dtype)
self.down[i] = w["down"][i, ...].to(device=device, dtype=self.dtype) self.down[i] = w["down"][i, ...].to(device=device, dtype=self.dtype)
self.up = torch.cat(self.up, dim=0) self.up = torch.stack(self.up, dim=0)
self.gate = torch.cat(self.gate, dim=0) self.gate = torch.stack(self.gate, dim=0)
self.down = torch.cat(self.down, dim=0) self.down = torch.stack(self.down, dim=0)
return return
def unload(self): def unload(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment