Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
f50c2a40
Unverified
Commit
f50c2a40
authored
Aug 27, 2025
by
spike-zhu
Committed by
GitHub
Aug 27, 2025
Browse files
issue/260: 摩尔平台 causal_softmax 算子开发
parent
9c87dbb1
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
207 additions
and
1 deletion
+207
-1
src/infiniop/ops/causal_softmax/moore/causal_softmax_kernel.h
...infiniop/ops/causal_softmax/moore/causal_softmax_kernel.h
+80
-0
src/infiniop/ops/causal_softmax/moore/causal_softmax_moore.h
src/infiniop/ops/causal_softmax/moore/causal_softmax_moore.h
+8
-0
src/infiniop/ops/causal_softmax/moore/causal_softmax_moore.mu
...infiniop/ops/causal_softmax/moore/causal_softmax_moore.mu
+93
-0
src/infiniop/ops/causal_softmax/operator.cc
src/infiniop/ops/causal_softmax/operator.cc
+15
-0
test/infiniop/causal_softmax.py
test/infiniop/causal_softmax.py
+11
-1
No files found.
src/infiniop/ops/causal_softmax/moore/causal_softmax_kernel.h
0 → 100644
View file @
f50c2a40
#ifndef __CAUSAL_SOFTMAX_KERNEL_CUH__
#define __CAUSAL_SOFTMAX_KERNEL_CUH__
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tdata
,
typename
Tcompute
>
__device__
void
causalSoftmaxKernel
(
Tdata
*
y_
,
const
Tdata
*
x_
,
size_t
batch
,
size_t
height
,
size_t
width
,
ptrdiff_t
y_stride_b
,
ptrdiff_t
y_stride_h
,
ptrdiff_t
x_stride_b
,
ptrdiff_t
x_stride_h
)
{
Tdata
*
y
=
y_
// threadIdx.x for col_id
+
blockIdx
.
y
*
y_stride_b
// gridDim.y for batch_id
+
blockIdx
.
x
*
y_stride_h
;
// gridDim.x for row_id
const
Tdata
*
x
=
x_
+
blockIdx
.
y
*
x_stride_b
+
blockIdx
.
x
*
x_stride_h
;
// [Reduce] Find max value in each row and store in shared memory
__shared__
Tdata
max_
;
Tdata
max_0
=
op
::
common_cuda
::
reduce_op
::
max
<
BLOCK_SIZE
,
Tdata
>
(
x
,
width
-
height
+
1
+
blockIdx
.
x
);
if
(
threadIdx
.
x
==
0
)
{
max_
=
max_0
;
}
__syncthreads
();
// [Elementwise] Subtract max value from each element and apply causal mask
for
(
size_t
col
=
threadIdx
.
x
;
col
<
width
;
col
+=
BLOCK_SIZE
)
{
// row_id ↓ |<- width ->|
// 0 | * * * ... * |
// 1 | * * * ... * * |
// 2 | * * * ... * * * |
// height: 3 col_id->
if
(
width
+
blockIdx
.
x
>=
threadIdx
.
x
+
height
)
{
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
||
std
::
is_same_v
<
Tdata
,
cuda_bfloat16
>
)
{
/*
* MUSA does not support CUDA's native `hexp` function.
* This code performs an explicit conversion:
* it casts the input to `float`, computes the exponential, and casts the result back.
* This ensures compatibility and correct behavior on the MUSA platform.
*/
float
val
=
static_cast
<
float
>
(
x
[
col
])
-
static_cast
<
float
>
(
max_
);
y
[
col
]
=
static_cast
<
Tdata
>
(
expf
(
val
));
}
else
{
y
[
col
]
=
exp
(
x
[
col
]
-
max_
);
}
}
else
{
/*
* In MUSA, the `__mt_bfloat16` type has ambiguous constructors for integer literals (e.g., `0`),
* as it could be implicitly converted from either `float` or `double`.
*
* This differs from CUDA's `half` type, which can typically be initialized
* from integer literals without ambiguity.
*
* To resolve this, we use the float literal `0.0f` to explicitly
* specify the conversion path, ensuring platform compatibility.
*/
y
[
col
]
=
Tdata
(
0.0
f
);
}
}
__syncthreads
();
// [Reduce] Find the sum of each updated row and store in shared memory
__shared__
Tcompute
sum_
;
Tcompute
sum_0
=
op
::
common_cuda
::
reduce_op
::
sum
<
BLOCK_SIZE
,
Tdata
,
Tcompute
>
(
y
,
width
);
if
(
threadIdx
.
x
==
0
)
{
sum_
=
sum_0
;
}
__syncthreads
();
// [Elementwise] Divide each element by the sum and store in shared memory
for
(
size_t
col
=
threadIdx
.
x
;
col
<
width
;
col
+=
BLOCK_SIZE
)
{
/*
* MUSA's bfloat16 type does not have a viable overloaded `/=` operator for float division.
* This change explicitly casts both operands to `float` before division,
* and then casts the result back to `Tdata`.
* This ensures the operation is performed correctly and avoids compilation errors.
*/
y
[
col
]
=
static_cast
<
Tdata
>
(
static_cast
<
float
>
(
y
[
col
])
/
static_cast
<
float
>
(
sum_
));
}
}
#endif // __CAUSAL_SOFTMAX_KERNEL_CUH__
src/infiniop/ops/causal_softmax/moore/causal_softmax_moore.h
0 → 100644
View file @
f50c2a40
#ifndef __CAUSAL_SOFTMAX_MOORE_H__
#define __CAUSAL_SOFTMAX_MOORE_H__
#include "../causal_softmax.h"
DESCRIPTOR
(
moore
)
#endif
src/infiniop/ops/causal_softmax/moore/causal_softmax_moore.mu
0 → 100644
View file @
f50c2a40
#include "../../../devices/moore/moore_common.h"
#include "causal_softmax_moore.h"
#include <cub/block/block_reduce.cuh>
#include "../../../devices/moore/moore_kernel_common.h"
#include "../../../reduce/cuda/reduce.cuh"
#include "causal_softmax_kernel.h"
template <unsigned int BLOCK_SIZE, typename Tdata, typename Tcompute>
INFINIOP_MOORE_KERNEL causalSoftmax(
Tdata *y, const Tdata *x,
size_t batch, size_t height, size_t width,
ptrdiff_t y_stride_b, ptrdiff_t y_stride_h,
ptrdiff_t x_stride_b, ptrdiff_t x_stride_h) {
causalSoftmaxKernel<BLOCK_SIZE, Tdata, Tcompute>(y, x, batch, height, width, y_stride_b, y_stride_h, x_stride_b, x_stride_h);
}
namespace op::causal_softmax::moore {
struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc) {
auto info = CausalSoftmaxInfo::create(y_desc, x_desc);
CHECK_RESULT(info);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel(void *y, const void *x, infiniDtype_t dtype,
size_t batch_size, size_t seq_len, size_t total_seq_len,
ptrdiff_t y_stride_b, ptrdiff_t y_stride_i,
ptrdiff_t x_stride_b, ptrdiff_t x_stride_i,
musaStream_t stream) {
dim3 grid(uint32_t(seq_len), uint32_t(batch_size), 1);
if (dtype == INFINI_DTYPE_F16) {
causalSoftmax<BLOCK_SIZE, half, float>
<<<grid, BLOCK_SIZE, 0, stream>>>((half *)y, (const half *)x,
batch_size, seq_len, total_seq_len,
y_stride_b, y_stride_i,
x_stride_b, x_stride_i);
} else if (dtype == INFINI_DTYPE_BF16) {
causalSoftmax<BLOCK_SIZE, __mt_bfloat16, float>
<<<grid, BLOCK_SIZE, 0, stream>>>((__mt_bfloat16 *)y, (const __mt_bfloat16 *)x,
batch_size, seq_len, total_seq_len,
y_stride_b, y_stride_i,
x_stride_b, x_stride_i);
} else if (dtype == INFINI_DTYPE_F32) {
causalSoftmax<BLOCK_SIZE, float, float>
<<<grid, BLOCK_SIZE, 0, stream>>>((float *)y, (const float *)x,
batch_size, seq_len, total_seq_len,
y_stride_b, y_stride_i,
x_stride_b, x_stride_i);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *y,
const void *x,
void *stream_) const {
musaStream_t stream = (musaStream_t)stream_;
if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_1024>(
y, x, _info.dtype, _info.batch_size, _info.seq_len, _info.total_seq_len,
_info.y_stride_b, _info.y_stride_i, _info.x_stride_b, _info.x_stride_i, stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_512>(
y, x, _info.dtype, _info.batch_size, _info.seq_len, _info.total_seq_len,
_info.y_stride_b, _info.y_stride_i, _info.x_stride_b, _info.x_stride_i, stream));
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::causal_softmax::moore
src/infiniop/ops/causal_softmax/operator.cc
View file @
f50c2a40
...
@@ -20,6 +20,9 @@
...
@@ -20,6 +20,9 @@
#ifdef ENABLE_KUNLUN_API
#ifdef ENABLE_KUNLUN_API
#include "kunlun/causal_softmax_kunlun.h"
#include "kunlun/causal_softmax_kunlun.h"
#endif
#endif
#ifdef ENABLE_MOORE_API
#include "moore/causal_softmax_moore.h"
#endif
__C
infiniStatus_t
infiniopCreateCausalSoftmaxDescriptor
(
__C
infiniStatus_t
infiniopCreateCausalSoftmaxDescriptor
(
infiniopHandle_t
handle
,
infiniopHandle_t
handle
,
...
@@ -56,6 +59,9 @@ __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
...
@@ -56,6 +59,9 @@ __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
#endif
#endif
#ifdef ENABLE_KUNLUN_API
#ifdef ENABLE_KUNLUN_API
CREATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
)
CREATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
)
#endif
#ifdef ENABLE_MOORE_API
CREATE
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
#endif
}
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
@@ -89,6 +95,9 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe
...
@@ -89,6 +95,9 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe
#endif
#endif
#ifdef ENABLE_KUNLUN_API
#ifdef ENABLE_KUNLUN_API
GET
(
INFINI_DEVICE_KUNLUN
,
kunlun
)
GET
(
INFINI_DEVICE_KUNLUN
,
kunlun
)
#endif
#ifdef ENABLE_MOORE_API
GET
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
#endif
}
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
@@ -127,6 +136,9 @@ __C infiniStatus_t infiniopCausalSoftmax(
...
@@ -127,6 +136,9 @@ __C infiniStatus_t infiniopCausalSoftmax(
#endif
#endif
#ifdef ENABLE_KUNLUN_API
#ifdef ENABLE_KUNLUN_API
CALCULATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
)
CALCULATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
)
#endif
#ifdef ENABLE_MOORE_API
CALCULATE
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
#endif
}
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
@@ -160,6 +172,9 @@ __C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxD
...
@@ -160,6 +172,9 @@ __C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxD
#endif
#endif
#ifdef ENABLE_KUNLUN_API
#ifdef ENABLE_KUNLUN_API
DESTROY
(
INFINI_DEVICE_KUNLUN
,
kunlun
)
DESTROY
(
INFINI_DEVICE_KUNLUN
,
kunlun
)
#endif
#ifdef ENABLE_MOORE_API
DESTROY
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
#endif
}
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
test/infiniop/causal_softmax.py
View file @
f50c2a40
...
@@ -69,7 +69,17 @@ NUM_ITERATIONS = 1000
...
@@ -69,7 +69,17 @@ NUM_ITERATIONS = 1000
def
causal_softmax
(
x
):
def
causal_softmax
(
x
):
type
=
x
.
dtype
type
=
x
.
dtype
mask
=
torch
.
tril
(
torch
.
ones_like
(
x
),
diagonal
=-
1
).
flip
(
dims
=
[
-
2
,
-
1
])
# Issue: torch_musa's implementation of `torch.tril` has a known bug for certain shapes (e.g., (32, 5, 5)).
# Workaround: Generate the lower triangular mask on the CPU and then transfer it to the MUSA device.
if
x
.
device
.
type
==
"musa"
:
mask
=
(
torch
.
tril
(
torch
.
ones_like
(
x
).
to
(
"cpu"
),
diagonal
=-
1
)
.
flip
(
dims
=
[
-
2
,
-
1
])
.
to
(
"musa"
)
)
else
:
mask
=
torch
.
tril
(
torch
.
ones_like
(
x
),
diagonal
=-
1
).
flip
(
dims
=
[
-
2
,
-
1
])
masked
=
torch
.
where
(
mask
==
1
,
-
torch
.
inf
,
x
.
to
(
torch
.
float32
))
masked
=
torch
.
where
(
mask
==
1
,
-
torch
.
inf
,
x
.
to
(
torch
.
float32
))
return
torch
.
nn
.
functional
.
softmax
(
masked
,
dim
=-
1
,
dtype
=
type
)
return
torch
.
nn
.
functional
.
softmax
(
masked
,
dim
=-
1
,
dtype
=
type
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment