Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
c5167eb7
Commit
c5167eb7
authored
Aug 22, 2025
by
zhangyue
Browse files
issue/390: kunlun p800 causal softmax
parent
c35920e2
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
315 additions
and
12 deletions
+315
-12
src/infiniop/devices/kunlun/kunlun_kernel_common.h
src/infiniop/devices/kunlun/kunlun_kernel_common.h
+92
-0
src/infiniop/ops/causal_softmax/kunlun/causal_softmax_kunlun.h
...nfiniop/ops/causal_softmax/kunlun/causal_softmax_kunlun.h
+8
-0
src/infiniop/ops/causal_softmax/kunlun/causal_softmax_kunlun.xpu
...iniop/ops/causal_softmax/kunlun/causal_softmax_kunlun.xpu
+110
-0
src/infiniop/ops/causal_softmax/kunlun/kernel.h
src/infiniop/ops/causal_softmax/kunlun/kernel.h
+65
-0
src/infiniop/ops/causal_softmax/operator.cc
src/infiniop/ops/causal_softmax/operator.cc
+15
-0
src/infiniop/ops/rms_norm/kunlun/rms_norm_kunlun.h
src/infiniop/ops/rms_norm/kunlun/rms_norm_kunlun.h
+0
-10
src/infiniop/reduce/kunlun/reduce_kunlun.h
src/infiniop/reduce/kunlun/reduce_kunlun.h
+25
-2
No files found.
src/infiniop/devices/kunlun/kunlun_kernel_common.h
View file @
c5167eb7
...
...
@@ -114,6 +114,29 @@ inline __device__ bfloat16_t atomicAdd<bfloat16_t>(__shared_ptr__ bfloat16_t *pt
return
old
;
}
/**
* @brief atomicMax for kunlun xpu
* @param ptr: pointer to shared memory
* @param value: value to compare
*/
template
<
typename
T
>
inline
__device__
T
atomicMax
(
__shared_ptr__
T
*
ptr
,
T
value
)
{
ticket_lock_mix
();
T
old
=
loadsm
(
ptr
);
if
constexpr
(
std
::
is_same
<
T
,
bfloat16_t
>::
value
)
{
float
of
=
__bfloat162float
(
old
);
float
vf
=
__bfloat162float
(
value
);
float
maxf
=
fmax
(
of
,
vf
);
bfloat16_t
max
=
__float2bfloat16_rn
(
maxf
);
*
ptr
=
max
;
}
else
{
*
ptr
=
fmax
(
old
,
value
);
}
mfence_sm
();
ticket_unlock_mix
();
return
old
;
}
/**
* @brief Get index of broadcasted input
* flat_index: flatten index of output tensor
...
...
@@ -156,6 +179,75 @@ inline __device__ int indexToOffset(
return
res
;
}
/**
* @brief Get max of a array of local mem
* @param data: pointer to local memory
* @param len: length of array
* @return max value
*/
template
<
typename
T
>
__inline__
__device__
T
max
(
const
T
*
data_ptr
,
size_t
len
)
{
T
max_val
=
data_ptr
[
0
];
for
(
size_t
i
=
0
;
i
<
len
;
++
i
)
{
max_val
=
fmax
(
max_val
,
data_ptr
[
i
]);
}
return
max_val
;
}
// Use simd vector instruction to calculate max of a half array
template
<
>
__inline__
__device__
half
max
(
const
half
*
data_ptr
,
size_t
len
)
{
int
remain
=
len
%
32
;
int
offset_last
=
len
-
remain
;
half
res
=
data_ptr
[
0
];
for
(
int
i
=
offset_last
;
i
<
len
;
i
++
)
{
res
=
fmax
(
res
,
*
(
data_ptr
+
i
));
}
mfence
();
if
(
offset_last
!=
0
)
{
__local__
half
acc_buf
[
32
];
float16x32_t
v_mv
=
vload_lm_float16x32_mz
(
data_ptr
);
// for every 16 float data
for
(
int
i
=
32
;
i
<
offset_last
;
i
+=
32
)
{
float16x32_t
v_0
=
vload_lm_float16x32_mz
(
data_ptr
+
i
);
v_mv
=
vvmax_float16x32_mz
(
v_mv
,
v_0
);
}
vstore_lm_float16x32_mz
(
acc_buf
,
v_mv
);
mfence
();
for
(
int
i
=
0
;
i
<
32
;
i
++
)
{
res
=
fmax
(
res
,
acc_buf
[
i
]);
}
}
return
res
;
}
// Use simd vector instruction to calculate max of a half array
template
<
>
__inline__
__device__
float
max
(
const
float
*
data_ptr
,
size_t
len
)
{
int
remain
=
len
%
16
;
int
offset_last
=
len
-
remain
;
float
res
=
data_ptr
[
0
];
for
(
int
i
=
offset_last
;
i
<
len
;
i
++
)
{
res
=
fmax
(
res
,
*
(
data_ptr
+
i
));
}
mfence
();
if
(
offset_last
!=
0
)
{
__local__
float
acc_buf
[
16
];
float32x16_t
v_mv
=
vload_lm_float32x16_mz
(
data_ptr
);
// for every 16 float data
for
(
int
i
=
16
;
i
<
offset_last
;
i
+=
16
)
{
float32x16_t
v_0
=
vload_lm_float32x16_mz
(
data_ptr
+
i
);
v_mv
=
vvmax_float32x16_mz
(
v_mv
,
v_0
);
}
vstore_lm_float32x16_mz
(
acc_buf
,
v_mv
);
mfence
();
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
res
=
fmax
(
res
,
acc_buf
[
i
]);
}
}
return
res
;
}
}
// namespace device::kunlun::kernel
#endif // __INFINIOP_KUNLUN_KERNEL_COMMON_H__
src/infiniop/ops/causal_softmax/kunlun/causal_softmax_kunlun.h
0 → 100644
View file @
c5167eb7
#ifndef __CAUSAL_SOFTMAX_KUNLUN_H__
#define __CAUSAL_SOFTMAX_KUNLUN_H__
#include "../causal_softmax.h"
DESCRIPTOR
(
kunlun
)
#endif
src/infiniop/ops/causal_softmax/kunlun/causal_softmax_kunlun.xpu
0 → 100644
View file @
c5167eb7
#include "../../../devices/kunlun/kunlun_common.h"
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "causal_softmax_kunlun.h"
#include "kernel.h"
template <unsigned int BLOCK_SIZE, typename Tdata, typename Tcompute>
__global__ void causalSoftmaxKernel(
Tdata *y,
const Tdata *x,
uint32_t batch,
uint32_t height,
uint32_t width,
int32_t y_stride_h,
int32_t x_stride_h) {
__shared__ Tdata x_sm[SM_SIZE / sizeof(Tdata)];
__shared__ Tdata y_sm[SM_SIZE / sizeof(Tdata)];
int row_id = cluster_id();
__global_ptr__ Tdata *y_ = y + row_id * y_stride_h;
__global_ptr__ const Tdata *x_ = x + row_id * x_stride_h;
if (core_id() == 0) {
GM2SM_ASYNC(x_, x_sm, width * sizeof(Tdata));
}
sync_cluster();
causalSoftmaxBlock<BLOCK_SIZE, Tdata, Tcompute>(y_sm, x_sm, height, width, row_id);
if (core_id() == 0) {
SM2GM_ASYNC(y_sm, y_, width * sizeof(Tdata));
}
sync_cluster();
}
namespace op::causal_softmax::kunlun {
struct Descriptor::Opaque {
std::shared_ptr<device::kunlun::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc) {
auto info = CausalSoftmaxInfo::create(y_desc, x_desc);
CHECK_RESULT(info);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::kunlun::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel(void *y, const void *x, infiniDtype_t dtype,
size_t batch_size, size_t seq_len, size_t total_seq_len,
ptrdiff_t y_stride_b, ptrdiff_t y_stride_h,
ptrdiff_t x_stride_b, ptrdiff_t x_stride_h,
kunlunStream_t stream) {
// Kunlunxin kernel dont support ptrdiff_t and size_t as parameters
uint32_t batch_size_ = static_cast<uint32_t>(batch_size);
uint32_t seq_len_ = static_cast<uint32_t>(seq_len);
uint32_t total_seq_len_ = static_cast<uint32_t>(total_seq_len);
int32_t y_stride_b_ = static_cast<int32_t>(y_stride_b);
int32_t y_stride_h_ = static_cast<int32_t>(y_stride_h);
int32_t x_stride_b_ = static_cast<int32_t>(x_stride_b);
int32_t x_stride_h_ = static_cast<int32_t>(x_stride_h);
#define LAUCH_KERNEL(Tdata, Tcompute) \
for (uint32_t i = 0; i < batch_size_; ++i) { \
causalSoftmaxKernel<BLOCK_SIZE, Tdata, Tcompute> \
<<<seq_len_, BLOCK_SIZE, stream>>>((Tdata *)y + i * y_stride_b_, (const Tdata *)x + i * x_stride_b_, \
batch_size, seq_len, total_seq_len, \
y_stride_h, x_stride_h); \
}
if (dtype == INFINI_DTYPE_F16) {
LAUCH_KERNEL(half, float);
} else if (dtype == INFINI_DTYPE_BF16) {
LAUCH_KERNEL(bfloat16_t, float);
} else if (dtype == INFINI_DTYPE_F32) {
LAUCH_KERNEL(float, float);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
#undef LAUCH_KERNEL
}
infiniStatus_t Descriptor::calculate(void *workspace,
size_t workspace_size,
void *y,
const void *x,
void *stream_) const {
kunlunStream_t stream = (kunlunStream_t)stream_;
CHECK_STATUS(launchKernel<64>(
y, x, _info.dtype, _info.batch_size, _info.seq_len, _info.total_seq_len,
_info.y_stride_b, _info.y_stride_i, _info.x_stride_b, _info.x_stride_i, stream));
return INFINI_STATUS_SUCCESS;
}
} // namespace op::causal_softmax::kunlun
src/infiniop/ops/causal_softmax/kunlun/kernel.h
0 → 100644
View file @
c5167eb7
#ifndef __CAUSAL_SOFTMAX_KUNLUN_KERNEL_H__
#define __CAUSAL_SOFTMAX_KUNLUN_KERNEL_H__
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "../../../reduce/kunlun/reduce_kunlun.h"
using
namespace
device
::
kunlun
::
kernel
;
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tdata
,
typename
Tcompute
>
__device__
void
causalSoftmaxBlock
(
__shared_ptr__
Tdata
*
y
,
__shared_ptr__
const
Tdata
*
x
,
size_t
height
,
size_t
width
,
int
row_id
)
{
// Reduce max for each row and store in shared memory
__shared__
Tdata
max_
;
Tdata
max_0
=
op
::
common_kunlun
::
reduce_op
::
max
<
BLOCK_SIZE
,
Tdata
>
(
x
,
width
-
height
+
1
+
size_t
(
row_id
));
if
(
core_id
()
==
0
)
{
max_
=
max_0
;
}
sync_cluster
();
// Elemetwise sub max for each element and apply causal softmax
for
(
size_t
col
=
core_id
();
col
<
width
;
col
+=
BLOCK_SIZE
)
{
// row_id ↓ |<- width ->|
// 0 | * * * ... * |
// 1 | * * * ... * * |
// 2 | * * * ... * * * |
// height: 3 col_id->
if
(
width
+
size_t
(
row_id
)
>=
col
+
height
)
{
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
)
{
y
[
col
]
=
hexp
(
loadsm
(
x
+
col
)
-
loadsm
(
&
max_
));
}
else
if
constexpr
(
std
::
is_same_v
<
Tdata
,
bfloat16_t
>
)
{
y
[
col
]
=
__float2bfloat16
(
exp
(
__bfloat162float
(
x
[
col
])
-
__bfloat162float
(
max_
)));
}
else
{
y
[
col
]
=
exp
(
x
[
col
]
-
max_
);
}
}
else
{
y
[
col
]
=
Tdata
(
0
);
}
}
sync_cluster
();
// Reduce sum for each row
__shared__
Tcompute
sum_
;
Tcompute
sum_0
=
op
::
common_kunlun
::
reduce_op
::
sum
<
BLOCK_SIZE
,
Tdata
,
Tcompute
>
(
y
,
width
);
if
(
core_id
()
==
0
)
{
sum_
=
sum_0
;
}
sync_cluster
();
// Apply softmax
for
(
size_t
col
=
core_id
();
col
<
width
;
col
+=
BLOCK_SIZE
)
{
if
(
sum_
!=
0
)
{
y
[
col
]
=
to
<
Tdata
>
(
to
<
Tcompute
>
(
loadsm
(
y
+
col
))
/
sum_
);
}
else
{
y
[
col
]
=
Tdata
(
0
);
}
}
sync_cluster
();
}
#endif
src/infiniop/ops/causal_softmax/operator.cc
View file @
c5167eb7
...
...
@@ -17,6 +17,9 @@
#ifdef ENABLE_CAMBRICON_API
#include "bang/causal_softmax_bang.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/causal_softmax_kunlun.h"
#endif
__C
infiniStatus_t
infiniopCreateCausalSoftmaxDescriptor
(
infiniopHandle_t
handle
,
...
...
@@ -50,6 +53,9 @@ __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
#endif
#ifdef ENABLE_ASCEND_API
CREATE
(
INFINI_DEVICE_ASCEND
,
ascend
)
#endif
#ifdef ENABLE_KUNLUN_API
CREATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
)
#endif
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -80,6 +86,9 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe
#endif
#ifdef ENABLE_CAMBRICON_API
GET
(
INFINI_DEVICE_CAMBRICON
,
bang
)
#endif
#ifdef ENABLE_KUNLUN_API
GET
(
INFINI_DEVICE_KUNLUN
,
kunlun
)
#endif
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -115,6 +124,9 @@ __C infiniStatus_t infiniopCausalSoftmax(
#endif
#ifdef ENABLE_ASCEND_API
CALCULATE
(
INFINI_DEVICE_ASCEND
,
ascend
)
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
)
#endif
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -145,6 +157,9 @@ __C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxD
#endif
#ifdef ENABLE_ASCEND_API
DESTROY
(
INFINI_DEVICE_ASCEND
,
ascend
)
#endif
#ifdef ENABLE_KUNLUN_API
DESTROY
(
INFINI_DEVICE_KUNLUN
,
kunlun
)
#endif
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
src/infiniop/ops/rms_norm/kunlun/rms_norm_kunlun.h
View file @
c5167eb7
...
...
@@ -5,14 +5,4 @@
DESCRIPTOR
(
kunlun
)
#define INSTANTIATE_RMSNORM_KERNEL(BLOCK_SIZE, Tcompute, Tdata, Tweight) \
template __global__ void rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight>( \
Tdata * y, \
int32_t stride_y, \
const Tdata *x, \
int32_t stride_x, \
const Tweight *w, \
uint32_t dim, \
float epsilon);
#endif
src/infiniop/reduce/kunlun/reduce_kunlun.h
View file @
c5167eb7
...
...
@@ -26,7 +26,7 @@ __device__ inline Tcompute sumSquared(__shared_ptr__ const Tdata *data_ptr, size
atomicAdd
(
&
temp_storage
,
ss
);
sync_cluster
();
return
temp_storage
;
return
loadsm
(
&
temp_storage
)
;
}
// Sum(x) on contiguous data of length count
...
...
@@ -48,7 +48,30 @@ __device__ inline Tcompute sum(__shared_ptr__ const Tdata *data_ptr, size_t coun
atomicAdd
(
&
temp_storage
,
ss
);
sync_cluster
();
return
temp_storage
;
return
loadsm
(
&
temp_storage
);
}
// Max(x) on contiguous data of length count
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tdata
>
__device__
inline
Tdata
max
(
__shared_ptr__
const
Tdata
*
data_ptr
,
size_t
count
)
{
Tdata
max_val
=
loadsm
(
data_ptr
);
for
(
size_t
i
=
core_id
();
i
<
count
;
i
+=
BLOCK_SIZE
)
{
// Tdata xi = loadsm(data_ptr + i);
Tdata
xi
=
loadsm
(
data_ptr
+
i
);
max_val
=
fmax
(
max_val
,
to
<
Tdata
>
(
xi
));
}
__shared__
Tdata
temp_storage
;
if
(
core_id
()
==
0
)
{
temp_storage
=
loadsm
(
data_ptr
);
}
sync_cluster
();
atomicMax
(
&
temp_storage
,
max_val
);
sync_cluster
();
return
loadsm
(
&
temp_storage
);
}
}
// namespace op::common_kunlun::reduce_op
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment