Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
8583cae7
Commit
8583cae7
authored
Feb 19, 2025
by
zhangyue
Browse files
issue/25: kunlun fundation and matmul
parent
45175dbf
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
293 additions
and
0 deletions
+293
-0
src/infiniop/devices/handle.cc
src/infiniop/devices/handle.cc
+13
-0
src/infiniop/devices/kunlun/common_kunlun.h
src/infiniop/devices/kunlun/common_kunlun.h
+45
-0
src/infiniop/devices/kunlun/kunlun_handle.cc
src/infiniop/devices/kunlun/kunlun_handle.cc
+24
-0
src/infiniop/devices/kunlun/kunlun_handle.h
src/infiniop/devices/kunlun/kunlun_handle.h
+13
-0
src/infiniop/ops/matmul/kunlun/matmul_xdnn.cc
src/infiniop/ops/matmul/kunlun/matmul_xdnn.cc
+106
-0
src/infiniop/ops/matmul/kunlun/matmul_xdnn.h
src/infiniop/ops/matmul/kunlun/matmul_xdnn.h
+17
-0
src/infiniop/ops/matmul/kunlun/matmul_xdnn_api.h
src/infiniop/ops/matmul/kunlun/matmul_xdnn_api.h
+31
-0
src/infiniop/ops/matmul/operator.cc
src/infiniop/ops/matmul/operator.cc
+3
-0
test/infiniop/libinfiniop/utils.py
test/infiniop/libinfiniop/utils.py
+8
-0
xmake.lua
xmake.lua
+14
-0
xmake/kunlun.lua
xmake/kunlun.lua
+19
-0
No files found.
src/infiniop/devices/handle.cc
View file @
8583cae7
...
@@ -11,6 +11,9 @@
...
@@ -11,6 +11,9 @@
#ifdef ENABLE_ASCEND_API
#ifdef ENABLE_ASCEND_API
#include "ascend/ascend_handle.h"
#include "ascend/ascend_handle.h"
#endif
#endif
#ifdef ENABLE_KUNLUN_API
#include "./kunlun/kunlun_handle.h"
#endif
__C
infiniStatus_t
infiniopCreateHandle
(
infiniopHandle_t
*
handle_ptr
,
__C
infiniStatus_t
infiniopCreateHandle
(
infiniopHandle_t
*
handle_ptr
,
infiniDevice_t
device
)
{
infiniDevice_t
device
)
{
...
@@ -37,6 +40,11 @@ __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr,
...
@@ -37,6 +40,11 @@ __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr,
case
INFINI_DEVICE_ASCEND
:
{
case
INFINI_DEVICE_ASCEND
:
{
return
createAscendHandle
((
infiniopAscendHandle_t
*
)
handle_ptr
);
return
createAscendHandle
((
infiniopAscendHandle_t
*
)
handle_ptr
);
}
}
#endif
#ifdef ENABLE_KUNLUN_API
case
INFINI_DEVICE_KUNLUN
:
{
return
createKunlunHandle
((
infiniopKunlunHandle_t
*
)
handle_ptr
);
}
#endif
#endif
}
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
@@ -62,6 +70,11 @@ __C infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle) {
...
@@ -62,6 +70,11 @@ __C infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle) {
case
INFINI_DEVICE_ASCEND
:
{
case
INFINI_DEVICE_ASCEND
:
{
return
destroyAscendHandle
((
infiniopAscendHandle_t
)
handle
);
return
destroyAscendHandle
((
infiniopAscendHandle_t
)
handle
);
}
}
#endif
#ifdef ENABLE_KUNLUN_API
case
INFINI_DEVICE_KUNLUN
:
{
return
destroyKunlunHandle
((
infiniopKunlunHandle_t
)
handle
);
}
#endif
#endif
}
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
src/infiniop/devices/kunlun/common_kunlun.h
0 → 100644
View file @
8583cae7
#ifndef __INFINIOP_COMMON_KUNLUN_H__
#define __INFINIOP_COMMON_KUNLUN_H__
#include "../pool.h"
#include "infinicore.h"
#include "kunlun_handle.h"
#include "xpu/runtime.h"
#include "xpu/runtime_ex.h"
#include "xpu/xdnn.h"
#include <memory>
namespace
xdnn
=
baidu
::
xpu
::
api
;
typedef
xdnn
::
Context
*
xdnnHandle_t
;
#define CHECK_KUNLUN(call) \
{ \
auto err = call; \
if (XPU_SUCCESS != err) { \
fprintf(stderr, "KUNLUN error in %s:%i : %s.\n", __FILE__, \
__LINE__, xpu_strerror(err)); \
return INFINIOP_STATUS_INTERNAL_ERROR; \
} \
}
struct
InfiniopKunlunHandle
{
infiniDevice_t
device
;
int
device_id
;
std
::
shared_ptr
<
Pool
<
xdnnHandle_t
>>
xdnn_handle_pool
;
};
template
<
typename
T
>
infiniopStatus_t
use_xdnn
(
std
::
shared_ptr
<
Pool
<
xdnnHandle_t
>>
xdnn_handle_pool
,
XPUStream
stream
,
T
const
&
f
)
{
auto
handle
=
xdnn_handle_pool
->
pop
();
if
(
!
handle
)
{
*
handle
=
xdnn
::
create_context
();
}
(
*
handle
)
->
set_stream
(
stream
);
auto
ret
=
f
(
*
handle
);
xdnn_handle_pool
->
push
(
std
::
move
(
*
handle
));
return
ret
;
}
#endif
src/infiniop/devices/kunlun/kunlun_handle.cc
0 → 100644
View file @
8583cae7
#include "common_kunlun.h"
infiniopStatus_t
createKunlunHandle
(
infiniopKunlunHandle_t
*
handle_ptr
)
{
int
device_id
;
CHECK_KUNLUN
(
xpu_current_device
(
&
device_id
))
auto
pool
=
std
::
make_shared
<
Pool
<
xdnnHandle_t
>>
();
xdnnHandle_t
handle
=
xdnn
::
create_context
();
pool
->
push
(
std
::
move
(
handle
));
*
handle_ptr
=
new
InfiniopKunlunHandle
{
INFINI_DEVICE_KUNLUN
,
device_id
,
std
::
move
(
pool
),
};
return
INFINIOP_STATUS_SUCCESS
;
}
infiniopStatus_t
destroyKunlunHandle
(
infiniopKunlunHandle_t
handle
)
{
handle
->
xdnn_handle_pool
=
nullptr
;
delete
handle
;
return
INFINIOP_STATUS_SUCCESS
;
}
src/infiniop/devices/kunlun/kunlun_handle.h
0 → 100644
View file @
8583cae7
#ifndef __INFINIOP_KUNLUN_HANDLE_H__
#define __INFINIOP_KUNLUN_HANDLE_H__
#include "infiniop/handle.h"
struct
InfiniopKunlunHandle
;
typedef
struct
InfiniopKunlunHandle
*
infiniopKunlunHandle_t
;
infiniopStatus_t
createKunlunHandle
(
infiniopKunlunHandle_t
*
handle_ptr
);
infiniopStatus_t
destroyKunlunHandle
(
infiniopKunlunHandle_t
handle
);
#endif
src/infiniop/ops/matmul/kunlun/matmul_xdnn.cc
0 → 100644
View file @
8583cae7
#include "matmul_xdnn.h"
template
<
typename
T
>
infiniopStatus_t
matmulKunlunCommon
(
infiniopMatmulKunlunDescriptor_t
desc
,
void
*
c
,
float
beta
,
void
const
*
a
,
void
const
*
b
,
float
alpha
,
void
*
stream
)
{
auto
info
=
desc
->
info
;
if
(
info
.
is_transed
)
{
std
::
swap
(
a
,
b
);
}
auto
transA
=
info
.
a_matrix
.
col_stride
==
1
?
false
:
true
;
auto
transB
=
info
.
b_matrix
.
col_stride
==
1
?
false
:
true
;
auto
ret
=
use_xdnn
(
desc
->
xdnn_handle_pool
,
(
XPUStream
)
stream
,
[
&
](
xdnnHandle_t
handle
)
{
for
(
size_t
i
=
0
;
i
<
info
.
batch
;
i
++
)
{
CHECK_KUNLUN
((
xdnn
::
fc_fusion
<
T
,
T
,
T
,
int16_t
>
(
handle
,
(
T
*
)((
char
*
)
a
+
i
*
info
.
a_matrix
.
stride
*
infiniSizeof
(
desc
->
dtype
)),
(
T
*
)((
char
*
)
b
+
i
*
info
.
b_matrix
.
stride
*
infiniSizeof
(
desc
->
dtype
)),
(
T
*
)((
char
*
)
c
+
i
*
info
.
c_matrix
.
stride
*
infiniSizeof
(
desc
->
dtype
)),
info
.
m
,
info
.
n
,
info
.
k
,
transA
,
transB
,
nullptr
,
nullptr
,
nullptr
,
info
.
a_matrix
.
ld
(),
info
.
b_matrix
.
ld
(),
info
.
c_matrix
.
ld
(),
alpha
,
beta
,
nullptr
,
xdnn
::
Activation_t
::
LINEAR
,
nullptr
)));
}
return
INFINIOP_STATUS_SUCCESS
;
});
return
ret
;
}
infiniopStatus_t
kunlunCreateMatmulDescriptor
(
infiniopKunlunHandle_t
handle
,
infiniopMatmulKunlunDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
c_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
)
{
infiniDtype_t
dtype
=
c_desc
->
dtype
;
if
(
dtype
!=
INFINI_DTYPE_F16
&&
dtype
!=
INFINI_DTYPE_F32
)
{
return
INFINIOP_STATUS_BAD_TENSOR_DTYPE
;
}
infiniopStatus_t
status
;
auto
info
=
MatmulInfo
(
c_desc
,
a_desc
,
b_desc
,
&
status
,
false
);
if
(
status
!=
INFINIOP_STATUS_SUCCESS
)
{
return
status
;
}
*
desc_ptr
=
new
InfiniopMatmulKunlunDescriptor
{
INFINI_DEVICE_KUNLUN
,
dtype
,
handle
->
device_id
,
info
,
handle
->
xdnn_handle_pool
};
return
INFINIOP_STATUS_SUCCESS
;
}
infiniopStatus_t
kunlunGetMatmulWorkspaceSize
(
infiniopMatmulKunlunDescriptor_t
desc
,
size_t
*
size
)
{
*
size
=
0
;
return
INFINIOP_STATUS_SUCCESS
;
}
infiniopStatus_t
kunlunMatmul
(
infiniopMatmulKunlunDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
c
,
void
const
*
a
,
void
const
*
b
,
float
alpha
,
float
beta
,
void
*
stream
)
{
if
(
desc
->
dtype
==
INFINI_DTYPE_F16
)
{
return
matmulKunlunCommon
<
float16
>
(
desc
,
c
,
beta
,
a
,
b
,
alpha
,
stream
);
}
if
(
desc
->
dtype
==
INFINI_DTYPE_F32
)
{
return
matmulKunlunCommon
<
float
>
(
desc
,
c
,
beta
,
a
,
b
,
alpha
,
stream
);
}
return
INFINIOP_STATUS_BAD_TENSOR_DTYPE
;
}
infiniopStatus_t
kunlunDestroyMatmulDescriptor
(
infiniopMatmulKunlunDescriptor_t
desc
)
{
desc
->
xdnn_handle_pool
=
nullptr
;
delete
desc
;
return
INFINIOP_STATUS_SUCCESS
;
}
src/infiniop/ops/matmul/kunlun/matmul_xdnn.h
0 → 100644
View file @
8583cae7
#ifndef __MATMUL_XDNN_H__
#define __MATMUL_XDNN_H__
#include "../../../devices/kunlun/common_kunlun.h"
#include "../../utils.h"
#include "../blas.h"
#include "matmul_xdnn_api.h"
struct
InfiniopMatmulKunlunDescriptor
{
infiniDevice_t
device
;
infiniDtype_t
dtype
;
int
device_id
;
MatmulInfo
info
;
std
::
shared_ptr
<
Pool
<
xdnnHandle_t
>>
xdnn_handle_pool
;
};
#endif
src/infiniop/ops/matmul/kunlun/matmul_xdnn_api.h
0 → 100644
View file @
8583cae7
#ifndef __MATMUL_XDNN_API_H__
#define __MATMUL_XDNN_API_H__
#include "../../../devices/kunlun/kunlun_handle.h"
#include "infiniop/operator.h"
struct
InfiniopMatmulKunlunDescriptor
;
typedef
struct
InfiniopMatmulKunlunDescriptor
*
infiniopMatmulKunlunDescriptor_t
;
infiniopStatus_t
kunlunCreateMatmulDescriptor
(
infiniopKunlunHandle_t
handle
,
infiniopMatmulKunlunDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
c_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
);
infiniopStatus_t
kunlunGetMatmulWorkspaceSize
(
infiniopMatmulKunlunDescriptor_t
desc
,
size_t
*
size
);
infiniopStatus_t
kunlunMatmul
(
infiniopMatmulKunlunDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
c
,
void
const
*
a
,
void
const
*
b
,
float
alpha
,
float
beta
,
void
*
stream
);
infiniopStatus_t
kunlunDestroyMatmulDescriptor
(
infiniopMatmulKunlunDescriptor_t
desc
);
#endif
src/infiniop/ops/matmul/operator.cc
View file @
8583cae7
...
@@ -12,6 +12,9 @@
...
@@ -12,6 +12,9 @@
#ifdef ENABLE_ASCEND_API
#ifdef ENABLE_ASCEND_API
#include "ascend/matmul_ascend.h"
#include "ascend/matmul_ascend.h"
#endif
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/matmul_xdnn_api.h"
#endif
__C
infiniStatus_t
infiniopCreateMatmulDescriptor
(
__C
infiniStatus_t
infiniopCreateMatmulDescriptor
(
infiniopHandle_t
handle
,
infiniopHandle_t
handle
,
...
...
test/infiniop/libinfiniop/utils.py
View file @
8583cae7
...
@@ -166,6 +166,11 @@ def get_args():
...
@@ -166,6 +166,11 @@ def get_args():
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Run ASCEND NPU test"
,
help
=
"Run ASCEND NPU test"
,
)
)
parser
.
add_argument
(
"--kunlun"
,
action
=
"store_true"
,
help
=
"Run KUNLUN XPU test"
,
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
@@ -428,6 +433,9 @@ def get_test_devices(args):
...
@@ -428,6 +433,9 @@ def get_test_devices(args):
torch
.
npu
.
set_device
(
0
)
# Ascend NPU needs explicit device initialization
torch
.
npu
.
set_device
(
0
)
# Ascend NPU needs explicit device initialization
devices_to_test
.
append
(
InfiniDeviceEnum
.
ASCEND
)
devices_to_test
.
append
(
InfiniDeviceEnum
.
ASCEND
)
if
args
.
kunlun
:
import
torch_xmlir
devices_to_test
.
append
(
InfiniDeviceEnum
.
KUNLUN
)
if
not
devices_to_test
:
if
not
devices_to_test
:
devices_to_test
=
[
InfiniDeviceEnum
.
CPU
]
devices_to_test
=
[
InfiniDeviceEnum
.
CPU
]
...
...
xmake.lua
View file @
8583cae7
...
@@ -100,6 +100,17 @@ if has_config("sugon-dcu") then
...
@@ -100,6 +100,17 @@ if has_config("sugon-dcu") then
add_defines
(
"ENABLE_SUGON_CUDA_API"
)
add_defines
(
"ENABLE_SUGON_CUDA_API"
)
end
end
-- 昆仑芯
option
(
"kunlun-xpu"
)
set_default
(
false
)
set_showmenu
(
true
)
set_description
(
"Enable or disable Kunlun XPU kernel"
)
option_end
()
if
has_config
(
"kunlun-xpu"
)
then
add_defines
(
"ENABLE_KUNLUN_API"
)
includes
(
"xmake/kunlun.lua"
)
end
target
(
"infiniop"
)
target
(
"infiniop"
)
set_kind
(
"shared"
)
set_kind
(
"shared"
)
...
@@ -134,6 +145,9 @@ target("infiniop")
...
@@ -134,6 +145,9 @@ target("infiniop")
if
has_config
(
"metax-gpu"
)
then
if
has_config
(
"metax-gpu"
)
then
add_deps
(
"metax-gpu"
)
add_deps
(
"metax-gpu"
)
end
end
if
has_config
(
"kunlun-xpu"
)
then
add_deps
(
"infiniop-kunlun"
)
end
set_languages
(
"cxx17"
)
set_languages
(
"cxx17"
)
add_files
(
"src/infiniop/devices/handle.cc"
)
add_files
(
"src/infiniop/devices/handle.cc"
)
add_files
(
"src/infiniop/ops/*/operator.cc"
)
add_files
(
"src/infiniop/ops/*/operator.cc"
)
...
...
xmake/kunlun.lua
0 → 100644
View file @
8583cae7
add_defines
(
"ENABLE_KUNLUN_API"
)
local
KUNLUN_HOME
=
os.getenv
(
"KUNLUN_HOME"
)
-- Add include dirs
add_includedirs
(
path
.
join
(
KUNLUN_HOME
,
"include"
),
{
public
=
true
})
add_linkdirs
(
path
.
join
(
KUNLUN_HOME
,
"lib64"
))
add_links
(
"xpurt"
)
add_links
(
"xpuapi"
)
target
(
"infiniop-kunlun"
)
-- Other configs
set_kind
(
"static"
)
set_languages
(
"cxx17"
)
on_install
(
function
(
target
)
end
)
-- Add files
add_files
(
"$(projectdir)/src/infiniop/devices/kunlun/*.cc"
,
"$(projectdir)/src/infiniop/ops/*/kunlun/*.cc"
)
add_cxflags
(
"-lstdc++ -Wall -Werror -fPIC"
)
target_end
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment