Commit da77b216 authored by baberabb's avatar baberabb
Browse files

added indices equality check

parent 72bb97ca
...@@ -91,6 +91,11 @@ class Test_HFLM: ...@@ -91,6 +91,11 @@ class Test_HFLM:
res = self.LM.loglikelihood(self.MULTIPLE_CH) res = self.LM.loglikelihood(self.MULTIPLE_CH)
_RES, _res = [r[0] for r in self.MULTIPLE_CH_RES], [r[0] for r in res] _RES, _res = [r[0] for r in self.MULTIPLE_CH_RES], [r[0] for r in res]
assert np.allclose(_res, _RES, atol=1e-2) assert np.allclose(_res, _RES, atol=1e-2)
# check indices for Multiple Choice
argmax_RES, argmax_res = np.argmax(
np.array(_RES).reshape(-1, 4), axis=1
), np.argmax(np.array(_res).reshape(-1, 4), axis=1)
assert (argmax_RES == argmax_res).all()
def test_greedy_until(self) -> None: def test_greedy_until(self) -> None:
res = self.LM.greedy_until(self.GREEDY_UNTIL) res = self.LM.greedy_until(self.GREEDY_UNTIL)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment