Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
8ccd42bf
Unverified
Commit
8ccd42bf
authored
Jun 11, 2025
by
PanZezhong1725
Committed by
GitHub
Jun 11, 2025
Browse files
Merge pull request #244 from InfiniTensor/issue/39
issue/39 Migrate cuda causal softmax to metax
parents
0f132536
04e2294d
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
218 additions
and
0 deletions
+218
-0
src/infiniop/ops/causal_softmax/maca/causal_softmax_kernel.h
src/infiniop/ops/causal_softmax/maca/causal_softmax_kernel.h
+60
-0
src/infiniop/ops/causal_softmax/maca/causal_softmax_maca.h
src/infiniop/ops/causal_softmax/maca/causal_softmax_maca.h
+8
-0
src/infiniop/ops/causal_softmax/maca/causal_softmax_maca.maca
...infiniop/ops/causal_softmax/maca/causal_softmax_maca.maca
+72
-0
src/infiniop/ops/causal_softmax/operator.cc
src/infiniop/ops/causal_softmax/operator.cc
+15
-0
src/infiniop/reduce/maca/reduce.h
src/infiniop/reduce/maca/reduce.h
+63
-0
No files found.
src/infiniop/ops/causal_softmax/maca/causal_softmax_kernel.h
0 → 100644
View file @
8ccd42bf
#ifndef __CAUSAL_SOFTMAX_KERNEL_H__
#define __CAUSAL_SOFTMAX_KERNEL_H__
#include "../../../devices/maca/maca_kernel_common.h"
#include "../../../reduce/maca/reduce.h"
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tdata
,
typename
Tcompute
>
INFINIOP_MACA_KERNEL
causalSoftmax
(
Tdata
*
y_
,
const
Tdata
*
x_
,
size_t
batch
,
size_t
height
,
size_t
width
,
ptrdiff_t
y_stride_b
,
ptrdiff_t
y_stride_h
,
ptrdiff_t
x_stride_b
,
ptrdiff_t
x_stride_h
)
{
Tdata
*
y
=
y_
// threadIdx.x for col_id
+
blockIdx
.
y
*
y_stride_b
// gridDim.y for batch_id
+
blockIdx
.
x
*
y_stride_h
;
// gridDim.x for row_id
const
Tdata
*
x
=
x_
+
blockIdx
.
y
*
x_stride_b
+
blockIdx
.
x
*
x_stride_h
;
// [Reduce] Find max value in each row and store in shared memory
__shared__
Tdata
max_
;
Tdata
max_0
=
op
::
common_maca
::
reduce_op
::
max
<
BLOCK_SIZE
,
Tdata
>
(
x
,
width
);
if
(
threadIdx
.
x
==
0
)
{
max_
=
max_0
;
}
__syncthreads
();
// [Elementwise] Subtract max value from each element and apply causal mask
for
(
size_t
col
=
threadIdx
.
x
;
col
<
width
;
col
+=
BLOCK_SIZE
)
{
// row_id ↓ |<- width ->|
// 0 | * * * ... * |
// 1 | * * * ... * * |
// 2 | * * * ... * * * |
// height: 3 col_id->
if
(
width
+
blockIdx
.
x
>=
threadIdx
.
x
+
height
)
{
#ifdef ENABLE_MACA_API
y
[
col
]
=
exp_
(
x
[
col
]
-
max_
);
#else
y
[
col
]
=
exp
(
x
[
col
]
-
max_
);
#endif
}
else
{
y
[
col
]
=
Tdata
(
0
);
}
}
__syncthreads
();
// [Reduce] Find the sum of each updated row and store in shared memory
__shared__
Tcompute
sum_
;
Tcompute
sum_0
=
op
::
common_maca
::
reduce_op
::
sum
<
BLOCK_SIZE
,
Tdata
,
Tcompute
>
(
y
,
width
);
if
(
threadIdx
.
x
==
0
)
{
sum_
=
sum_0
;
}
__syncthreads
();
// [Elementwise] Divide each element by the sum and store in shared memory
for
(
size_t
col
=
threadIdx
.
x
;
col
<
width
;
col
+=
BLOCK_SIZE
)
{
y
[
col
]
/=
Tdata
(
sum_
);
}
}
#endif // __CAUSAL_SOFTMAX_KERNEL_H__
src/infiniop/ops/causal_softmax/maca/causal_softmax_maca.h
0 → 100644
View file @
8ccd42bf
#ifndef __CAUSAL_SOFTMAX_MACA_H__
#define __CAUSAL_SOFTMAX_MACA_H__
#include "../causal_softmax.h"
DESCRIPTOR
(
maca
)
#endif
src/infiniop/ops/causal_softmax/maca/causal_softmax_maca.maca
0 → 100644
View file @
8ccd42bf
#include "../../../devices/maca/common_maca.h"
#include "causal_softmax_kernel.h"
#include "causal_softmax_maca.h"
namespace op::causal_softmax::maca {
struct Descriptor::Opaque {
std::shared_ptr<device::maca::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc) {
auto info = CausalSoftmaxInfo::create(y_desc, x_desc);
CHECK_RESULT(info);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::maca::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel(void *y, const void *x, infiniDtype_t dtype,
size_t batch_size, size_t seq_len, size_t total_seq_len,
ptrdiff_t y_stride_b, ptrdiff_t y_stride_i,
ptrdiff_t x_stride_b, ptrdiff_t x_stride_i,
hcStream_t stream) {
dim3 grid(uint32_t(seq_len), uint32_t(batch_size), 1);
if (dtype == INFINI_DTYPE_F16) {
causalSoftmax<BLOCK_SIZE, half, float>
<<<grid, BLOCK_SIZE, 0, stream>>>((half *)y, (const half *)x,
batch_size, seq_len, total_seq_len,
y_stride_b, y_stride_i,
x_stride_b, x_stride_i);
} else if (dtype == INFINI_DTYPE_F32) {
causalSoftmax<BLOCK_SIZE, float, float>
<<<grid, BLOCK_SIZE, 0, stream>>>((float *)y, (const float *)x,
batch_size, seq_len, total_seq_len,
y_stride_b, y_stride_i,
x_stride_b, x_stride_i);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *y,
const void *x,
void *stream_) const {
hcStream_t stream = (hcStream_t)stream_;
if (_opaque->internal->maxThreadsPerBlock() == MACA_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<MACA_BLOCK_SIZE_1024>(
y, x, _info.dtype, _info.batch_size, _info.seq_len, _info.total_seq_len,
_info.y_stride_b, _info.y_stride_i, _info.x_stride_b, _info.x_stride_i, stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MACA_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<MACA_BLOCK_SIZE_512>(
y, x, _info.dtype, _info.batch_size, _info.seq_len, _info.total_seq_len,
_info.y_stride_b, _info.y_stride_i, _info.x_stride_b, _info.x_stride_i, stream));
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::causal_softmax::maca
src/infiniop/ops/causal_softmax/operator.cc
View file @
8ccd42bf
...
@@ -8,6 +8,9 @@
...
@@ -8,6 +8,9 @@
#ifdef ENABLE_CUDA_API
#ifdef ENABLE_CUDA_API
#include "cuda/causal_softmax_cuda.cuh"
#include "cuda/causal_softmax_cuda.cuh"
#endif
#endif
#ifdef ENABLE_METAX_API
#include "maca/causal_softmax_maca.h"
#endif
#ifdef ENABLE_ASCEND_API
#ifdef ENABLE_ASCEND_API
#include "ascend/causal_softmax_ascend.h"
#include "ascend/causal_softmax_ascend.h"
#endif
#endif
...
@@ -33,6 +36,9 @@ __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
...
@@ -33,6 +36,9 @@ __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
#ifdef ENABLE_CUDA_API
#ifdef ENABLE_CUDA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
cuda
)
CREATE
(
INFINI_DEVICE_NVIDIA
,
cuda
)
#endif
#endif
#ifdef ENABLE_METAX_API
CREATE
(
INFINI_DEVICE_METAX
,
maca
)
#endif
#ifdef ENABLE_CAMBRICON_MLU
#ifdef ENABLE_CAMBRICON_MLU
case
DevCambriconMlu
:
{
case
DevCambriconMlu
:
{
return
bangCreateCausalSoftmaxDescriptor
((
BangHandle_t
)
handle
,
(
CausalSoftmaxBangDescriptor_t
*
)
desc_ptr
,
y_desc
);
return
bangCreateCausalSoftmaxDescriptor
((
BangHandle_t
)
handle
,
(
CausalSoftmaxBangDescriptor_t
*
)
desc_ptr
,
y_desc
);
...
@@ -80,6 +86,9 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe
...
@@ -80,6 +86,9 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe
#ifdef ENABLE_ASCEND_API
#ifdef ENABLE_ASCEND_API
GET
(
INFINI_DEVICE_ASCEND
,
ascend
)
GET
(
INFINI_DEVICE_ASCEND
,
ascend
)
#endif
#endif
#ifdef ENABLE_METAX_API
GET
(
INFINI_DEVICE_METAX
,
maca
)
#endif
#ifdef ENABLE_METAX_GPU
#ifdef ENABLE_METAX_GPU
case
DevMetaxGpu
:
{
case
DevMetaxGpu
:
{
return
macaGetCausalSoftmaxWorkspaceSize
((
CausalSoftmaxMacaDescriptor_t
)
desc
,
size
);
return
macaGetCausalSoftmaxWorkspaceSize
((
CausalSoftmaxMacaDescriptor_t
)
desc
,
size
);
...
@@ -113,6 +122,9 @@ __C infiniStatus_t infiniopCausalSoftmax(
...
@@ -113,6 +122,9 @@ __C infiniStatus_t infiniopCausalSoftmax(
#ifdef ENABLE_CUDA_API
#ifdef ENABLE_CUDA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
cuda
)
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
cuda
)
#endif
#endif
#ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
maca
)
#endif
#ifdef ENABLE_CAMBRICON_MLU
#ifdef ENABLE_CAMBRICON_MLU
case
DevCambriconMlu
:
{
case
DevCambriconMlu
:
{
return
bangCausalSoftmax
((
CausalSoftmaxBangDescriptor_t
)
desc
,
workspace
,
workspace_size
,
data
,
stream
);
return
bangCausalSoftmax
((
CausalSoftmaxBangDescriptor_t
)
desc
,
workspace
,
workspace_size
,
data
,
stream
);
...
@@ -150,6 +162,9 @@ __C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxD
...
@@ -150,6 +162,9 @@ __C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxD
#ifdef ENABLE_CUDA_API
#ifdef ENABLE_CUDA_API
DESTROY
(
INFINI_DEVICE_NVIDIA
,
cuda
)
DESTROY
(
INFINI_DEVICE_NVIDIA
,
cuda
)
#endif
#endif
#ifdef ENABLE_METAX_API
DESTROY
(
INFINI_DEVICE_METAX
,
maca
)
#endif
#ifdef ENABLE_CAMBRICON_MLU
#ifdef ENABLE_CAMBRICON_MLU
case
DevCambriconMlu
:
{
case
DevCambriconMlu
:
{
return
bangDestroyCausalSoftmaxDescriptor
((
CausalSoftmaxBangDescriptor_t
)
desc
);
return
bangDestroyCausalSoftmaxDescriptor
((
CausalSoftmaxBangDescriptor_t
)
desc
);
...
...
src/infiniop/reduce/maca/reduce.h
0 → 100644
View file @
8ccd42bf
#ifndef __INFINIOP_REDUCE_MACA_H__
#define __INFINIOP_REDUCE_MACA_H__
#include <hccub/block/block_reduce.cuh>
/*
* Device functions for reduction operations on MACA.
*
* Note: Only local result on thread 0 is guranteed to be correct.
* A manual broadcast is needed for other threads.
*/
namespace
op
::
common_maca
::
reduce_op
{
// Sum(x^2) on contiguous data of length count
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tdata
,
typename
Tcompute
>
__device__
__forceinline__
Tcompute
sumSquared
(
const
Tdata
*
data_ptr
,
size_t
count
)
{
Tcompute
ss
=
0
;
// Each thread computes its partial sum
for
(
size_t
i
=
threadIdx
.
x
;
i
<
count
;
i
+=
BLOCK_SIZE
)
{
ss
+=
Tcompute
(
data_ptr
[
i
])
*
Tcompute
(
data_ptr
[
i
]);
}
// Use CUB block-level reduction
using
BlockReduce
=
cub
::
BlockReduce
<
Tcompute
,
BLOCK_SIZE
>
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
return
BlockReduce
(
temp_storage
).
Sum
(
ss
);
}
// Sum(x) on contiguous data of length count
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tdata
,
typename
Tcompute
>
__device__
__forceinline__
Tcompute
sum
(
const
Tdata
*
data_ptr
,
size_t
count
)
{
Tcompute
s
=
0
;
for
(
size_t
i
=
threadIdx
.
x
;
i
<
count
;
i
+=
BLOCK_SIZE
)
{
s
+=
Tcompute
(
data_ptr
[
i
]);
}
using
BlockReduce
=
cub
::
BlockReduce
<
Tcompute
,
BLOCK_SIZE
>
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
return
BlockReduce
(
temp_storage
).
Sum
(
s
);
}
// Max(x) on contiguous data of length count
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tdata
>
__device__
__forceinline__
Tdata
max
(
const
Tdata
*
data_ptr
,
size_t
count
)
{
Tdata
max_
=
data_ptr
[
0
];
for
(
size_t
i
=
threadIdx
.
x
;
i
<
count
;
i
+=
BLOCK_SIZE
)
{
max_
=
cub
::
Max
()(
max_
,
data_ptr
[
i
]);
}
using
BlockReduce
=
cub
::
BlockReduce
<
Tdata
,
BLOCK_SIZE
>
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
return
BlockReduce
(
temp_storage
).
Reduce
(
max_
,
cub
::
Max
(),
BLOCK_SIZE
);
}
}
// namespace op::common_maca::reduce_op
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment