"vscode:/vscode.git/clone" did not exist on "70022ffc002dabc59b936d3fc001b94b81ba08db"
Unverified Commit 064cac7b authored by Nikhil Gupta's avatar Nikhil Gupta Committed by GitHub
Browse files

[fix]: remove data type hardcoding from gptoss model implementation (#23807)


Signed-off-by: default avatarNikhil Gupta <nikhil.gupta2@arm.com>
parent e19bce40
......@@ -76,7 +76,6 @@ class OAIAttention(nn.Module):
self.sinks = torch.nn.Parameter(
torch.empty(config.num_attention_heads // tp_size,
dtype=torch.bfloat16,
requires_grad=False))
self.q_size = self.num_attention_heads * self.head_dim // tp_size
......@@ -145,8 +144,7 @@ class MLPBlock(torch.nn.Module):
self.experts_per_token = config.num_experts_per_tok
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
self.router = torch.nn.Linear(config.hidden_size,
config.num_local_experts,
dtype=torch.bfloat16)
config.num_local_experts)
assert config.intermediate_size % self.world_size == 0
self.experts = FusedMoE(num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment