Commit 6e7fe25f authored by PanZezhong's avatar PanZezhong
Browse files

issue/304 fix clip

parent 8366330c
......@@ -13,6 +13,7 @@ def run_tests(args):
failed = []
for test in [
"add.py",
"clip.py",
"gemm.py",
"random_sample.py",
"rms_norm.py",
......@@ -22,7 +23,7 @@ def run_tests(args):
"attention.py",
"causal_softmax.py",
"rearrange.py",
"mul.py"
"mul.py",
]:
result = subprocess.run(
f"python {test} {args} --debug", text=True, encoding="utf-8", shell=True
......
......@@ -19,9 +19,10 @@ public:
#else
return {std::clamp(x.x, min_val.x, max_val.x), std::clamp(x.y, min_val.y, max_val.y)};
#endif
}
} else {
return std::clamp(x, min_val, max_val);
}
}
} ClipOp;
} // namespace op::clip::cuda
......
......@@ -48,10 +48,6 @@ _TEST_CASES_ = [
((10,), None, None, -1000.0, 1000.0), # 大范围
((10,), None, None, -0.001, 0.001), # 小范围
((10,), None, None, 0.0, 0.0), # min=max
# 特殊形状测试
((0,), None, None, -1.0, 1.0), # 空张量
((1, 0), None, None, -1.0, 1.0), # 空维度
]
......@@ -88,6 +84,8 @@ NUM_ITERATIONS = 1000
class ClipDescriptor(Structure):
_fields_ = [("device_type", c_int32), ("device_id", c_int32)]
infiniopClipDescriptor_t = POINTER(ClipDescriptor)
......@@ -95,25 +93,6 @@ def clip(x, min_val, max_val):
return torch.clamp(x, min_val, max_val)
def create_tensor_with_stride(shape, stride, dtype, device):
"""Create a tensor with specific stride without using view() that might cause errors."""
x = torch.rand(shape, dtype=dtype, device=device) * 4.0 - 2.0 # Range: [-2, 2]
if stride is None:
return x
if len(shape) == 2 and len(stride) == 2:
if stride == (shape[1], 1):
return x.contiguous()
elif stride == (1, shape[0]):
return x.transpose(0, 1).contiguous().transpose(0, 1)
else:
y = torch.zeros(shape, dtype=dtype, device=device)
for i in range(shape[0]):
for j in range(shape[1]):
y[i, j] = x[i, j]
return y.contiguous()
return x
def test(
lib,
handle,
......@@ -125,12 +104,13 @@ def test(
max_val=1.0,
inplace=Inplace.OUT_OF_PLACE,
dtype=torch.float32,
sync=None,
):
print(
f"Testing Clip on {torch_device} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} "
f"min_val:{min_val} max_val:{max_val} dtype:{dtype} inplace:{inplace}"
)
x = create_tensor_with_stride(shape, x_stride, dtype, torch_device)
x = torch.rand(shape, dtype=dtype).to(torch_device)
ans = clip(x, min_val, max_val)
x = rearrange_if_needed(x, x_stride)
x_tensor = to_tensor(x, lib)
......@@ -141,18 +121,34 @@ def test(
y = torch.zeros(shape, dtype=dtype).to(torch_device)
y = rearrange_if_needed(y, y_stride)
y_tensor = to_tensor(y, lib)
if sync is not None:
sync()
descriptor = infiniopClipDescriptor_t()
min_, max_ = torch.tensor([min_val], dtype=dtype).to(torch_device), torch.tensor(
[max_val], dtype=dtype
).to(torch_device)
min_tensor = to_tensor(
min_, lib, force_shape=shape, force_strides=[0 for _ in shape]
)
max_tensor = to_tensor(
max_, lib, force_shape=shape, force_strides=[0 for _ in shape]
)
check_error(
lib.infiniopCreateClipDescriptor(
handle, ctypes.byref(descriptor), y_tensor.descriptor, x_tensor.descriptor
handle,
ctypes.byref(descriptor),
y_tensor.descriptor,
x_tensor.descriptor,
min_tensor.descriptor,
max_tensor.descriptor,
)
)
workspace_size = c_uint64(0)
check_error(
lib.infiniopGetClipWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
lib.infiniopGetClipWorkspaceSize(descriptor, ctypes.byref(workspace_size))
)
workspace = create_workspace(workspace_size.value, x.device)
......@@ -164,8 +160,8 @@ def test(
workspace_size.value,
y_tensor.data,
x_tensor.data,
c_float(min_val),
c_float(max_val),
min_tensor.data,
max_tensor.data,
None,
)
)
......@@ -209,6 +205,8 @@ if __name__ == "__main__":
POINTER(infiniopClipDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
]
lib.infiniopGetClipWorkspaceSize.restype = c_int32
......@@ -224,8 +222,8 @@ if __name__ == "__main__":
c_uint64,
c_void_p,
c_void_p,
c_float,
c_float,
c_void_p,
c_void_p,
c_void_p,
]
......
......@@ -11,14 +11,22 @@ def check_error(status):
raise Exception("Error code " + str(status))
def to_tensor(tensor, lib, force_unsigned=False):
def to_tensor(tensor, lib, force_unsigned=False, force_shape=None, force_strides=None):
"""
Convert a PyTorch tensor to a library Tensor(descriptor, data).
"""
import torch
ndim = tensor.ndimension()
if force_shape is not None:
ndim = len(force_shape)
shape = (ctypes.c_size_t * ndim)(*force_shape)
else:
shape = (ctypes.c_size_t * ndim)(*tensor.shape)
if force_strides is not None:
ndim = len(force_strides)
strides = (ctypes.c_int64 * ndim)(*force_strides)
else:
strides = (ctypes.c_int64 * ndim)(*(tensor.stride()))
# fmt: off
dt = (
......
......@@ -31,6 +31,10 @@ target("infiniop-cuda")
add_cuflags("--extended-lambda")
add_culdflags("-Xcompiler=-fPIC")
add_cxxflags("-fPIC")
add_cuflags("--expt-relaxed-constexpr")
if CUDNN_ROOT ~= nil then
add_linkdirs(CUDNN_ROOT .. "/lib")
end
end
set_languages("cxx17")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment