Commit 9faf1ffc authored by zhangyue's avatar zhangyue
Browse files

issue/209: kunlun elementwise first commit

parent 1a4cfb99
......@@ -16,7 +16,7 @@ typedef XPUStream kunlunStream_t;
typedef XPUEvent kunlunEvent_t;
typedef xdnn::Context *xdnnHandle_t;
#define CHECK_XDNN(API) CHECK_INTERNAL(API, XPU_SUCCESS)
#define CHECK_KUNLUN(API) CHECK_INTERNAL(API, XPU_SUCCESS)
namespace device::kunlun {
......
#ifndef __INFINIOP_KUNLUN_COMMON_H__
#define __INFINIOP_KUNLUN_COMMON_H__
#ifndef __INFINIOP_KUNLUN_KERNEL_COMMON_H__
#define __INFINIOP_KUNLUN_KERNEL_COMMON_H__
// This header file will only be include by .xpu file
#include "kunlun_kernel_dtype.h"
#include "xpu/kernel/xtdk.h"
#include "xpu/kernel/xtdk_io.h"
#include "xpu/kernel/xtdk_math.h"
#include "xpu/kernel/xtdk_simd.h"
#include "xpu/runtime.h"
namespace device::kunlun::kernel {
// Get mask for kunlun xpu 512bit register calculation
// if data is not enough to 512bit, padding zero and use
// mask to identify real data
......@@ -26,6 +29,37 @@ inline __device__ void atomicAddF32(__shared_ptr__ float *ptr, float value) {
}
}
inline __device__ size_t indexToReducedOffset(
size_t flat_index,
size_t ndim,
const ptrdiff_t *broadcasted_strides,
const ptrdiff_t *target_strides) {
size_t res = 0;
for (size_t i = 0; i < ndim; ++i) {
res += flat_index / broadcasted_strides[i].value * target_strides[i].value;
flat_index %= broadcasted_strides[i].value;
mfence();
}
return res;
}
inline __device__ size_t indexToOffset(
size_t flat_index,
size_t ndim,
const _size_t *shape,
const ptrdiff_t *strides) {
size_t res = 0;
for (size_t i = ndim; i-- > 0;) {
res += (flat_index % shape[i].value) * strides[i].value;
flat_index /= shape[i].value;
mfence();
}
return res;
}
} // namespace device::kunlun::kernel
// TODO: atomicAddF16
// TODO: atomicAddI8
#endif
#ifndef __INFINIOP_KUNLUN_DTYPE_H__
#define __INFINIOP_KUNLUN_DTYPE_H__
#include "xpu/kernel/xtdk.h"
#include "xpu/kernel/xtdk_io.h"
#include "xpu/kernel/xtdk_math.h"
#include "xpu/kernel/xtdk_simd.h"
#include "xpu/runtime.h"
// kunlun ptrdiff_t* is used to save ptrdiff_t array
// copied from host
typedef struct ptrdiff_t {
long value; // 32 bit
long padding; // 32 bit
} ptrdiff_t;
// same as ptrdiff
typedef struct _size_t {
size_t value;
size_t padding;
} _size_t;
#endif
#ifndef __INFINIOP_ELEMENTWISE_KUNLUN_H__
#define __INFINIOP_ELEMENTWISE_KUNLUN_H__
#include "../../../utils.h"
#include "../../devices/kunlun/kunlun_handle.h"
#include "elementwise_kunlun_api.h"
namespace op::elementwise::kunlun {
struct DeviceImpl::Opaque {
std::shared_ptr<device::kunlun::Handle::Internal> internal;
Opaque(const std::shared_ptr<device::kunlun::Handle::Internal> &internal_)
: internal(internal_) {}
template <size_t N, typename Op, typename Tdata, typename... Args>
infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
kunlunStream_t stream,
Args &&...args) {
auto output_size = info.getOutputSize();
if (output_size == 0) {
return INFINI_STATUS_SUCCESS;
}
// Device pointers
const void **d_inputs_arr = nullptr;
const bool *d_input_contiguous = nullptr;
const bool *d_input_broadcasted = nullptr;
const size_t *d_output_shape = nullptr;
const ptrdiff_t *d_output_strides = nullptr;
const size_t *d_input_shapes = nullptr;
const ptrdiff_t *d_input_strides = nullptr;
CHECK_STATUS(infoToDevice<N>(info, workspace, inputs.data(), d_inputs_arr,
d_input_contiguous, d_input_broadcasted,
d_output_shape, d_output_strides,
d_input_shapes, d_input_strides));
Op::template launch<Tdata>(
output_size,
info.getNdim(),
info.isOutputContiguous(),
reinterpret_cast<const void *>(d_input_contiguous),
reinterpret_cast<const void *>(d_input_broadcasted),
reinterpret_cast<const void *>(d_output_shape),
reinterpret_cast<const void *>(d_input_shapes),
reinterpret_cast<const void *>(d_output_strides),
reinterpret_cast<const void *>(d_input_strides),
output,
reinterpret_cast<const void *const *>(d_inputs_arr),
stream,
args...);
// std::forward<Args>(args)...);
return INFINI_STATUS_SUCCESS;
}
private:
template <size_t N>
infiniStatus_t infoToDevice(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
const void *const *h_inputs_arr,
const void **&d_inputs_arr,
const bool *&d_input_contiguous,
const bool *&d_input_broadcasted,
const size_t *&d_output_shape,
const ptrdiff_t *&d_output_strides,
const size_t *&d_input_shapes,
const ptrdiff_t *&d_input_strides) const {
constexpr auto input_size = N;
const auto ndim = info.getNdim();
constexpr auto input_arr_size = N * sizeof(*h_inputs_arr);
const int8_t *info_meta_start = info.getMetaStart();
const int8_t *d_meta_start = reinterpret_cast<int8_t *>(workspace) + input_arr_size;
// copy the input pointer array and meta to device
CHECK_KUNLUN(xpu_memcpy(workspace, h_inputs_arr, input_arr_size, XPU_HOST_TO_DEVICE));
CHECK_KUNLUN(xpu_memcpy((void *)d_meta_start, info_meta_start, info.getMetaMemSize(), XPU_HOST_TO_DEVICE));
// offset/assign the pointers
d_inputs_arr = reinterpret_cast<const void **>(workspace);
d_output_shape = reinterpret_cast<const size_t *>(d_meta_start);
d_output_strides = reinterpret_cast<const ptrdiff_t *>(d_output_shape + ndim);
d_input_shapes = reinterpret_cast<const size_t *>(d_output_strides + ndim);
d_input_strides = reinterpret_cast<const ptrdiff_t *>(d_input_shapes + input_size * ndim);
d_input_contiguous = reinterpret_cast<const bool *>(d_input_strides + input_size * ndim);
d_input_broadcasted = reinterpret_cast<const bool *>(d_input_contiguous + input_size);
return INFINI_STATUS_SUCCESS;
}
};
template <typename... Args>
utils::Result<DeviceImpl *> DeviceImpl::create(Args &&...args) {
auto opaque = std::make_shared<Opaque>(std::forward<Args>(args)...);
return utils::Result<DeviceImpl *>(new DeviceImpl(opaque));
}
template <typename Op, typename Tdata, typename... Args>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
constexpr size_t N = Op::num_inputs;
return _opaque->calculateImpl<N, Op, Tdata>(
info, workspace, output, inputs,
reinterpret_cast<kunlunStream_t>(stream),
std::forward<Args>(args)...);
}
} // namespace op::elementwise::kunlun
// Template for kunlun kernel interface declaration
#define LAUNCH_ELEMENTWISE_KERNEL(OpName) \
template <typename Tdata, typename... Args> \
void launch##OpName##Kernel( \
size_t output_size, \
size_t ndim, \
bool output_contiguous, \
const void *input_contiguous, \
const void *input_broadcasted, \
const void *output_shape, \
const void *input_shapes, \
const void *output_strides, \
const void *input_strides, \
void *output, \
const void *const *inputs, \
XPUStream stream, \
Args... args);
#endif
#ifndef __INFINIOP_ELEMENTWISE_KUNLUN_API_H__
#define __INFINIOP_ELEMENTWISE_KUNLUN_API_H__
#include "../elementwise.h"
namespace op::elementwise::kunlun {
class DeviceImpl final {
struct Opaque;
std::shared_ptr<Opaque> _opaque;
DeviceImpl(std::shared_ptr<Opaque> opaque) : _opaque(std::move(opaque)) {}
public:
~DeviceImpl() = default;
template <typename... Args>
static utils::Result<DeviceImpl *> create(Args &&...args);
template <typename Op, typename Tdata, typename... Args>
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args);
};
} // namespace op::elementwise::kunlun
#define CREATE_ELEMENTWISE_KUNLUN_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \
\
auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \
CHECK_RESULT(info_result); \
auto info = info_result.take(); \
auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \
\
auto device_impl_result = op::elementwise::kunlun::DeviceImpl::create(HANDLE->internal()); \
CHECK_RESULT(device_impl_result); \
\
*desc_ptr = new Descriptor( \
DTYPE, \
std::move(info), \
std::move(device_impl_result.take()), \
workspace_size, \
HANDLE->device, \
HANDLE->device_id);
#endif
#ifndef __INFINIOP_ELEMENTWISE_KUNLUN_XPU__
#define __INFINIOP_ELEMENTWISE_KUNLUN_XPU__
#include "../../devices/kunlun/kunlun_kernel_common.h"
#include "xpu/kernel/xtdk_io.h"
// #include <cstdio>
using namespace device::kunlun::kernel;
/**
* @brief Computes input tile offset
*/
struct InputIndexer {
size_t idx;
size_t ndim;
const bool *input_contiguous;
const bool *input_broadcasted;
const _size_t *input_shapes;
const ptrdiff_t *input_strides;
const ptrdiff_t *output_strides;
__device__ size_t operator()(size_t input_id) const {
return input_contiguous[input_id]
? idx
: (input_broadcasted[input_id]
? indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
}
};
/**
* @brief Computes the output index in memory, accounting for strides if non-contiguous.
*
* @param idx Linear index.
* @param is_contiguous Whether the output tensor is contiguous.
* @param ndim Number of dimensions.
* @param shape Shape of the output tensor.
* @param strides Strides of the output tensor.
* @return Memory offset index.
*/
inline __device__ size_t
getOutputIndex(size_t idx,
bool is_contiguous,
size_t ndim,
const _size_t *shape,
const ptrdiff_t *strides) {
return is_contiguous ? idx : indexToOffset(idx, ndim, shape, strides);
}
template <size_t N, typename Op, typename Tdata, typename... Args>
__device__ void launchOp(
__global_ptr__ Tdata **typed_inputs, // gm pointer
__global_ptr__ Tdata *output, // gm pointer output
Tdata *inputs_buf, // local mem buffer
size_t *input_indexes,
size_t output_index,
Args... args) {
static_assert(N == Op::num_inputs, "template N is not equal to Op::num_inputs!\n");
#pragma unroll
// Copy inputs to buf
for (size_t i = 0; i < N; i++) {
auto gm = typed_inputs[i] + input_indexes[i];
auto lm = inputs_buf + i;
GM2LM_ASYNC(gm, lm, 1 * sizeof(Tdata));
}
mfence();
// Calculate elementwise
// Inputs save all operands
Tdata out = Op{}(inputs_buf, args...);
// Copy out to gm
LM2GM_ASYNC(&out, output + output_index, 1 * sizeof(Tdata));
mfence();
}
template <size_t N, typename Op, typename Tdata, typename... Args>
__global__ void elementwiseKernel(
size_t output_size,
size_t ndim,
bool output_contiguous,
const bool *input_contiguous_gm,
const bool *input_broadcasted_gm,
const _size_t *output_shape_gm,
const _size_t *input_shapes_gm,
const ptrdiff_t *output_strides_gm,
const ptrdiff_t *input_strides_gm,
Tdata *output,
const void *const *inputs,
Args... args) {
// Only support 3 mode elementwise
static_assert(N < 4, "elementwise Kernel support mode < 4 calculate");
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
return;
}
int thread_id = ncores * cluster_id() + cid;
int nthreads = ncores * cluster_num();
// Cast input gm pointer type
auto typed_inputs = reinterpret_cast<const __global_ptr__ Tdata *const __global_ptr__ *>(inputs);
const int BUFF_SIZE = 64;
// Input data cache
__local__ Tdata inputs_buf[N];
// Input contiguous/broadcasted flags
__local__ bool input_contiguous[N];
__local__ bool input_broadcasted[N];
// Input shape/strides
__local__ _size_t input_shapes[N * ndim];
__local__ ptrdiff_t input_strides[N * ndim];
// Output shape/strides
__local__ _size_t output_shape[ndim];
__local__ ptrdiff_t output_strides[ndim];
// Inputs gm ptr buf
__local__ __global_ptr__ Tdata *typed_inputs_ptr[N];
// Load from gm
GM2LM_ASYNC(input_contiguous_gm, input_contiguous, N * sizeof(bool));
GM2LM_ASYNC(input_broadcasted_gm, input_broadcasted, N * sizeof(bool));
GM2LM_ASYNC(input_shapes_gm, input_shapes, N * ndim * sizeof(_size_t));
GM2LM_ASYNC(input_strides_gm, input_strides, N * ndim * sizeof(ptrdiff_t));
GM2LM_ASYNC(output_shape_gm, output_shape, ndim * sizeof(_size_t));
GM2LM_ASYNC(output_strides_gm, output_strides, ndim * sizeof(ptrdiff_t));
GM2LM_ASYNC(typed_inputs, typed_inputs_ptr, N * sizeof(__global_ptr__ Tdata *));
mfence();
int len_per_loop = min(BUFF_SIZE, roundup_div(output_size, nthreads));
for (int start = thread_id * len_per_loop; start < output_size; start += nthreads * len_per_loop) {
size_t read_len = min(len_per_loop, output_size - start);
for (int idx = start; idx < start + read_len; ++idx) {
size_t out_idx = getOutputIndex(static_cast<size_t>(idx), output_contiguous,
ndim, output_shape, output_strides);
InputIndexer indexer{static_cast<size_t>(idx), ndim, input_contiguous, input_broadcasted,
input_shapes, input_strides, output_strides};
// Get index offset for every operand
size_t indexes[N];
for (size_t i = 0; i < N; i++) {
indexes[i] = indexer(i);
}
// Launch operater
launchOp<N, Op, Tdata>(&typed_inputs_ptr[0], output, inputs_buf, indexes, out_idx, args...);
}
}
sync_cluster();
}
#define LAUNCH_ELEMENTWISE_KERNEL_IMPL(OpName, Op) \
template <typename Tdata, typename... Args> \
void launch##OpName##Kernel( \
size_t output_size, \
size_t ndim, \
bool output_contiguous, \
const void *input_contiguous, \
const void *input_broadcasted, \
const void *output_shape, \
const void *input_shapes, \
const void *output_strides, \
const void *input_strides, \
void *output, \
const void *const *inputs, \
XPUStream stream, \
Args... args) { \
elementwiseKernel<Op::num_inputs, Op, Tdata><<<8, 64, stream>>>( \
output_size, ndim, output_contiguous, \
reinterpret_cast<const bool *>(input_contiguous), \
reinterpret_cast<const bool *>(input_broadcasted), \
reinterpret_cast<const _size_t *>(output_shape), \
reinterpret_cast<const _size_t *>(input_shapes), \
reinterpret_cast<const ptrdiff_t *>(output_strides), \
reinterpret_cast<const ptrdiff_t *>(input_strides), \
reinterpret_cast<Tdata *>(output), inputs, args...); \
}
#define LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(OpName, T, ...) \
template void launch##OpName##Kernel<T, ##__VA_ARGS__>( \
size_t output_size, \
size_t ndim, \
bool output_contiguous, \
const void *input_contiguous, \
const void *input_broadcasted, \
const void *output_shape, \
const void *input_shapes, \
const void *output_strides, \
const void *input_strides, \
void *output, \
const void *const *inputs, \
XPUStream stream, \
##__VA_ARGS__);
#endif
......@@ -62,7 +62,7 @@ infiniStatus_t calculate(
(kunlunStream_t)stream,
[&](xdnnHandle_t handle) {
for (size_t i = 0; i < info.batch; i++) {
CHECK_XDNN((xdnn::fc_fusion<Tdata, Tdata, Tdata, int16_t>(
CHECK_KUNLUN((xdnn::fc_fusion<Tdata, Tdata, Tdata, int16_t>(
handle,
(Tdata *)((char *)a + i * info.a_matrix.stride * unit),
(Tdata *)((char *)b + i * info.b_matrix.stride * unit),
......
#ifndef __RMS_NORM_KUNLUN_KERNEL_XPU__
#define __RMS_NORM_KUNLUN_KERNEL_XPU__
#include "../../../devices/kunlun/kunlun_common.h"
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "../../../reduce/kunlun/reduce_kunlun.h"
using namespace device::kunlun::kernel;
// Element wise mul used in x * w
static inline __device__ void elementwiseMulRms(float *x, float *w, float *y, int count, float rms) {
int remain = count % 16;
......
#include "swiglu_kunlun.h"
// Op interface declare
LAUNCH_ELEMENTWISE_KERNEL(SwiGLU)
namespace op::swiglu::kunlun {
typedef struct SwiGLUOp {
static constexpr size_t num_inputs = 2;
template <typename Tdata, typename... Args>
static infiniStatus_t launch(Args... args) {
launchSwiGLUKernel<Tdata>(args...);
return INFINI_STATUS_SUCCESS;
}
} SwiGLUOp;
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::kunlun::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &up_desc = input_desc_vec.at(0);
const auto &gate_desc = input_desc_vec.at(1);
const auto &out_shape = out_desc->shape();
const auto &up_shape = up_desc->shape();
const auto &gate_shape = gate_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F32);
CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape);
// create KUNLUN elementwise descriptor
CREATE_ELEMENTWISE_KUNLUN_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_F32:
return _device_info->calculate<SwiGLUOp, float>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::swiglu::kunlun
\ No newline at end of file
#ifndef __SWIGLU_KUNLUN_H__
#define __SWIGLU_KUNLUN_H__
#include "../../../elementwise/kunlun/elementwise_kunlun.h"
ELEMENTWISE_DESCRIPTOR(swiglu, kunlun)
#endif // __SWIGLU_KUNLUN_H__
#ifndef __SWIGLU_KUNLUN_H__
#define __SWIGLU_KUNLUN_H__
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "../../../elementwise/kunlun/elementwise_kunlun_kernel.h"
/// @brief Define swiglu op for local mem
typedef struct SwiGLUOp {
private:
template <typename T>
inline __device__ T sigmoid(T x) const {
return 1.0f / (1.0f + exp(-x));
}
public:
// This static number must be set in other Ops
static constexpr size_t num_inputs = 2;
template <typename T>
inline __device__ T operator()(const T *inputs) const {
T up = inputs[0];
T gate = inputs[1];
T out = gate * sigmoid(gate) * up;
return out;
}
} SwiGLUOp;
// Defination for swiglu kernel interface
LAUNCH_ELEMENTWISE_KERNEL_IMPL(SwiGLU, SwiGLUOp)
// Template instantiate
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(SwiGLU, float)
#endif // __SWIGLU_KUNLUN_H__
......@@ -8,6 +8,9 @@
#ifdef ENABLE_CUDA_API
#include "cuda/swiglu_cuda.cuh"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/swiglu_kunlun.h"
#endif
__C infiniStatus_t infiniopCreateSwiGLUDescriptor(
infiniopHandle_t handle,
......@@ -33,6 +36,9 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
#ifdef ENABLE_CUDA_API
CREATE(INFINI_DEVICE_NVIDIA, cuda);
#endif
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
return bangCreateSwiGLUDescriptor((BangHandle_t)handle,
......@@ -80,6 +86,9 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des
#ifdef ENABLE_CUDA_API
GET(INFINI_DEVICE_NVIDIA, cuda)
#endif
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun)
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
return bangGetSwiGLUWorkspaceSize((SwiGLUBangDescriptor_t)desc, size);
......@@ -127,6 +136,9 @@ __C infiniStatus_t infiniopSwiGLU(
#ifdef ENABLE_CUDA_API
CALCULATE(INFINI_DEVICE_NVIDIA, cuda);
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
return bangSwiGLU((SwiGLUBangDescriptor_t)desc, c, a, b, stream);
......@@ -168,6 +180,9 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) {
#ifdef ENABLE_CUDA_API
DELETE(INFINI_DEVICE_NVIDIA, cuda);
#endif
#ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
return bangDestroySwiGLUDescriptor((SwiGLUBangDescriptor_t)desc);
......
#ifndef __INFINIOP_REDUCE_KUNLUN_H__
#define __INFINIOP_REDUCE_KUNLUN_H__
#include "../../devices/kunlun/kunlun_common.h"
#include "../../devices/kunlun/kunlun_kernel_common.h"
namespace op::common_kunlun::reduce_op {
using namespace device::kunlun::kernel;
// Use 16 floats instruction to calculate reduce
// data_ptr is the pointer of LM
static inline __device__ float sumSquaredF32(float *data_ptr, int count) {
......
......@@ -83,6 +83,7 @@ def test(
b_stride=None,
c_stride=None,
dtype=torch.float16,
sync=None
):
print(
f"Testing Gemm on {torch_device} with alpha:{alpha}, beta:{beta},"
......@@ -104,6 +105,9 @@ def test(
]
a_tensor, b_tensor, c_tensor = [to_tensor(tensor, lib) for tensor in [a, b, c]]
if sync is not None:
sync()
descriptor = infiniopGemmDescriptor_t()
check_error(
lib.infiniopCreateGemmDescriptor(
......
......@@ -423,6 +423,7 @@ def test_operator(lib, device, test_func, test_cases, tensor_dtypes):
infiniDeviceEnum_str_map[device],
*test_case,
tensor_dtype,
get_sync_func(device)
)
finally:
destroy_handle(lib, handle)
......@@ -471,3 +472,14 @@ def get_test_devices(args):
devices_to_test = [InfiniDeviceEnum.CPU]
return devices_to_test
def get_sync_func(device):
import torch
if device == "cpu":
sync = None
else:
sync = getattr(torch, infiniDeviceEnum_str_map[device]).synchronize
return sync
......@@ -27,17 +27,17 @@ _TEST_CASES = [
# y_shape, x_shape, w_shape, y_stride, x_stride, w_dtype
((1, 4), (1, 4), (4,), None, None, torch.float32),
((16, 2048), (16, 2048), (2048,), None, None, torch.float32),
((16, 2048), (16, 2048), (2048,), None, None, torch.float16),
# ((16, 2048), (16, 2048), (2048,), None, None, torch.float16),
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float32),
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float16),
# ((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float16),
]
# x types used for testing
_TENSOR_DTYPES = [torch.float16]
_TENSOR_DTYPES = [torch.float32]
# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float16: {"atol": 1e-3, "rtol": 1e-3},
torch.float32: {"atol": 1e-3, "rtol": 1e-3},
}
DEBUG = False
......@@ -72,6 +72,7 @@ def test(
x_stride,
w_dtype=torch.float16,
dtype=torch.float16,
sync=None
):
print(
f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}"
......@@ -89,9 +90,11 @@ def test(
rearrange_if_needed(tensor, stride)
for tensor, stride in zip([x, y], [x_stride, y_stride])
]
x_tensor, y_tensor, w_tensor = [to_tensor(tensor, lib) for tensor in [x, y, w]]
if sync is not None:
sync()
descriptor = infiniopRMSNormDescriptor_t()
check_error(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment