Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
a8955429
Commit
a8955429
authored
Apr 09, 2025
by
PanZezhong
Browse files
issue/4 添加cuda causal softmax算子
parent
77409cea
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
166 additions
and
21 deletions
+166
-21
src/infiniop/ops/causal_softmax/cuda/causal_softmax_cuda.cu
src/infiniop/ops/causal_softmax/cuda/causal_softmax_cuda.cu
+59
-0
src/infiniop/ops/causal_softmax/cuda/causal_softmax_cuda.cuh
src/infiniop/ops/causal_softmax/cuda/causal_softmax_cuda.cuh
+8
-0
src/infiniop/ops/causal_softmax/cuda/causal_softmax_kernel.cuh
...nfiniop/ops/causal_softmax/cuda/causal_softmax_kernel.cuh
+49
-0
src/infiniop/ops/causal_softmax/operator.cc
src/infiniop/ops/causal_softmax/operator.cc
+11
-20
src/infiniop/reduce/cuda/reduce.cuh
src/infiniop/reduce/cuda/reduce.cuh
+37
-0
xmake.lua
xmake.lua
+2
-1
No files found.
src/infiniop/ops/causal_softmax/cuda/causal_softmax_cuda.cu
0 → 100644
View file @
a8955429
#include "causal_softmax_cuda.cuh"
#include "../../../devices/cuda/cuda_common.cuh"
#include "causal_softmax_kernel.cuh"
namespace
op
::
causal_softmax
::
cuda
{
struct
Descriptor
::
Opaque
{
std
::
shared_ptr
<
device
::
cuda
::
Handle
::
Internal
>
internal
;
};
Descriptor
::~
Descriptor
()
{
delete
_opaque
;
}
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
)
{
CausalSoftmaxInfo
info
;
CHECK_STATUS
(
createCausalSoftmaxInfo
(
&
info
,
y_desc
));
*
desc_ptr
=
new
Descriptor
(
new
Opaque
{
reinterpret_cast
<
device
::
cuda
::
Handle
*>
(
handle
)
->
internal
()},
info
,
0
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
template
<
unsigned
int
BLOCK_SIZE
>
infiniStatus_t
launchKernel
(
void
*
data
,
infiniDtype_t
dtype
,
size_t
batch_size
,
size_t
seq_len
,
size_t
total_seq_len
,
ptrdiff_t
stride_b
,
ptrdiff_t
stride_i
,
cudaStream_t
stream
)
{
dim3
grid
(
uint32_t
(
seq_len
),
uint32_t
(
batch_size
),
1
);
if
(
dtype
==
INFINI_DTYPE_F16
)
{
causalSoftmax
<
BLOCK_SIZE
,
half
,
float
>
<<<
grid
,
BLOCK_SIZE
,
0
,
stream
>>>
((
half
*
)
data
,
batch_size
,
seq_len
,
total_seq_len
,
stride_b
,
stride_i
);
}
else
if
(
dtype
==
INFINI_DTYPE_F32
)
{
causalSoftmax
<
BLOCK_SIZE
,
float
,
float
>
<<<
grid
,
BLOCK_SIZE
,
0
,
stream
>>>
((
float
*
)
data
,
batch_size
,
seq_len
,
total_seq_len
,
stride_b
,
stride_i
);
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
data
,
void
*
stream_
)
{
cudaStream_t
stream
=
(
cudaStream_t
)
stream_
;
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_1024
)
{
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_1024
>
(
data
,
_info
.
dtype
,
_info
.
batch_size
,
_info
.
seq_len
,
_info
.
total_seq_len
,
_info
.
stride_b
,
_info
.
stride_i
,
stream
));
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_512
)
{
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_512
>
(
data
,
_info
.
dtype
,
_info
.
batch_size
,
_info
.
seq_len
,
_info
.
total_seq_len
,
_info
.
stride_b
,
_info
.
stride_i
,
stream
));
}
else
{
return
INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::causal_softmax::cuda
src/infiniop/ops/causal_softmax/cuda/causal_softmax_cuda.cuh
0 → 100644
View file @
a8955429
#ifndef __CAUSAL_SOFTMAX_CUDA_H__
#define __CAUSAL_SOFTMAX_CUDA_H__
#include "../causal_softmax.h"
DESCRIPTOR
(
cuda
)
#endif
src/infiniop/ops/causal_softmax/cuda/causal_softmax_kernel.cuh
0 → 100644
View file @
a8955429
#ifndef __CAUSAL_SOFTMAX_KERNEL_CUH__
#define __CAUSAL_SOFTMAX_KERNEL_CUH__
#include "../../../devices/cuda/cuda_common.cuh"
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tdata
,
typename
Tcompute
>
INFINIOP_CUDA_KERNEL
causalSoftmax
(
Tdata
*
data_
,
size_t
batch
,
size_t
height
,
size_t
width
,
ptrdiff_t
stride_b
,
ptrdiff_t
stride_h
)
{
Tdata
*
data
=
data_
// threadIdx.x for col_id
+
blockIdx
.
y
*
stride_b
// gridDim.y for batch_id
+
blockIdx
.
x
*
stride_h
;
// gridDim.x for row_id
// [Reduce] Find max value in each row and store in shared memory
__shared__
Tdata
max_
;
Tdata
max_0
=
op
::
common_cuda
::
reduce_op
::
max
<
BLOCK_SIZE
,
Tdata
>
(
data
,
width
);
if
(
threadIdx
.
x
==
0
)
{
max_
=
max_0
;
}
__syncthreads
();
// [Elementwise] Subtract max value from each element and apply causal mask
for
(
size_t
col
=
threadIdx
.
x
;
col
<
width
;
col
+=
BLOCK_SIZE
)
{
// row_id ↓ |<- width ->|
// 0 | * * * ... * |
// 1 | * * * ... * * |
// 2 | * * * ... * * * |
// height: 3 col_id->
if
(
width
+
blockIdx
.
x
>=
threadIdx
.
x
+
height
)
{
data
[
col
]
=
exp
(
data
[
col
]
-
max_
);
}
else
{
data
[
col
]
=
Tdata
(
0
);
}
}
__syncthreads
();
// [Reduce] Find the sum of each updated row and store in shared memory
__shared__
Tcompute
sum_
;
Tcompute
sum_0
=
op
::
common_cuda
::
reduce_op
::
sum
<
BLOCK_SIZE
,
Tdata
,
Tcompute
>
(
data
,
width
);
if
(
threadIdx
.
x
==
0
)
{
sum_
=
sum_0
;
}
__syncthreads
();
// [Elementwise] Divide each element by the sum and store in shared memory
for
(
size_t
col
=
threadIdx
.
x
;
col
<
width
;
col
+=
BLOCK_SIZE
)
{
data
[
col
]
/=
Tdata
(
sum_
);
}
}
#endif // __CAUSAL_SOFTMAX_KERNEL_CUH__
src/infiniop/ops/causal_softmax/operator.cc
View file @
a8955429
...
@@ -5,6 +5,9 @@
...
@@ -5,6 +5,9 @@
#ifdef ENABLE_CPU_API
#ifdef ENABLE_CPU_API
#include "cpu/causal_softmax_cpu.h"
#include "cpu/causal_softmax_cpu.h"
#endif
#endif
#ifdef ENABLE_CUDA_API
#include "cuda/causal_softmax_cuda.cuh"
#endif
__C
infiniStatus_t
infiniopCreateCausalSoftmaxDescriptor
(
__C
infiniStatus_t
infiniopCreateCausalSoftmaxDescriptor
(
infiniopHandle_t
handle
,
infiniopHandle_t
handle
,
...
@@ -24,11 +27,8 @@ __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
...
@@ -24,11 +27,8 @@ __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
#ifdef ENABLE_CPU_API
#ifdef ENABLE_CPU_API
CREATE
(
INFINI_DEVICE_CPU
,
cpu
)
CREATE
(
INFINI_DEVICE_CPU
,
cpu
)
#endif
#endif
#ifdef ENABLE_NV_GPU
#ifdef ENABLE_CUDA_API
case
DevNvGpu
:
{
CREATE
(
INFINI_DEVICE_NVIDIA
,
cuda
)
return
cudaCreateCausalSoftmaxDescriptor
((
CudaHandle_t
)
handle
,
(
CausalSoftmaxCudaDescriptor_t
*
)
desc_ptr
,
y_desc
);
}
#endif
#endif
#ifdef ENABLE_CAMBRICON_MLU
#ifdef ENABLE_CAMBRICON_MLU
case
DevCambriconMlu
:
{
case
DevCambriconMlu
:
{
...
@@ -66,11 +66,8 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe
...
@@ -66,11 +66,8 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe
#ifdef ENABLE_CPU_API
#ifdef ENABLE_CPU_API
GET
(
INFINI_DEVICE_CPU
,
cpu
)
GET
(
INFINI_DEVICE_CPU
,
cpu
)
#endif
#endif
#ifdef ENABLE_NV_GPU
#ifdef ENABLE_CUDA_API
case
DevNvGpu
:
{
GET
(
INFINI_DEVICE_NVIDIA
,
cuda
)
return
cudaGetCausalSoftmaxWorkspaceSize
((
CausalSoftmaxCudaDescriptor_t
)
desc
,
size
);
}
#endif
#endif
#ifdef ENABLE_CAMBRICON_MLU
#ifdef ENABLE_CAMBRICON_MLU
case
DevCambriconMlu
:
{
case
DevCambriconMlu
:
{
...
@@ -114,11 +111,8 @@ __C infiniStatus_t infiniopCausalSoftmax(
...
@@ -114,11 +111,8 @@ __C infiniStatus_t infiniopCausalSoftmax(
#ifdef ENABLE_CPU_API
#ifdef ENABLE_CPU_API
CALCULATE
(
INFINI_DEVICE_CPU
,
cpu
)
CALCULATE
(
INFINI_DEVICE_CPU
,
cpu
)
#endif
#endif
#ifdef ENABLE_NV_GPU
#ifdef ENABLE_CUDA_API
case
DevNvGpu
:
{
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
cuda
)
return
cudaCausalSoftmax
((
CausalSoftmaxCudaDescriptor_t
)
desc
,
workspace
,
workspace_size
,
data
,
stream
);
}
#endif
#endif
#ifdef ENABLE_CAMBRICON_MLU
#ifdef ENABLE_CAMBRICON_MLU
case
DevCambriconMlu
:
{
case
DevCambriconMlu
:
{
...
@@ -156,11 +150,8 @@ __C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxD
...
@@ -156,11 +150,8 @@ __C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxD
#ifdef ENABLE_CPU_API
#ifdef ENABLE_CPU_API
DESTROY
(
INFINI_DEVICE_CPU
,
cpu
)
DESTROY
(
INFINI_DEVICE_CPU
,
cpu
)
#endif
#endif
#ifdef ENABLE_NV_GPU
#ifdef ENABLE_CUDA_API
case
DevNvGpu
:
{
DESTROY
(
INFINI_DEVICE_NVIDIA
,
cuda
)
return
cudaDestroyCausalSoftmaxDescriptor
((
CausalSoftmaxCudaDescriptor_t
)
desc
);
}
#endif
#endif
#ifdef ENABLE_CAMBRICON_MLU
#ifdef ENABLE_CAMBRICON_MLU
case
DevCambriconMlu
:
{
case
DevCambriconMlu
:
{
...
...
src/infiniop/reduce/cuda/reduce.cuh
View file @
a8955429
...
@@ -3,8 +3,15 @@
...
@@ -3,8 +3,15 @@
#include <cub/block/block_reduce.cuh>
#include <cub/block/block_reduce.cuh>
/*
* Device functions for reduction operations on CUDA.
*
* Note: Only local result on thread 0 is guranteed to be correct.
* A manual broadcast is needed for other threads.
*/
namespace
op
::
common_cuda
::
reduce_op
{
namespace
op
::
common_cuda
::
reduce_op
{
// Sum(x^2) on contiguous data of length count
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tdata
,
typename
Tcompute
>
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tdata
,
typename
Tcompute
>
__device__
__forceinline__
Tcompute
sumSquared
(
const
Tdata
*
data_ptr
,
size_t
count
)
{
__device__
__forceinline__
Tcompute
sumSquared
(
const
Tdata
*
data_ptr
,
size_t
count
)
{
Tcompute
ss
=
0
;
Tcompute
ss
=
0
;
...
@@ -21,6 +28,36 @@ __device__ __forceinline__ Tcompute sumSquared(const Tdata *data_ptr, size_t cou
...
@@ -21,6 +28,36 @@ __device__ __forceinline__ Tcompute sumSquared(const Tdata *data_ptr, size_t cou
return
BlockReduce
(
temp_storage
).
Sum
(
ss
);
return
BlockReduce
(
temp_storage
).
Sum
(
ss
);
}
}
// Sum(x) on contiguous data of length count
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tdata
,
typename
Tcompute
>
__device__
__forceinline__
Tcompute
sum
(
const
Tdata
*
data_ptr
,
size_t
count
)
{
Tcompute
s
=
0
;
for
(
size_t
i
=
threadIdx
.
x
;
i
<
count
;
i
+=
BLOCK_SIZE
)
{
s
+=
Tcompute
(
data_ptr
[
i
]);
}
using
BlockReduce
=
cub
::
BlockReduce
<
Tcompute
,
BLOCK_SIZE
>
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
return
BlockReduce
(
temp_storage
).
Sum
(
s
);
}
// Max(x) on contiguous data of length count
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tdata
>
__device__
__forceinline__
Tdata
max
(
const
Tdata
*
data_ptr
,
size_t
count
)
{
Tdata
max_
=
data_ptr
[
0
];
for
(
size_t
i
=
threadIdx
.
x
;
i
<
count
;
i
+=
BLOCK_SIZE
)
{
max_
=
cub
::
Max
()(
max_
,
data_ptr
[
i
]);
}
using
BlockReduce
=
cub
::
BlockReduce
<
Tdata
,
BLOCK_SIZE
>
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
return
BlockReduce
(
temp_storage
).
Reduce
(
max_
,
cub
::
Max
(),
BLOCK_SIZE
);
}
}
// namespace op::common_cuda::reduce_op
}
// namespace op::common_cuda::reduce_op
#endif
#endif
xmake.lua
View file @
a8955429
...
@@ -4,9 +4,10 @@ local GREEN = '\27[0;32m'
...
@@ -4,9 +4,10 @@ local GREEN = '\27[0;32m'
local
YELLOW
=
'
\27
[1;33m'
local
YELLOW
=
'
\27
[1;33m'
local
NC
=
'
\27
[0m'
-- No Color
local
NC
=
'
\27
[0m'
-- No Color
add_includedirs
(
"include"
)
set_encodings
(
"utf-8"
)
set_encodings
(
"utf-8"
)
add_includedirs
(
"include"
)
if
is_mode
(
"debug"
)
then
if
is_mode
(
"debug"
)
then
add_defines
(
"DEBUG_MODE"
)
add_defines
(
"DEBUG_MODE"
)
end
end
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment