Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
a9acf208
Commit
a9acf208
authored
Aug 05, 2025
by
xgqdut2016
Committed by
zhangyue
Aug 26, 2025
Browse files
issue/340: kunlun cublas gemm
parent
cb06c721
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
111 additions
and
26 deletions
+111
-26
src/infiniop/devices/kunlun/kunlun_common.h
src/infiniop/devices/kunlun/kunlun_common.h
+1
-0
src/infiniop/devices/kunlun/kunlun_handle.cc
src/infiniop/devices/kunlun/kunlun_handle.cc
+11
-0
src/infiniop/devices/kunlun/kunlun_handle.h
src/infiniop/devices/kunlun/kunlun_handle.h
+2
-0
src/infiniop/ops/gemm/kunlun/gemm_kunlun.cc
src/infiniop/ops/gemm/kunlun/gemm_kunlun.cc
+97
-26
No files found.
src/infiniop/devices/kunlun/kunlun_common.h
View file @
a9acf208
...
@@ -13,5 +13,6 @@ typedef XPUEvent kunlunEvent_t;
...
@@ -13,5 +13,6 @@ typedef XPUEvent kunlunEvent_t;
typedef
xdnn
::
Context
*
xdnnHandle_t
;
typedef
xdnn
::
Context
*
xdnnHandle_t
;
#define CHECK_KUNLUN(API) CHECK_INTERNAL(API, XPU_SUCCESS)
#define CHECK_KUNLUN(API) CHECK_INTERNAL(API, XPU_SUCCESS)
#define CHECK_CUBLAS(API) CHECK_INTERNAL(API, CUBLAS_STATUS_SUCCESS)
#endif
#endif
src/infiniop/devices/kunlun/kunlun_handle.cc
View file @
a9acf208
...
@@ -12,6 +12,17 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & {
...
@@ -12,6 +12,17 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & {
infiniStatus_t
Handle
::
create
(
InfiniopHandle
**
handle_ptr
,
int
device_id
)
{
infiniStatus_t
Handle
::
create
(
InfiniopHandle
**
handle_ptr
,
int
device_id
)
{
*
handle_ptr
=
new
Handle
(
device_id
);
*
handle_ptr
=
new
Handle
(
device_id
);
}
infiniStatus_t
Handle
::
Internal
::
useCublas
(
cudaStream_t
stream
,
const
Fn
<
cublasHandle_t
>
&
f
)
const
{
auto
handle
=
blas_handles
.
pop
();
if
(
!
handle
)
{
CHECK_CUBLAS
(
cublasCreate
(
&
(
*
handle
)));
}
CHECK_CUBLAS
(
cublasSetStream
(
*
handle
,
stream
));
CHECK_STATUS
(
f
(
*
handle
));
blas_handles
.
push
(
std
::
move
(
*
handle
));
return
INFINI_STATUS_SUCCESS
;
return
INFINI_STATUS_SUCCESS
;
}
}
...
...
src/infiniop/devices/kunlun/kunlun_handle.h
View file @
a9acf208
...
@@ -23,11 +23,13 @@ public:
...
@@ -23,11 +23,13 @@ public:
class
Handle
::
Internal
{
class
Handle
::
Internal
{
Pool
<
xdnnHandle_t
>
dnn_handles
;
Pool
<
xdnnHandle_t
>
dnn_handles
;
Pool
<
cublasHandle_t
>
blas_handles
;
template
<
typename
T
>
template
<
typename
T
>
using
Fn
=
std
::
function
<
infiniStatus_t
(
T
)
>
;
using
Fn
=
std
::
function
<
infiniStatus_t
(
T
)
>
;
public:
public:
infiniStatus_t
useXdnn
(
kunlunStream_t
stream
,
const
Fn
<
xdnnHandle_t
>
&
f
)
const
;
infiniStatus_t
useXdnn
(
kunlunStream_t
stream
,
const
Fn
<
xdnnHandle_t
>
&
f
)
const
;
infiniStatus_t
useCublas
(
cudaStream_t
stream
,
const
Fn
<
cublasHandle_t
>
&
f
)
const
;
};
};
}
// namespace device::kunlun
}
// namespace device::kunlun
...
...
src/infiniop/ops/gemm/kunlun/gemm_kunlun.cc
View file @
a9acf208
...
@@ -38,6 +38,58 @@ infiniStatus_t Descriptor::create(
...
@@ -38,6 +38,58 @@ infiniStatus_t Descriptor::create(
return
INFINI_STATUS_SUCCESS
;
return
INFINI_STATUS_SUCCESS
;
}
}
// template <class Tdata>
// infiniStatus_t calculate(
// MatmulInfo info,
// std::shared_ptr<HandleInternal> internal,
// infiniDtype_t dtype,
// void *c,
// float beta,
// const void *a,
// const void *b,
// float alpha,
// kunlunStream_t stream) {
// if (info.is_transed) {
// std::swap(a, b);
// }
// auto transA = info.a_matrix.col_stride == 1 ? false : true;
// auto transB = info.b_matrix.col_stride == 1 ? false : true;
// auto unit = infiniSizeOf(dtype);
// CHECK_STATUS(internal->useXdnn(
// (kunlunStream_t)stream,
// [&](xdnnHandle_t handle) {
// for (size_t i = 0; i < info.batch; i++) {
// CHECK_KUNLUN((xdnn::fc_fusion<Tdata, Tdata, Tdata, int16_t>(
// handle,
// (Tdata *)((char *)a + i * info.a_matrix.stride * unit),
// (Tdata *)((char *)b + i * info.b_matrix.stride * unit),
// (Tdata *)((char *)c + i * info.c_matrix.stride * unit),
// info.m,
// info.n,
// info.k,
// transA,
// transB,
// nullptr,
// nullptr,
// nullptr,
// info.a_matrix.ld(),
// info.b_matrix.ld(),
// info.c_matrix.ld(),
// alpha,
// beta,
// nullptr,
// xdnn::Activation_t::LINEAR,
// nullptr)));
// }
// return INFINI_STATUS_SUCCESS;
// }));
// return INFINI_STATUS_SUCCESS;
// }
template
<
class
Tdata
>
template
<
class
Tdata
>
infiniStatus_t
calculate
(
infiniStatus_t
calculate
(
MatmulInfo
info
,
MatmulInfo
info
,
...
@@ -54,37 +106,56 @@ infiniStatus_t calculate(
...
@@ -54,37 +106,56 @@ infiniStatus_t calculate(
std
::
swap
(
a
,
b
);
std
::
swap
(
a
,
b
);
}
}
auto
transA
=
info
.
a_matrix
.
col_stride
==
1
?
false
:
true
;
auto
transA
=
info
.
a_matrix
.
col_stride
==
1
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
auto
transB
=
info
.
b_matrix
.
col_stride
==
1
?
false
:
true
;
auto
transB
=
info
.
b_matrix
.
col_stride
==
1
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cudaDataType_t
a_type
,
b_type
,
c_type
;
cublasComputeType_t
compute_type
;
switch
(
dtype
)
{
case
INFINI_DTYPE_F16
:
a_type
=
b_type
=
c_type
=
CUDA_R_16F
;
compute_type
=
CUBLAS_COMPUTE_32F
;
break
;
case
INFINI_DTYPE_BF16
:
a_type
=
b_type
=
c_type
=
CUDA_R_16BF
;
compute_type
=
CUBLAS_COMPUTE_32F
;
break
;
case
INFINI_DTYPE_F32
:
a_type
=
b_type
=
c_type
=
CUDA_R_32F
;
compute_type
=
CUBLAS_COMPUTE_32F_FAST_TF32
;
break
;
auto
unit
=
infiniSizeOf
(
dtype
);
default:
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
CHECK_STATUS
(
internal
->
use
Xdnn
(
CHECK_STATUS
(
internal
->
use
Cublas
(
(
kunlun
Stream_t
)
stream
,
(
cuda
Stream_t
)
stream
,
[
&
](
xdnn
Handle_t
handle
)
{
[
&
](
cublas
Handle_t
handle
)
{
for
(
size_t
i
=
0
;
i
<
info
.
batch
;
i
++
)
{
CHECK_CUBLAS
(
CHECK_KUNLUN
((
xdnn
::
fc_fusion
<
Tdata
,
Tdata
,
Tdata
,
int16_t
>
(
cublasGemmStridedBatchedEx
(
handle
,
handle
,
(
Tdata
*
)((
char
*
)
a
+
i
*
info
.
a_matrix
.
stride
*
unit
),
(
Tdata
*
)((
char
*
)
b
+
i
*
info
.
b_matrix
.
stride
*
unit
),
(
Tdata
*
)((
char
*
)
c
+
i
*
info
.
c_matrix
.
stride
*
unit
),
info
.
m
,
info
.
n
,
info
.
k
,
transA
,
transA
,
transB
,
transB
,
nullptr
,
static_cast
<
int
>
(
info
.
m
),
nullptr
,
static_cast
<
int
>
(
info
.
n
),
nullptr
,
static_cast
<
int
>
(
info
.
k
),
info
.
a_matrix
.
ld
(),
&
alpha
,
info
.
b_matrix
.
ld
(),
a
,
info
.
c_matrix
.
ld
(),
a_type
,
alpha
,
static_cast
<
int
>
(
info
.
a_matrix
.
ld
()),
beta
,
info
.
a_matrix
.
stride
,
nullptr
,
b
,
xdnn
::
Activation_t
::
LINEAR
,
b_type
,
nullptr
)));
static_cast
<
int
>
(
info
.
b_matrix
.
ld
()),
}
info
.
b_matrix
.
stride
,
&
beta
,
c
,
c_type
,
static_cast
<
int
>
(
info
.
c_matrix
.
ld
()),
info
.
c_matrix
.
stride
,
static_cast
<
int
>
(
info
.
batch
),
compute_type
,
CUBLAS_GEMM_DEFAULT
));
return
INFINI_STATUS_SUCCESS
;
return
INFINI_STATUS_SUCCESS
;
}));
}));
return
INFINI_STATUS_SUCCESS
;
return
INFINI_STATUS_SUCCESS
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment