Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
8c798185
Commit
8c798185
authored
Sep 02, 2025
by
zhangyue
Browse files
issue/416: p800 rearrange kernel
parent
0e1c5585
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
276 additions
and
0 deletions
+276
-0
src/infiniop/ops/rearrange/kunlun/kernel.h
src/infiniop/ops/rearrange/kunlun/kernel.h
+64
-0
src/infiniop/ops/rearrange/kunlun/rearrange_kunlun.h
src/infiniop/ops/rearrange/kunlun/rearrange_kunlun.h
+83
-0
src/infiniop/ops/rearrange/kunlun/rearrange_kunlun.xpu
src/infiniop/ops/rearrange/kunlun/rearrange_kunlun.xpu
+117
-0
src/infiniop/ops/rearrange/operator.cc
src/infiniop/ops/rearrange/operator.cc
+12
-0
No files found.
src/infiniop/ops/rearrange/kunlun/kernel.h
0 → 100644
View file @
8c798185
#ifndef __REARRANGE_KUNLUN_KERNEL_H__
#define __REARRANGE_KUNLUN_KERNEL_H__
#include "../../../devices/kunlun/kunlun_kernel_common.h"
using
namespace
device
::
kunlun
::
kernel
;
/**
* @brief rearrange kernel function
* @tparam BLOCK_SIZE the block size of the kernel
* @tparam T the data type of the input and output tensor
* @param x the input tensor
* @param y the output tensor
* @param shape the shape of the input tensor
* @param x_stride the stride of the input tensor
* @param y_stride the stride of the output tensor
* @param total_size the total size of the input tensor
*/
template
<
unsigned
int
BUFF_SIZE
,
typename
Tdata
>
__global__
void
rearrangeKernel
(
Tdata
*
y
,
const
Tdata
*
x
,
const
void
*
shape
,
const
void
*
x_stride
,
const
void
*
y_stride
,
uint32_t
ndim
,
uint32_t
total_size
)
{
int
cid
=
core_id
();
int
ncores
=
core_num
();
if
(
cid
>=
ncores
)
{
return
;
}
int
thread_id
=
ncores
*
cluster_id
()
+
cid
;
int
nthreads
=
ncores
*
cluster_num
();
__local__
Tdata
x_local
[
BUFF_SIZE
];
__local__
_size_t
shape_lm
[
ndim
];
__local__
_ptrdiff_t
x_stride_lm
[
ndim
];
__local__
_ptrdiff_t
y_stride_lm
[
ndim
];
GM2LM_ASYNC
(
shape
,
shape_lm
,
ndim
*
sizeof
(
_size_t
));
GM2LM_ASYNC
(
x_stride
,
x_stride_lm
,
ndim
*
sizeof
(
_ptrdiff_t
));
GM2LM_ASYNC
(
y_stride
,
y_stride_lm
,
ndim
*
sizeof
(
_ptrdiff_t
));
mfence
();
int
len_per_loop
=
min
(
BUFF_SIZE
,
roundup_div
(
total_size
,
nthreads
));
for
(
int
start
=
thread_id
*
len_per_loop
;
start
<
total_size
;
start
+=
nthreads
*
len_per_loop
)
{
int
len
=
min
(
len_per_loop
,
total_size
-
start
);
for
(
int
idx
=
start
;
idx
<
start
+
len
;
++
idx
)
{
int
in_idx
=
indexToOffset
(
idx
,
ndim
,
shape_lm
,
x_stride_lm
);
GM2LM_ASYNC
(
x
+
in_idx
,
x_local
+
idx
-
start
,
sizeof
(
Tdata
));
}
mfence
();
for
(
int
idx
=
start
;
idx
<
start
+
len
;
++
idx
)
{
int
out_idx
=
indexToOffset
(
idx
,
ndim
,
shape_lm
,
y_stride_lm
);
LM2GM_ASYNC
(
x_local
+
idx
-
start
,
y
+
out_idx
,
sizeof
(
Tdata
));
}
sync_cluster
();
}
}
#endif
src/infiniop/ops/rearrange/kunlun/rearrange_kunlun.h
0 → 100644
View file @
8c798185
#ifndef __REARRANGE_KUNLUN_H__
#define __REARRANGE_KUNLUN_H__
#include "../../../tensor.h"
#include "../rearrange.h"
#include <numeric>
namespace
op
::
rearrange
::
kunlun
{
struct
RearrangeInfo
{
std
::
vector
<
size_t
>
shape
;
std
::
vector
<
ptrdiff_t
>
src_strides
;
std
::
vector
<
ptrdiff_t
>
dst_strides
;
infiniDtype_t
dtype
;
// Device space Size for shape, src_strides, dst_strides
size_t
workspace_size
;
size_t
nelements
()
const
{
return
std
::
accumulate
(
shape
.
begin
(),
shape
.
end
(),
1
,
std
::
multiplies
<
size_t
>
());
}
size_t
ndim
()
const
{
return
shape
.
size
();
}
size_t
workspaceSize
()
const
{
return
workspace_size
;
}
static
utils
::
Result
<
RearrangeInfo
>
create
(
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
x_desc
)
{
auto
dtype
=
y_desc
->
dtype
();
auto
ndim
=
y_desc
->
ndim
();
CHECK_OR_RETURN
(
x_desc
->
dtype
()
==
dtype
,
INFINI_STATUS_BAD_TENSOR_DTYPE
);
CHECK_OR_RETURN
(
x_desc
->
ndim
()
==
ndim
,
INFINI_STATUS_BAD_TENSOR_SHAPE
);
auto
y_shape
=
y_desc
->
shape
();
auto
y_strides
=
y_desc
->
strides
();
auto
x_shape
=
x_desc
->
shape
();
auto
x_strides
=
x_desc
->
strides
();
CHECK_SAME_SHAPE
(
x_shape
,
y_shape
);
auto
workspace_size_
=
sizeof
(
size_t
)
*
ndim
+
sizeof
(
ptrdiff_t
)
*
ndim
*
2
;
return
utils
::
Result
<
RearrangeInfo
>
(
RearrangeInfo
{
y_shape
,
x_strides
,
y_strides
,
dtype
,
workspace_size_
,
});
}
};
class
Descriptor
final
:
public
InfiniopDescriptor
{
struct
Opaque
;
Opaque
*
_opaque
;
RearrangeInfo
_info
;
Descriptor
(
Opaque
*
opaque
,
infiniDevice_t
device_type
,
int
device_id
,
RearrangeInfo
info
)
:
InfiniopDescriptor
{
device_type
,
device_id
},
_opaque
(
opaque
),
_info
(
info
)
{}
public:
~
Descriptor
();
static
infiniStatus_t
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
x_desc
);
infiniStatus_t
calculate
(
void
*
y
,
const
void
*
x
,
void
*
stream
)
const
;
};
}
// namespace op::rearrange::kunlun
#endif // __REARRANGE_KUNLUN_H__
src/infiniop/ops/rearrange/kunlun/rearrange_kunlun.xpu
0 → 100644
View file @
8c798185
#include "../../../devices/kunlun/kunlun_common.h"
#include "../../../devices/kunlun/kunlun_handle.h"
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "kernel.h"
#include "rearrange_kunlun.h"
#include <memory>
namespace op::rearrange::kunlun {
struct Descriptor::Opaque {
std::shared_ptr<device::kunlun::Handle::Internal> internal;
void *workspace;
~Opaque() {
if (workspace) {
xpu_free(workspace);
}
}
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc) {
auto result = RearrangeInfo::create(y_desc, x_desc);
CHECK_RESULT(result);
auto info = result.take();
void *workspace = nullptr;
size_t workspace_size = info.workspaceSize();
CHECK_KUNLUN(xpu_malloc(&workspace, workspace_size, XPU_MEM_L3));
*desc_ptr = new Descriptor(
new Opaque{
reinterpret_cast<device::kunlun::Handle *>(handle)->internal(),
workspace},
handle->device,
handle->device_id,
std::move(info));
return INFINI_STATUS_SUCCESS;
}
template <unsigned int BUFF_SIZE>
infiniStatus_t launchKernel(
void *y,
const void *x,
void *workspace,
size_t ndim,
size_t total_size,
infiniDtype_t dtype,
kunlunStream_t stream) {
__global_ptr__ size_t *d_shape = reinterpret_cast<__global_ptr__ size_t *>(workspace);
__global_ptr__ ptrdiff_t *d_src_strides = reinterpret_cast<__global_ptr__ ptrdiff_t *>(d_shape + ndim);
__global_ptr__ ptrdiff_t *d_dst_strides = reinterpret_cast<__global_ptr__ ptrdiff_t *>(d_src_strides + ndim);
#define LAUNCH_KERNEL(Tdata) \
rearrangeKernel<BUFF_SIZE, Tdata> \
<<<12, 64, stream>>>( \
reinterpret_cast<__global_ptr__ Tdata *>(y), \
reinterpret_cast<__global_ptr__ const Tdata *>(x), \
reinterpret_cast<__global_ptr__ void *>(d_shape), \
reinterpret_cast<__global_ptr__ void *>(d_src_strides), \
reinterpret_cast<__global_ptr__ void *>(d_dst_strides), \
static_cast<uint32_t>(ndim), \
static_cast<uint32_t>(total_size));
switch (dtype) {
case INFINI_DTYPE_F32:
LAUNCH_KERNEL(float);
break;
case INFINI_DTYPE_BF16:
LAUNCH_KERNEL(bfloat16_t);
break;
case INFINI_DTYPE_F16:
LAUNCH_KERNEL(half);
break;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
#undef LAUNCH_KERNEL
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *y,
const void *x,
void *stream) const {
size_t ndim = _info.ndim();
size_t total_size = _info.nelements();
infiniDtype_t dtype = _info.dtype;
// Get workspace from opaque
void *workspace = _opaque->workspace;
__global_ptr__ size_t *d_shape = reinterpret_cast<__global_ptr__ size_t *>(workspace);
__global_ptr__ ptrdiff_t *d_src_strides = reinterpret_cast<__global_ptr__ ptrdiff_t *>(d_shape + ndim);
__global_ptr__ ptrdiff_t *d_dst_strides = reinterpret_cast<__global_ptr__ ptrdiff_t *>(d_src_strides + ndim);
// Copy shape, src_strides, dst_strides to device memory
CHECK_KUNLUN(xpu_memcpy_async(d_shape, _info.shape.data(), sizeof(size_t) * ndim, XPU_HOST_TO_DEVICE, stream));
CHECK_KUNLUN(xpu_memcpy_async(d_src_strides, _info.src_strides.data(), sizeof(ptrdiff_t) * ndim, XPU_HOST_TO_DEVICE, stream));
CHECK_KUNLUN(xpu_memcpy_async(d_dst_strides, _info.dst_strides.data(), sizeof(ptrdiff_t) * ndim, XPU_HOST_TO_DEVICE, stream));
CHECK_STATUS(launchKernel<64>(y, x, workspace,
ndim, total_size, dtype,
reinterpret_cast<kunlunStream_t>(stream)));
return INFINI_STATUS_SUCCESS;
}
} // namespace op::rearrange::kunlun
src/infiniop/ops/rearrange/operator.cc
View file @
8c798185
...
@@ -20,6 +20,9 @@
...
@@ -20,6 +20,9 @@
#ifdef ENABLE_MOORE_API
#ifdef ENABLE_MOORE_API
#include "moore/rearrange_moore.h"
#include "moore/rearrange_moore.h"
#endif
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/rearrange_kunlun.h"
#endif
__C
infiniStatus_t
infiniopCreateRearrangeDescriptor
(
__C
infiniStatus_t
infiniopCreateRearrangeDescriptor
(
infiniopHandle_t
handle
,
infiniopHandle_t
handle
,
...
@@ -57,6 +60,9 @@ __C infiniStatus_t infiniopCreateRearrangeDescriptor(
...
@@ -57,6 +60,9 @@ __C infiniStatus_t infiniopCreateRearrangeDescriptor(
#endif
#endif
#ifdef ENABLE_MOORE_API
#ifdef ENABLE_MOORE_API
CREATE
(
INFINI_DEVICE_MOORE
,
moore
);
CREATE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
#ifdef ENABLE_KUNLUN_API
CREATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
@@ -99,6 +105,9 @@ __C infiniStatus_t infiniopRearrange(
...
@@ -99,6 +105,9 @@ __C infiniStatus_t infiniopRearrange(
#ifdef ENABLE_MOORE_API
#ifdef ENABLE_MOORE_API
CALCULATE
(
INFINI_DEVICE_MOORE
,
moore
);
CALCULATE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
@@ -138,6 +147,9 @@ __C infiniStatus_t infiniopDestroyRearrangeDescriptor(
...
@@ -138,6 +147,9 @@ __C infiniStatus_t infiniopDestroyRearrangeDescriptor(
#ifdef ENABLE_MOORE_API
#ifdef ENABLE_MOORE_API
DELETE
(
INFINI_DEVICE_MOORE
,
moore
);
DELETE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
#endif
#ifdef ENABLE_KUNLUN_API
DELETE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment