Unverified Commit c98e68be authored by goldenfox2025's avatar goldenfox2025 Committed by GitHub
Browse files

Merge branch 'main' into issue180

parents d7c12d52 125afeb5
#ifndef __SWIGLU_KUNLUN_H__
#define __SWIGLU_KUNLUN_H__
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "../../../elementwise/kunlun/elementwise_kunlun_kernel.h"
/// @brief Define swiglu op for local mem
typedef struct SwiGLUOp {
private:
template <typename T>
inline __device__ T sigmoid(T x) const {
return 1.0f / (1.0f + exp(-x));
}
public:
// This static number must be set in other Ops
static constexpr size_t num_inputs = 2;
template <typename T>
inline __device__ T operator()(const T *inputs) const {
T up = inputs[0];
T gate = inputs[1];
T out = gate * sigmoid(gate) * up;
return out;
}
} SwiGLUOp;
// Definition for swiglu kernel interface
LAUNCH_ELEMENTWISE_KERNEL_IMPL(SwiGLU, SwiGLUOp)
// Template instantiate
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(SwiGLU, float)
#endif // __SWIGLU_KUNLUN_H__
...@@ -8,6 +8,9 @@ ...@@ -8,6 +8,9 @@
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
#include "cuda/swiglu_cuda.cuh" #include "cuda/swiglu_cuda.cuh"
#endif #endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/swiglu_kunlun.h"
#endif
__C infiniStatus_t infiniopCreateSwiGLUDescriptor( __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
...@@ -33,6 +36,9 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor( ...@@ -33,6 +36,9 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
CREATE(INFINI_DEVICE_NVIDIA, cuda); CREATE(INFINI_DEVICE_NVIDIA, cuda);
#endif #endif
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
return bangCreateSwiGLUDescriptor((BangHandle_t)handle, return bangCreateSwiGLUDescriptor((BangHandle_t)handle,
...@@ -80,6 +86,9 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des ...@@ -80,6 +86,9 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
GET(INFINI_DEVICE_NVIDIA, cuda) GET(INFINI_DEVICE_NVIDIA, cuda)
#endif #endif
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun)
#endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
return bangGetSwiGLUWorkspaceSize((SwiGLUBangDescriptor_t)desc, size); return bangGetSwiGLUWorkspaceSize((SwiGLUBangDescriptor_t)desc, size);
...@@ -127,6 +136,9 @@ __C infiniStatus_t infiniopSwiGLU( ...@@ -127,6 +136,9 @@ __C infiniStatus_t infiniopSwiGLU(
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
CALCULATE(INFINI_DEVICE_NVIDIA, cuda); CALCULATE(INFINI_DEVICE_NVIDIA, cuda);
#endif #endif
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
return bangSwiGLU((SwiGLUBangDescriptor_t)desc, c, a, b, stream); return bangSwiGLU((SwiGLUBangDescriptor_t)desc, c, a, b, stream);
...@@ -168,6 +180,9 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) { ...@@ -168,6 +180,9 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) {
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
DELETE(INFINI_DEVICE_NVIDIA, cuda); DELETE(INFINI_DEVICE_NVIDIA, cuda);
#endif #endif
#ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
return bangDestroySwiGLUDescriptor((SwiGLUBangDescriptor_t)desc); return bangDestroySwiGLUDescriptor((SwiGLUBangDescriptor_t)desc);
......
#ifndef __INFINIOP_REDUCE_KUNLUN_H__ #ifndef __INFINIOP_REDUCE_KUNLUN_H__
#define __INFINIOP_REDUCE_KUNLUN_H__ #define __INFINIOP_REDUCE_KUNLUN_H__
#include "../../devices/kunlun/kunlun_common.h" #include "../../devices/kunlun/kunlun_kernel_common.h"
namespace op::common_kunlun::reduce_op { namespace op::common_kunlun::reduce_op {
using namespace device::kunlun::kernel;
// Use 16 floats instruction to calculate reduce // Use 16 floats instruction to calculate reduce
// data_ptr is the pointer of LM // data_ptr is the pointer of LM
static inline __device__ float sumSquaredF32(float *data_ptr, int count) { static inline __device__ float sumSquaredF32(float *data_ptr, int count) {
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "bang/infinirt_bang.h" #include "bang/infinirt_bang.h"
#include "cpu/infinirt_cpu.h" #include "cpu/infinirt_cpu.h"
#include "cuda/infinirt_cuda.cuh" #include "cuda/infinirt_cuda.cuh"
#include "kunlun/infinirt_kunlun.h"
#include "maca/infinirt_maca.h" #include "maca/infinirt_maca.h"
#include "musa/infinirt_musa.h" #include "musa/infinirt_musa.h"
...@@ -66,8 +67,11 @@ __C infiniStatus_t infinirtGetDevice(infiniDevice_t *device_ptr, int *device_id_ ...@@ -66,8 +67,11 @@ __C infiniStatus_t infinirtGetDevice(infiniDevice_t *device_ptr, int *device_id_
case INFINI_DEVICE_MOORE: \ case INFINI_DEVICE_MOORE: \
_status = infinirt::musa::API PARAMS; \ _status = infinirt::musa::API PARAMS; \
break; \ break; \
case INFINI_DEVICE_KUNLUN: \
_status = infinirt::kunlun::API PARAMS; \
break; \
default: \ default: \
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; \ _status = INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; \
} \ } \
{ ACTION; } \ { ACTION; } \
return _status; \ return _status; \
......
import numpy as np
import gguf
from typing import List
from .. import InfiniopTestWriter, InfiniopTestCase, np_dtype_to_ggml, gguf_strides
def swiglu(
a: np.ndarray,
b: np.ndarray,
):
c = a * b / (1.0 + np.exp(-b))
return c
class SwiGLUTestCase(InfiniopTestCase):
def __init__(
self,
a: np.ndarray,
stride_a: List[int] | None,
b: np.ndarray,
stride_b: List[int] | None,
c: np.ndarray,
stride_c: List[int] | None,
):
super().__init__("swiglu")
self.a = a
self.stride_a = stride_a
self.b = b
self.stride_b = stride_b
self.c = c
self.stride_c = stride_c
def write_test(self, test_writer: "InfiniopTestWriter"):
super().write_test(test_writer)
if self.stride_a is not None:
test_writer.add_array(test_writer.gguf_key("a.strides"), self.stride_a)
if self.stride_b is not None:
test_writer.add_array(test_writer.gguf_key("b.strides"), self.stride_b)
if self.stride_c is not None:
test_writer.add_array(test_writer.gguf_key("c.strides"), self.stride_c)
test_writer.add_tensor(
test_writer.gguf_key("a"), self.a, raw_dtype=np_dtype_to_ggml(self.a.dtype)
)
test_writer.add_tensor(
test_writer.gguf_key("b"), self.b, raw_dtype=np_dtype_to_ggml(self.b.dtype)
)
test_writer.add_tensor(
test_writer.gguf_key("c"), self.c, raw_dtype=np_dtype_to_ggml(self.c.dtype)
)
ans = swiglu(
self.a.astype(np.float64),
self.b.astype(np.float64),
)
test_writer.add_tensor(
test_writer.gguf_key("ans"), ans, raw_dtype=gguf.GGMLQuantizationType.F64
)
if __name__ == "__main__":
test_writer = InfiniopTestWriter("swiglu.gguf")
test_cases = [
SwiGLUTestCase(
np.random.rand(64, 128).astype(np.float32),
None,
np.random.rand(64, 128).astype(np.float32),
None,
np.random.rand(64, 128).astype(np.float32),
None,
),
SwiGLUTestCase(
np.random.rand(64, 121).astype(np.float32),
None,
np.random.rand(64, 121).astype(np.float32),
None,
np.random.rand(64, 121).astype(np.float32),
None,
),
SwiGLUTestCase(
np.random.rand(15, 512).astype(np.float32),
None,
np.random.rand(15, 512).astype(np.float32),
None,
np.random.rand(15, 512).astype(np.float32),
None,
),
SwiGLUTestCase(
np.random.rand(13, 4).astype(np.float32),
None,
np.random.rand(13, 4).astype(np.float32),
None,
np.random.rand(13, 4).astype(np.float32),
None,
),
SwiGLUTestCase(
np.random.rand(13, 4).astype(np.float16),
None,
np.random.rand(13, 4).astype(np.float16),
None,
np.random.rand(13, 4).astype(np.float16),
None,
),
SwiGLUTestCase(
np.random.rand(13, 4).astype(np.float32),
gguf_strides(10, 1),
np.random.rand(13, 4).astype(np.float32),
gguf_strides(10, 1),
np.random.rand(13, 4).astype(np.float32),
gguf_strides(10, 1),
),
SwiGLUTestCase(
np.random.rand(13, 4).astype(np.float16),
gguf_strides(10, 1),
np.random.rand(13, 4).astype(np.float16),
gguf_strides(10, 1),
np.random.rand(13, 4).astype(np.float16),
gguf_strides(10, 1),
),
SwiGLUTestCase(
np.random.rand(13, 4, 4).astype(np.float32),
None,
np.random.rand(13, 4, 4).astype(np.float32),
None,
np.random.rand(13, 4, 4).astype(np.float32),
None,
),
SwiGLUTestCase(
np.random.rand(13, 4, 4).astype(np.float16),
None,
np.random.rand(13, 4, 4).astype(np.float16),
None,
np.random.rand(13, 4, 4).astype(np.float16),
None,
),
SwiGLUTestCase(
np.random.rand(13, 4, 4).astype(np.float32),
gguf_strides(20, 4, 1),
np.random.rand(13, 4, 4).astype(np.float32),
gguf_strides(20, 4, 1),
np.random.rand(13, 4, 4).astype(np.float32),
gguf_strides(20, 4, 1),
),
SwiGLUTestCase(
np.random.rand(13, 4, 4).astype(np.float16),
gguf_strides(20, 4, 1),
np.random.rand(13, 4, 4).astype(np.float16),
gguf_strides(20, 4, 1),
np.random.rand(13, 4, 4).astype(np.float16),
gguf_strides(20, 4, 1),
),
SwiGLUTestCase(
np.random.rand(16, 5632).astype(np.float32),
None,
np.random.rand(16, 5632).astype(np.float32),
None,
np.random.rand(16, 5632).astype(np.float32),
None,
),
SwiGLUTestCase(
np.random.rand(16, 5632).astype(np.float16),
None,
np.random.rand(16, 5632).astype(np.float16),
None,
np.random.rand(16, 5632).astype(np.float16),
None,
),
SwiGLUTestCase(
np.random.rand(16, 5632).astype(np.float32),
gguf_strides(13312, 1),
np.random.rand(16, 5632).astype(np.float32),
gguf_strides(13312, 1),
np.random.rand(16, 5632).astype(np.float32),
gguf_strides(13312, 1),
),
SwiGLUTestCase(
np.random.rand(16, 5632).astype(np.float16),
gguf_strides(13312, 1),
np.random.rand(16, 5632).astype(np.float16),
gguf_strides(13312, 1),
np.random.rand(16, 5632).astype(np.float16),
gguf_strides(13312, 1),
),
SwiGLUTestCase(
np.random.rand(16, 5632).astype(np.float32),
gguf_strides(5632, 1),
np.random.rand(16, 5632).astype(np.float32),
gguf_strides(5632, 1),
np.random.rand(16, 5632).astype(np.float32),
gguf_strides(1, 16),
),
SwiGLUTestCase(
np.random.rand(16, 5632).astype(np.float16),
gguf_strides(5632, 1),
np.random.rand(16, 5632).astype(np.float16),
gguf_strides(5632, 1),
np.random.rand(16, 5632).astype(np.float16),
gguf_strides(1, 16),
),
SwiGLUTestCase(
np.random.rand(2, 3, 400).astype(np.float32),
gguf_strides(1200, 400, 1),
np.random.rand(2, 3, 400).astype(np.float32),
gguf_strides(1200, 400, 1),
np.random.rand(2, 3, 400).astype(np.float32),
gguf_strides(1, 2, 6),
),
SwiGLUTestCase(
np.random.rand(2, 3, 400).astype(np.float16),
gguf_strides(1200, 400, 1),
np.random.rand(2, 3, 400).astype(np.float16),
gguf_strides(1200, 400, 1),
np.random.rand(2, 3, 400).astype(np.float16),
gguf_strides(1, 2, 6),
),
SwiGLUTestCase(
np.random.rand(4, 4, 5632).astype(np.float32),
None,
np.random.rand(4, 4, 5632).astype(np.float32),
None,
np.random.rand(4, 4, 5632).astype(np.float32),
None,
),
SwiGLUTestCase(
np.random.rand(4, 4, 5632).astype(np.float16),
None,
np.random.rand(4, 4, 5632).astype(np.float16),
None,
np.random.rand(4, 4, 5632).astype(np.float16),
None,
),
SwiGLUTestCase(
np.random.rand(4, 4, 5632).astype(np.float32),
gguf_strides(45056, 5632, 1),
np.random.rand(4, 4, 5632).astype(np.float32),
gguf_strides(45056, 5632, 1),
np.random.rand(4, 4, 5632).astype(np.float32),
gguf_strides(45056, 5632, 1),
),
SwiGLUTestCase(
np.random.rand(4, 4, 5632).astype(np.float16),
gguf_strides(45056, 5632, 1),
np.random.rand(4, 4, 5632).astype(np.float16),
gguf_strides(45056, 5632, 1),
np.random.rand(4, 4, 5632).astype(np.float16),
gguf_strides(45056, 5632, 1),
),
]
test_writer.add_tests(test_cases)
test_writer.save()
...@@ -101,6 +101,7 @@ def test( ...@@ -101,6 +101,7 @@ def test(
v_stride=None, v_stride=None,
k_cache_stride=None, k_cache_stride=None,
v_cache_stride=None, v_cache_stride=None,
sync=None
): ):
print( print(
f"Testing Attention on {torch_device} with n_q_head:{n_q_head} n_kv_head:{n_kv_head} seq_len:{seq_len} head_dim:{head_dim} pos:{pos} " f"Testing Attention on {torch_device} with n_q_head:{n_q_head} n_kv_head:{n_kv_head} seq_len:{seq_len} head_dim:{head_dim} pos:{pos} "
...@@ -139,6 +140,9 @@ def test( ...@@ -139,6 +140,9 @@ def test(
v_tensor = to_tensor(v, lib) v_tensor = to_tensor(v, lib)
k_cache_tensor = to_tensor(k_cache, lib) k_cache_tensor = to_tensor(k_cache, lib)
v_cache_tensor = to_tensor(v_cache, lib) v_cache_tensor = to_tensor(v_cache, lib)
if sync is not None:
sync()
descriptor = infiniopAttentionDescriptor_t() descriptor = infiniopAttentionDescriptor_t()
check_error( check_error(
......
...@@ -88,6 +88,7 @@ def test( ...@@ -88,6 +88,7 @@ def test(
padding, padding,
strides, strides,
tensor_dtype=torch.float16, tensor_dtype=torch.float16,
sync=None
): ):
print( print(
f"Testing AvgPool on {torch_device} with x_shape:{x_shape} kernel_shape:{k_shape} padding:{padding} strides:{strides} dtype:{tensor_dtype}" f"Testing AvgPool on {torch_device} with x_shape:{x_shape} kernel_shape:{k_shape} padding:{padding} strides:{strides} dtype:{tensor_dtype}"
...@@ -109,6 +110,10 @@ def test( ...@@ -109,6 +110,10 @@ def test(
x_tensor = to_tensor(x, lib) x_tensor = to_tensor(x, lib)
y_tensor = to_tensor(y, lib) y_tensor = to_tensor(y, lib)
if sync is not None:
sync()
descriptor = infiniopAvgPoolDescriptor_t() descriptor = infiniopAvgPoolDescriptor_t()
check_error( check_error(
......
...@@ -87,6 +87,7 @@ def test( ...@@ -87,6 +87,7 @@ def test(
y_stride=None, y_stride=None,
inplace=Inplace.OUT_OF_PLACE, inplace=Inplace.OUT_OF_PLACE,
dtype=torch.float16, dtype=torch.float16,
sync=None
): ):
print( print(
f"Testing CausalSoftmax on {torch_device} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} dtype:{dtype} inplace:{inplace}" f"Testing CausalSoftmax on {torch_device} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} dtype:{dtype} inplace:{inplace}"
...@@ -107,6 +108,9 @@ def test( ...@@ -107,6 +108,9 @@ def test(
y = torch.zeros(shape, dtype=dtype).to(torch_device) y = torch.zeros(shape, dtype=dtype).to(torch_device)
y = rearrange_if_needed(y, y_stride) y = rearrange_if_needed(y, y_stride)
y_tensor = to_tensor(y, lib) y_tensor = to_tensor(y, lib)
if sync is not None:
sync()
descriptor = infiniopCausalSoftmaxDescriptor_t() descriptor = infiniopCausalSoftmaxDescriptor_t()
check_error( check_error(
......
...@@ -95,6 +95,7 @@ def test( ...@@ -95,6 +95,7 @@ def test(
dilations, dilations,
tensor_stride=None, tensor_stride=None,
tensor_dtype=torch.float16, tensor_dtype=torch.float16,
sync=None
): ):
assert len(pads) == len(strides) == len(dilations) assert len(pads) == len(strides) == len(dilations)
print( print(
...@@ -118,8 +119,11 @@ def test( ...@@ -118,8 +119,11 @@ def test(
x_tensor = to_tensor(x, lib) x_tensor = to_tensor(x, lib)
w_tensor = to_tensor(w, lib) w_tensor = to_tensor(w, lib)
y_tensor = to_tensor(y, lib) y_tensor = to_tensor(y, lib)
descriptor = infiniopConvDescriptor_t()
if sync is not None:
sync()
descriptor = infiniopConvDescriptor_t()
check_error( check_error(
lib.infiniopCreateConvDescriptor( lib.infiniopCreateConvDescriptor(
handle, handle,
......
...@@ -52,6 +52,7 @@ def test( ...@@ -52,6 +52,7 @@ def test(
y_stride=None, y_stride=None,
x_stride=None, x_stride=None,
tensor_dtype=torch.float16, tensor_dtype=torch.float16,
sync=None
): ):
print( print(
f"Testing Expand on {torch_device} with x_shape:{x_shape} y_shape:{y_shape} x_stride:{x_stride} y_stride:{y_stride} dtype:{tensor_dtype}" f"Testing Expand on {torch_device} with x_shape:{x_shape} y_shape:{y_shape} x_stride:{x_stride} y_stride:{y_stride} dtype:{tensor_dtype}"
...@@ -76,8 +77,11 @@ def test( ...@@ -76,8 +77,11 @@ def test(
x_tensor = to_tensor(x, lib) x_tensor = to_tensor(x, lib)
y_tensor = to_tensor(y, lib) y_tensor = to_tensor(y, lib)
descriptor = infiniopExpandDescriptor_t()
if sync is not None:
sync()
descriptor = infiniopExpandDescriptor_t()
check_error( check_error(
lib.infiniopCreateExpandDescriptor( lib.infiniopCreateExpandDescriptor(
handle, handle,
......
...@@ -83,6 +83,7 @@ def test( ...@@ -83,6 +83,7 @@ def test(
b_stride=None, b_stride=None,
c_stride=None, c_stride=None,
dtype=torch.float16, dtype=torch.float16,
sync=None
): ):
print( print(
f"Testing Gemm on {torch_device} with alpha:{alpha}, beta:{beta}," f"Testing Gemm on {torch_device} with alpha:{alpha}, beta:{beta},"
...@@ -104,6 +105,9 @@ def test( ...@@ -104,6 +105,9 @@ def test(
] ]
a_tensor, b_tensor, c_tensor = [to_tensor(tensor, lib) for tensor in [a, b, c]] a_tensor, b_tensor, c_tensor = [to_tensor(tensor, lib) for tensor in [a, b, c]]
if sync is not None:
sync()
descriptor = infiniopGemmDescriptor_t() descriptor = infiniopGemmDescriptor_t()
check_error( check_error(
lib.infiniopCreateGemmDescriptor( lib.infiniopCreateGemmDescriptor(
......
...@@ -51,6 +51,7 @@ def test( ...@@ -51,6 +51,7 @@ def test(
torch_device, torch_device,
x_shape, x_shape,
tensor_dtype=torch.float16, tensor_dtype=torch.float16,
sync=None
): ):
print( print(
f"Testing GlobalAvgPool on {torch_device} with input tensor_shape: {x_shape} dtype: {tensor_dtype}" f"Testing GlobalAvgPool on {torch_device} with input tensor_shape: {x_shape} dtype: {tensor_dtype}"
...@@ -70,8 +71,11 @@ def test( ...@@ -70,8 +71,11 @@ def test(
x_tensor = to_tensor(x, lib) x_tensor = to_tensor(x, lib)
y_tensor = to_tensor(y, lib) y_tensor = to_tensor(y, lib)
descriptor = infiniopGlobalAvgPoolDescriptor_t()
if sync is not None:
sync()
descriptor = infiniopGlobalAvgPoolDescriptor_t()
check_error( check_error(
lib.infiniopCreateGlobalAvgPoolDescriptor( lib.infiniopCreateGlobalAvgPoolDescriptor(
handle, handle,
......
...@@ -423,6 +423,7 @@ def test_operator(lib, device, test_func, test_cases, tensor_dtypes): ...@@ -423,6 +423,7 @@ def test_operator(lib, device, test_func, test_cases, tensor_dtypes):
infiniDeviceEnum_str_map[device], infiniDeviceEnum_str_map[device],
*test_case, *test_case,
tensor_dtype, tensor_dtype,
get_sync_func(device)
) )
finally: finally:
destroy_handle(lib, handle) destroy_handle(lib, handle)
...@@ -471,3 +472,14 @@ def get_test_devices(args): ...@@ -471,3 +472,14 @@ def get_test_devices(args):
devices_to_test = [InfiniDeviceEnum.CPU] devices_to_test = [InfiniDeviceEnum.CPU]
return devices_to_test return devices_to_test
def get_sync_func(device):
import torch
if device == "cpu":
sync = None
else:
sync = getattr(torch, infiniDeviceEnum_str_map[device]).synchronize
return sync
...@@ -83,6 +83,7 @@ def test( ...@@ -83,6 +83,7 @@ def test(
padding, padding,
strides, strides,
tensor_dtype=torch.float16, tensor_dtype=torch.float16,
sync=None
): ):
print( print(
f"Testing MaxPool on {torch_device} with x_shape:{x_shape} kernel_shape:{k_shape} padding:{padding} strides:{strides} dtype:{tensor_dtype}" f"Testing MaxPool on {torch_device} with x_shape:{x_shape} kernel_shape:{k_shape} padding:{padding} strides:{strides} dtype:{tensor_dtype}"
...@@ -104,8 +105,11 @@ def test( ...@@ -104,8 +105,11 @@ def test(
x_tensor = to_tensor(x, lib) x_tensor = to_tensor(x, lib)
y_tensor = to_tensor(y, lib) y_tensor = to_tensor(y, lib)
descriptor = infiniopMaxPoolDescriptor_t()
if sync is not None:
sync()
descriptor = infiniopMaxPoolDescriptor_t()
check_error( check_error(
lib.infiniopCreateMaxPoolDescriptor( lib.infiniopCreateMaxPoolDescriptor(
handle, handle,
......
...@@ -65,6 +65,7 @@ def test( ...@@ -65,6 +65,7 @@ def test(
y_stride=None, y_stride=None,
w12_stride=None, w12_stride=None,
w3_stride=None, w3_stride=None,
sync=None
): ):
print( print(
f"Testing MLP on {torch_device} with num_tokens:{num_tokens} hidden_size:{hidden_size} intermediate_size:{intermediate_size}" f"Testing MLP on {torch_device} with num_tokens:{num_tokens} hidden_size:{hidden_size} intermediate_size:{intermediate_size}"
...@@ -97,6 +98,10 @@ def test( ...@@ -97,6 +98,10 @@ def test(
x_tensor = to_tensor(x, lib) x_tensor = to_tensor(x, lib)
w12_tensor = to_tensor(w12, lib) w12_tensor = to_tensor(w12, lib)
w3_tensor = to_tensor(w3, lib) w3_tensor = to_tensor(w3, lib)
if sync is not None:
sync()
descriptor = infiniopMLPDescriptor_t() descriptor = infiniopMLPDescriptor_t()
check_error( check_error(
lib.infiniopCreateMLPDescriptor( lib.infiniopCreateMLPDescriptor(
......
...@@ -103,6 +103,7 @@ def test( ...@@ -103,6 +103,7 @@ def test(
topk, topk,
temperature, temperature,
dtype=torch.float16, dtype=torch.float16,
sync=None
): ):
print( print(
f"Testing RandomSample on {torch_device} with voc:{voc} random_val:{random_val} topp:{topp} topk:{topk} temperature:{temperature} dtype:{dtype}" f"Testing RandomSample on {torch_device} with voc:{voc} random_val:{random_val} topp:{topp} topk:{topk} temperature:{temperature} dtype:{dtype}"
...@@ -122,6 +123,9 @@ def test( ...@@ -122,6 +123,9 @@ def test(
indices_tensor.descriptor.contents.dt = InfiniDtype.U64 # treat int64 as uint64 indices_tensor.descriptor.contents.dt = InfiniDtype.U64 # treat int64 as uint64
if sync is not None:
sync()
descriptor = infiniopRandomSampleDescriptor_t() descriptor = infiniopRandomSampleDescriptor_t()
check_error( check_error(
lib.infiniopCreateRandomSampleDescriptor( lib.infiniopCreateRandomSampleDescriptor(
......
...@@ -131,6 +131,7 @@ def test( ...@@ -131,6 +131,7 @@ def test(
x_stride, x_stride,
y_stride, y_stride,
dtype=torch.float16, dtype=torch.float16,
sync=None
): ):
print( print(
f"Testing Rerrange on {torch_device} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} dtype:{dtype}" f"Testing Rerrange on {torch_device} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} dtype:{dtype}"
...@@ -145,6 +146,9 @@ def test( ...@@ -145,6 +146,9 @@ def test(
] ]
x_tensor, y_tensor = [to_tensor(tensor, lib) for tensor in [x, y]] x_tensor, y_tensor = [to_tensor(tensor, lib) for tensor in [x, y]]
if sync is not None:
sync()
descriptor = infiniopRearrangeDescriptor_t() descriptor = infiniopRearrangeDescriptor_t()
check_error( check_error(
......
...@@ -55,6 +55,7 @@ def test( ...@@ -55,6 +55,7 @@ def test(
tensor_shape, tensor_shape,
tensor_dtype=torch.float16, tensor_dtype=torch.float16,
inplace=Inplace.OUT_OF_PLACE, inplace=Inplace.OUT_OF_PLACE,
sync=None
): ):
print( print(
f"Testing Relu on {torch_device} with tensor_shape:{tensor_shape} dtype:{tensor_dtype} inplace: {inplace.name}" f"Testing Relu on {torch_device} with tensor_shape:{tensor_shape} dtype:{tensor_dtype} inplace: {inplace.name}"
...@@ -78,8 +79,11 @@ def test( ...@@ -78,8 +79,11 @@ def test(
x_tensor = to_tensor(x, lib) x_tensor = to_tensor(x, lib)
y_tensor = to_tensor(y, lib) if inplace == Inplace.OUT_OF_PLACE else x_tensor y_tensor = to_tensor(y, lib) if inplace == Inplace.OUT_OF_PLACE else x_tensor
descriptor = infiniopReluDescriptor_t()
if sync is not None:
sync()
descriptor = infiniopReluDescriptor_t()
check_error( check_error(
lib.infiniopCreateReluDescriptor( lib.infiniopCreateReluDescriptor(
handle, handle,
......
...@@ -72,6 +72,7 @@ def test( ...@@ -72,6 +72,7 @@ def test(
x_stride, x_stride,
w_dtype=torch.float16, w_dtype=torch.float16,
dtype=torch.float16, dtype=torch.float16,
sync=None
): ):
print( print(
f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}" f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}"
...@@ -89,9 +90,11 @@ def test( ...@@ -89,9 +90,11 @@ def test(
rearrange_if_needed(tensor, stride) rearrange_if_needed(tensor, stride)
for tensor, stride in zip([x, y], [x_stride, y_stride]) for tensor, stride in zip([x, y], [x_stride, y_stride])
] ]
x_tensor, y_tensor, w_tensor = [to_tensor(tensor, lib) for tensor in [x, y, w]] x_tensor, y_tensor, w_tensor = [to_tensor(tensor, lib) for tensor in [x, y, w]]
if sync is not None:
sync()
descriptor = infiniopRMSNormDescriptor_t() descriptor = infiniopRMSNormDescriptor_t()
check_error( check_error(
......
...@@ -117,6 +117,7 @@ def test( ...@@ -117,6 +117,7 @@ def test(
y_strides=None, y_strides=None,
inplace=Inplace.OUT_OF_PLACE, inplace=Inplace.OUT_OF_PLACE,
dtype=torch.float32, dtype=torch.float32,
sync=None
): ):
if inplace == Inplace.INPLACE_X: if inplace == Inplace.INPLACE_X:
y_strides = x_strides y_strides = x_strides
...@@ -147,8 +148,8 @@ def test( ...@@ -147,8 +148,8 @@ def test(
else: else:
y_tensor = to_tensor(y, lib) y_tensor = to_tensor(y, lib)
if torch_device == "npu": if sync is not None:
synchronize_device(torch_device) sync()
check_error( check_error(
lib.infiniopCreateRoPEDescriptor( lib.infiniopCreateRoPEDescriptor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment