Unverified Commit 309878f0 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #91 from YdrMaster/main

issue/87/refactor: 重构 cuda handle,并修改 cpu handle 和 matmul 命名空间
parents 2538df1b 01b2f8ab
......@@ -35,4 +35,14 @@ jobs:
run: xmake f -cv
- name: build with xmake
run: xmake build && xmake install
run: xmake build
- name: install to INFINI_ROOT
if: matrix.os != 'windows-latest'
run: xmake install
- name: python test
if: matrix.os != 'windows-latest'
run: |
pip install torch
LD_LIBRARY_PATH=$HOME/.infini/lib python test/infiniop/matmul.py --cpu
#ifndef __INFINIOP_API_H__
#define __INFINIOP_API_H__
#include "infiniop/tensor_descriptor.h"
#include "infiniop/handle.h"
#include "infiniop/ops/add.h"
#include "infiniop/ops/attention.h"
......@@ -19,5 +19,6 @@
#include "infiniop/ops/rms_norm.h"
#include "infiniop/ops/rotary_embedding.h"
#include "infiniop/ops/swiglu.h"
#include "infiniop/tensor_descriptor.h"
#endif // __INFINIOP_API_H__
......@@ -7,7 +7,7 @@ struct InfiniopHandle;
typedef InfiniopHandle *infiniopHandle_t;
__C __export infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr, infiniDevice_t device);
__C __export infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr);
__C __export infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle);
......
#include "cpu_handle.h"
namespace infiniop::cpu {
namespace device::cpu {
Handle::Handle() : InfiniopHandle{INFINI_DEVICE_CPU, 0} {}
infiniStatus_t Handle::create(InfiniopHandle **handle_ptr) {
infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int) {
*handle_ptr = new Handle{};
return INFINI_STATUS_SUCCESS;
}
} // namespace infiniop::cpu
} // namespace device::cpu
......@@ -3,13 +3,15 @@
#include "../../handle.h"
namespace infiniop::cpu {
namespace device::cpu {
class Handle : public InfiniopHandle {
Handle();
public:
static infiniStatus_t create(InfiniopHandle **handle_ptr);
static infiniStatus_t create(InfiniopHandle **handle_ptr, int);
};
} // namespace infiniop::cpu
} // namespace device::cpu
#endif
#ifndef __INFINIOP_COMMON_CUDA_H__
#define __INFINIOP_COMMON_CUDA_H__
#define MAX_THREADS_PER_BLOCK 1024
#define MAX_WARP_PER_BLOCK 32
#define WARP_SIZE 32
#include "../../../utils.h"
#include <iostream>
#define CHECK_CUDA_OR_RETURN(API, ERROR) CHECK_API_OR(API, cudaSuccess, return ERROR)
#define CHECK_CUDA(API) CHECK_INTERNAL(API, cudaSuccess)
#define CHECK_CUDNN(API) CHECK_INTERNAL(API, CUDNN_STATUS_SUCCESS)
#include "../pool.h"
#include "cuda_handle.h"
#include "infinicore.h"
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include <cudnn.h>
#include <memory>
struct InfiniopCudaHandle {
infiniDevice_t device;
int device_id;
std::shared_ptr<Pool<cublasHandle_t>> cublas_handle_pool;
std::shared_ptr<Pool<cudnnHandle_t>> cudnn_handle_pool;
cudaDeviceProp prop;
int compute_capability_major;
int compute_capability_minor;
};
template <typename T>
void use_cublas(std::shared_ptr<Pool<cublasHandle_t>> &pool, cudaStream_t stream, const T &f) {
auto handle = pool->pop();
if (!handle) {
cublasCreate(&(*handle));
}
cublasSetStream(*handle, stream);
f(*handle);
pool->push(std::move(*handle));
}
template <typename T>
void use_cudnn(std::shared_ptr<Pool<cudnnHandle_t>> &pool, cudaStream_t stream, const T &f) {
auto handle = pool->pop();
if (!handle) {
cudnnCreate(&(*handle));
}
cudnnSetStream(*handle, stream);
f(*handle);
pool->push(std::move(*handle));
}
inline cudnnDataType_t getCudnnDtype(infiniDtype_t dt) {
switch (dt) {
case INFINI_DTYPE_F16:
return CUDNN_DATA_HALF;
case INFINI_DTYPE_F32:
return CUDNN_DATA_FLOAT;
case INFINI_DTYPE_F64:
return CUDNN_DATA_DOUBLE;
case INFINI_DTYPE_BF16:
return CUDNN_DATA_BFLOAT16;
case INFINI_DTYPE_I8:
return CUDNN_DATA_INT8;
case INFINI_DTYPE_I32:
return CUDNN_DATA_INT32;
case INFINI_DTYPE_I64:
return CUDNN_DATA_INT64;
case INFINI_DTYPE_U8:
return CUDNN_DATA_UINT8;
default:
return CUDNN_DATA_FLOAT;
}
}
// return the memory offset of original tensor, given the flattened index of
// broadcasted tensor
inline __device__ __host__ size_t indexToReducedOffset(
size_t flat_index,
size_t ndim,
const ptrdiff_t *broadcasted_strides,
const ptrdiff_t *target_strides) {
size_t res = 0;
for (size_t i = 0; i < ndim; ++i) {
res += flat_index / broadcasted_strides[i] * target_strides[i];
flat_index %= broadcasted_strides[i];
}
return res;
}
// get the memory offset of the given element in a tensor given its flat index
inline __device__ __host__ size_t indexToOffset(
size_t flat_index,
size_t ndim,
const size_t *shape,
const ptrdiff_t *strides) {
size_t res = 0;
for (size_t i = ndim; i-- > 0;) {
res += (flat_index % shape[i]) * strides[i];
flat_index /= shape[i];
}
return res;
}
#endif // __INFINIOP_COMMON_CUDA_H__
#include "common_cuda.cuh"
infiniStatus_t createCudaHandle(infiniopCudaHandle_t *handle_ptr, infiniDevice_t cuda_device_type) {
// Create a new cublas handle pool
int device_id = 0;
CHECK_CUDA_OR_RETURN(cudaGetDevice(&device_id), INFINI_STATUS_DEVICE_NOT_INITIALIZED);
auto pool = std::make_shared<Pool<cublasHandle_t>>();
cublasHandle_t handle;
cublasCreate(&handle);
pool->push(std::move(handle));
// create a cudnn handle pool
auto cudnn_pool = std::make_shared<Pool<cudnnHandle_t>>();
cudnnHandle_t cudnn_handle;
CHECK_CUDNN(cudnnCreate(&cudnn_handle));
cudnn_pool->push(std::move(cudnn_handle));
// set CUDA device property
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, device_id);
// set device compute capability numbers
int capability_major;
int capability_minor;
cudaDeviceGetAttribute(&capability_major, cudaDevAttrComputeCapabilityMajor, device_id);
cudaDeviceGetAttribute(&capability_minor, cudaDevAttrComputeCapabilityMinor, device_id);
*handle_ptr = new InfiniopCudaHandle{
cuda_device_type,
device_id,
std::move(pool),
std::move(cudnn_pool),
std::move(prop),
capability_major,
capability_minor,
};
#include "cuda_handle.cuh"
return INFINI_STATUS_SUCCESS;
namespace device::cuda {
Handle::Handle(infiniDevice_t device, int device_id)
: InfiniopHandle{device, device_id},
_internal(std::make_shared<Handle::Internal>()) {}
auto Handle::internal() const -> const std::shared_ptr<Internal> & {
return _internal;
}
infiniStatus_t destroyCudaHandle(infiniopCudaHandle_t handle_ptr) {
handle_ptr->cublas_handle_pool = nullptr;
handle_ptr->cudnn_handle_pool = nullptr;
delete handle_ptr;
template <typename T>
using Fn = std::function<void(T)>;
void Handle::Internal::use_cublas(cudaStream_t stream, const Fn<cublasHandle_t> &f) const {
auto handle = blas_handles.pop();
if (!handle) {
cublasCreate(&(*handle));
}
cublasSetStream(*handle, stream);
f(*handle);
blas_handles.push(std::move(*handle));
}
void Handle::Internal::use_cudnn(cudaStream_t stream, const Fn<cudnnHandle_t> &f) const {
auto handle = dnn_handles.pop();
if (!handle) {
cudnnCreate(&(*handle));
}
cudnnSetStream(*handle, stream);
f(*handle);
dnn_handles.push(std::move(*handle));
}
cudnnDataType_t getCudnnDtype(infiniDtype_t dt) {
switch (dt) {
case INFINI_DTYPE_F16:
return CUDNN_DATA_HALF;
case INFINI_DTYPE_F32:
return CUDNN_DATA_FLOAT;
case INFINI_DTYPE_F64:
return CUDNN_DATA_DOUBLE;
case INFINI_DTYPE_BF16:
return CUDNN_DATA_BFLOAT16;
case INFINI_DTYPE_I8:
return CUDNN_DATA_INT8;
case INFINI_DTYPE_I32:
return CUDNN_DATA_INT32;
case INFINI_DTYPE_I64:
return CUDNN_DATA_INT64;
case INFINI_DTYPE_U8:
return CUDNN_DATA_UINT8;
default:
return CUDNN_DATA_FLOAT;
}
}
namespace nvidia {
Handle::Handle(int device_id)
: cuda::Handle(INFINI_DEVICE_NVIDIA, device_id) {}
infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
*handle_ptr = new Handle(device_id);
return INFINI_STATUS_SUCCESS;
}
} // namespace nvidia
} // namespace device::cuda
#ifndef __INFINIOP_CUDA_INTERNAL_H__
#define __INFINIOP_CUDA_INTERNAL_H__
#include "../pool.h"
#include "cuda_handle.h"
#include <cublas_v2.h>
#include <cudnn.h>
#include <functional>
namespace device::cuda {
class Handle::Internal {
Pool<cublasHandle_t> blas_handles;
Pool<cudnnHandle_t> dnn_handles;
public:
void use_cublas(cudaStream_t stream, const std::function<void(cublasHandle_t)> &f) const;
void use_cudnn(cudaStream_t stream, const std::function<void(cudnnHandle_t)> &f) const;
};
cudnnDataType_t getCudnnDtype(infiniDtype_t dt);
// return the memory offset of original tensor, given the flattened index of broadcasted tensor
__forceinline__ __device__ __host__ size_t
indexToReducedOffset(
size_t flat_index,
size_t ndim,
const ptrdiff_t *broadcasted_strides,
const ptrdiff_t *target_strides) {
size_t res = 0;
for (size_t i = 0; i < ndim; ++i) {
res += flat_index / broadcasted_strides[i] * target_strides[i];
flat_index %= broadcasted_strides[i];
}
return res;
}
// get the memory offset of the given element in a tensor given its flat index
__forceinline__ __device__ __host__ size_t
indexToOffset(
size_t flat_index,
size_t ndim,
const size_t *shape,
const ptrdiff_t *strides) {
size_t res = 0;
for (size_t i = ndim; i-- > 0;) {
res += (flat_index % shape[i]) * strides[i];
flat_index /= shape[i];
}
return res;
}
} // namespace device::cuda
#endif // __INFINIOP_CUDA_INTERNAL_H__
......@@ -2,12 +2,32 @@
#define __INFINIOP_CUDA_HANDLE_H__
#include "../../handle.h"
#include <memory>
struct InfiniopCudaHandle;
typedef struct InfiniopCudaHandle *infiniopCudaHandle_t;
namespace device::cuda {
infiniStatus_t createCudaHandle(infiniopCudaHandle_t *handle_ptr, infiniDevice_t cuda_device_type);
struct Handle : public InfiniopHandle {
class Internal;
auto internal() const -> const std::shared_ptr<Internal> &;
infiniStatus_t destroyCudaHandle(infiniopCudaHandle_t handle_ptr);
protected:
Handle(infiniDevice_t device, int device_id);
#endif
private:
std::shared_ptr<Internal> _internal;
};
namespace nvidia {
class Handle : public cuda::Handle {
Handle(int device_id);
public:
static infiniStatus_t create(InfiniopHandle **handle_ptr, int device_id);
};
} // namespace nvidia
} // namespace device::cuda
#endif // __INFINIOP_CUDA_HANDLE_H__
#include "infiniop/handle.h"
#include "../../utils.h"
#include "infinirt.h"
#ifdef ENABLE_CPU_API
#include "cpu/cpu_handle.h"
#endif
......@@ -15,21 +18,25 @@
#include "kunlun/kunlun_handle.h"
#endif
__C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr,
infiniDevice_t device) {
__C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) {
if (handle_ptr == nullptr) {
return INFINI_STATUS_NULL_POINTER;
}
infiniDevice_t device;
int device_id;
CHECK_STATUS(infinirtGetDevice(&device, &device_id));
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return device::NAMESPACE::Handle::create(handle_ptr, device_id)
switch (device) {
#ifdef ENABLE_CPU_API
case INFINI_DEVICE_CPU:
return infiniop::cpu::Handle::create(handle_ptr);
CREATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_CUDA_API
case INFINI_DEVICE_NVIDIA: {
return createCudaHandle((infiniopCudaHandle_t *)handle_ptr, device);
}
CREATE(INFINI_DEVICE_NVIDIA, cuda::nvidia);
#endif
#ifdef ENABLE_CAMBRICON_API
case INFINI_DEVICE_CAMBRICON: {
......@@ -46,25 +53,27 @@ __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr,
return createKunlunHandle((infiniopKunlunHandle_t *)handle_ptr);
}
#endif
}
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CREATE
}
__C infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle) {
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<infiniop::NAMESPACE::Handle *>(handle); \
return INFINI_STATUS_SUCCESS;
delete reinterpret_cast<device::NAMESPACE::Handle *>(handle); \
return INFINI_STATUS_SUCCESS
switch (handle->device) {
#ifdef ENABLE_CPU_API
DELETE(INFINI_DEVICE_CPU, cpu)
DELETE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_CUDA_API
case INFINI_DEVICE_NVIDIA: {
return destroyCudaHandle((infiniopCudaHandle_t)handle);
}
DELETE(INFINI_DEVICE_NVIDIA, cuda::nvidia);
#endif
#ifdef ENABLE_CAMBRICON_API
case INFINI_DEVICE_CAMBRICON: {
......@@ -84,5 +93,6 @@ __C infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle) {
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef DELETE
}
......@@ -6,7 +6,7 @@
#include <aclnnop/aclnn_matmul.h>
#include <aclnnop/level2/aclnn_gemm.h>
namespace matmul::ascend {
namespace op::matmul::ascend {
struct Descriptor::Opaque {
mutable aclOpExecutor *executor;
......@@ -135,4 +135,4 @@ infiniStatus_t Descriptor::calculate(
return INFINI_STATUS_SUCCESS;
}
} // namespace matmul::ascend
} // namespace op::matmul::ascend
......@@ -3,7 +3,7 @@
#include "../../../devices/bang/common_bang.h"
#include <cnnl_extra.h>
namespace matmul::bang {
namespace op::matmul::bang {
struct Descriptor::Opaque {
cnnlMatMulDescriptor_t op;
......@@ -157,4 +157,4 @@ infiniStatus_t Descriptor::calculate(
return INFINI_STATUS_SUCCESS;
}
} // namespace matmul::bang
} // namespace op::matmul::bang
......@@ -5,7 +5,8 @@
#include "../../tensor.h"
#include <algorithm>
namespace matmul {
namespace op::matmul {
struct BlasMatrix {
size_t ndim;
size_t batch;
......@@ -118,6 +119,7 @@ struct MatmulInfo {
k = a_matrix.cols;
}
};
} // namespace matmul
} // namespace op::matmul
#endif // __BLAS_H__
......@@ -2,7 +2,7 @@
#include "../../../devices/cpu/common_cpu.h"
#include "../../../devices/cpu/cpu_handle.h"
namespace matmul::cpu {
namespace op::matmul::cpu {
Descriptor::~Descriptor() = default;
......@@ -12,7 +12,7 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
auto handle = reinterpret_cast<infiniop::cpu::Handle *>(handle_);
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
auto dtype = c_desc->dtype();
if (dtype != INFINI_DTYPE_F16 && dtype != INFINI_DTYPE_F32) {
......@@ -96,4 +96,4 @@ infiniStatus_t Descriptor::calculate(
}
}
} // namespace matmul::cpu
} // namespace op::matmul::cpu
#include "../../../devices/cuda/common_cuda.cuh"
#include "../../../devices/cuda/cuda_handle.cuh"
#include "matmul_cuda.cuh"
namespace matmul::cuda {
namespace op::matmul::cuda {
struct Descriptor::Opaque {
std::shared_ptr<Pool<cublasHandle_t>> cublas_handle_pool;
std::shared_ptr<device::cuda::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
......@@ -17,7 +17,7 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
auto handle = reinterpret_cast<infiniopCudaHandle_t>(handle_);
auto handle = reinterpret_cast<device::cuda::nvidia::Handle *>(handle_);
auto dtype = c_desc->dtype();
if (dtype != INFINI_DTYPE_F16 && dtype != INFINI_DTYPE_F32) {
......@@ -32,7 +32,7 @@ infiniStatus_t Descriptor::create(
*desc_ptr = new Descriptor(
dtype, info, 0,
new Opaque{handle->cublas_handle_pool},
new Opaque{handle->internal()},
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
......@@ -76,7 +76,7 @@ infiniStatus_t Descriptor::calculate(
auto op_a = _info.a_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
auto op_b = _info.b_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
use_cublas(_opaque->cublas_handle_pool,
_opaque->internal->use_cublas(
(cudaStream_t)stream,
[&](cublasHandle_t handle) {
cublasGemmStridedBatchedEx(
......@@ -107,4 +107,4 @@ infiniStatus_t Descriptor::calculate(
return INFINI_STATUS_SUCCESS;
}
} // namespace matmul::cuda
} // namespace op::matmul::cuda
......@@ -2,7 +2,7 @@
#include "../../../devices/kunlun/common_kunlun.h"
#include "../../utils.h"
namespace matmul::kunlun {
namespace op::matmul::kunlun {
struct Descriptor::Opaque {
std::shared_ptr<Pool<xdnnHandle_t>> xdnn_handle_pool;
......@@ -110,4 +110,4 @@ infiniStatus_t Descriptor::calculate(
}
}
} // namespace matmul::kunlun
} // namespace op::matmul::kunlun
......@@ -46,7 +46,7 @@
#define DESCRIPTOR(NAMESPACE) \
\
namespace matmul::NAMESPACE { \
namespace op::matmul::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
......
......@@ -27,9 +27,9 @@ __C infiniStatus_t infiniopCreateMatmulDescriptor(
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return matmul::NAMESPACE::Descriptor::create( \
return op::matmul::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<matmul::NAMESPACE::Descriptor **>(desc_ptr), \
reinterpret_cast<op::matmul::NAMESPACE::Descriptor **>(desc_ptr), \
c_desc, \
a_desc, \
b_desc)
......@@ -66,7 +66,7 @@ infiniopGetMatmulWorkspaceSize(
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<const matmul::NAMESPACE::Descriptor *>(desc)->workspace_size; \
*size = reinterpret_cast<const op::matmul::NAMESPACE::Descriptor *>(desc)->workspace_size; \
return INFINI_STATUS_SUCCESS
switch (desc->device_type) {
......@@ -106,7 +106,7 @@ __C infiniStatus_t infiniopMatmul(
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const matmul::NAMESPACE::Descriptor *>(desc) \
return reinterpret_cast<const op::matmul::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, \
c, beta, \
a, b, alpha, \
......@@ -142,7 +142,7 @@ infiniopDestroyMatmulDescriptor(infiniopMatmulDescriptor_t desc) {
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const matmul::NAMESPACE::Descriptor *>(desc); \
delete reinterpret_cast<const op::matmul::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
......
......@@ -37,6 +37,20 @@ class Handle(Structure):
infiniopHandle_t = POINTER(Handle)
class InfiniLib:
def __init__(self, librt, libop):
self.librt = librt
self.libop = libop
def __getattr__(self, name):
if hasattr(self.libop, name):
return getattr(self.libop, name)
elif hasattr(self.librt, name):
return getattr(self.librt, name)
else:
raise AttributeError(f"Attribute {name} not found in library")
# Open operators library
def open_lib():
def find_library_in_ld_path(subdir, library_name):
......@@ -51,14 +65,22 @@ def open_lib():
system_name = platform.system()
# Load the library
if system_name == "Windows":
library_path = find_library_in_ld_path("bin", "infiniop.dll")
libop_path = find_library_in_ld_path("bin", "infiniop.dll")
librt_path = find_library_in_ld_path("bin", "infinirt.dll")
elif system_name == "Linux":
library_path = find_library_in_ld_path("lib", "libinfiniop.so")
libop_path = find_library_in_ld_path("lib", "libinfiniop.so")
librt_path = find_library_in_ld_path("lib", "libinfinirt.so")
assert (
library_path is not None
libop_path is not None
), f"Cannot find infiniop.dll or libinfiniop.so. Check if INFINI_ROOT is set correctly."
lib = ctypes.CDLL(library_path)
assert (
librt_path is not None
), f"Cannot find infinirt.dll or libinfinirt.so. Check if INFINI_ROOT is set correctly."
librt = ctypes.CDLL(librt_path)
libop = ctypes.CDLL(libop_path)
lib = InfiniLib(librt, libop)
lib.infiniopCreateTensorDescriptor.argtypes = [
POINTER(infiniopTensorDescriptor_t),
c_uint64,
......@@ -69,9 +91,11 @@ def open_lib():
lib.infiniopCreateTensorDescriptor.restype = c_int
lib.infiniopDestroyTensorDescriptor.argtypes = [infiniopTensorDescriptor_t]
lib.infiniopDestroyTensorDescriptor.restype = c_int
lib.infiniopCreateHandle.argtypes = [POINTER(infiniopHandle_t), c_int, c_int]
lib.infiniopCreateHandle.argtypes = [POINTER(infiniopHandle_t)]
lib.infiniopCreateHandle.restype = c_int
lib.infiniopDestroyHandle.argtypes = [infiniopHandle_t]
lib.infiniopDestroyHandle.restype = c_int
lib.infinirtSetDevice.argtypes = [c_int, c_int]
lib.infinirtSetDevice.restype = c_int
return lib
......@@ -57,9 +57,9 @@ def create_workspace(size, torch_device):
return torch.zeros(size=(size,), dtype=torch.uint8, device=torch_device)
def create_handle(lib, device, id=0):
def create_handle(lib):
handle = infiniopHandle_t()
check_error(lib.infiniopCreateHandle(ctypes.byref(handle), device, id))
check_error(lib.infiniopCreateHandle(ctypes.byref(handle)))
return handle
......@@ -392,7 +392,8 @@ def test_operator(lib, device, test_func, test_cases, tensor_dtypes):
to be passed to `test_func`.
- tensor_dtypes (list): A list of tensor data types (e.g., `torch.float32`) to test.
"""
handle = create_handle(lib, device)
lib.infinirtSetDevice(device, ctypes.c_int(0))
handle = create_handle(lib)
try:
for test_case in test_cases:
for tensor_dtype in tensor_dtypes:
......@@ -435,6 +436,7 @@ def get_test_devices(args):
devices_to_test.append(InfiniDeviceEnum.ASCEND)
if args.kunlun:
import torch_xmlir
devices_to_test.append(InfiniDeviceEnum.KUNLUN)
if not devices_to_test:
devices_to_test = [InfiniDeviceEnum.CPU]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment