Commit 9bd8df78 authored by Zimin Li's avatar Zimin Li
Browse files

issue/74 fix inplace issue and various naming and import issues

parent 5619c372
import torch import torch
import ctypes import ctypes
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float
from libinfiniop import ( from libinfiniop import (
InfiniDtype,
infiniopHandle_t, infiniopHandle_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
open_lib, open_lib,
to_tensor, to_tensor,
get_test_devices, get_test_devices,
check_error, check_error,
rearrange_if_needed,
create_workspace, create_workspace,
test_operator, test_operator,
get_args, get_args,
...@@ -57,7 +57,7 @@ class RandomSampleDescriptor(Structure): ...@@ -57,7 +57,7 @@ class RandomSampleDescriptor(Structure):
infiniopRandomSampleDescriptor_t = POINTER(RandomSampleDescriptor) infiniopRandomSampleDescriptor_t = POINTER(RandomSampleDescriptor)
def random_sample(data, random_val, topp, topk, voc, temperature, torch_device): def random_sample(data, random_val, topp, topk, voc, temperature):
if topp > 0 and topk > 1: if topp > 0 and topk > 1:
indices = torch.zeros([topk], dtype=torch.int64) indices = torch.zeros([topk], dtype=torch.int64)
dataNp = data.clone().detach() dataNp = data.clone().detach()
...@@ -115,23 +115,25 @@ def test( ...@@ -115,23 +115,25 @@ def test(
topp, topp,
topk, topk,
temperature, temperature,
x_dtype=torch.float16, dtype=torch.float16,
): ):
print(f"Testing RandomSample on {torch_device} with voc:{voc} dtype:{x_dtype}") print(
f"Testing RandomSample on {torch_device} with voc:{voc} random_val:{random_val} topp:{topp} topk:{topk} temperature:{temperature} dtype:{dtype}"
)
data = torch.arange(voc).float() * 0.0001 data = torch.arange(voc).float() * 0.0001
_perm = torch.randperm(voc) _perm = torch.randperm(voc)
data = data[_perm].to(x_dtype).to(torch_device) data = data[_perm].to(dtype).to(torch_device)
ans = random_sample( ans = random_sample(
data, random_val, topp, topk, voc, temperature, torch_device data, random_val, topp, topk, voc, temperature
) # 这个函数在device速度可能会很慢,可以通过data.to("cpu")方式加快计算过程 ) # 这个函数在device速度可能会很慢,可以通过data.to("cpu")方式加快计算过程
indices = torch.zeros([1], dtype=torch.int64).to(torch_device) indices = torch.zeros([1], dtype=torch.int64).to(torch_device)
x_tensor, indices_tensor = [to_tensor(tensor, lib) for tensor in [data, indices]] x_tensor, indices_tensor = [to_tensor(tensor, lib) for tensor in [data, indices]]
indices_tensor.descriptor.contents.dt = U64 # treat int64 as uint64 indices_tensor.descriptor.contents.dt = InfiniDtype.U64 # treat int64 as uint64
descriptor = infiniopRandomSampleDescriptor_t() descriptor = infiniopRandomSampleDescriptor_t()
check_error( check_error(
...@@ -191,7 +193,7 @@ def test( ...@@ -191,7 +193,7 @@ def test(
if PROFILE: if PROFILE:
# fmt: off # fmt: off
profile_operation("PyTorch", lambda: random_sample( profile_operation("PyTorch", lambda: random_sample(
data, random_val, topp, topk, voc, temperature, torch_device data, random_val, topp, topk, voc, temperature
), torch_device, NUM_PRERUN, NUM_ITERATIONS) ), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_random_sample(), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation(" lib", lambda: lib_random_sample(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on # fmt: on
......
import torch import torch
import ctypes import ctypes
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float from ctypes import POINTER, Structure, c_int32, c_void_p
from libinfiniop import ( from libinfiniop import (
infiniopHandle_t, infiniopHandle_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
...@@ -9,7 +9,7 @@ from libinfiniop import ( ...@@ -9,7 +9,7 @@ from libinfiniop import (
get_test_devices, get_test_devices,
check_error, check_error,
rearrange_if_needed, rearrange_if_needed,
create_workspace, rearrange_tensor,
test_operator, test_operator,
get_args, get_args,
debug, debug,
...@@ -62,14 +62,14 @@ def test( ...@@ -62,14 +62,14 @@ def test(
x_stride, x_stride,
y_shape, y_shape,
y_stride, y_stride,
x_dtype=torch.float16, dtype=torch.float16,
): ):
print( print(
f"Testing Rerrange on {torch_device} with x_shape:{x_shape} x_stride:{x_stride} y_shape:{y_shape} y_stride:{y_stride} x_dtype:{x_dtype}" f"Testing Rerrange on {torch_device} with x_shape:{x_shape} x_stride:{x_stride} y_shape:{y_shape} y_stride:{y_stride} dtype:{dtype}"
) )
x = torch.rand(x_shape, dtype=x_dtype).to(torch_device) x = torch.rand(x_shape, dtype=dtype).to(torch_device)
y = torch.zeros(y_shape, dtype=x_dtype).to(torch_device) y = torch.zeros(y_shape, dtype=dtype).to(torch_device)
x, y = [ x, y = [
rearrange_if_needed(tensor, stride) rearrange_if_needed(tensor, stride)
......
...@@ -2,7 +2,7 @@ from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float ...@@ -2,7 +2,7 @@ from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float
import ctypes import ctypes
import torch import torch
import ctypes import ctypes
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float
from libinfiniop import ( from libinfiniop import (
infiniopHandle_t, infiniopHandle_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
...@@ -69,12 +69,12 @@ def test( ...@@ -69,12 +69,12 @@ def test(
w_shape, w_shape,
y_stride, y_stride,
x_stride, x_stride,
dtype=torch.float16,
w_dtype=torch.float16, w_dtype=torch.float16,
dtype=torch.float16,
): ):
print( print(
f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}" f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}"
f" dtype:{dtype} w_dtype:{w_dtype}" f" y_stride:{y_stride} x_stride:{x_stride} w_dtype:{w_dtype} dtype:{dtype}"
) )
y = torch.zeros(y_shape, dtype=dtype).to(torch_device) y = torch.zeros(y_shape, dtype=dtype).to(torch_device)
......
...@@ -2,6 +2,7 @@ import torch ...@@ -2,6 +2,7 @@ import torch
import ctypes import ctypes
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
from libinfiniop import ( from libinfiniop import (
InfiniDtype,
infiniopHandle_t, infiniopHandle_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
open_lib, open_lib,
...@@ -131,7 +132,7 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): ...@@ -131,7 +132,7 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
check_error( check_error(
lib.infiniopCreateRoPEDescriptor( lib.infiniopCreateRoPEDescriptor(
handle, handle,
byref(descriptor), ctypes.byref(descriptor),
t_tensor.descriptor, t_tensor.descriptor,
pos_tensor.descriptor, pos_tensor.descriptor,
sin_table_tensor.descriptor, sin_table_tensor.descriptor,
...@@ -231,4 +232,5 @@ if __name__ == "__main__": ...@@ -231,4 +232,5 @@ if __name__ == "__main__":
# Execute tests # Execute tests
for device in get_test_devices(args): for device in get_test_devices(args):
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES) test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m") print("\033[92mTest passed!\033[0m")
import torch import torch
import ctypes import ctypes
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float from ctypes import POINTER, Structure, c_int32, c_void_p
from libinfiniop import ( from libinfiniop import (
infiniopHandle_t, infiniopHandle_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
...@@ -9,7 +9,6 @@ from libinfiniop import ( ...@@ -9,7 +9,6 @@ from libinfiniop import (
get_test_devices, get_test_devices,
check_error, check_error,
rearrange_if_needed, rearrange_if_needed,
create_workspace,
test_operator, test_operator,
get_args, get_args,
debug, debug,
...@@ -23,14 +22,15 @@ from enum import Enum, auto ...@@ -23,14 +22,15 @@ from enum import Enum, auto
# ============================================================================== # ==============================================================================
# These are not meant to be imported from other modules # These are not meant to be imported from other modules
_TEST_CASES_ = [ _TEST_CASES_ = [
# shape, a_stride, b_stride, c_stride
((13, 4), None, None, None), ((13, 4), None, None, None),
((13, 4), (10, 1), (10, 1), (10, 1)), ((13, 4), (10, 1), (10, 1), (10, 1)),
((13, 4, 4), None, None, None), # ((13, 4, 4), None, None, None),
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)), # ((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)),
((16, 5632), None, None, None), ((16, 5632), None, None, None),
((16, 5632), (13312, 1), (13312, 1), (13312, 1)), ((16, 5632), (13312, 1), (13312, 1), (13312, 1)),
((4, 4, 5632), None, None, None), # ((4, 4, 5632), None, None, None),
((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)), # ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)),
] ]
# Inplace options applied for each test case in _TEST_CASES_ # Inplace options applied for each test case in _TEST_CASES_
...@@ -91,16 +91,13 @@ def test( ...@@ -91,16 +91,13 @@ def test(
sync=None, sync=None,
): ):
print( print(
f"Testing SwiGLU on {torch_device} with shape:{shape} a_stride:{a_stride} b_stride:{b_stride} c_stride:{c_stride} dtype:{dtype}" f"Testing SwiGLU on {torch_device} with shape:{shape} a_stride:{a_stride} b_stride:{b_stride} c_stride:{c_stride} "
f"dtype:{dtype} inplace:{inplace}"
) )
a = torch.rand(shape, dtype=dtype).to(torch_device) a = torch.rand(shape, dtype=dtype).to(torch_device)
b = torch.rand(shape, dtype=dtype).to(torch_device) b = torch.rand(shape, dtype=dtype).to(torch_device)
c = ( c = torch.rand(shape, dtype=dtype).to(torch_device)
torch.rand(c_shape, dtype=tensor_dtype).to(torch_device)
if inplace == Inplace.OUT_OF_PLACE
else (a if inplace == Inplace.INPLACE_A else b)
)
ans = swiglu(a, b) ans = swiglu(a, b)
...@@ -108,6 +105,11 @@ def test( ...@@ -108,6 +105,11 @@ def test(
rearrange_if_needed(tensor, stride) rearrange_if_needed(tensor, stride)
for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_stride]) for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_stride])
] ]
c = (
c
if inplace == Inplace.OUT_OF_PLACE
else (a if inplace == Inplace.INPLACE_A else b)
)
a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]] a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]]
c_tensor = ( c_tensor = (
to_tensor(c, lib) to_tensor(c, lib)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment