Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
1d064392
Unverified
Commit
1d064392
authored
Aug 28, 2025
by
zhangyue
Committed by
GitHub
Aug 28, 2025
Browse files
Merge pull request #409 from InfiniTensor/issue/340
Issue/340 接入昆仑芯 XBLAS
parents
e221916d
b92ecc31
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
131 additions
and
64 deletions
+131
-64
src/infiniop/devices/kunlun/kunlun_xblas.cc
src/infiniop/devices/kunlun/kunlun_xblas.cc
+30
-0
src/infiniop/devices/kunlun/kunlun_xblas.h
src/infiniop/devices/kunlun/kunlun_xblas.h
+38
-0
src/infiniop/ops/gemm/kunlun/gemm_kunlun.cc
src/infiniop/ops/gemm/kunlun/gemm_kunlun.cc
+59
-63
xmake/kunlun.lua
xmake/kunlun.lua
+4
-1
No files found.
src/infiniop/devices/kunlun/kunlun_xblas.cc
0 → 100644
View file @
1d064392
#include "kunlun_xblas.h"
namespace
device
::
kunlun
::
blas
{
Handle
::
Handle
(
int
device_id
)
:
InfiniopHandle
{
INFINI_DEVICE_KUNLUN
,
device_id
},
_internal
(
std
::
make_shared
<
Handle
::
Internal
>
())
{}
auto
Handle
::
internal
()
const
->
const
std
::
shared_ptr
<
Internal
>
&
{
return
_internal
;
}
infiniStatus_t
Handle
::
create
(
InfiniopHandle
**
handle_ptr
,
int
device_id
)
{
*
handle_ptr
=
new
Handle
(
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Handle
::
Internal
::
useCublas
(
cudaStream_t
stream
,
const
Fn
<
cublasHandle_t
>
&
f
)
const
{
auto
handle
=
blas_handles
.
pop
();
if
(
!
handle
)
{
CHECK_CUBLAS
(
cublasCreate
(
&
(
*
handle
)));
}
CHECK_CUBLAS
(
cublasSetStream
(
*
handle
,
stream
));
CHECK_STATUS
(
f
(
*
handle
));
blas_handles
.
push
(
std
::
move
(
*
handle
));
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace device::kunlun::blas
src/infiniop/devices/kunlun/kunlun_xblas.h
0 → 100644
View file @
1d064392
#ifndef __KUNLUN_XBLAS_H__
#define __KUNLUN_XBLAS_H__
#include "../../handle.h"
#include "../pool.h"
#include "kunlun_common.h"
#include <cublas_v2.h>
#include <memory>
#define CHECK_CUBLAS(API) CHECK_INTERNAL(API, CUBLAS_STATUS_SUCCESS)
namespace
device
::
kunlun
::
blas
{
struct
Handle
:
public
InfiniopHandle
{
class
Internal
;
auto
internal
()
const
->
const
std
::
shared_ptr
<
Internal
>
&
;
Handle
(
int
device_id
);
private:
std
::
shared_ptr
<
Internal
>
_internal
;
public:
static
infiniStatus_t
create
(
InfiniopHandle
**
handle_ptr
,
int
device_id
);
};
class
Handle
::
Internal
{
Pool
<
cublasHandle_t
>
blas_handles
;
template
<
typename
T
>
using
Fn
=
std
::
function
<
infiniStatus_t
(
T
)
>
;
public:
infiniStatus_t
useCublas
(
cudaStream_t
stream
,
const
Fn
<
cublasHandle_t
>
&
f
)
const
;
};
}
// namespace device::kunlun::blas
#endif // __KUNLUN_XBLAS_H__
src/infiniop/ops/gemm/kunlun/gemm_kunlun.cc
View file @
1d064392
#include "gemm_kunlun.h"
#include "../../../../utils.h"
#include "../../../devices/kunlun/kunlun_common.h"
#include "../../../devices/kunlun/kunlun_
handle
.h"
#include "../../../devices/kunlun/kunlun_
xblas
.h"
namespace
op
::
gemm
::
kunlun
{
typedef
device
::
kunlun
::
Handle
::
Internal
HandleInternal
;
typedef
device
::
kunlun
::
blas
::
Handle
::
Internal
HandleInternal
;
struct
Descriptor
::
Opaque
{
std
::
shared_ptr
<
HandleInternal
>
internal
;
...
...
@@ -21,14 +20,12 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t
c_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
)
{
auto
handle
=
reinterpret_cast
<
device
::
kunlun
::
Handle
*>
(
handle_
);
auto
handle
=
reinterpret_cast
<
device
::
kunlun
::
blas
::
Handle
*>
(
handle_
);
auto
dtype
=
c_desc
->
dtype
();
if
(
dtype
!=
INFINI_DTYPE_F16
&&
dtype
!=
INFINI_DTYPE_F32
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_F32
,
INFINI_DTYPE_BF16
);
auto
result
=
MatmulInfo
::
create
(
c_desc
,
a_desc
,
b_desc
,
MatrixLayout
::
ROW
_MAJOR
);
auto
result
=
MatmulInfo
::
create
(
c_desc
,
a_desc
,
b_desc
,
MatrixLayout
::
COL
_MAJOR
);
CHECK_RESULT
(
result
);
*
desc_ptr
=
new
Descriptor
(
...
...
@@ -38,75 +35,74 @@ infiniStatus_t Descriptor::create(
return
INFINI_STATUS_SUCCESS
;
}
template
<
class
Tdata
>
infiniStatus_t
calculate
(
MatmulInfo
info
,
std
::
shared_ptr
<
HandleInternal
>
internal
,
infiniDtype_t
dtype
,
void
*
c
,
float
beta
,
const
void
*
a
,
const
void
*
b
,
float
alpha
,
kunlunStream_t
stream
)
{
if
(
info
.
is_transed
)
{
std
::
swap
(
a
,
b
);
}
auto
transA
=
info
.
a_matrix
.
col_stride
==
1
?
false
:
true
;
auto
transB
=
info
.
b_matrix
.
col_stride
==
1
?
false
:
true
;
auto
unit
=
infiniSizeOf
(
dtype
);
CHECK_STATUS
(
internal
->
useXdnn
(
(
kunlunStream_t
)
stream
,
[
&
](
xdnnHandle_t
handle
)
{
for
(
size_t
i
=
0
;
i
<
info
.
batch
;
i
++
)
{
CHECK_KUNLUN
((
xdnn
::
fc_fusion
<
Tdata
,
Tdata
,
Tdata
,
int16_t
>
(
handle
,
(
Tdata
*
)((
char
*
)
a
+
i
*
info
.
a_matrix
.
stride
*
unit
),
(
Tdata
*
)((
char
*
)
b
+
i
*
info
.
b_matrix
.
stride
*
unit
),
(
Tdata
*
)((
char
*
)
c
+
i
*
info
.
c_matrix
.
stride
*
unit
),
info
.
m
,
info
.
n
,
info
.
k
,
transA
,
transB
,
nullptr
,
nullptr
,
nullptr
,
info
.
a_matrix
.
ld
(),
info
.
b_matrix
.
ld
(),
info
.
c_matrix
.
ld
(),
alpha
,
beta
,
nullptr
,
xdnn
::
Activation_t
::
LINEAR
,
nullptr
)));
}
return
INFINI_STATUS_SUCCESS
;
}));
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
works
a
pce_size
,
size_t
worksp
a
ce_size
,
void
*
c
,
float
beta
,
const
void
*
a
,
const
void
*
b
,
float
alpha
,
void
*
stream
)
const
{
cudaDataType
a_type
,
b_type
,
c_type
;
cublasComputeType_t
compute_type
;
switch
(
_dtype
)
{
case
INFINI_DTYPE_F16
:
return
op
::
gemm
::
kunlun
::
calculate
<
float16
>
(
_info
,
_opaque
->
internal
,
_dtype
,
c
,
beta
,
a
,
b
,
alpha
,
(
kunlunStream_t
)
stream
);
a_type
=
b_type
=
c_type
=
CUDA_R_16F
;
compute_type
=
CUBLAS_COMPUTE_32F
;
break
;
case
INFINI_DTYPE_BF16
:
a_type
=
b_type
=
c_type
=
CUDA_R_16BF
;
compute_type
=
CUBLAS_COMPUTE_32F
;
break
;
case
INFINI_DTYPE_F32
:
return
op
::
gemm
::
kunlun
::
calculate
<
float
>
(
_info
,
_opaque
->
internal
,
_dtype
,
c
,
beta
,
a
,
b
,
alpha
,
(
kunlunStream_t
)
stream
);
a_type
=
b_type
=
c_type
=
CUDA_R_32F
;
compute_type
=
CUBLAS_COMPUTE_32F_FAST_TF32
;
break
;
default:
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
_info
.
is_transed
)
{
std
::
swap
(
a
,
b
);
}
auto
op_a
=
_info
.
a_matrix
.
row_stride
==
1
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
auto
op_b
=
_info
.
b_matrix
.
row_stride
==
1
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
CHECK_STATUS
(
_opaque
->
internal
->
useCublas
(
(
cudaStream_t
)
stream
,
[
&
](
cublasHandle_t
handle
)
{
CHECK_CUBLAS
(
cublasGemmStridedBatchedEx
(
handle
,
op_a
,
op_b
,
static_cast
<
int
>
(
_info
.
m
),
static_cast
<
int
>
(
_info
.
n
),
static_cast
<
int
>
(
_info
.
k
),
&
alpha
,
a
,
a_type
,
static_cast
<
int
>
(
_info
.
a_matrix
.
ld
()),
_info
.
a_matrix
.
stride
,
b
,
b_type
,
static_cast
<
int
>
(
_info
.
b_matrix
.
ld
()),
_info
.
b_matrix
.
stride
,
&
beta
,
c
,
c_type
,
static_cast
<
int
>
(
_info
.
c_matrix
.
ld
()),
_info
.
c_matrix
.
stride
,
static_cast
<
int
>
(
_info
.
batch
),
compute_type
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
return
INFINI_STATUS_SUCCESS
;
}));
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::gemm::kunlun
xmake/kunlun.lua
View file @
1d064392
...
...
@@ -3,16 +3,19 @@ local KUNLUN_HOME = os.getenv("KUNLUN_HOME")
local
XRE_DIR
=
path
.
join
(
KUNLUN_HOME
,
"xre"
)
local
XTDK_DIR
=
path
.
join
(
KUNLUN_HOME
,
"xtdk"
)
local
XDNN_DIR
=
path
.
join
(
KUNLUN_HOME
,
"xhpc"
,
"xdnn"
)
local
XBLAS_DIR
=
path
.
join
(
KUNLUN_HOME
,
"xhpc"
,
"xblas"
)
-- Add include dirs
add_includedirs
(
path
.
join
(
XRE_DIR
,
"include"
),
{
public
=
true
})
add_includedirs
(
path
.
join
(
XDNN_DIR
,
"include"
),
{
public
=
true
})
add_includedirs
(
path
.
join
(
XTDK_DIR
,
"include"
),
{
public
=
true
})
add_includedirs
(
path
.
join
(
XBLAS_DIR
,
"include"
),
{
public
=
true
})
-- Add link dirs
add_linkdirs
(
path
.
join
(
XRE_DIR
,
"so"
))
add_linkdirs
(
path
.
join
(
XDNN_DIR
,
"so"
))
add_links
(
"xpurt"
,
"xpuapi"
)
add_linkdirs
(
path
.
join
(
XBLAS_DIR
,
"so"
))
add_links
(
"xpurt"
,
"xpuapi"
,
"xpu_blas"
)
rule
(
"xpu"
)
set_extensions
(
".xpu"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment