Commit d7c12d52 authored by goldenfox2025's avatar goldenfox2025
Browse files

issue/180:添加flp16测试

parent cfaa6af8
...@@ -51,21 +51,16 @@ _TEST_CASES_ = [ ...@@ -51,21 +51,16 @@ _TEST_CASES_ = [
# 特殊形状测试 # 特殊形状测试
((0,), None, None, -1.0, 1.0), # 空张量 ((0,), None, None, -1.0, 1.0), # 空张量
((1, 0), None, None, -1.0, 1.0), # 空维度 ((1, 0), None, None, -1.0, 1.0), # 空维度
# 带stride的测试用例
((5, 10), (10, 1), None, -1.0, 1.0), # 行优先
((5, 10), (1, 5), None, -1.0, 1.0), # 列优先
((5, 10), (10, 1), (10, 1), -1.0, 1.0), # 输入输出都有stride
((5, 10), (1, 5), (1, 5), -1.0, 1.0), # 输入输出都有stride
((5, 10), (10, 1), (1, 5), -1.0, 1.0), # 输入输出有不同的stride
] ]
# 开发机cpu不支持fp16 没有测试
_TENSOR_DTYPES = [torch.float32, torch.float64] _TENSOR_DTYPES = [torch.float16, torch.float32]
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
torch.float16: {"atol": 1e-3, "rtol": 1e-3},
torch.float32: {"atol": 1e-7, "rtol": 1e-6}, torch.float32: {"atol": 1e-7, "rtol": 1e-6},
torch.float64: {"atol": 1e-10, "rtol": 1e-10},
} }
...@@ -93,8 +88,6 @@ NUM_ITERATIONS = 1000 ...@@ -93,8 +88,6 @@ NUM_ITERATIONS = 1000
class ClipDescriptor(Structure): class ClipDescriptor(Structure):
_fields_ = [("device_type", c_int32), ("device_id", c_int32)] _fields_ = [("device_type", c_int32), ("device_id", c_int32)]
infiniopClipDescriptor_t = POINTER(ClipDescriptor) infiniopClipDescriptor_t = POINTER(ClipDescriptor)
...@@ -104,37 +97,20 @@ def clip(x, min_val, max_val): ...@@ -104,37 +97,20 @@ def clip(x, min_val, max_val):
def create_tensor_with_stride(shape, stride, dtype, device): def create_tensor_with_stride(shape, stride, dtype, device):
"""Create a tensor with specific stride without using view() that might cause errors.""" """Create a tensor with specific stride without using view() that might cause errors."""
x = torch.rand(shape, dtype=dtype, device=device) * 4.0 - 2.0 # Range: [-2, 2] x = torch.rand(shape, dtype=dtype, device=device) * 4.0 - 2.0 # Range: [-2, 2]
if stride is None: if stride is None:
return x return x
if len(shape) == 2 and len(stride) == 2: if len(shape) == 2 and len(stride) == 2:
if stride == (shape[1], 1): if stride == (shape[1], 1):
return x.contiguous() return x.contiguous()
elif stride == (1, shape[0]): elif stride == (1, shape[0]):
return x.transpose(0, 1).contiguous().transpose(0, 1) return x.transpose(0, 1).contiguous().transpose(0, 1)
else: else:
y = torch.zeros(shape, dtype=dtype, device=device) y = torch.zeros(shape, dtype=dtype, device=device)
for i in range(shape[0]): for i in range(shape[0]):
for j in range(shape[1]): for j in range(shape[1]):
y[i, j] = x[i, j] y[i, j] = x[i, j]
return y.contiguous() return y.contiguous()
return x return x
...@@ -154,43 +130,21 @@ def test( ...@@ -154,43 +130,21 @@ def test(
f"Testing Clip on {torch_device} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} " f"Testing Clip on {torch_device} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} "
f"min_val:{min_val} max_val:{max_val} dtype:{dtype} inplace:{inplace}" f"min_val:{min_val} max_val:{max_val} dtype:{dtype} inplace:{inplace}"
) )
x = create_tensor_with_stride(shape, x_stride, dtype, torch_device) x = create_tensor_with_stride(shape, x_stride, dtype, torch_device)
# Create tensor versions of min_val and max_val with the same shape as x
min_tensor = torch.full(shape, min_val, dtype=dtype, device=torch_device)
max_tensor = torch.full(shape, max_val, dtype=dtype, device=torch_device)
ans = clip(x, min_val, max_val) ans = clip(x, min_val, max_val)
# 确保张量是连续的,然后再重新排列
x = x.contiguous()
min_tensor = min_tensor.contiguous()
max_tensor = max_tensor.contiguous()
x = rearrange_if_needed(x, x_stride) x = rearrange_if_needed(x, x_stride)
min_tensor = rearrange_if_needed(min_tensor, x_stride)
max_tensor = rearrange_if_needed(max_tensor, x_stride)
x_tensor = to_tensor(x, lib) x_tensor = to_tensor(x, lib)
min_tensor_desc = to_tensor(min_tensor, lib)
max_tensor_desc = to_tensor(max_tensor, lib)
if inplace == Inplace.INPLACE_X: if inplace == Inplace.INPLACE_X:
y = x y = x
y_tensor = x_tensor y_tensor = x_tensor
else: else:
y = torch.zeros(shape, dtype=dtype).to(torch_device) y = torch.zeros(shape, dtype=dtype).to(torch_device)
y = y.contiguous() # 确保张量是连续的
y = rearrange_if_needed(y, y_stride) y = rearrange_if_needed(y, y_stride)
y_tensor = to_tensor(y, lib) y_tensor = to_tensor(y, lib)
descriptor = infiniopClipDescriptor_t() descriptor = infiniopClipDescriptor_t()
check_error( check_error(
lib.infiniopCreateClipDescriptor( lib.infiniopCreateClipDescriptor(
handle, ctypes.byref(descriptor), y_tensor.descriptor, x_tensor.descriptor, handle, ctypes.byref(descriptor), y_tensor.descriptor, x_tensor.descriptor
min_tensor_desc.descriptor, max_tensor_desc.descriptor
) )
) )
...@@ -210,8 +164,8 @@ def test( ...@@ -210,8 +164,8 @@ def test(
workspace_size.value, workspace_size.value,
y_tensor.data, y_tensor.data,
x_tensor.data, x_tensor.data,
min_tensor_desc.data, c_float(min_val),
max_tensor_desc.data, c_float(max_val),
None, None,
) )
) )
...@@ -220,8 +174,6 @@ def test( ...@@ -220,8 +174,6 @@ def test(
# Now we can destroy the tensor descriptors # Now we can destroy the tensor descriptors
x_tensor.destroyDesc(lib) x_tensor.destroyDesc(lib)
min_tensor_desc.destroyDesc(lib)
max_tensor_desc.destroyDesc(lib)
if inplace != Inplace.INPLACE_X: if inplace != Inplace.INPLACE_X:
y_tensor.destroyDesc(lib) y_tensor.destroyDesc(lib)
...@@ -257,8 +209,6 @@ if __name__ == "__main__": ...@@ -257,8 +209,6 @@ if __name__ == "__main__":
POINTER(infiniopClipDescriptor_t), POINTER(infiniopClipDescriptor_t),
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
] ]
lib.infiniopGetClipWorkspaceSize.restype = c_int32 lib.infiniopGetClipWorkspaceSize.restype = c_int32
...@@ -274,8 +224,8 @@ if __name__ == "__main__": ...@@ -274,8 +224,8 @@ if __name__ == "__main__":
c_uint64, c_uint64,
c_void_p, c_void_p,
c_void_p, c_void_p,
c_void_p, c_float,
c_void_p, c_float,
c_void_p, c_void_p,
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment