from seek import *  # noqa: F403 ('from seek import *' used; unable to detect undefined names)
import unittest


class TestGpr(unittest.TestCase):
    def _assert_gpr_not_view(self, gpr, rtype, name, count, align, repr_s):
        self.assertFalse(gpr.is_view)
        self.assertEqual(gpr.rtype, rtype)
        self.assertEqual(gpr.name, name)
        self.assertEqual(gpr.count, count)
        self.assertEqual(gpr.align, align)
        self.assertIs(gpr.base_gpr, gpr)
        self.assertEqual(gpr.base_offset, 0)
        self.assertEqual(repr(gpr), repr_s)

    def _assert_gpr_is_view(self, gpr, rtype, name, count, align, base_gpr, base_offset, repr_s):
        self.assertTrue(gpr.is_view)
        self.assertEqual(gpr.rtype, rtype)
        self.assertEqual(gpr.name, name)
        self.assertEqual(gpr.count, count)
        self.assertEqual(gpr.align, align)
        self.assertIs(gpr.base_gpr, base_gpr)
        self.assertEqual(gpr.base_offset, base_offset)
        self.assertEqual(repr(gpr), repr_s)

    def test_subscription_and_alias(self):
        v_value = new(name="value", rtype=GprType.V, count=4, align=4)
        self._assert_gpr_not_view(v_value, rtype=GprType.V, name="value", count=4, align=Align(4,0),
                                  repr_s="VGpr4('value', align=4)")

        v_value23 = v_value[2:3]
        self._assert_gpr_is_view(v_value23, rtype=GprType.V, name="value[2:3]",
                                 count=2, align=Align(4,2), base_gpr=v_value, base_offset=2,
                                 repr_s="VGpr2('@value[2:3]', align=(4,2))")

        self._assert_gpr_is_view(v_value23[0], rtype=GprType.V, name="value[2:3][0]",
                                 count=1, align=Align(4,2), base_gpr=v_value, base_offset=2,
                                 repr_s="VGpr1('@value[2:3][0]', ref=VGpr4('value', align=4)[2])")

        self._assert_gpr_is_view(v_value23[0:0], rtype=GprType.V, name="value[2:3][0:0]",
                                 count=1, align=Align(4,2), base_gpr=v_value, base_offset=2,
                                 repr_s="VGpr1('@value[2:3][0:0]', ref=VGpr4('value', align=4)[2])")

        self._assert_gpr_is_view(v_value23[0:1][1][0], rtype=GprType.V, name="value[2:3][0:1][1][0]",
                                 count=1, align=Align(4,3), base_gpr=v_value, base_offset=3,
                                 repr_s="VGpr1('@value[2:3][0:1][1][0]', ref=VGpr4('value', align=4)[3])")

        v_aliased = v_value23[0:1].alias('aliased')
        self._assert_gpr_is_view(v_aliased, rtype=GprType.V, name="aliased",
                                 count=2, align=Align(4,2), base_gpr=v_value, base_offset=2,
                                 repr_s="VGpr2('@aliased', ref=VGpr4('value', align=4)[2:3])")

        self._assert_gpr_is_view(v_aliased[1], rtype=GprType.V, name="aliased[1]",
                                 count=1, align=Align(4,3), base_gpr=v_value, base_offset=3,
                                 repr_s="VGpr1('@aliased[1]', ref=VGpr4('value', align=4)[3])")

        v_inferred = v_value[1:2][1].alias()  # aliased name is inferred from assignment target 'v_inferred'
        self._assert_gpr_is_view(v_inferred, rtype=GprType.V, name="inferred",
                                 count=1, align=Align(4,2), base_gpr=v_value, base_offset=2,
                                 repr_s="VGpr1('@inferred', ref=VGpr4('value', align=4)[2])")

    def test_new_register(self):
        v_data = new()
        self._assert_gpr_not_view(v_data, rtype=GprType.V, name="data",
                                  count=1, align=Align(1,0), repr_s="VGpr1('data')")

        s_data2 = new(count=2)
        self._assert_gpr_not_view(s_data2, rtype=GprType.S, name="data2",
                                  count=2, align=Align(1,0), repr_s="SGpr2('data2')")

        s_list = new[3](count=2)
        for i in range(0, 3):
            self._assert_gpr_not_view(s_list[i], rtype=GprType.S, name=f"list_{i}",
                                      count=2, align=Align(1,0), repr_s=f"SGpr2('list_{i}')")

        v_matrix = new[3,4]()  # default count is 1
        for i in range(0, 3):
            for j in range(0, 4):
                self._assert_gpr_not_view(v_matrix[i,j], rtype=GprType.V, name=f"matrix_{i}_{j}",
                                          count=1, align=Align(1,0), repr_s=f"VGpr1('matrix_{i}_{j}')")

        v_tensor = new[3,4,5]()  # default count is 1
        for i in range(0, 3):
            for j in range(0, 4):
                for k in range(0, 5):
                    self._assert_gpr_not_view(v_tensor[i,j,k], rtype=GprType.V, name=f"tensor_{i}_{j}_{k}",
                                              count=1, align=Align(1,0), repr_s=f"VGpr1('tensor_{i}_{j}_{k}')")

    # noinspection PyStatementEffect
    def test_invalid_subscription(self):
        v_data = new()

        # Should not throw
        v_data[0]
        v_data[0][0]
        v_data[0:0][0]

        # Should throw
        self.assertRaises(SeekException, lambda: v_data[1])
        self.assertRaises(SeekException, lambda: v_data[1:1])
        self.assertRaises(SeekException, lambda: v_data[0:1])
        self.assertRaises(SeekException, lambda: v_data[-1])
        self.assertRaises(SeekException, lambda: v_data[-1:0])

        v_data2 = new(count=2)

        # Should not throw
        v_data2[0]
        v_data2[1]
        v_data2[0:1]
        v_data2[1:1]
        v_data2[0:1][1]

        # Should throw
        self.assertRaises(SeekException, lambda: v_data2[2])
        self.assertRaises(SeekException, lambda: v_data2[1:2])
        self.assertRaises(SeekException, lambda: v_data2[0:2])
        self.assertRaises(SeekException, lambda: v_data2[-1])
        self.assertRaises(SeekException, lambda: v_data2[-1:0])
        self.assertRaises(SeekException, lambda: v_data2[0:1][2])

    # noinspection PyStatementEffect
    def test_neg_abs_sext(self):
        v_value = new(count=2)

        # Should not throw
        -v_value
        abs(v_value)
        -abs(v_value[0])
        sext(v_value[1])

        # Should throw
        self.assertRaises(SeekException, lambda: abs(-v_value))
        self.assertRaises(SeekException, lambda: abs(abs(v_value)))
        self.assertRaises(SeekException, lambda: -(-v_value))
        self.assertRaises(SeekException, lambda: sext(-v_value))
        self.assertRaises(SeekException, lambda: sext(abs(v_value)))
        self.assertRaises(SeekException, lambda: abs(sext(v_value)))
        self.assertRaises(SeekException, lambda: sext(sext(v_value)))
        self.assertRaises(SeekException, lambda: -(sext(v_value)))

        # repr
        self.assertEqual(repr(-v_value), "-VGpr2('@value')")
        self.assertEqual(repr(abs(v_value)), "|VGpr2|('@value')")
        self.assertEqual(repr(-abs(v_value)), "-|VGpr2|('@value')")
        self.assertEqual(repr(sext(v_value)), "VGpr2.sext('@value')")

        self.assertEqual(repr(-v_value[1]), "-VGpr1('@value[1]')")
        self.assertEqual(repr(abs(v_value[1]).alias('xxx')), "|VGpr1|('@xxx', ref=VGpr2('value')[1])")
        self.assertEqual(repr(-abs(v_value[0])), "-|VGpr1|('@value[0]')")
        self.assertEqual(repr(sext(v_value[1:1][0])), "VGpr1.sext('@value[1:1][0]', ref=VGpr2('value')[1])")

    def test_compare(self):
        self.assertTrue(vcc[0] == vcc_lo)
        self.assertTrue(vcc[1] == vcc_hi)
        self.assertTrue({vcc[1], vcc_lo, vcc[0:1]} == {vcc, vcc[0], vcc_hi})  # Gpr is hashable
        self.assertTrue(vcc[0] == -vcc[0])  # neg() is not considered in compare
        self.assertTrue(vcc_lo == abs(vcc[0]))  # abs() is not considered in compare
        self.assertTrue(vcc[0] == -abs(vcc_lo))  # neg(), abs() are not considered in compare
        self.assertTrue(vcc[0] == sext(vcc_lo))  # sext() is not considered in compare

        # Gprs are ordered
        v_r1 = new()
        s_r2 = new()
        self.assertTrue(v_r1 != s_r2)
        self.assertFalse(v_r1 == s_r2)
        self.assertTrue(v_r1 < s_r2)
        self.assertTrue(v_r1 <= s_r2)

        # Compare with None
        objNone = None
        # noinspection DuplicatedCode
        self.assertTrue(threadIdx.x != objNone)
        self.assertFalse(threadIdx.x == objNone)
        self.assertRaises(TypeError, lambda: threadIdx.x < objNone)
        self.assertRaises(TypeError, lambda: threadIdx.x <= objNone)
        self.assertRaises(TypeError, lambda: threadIdx.x > objNone)
        self.assertRaises(TypeError, lambda: threadIdx.x >= objNone)

        # Compare with types other than Gpr
        objInt = 42
        # noinspection DuplicatedCode
        self.assertTrue(threadIdx.x != objInt)
        self.assertFalse(threadIdx.x == objInt)
        self.assertRaises(TypeError, lambda: threadIdx.x < objInt)
        self.assertRaises(TypeError, lambda: threadIdx.x <= objInt)
        self.assertRaises(TypeError, lambda: threadIdx.x > objInt)
        self.assertRaises(TypeError, lambda: threadIdx.x >= objInt)


class TestGprSet(unittest.TestCase):
    def test_is_superset(self):
        v_data = new(count=4)
        s_data = new(count=4, align=2)

        gprset = GprSet(v_data[0], vcc_hi, threadIdx.x, vcc, s_data[2:3], execz, exec, v_data[3], xnack_mask_lo)
        gprset._sanity_check()

        self.assertTrue(gprset.is_superset(vcc_lo))
        self.assertTrue(gprset.is_superset(vcc_hi))
        self.assertTrue(gprset.is_superset(vcc))
        self.assertTrue(gprset.is_superset(xnack_mask_lo))
        self.assertTrue(gprset.is_superset(s_data[2]))
        self.assertTrue(gprset.is_superset(s_data[3]))
        self.assertTrue(gprset.is_superset(s_data[2:3]))

        self.assertFalse(gprset.is_superset(s_data))
        self.assertFalse(gprset.is_superset(v_data))
        self.assertFalse(gprset.is_superset(v_data[2:3]))
        self.assertFalse(gprset.is_superset(xnack_mask_hi))
        self.assertFalse(gprset.is_superset(xnack_mask))

    def test_is_intersected(self):
        v_data = new(count=4)
        s_data = new(count=4, align=2)

        gprset = GprSet(v_data[0], vcc_hi, threadIdx.x, vcc, s_data[2:3], execz, exec, v_data[3], xnack_mask_lo)
        gprset._sanity_check()

        self.assertTrue(gprset.is_intersected(vcc_lo))
        self.assertTrue(gprset.is_intersected(vcc_hi))
        self.assertTrue(gprset.is_intersected(vcc))
        self.assertTrue(gprset.is_intersected(exec_hi))
        self.assertTrue(gprset.is_intersected(xnack_mask))
        self.assertTrue(gprset.is_intersected(s_data[2]))
        self.assertTrue(gprset.is_intersected(s_data[3]))
        self.assertTrue(gprset.is_intersected(s_data))
        self.assertTrue(gprset.is_intersected(v_data[0]))
        self.assertTrue(gprset.is_intersected(v_data))

        self.assertFalse(gprset.is_intersected(threadIdx.y))
        self.assertFalse(gprset.is_intersected(blockIdx.z))
        self.assertFalse(gprset.is_intersected(vccz))
        self.assertFalse(gprset.is_intersected(lds_direct))

    def test_set_operation(self):
        v_data = new(count=4)
        s_data = new(count=4, align=2)

        gprset = GprSet(v_data[0], vcc_hi, threadIdx.x, vcc, s_data[2:3], execz, exec, v_data[3])
        gprset._sanity_check()
        self.assertListEqual(gprset.base_gprs, [vcc, exec, execz, threadIdx.x, v_data, s_data])
        self.assertListEqual(gprset.expanded_gprs,
                             [vcc_lo, vcc_hi, exec_lo, exec_hi, execz, threadIdx.x,
                              v_data[0], v_data[3], s_data[2], s_data[3]])
        self.assertTrue(gprset)

        tmp = gprset.union(v_data, s_kernarg)
        tmp._sanity_check()
        self.assertListEqual(tmp.base_gprs, [vcc, exec, execz, s_kernarg, threadIdx.x, v_data, s_data])
        self.assertListEqual(tmp.expanded_gprs,
                             [vcc_lo, vcc_hi, exec_lo, exec_hi, execz, s_kernarg[0], s_kernarg[1], threadIdx.x,
                              v_data[0], v_data[1], v_data[2], v_data[3], s_data[2], s_data[3]])

        tmp = gprset.difference(exec_hi, s_data[1:2], exec_lo, vcc_hi, blockIdx.x)
        tmp._sanity_check()
        self.assertListEqual(tmp.base_gprs, [vcc, execz, threadIdx.x, v_data, s_data])
        self.assertListEqual(tmp.expanded_gprs,
                             [vcc_lo, execz, threadIdx.x, v_data[0], v_data[3], s_data[3]])

        tmp = gprset.intersect(GprSet(exec_hi, s_data[1:2], exec_lo, vcc_hi, blockIdx.x))
        tmp._sanity_check()
        self.assertListEqual(tmp.base_gprs, [vcc, exec, s_data])
        self.assertListEqual(tmp.expanded_gprs,
                             [vcc_hi, exec_lo, exec_hi, s_data[2]])


if __name__ == '__main__':
    unittest.main()
