Commit 9bd8df78 authored by Zimin Li's avatar Zimin Li
Browse files

issue/74 fix inplace issue and various naming and import issues

parent 5619c372
import torch
import ctypes
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float
from libinfiniop import (
InfiniDtype,
infiniopHandle_t,
infiniopTensorDescriptor_t,
open_lib,
to_tensor,
get_test_devices,
check_error,
rearrange_if_needed,
create_workspace,
test_operator,
get_args,
......@@ -57,7 +57,7 @@ class RandomSampleDescriptor(Structure):
infiniopRandomSampleDescriptor_t = POINTER(RandomSampleDescriptor)
def random_sample(data, random_val, topp, topk, voc, temperature, torch_device):
def random_sample(data, random_val, topp, topk, voc, temperature):
if topp > 0 and topk > 1:
indices = torch.zeros([topk], dtype=torch.int64)
dataNp = data.clone().detach()
......@@ -115,23 +115,25 @@ def test(
topp,
topk,
temperature,
x_dtype=torch.float16,
dtype=torch.float16,
):
print(f"Testing RandomSample on {torch_device} with voc:{voc} dtype:{x_dtype}")
print(
f"Testing RandomSample on {torch_device} with voc:{voc} random_val:{random_val} topp:{topp} topk:{topk} temperature:{temperature} dtype:{dtype}"
)
data = torch.arange(voc).float() * 0.0001
_perm = torch.randperm(voc)
data = data[_perm].to(x_dtype).to(torch_device)
data = data[_perm].to(dtype).to(torch_device)
ans = random_sample(
data, random_val, topp, topk, voc, temperature, torch_device
data, random_val, topp, topk, voc, temperature
) # 这个函数在device速度可能会很慢,可以通过data.to("cpu")方式加快计算过程
indices = torch.zeros([1], dtype=torch.int64).to(torch_device)
x_tensor, indices_tensor = [to_tensor(tensor, lib) for tensor in [data, indices]]
indices_tensor.descriptor.contents.dt = U64 # treat int64 as uint64
indices_tensor.descriptor.contents.dt = InfiniDtype.U64 # treat int64 as uint64
descriptor = infiniopRandomSampleDescriptor_t()
check_error(
......@@ -191,7 +193,7 @@ def test(
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: random_sample(
data, random_val, topp, topk, voc, temperature, torch_device
data, random_val, topp, topk, voc, temperature
), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_random_sample(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
......
import torch
import ctypes
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
from ctypes import POINTER, Structure, c_int32, c_void_p
from libinfiniop import (
infiniopHandle_t,
infiniopTensorDescriptor_t,
......@@ -9,7 +9,7 @@ from libinfiniop import (
get_test_devices,
check_error,
rearrange_if_needed,
create_workspace,
rearrange_tensor,
test_operator,
get_args,
debug,
......@@ -62,14 +62,14 @@ def test(
x_stride,
y_shape,
y_stride,
x_dtype=torch.float16,
dtype=torch.float16,
):
print(
f"Testing Rerrange on {torch_device} with x_shape:{x_shape} x_stride:{x_stride} y_shape:{y_shape} y_stride:{y_stride} x_dtype:{x_dtype}"
f"Testing Rerrange on {torch_device} with x_shape:{x_shape} x_stride:{x_stride} y_shape:{y_shape} y_stride:{y_stride} dtype:{dtype}"
)
x = torch.rand(x_shape, dtype=x_dtype).to(torch_device)
y = torch.zeros(y_shape, dtype=x_dtype).to(torch_device)
x = torch.rand(x_shape, dtype=dtype).to(torch_device)
y = torch.zeros(y_shape, dtype=dtype).to(torch_device)
x, y = [
rearrange_if_needed(tensor, stride)
......
......@@ -2,7 +2,7 @@ from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float
import ctypes
import torch
import ctypes
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float
from libinfiniop import (
infiniopHandle_t,
infiniopTensorDescriptor_t,
......@@ -69,12 +69,12 @@ def test(
w_shape,
y_stride,
x_stride,
dtype=torch.float16,
w_dtype=torch.float16,
dtype=torch.float16,
):
print(
f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}"
f" dtype:{dtype} w_dtype:{w_dtype}"
f" y_stride:{y_stride} x_stride:{x_stride} w_dtype:{w_dtype} dtype:{dtype}"
)
y = torch.zeros(y_shape, dtype=dtype).to(torch_device)
......
......@@ -2,6 +2,7 @@ import torch
import ctypes
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
from libinfiniop import (
InfiniDtype,
infiniopHandle_t,
infiniopTensorDescriptor_t,
open_lib,
......@@ -131,7 +132,7 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
check_error(
lib.infiniopCreateRoPEDescriptor(
handle,
byref(descriptor),
ctypes.byref(descriptor),
t_tensor.descriptor,
pos_tensor.descriptor,
sin_table_tensor.descriptor,
......@@ -231,4 +232,5 @@ if __name__ == "__main__":
# Execute tests
for device in get_test_devices(args):
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
import torch
import ctypes
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
from ctypes import POINTER, Structure, c_int32, c_void_p
from libinfiniop import (
infiniopHandle_t,
infiniopTensorDescriptor_t,
......@@ -9,7 +9,6 @@ from libinfiniop import (
get_test_devices,
check_error,
rearrange_if_needed,
create_workspace,
test_operator,
get_args,
debug,
......@@ -23,14 +22,15 @@ from enum import Enum, auto
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_ = [
# shape, a_stride, b_stride, c_stride
((13, 4), None, None, None),
((13, 4), (10, 1), (10, 1), (10, 1)),
((13, 4, 4), None, None, None),
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)),
# ((13, 4, 4), None, None, None),
# ((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)),
((16, 5632), None, None, None),
((16, 5632), (13312, 1), (13312, 1), (13312, 1)),
((4, 4, 5632), None, None, None),
((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)),
# ((4, 4, 5632), None, None, None),
# ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)),
]
# Inplace options applied for each test case in _TEST_CASES_
......@@ -91,16 +91,13 @@ def test(
sync=None,
):
print(
f"Testing SwiGLU on {torch_device} with shape:{shape} a_stride:{a_stride} b_stride:{b_stride} c_stride:{c_stride} dtype:{dtype}"
f"Testing SwiGLU on {torch_device} with shape:{shape} a_stride:{a_stride} b_stride:{b_stride} c_stride:{c_stride} "
f"dtype:{dtype} inplace:{inplace}"
)
a = torch.rand(shape, dtype=dtype).to(torch_device)
b = torch.rand(shape, dtype=dtype).to(torch_device)
c = (
torch.rand(c_shape, dtype=tensor_dtype).to(torch_device)
if inplace == Inplace.OUT_OF_PLACE
else (a if inplace == Inplace.INPLACE_A else b)
)
c = torch.rand(shape, dtype=dtype).to(torch_device)
ans = swiglu(a, b)
......@@ -108,6 +105,11 @@ def test(
rearrange_if_needed(tensor, stride)
for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_stride])
]
c = (
c
if inplace == Inplace.OUT_OF_PLACE
else (a if inplace == Inplace.INPLACE_A else b)
)
a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]]
c_tensor = (
to_tensor(c, lib)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment