Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
5450c707
Commit
5450c707
authored
Mar 12, 2025
by
qinyiqun
Browse files
issue/31 添加摩尔线程matmul算子
parent
0b9a1764
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
187 additions
and
4 deletions
+187
-4
src/infiniop/devices/musa/common_musa.h
src/infiniop/devices/musa/common_musa.h
+2
-0
src/infiniop/devices/musa/musa_handle.cc
src/infiniop/devices/musa/musa_handle.cc
+1
-0
src/infiniop/devices/musa/pool.h
src/infiniop/devices/musa/pool.h
+1
-1
src/infiniop/ops/gemm/operator.cc
src/infiniop/ops/gemm/operator.cc
+16
-0
src/infiniop/ops/matmul/musa/matmul_musa.h
src/infiniop/ops/matmul/musa/matmul_musa.h
+8
-0
src/infiniop/ops/matmul/musa/matmul_musa.mu
src/infiniop/ops/matmul/musa/matmul_musa.mu
+142
-0
test/infiniop/libinfiniop/utils.py
test/infiniop/libinfiniop/utils.py
+10
-0
xmake.lua
xmake.lua
+1
-1
xmake/musa.lua
xmake/musa.lua
+6
-2
No files found.
src/infiniop/devices/musa/common_musa.h
View file @
5450c707
...
...
@@ -5,6 +5,8 @@
#include <mublas.h>
#include <mudnn.h>
#include <musa.h>
// #include <musa_fp16.h>
#include <musa_fp16_mtgpu.h>
#include <musa_runtime_api.h>
#define CHECK_MUBLAS(API) CHECK_INTERNAL(API, MUBLAS_STATUS_SUCCESS)
...
...
src/infiniop/devices/musa/musa_handle.cc
View file @
5450c707
...
...
@@ -12,6 +12,7 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & {
infiniStatus_t
Handle
::
Internal
::
useMublas
(
MUstream
stream
,
const
Fn
<
mublasHandle_t
>
&
f
)
const
{
mublasHandle_t
*
handle
=
mublas_handles
.
pop
();
if
(
!
handle
)
{
handle
=
new
mublasHandle_t
;
CHECK_MUBLAS
(
mublasCreate
(
handle
));
}
CHECK_MUBLAS
(
mublasSetStream
(
*
handle
,
stream
));
...
...
src/infiniop/devices/musa/pool.h
View file @
5450c707
...
...
@@ -21,7 +21,7 @@ public:
void
push
(
T
*
val
)
const
{
Node
<
T
>
*
new_node
=
new
Node
<
T
>
(
val
);
new_node
->
next
=
_head
.
load
();
while
(
!
_head
.
compare_exchange_weak
(
new_node
->
next
,
new_node
))
;
while
(
!
_head
.
compare_exchange_weak
(
new_node
->
next
,
new_node
))
{}
}
T
*
pop
()
const
{
...
...
src/infiniop/ops/gemm/operator.cc
View file @
5450c707
...
...
@@ -17,6 +17,9 @@
#ifdef ENABLE_METAX_API
#include "maca/gemm_maca.h"
#endif
#ifdef ENABLE_MOORE_API
#include "musa/matmul_musa.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/gemm_kunlun.h"
#endif
...
...
@@ -54,6 +57,10 @@ __C infiniStatus_t infiniopCreateGemmDescriptor(
#ifdef ENABLE_METAX_API
CREATE
(
INFINI_DEVICE_METAX
,
maca
);
#endif
#ifdef ENABLE_MOORE_API
CREATE
(
INFINI_DEVICE_MOORE
,
musa
);
#endif
#ifdef ENABLE_KUNLUN_API
CREATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
...
...
@@ -92,6 +99,9 @@ infiniopGetGemmWorkspaceSize(
#ifdef ENABLE_METAX_API
GET
(
INFINI_DEVICE_METAX
,
maca
);
#endif
#ifdef ENABLE_MOORE_API
GET
(
INFINI_DEVICE_MOORE
,
musa
);
#endif
#ifdef ENABLE_KUNLUN_API
GET
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
...
...
@@ -138,6 +148,9 @@ __C infiniStatus_t infiniopGemm(
#ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
maca
);
#endif
#ifdef ENABLE_MOORE_API
CALCULATE
(
INFINI_DEVICE_MOORE
,
musa
);
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
...
...
@@ -174,6 +187,9 @@ infiniopDestroyGemmDescriptor(infiniopGemmDescriptor_t desc) {
#ifdef ENABLE_METAX_API
DELETE
(
INFINI_DEVICE_METAX
,
maca
);
#endif
#ifdef ENABLE_MOORE_API
DELETE
(
INFINI_DEVICE_MOORE
,
musa
);
#endif
#ifdef ENABLE_KUNLUN_API
DELETE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
...
...
src/infiniop/ops/matmul/musa/matmul_musa.h
0 → 100644
View file @
5450c707
#ifndef __MATMUL_MUSA_CUH__
#define __MATMUL_MUSA_CUH__
#include "../matmul.h"
DESCRIPTOR
(
musa
)
#endif // __MATMUL_MUSA_CUH__
src/infiniop/ops/matmul/musa/matmul_musa.mu
0 → 100644
View file @
5450c707
#include "../../../devices/musa/common_musa.h"
#include "../../../devices/musa/musa_handle.h"
#include "matmul_musa.h"
namespace op::matmul::musa {
struct Descriptor::Opaque {
std::shared_ptr<device::musa::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
auto handle = reinterpret_cast<device::musa::Handle *>(handle_);
auto dtype = c_desc->dtype();
if (dtype != INFINI_DTYPE_F16 && dtype != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
infiniStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::COL_MAJOR);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
*desc_ptr = new Descriptor(
dtype, info, 0,
new Opaque{handle->internal()},
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename Tdata>
infiniStatus_t calculate(
const MatmulInfo &info,
std::shared_ptr<device::musa::Handle::Internal> &_internal,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
void *stream) {
musaDataType a_type, b_type, c_type;
mublasComputeType_t compute_type;
Tdata alpha_, beta_;
if constexpr (std::is_same<Tdata, half>::value) {
alpha_ = __float2half(alpha);
beta_ = __float2half(beta);
a_type = b_type = c_type = MUSA_R_16F;
compute_type = MUBLAS_COMPUTE_16F;
} else {
alpha_ = alpha;
beta_ = beta;
a_type = b_type = c_type = MUSA_R_32F;
compute_type = MUBLAS_COMPUTE_32F_FAST_TF32;
}
if (info.is_transed) {
std::swap(a, b);
}
auto op_a = info.a_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T;
auto op_b = info.b_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T;
CHECK_STATUS(_internal->useMublas(
(MUstream)stream,
[&](mublasHandle_t handle) {
CHECK_MUBLAS(
mublasGemmStridedBatchedEx(
handle,
op_a,
op_b,
static_cast<int>(info.m),
static_cast<int>(info.n),
static_cast<int>(info.k),
&alpha_,
a,
a_type,
static_cast<int>(info.a_matrix.ld()),
info.a_matrix.stride,
b,
b_type,
static_cast<int>(info.b_matrix.ld()),
info.b_matrix.stride,
&beta_,
c,
c_type,
static_cast<int>(info.c_matrix.ld()),
info.c_matrix.stride,
static_cast<int>(info.batch),
compute_type,
MUBLAS_GEMM_DEFAULT));
return INFINI_STATUS_SUCCESS;
}));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace,
size_t workspace_size,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
void *stream) const {
switch (_dtype) {
case INFINI_DTYPE_F16:
return musa::calculate<half>(_info,
_opaque->internal,
c,
beta,
a,
b,
alpha,
stream);
case INFINI_DTYPE_F32:
return musa::calculate<float>(_info,
_opaque->internal,
c,
beta,
a,
b,
alpha,
stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
} // namespace op::matmul::musa
test/infiniop/libinfiniop/utils.py
View file @
5450c707
...
...
@@ -171,6 +171,11 @@ def get_args():
action
=
"store_true"
,
help
=
"Run METAX GPU test"
,
)
parser
.
add_argument
(
"--moore"
,
action
=
"store_true"
,
help
=
"Run MTHREADS GPU test"
,
)
parser
.
add_argument
(
"--kunlun"
,
action
=
"store_true"
,
...
...
@@ -443,6 +448,11 @@ def get_test_devices(args):
import
torch
devices_to_test
.
append
(
InfiniDeviceEnum
.
METAX
)
if
args
.
moore
:
import
torch
import
torch_musa
devices_to_test
.
append
(
InfiniDeviceEnum
.
MOORE
)
if
args
.
kunlun
:
import
torch_xmlir
...
...
xmake.lua
View file @
5450c707
...
...
@@ -184,7 +184,7 @@ target("infiniop")
add_deps
(
"infiniop-metax"
)
end
if
has_config
(
"moore-gpu"
)
then
add_deps
(
"infini
-musa
"
)
add_deps
(
"infini
op-moore
"
)
end
if
has_config
(
"kunlun-xpu"
)
then
add_deps
(
"infiniop-kunlun"
)
...
...
xmake/musa.lua
View file @
5450c707
...
...
@@ -28,10 +28,14 @@ rule("mu")
end
)
rule_end
()
target
(
"infini
-musa
"
)
target
(
"infini
op-moore
"
)
set_kind
(
"static"
)
on_install
(
function
(
target
)
end
)
set_languages
(
"cxx17"
)
set_warnings
(
"all"
)
add_cxflags
(
"-lstdc++ -Wall -fPIC"
)
add_files
(
"../src/infiniop/devices/musa/*.cc"
,
"../src/infiniop/ops/*/musa/*.cc"
)
add_files
(
"
src
/ops/*/musa/*.mu"
,
{
rule
=
"mu"
})
add_files
(
"
../src/infiniop
/ops/*/musa/*.mu"
,
{
rule
=
"mu"
})
add_cxflags
(
"-lstdc++ -Wall -fPIC"
)
target_end
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment