Unverified Commit 88256082 authored by Zhang Jian's avatar Zhang Jian Committed by GitHub
Browse files

[Bugfix][CI] Fix wrong residual shape in TestFusedAddRMSNorm.example_inputs...


[Bugfix][CI] Fix wrong residual shape in TestFusedAddRMSNorm.example_inputs that causes flaky test (#40629)
Signed-off-by: default avatarZhang Jian <jianmusings@gmail.com>
parent 095d2f87
......@@ -117,9 +117,9 @@ class TestFusedAddRMSNorm(torch.nn.Module):
else:
return norm_output, residual_output
def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16):
hidden_states = torch.randn((batch_size * seq_len, hidden_size))
residual = torch.randn((batch_size * seq_len, hidden_size))
def example_inputs(self, batch_size=8, seq_len=16):
hidden_states = torch.randn((batch_size * seq_len, self.hidden_size))
residual = torch.randn((batch_size * seq_len, self.intermediate_size))
return (hidden_states, residual)
def ops_in_model(self, do_fusion):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment