Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
650bc975
Commit
650bc975
authored
Mar 13, 2025
by
PanZezhong
Browse files
issue/6 添加cuda通用reduce模块,实现rmsnorm cuda算子
parent
240b1236
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
194 additions
and
19 deletions
+194
-19
include/infinicore.h
include/infinicore.h
+1
-0
src/infiniop/devices/cuda/cuda_common.cuh
src/infiniop/devices/cuda/cuda_common.cuh
+12
-0
src/infiniop/ops/rms_norm/cuda/rms_norm_cuda.cu
src/infiniop/ops/rms_norm/cuda/rms_norm_cuda.cu
+95
-0
src/infiniop/ops/rms_norm/cuda/rms_norm_cuda.cuh
src/infiniop/ops/rms_norm/cuda/rms_norm_cuda.cuh
+8
-0
src/infiniop/ops/rms_norm/cuda/rms_norm_kernel.cuh
src/infiniop/ops/rms_norm/cuda/rms_norm_kernel.cuh
+37
-0
src/infiniop/ops/rms_norm/operator.cc
src/infiniop/ops/rms_norm/operator.cc
+11
-19
src/infiniop/ops/rms_norm/rms_norm.h
src/infiniop/ops/rms_norm/rms_norm.h
+4
-0
src/infiniop/reduce/cuda/reduce.cuh
src/infiniop/reduce/cuda/reduce.cuh
+26
-0
No files found.
include/infinicore.h
View file @
650bc975
...
@@ -28,6 +28,7 @@ typedef enum {
...
@@ -28,6 +28,7 @@ typedef enum {
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
=
5
,
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
=
5
,
INFINI_STATUS_DEVICE_NOT_FOUND
=
6
,
INFINI_STATUS_DEVICE_NOT_FOUND
=
6
,
INFINI_STATUS_DEVICE_NOT_INITIALIZED
=
7
,
INFINI_STATUS_DEVICE_NOT_INITIALIZED
=
7
,
INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED
=
8
,
// Op Errors
// Op Errors
INFINI_STATUS_BAD_TENSOR_DTYPE
=
10
,
INFINI_STATUS_BAD_TENSOR_DTYPE
=
10
,
INFINI_STATUS_BAD_TENSOR_SHAPE
=
11
,
INFINI_STATUS_BAD_TENSOR_SHAPE
=
11
,
...
...
src/infiniop/devices/cuda/cuda_common.cuh
View file @
650bc975
#ifndef __INFINIOP_CUDA_COMMON_CUH__
#ifndef __INFINIOP_CUDA_COMMON_CUH__
#define __INFINIOP_CUDA_COMMON_CUH__
#define __INFINIOP_CUDA_COMMON_CUH__
#include "../../reduce/cuda/reduce.cuh"
#include "cuda_handle.cuh"
#include "cuda_handle.cuh"
#include "infinicore.h"
#include "infinicore.h"
#ifdef ENABLE_SUGON_CUDA_API
#define INFINIOP_CUDA_KERNEL __launch_bounds__(512) __global__ void
#else
#define INFINIOP_CUDA_KERNEL __global__ void
#endif
// Posible maximum number of threads per block for CUDA architectures
// Used for picking correct kernel launch configuration
#define CUDA_BLOCK_SIZE_1024 1024
#define CUDA_BLOCK_SIZE_512 512
namespace
device
::
cuda
{
namespace
device
::
cuda
{
cudnnDataType_t
getCudnnDtype
(
infiniDtype_t
dt
);
cudnnDataType_t
getCudnnDtype
(
infiniDtype_t
dt
);
...
...
src/infiniop/ops/rms_norm/cuda/rms_norm_cuda.cu
0 → 100644
View file @
650bc975
#include "../../../devices/cuda/cuda_common.cuh"
#include "rms_norm_cuda.cuh"
#include "rms_norm_kernel.cuh"
#include <memory>
#include <stdint.h>
namespace
op
::
rms_norm
::
cuda
{
struct
Descriptor
::
Opaque
{
std
::
shared_ptr
<
device
::
cuda
::
Handle
::
Internal
>
internal
;
};
Descriptor
::~
Descriptor
()
{
delete
_opaque
;
}
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
x_desc
,
infiniopTensorDescriptor_t
w_desc
,
float
epsilon
)
{
RMSNormInfo
info
;
CHECK_STATUS
(
createRMSNormInfo
(
&
info
,
y_desc
,
x_desc
,
w_desc
,
epsilon
));
// only support contiguous last dimension
if
(
info
.
x_strides
[
1
]
!=
1
||
info
.
y_strides
[
1
]
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
*
desc_ptr
=
new
Descriptor
(
new
Opaque
{
reinterpret_cast
<
device
::
cuda
::
Handle
*>
(
handle
)
->
internal
()},
info
,
0
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
// launch kernel with different data types
template
<
unsigned
int
BLOCK_SIZE
>
infiniStatus_t
launchKernel
(
uint32_t
batch_size
,
size_t
dim
,
void
*
y
,
infiniDtype_t
atype
,
ptrdiff_t
stride_y
,
const
void
*
x
,
ptrdiff_t
stride_x
,
const
void
*
w
,
infiniDtype_t
wtype
,
float
epsilon
,
cudaStream_t
cuda_stream
)
{
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
rmsnormBlock<BLOCK_SIZE, Tdata, Tweight, Tcompute><<<batch_size, BLOCK_SIZE, 0, cuda_stream>>>( \
reinterpret_cast<Tdata *>(y), \
stride_y, \
reinterpret_cast<const Tdata *>(x), \
stride_x, \
reinterpret_cast<const Tweight *>(w), \
dim, \
epsilon)
if
(
atype
==
INFINI_DTYPE_F16
&&
wtype
==
INFINI_DTYPE_F16
)
{
LAUNCH_KERNEL
(
half
,
half
,
float
);
}
else
if
(
atype
==
INFINI_DTYPE_F16
&&
wtype
==
INFINI_DTYPE_F32
)
{
LAUNCH_KERNEL
(
half
,
float
,
float
);
}
else
if
(
atype
==
INFINI_DTYPE_F32
&&
wtype
==
INFINI_DTYPE_F32
)
{
LAUNCH_KERNEL
(
float
,
float
,
float
);
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
#undef LAUNCH_KERNEL
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
const
void
*
x
,
const
void
*
w
,
void
*
stream
)
{
if
(
workspace_size
<
_workspace_size
)
{
return
INFINI_STATUS_INSUFFICIENT_WORKSPACE
;
}
auto
stride_x
=
_info
.
x_strides
[
0
];
auto
stride_y
=
_info
.
y_strides
[
0
];
auto
dim
=
_info
.
dim
();
uint32_t
batch_size
=
static_cast
<
uint32_t
>
(
_info
.
shape
[
0
]);
auto
cuda_stream
=
reinterpret_cast
<
cudaStream_t
>
(
stream
);
// launch kernel with different block sizes
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_1024
)
{
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_1024
>
(
batch_size
,
dim
,
y
,
_info
.
atype
,
stride_y
,
x
,
stride_x
,
w
,
_info
.
wtype
,
_info
.
epsilon
,
cuda_stream
));
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_512
)
{
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_512
>
(
batch_size
,
dim
,
y
,
_info
.
atype
,
stride_y
,
x
,
stride_x
,
w
,
_info
.
wtype
,
_info
.
epsilon
,
cuda_stream
));
}
else
{
return
INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::rms_norm::cuda
src/infiniop/ops/rms_norm/cuda/rms_norm_cuda.cuh
0 → 100644
View file @
650bc975
#ifndef __RMS_NORM_CUDA_H__
#define __RMS_NORM_CUDA_H__
#include "../rms_norm.h"
DESCRIPTOR
(
cuda
)
#endif
src/infiniop/ops/rms_norm/cuda/rms_norm_kernel.cuh
0 → 100644
View file @
650bc975
#ifndef __RMS_NORM_CUDA_KERNEL_H__
#define __RMS_NORM_CUDA_KERNEL_H__
#include "../../../devices/cuda/cuda_common.cuh"
#include <cub/block/block_reduce.cuh>
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tdata
,
typename
Tweight
,
typename
Tcompute
>
INFINIOP_CUDA_KERNEL
rmsnormBlock
(
Tdata
*
__restrict__
y
,
ptrdiff_t
stride_y
,
const
Tdata
*
__restrict__
x
,
ptrdiff_t
stride_x
,
const
Tweight
*
__restrict__
w
,
size_t
dim
,
float
epsilon
)
{
// Each block takes care of a row of continuous data of length dim
// Each thread deals with every block_size element in the row
auto
y_ptr
=
y
+
blockIdx
.
x
*
stride_y
;
auto
x_ptr
=
x
+
blockIdx
.
x
*
stride_x
;
auto
w_ptr
=
w
;
// Block-reduce sum of x^2
Tcompute
ss
=
op
::
common_cuda
::
reduce_op
::
sumSquared
<
BLOCK_SIZE
,
Tdata
,
Tcompute
>
(
x_ptr
,
dim
);
// Thread_0 computes RMS=1/sqrt(ss/dim+epsilon) and stores in shared memory
__shared__
Tcompute
rms
;
if
(
threadIdx
.
x
==
0
)
{
rms
=
Tdata
(
rsqrtf
(
ss
/
Tcompute
(
dim
)
+
epsilon
));
}
__syncthreads
();
for
(
size_t
i
=
threadIdx
.
x
;
i
<
dim
;
i
+=
BLOCK_SIZE
)
{
y_ptr
[
i
]
=
Tdata
(
Tcompute
(
x_ptr
[
i
])
*
Tcompute
(
w_ptr
[
i
])
*
rms
);
}
}
#endif
src/infiniop/ops/rms_norm/operator.cc
View file @
650bc975
...
@@ -5,6 +5,9 @@
...
@@ -5,6 +5,9 @@
#ifdef ENABLE_CPU_API
#ifdef ENABLE_CPU_API
#include "cpu/rms_norm_cpu.h"
#include "cpu/rms_norm_cpu.h"
#endif
#endif
#ifdef ENABLE_CUDA_API
#include "cuda/rms_norm_cuda.cuh"
#endif
__C
infiniStatus_t
infiniopCreateRMSNormDescriptor
(
__C
infiniStatus_t
infiniopCreateRMSNormDescriptor
(
infiniopHandle_t
handle
,
infiniopHandle_t
handle
,
...
@@ -28,10 +31,8 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
...
@@ -28,10 +31,8 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
#ifdef ENABLE_CPU_API
#ifdef ENABLE_CPU_API
CREATE
(
INFINI_DEVICE_CPU
,
cpu
)
CREATE
(
INFINI_DEVICE_CPU
,
cpu
)
#endif
#endif
#ifdef ENABLE_NV_GPU
#ifdef ENABLE_CUDA_API
case
DevNvGpu
:
{
CREATE
(
INFINI_DEVICE_NVIDIA
,
cuda
)
return
cudaCreateRMSNormDescriptor
((
CudaHandle_t
)
handle
,
(
RMSNormCudaDescriptor_t
*
)
desc_ptr
,
y_desc
,
x_desc
,
w_desc
,
epsilon
);
}
#endif
#endif
#ifdef ENABLE_CAMBRICON_MLU
#ifdef ENABLE_CAMBRICON_MLU
case
DevCambriconMlu
:
{
case
DevCambriconMlu
:
{
...
@@ -76,11 +77,8 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d
...
@@ -76,11 +77,8 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d
#ifdef ENABLE_CPU_API
#ifdef ENABLE_CPU_API
GET
(
INFINI_DEVICE_CPU
,
cpu
)
GET
(
INFINI_DEVICE_CPU
,
cpu
)
#endif
#endif
#ifdef ENABLE_NV_GPU
#ifdef ENABLE_CUDA_API
case
DevNvGpu
:
{
GET
(
INFINI_DEVICE_NVIDIA
,
cuda
)
return
cudaGetRMSNormWorkspaceSize
((
RMSNormCudaDescriptor_t
)
desc
,
size
);
}
#endif
#endif
#ifdef ENABLE_CAMBRICON_MLU
#ifdef ENABLE_CAMBRICON_MLU
case
DevCambriconMlu
:
{
case
DevCambriconMlu
:
{
...
@@ -122,11 +120,8 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works
...
@@ -122,11 +120,8 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works
#ifdef ENABLE_CPU_API
#ifdef ENABLE_CPU_API
CALCULATE
(
INFINI_DEVICE_CPU
,
cpu
)
CALCULATE
(
INFINI_DEVICE_CPU
,
cpu
)
#endif
#endif
#ifdef ENABLE_NV_GPU
#ifdef ENABLE_CUDA_API
case
DevNvGpu
:
{
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
cuda
)
return
cudaRMSNorm
((
RMSNormCudaDescriptor_t
)
desc
,
workspace
,
workspace_size
,
y
,
x
,
w
,
stream
);
}
#endif
#endif
#ifdef ENABLE_CAMBRICON_MLU
#ifdef ENABLE_CAMBRICON_MLU
case
DevCambriconMlu
:
{
case
DevCambriconMlu
:
{
...
@@ -172,11 +167,8 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t
...
@@ -172,11 +167,8 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t
#ifdef ENABLE_CPU_API
#ifdef ENABLE_CPU_API
DESTROY
(
INFINI_DEVICE_CPU
,
cpu
)
DESTROY
(
INFINI_DEVICE_CPU
,
cpu
)
#endif
#endif
#ifdef ENABLE_NV_GPU
#ifdef ENABLE_CUDA_API
case
DevNvGpu
:
{
DESTROY
(
INFINI_DEVICE_NVIDIA
,
cuda
)
return
cudaDestroyRMSNormDescriptor
((
RMSNormCudaDescriptor_t
)
desc
);
}
#endif
#endif
#ifdef ENABLE_CAMBRICON_MLU
#ifdef ENABLE_CAMBRICON_MLU
case
DevCambriconMlu
:
{
case
DevCambriconMlu
:
{
...
...
src/infiniop/ops/rms_norm/rms_norm.h
View file @
650bc975
...
@@ -51,6 +51,10 @@ inline infiniStatus_t createRMSNormInfo(RMSNormInfo *info, infiniopTensorDescrip
...
@@ -51,6 +51,10 @@ inline infiniStatus_t createRMSNormInfo(RMSNormInfo *info, infiniopTensorDescrip
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
if
(
w_desc
->
stride
(
0
)
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
info
->
shape
=
std
::
move
(
y_desc
->
shape
());
info
->
shape
=
std
::
move
(
y_desc
->
shape
());
info
->
y_strides
=
std
::
move
(
y_desc
->
strides
());
info
->
y_strides
=
std
::
move
(
y_desc
->
strides
());
info
->
x_strides
=
std
::
move
(
x_desc
->
strides
());
info
->
x_strides
=
std
::
move
(
x_desc
->
strides
());
...
...
src/infiniop/reduce/cuda/reduce.cuh
0 → 100644
View file @
650bc975
#ifndef __INFINIOP_REDUCE_CUDA_H__
#define __INFINIOP_REDUCE_CUDA_H__
#include <cub/block/block_reduce.cuh>
namespace
op
::
common_cuda
::
reduce_op
{
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tdata
,
typename
Tcompute
>
__device__
__forceinline__
Tcompute
sumSquared
(
const
Tdata
*
data_ptr
,
size_t
count
)
{
Tcompute
ss
=
0
;
// Each thread computes its partial sum
for
(
size_t
i
=
threadIdx
.
x
;
i
<
count
;
i
+=
BLOCK_SIZE
)
{
ss
+=
Tcompute
(
data_ptr
[
i
]
*
data_ptr
[
i
]);
}
// Use CUB block-level reduction
using
BlockReduce
=
cub
::
BlockReduce
<
Tcompute
,
BLOCK_SIZE
>
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
return
BlockReduce
(
temp_storage
).
Sum
(
ss
);
}
}
// namespace op::common_cuda::reduce_op
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment