Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
5ffd00ea
Commit
5ffd00ea
authored
Sep 03, 2025
by
zhushuang
Browse files
feat: implement GEMM with MUBLAS and MUDNN backends in moore gpu
parent
c7373fee
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
320 additions
and
7 deletions
+320
-7
src/infiniop/ops/gemm/moore/gemm_moore.h
src/infiniop/ops/gemm/moore/gemm_moore.h
+98
-2
src/infiniop/ops/gemm/moore/mublas/gemm_mublas.h
src/infiniop/ops/gemm/moore/mublas/gemm_mublas.h
+8
-0
src/infiniop/ops/gemm/moore/mublas/gemm_mublas.mu
src/infiniop/ops/gemm/moore/mublas/gemm_mublas.mu
+5
-5
src/infiniop/ops/gemm/moore/mudnn/gemm_mudnn.h
src/infiniop/ops/gemm/moore/mudnn/gemm_mudnn.h
+8
-0
src/infiniop/ops/gemm/moore/mudnn/gemm_mudnn.mu
src/infiniop/ops/gemm/moore/mudnn/gemm_mudnn.mu
+198
-0
xmake/moore.lua
xmake/moore.lua
+3
-0
No files found.
src/infiniop/ops/gemm/moore/gemm_moore.h
View file @
5ffd00ea
#ifndef __GEMM_MOORE_H__
#define __GEMM_MOORE_H__
#include "../gemm.h"
#include "mublas/gemm_mublas.h"
#include "mudnn/gemm_mudnn.h"
DESCRIPTOR
(
moore
)
namespace
op
::
gemm
::
moore
{
// Descriptor class for GEMM operations on Moore devices.
// This class acts as a wrapper to select either mublas or mudnn backend.
// It encapsulates the backend-specific Descriptor implementation and provides
// a unified interface for workspace query and GEMM calculation.
class
Descriptor
final
:
public
InfiniopDescriptor
{
public:
// Destructor: deletes the backend-specific descriptor.
~
Descriptor
()
{
if
(
_backend
==
Backend
::
MUBLAS
)
{
delete
reinterpret_cast
<
mublas
::
Descriptor
*>
(
_impl
);
}
else
{
delete
reinterpret_cast
<
mudnn
::
Descriptor
*>
(
_impl
);
}
}
// Returns the required workspace size for the GEMM operation.
size_t
workspaceSize
()
const
{
if
(
_backend
==
Backend
::
MUBLAS
)
{
return
reinterpret_cast
<
mublas
::
Descriptor
*>
(
_impl
)
->
workspaceSize
();
}
else
{
return
reinterpret_cast
<
mudnn
::
Descriptor
*>
(
_impl
)
->
workspaceSize
();
}
}
// Static factory method to create a Descriptor instance.
// This method chooses the backend (mublas or mudnn) and constructs
// the corresponding implementation internally.
static
infiniStatus_t
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
c_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
)
{
auto
desc
=
new
Descriptor
(
handle
->
device
,
handle
->
device_id
);
// Backend selection strategy:
// Currently defaulting to MUDNN.
// Can be modified to choose based on environment variables or runtime parameters.
desc
->
_backend
=
Backend
::
MUDNN
;
if
(
desc
->
_backend
==
Backend
::
MUBLAS
)
{
mublas
::
Descriptor
*
impl
;
auto
status
=
mublas
::
Descriptor
::
create
(
handle
,
&
impl
,
c_desc
,
a_desc
,
b_desc
);
if
(
status
!=
INFINI_STATUS_SUCCESS
)
{
delete
desc
;
return
status
;
}
desc
->
_impl
=
impl
;
}
else
{
mudnn
::
Descriptor
*
impl
;
auto
status
=
mudnn
::
Descriptor
::
create
(
handle
,
&
impl
,
c_desc
,
a_desc
,
b_desc
);
if
(
status
!=
INFINI_STATUS_SUCCESS
)
{
delete
desc
;
return
status
;
}
desc
->
_impl
=
impl
;
}
*
desc_ptr
=
desc
;
return
INFINI_STATUS_SUCCESS
;
}
// Unified GEMM calculation interface.
// Calls the corresponding backend's calculate function internally.
infiniStatus_t
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
c
,
float
beta
,
const
void
*
a
,
const
void
*
b
,
float
alpha
,
void
*
stream
)
const
{
if
(
_backend
==
Backend
::
MUBLAS
)
{
return
reinterpret_cast
<
mublas
::
Descriptor
*>
(
_impl
)
->
calculate
(
workspace
,
workspace_size
,
c
,
beta
,
a
,
b
,
alpha
,
stream
);
}
else
{
return
reinterpret_cast
<
mudnn
::
Descriptor
*>
(
_impl
)
->
calculate
(
workspace
,
workspace_size
,
c
,
beta
,
a
,
b
,
alpha
,
stream
);
}
}
private:
// Private constructor: ensures users cannot directly instantiate Descriptor.
// Instances must be created via the static create() factory method.
Descriptor
(
infiniDevice_t
device_type
,
int
device_id
)
:
InfiniopDescriptor
{
device_type
,
device_id
},
_impl
(
nullptr
)
{}
// Enum to indicate which backend is being used internally.
enum
class
Backend
{
MUBLAS
,
MUDNN
};
Backend
_backend
;
// Currently selected MUBLAS/MUDNN backend
void
*
_impl
;
// Pointer to backend-specific descriptor (mublas::Descriptor* or mudnn::Descriptor*)
};
}
// namespace op::gemm::moore
#endif // __GEMM_MOORE_H__
src/infiniop/ops/gemm/moore/mublas/gemm_mublas.h
0 → 100644
View file @
5ffd00ea
#ifndef __GEMM_MUBLAS_H__
#define __GEMM_MUBLAS_H__
#include "../../gemm.h"
DESCRIPTOR
(
mublas
)
#endif // __GEMM_MUBLAS_H__
src/infiniop/ops/gemm/moore/gemm_m
oore
.mu
→
src/infiniop/ops/gemm/moore/
mublas/
gemm_m
ublas
.mu
View file @
5ffd00ea
#include "../../../devices/moore/moore_common.h"
#include "../../../devices/moore/moore_handle.h"
#include "gemm_m
oore
.h"
#include "../../../
../
devices/moore/moore_common.h"
#include "../../../
../
devices/moore/moore_handle.h"
#include "gemm_m
ublas
.h"
namespace op::gemm::m
oore
{
namespace op::gemm::m
ublas
{
struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
...
...
@@ -122,4 +122,4 @@ infiniStatus_t Descriptor::calculate(
return INFINI_STATUS_SUCCESS;
}
} // namespace op::gemm::m
oore
} // namespace op::gemm::m
ublas
src/infiniop/ops/gemm/moore/mudnn/gemm_mudnn.h
0 → 100644
View file @
5ffd00ea
#ifndef __GEMM_MUDNN_H__
#define __GEMM_MUDNN_H__
#include "../../gemm.h"
DESCRIPTOR
(
mudnn
)
#endif // __GEMM_MUDNN_H__
src/infiniop/ops/gemm/moore/mudnn/gemm_mudnn.mu
0 → 100644
View file @
5ffd00ea
#include "../../../../devices/moore/moore_common.h"
#include "../../../../devices/moore/moore_handle.h"
#include "gemm_mudnn.h"
#include <musa_bf16.h>
namespace op::gemm::mudnn {
struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
auto handle = reinterpret_cast<device::moore::Handle *>(handle_);
auto dtype = c_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::ROW_MAJOR);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(
dtype, result.take(), 0,
new Opaque{handle->internal()},
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename Tdata>
infiniStatus_t calculate(
const MatmulInfo &info,
std::shared_ptr<device::moore::Handle::Internal> &_internal,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
void *stream)
{
// 0. For muDNN development, refer to the official documentation and the following headers:
// - /usr/local/musa/include/mudnn_base.h
// - /usr/local/musa/include/mudnn_math.h
// - /usr/local/musa/include/mudnn.h
// 1. Create BatchMatMul operator
auto matmul_operator = std::make_unique<::musa::dnn::BatchMatMul>();
matmul_operator->SetComputeMode(::musa::dnn::BatchMatMul::ComputeMode::TENSOR);
// 2. Use _internal->useMudnn to manage muDNN handle
return _internal->useMudnn((musaStream_t)stream, [&](::musa::dnn::Handle &mudnn_handle) -> infiniStatus_t {
// 3. Create BatchMatMul Tensor
::musa::dnn::Tensor out, left, right;
if constexpr (std::is_same<Tdata, half>::value) {
out.SetType(::musa::dnn::Tensor::Type::HALF);
left.SetType(::musa::dnn::Tensor::Type::HALF);
right.SetType(::musa::dnn::Tensor::Type::HALF);
}
else if constexpr (std::is_same<Tdata, __mt_bfloat16>::value){
out.SetType(::musa::dnn::Tensor::Type::BFLOAT16);
left.SetType(::musa::dnn::Tensor::Type::BFLOAT16);
right.SetType(::musa::dnn::Tensor::Type::BFLOAT16);
}
else{
out.SetType(::musa::dnn::Tensor::Type::FLOAT);
left.SetType(::musa::dnn::Tensor::Type::FLOAT);
right.SetType(::musa::dnn::Tensor::Type::FLOAT);
}
// 4. Bind BatchMatMul Tensor addr
out.SetAddr(c);
left.SetAddr(a);
right.SetAddr(b);
// 5. Config Tensor left
std::array<int64_t, 3> a_dims_array;
std::array<int64_t, 3> a_stride_array;
if (info.a_matrix.col_stride != 1) {
a_dims_array = { static_cast<int64_t>(info.batch),
static_cast<int64_t>(info.k),
static_cast<int64_t>(info.m) };
} else {
a_dims_array = { static_cast<int64_t>(info.batch),
static_cast<int64_t>(info.m),
static_cast<int64_t>(info.k) };
}
a_stride_array = { static_cast<int64_t>(info.a_matrix.stride),
static_cast<int64_t>(info.a_matrix.ld()),
1 };
left.SetNdInfo(static_cast<int>(a_dims_array.size()), a_dims_array.data(), a_stride_array.data());
// 6. Config Tensor right
std::array<int64_t, 3> b_dims_array;
std::array<int64_t, 3> b_stride_array;
if (info.b_matrix.col_stride != 1) {
b_dims_array = { static_cast<int64_t>(info.batch),
static_cast<int64_t>(info.n),
static_cast<int64_t>(info.k) };
} else {
b_dims_array = { static_cast<int64_t>(info.batch),
static_cast<int64_t>(info.k),
static_cast<int64_t>(info.n) };
}
b_stride_array = { static_cast<int64_t>(info.b_matrix.stride),
static_cast<int64_t>(info.b_matrix.ld()),
1 };
right.SetNdInfo(static_cast<int>(b_dims_array.size()), b_dims_array.data(), b_stride_array.data());
// 7. Confit Tensor out, muDNN BatchMatMul output only support row-major tensor
std::array<int64_t, 3> c_dims_array = { static_cast<int64_t>(info.batch),
static_cast<int64_t>(info.m),
static_cast<int64_t>(info.n) };
std::array<int64_t, 3> c_stride_array = { static_cast<int64_t>(info.c_matrix.stride),
static_cast<int64_t>(info.c_matrix.ld()),
1 };
out.SetNdInfo(static_cast<int>(c_dims_array.size()), c_dims_array.data(), c_stride_array.data());
// 8. Workspace Memory Handler
::musa::dnn::MemoryMaintainer maintainer = [](size_t size) -> ::musa::dnn::MemoryHandler {
void* ptr = nullptr;
musaMalloc(&ptr, size);
return ::musa::dnn::MemoryHandler(ptr, [](void* p) { if(p) musaFree(p); });
};
// 9. Tensor left and Tensor right transpose config
if (info.a_matrix.col_stride == 1 && info.b_matrix.col_stride != 1)
matmul_operator->SetTranspose(false, true);
else if (info.a_matrix.col_stride != 1 && info.b_matrix.col_stride == 1)
matmul_operator->SetTranspose(true, false);
else if (info.a_matrix.col_stride != 1 && info.b_matrix.col_stride != 1)
matmul_operator->SetTranspose(true, true);
else
matmul_operator->SetTranspose(false, false);
// 10. BatchMatMul workspace config
size_t workspace_size_in_bytes = 0;
matmul_operator->GetWorkspaceSize(mudnn_handle, workspace_size_in_bytes, out, left, right);
// 11. Alpha Beta Gamma
matmul_operator->SetAlpha(static_cast<double>(alpha));
matmul_operator->SetBeta(static_cast<double>(beta));
matmul_operator->SetGamma(0.0);
// 12. Run
matmul_operator->Run(
mudnn_handle,
out,
left,
right,
static_cast<int64_t>(info.batch),
static_cast<int64_t>(info.m),
static_cast<int64_t>(info.n),
static_cast<int64_t>(info.k),
static_cast<int64_t>(info.a_matrix.ld()),
static_cast<int64_t>(info.b_matrix.ld()),
static_cast<int64_t>(info.c_matrix.ld()),
static_cast<int64_t>(info.a_matrix.stride),
static_cast<int64_t>(info.b_matrix.stride),
static_cast<int64_t>(info.c_matrix.stride),
maintainer
);
return INFINI_STATUS_SUCCESS;
});
}
infiniStatus_t Descriptor::calculate(void *workspace,
size_t workspace_size,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
void *stream) const {
switch (_dtype) {
case INFINI_DTYPE_F16:
return mudnn::calculate<half>(_info, _opaque->internal, c, beta, a, b, alpha, stream);
case INFINI_DTYPE_F32:
return mudnn::calculate<float>(_info,_opaque->internal, c, beta, a, b, alpha, stream);
case INFINI_DTYPE_BF16:
return mudnn::calculate<__mt_bfloat16>(_info,_opaque->internal, c, beta, a, b, alpha, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
} // namespace op::gemm::mudnn
xmake/moore.lua
View file @
5ffd00ea
...
...
@@ -44,6 +44,9 @@ target("infiniop-moore")
add_cxflags
(
"-lstdc++"
,
"-fPIC"
,
"-Wno-comment"
)
add_files
(
"../src/infiniop/devices/moore/*.cc"
)
add_files
(
"../src/infiniop/ops/*/moore/*.mu"
,
{
rule
=
"mu"
})
-- Add source files for Moore muBLAS/muDNN GEMM backends.
add_files
(
"../src/infiniop/ops/gemm/moore/*/*.mu"
,
{
rule
=
"mu"
})
target_end
()
target
(
"infinirt-moore"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment