from seek import * # noqa: F403 ('from seek import *' used; unable to detect undefined names) import unittest class TestGpr(unittest.TestCase): def _assert_gpr_not_view(self, gpr, rtype, name, count, align, repr_s): self.assertFalse(gpr.is_view) self.assertEqual(gpr.rtype, rtype) self.assertEqual(gpr.name, name) self.assertEqual(gpr.count, count) self.assertEqual(gpr.align, align) self.assertIs(gpr.base_gpr, gpr) self.assertEqual(gpr.base_offset, 0) self.assertEqual(repr(gpr), repr_s) def _assert_gpr_is_view(self, gpr, rtype, name, count, align, base_gpr, base_offset, repr_s): self.assertTrue(gpr.is_view) self.assertEqual(gpr.rtype, rtype) self.assertEqual(gpr.name, name) self.assertEqual(gpr.count, count) self.assertEqual(gpr.align, align) self.assertIs(gpr.base_gpr, base_gpr) self.assertEqual(gpr.base_offset, base_offset) self.assertEqual(repr(gpr), repr_s) def test_subscription_and_alias(self): v_value = new(name="value", rtype=GprType.V, count=4, align=4) self._assert_gpr_not_view(v_value, rtype=GprType.V, name="value", count=4, align=Align(4,0), repr_s="VGpr4('value', align=4)") v_value23 = v_value[2:3] self._assert_gpr_is_view(v_value23, rtype=GprType.V, name="value[2:3]", count=2, align=Align(4,2), base_gpr=v_value, base_offset=2, repr_s="VGpr2('@value[2:3]', align=(4,2))") self._assert_gpr_is_view(v_value23[0], rtype=GprType.V, name="value[2:3][0]", count=1, align=Align(4,2), base_gpr=v_value, base_offset=2, repr_s="VGpr1('@value[2:3][0]', ref=VGpr4('value', align=4)[2])") self._assert_gpr_is_view(v_value23[0:0], rtype=GprType.V, name="value[2:3][0:0]", count=1, align=Align(4,2), base_gpr=v_value, base_offset=2, repr_s="VGpr1('@value[2:3][0:0]', ref=VGpr4('value', align=4)[2])") self._assert_gpr_is_view(v_value23[0:1][1][0], rtype=GprType.V, name="value[2:3][0:1][1][0]", count=1, align=Align(4,3), base_gpr=v_value, base_offset=3, repr_s="VGpr1('@value[2:3][0:1][1][0]', ref=VGpr4('value', align=4)[3])") v_aliased = v_value23[0:1].alias('aliased') self._assert_gpr_is_view(v_aliased, rtype=GprType.V, name="aliased", count=2, align=Align(4,2), base_gpr=v_value, base_offset=2, repr_s="VGpr2('@aliased', ref=VGpr4('value', align=4)[2:3])") self._assert_gpr_is_view(v_aliased[1], rtype=GprType.V, name="aliased[1]", count=1, align=Align(4,3), base_gpr=v_value, base_offset=3, repr_s="VGpr1('@aliased[1]', ref=VGpr4('value', align=4)[3])") v_inferred = v_value[1:2][1].alias() # aliased name is inferred from assignment target 'v_inferred' self._assert_gpr_is_view(v_inferred, rtype=GprType.V, name="inferred", count=1, align=Align(4,2), base_gpr=v_value, base_offset=2, repr_s="VGpr1('@inferred', ref=VGpr4('value', align=4)[2])") def test_new_register(self): v_data = new() self._assert_gpr_not_view(v_data, rtype=GprType.V, name="data", count=1, align=Align(1,0), repr_s="VGpr1('data')") s_data2 = new(count=2) self._assert_gpr_not_view(s_data2, rtype=GprType.S, name="data2", count=2, align=Align(1,0), repr_s="SGpr2('data2')") s_list = new[3](count=2) for i in range(0, 3): self._assert_gpr_not_view(s_list[i], rtype=GprType.S, name=f"list_{i}", count=2, align=Align(1,0), repr_s=f"SGpr2('list_{i}')") v_matrix = new[3,4]() # default count is 1 for i in range(0, 3): for j in range(0, 4): self._assert_gpr_not_view(v_matrix[i,j], rtype=GprType.V, name=f"matrix_{i}_{j}", count=1, align=Align(1,0), repr_s=f"VGpr1('matrix_{i}_{j}')") v_tensor = new[3,4,5]() # default count is 1 for i in range(0, 3): for j in range(0, 4): for k in range(0, 5): self._assert_gpr_not_view(v_tensor[i,j,k], rtype=GprType.V, name=f"tensor_{i}_{j}_{k}", count=1, align=Align(1,0), repr_s=f"VGpr1('tensor_{i}_{j}_{k}')") # noinspection PyStatementEffect def test_invalid_subscription(self): v_data = new() # Should not throw v_data[0] v_data[0][0] v_data[0:0][0] # Should throw self.assertRaises(SeekException, lambda: v_data[1]) self.assertRaises(SeekException, lambda: v_data[1:1]) self.assertRaises(SeekException, lambda: v_data[0:1]) self.assertRaises(SeekException, lambda: v_data[-1]) self.assertRaises(SeekException, lambda: v_data[-1:0]) v_data2 = new(count=2) # Should not throw v_data2[0] v_data2[1] v_data2[0:1] v_data2[1:1] v_data2[0:1][1] # Should throw self.assertRaises(SeekException, lambda: v_data2[2]) self.assertRaises(SeekException, lambda: v_data2[1:2]) self.assertRaises(SeekException, lambda: v_data2[0:2]) self.assertRaises(SeekException, lambda: v_data2[-1]) self.assertRaises(SeekException, lambda: v_data2[-1:0]) self.assertRaises(SeekException, lambda: v_data2[0:1][2]) # noinspection PyStatementEffect def test_neg_abs_sext(self): v_value = new(count=2) # Should not throw -v_value abs(v_value) -abs(v_value[0]) sext(v_value[1]) # Should throw self.assertRaises(SeekException, lambda: abs(-v_value)) self.assertRaises(SeekException, lambda: abs(abs(v_value))) self.assertRaises(SeekException, lambda: -(-v_value)) self.assertRaises(SeekException, lambda: sext(-v_value)) self.assertRaises(SeekException, lambda: sext(abs(v_value))) self.assertRaises(SeekException, lambda: abs(sext(v_value))) self.assertRaises(SeekException, lambda: sext(sext(v_value))) self.assertRaises(SeekException, lambda: -(sext(v_value))) # repr self.assertEqual(repr(-v_value), "-VGpr2('@value')") self.assertEqual(repr(abs(v_value)), "|VGpr2|('@value')") self.assertEqual(repr(-abs(v_value)), "-|VGpr2|('@value')") self.assertEqual(repr(sext(v_value)), "VGpr2.sext('@value')") self.assertEqual(repr(-v_value[1]), "-VGpr1('@value[1]')") self.assertEqual(repr(abs(v_value[1]).alias('xxx')), "|VGpr1|('@xxx', ref=VGpr2('value')[1])") self.assertEqual(repr(-abs(v_value[0])), "-|VGpr1|('@value[0]')") self.assertEqual(repr(sext(v_value[1:1][0])), "VGpr1.sext('@value[1:1][0]', ref=VGpr2('value')[1])") def test_compare(self): self.assertTrue(vcc[0] == vcc_lo) self.assertTrue(vcc[1] == vcc_hi) self.assertTrue({vcc[1], vcc_lo, vcc[0:1]} == {vcc, vcc[0], vcc_hi}) # Gpr is hashable self.assertTrue(vcc[0] == -vcc[0]) # neg() is not considered in compare self.assertTrue(vcc_lo == abs(vcc[0])) # abs() is not considered in compare self.assertTrue(vcc[0] == -abs(vcc_lo)) # neg(), abs() are not considered in compare self.assertTrue(vcc[0] == sext(vcc_lo)) # sext() is not considered in compare # Gprs are ordered v_r1 = new() s_r2 = new() self.assertTrue(v_r1 != s_r2) self.assertFalse(v_r1 == s_r2) self.assertTrue(v_r1 < s_r2) self.assertTrue(v_r1 <= s_r2) # Compare with None objNone = None # noinspection DuplicatedCode self.assertTrue(threadIdx.x != objNone) self.assertFalse(threadIdx.x == objNone) self.assertRaises(TypeError, lambda: threadIdx.x < objNone) self.assertRaises(TypeError, lambda: threadIdx.x <= objNone) self.assertRaises(TypeError, lambda: threadIdx.x > objNone) self.assertRaises(TypeError, lambda: threadIdx.x >= objNone) # Compare with types other than Gpr objInt = 42 # noinspection DuplicatedCode self.assertTrue(threadIdx.x != objInt) self.assertFalse(threadIdx.x == objInt) self.assertRaises(TypeError, lambda: threadIdx.x < objInt) self.assertRaises(TypeError, lambda: threadIdx.x <= objInt) self.assertRaises(TypeError, lambda: threadIdx.x > objInt) self.assertRaises(TypeError, lambda: threadIdx.x >= objInt) class TestGprSet(unittest.TestCase): def test_is_superset(self): v_data = new(count=4) s_data = new(count=4, align=2) gprset = GprSet(v_data[0], vcc_hi, threadIdx.x, vcc, s_data[2:3], execz, exec, v_data[3], xnack_mask_lo) gprset._sanity_check() self.assertTrue(gprset.is_superset(vcc_lo)) self.assertTrue(gprset.is_superset(vcc_hi)) self.assertTrue(gprset.is_superset(vcc)) self.assertTrue(gprset.is_superset(xnack_mask_lo)) self.assertTrue(gprset.is_superset(s_data[2])) self.assertTrue(gprset.is_superset(s_data[3])) self.assertTrue(gprset.is_superset(s_data[2:3])) self.assertFalse(gprset.is_superset(s_data)) self.assertFalse(gprset.is_superset(v_data)) self.assertFalse(gprset.is_superset(v_data[2:3])) self.assertFalse(gprset.is_superset(xnack_mask_hi)) self.assertFalse(gprset.is_superset(xnack_mask)) def test_is_intersected(self): v_data = new(count=4) s_data = new(count=4, align=2) gprset = GprSet(v_data[0], vcc_hi, threadIdx.x, vcc, s_data[2:3], execz, exec, v_data[3], xnack_mask_lo) gprset._sanity_check() self.assertTrue(gprset.is_intersected(vcc_lo)) self.assertTrue(gprset.is_intersected(vcc_hi)) self.assertTrue(gprset.is_intersected(vcc)) self.assertTrue(gprset.is_intersected(exec_hi)) self.assertTrue(gprset.is_intersected(xnack_mask)) self.assertTrue(gprset.is_intersected(s_data[2])) self.assertTrue(gprset.is_intersected(s_data[3])) self.assertTrue(gprset.is_intersected(s_data)) self.assertTrue(gprset.is_intersected(v_data[0])) self.assertTrue(gprset.is_intersected(v_data)) self.assertFalse(gprset.is_intersected(threadIdx.y)) self.assertFalse(gprset.is_intersected(blockIdx.z)) self.assertFalse(gprset.is_intersected(vccz)) self.assertFalse(gprset.is_intersected(lds_direct)) def test_set_operation(self): v_data = new(count=4) s_data = new(count=4, align=2) gprset = GprSet(v_data[0], vcc_hi, threadIdx.x, vcc, s_data[2:3], execz, exec, v_data[3]) gprset._sanity_check() self.assertListEqual(gprset.base_gprs, [vcc, exec, execz, threadIdx.x, v_data, s_data]) self.assertListEqual(gprset.expanded_gprs, [vcc_lo, vcc_hi, exec_lo, exec_hi, execz, threadIdx.x, v_data[0], v_data[3], s_data[2], s_data[3]]) self.assertTrue(gprset) tmp = gprset.union(v_data, s_kernarg) tmp._sanity_check() self.assertListEqual(tmp.base_gprs, [vcc, exec, execz, s_kernarg, threadIdx.x, v_data, s_data]) self.assertListEqual(tmp.expanded_gprs, [vcc_lo, vcc_hi, exec_lo, exec_hi, execz, s_kernarg[0], s_kernarg[1], threadIdx.x, v_data[0], v_data[1], v_data[2], v_data[3], s_data[2], s_data[3]]) tmp = gprset.difference(exec_hi, s_data[1:2], exec_lo, vcc_hi, blockIdx.x) tmp._sanity_check() self.assertListEqual(tmp.base_gprs, [vcc, execz, threadIdx.x, v_data, s_data]) self.assertListEqual(tmp.expanded_gprs, [vcc_lo, execz, threadIdx.x, v_data[0], v_data[3], s_data[3]]) tmp = gprset.intersect(GprSet(exec_hi, s_data[1:2], exec_lo, vcc_hi, blockIdx.x)) tmp._sanity_check() self.assertListEqual(tmp.base_gprs, [vcc, exec, s_data]) self.assertListEqual(tmp.expanded_gprs, [vcc_hi, exec_lo, exec_hi, s_data[2]]) if __name__ == '__main__': unittest.main()