Unverified Commit cbd573f5 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #169 from InfiniTensor/issue/48

Issue/48 Rope CPU & CUDA
parents 95fd5c1b 39c133c4
...@@ -12,11 +12,12 @@ os.chdir(PROJECT_DIR) ...@@ -12,11 +12,12 @@ os.chdir(PROJECT_DIR)
def run_tests(args): def run_tests(args):
failed = [] failed = []
for test in [ for test in [
"causal_softmax.py",
"gemm.py", "gemm.py",
"random_sample.py",
"rms_norm.py", "rms_norm.py",
"causal_softmax.py", "rope.py",
"swiglu.py", "swiglu.py",
"random_sample.py",
]: ]:
result = subprocess.run( result = subprocess.run(
f"python {test} {args}", text=True, encoding="utf-8", shell=True f"python {test} {args}", text=True, encoding="utf-8", shell=True
......
#include "rope_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
namespace op::rope::cpu {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
CHECK_RESULT(info);
// Create descriptor
*desc_ptr = new Descriptor(
info.take(),
0,
nullptr,
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename Tdata, typename Tindex>
infiniStatus_t calculateRoPE(const RoPEInfo &info,
Tdata *y,
const Tdata *x,
const Tindex *pos_ids,
const Tdata *sin_table,
const Tdata *cos_table) {
#pragma omp parallel for
for (ptrdiff_t h = 0; h < ptrdiff_t(info.nhead); h++) {
for (size_t tok = 0; tok < info.seqlen; tok++) {
size_t x_offset = tok * info.x_stride_seqlen + h * info.x_stride_nhead;
size_t y_offset = tok * info.y_stride_seqlen + h * info.y_stride_nhead;
size_t pos_id = size_t(pos_ids[tok]);
size_t table_offset = pos_id * info.table_dim;
for (size_t i = 0; i < info.table_dim; i++) {
size_t pos0 = 2 * i;
size_t pos1 = 2 * i + 1;
if constexpr (std::is_same<Tdata, fp16_t>::value) {
float x0 = utils::cast<float>(x[x_offset + pos0]),
x1 = utils::cast<float>(x[x_offset + pos1]),
sin__ = utils::cast<float>(sin_table[table_offset + i]),
cos__ = utils::cast<float>(cos_table[table_offset + i]);
y[y_offset + pos0] = utils::cast<fp16_t>(x0 * cos__ - x1 * sin__);
y[y_offset + pos1] = utils::cast<fp16_t>(x0 * sin__ + x1 * cos__);
} else {
Tdata x0 = x[x_offset + pos0],
x1 = x[x_offset + pos1],
sin__ = sin_table[table_offset + i],
cos__ = cos_table[table_offset + i];
y[y_offset + pos0] = x0 * cos__ - x1 * sin__;
y[y_offset + pos1] = x0 * sin__ + x1 * cos__;
}
}
}
}
return INFINI_STATUS_SUCCESS;
}
#define CALCULATE_ROPE(TDATA, TINDEX) \
calculateRoPE(_info, (TDATA *)y, (const TDATA *)x, (const TINDEX *)pos_ids, (const TDATA *)sin_table, (const TDATA *)cos_table)
#define ROPE_TYPE(TDATA) \
switch (_info.pos_type) { \
case INFINI_DTYPE_U8: \
return CALCULATE_ROPE(TDATA, uint8_t); \
case INFINI_DTYPE_U16: \
return CALCULATE_ROPE(TDATA, uint16_t); \
case INFINI_DTYPE_U32: \
return CALCULATE_ROPE(TDATA, uint32_t); \
case INFINI_DTYPE_U64: \
return CALCULATE_ROPE(TDATA, uint64_t); \
case INFINI_DTYPE_I8: \
return CALCULATE_ROPE(TDATA, int8_t); \
case INFINI_DTYPE_I16: \
return CALCULATE_ROPE(TDATA, int16_t); \
case INFINI_DTYPE_I32: \
return CALCULATE_ROPE(TDATA, int32_t); \
case INFINI_DTYPE_I64: \
return CALCULATE_ROPE(TDATA, int64_t); \
default: \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *y,
const void *x,
const void *pos_ids,
const void *sin_table,
const void *cos_table,
void *stream) const {
switch (_info.data_type) {
case INFINI_DTYPE_F16:
ROPE_TYPE(fp16_t);
case INFINI_DTYPE_F32:
ROPE_TYPE(float);
case INFINI_DTYPE_F64:
ROPE_TYPE(double);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
#undef ROPE_TYPE
#undef CALCULATE_ROPE
} // namespace op::rope::cpu
#ifndef __INFINIOP_ROPE_CPU_H__
#define __INFINIOP_ROPE_CPU_H__
#include "../rope.h"
DESCRIPTOR(cpu)
#endif // __INFINIOP_ROPE_CPU_H__
#include "../../../devices/cuda/cuda_common.cuh"
#include "rope_cuda.cuh"
#include "rope_cuda_kernel.cuh"
namespace op::rope::cuda {
struct Descriptor::Opaque {
std::shared_ptr<device::cuda::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
auto handle = reinterpret_cast<device::cuda::Handle *>(handle_);
auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
CHECK_RESULT(info);
// Create descriptor
*desc_ptr = new Descriptor(
info.take(),
0,
new Opaque{reinterpret_cast<device::cuda::Handle *>(handle)->internal()},
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename Tdata, typename Tindex>
infiniStatus_t calculateRoPE(const RoPEInfo &info,
int block_size,
Tdata *y,
const Tdata *x,
const Tindex *pos_ids,
const Tdata *sin_table,
const Tdata *cos_table,
cudaStream_t stream) {
auto dimx = unsigned int(info.seqlen),
dimy = unsigned int(info.nhead);
int nthreads = std::max(int(info.table_dim), block_size);
ropeThreadPerItem<<<dim3(dimx, dimy), nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim,
info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead);
return INFINI_STATUS_SUCCESS;
}
#define CALCULATE_ROPE(TDATA, TINDEX) \
calculateRoPE(_info, \
_opaque->internal->maxThreadsPerBlock(), \
(TDATA *)y, \
(const TDATA *)x, \
(const TINDEX *)pos_ids, \
(const TDATA *)sin_table, \
(const TDATA *)cos_table, \
(cudaStream_t)stream)
#define ROPE_TYPE(TDATA) \
switch (_info.pos_type) { \
case INFINI_DTYPE_U8: \
return CALCULATE_ROPE(TDATA, uint8_t); \
case INFINI_DTYPE_U16: \
return CALCULATE_ROPE(TDATA, uint16_t); \
case INFINI_DTYPE_U32: \
return CALCULATE_ROPE(TDATA, uint32_t); \
case INFINI_DTYPE_U64: \
return CALCULATE_ROPE(TDATA, uint64_t); \
case INFINI_DTYPE_I8: \
return CALCULATE_ROPE(TDATA, int8_t); \
case INFINI_DTYPE_I16: \
return CALCULATE_ROPE(TDATA, int16_t); \
case INFINI_DTYPE_I32: \
return CALCULATE_ROPE(TDATA, int32_t); \
case INFINI_DTYPE_I64: \
return CALCULATE_ROPE(TDATA, int64_t); \
default: \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *y,
const void *x,
const void *pos_ids,
const void *sin_table,
const void *cos_table,
void *stream) const {
switch (_info.data_type) {
case INFINI_DTYPE_F16:
ROPE_TYPE(half);
case INFINI_DTYPE_F32:
ROPE_TYPE(float);
case INFINI_DTYPE_F64:
ROPE_TYPE(double);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
#undef ROPE_TYPE
#undef CALCULATE_ROPE
} // namespace op::rope::cuda
#ifndef __INFINIOP_ROPE_CUDA_H__
#define __INFINIOP_ROPE_CUDA_H__
#include "../rope.h"
DESCRIPTOR(cuda)
#endif // __INFINIOP_ROPE_CUDA_H__
#ifndef __INFINIOP_ROPE_CUDA_KERNEL_CUH__
#define __INFINIOP_ROPE_CUDA_KERNEL_CUH__
#include "../../../devices/cuda/cuda_kernel_common.cuh"
template <typename Tdata, typename Tindex, typename Tangle>
INFINIOP_CUDA_KERNEL ropeThreadPerItem(
Tdata *y_,
const Tdata *x_,
const Tindex *__restrict__ pos_ids,
const Tangle *__restrict__ sin_table,
const Tangle *__restrict__ cos_table,
size_t table_dim,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
auto y_offset = blockIdx.x * y_stride_seqlen + blockIdx.y * y_stride_nhead;
auto x_offset = blockIdx.x * x_stride_seqlen + blockIdx.y * x_stride_nhead;
size_t pos_id = size_t(pos_ids[blockIdx.x]);
auto table_offset = pos_id * table_dim;
for (size_t i = threadIdx.x; i < table_dim; i += blockDim.x) {
Tangle sin__ = sin_table[table_offset + i],
cos__ = cos_table[table_offset + i];
if constexpr (std::is_same<Tdata, half>::value) {
auto &y = reinterpret_cast<half2 &>(y_[y_offset + 2 * i]);
auto &x = reinterpret_cast<const half2 &>(x_[x_offset + 2 * i]);
Tangle y0 = x.x * cos__ - x.y * sin__,
y1 = x.x * sin__ + x.y * cos__;
y = half2(y0, y1);
} else {
Tangle x0 = x_[x_offset + 2 * i],
x1 = x_[x_offset + 2 * i + 1];
y_[y_offset + 2 * i] = Tdata(x0 * cos__ - x1 * sin__);
y_[y_offset + 2 * i + 1] = Tdata(x0 * sin__ + x1 * cos__);
}
}
}
#endif
...@@ -2,6 +2,13 @@ ...@@ -2,6 +2,13 @@
#include "../../handle.h" #include "../../handle.h"
#include "infiniop/ops/rope.h" #include "infiniop/ops/rope.h"
#ifdef ENABLE_CPU_API
#include "cpu/rope_cpu.h"
#endif
#ifdef ENABLE_CUDA_API
#include "cuda/rope_cuda.cuh"
#endif
__C infiniStatus_t infiniopCreateRoPEDescriptor( __C infiniStatus_t infiniopCreateRoPEDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
infiniopRoPEDescriptor_t *desc_ptr, infiniopRoPEDescriptor_t *desc_ptr,
...@@ -10,20 +17,24 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor( ...@@ -10,20 +17,24 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor(
infiniopTensorDescriptor_t pos_ids, infiniopTensorDescriptor_t pos_ids,
infiniopTensorDescriptor_t sin_table, infiniopTensorDescriptor_t sin_table,
infiniopTensorDescriptor_t cos_table) { infiniopTensorDescriptor_t cos_table) {
switch (handle->device) {
#ifdef ENABLE_CPU
case DevCpu:
return cpuCreateRoPEDescriptor((CpuHandle_t)handle,
(RoPECpuDescriptor_t *)desc_ptr, t,
pos_ids, sin_table, cos_table);
#endif
#ifdef ENABLE_NV_GPU
case DevNvGpu: {
return cudaCreateRoPEDescriptor((CudaHandle_t)handle,
(RoPECudaDescriptor_t *)desc_ptr, t,
pos_ids, sin_table, cos_table);
}
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::rope::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::rope::NAMESPACE::Descriptor **>(desc_ptr), \
y, \
x, \
pos_ids, \
sin_table, \
cos_table)
switch (handle->device) {
#ifdef ENABLE_CPU_API
CREATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_CUDA_API
CREATE(INFINI_DEVICE_NVIDIA, cuda);
#endif #endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
...@@ -54,21 +65,25 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor( ...@@ -54,21 +65,25 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor(
} }
#endif #endif
} }
#undef CREATE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc, __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc,
size_t *size) { size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<const op::rope::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS
switch (desc->device_type) { switch (desc->device_type) {
#ifdef ENABLE_CPU #ifdef ENABLE_CPU_API
case DevCpu: GET(INFINI_DEVICE_CPU, cpu);
return cpuGetRoPEWorkspaceSize((RoPECpuDescriptor_t)desc, size);
#endif #endif
#ifdef ENABLE_NV_GPU #ifdef ENABLE_CUDA_API
case DevNvGpu: { GET(INFINI_DEVICE_NVIDIA, cuda);
return cudaGetRoPEWorkspaceSize((RoPECudaDescriptor_t)desc, size);
}
#endif #endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
...@@ -91,6 +106,9 @@ __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc, ...@@ -91,6 +106,9 @@ __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc,
} }
#endif #endif
} }
#undef GET
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
...@@ -100,22 +118,22 @@ __C infiniStatus_t infiniopRoPE( ...@@ -100,22 +118,22 @@ __C infiniStatus_t infiniopRoPE(
size_t workspace_size, size_t workspace_size,
void *y, void *y,
const void *x, const void *x,
void const *pos_ids, const void *pos_ids,
void const *sin_table, const void *sin_table,
void const *cos_table, const void *cos_table,
void *stream) { void *stream) {
switch (desc->device_type) {
#ifdef ENABLE_CPU
case DevCpu:
return cpuRoPE((RoPECpuDescriptor_t)desc, workspace, workspace_size, t,
pos_ids, sin_table, cos_table, stream);
#endif
#ifdef ENABLE_NV_GPU
case DevNvGpu: {
return cudaRoPE((RoPECudaDescriptor_t)desc, workspace, workspace_size,
t, pos_ids, sin_table, cos_table, stream);
}
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::rope::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, y, x, pos_ids, sin_table, cos_table, stream)
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_CUDA_API
CALCULATE(INFINI_DEVICE_NVIDIA, cuda);
#endif #endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
...@@ -143,21 +161,26 @@ __C infiniStatus_t infiniopRoPE( ...@@ -143,21 +161,26 @@ __C infiniStatus_t infiniopRoPE(
} }
#endif #endif
} }
#undef CALCULATE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t __C infiniStatus_t
infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) { infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) {
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const op::rope::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) { switch (desc->device_type) {
#ifdef ENABLE_CPU #ifdef ENABLE_CPU_API
case DevCpu: DELETE(INFINI_DEVICE_CPU, cpu);
return cpuDestroyRoPEDescriptor((RoPECpuDescriptor_t)desc);
#endif #endif
#ifdef ENABLE_NV_GPU #ifdef ENABLE_CUDA_API
case DevNvGpu: { DELETE(INFINI_DEVICE_NVIDIA, cuda);
return cudaDestroyRoPEDescriptor((RoPECudaDescriptor_t)desc);
}
#endif #endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
...@@ -180,5 +203,8 @@ infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) { ...@@ -180,5 +203,8 @@ infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) {
} }
#endif #endif
} }
#undef DELETE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
#ifndef __ROPE_H__
#define __ROPE_H__
#include "../../../utils.h"
#include "../../operator.h"
#include "../../tensor.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::rope::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
RoPEInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
RoPEInfo info, \
size_t workspace_size_, \
Opaque *opaque, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size_) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t x_desc, \
infiniopTensorDescriptor_t pos_desc, \
infiniopTensorDescriptor_t sin_desc, \
infiniopTensorDescriptor_t cos_desc); \
\
infiniStatus_t calculate( \
void *workspace, \
size_t workspace_size, \
void *y, \
const void *x, \
const void *pos_ids, \
const void *sin_table, \
const void *cos_table, \
void *stream) const; \
}; \
}
class RoPEInfo {
private:
RoPEInfo() = default;
public:
infiniDtype_t data_type, pos_type;
size_t seqlen, nhead, dhead, table_len, table_dim;
ptrdiff_t
y_stride_seqlen,
y_stride_nhead,
x_stride_seqlen,
x_stride_nhead;
static utils::Result<RoPEInfo> createRoPEInfo(
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
CHECK_OR_RETURN(
y_desc != nullptr && pos_desc != nullptr && sin_desc != nullptr && cos_desc != nullptr,
INFINI_STATUS_NULL_POINTER);
const infiniDtype_t data_type = y_desc->dtype();
const infiniDtype_t pos_type = pos_desc->dtype();
CHECK_OR_RETURN(data_type == x_desc->dtype() && data_type == sin_desc->dtype() && data_type == cos_desc->dtype(),
INFINI_STATUS_BAD_TENSOR_DTYPE);
CHECK_DTYPE(data_type, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_DTYPE_ANY_INT(pos_type);
CHECK_OR_RETURN(y_desc->ndim() == 3
&& x_desc->ndim() == 3
&& pos_desc->ndim() == 1
&& sin_desc->ndim() == 2
&& cos_desc->ndim() == 2,
INFINI_STATUS_BAD_TENSOR_SHAPE);
const auto seqlen = y_desc->dim(0),
nhead = y_desc->dim(1),
dhead = y_desc->dim(2),
table_len = sin_desc->dim(0),
table_dim = sin_desc->dim(1);
CHECK_OR_RETURN(seqlen == x_desc->dim(0)
&& seqlen == pos_desc->dim(0)
&& nhead == x_desc->dim(1) && dhead == x_desc->dim(2)
&& table_len == cos_desc->dim(0) && table_dim == cos_desc->dim(1),
INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_OR_RETURN(dhead == table_dim * 2, INFINI_STATUS_BAD_TENSOR_SHAPE);
// Last dimension of x and y must be contiguous
CHECK_OR_RETURN(y_desc->stride(2) == 1 && x_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
// sin table and cos table must be totally contiguous
CHECK_OR_RETURN(sin_desc->stride(1) == 1
&& cos_desc->stride(1) == 1
&& sin_desc->stride(0) == ptrdiff_t(table_dim)
&& cos_desc->stride(0) == ptrdiff_t(table_dim),
INFINI_STATUS_BAD_TENSOR_STRIDES);
return utils::Result<RoPEInfo>(RoPEInfo{
data_type,
pos_type,
seqlen,
nhead,
dhead,
table_len,
table_dim,
y_desc->stride(0),
y_desc->stride(1),
x_desc->stride(0),
x_desc->stride(1),
});
}
};
#endif
...@@ -3,6 +3,16 @@ ...@@ -3,6 +3,16 @@
#include <iostream> #include <iostream>
#include <tuple> #include <tuple>
#define CHECK_OR_RETURN(CONDITION, ERROR) \
do { \
if (!(CONDITION)) { \
std::cerr << "Check Failed: `(" << #CONDITION << ")` is False" \
<< " from " << __func__ \
<< " at " << __FILE__ << ":" << __LINE__ << std::endl; \
return ERROR; \
} \
} while (0)
#define CHECK_API_OR(API, EXPECT, ACTION) \ #define CHECK_API_OR(API, EXPECT, ACTION) \
do { \ do { \
auto api_result_ = (API); \ auto api_result_ = (API); \
...@@ -31,6 +41,11 @@ ...@@ -31,6 +41,11 @@
return INFINI_STATUS_BAD_TENSOR_DTYPE); \ return INFINI_STATUS_BAD_TENSOR_DTYPE); \
} while (0) } while (0)
#define CHECK_DTYPE_ANY_INT(DT) \
CHECK_DTYPE(DT, \
INFINI_DTYPE_U8, INFINI_DTYPE_U16, INFINI_DTYPE_U32, INFINI_DTYPE_U64, \
INFINI_DTYPE_I8, INFINI_DTYPE_I16, INFINI_DTYPE_I32, INFINI_DTYPE_I64);
#define CHECK_SAME_VEC(ERR, FIRST, ...) \ #define CHECK_SAME_VEC(ERR, FIRST, ...) \
do { \ do { \
for (const auto &shape___ : {__VA_ARGS__}) { \ for (const auto &shape___ : {__VA_ARGS__}) { \
......
...@@ -10,7 +10,7 @@ def check_error(status): ...@@ -10,7 +10,7 @@ def check_error(status):
raise Exception("Error code " + str(status)) raise Exception("Error code " + str(status))
def to_tensor(tensor, lib): def to_tensor(tensor, lib, force_unsigned=False):
""" """
Convert a PyTorch tensor to a library Tensor(descriptor, data). Convert a PyTorch tensor to a library Tensor(descriptor, data).
""" """
...@@ -37,6 +37,16 @@ def to_tensor(tensor, lib): ...@@ -37,6 +37,16 @@ def to_tensor(tensor, lib):
InfiniDtype.U64 if tensor.dtype == torch.uint64 else InfiniDtype.U64 if tensor.dtype == torch.uint64 else
None None
) )
if force_unsigned:
dt = (
InfiniDtype.U8 if dt == InfiniDtype.I8 else
InfiniDtype.U16 if dt == InfiniDtype.I16 else
InfiniDtype.U32 if dt == InfiniDtype.I32 else
InfiniDtype.U64 if dt == InfiniDtype.I64 else
dt
)
# fmt: on # fmt: on
assert dt is not None assert dt is not None
# Create TensorDecriptor # Create TensorDecriptor
......
import torch import torch
import ctypes import ctypes
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p
from libinfiniop import ( from libinfiniop import (
InfiniDtype,
infiniopHandle_t, infiniopHandle_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
open_lib, open_lib,
...@@ -18,30 +17,49 @@ from libinfiniop import ( ...@@ -18,30 +17,49 @@ from libinfiniop import (
profile_operation, profile_operation,
synchronize_device, synchronize_device,
) )
from enum import Enum, auto
# ============================================================================== # ==============================================================================
# Configuration (Internal Use Only) # Configuration (Internal Use Only)
# ============================================================================== # ==============================================================================
# These are not meant to be imported from other modules # These are not meant to be imported from other modules
_TEST_CASES = [ _TEST_CASES_ = [
# (t_shape, t_strides) # (shape, x_strides, y_strides)
((1, 32, 128), None), ((1, 32, 128), None, None),
((1, 32, 64), None), ((10, 32, 64), None, None),
# 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心 # 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心
# 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持 # 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
((4, 1, 32), None), ((4, 1, 32), (64, 64, 1), None),
((1, 32, 128), None), ((11, 33, 128), None, (8000, 200, 1)),
((3, 32, 128), (8000, 200, 1)), ((3, 32, 128), (8000, 200, 1), (7000, 128, 1)),
] ]
# Data types used for testing # Data types used for testing
_TENSOR_DTYPES = [torch.float16] _TENSOR_DTYPES = [torch.float16, torch.float32]
# Tolerance map for different data types # Tolerance map for different data types
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
torch.float16: {"atol": 1e-4, "rtol": 1e-2}, torch.float16: {"atol": 1e-3, "rtol": 1e-2},
torch.float32: {"atol": 1e-4, "rtol": 1e-3},
} }
class Inplace(Enum):
OUT_OF_PLACE = auto()
INPLACE_X = auto()
_INPLACE = [
Inplace.OUT_OF_PLACE,
Inplace.INPLACE_X,
]
_TEST_CASES = [
test_case + (inplace_item,)
for test_case in _TEST_CASES_
for inplace_item in _INPLACE
]
DEBUG = False DEBUG = False
PROFILE = False PROFILE = False
NUM_PRERUN = 10 NUM_PRERUN = 10
...@@ -55,23 +73,21 @@ class RoPEDescriptor(Structure): ...@@ -55,23 +73,21 @@ class RoPEDescriptor(Structure):
infiniopRoPEDescriptor_t = POINTER(RoPEDescriptor) infiniopRoPEDescriptor_t = POINTER(RoPEDescriptor)
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): def rotary_embedding(t, sin, cos, torch_device):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[0], x.shape[-1])
shape = [d if i == 0 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def rotary_embedding(t, pos, theta, torch_device):
dh = t.shape[2] dh = t.shape[2]
dt = t.dtype
assert dh % 2 == 0, "Embedding dimension must be even." assert dh % 2 == 0, "Embedding dimension must be even."
t_even = t[..., 0::2] # [seq_len, n_head, dh // 2] t_even = t[..., 0::2] # [seq_len, n_head, dh // 2]
t_odd = t[..., 1::2] # [seq_len, n_head, dh // 2] t_odd = t[..., 1::2] # [seq_len, n_head, dh // 2]
freqs = (1.0 / (theta ** (torch.arange(0, dh, 2).float() / dh))).to(torch_device) cos = cos.unsqueeze(1) # [seq_len, 1, dh // 2]
freqs = torch.outer(pos, freqs) # [seq_len, dh // 2] sin = sin.unsqueeze(1) # [seq_len, 1, dh // 2]
cos = torch.cos(freqs).unsqueeze(1) # [seq_len, 1, dh // 2] if torch_device == "cpu":
sin = torch.sin(freqs).unsqueeze(1) # [seq_len, 1, dh // 2] (t_even, t_odd, cos, sin) = (
t_even.float(),
t_odd.float(),
cos.float(),
sin.float(),
)
t_out_even = t_even * cos - t_odd * sin t_out_even = t_even * cos - t_odd * sin
t_out_odd = t_even * sin + t_odd * cos t_out_odd = t_even * sin + t_odd * cos
...@@ -80,51 +96,56 @@ def rotary_embedding(t, pos, theta, torch_device): ...@@ -80,51 +96,56 @@ def rotary_embedding(t, pos, theta, torch_device):
t_out[..., 0::2] = t_out_even t_out[..., 0::2] = t_out_even
t_out[..., 1::2] = t_out_odd t_out[..., 1::2] = t_out_odd
return t_out return t_out.to(dt).to(torch_device)
def sin_cos_table(max_seq_len, dim, torch_device, theta): def sin_cos_table(pos, dim, torch_device, theta, dtype):
pos = torch.arange( assert dim % 2 == 0, "Embedding dimension must be even."
0, max_seq_len, dtype=torch.float32, device=torch.device(torch_device)
)
freqs = (1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))).to( freqs = (1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))).to(
torch_device torch_device
) )
# (a0, a1, a2) -> (a0, a0, a1, a1, a2, a2)
freqs = torch.repeat_interleave(freqs, repeats=2)
angles = torch.outer(pos, freqs) angles = torch.outer(pos, freqs)
return torch.sin(angles), torch.cos(angles) return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype)
def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): def test(
lib,
handle,
torch_device,
shape,
x_strides=None,
y_strides=None,
inplace=Inplace.OUT_OF_PLACE,
dtype=torch.float32,
):
if inplace == Inplace.INPLACE_X:
y_strides = x_strides
print( print(
f"Testing Rotary Positional Embedding on {torch_device} with shape:{shape} strides:{strides} and dtype:{dtype}" f"Testing Rotary Positional Embedding on {torch_device} with shape:{shape} x_strides:{x_strides} y_strides:{y_strides} and dtype:{dtype} inplace:{inplace}"
) )
t = torch.rand(shape, dtype=dtype) x = torch.rand(shape, dtype=dtype).to(torch_device)
x = rearrange_if_needed(x, x_strides)
if inplace == Inplace.INPLACE_X:
y = x
else:
y = torch.rand(shape, dtype=dtype).to(torch_device)
y = rearrange_if_needed(y, y_strides)
theta = 1e5
pos = torch.arange(0, x.shape[0], dtype=torch.int32).to(torch_device)
sin_table, cos_table = sin_cos_table(pos, x.shape[2], x.device, theta, dtype)
t = rearrange_if_needed(t, strides) ans = rotary_embedding(x, sin_table, cos_table, torch_device)
posTmp = torch.arange(0, t.shape[0]).to(torch_device)
pos = torch.zeros(2 * posTmp.shape[0], dtype=torch.int32)
for i in range(posTmp.shape[0]):
pos[2 * i] = posTmp[i]
pos[2 * i + 1] = 0
pos = pos.to(torch_device)
theta = 1e4
ans = rotary_embedding(t, posTmp, theta, torch_device)
descriptor = infiniopRoPEDescriptor_t() descriptor = infiniopRoPEDescriptor_t()
# 2x table length for test x_tensor, pos_tensor, sin_table_tensor, cos_table_tensor = [
sin_table, cos_table = sin_cos_table(t.shape[0] * 2, t.shape[2], t.device, theta) to_tensor(tensor, lib, force_unsigned=True)
for tensor in [x, pos, sin_table, cos_table]
t_tensor, sin_table_tensor, cos_table_tensor = [
to_tensor(tensor, lib) for tensor in [t, sin_table, cos_table]
] ]
if inplace == Inplace.INPLACE_X:
pos_tensor = to_tensor(pos[: t.shape[0]], lib) y_tensor = x_tensor
pos_tensor.descriptor.contents.dtype = InfiniDtype.U64 else:
y_tensor = to_tensor(y, lib)
if torch_device == "npu": if torch_device == "npu":
synchronize_device(torch_device) synchronize_device(torch_device)
...@@ -133,7 +154,8 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): ...@@ -133,7 +154,8 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
lib.infiniopCreateRoPEDescriptor( lib.infiniopCreateRoPEDescriptor(
handle, handle,
ctypes.byref(descriptor), ctypes.byref(descriptor),
t_tensor.descriptor, y_tensor.descriptor,
x_tensor.descriptor,
pos_tensor.descriptor, pos_tensor.descriptor,
sin_table_tensor.descriptor, sin_table_tensor.descriptor,
cos_table_tensor.descriptor, cos_table_tensor.descriptor,
...@@ -141,14 +163,14 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): ...@@ -141,14 +163,14 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
) )
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for tensor in [t_tensor, pos_tensor, sin_table_tensor, cos_table_tensor]: for tensor in [y_tensor, x_tensor, pos_tensor, sin_table_tensor, cos_table_tensor]:
tensor.descriptor.contents.invalidate() tensor.destroyDesc(lib)
workspace_size = c_uint64(0) workspace_size = c_uint64(0)
check_error( check_error(
lib.infiniopGetRoPEWorkspaceSize(descriptor, ctypes.byref(workspace_size)) lib.infiniopGetRoPEWorkspaceSize(descriptor, ctypes.byref(workspace_size))
) )
workspace = create_workspace(workspace_size.value, t.device) workspace = create_workspace(workspace_size.value, x.device)
def lib_rope(): def lib_rope():
check_error( check_error(
...@@ -156,7 +178,8 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): ...@@ -156,7 +178,8 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
descriptor, descriptor,
workspace.data_ptr() if workspace is not None else None, workspace.data_ptr() if workspace is not None else None,
workspace_size.value, workspace_size.value,
t_tensor.data, y_tensor.data,
x_tensor.data,
pos_tensor.data, pos_tensor.data,
sin_table_tensor.data, sin_table_tensor.data,
cos_table_tensor.data, cos_table_tensor.data,
...@@ -168,13 +191,13 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): ...@@ -168,13 +191,13 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG: if DEBUG:
debug(t, ans, atol=atol, rtol=rtol) debug(y, ans, atol=atol, rtol=rtol)
assert torch.allclose(t, ans, atol=atol, rtol=rtol) assert torch.allclose(y, ans, atol=atol, rtol=rtol)
if PROFILE: if PROFILE:
profile_operation( profile_operation(
"PyTorch", "PyTorch",
lambda: rotary_embedding(t, posTmp, theta, torch_device), lambda: rotary_embedding(x, pos, theta, torch_device),
torch_device, torch_device,
NUM_PRERUN, NUM_PRERUN,
NUM_ITERATIONS, NUM_ITERATIONS,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment