Commit aeeae7ea authored by Catheriany's avatar Catheriany
Browse files

issue/228: swiglu测例0步长添加

parent 3329034a
from .infiniop_test import InfiniopTestCase, InfiniopTestWriter, np_dtype_to_ggml, gguf_strides, contiguous_gguf_strides from .infiniop_test import InfiniopTestCase, InfiniopTestWriter, np_dtype_to_ggml, gguf_strides, contiguous_gguf_strides, process_zero_stride_tensor
...@@ -37,6 +37,18 @@ def contiguous_gguf_strides(shape: tuple[int, ...]) -> list[int]: ...@@ -37,6 +37,18 @@ def contiguous_gguf_strides(shape: tuple[int, ...]) -> list[int]:
acc *= size acc *= size
return strides[::-1] return strides[::-1]
def process_zero_stride_tensor(a, b, stride_a=None, stride_b=None):
def normalize_stride(tensor, stride):
if stride:
slices = tuple(slice(0, 1) if s == 0 else slice(None) for s in stride)
return tensor[slices]
else:
return tensor
a_unique = normalize_stride(a, stride_a)
b_unique = normalize_stride(b, stride_b)
return a_unique, b_unique
class InfiniopTestCase: class InfiniopTestCase:
op_name: str op_name: str
......
...@@ -4,7 +4,7 @@ import gguf ...@@ -4,7 +4,7 @@ import gguf
from typing import List from typing import List
from numpy.lib.stride_tricks import as_strided from numpy.lib.stride_tricks import as_strided
from .. import InfiniopTestWriter, InfiniopTestCase, np_dtype_to_ggml, gguf_strides, contiguous_gguf_strides from .. import InfiniopTestWriter, InfiniopTestCase, np_dtype_to_ggml, gguf_strides, contiguous_gguf_strides, process_zero_stride_tensor
def add( def add(
...@@ -13,17 +13,6 @@ def add( ...@@ -13,17 +13,6 @@ def add(
): ):
return a + b return a + b
def process_tensor(a, b, stride_a=None, stride_b=None):
def normalize_stride(tensor, stride):
if stride:
slices = tuple(slice(0, 1) if s == 0 else slice(None) for s in stride)
return tensor[slices]
else:
return tensor
a_unique = normalize_stride(a, stride_a)
b_unique = normalize_stride(b, stride_b)
return a_unique, b_unique
class AddTestCase(InfiniopTestCase): class AddTestCase(InfiniopTestCase):
def __init__( def __init__(
...@@ -111,9 +100,7 @@ if __name__ == "__main__": ...@@ -111,9 +100,7 @@ if __name__ == "__main__":
a = np.random.rand(*shape).astype(dtype) a = np.random.rand(*shape).astype(dtype)
b = np.random.rand(*shape).astype(dtype) b = np.random.rand(*shape).astype(dtype)
c = np.empty(tuple(0 for _ in shape), dtype=dtype) c = np.empty(tuple(0 for _ in shape), dtype=dtype)
a, b = process_tensor(a, b, stride_a, stride_b) a, b = process_zero_stride_tensor(a, b, stride_a, stride_b)
if stride_c is None:
stride_c = contiguous_gguf_strides(shape)
test_case = AddTestCase( test_case = AddTestCase(
a=a, a=a,
shape_a=shape, shape_a=shape,
......
...@@ -2,7 +2,7 @@ import numpy as np ...@@ -2,7 +2,7 @@ import numpy as np
import gguf import gguf
from typing import List from typing import List
from .. import InfiniopTestWriter, InfiniopTestCase, np_dtype_to_ggml, gguf_strides, contiguous_gguf_strides from .. import InfiniopTestWriter, InfiniopTestCase, np_dtype_to_ggml, gguf_strides, contiguous_gguf_strides, process_zero_stride_tensor
def swiglu( def swiglu(
...@@ -92,6 +92,8 @@ if __name__ == "__main__": ...@@ -92,6 +92,8 @@ if __name__ == "__main__":
((2, 3, 400), (1200, 400, 1), (1200, 400, 1), (1, 2, 6)), ((2, 3, 400), (1200, 400, 1), (1200, 400, 1), (1, 2, 6)),
((4, 4, 5632), None, None, None), ((4, 4, 5632), None, None, None),
((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)), ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)),
((13, 4), (0, 1), None, None),
((13, 4, 4), (4, 0, 1), (0, 4, 1), None),
] ]
_TENSOR_DTYPES_ = [np.float32, np.float16] _TENSOR_DTYPES_ = [np.float32, np.float16]
...@@ -100,6 +102,7 @@ if __name__ == "__main__": ...@@ -100,6 +102,7 @@ if __name__ == "__main__":
a = np.random.rand(*shape).astype(dtype) a = np.random.rand(*shape).astype(dtype)
b = np.random.rand(*shape).astype(dtype) b = np.random.rand(*shape).astype(dtype)
c = np.empty(tuple(0 for _ in shape), dtype=dtype) c = np.empty(tuple(0 for _ in shape), dtype=dtype)
a, b = process_zero_stride_tensor(a, b, stride_a, stride_b)
test_case = SwiGLUTestCase( test_case = SwiGLUTestCase(
a=a, a=a,
shape_a=list(shape), shape_a=list(shape),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment