Unverified Commit 3eb14921 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #210 from InfiniTensor/issue-209-elementwise-kunlun

Issue/209 elementwise kunlun
parents 1a4cfb99 cda0ccba
......@@ -51,6 +51,7 @@ def test(
torch_device,
x_shape,
tensor_dtype=torch.float16,
sync=None
):
print(
f"Testing GlobalAvgPool on {torch_device} with input tensor_shape: {x_shape} dtype: {tensor_dtype}"
......@@ -70,8 +71,11 @@ def test(
x_tensor = to_tensor(x, lib)
y_tensor = to_tensor(y, lib)
descriptor = infiniopGlobalAvgPoolDescriptor_t()
if sync is not None:
sync()
descriptor = infiniopGlobalAvgPoolDescriptor_t()
check_error(
lib.infiniopCreateGlobalAvgPoolDescriptor(
handle,
......
......@@ -423,6 +423,7 @@ def test_operator(lib, device, test_func, test_cases, tensor_dtypes):
infiniDeviceEnum_str_map[device],
*test_case,
tensor_dtype,
get_sync_func(device)
)
finally:
destroy_handle(lib, handle)
......@@ -471,3 +472,14 @@ def get_test_devices(args):
devices_to_test = [InfiniDeviceEnum.CPU]
return devices_to_test
def get_sync_func(device):
import torch
if device == "cpu":
sync = None
else:
sync = getattr(torch, infiniDeviceEnum_str_map[device]).synchronize
return sync
......@@ -83,6 +83,7 @@ def test(
padding,
strides,
tensor_dtype=torch.float16,
sync=None
):
print(
f"Testing MaxPool on {torch_device} with x_shape:{x_shape} kernel_shape:{k_shape} padding:{padding} strides:{strides} dtype:{tensor_dtype}"
......@@ -104,8 +105,11 @@ def test(
x_tensor = to_tensor(x, lib)
y_tensor = to_tensor(y, lib)
descriptor = infiniopMaxPoolDescriptor_t()
if sync is not None:
sync()
descriptor = infiniopMaxPoolDescriptor_t()
check_error(
lib.infiniopCreateMaxPoolDescriptor(
handle,
......
......@@ -65,6 +65,7 @@ def test(
y_stride=None,
w12_stride=None,
w3_stride=None,
sync=None
):
print(
f"Testing MLP on {torch_device} with num_tokens:{num_tokens} hidden_size:{hidden_size} intermediate_size:{intermediate_size}"
......@@ -97,6 +98,10 @@ def test(
x_tensor = to_tensor(x, lib)
w12_tensor = to_tensor(w12, lib)
w3_tensor = to_tensor(w3, lib)
if sync is not None:
sync()
descriptor = infiniopMLPDescriptor_t()
check_error(
lib.infiniopCreateMLPDescriptor(
......
......@@ -103,6 +103,7 @@ def test(
topk,
temperature,
dtype=torch.float16,
sync=None
):
print(
f"Testing RandomSample on {torch_device} with voc:{voc} random_val:{random_val} topp:{topp} topk:{topk} temperature:{temperature} dtype:{dtype}"
......@@ -122,6 +123,9 @@ def test(
indices_tensor.descriptor.contents.dt = InfiniDtype.U64 # treat int64 as uint64
if sync is not None:
sync()
descriptor = infiniopRandomSampleDescriptor_t()
check_error(
lib.infiniopCreateRandomSampleDescriptor(
......
......@@ -131,6 +131,7 @@ def test(
x_stride,
y_stride,
dtype=torch.float16,
sync=None
):
print(
f"Testing Rerrange on {torch_device} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} dtype:{dtype}"
......@@ -145,6 +146,9 @@ def test(
]
x_tensor, y_tensor = [to_tensor(tensor, lib) for tensor in [x, y]]
if sync is not None:
sync()
descriptor = infiniopRearrangeDescriptor_t()
check_error(
......
......@@ -55,6 +55,7 @@ def test(
tensor_shape,
tensor_dtype=torch.float16,
inplace=Inplace.OUT_OF_PLACE,
sync=None
):
print(
f"Testing Relu on {torch_device} with tensor_shape:{tensor_shape} dtype:{tensor_dtype} inplace: {inplace.name}"
......@@ -78,8 +79,11 @@ def test(
x_tensor = to_tensor(x, lib)
y_tensor = to_tensor(y, lib) if inplace == Inplace.OUT_OF_PLACE else x_tensor
descriptor = infiniopReluDescriptor_t()
if sync is not None:
sync()
descriptor = infiniopReluDescriptor_t()
check_error(
lib.infiniopCreateReluDescriptor(
handle,
......
......@@ -72,6 +72,7 @@ def test(
x_stride,
w_dtype=torch.float16,
dtype=torch.float16,
sync=None
):
print(
f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}"
......@@ -89,9 +90,11 @@ def test(
rearrange_if_needed(tensor, stride)
for tensor, stride in zip([x, y], [x_stride, y_stride])
]
x_tensor, y_tensor, w_tensor = [to_tensor(tensor, lib) for tensor in [x, y, w]]
if sync is not None:
sync()
descriptor = infiniopRMSNormDescriptor_t()
check_error(
......
......@@ -117,6 +117,7 @@ def test(
y_strides=None,
inplace=Inplace.OUT_OF_PLACE,
dtype=torch.float32,
sync=None
):
if inplace == Inplace.INPLACE_X:
y_strides = x_strides
......@@ -147,8 +148,8 @@ def test(
else:
y_tensor = to_tensor(y, lib)
if torch_device == "npu":
synchronize_device(torch_device)
if sync is not None:
sync()
check_error(
lib.infiniopCreateRoPEDescriptor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment