Unverified Commit fb5e36d2 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #793 from InfiniTensor/issue/792

issue/792 - add moore threads device sync to tests
parents 1f871ea9 0017fa0b
......@@ -13,6 +13,8 @@ def synchronize_device(torch_device):
torch.npu.synchronize()
elif torch_device == "mlu":
torch.mlu.synchronize()
elif torch_device == "musa":
torch.musa.synchronize()
def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
......
......@@ -71,17 +71,32 @@ class TestTensor(CTensor):
torch_shape.append(shape[i])
if mode == "random":
# For integer types, use randint instead of rand
if dt in [InfiniDtype.I8, InfiniDtype.I16, InfiniDtype.I32, InfiniDtype.I64,
InfiniDtype.U8, InfiniDtype.U16, InfiniDtype.U32, InfiniDtype.U64,
InfiniDtype.BYTE, InfiniDtype.BOOL]:
if dt in [
InfiniDtype.I8,
InfiniDtype.I16,
InfiniDtype.I32,
InfiniDtype.I64,
InfiniDtype.U8,
InfiniDtype.U16,
InfiniDtype.U32,
InfiniDtype.U64,
InfiniDtype.BYTE,
InfiniDtype.BOOL,
]:
randint_low = -2000000000 if randint_low is None else randint_low
randint_high = 2000000000 if randint_high is None else randint_high
self._torch_tensor = torch.randint(
randint_low, randint_high, torch_shape, dtype=to_torch_dtype(dt), device=torch_device_map[device]
randint_low,
randint_high,
torch_shape,
dtype=to_torch_dtype(dt),
device=torch_device_map[device],
)
else:
self._torch_tensor = torch.rand(
torch_shape, dtype=to_torch_dtype(dt), device=torch_device_map[device]
torch_shape,
dtype=to_torch_dtype(dt),
device=torch_device_map[device],
)
elif mode == "zeros":
self._torch_tensor = torch.zeros(
......@@ -431,6 +446,8 @@ def synchronize_device(torch_device):
torch.npu.synchronize()
elif torch_device == "mlu":
torch.mlu.synchronize()
elif torch_device == "musa":
torch.musa.synchronize()
def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
......@@ -463,7 +480,14 @@ def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
def filter_tensor_dtypes_by_device(device, tensor_dtypes):
if device in (InfiniDeviceEnum.CPU, InfiniDeviceEnum.NVIDIA, InfiniDeviceEnum.METAX, InfiniDeviceEnum.ASCEND, InfiniDeviceEnum.ILUVATAR, InfiniDeviceEnum.CAMBRICON):
if device in (
InfiniDeviceEnum.CPU,
InfiniDeviceEnum.NVIDIA,
InfiniDeviceEnum.METAX,
InfiniDeviceEnum.ASCEND,
InfiniDeviceEnum.ILUVATAR,
InfiniDeviceEnum.CAMBRICON,
):
return tensor_dtypes
else:
# 过滤掉 torch.bfloat16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment