Commit bf4f41b6 authored by PanZezhong's avatar PanZezhong
Browse files

issue/48 rope cuda

parent 07279a25
#include "../../../devices/cuda/cuda_common.cuh"
#include "rope_cuda.cuh"
#include "rope_cuda_kernel.cuh"
namespace op::rope::cuda {
struct Descriptor::Opaque {
std::shared_ptr<device::cuda::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
auto handle = reinterpret_cast<device::cuda::Handle *>(handle_);
auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
CHECK_RESULT(info);
// Create descriptor
*desc_ptr = new Descriptor(
info.take(),
0,
new Opaque{reinterpret_cast<device::cuda::Handle *>(handle)->internal()},
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename Tdata, typename Tindex>
infiniStatus_t calculateRoPE(const RoPEInfo &info,
int block_size,
Tdata *y,
const Tdata *x,
const Tindex *pos_ids,
const Tdata *sin_table,
const Tdata *cos_table,
cudaStream_t stream) {
auto dimx = unsigned int(info.seqlen),
dimy = unsigned int(info.nhead);
int nthreads = std::max(int(info.table_dim), block_size);
ropeThreadPerItem<<<dim3(dimx, dimy), nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim,
info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead);
return INFINI_STATUS_SUCCESS;
}
#define CALCULATE_ROPE(TDATA, TINDEX) \
calculateRoPE(_info, \
_opaque->internal->maxThreadsPerBlock(), \
(TDATA *)y, \
(const TDATA *)x, \
(const TINDEX *)pos_ids, \
(const TDATA *)sin_table, \
(const TDATA *)cos_table, \
(cudaStream_t)stream)
#define ROPE_TYPE(TDATA) \
switch (_info.pos_type) { \
case INFINI_DTYPE_U8: \
return CALCULATE_ROPE(TDATA, uint8_t); \
case INFINI_DTYPE_U16: \
return CALCULATE_ROPE(TDATA, uint16_t); \
case INFINI_DTYPE_U32: \
return CALCULATE_ROPE(TDATA, uint32_t); \
case INFINI_DTYPE_U64: \
return CALCULATE_ROPE(TDATA, uint64_t); \
default: \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *y,
const void *x,
const void *pos_ids,
const void *sin_table,
const void *cos_table,
void *stream) const {
switch (_info.data_type) {
case INFINI_DTYPE_F16:
ROPE_TYPE(half);
case INFINI_DTYPE_F32:
ROPE_TYPE(float);
case INFINI_DTYPE_F64:
ROPE_TYPE(double);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
#undef ROPE_TYPE
#undef CALCULATE_ROPE
} // namespace op::rope::cuda
#ifndef __INFINIOP_ROPE_CUDA_H__
#define __INFINIOP_ROPE_CUDA_H__
#include "../rope.h"
DESCRIPTOR(cuda)
#endif // __INFINIOP_ROPE_CUDA_H__
#ifndef __INFINIOP_ROPE_CUDA_KERNEL_CUH__
#define __INFINIOP_ROPE_CUDA_KERNEL_CUH__
#include "../../../devices/cuda/cuda_kernel_common.cuh"
template <typename Tdata, typename Tindex, typename Tangle>
INFINIOP_CUDA_KERNEL ropeThreadPerItem(
Tdata *y_,
const Tdata *x_,
const Tindex *__restrict__ pos_ids,
const Tangle *__restrict__ sin_table,
const Tangle *__restrict__ cos_table,
size_t table_dim,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
auto y_offset = blockIdx.x * y_stride_seqlen + blockIdx.y * y_stride_nhead;
auto x_offset = blockIdx.x * x_stride_seqlen + blockIdx.y * x_stride_nhead;
size_t pos_id = size_t(pos_ids[blockIdx.x]);
auto table_offset = pos_id * table_dim;
for (size_t i = threadIdx.x; i < table_dim; i += blockDim.x) {
Tangle sin__ = sin_table[table_offset + i],
cos__ = cos_table[table_offset + i];
if constexpr (std::is_same<Tdata, half>::value) {
auto &y = reinterpret_cast<half2 &>(y_[y_offset + 2 * i]);
auto &x = reinterpret_cast<const half2 &>(x_[x_offset + 2 * i]);
Tangle y0 = x.x * cos__ - x.y * sin__,
y1 = x.x * sin__ + x.y * cos__;
y = half2(y0, y1);
} else {
Tangle x0 = x_[x_offset + 2 * i],
x1 = x_[x_offset + 2 * i + 1];
y_[y_offset + 2 * i] = Tdata(x0 * cos__ - x1 * sin__);
y_[y_offset + 2 * i + 1] = Tdata(x0 * sin__ + x1 * cos__);
}
}
}
#endif
...@@ -5,6 +5,9 @@ ...@@ -5,6 +5,9 @@
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
#include "cpu/rope_cpu.h" #include "cpu/rope_cpu.h"
#endif #endif
#ifdef ENABLE_CUDA_API
#include "cuda/rope_cuda.cuh"
#endif
__C infiniStatus_t infiniopCreateRoPEDescriptor( __C infiniStatus_t infiniopCreateRoPEDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
...@@ -30,13 +33,8 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor( ...@@ -30,13 +33,8 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor(
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
CREATE(INFINI_DEVICE_CPU, cpu); CREATE(INFINI_DEVICE_CPU, cpu);
#endif #endif
#ifdef ENABLE_NV_GPU #ifdef ENABLE_CUDA_API
case DevNvGpu: { CREATE(INFINI_DEVICE_NVIDIA, cuda);
return cudaCreateRoPEDescriptor((CudaHandle_t)handle,
(RoPECudaDescriptor_t *)desc_ptr, t,
pos_ids, sin_table, cos_table);
}
#endif #endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
...@@ -84,11 +82,8 @@ __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc, ...@@ -84,11 +82,8 @@ __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc,
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
GET(INFINI_DEVICE_CPU, cpu); GET(INFINI_DEVICE_CPU, cpu);
#endif #endif
#ifdef ENABLE_NV_GPU #ifdef ENABLE_CUDA_API
case DevNvGpu: { GET(INFINI_DEVICE_NVIDIA, cuda);
return cudaGetRoPEWorkspaceSize((RoPECudaDescriptor_t)desc, size);
}
#endif #endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
...@@ -137,12 +132,8 @@ __C infiniStatus_t infiniopRoPE( ...@@ -137,12 +132,8 @@ __C infiniStatus_t infiniopRoPE(
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
CALCULATE(INFINI_DEVICE_CPU, cpu); CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif #endif
#ifdef ENABLE_NV_GPU #ifdef ENABLE_CUDA_API
case DevNvGpu: { CALCULATE(INFINI_DEVICE_NVIDIA, cuda);
return cudaRoPE((RoPECudaDescriptor_t)desc, workspace, workspace_size,
t, pos_ids, sin_table, cos_table, stream);
}
#endif #endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
...@@ -188,11 +179,8 @@ infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) { ...@@ -188,11 +179,8 @@ infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) {
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
DELETE(INFINI_DEVICE_CPU, cpu); DELETE(INFINI_DEVICE_CPU, cpu);
#endif #endif
#ifdef ENABLE_NV_GPU #ifdef ENABLE_CUDA_API
case DevNvGpu: { DELETE(INFINI_DEVICE_NVIDIA, cuda);
return cudaDestroyRoPEDescriptor((RoPECudaDescriptor_t)desc);
}
#endif #endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
......
...@@ -26,7 +26,7 @@ from enum import Enum, auto ...@@ -26,7 +26,7 @@ from enum import Enum, auto
_TEST_CASES_ = [ _TEST_CASES_ = [
# (shape, x_strides, y_strides) # (shape, x_strides, y_strides)
((1, 32, 128), None, None), ((1, 32, 128), None, None),
((1, 32, 64), None, None), ((10, 32, 64), None, None),
# 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心 # 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心
# 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持 # 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
((4, 1, 32), (64, 64, 1), None), ((4, 1, 32), (64, 64, 1), None),
...@@ -39,7 +39,7 @@ _TENSOR_DTYPES = [torch.float16, torch.float32] ...@@ -39,7 +39,7 @@ _TENSOR_DTYPES = [torch.float16, torch.float32]
# Tolerance map for different data types # Tolerance map for different data types
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
torch.float16: {"atol": 1e-4, "rtol": 1e-2}, torch.float16: {"atol": 1e-3, "rtol": 1e-2},
torch.float32: {"atol": 1e-4, "rtol": 1e-3}, torch.float32: {"atol": 1e-4, "rtol": 1e-3},
} }
...@@ -77,10 +77,17 @@ def rotary_embedding(t, sin, cos, torch_device): ...@@ -77,10 +77,17 @@ def rotary_embedding(t, sin, cos, torch_device):
dh = t.shape[2] dh = t.shape[2]
dt = t.dtype dt = t.dtype
assert dh % 2 == 0, "Embedding dimension must be even." assert dh % 2 == 0, "Embedding dimension must be even."
t_even = t[..., 0::2].float() # [seq_len, n_head, dh // 2] t_even = t[..., 0::2] # [seq_len, n_head, dh // 2]
t_odd = t[..., 1::2].float() # [seq_len, n_head, dh // 2] t_odd = t[..., 1::2] # [seq_len, n_head, dh // 2]
cos = cos.unsqueeze(1).float() # [seq_len, 1, dh // 2] cos = cos.unsqueeze(1) # [seq_len, 1, dh // 2]
sin = sin.unsqueeze(1).float() # [seq_len, 1, dh // 2] sin = sin.unsqueeze(1) # [seq_len, 1, dh // 2]
if torch_device == "cpu":
(t_even, t_odd, cos, sin) = (
t_even.float(),
t_odd.float(),
cos.float(),
sin.float(),
)
t_out_even = t_even * cos - t_odd * sin t_out_even = t_even * cos - t_odd * sin
t_out_odd = t_even * sin + t_odd * cos t_out_odd = t_even * sin + t_odd * cos
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment