Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
d4f726de
Commit
d4f726de
authored
Feb 11, 2026
by
zhushuang
Browse files
issue/972 - feat: add scaled_mm with muDNN BatchMatMul for moore gpu
parent
012df56c
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
316 additions
and
28 deletions
+316
-28
src/infiniop/ops/scaled_mm/info.h
src/infiniop/ops/scaled_mm/info.h
+1
-1
src/infiniop/ops/scaled_mm/int8_gemm.h
src/infiniop/ops/scaled_mm/int8_gemm.h
+2
-2
src/infiniop/ops/scaled_mm/moore/int8_gemm_moore.h
src/infiniop/ops/scaled_mm/moore/int8_gemm_moore.h
+7
-0
src/infiniop/ops/scaled_mm/moore/int8_gemm_moore.mu
src/infiniop/ops/scaled_mm/moore/int8_gemm_moore.mu
+238
-0
src/infiniop/ops/scaled_mm/operator.cc
src/infiniop/ops/scaled_mm/operator.cc
+16
-0
test/infiniop/scaled_mm_int8.py
test/infiniop/scaled_mm_int8.py
+52
-25
No files found.
src/infiniop/ops/scaled_mm/info.h
View file @
d4f726de
#ifndef __GEMM_INFO_H__
#ifndef __
I8
GEMM_INFO_H__
#define __I8GEMM_INFO_H__
#define __I8GEMM_INFO_H__
#include "../../../utils.h"
#include "../../../utils.h"
...
...
src/infiniop/ops/scaled_mm/int8_gemm.h
View file @
d4f726de
...
@@ -18,8 +18,8 @@
...
@@ -18,8 +18,8 @@
size_t workspace_size, \
size_t workspace_size, \
infiniDtype_t out_dtype, \
infiniDtype_t out_dtype, \
infiniDevice_t device_type, int device_id) \
infiniDevice_t device_type, int device_id) \
: InfiniopDescriptor{device_type, device_id}, _o
ut_dtype(out_dtype),
\
: InfiniopDescriptor{device_type, device_id}, _o
paque(opaque),
\
_
opaque(opaque), _info(info), _workspace_size(workspace_size) {}
\
_
workspace_size(workspace_size), _info(info), _out_dtype(out_dtype) {}
\
\
\
public: \
public: \
~Descriptor(); \
~Descriptor(); \
...
...
src/infiniop/ops/scaled_mm/moore/int8_gemm_moore.h
0 → 100644
View file @
d4f726de
#ifndef __INT8_GEMM_MOORE_API_H__
#define __INT8_GEMM_MOORE_API_H__
#include "../int8_gemm.h"
DESCRIPTOR
(
moore
)
#endif // __INT8_GEMM_MOORE_API_H__
src/infiniop/ops/scaled_mm/moore/int8_gemm_moore.mu
0 → 100644
View file @
d4f726de
#include "../../../devices/moore/moore_common.h"
#include "../../../devices/moore/moore_handle.h"
#include "int8_gemm_moore.h"
namespace op::i8gemm::moore {
static void moore_i8gemm_launch(
const I8GemmInfo &info,
std::shared_ptr<device::moore::Handle::Internal> &internal,
void* out,
const int8_t* A,
const int8_t* B,
const float* A_scale,
const float* B_scale,
const void* bias,
infiniDtype_t out_dtype,
musaStream_t stream)
{
internal->useMudnn(stream,
[&](::musa::dnn::Handle &mudnn_handle) -> infiniStatus_t {
// 1. Operator
auto matmul = std::make_unique<::musa::dnn::BatchMatMul>();
matmul->SetComputeMode(::musa::dnn::BatchMatMul::ComputeMode::TENSOR);
// 2. Tensors
::musa::dnn::Tensor out_t, a_t, b_t, bias_t;
::musa::dnn::Tensor scale_a_t, scale_b_t;
// 3. Output dtype
if (out_dtype == INFINI_DTYPE_F16) {
out_t.SetType(::musa::dnn::Tensor::Type::HALF);
bias_t.SetType(::musa::dnn::Tensor::Type::HALF);
} else {
out_t.SetType(::musa::dnn::Tensor::Type::BFLOAT16);
bias_t.SetType(::musa::dnn::Tensor::Type::BFLOAT16);
}
// 4. Input INT8
a_t.SetType(::musa::dnn::Tensor::Type::INT8);
b_t.SetType(::musa::dnn::Tensor::Type::INT8);
// 5. Scale (per-tensor)
scale_a_t.SetType(::musa::dnn::Tensor::Type::FLOAT);
scale_b_t.SetType(::musa::dnn::Tensor::Type::FLOAT);
// 6. Bind memory
out_t.SetAddr(out);
a_t.SetAddr(const_cast<int8_t*>(A));
b_t.SetAddr(const_cast<int8_t*>(B));
scale_a_t.SetAddr(const_cast<float*>(A_scale));
scale_b_t.SetAddr(const_cast<float*>(B_scale));
if (bias)
bias_t.SetAddr(const_cast<void*>(bias));
// 7. A NdInfo
{
std::array<int64_t,3> dims;
std::array<int64_t,3> strides;
if (info.a_matrix.col_stride != 1) {
dims = {info.batch, info.k, info.m};
} else {
dims = {info.batch, info.m, info.k};
}
strides = {
info.a_matrix.stride,
info.a_matrix.ld(),
1
};
a_t.SetNdInfo(3, dims.data(), strides.data());
}
// 8. B NdInfo
{
std::array<int64_t,3> dims;
std::array<int64_t,3> strides;
if (info.b_matrix.col_stride != 1) {
dims = {info.batch, info.n, info.k};
} else {
dims = {info.batch, info.k, info.n};
}
strides = {
info.b_matrix.stride,
info.b_matrix.ld(),
1
};
b_t.SetNdInfo(3, dims.data(), strides.data());
}
// 9. out NdInfo
{
std::array<int64_t, 3> dims = {
info.batch,
info.m,
info.n
};
std::array<int64_t, 3> strides = {
info.m * info.n,
info.n,
1
};
out_t.SetNdInfo(3, dims.data(), strides.data());
}
// 10. Bias & scale NdInfo
if (bias) {
std::array<int64_t,1> dims = { info.n };
std::array<int64_t,1> strides = { 1 };
bias_t.SetNdInfo(1, dims.data(), strides.data());
}
{
std::array<int64_t,3> a_scale_dims = { info.batch, info.m, 1 };
std::array<int64_t,3> a_scale_strides = { info.m, 1, 1 };
scale_a_t.SetNdInfo(3, a_scale_dims.data(), a_scale_strides.data());
std::array<int64_t,3> b_scale_dims = { info.batch, 1, info.n };
std::array<int64_t,3> b_scale_strides = { info.n, 1, 1 };
scale_b_t.SetNdInfo(3, b_scale_dims.data(), b_scale_strides.data());
}
// 11. Transpose
matmul->SetTranspose(
info.a_matrix.col_stride != 1,
info.b_matrix.col_stride != 1);
// 12. Lt param (no epilogue enum)
::musa::dnn::MatMulLtParam lt_param;
lt_param.SetScale(
scale_a_t,
scale_b_t,
::musa::dnn::Tensor(),
::musa::dnn::Tensor());
// 13. Alpha / Beta
matmul->SetAlpha(1.0);
matmul->SetBeta(0.0);
matmul->SetGamma(1.0);
// 14. Workspace
::musa::dnn::MemoryMaintainer maintainer =
[](size_t size) {
void* ptr = nullptr;
musaMalloc(&ptr, size);
return ::musa::dnn::MemoryHandler(
ptr,
[](void* p) { if (p) musaFree(p); });
};
// 15. Run
matmul->RunLt(
mudnn_handle,
out_t,
a_t,
b_t,
::musa::dnn::Tensor(),
bias ? bias_t : ::musa::dnn::Tensor(),
lt_param,
maintainer);
return INFINI_STATUS_SUCCESS;
});
}
/* ================= Descriptor ================= */
struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t bias_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t a_scale_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t b_scale_desc)
{
auto handle = reinterpret_cast<device::moore::Handle *>(handle_);
auto dtype = out_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16);
auto result = I8GemmInfo::create(
out_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(
new Opaque{handle->internal()},
result.take(),
0,
dtype,
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *out,
const void *bias,
const void *a,
const void *a_scale,
const void *b,
const void *b_scale,
void *stream_) const
{
moore_i8gemm_launch(
_info,
_opaque->internal,
out,
static_cast<const int8_t*>(a),
static_cast<const int8_t*>(b),
static_cast<const float*>(a_scale),
static_cast<const float*>(b_scale),
bias,
_out_dtype,
reinterpret_cast<musaStream_t>(stream_));
return INFINI_STATUS_SUCCESS;
}
} // namespace op::i8gemm::moore
src/infiniop/ops/scaled_mm/operator.cc
View file @
d4f726de
...
@@ -6,6 +6,10 @@
...
@@ -6,6 +6,10 @@
#include "nvidia/int8_gemm_nvidia.cuh"
#include "nvidia/int8_gemm_nvidia.cuh"
#endif
#endif
#if defined(ENABLE_MOORE_API)
#include "moore/int8_gemm_moore.h"
#endif
__C
infiniStatus_t
infiniopCreateI8GemmDescriptor
(
infiniopHandle_t
handle
,
__C
infiniStatus_t
infiniopCreateI8GemmDescriptor
(
infiniopHandle_t
handle
,
infiniopI8GemmDescriptor_t
*
desc_ptr
,
infiniopI8GemmDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
infiniopTensorDescriptor_t
out_desc
,
...
@@ -31,6 +35,9 @@ __C infiniStatus_t infiniopCreateI8GemmDescriptor(infiniopHandle_t handle,
...
@@ -31,6 +35,9 @@ __C infiniStatus_t infiniopCreateI8GemmDescriptor(infiniopHandle_t handle,
#endif
#endif
#if defined(ENABLE_QY_API)
#if defined(ENABLE_QY_API)
CREATE
(
INFINI_DEVICE_QY
,
nvidia
)
CREATE
(
INFINI_DEVICE_QY
,
nvidia
)
#endif
#if defined(ENABLE_MOORE_API)
CREATE
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
@@ -49,6 +56,9 @@ __C infiniStatus_t infiniopGetI8GemmWorkspaceSize(infiniopI8GemmDescriptor_t des
...
@@ -49,6 +56,9 @@ __C infiniStatus_t infiniopGetI8GemmWorkspaceSize(infiniopI8GemmDescriptor_t des
#endif
#endif
#if defined(ENABLE_QY_API)
#if defined(ENABLE_QY_API)
GET
(
INFINI_DEVICE_QY
,
nvidia
)
GET
(
INFINI_DEVICE_QY
,
nvidia
)
#endif
#if defined(ENABLE_MOORE_API)
GET
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
@@ -76,6 +86,9 @@ __C infiniStatus_t infiniopI8Gemm(infiniopI8GemmDescriptor_t desc,
...
@@ -76,6 +86,9 @@ __C infiniStatus_t infiniopI8Gemm(infiniopI8GemmDescriptor_t desc,
#endif
#endif
#if defined(ENABLE_QY_API)
#if defined(ENABLE_QY_API)
CACULATE
(
INFINI_DEVICE_QY
,
nvidia
)
CACULATE
(
INFINI_DEVICE_QY
,
nvidia
)
#endif
#if defined(ENABLE_MOORE_API)
CACULATE
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
@@ -94,6 +107,9 @@ __C infiniStatus_t infiniopDestroyI8GemmDescriptor(infiniopI8GemmDescriptor_t de
...
@@ -94,6 +107,9 @@ __C infiniStatus_t infiniopDestroyI8GemmDescriptor(infiniopI8GemmDescriptor_t de
#endif
#endif
#if defined(ENABLE_QY_API)
#if defined(ENABLE_QY_API)
DESTROY
(
INFINI_DEVICE_QY
,
nvidia
)
DESTROY
(
INFINI_DEVICE_QY
,
nvidia
)
#endif
#if defined(ENABLE_MOORE_API)
DESTROY
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
test/infiniop/scaled_mm_int8.py
View file @
d4f726de
...
@@ -25,6 +25,7 @@ from enum import Enum, auto
...
@@ -25,6 +25,7 @@ from enum import Enum, auto
# These are not meant to be imported from other modules
# These are not meant to be imported from other modules
_TEST_CASES_
=
[
_TEST_CASES_
=
[
# x_shape, w_shape, y_shape, alpha, beta
# x_shape, w_shape, y_shape, alpha, beta
((
2
,
4
),
(
4
,
2
),
(
2
,
2
)),
((
128
,
512
),
(
512
,
1024
),
(
128
,
1024
)),
((
128
,
512
),
(
512
,
1024
),
(
128
,
1024
)),
((
256
,
1024
),
(
1024
,
2048
),
(
256
,
2048
)),
((
256
,
1024
),
(
1024
,
2048
),
(
256
,
2048
)),
((
1024
,
2048
),
(
2048
,
1024
),
(
1024
,
1024
)),
((
1024
,
2048
),
(
2048
,
1024
),
(
1024
,
1024
)),
...
@@ -59,10 +60,8 @@ _TOLERANCE_MAP = {
...
@@ -59,10 +60,8 @@ _TOLERANCE_MAP = {
DEBUG
=
False
DEBUG
=
False
PROFILE
=
False
PROFILE
=
False
NUM_PRERUN
=
10
NUM_PRERUN
=
10
NUM_ITERATIONS
=
1000
NUM_ITERATIONS
=
100
def
to_int8
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
round
(
tensor
.
clamp
(
min
=-
128
,
max
=
127
)).
to
(
dtype
=
torch
.
int8
)
def
torch_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
):
def
torch_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
):
o
=
torch
.
matmul
(
a
.
to
(
torch
.
float32
),
b
.
to
(
torch
.
float32
))
o
=
torch
.
matmul
(
a
.
to
(
torch
.
float32
),
b
.
to
(
torch
.
float32
))
...
@@ -72,6 +71,7 @@ def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias):
...
@@ -72,6 +71,7 @@ def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias):
o
=
o
.
to
(
torch
.
float32
)
*
scale_a
.
view
(
-
1
,
1
)
*
scale_b
.
view
(
1
,
-
1
)
o
=
o
.
to
(
torch
.
float32
)
*
scale_a
.
view
(
-
1
,
1
)
*
scale_b
.
view
(
1
,
-
1
)
return
o
.
to
(
out_dtype
)
return
o
.
to
(
out_dtype
)
def
test
(
def
test
(
handle
,
handle
,
device
,
device
,
...
@@ -87,30 +87,38 @@ def test(
...
@@ -87,30 +87,38 @@ def test(
)
)
M
,
K
=
x_shape
M
,
K
=
x_shape
N
=
w_shape
[
1
]
N
=
w_shape
[
1
]
x_packed
=
to_int8
(
torch
.
randn
((
M
,
K
),
device
=
"cuda"
)
*
5
)
weights
=
to_int8
(
torch
.
randn
((
N
,
K
),
device
=
"cuda"
).
t
()
*
5
)
x_scale
=
torch
.
randn
((
M
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
weights_scale
=
torch
.
randn
((
N
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
bias
=
torch
.
randn
((
N
,),
device
=
"cuda"
,
dtype
=
torch
.
float16
if
dtype
==
InfiniDtype
.
F16
else
torch
.
bfloat16
)
*
10
ans
=
torch_scaled_mm
(
x_packed
,
weights
,
x_scale
,
weights_scale
,
torch
.
float16
if
dtype
==
InfiniDtype
.
F16
else
torch
.
bfloat16
,
bias
=
bias
)
x_packed
=
TestTensor
(
x_packed
=
TestTensor
(
(
M
,
K
),
x_packed
.
stride
(),
InfiniDtype
.
I8
,
device
,
mode
=
"manual"
,
set_tensor
=
x_packed
(
M
,
K
),
)
None
,
x_scale
=
TestTensor
(
InfiniDtype
.
I8
,
(
M
,),
x_scale
.
stride
(),
InfiniDtype
.
F32
,
device
,
mode
=
"manual"
,
set_tensor
=
x_scale
device
,
mode
=
"randint"
,
randint_low
=-
128
,
randint_high
=
127
,
)
)
weights
=
TestTensor
(
weights
=
TestTensor
(
(
K
,
N
),
weights
.
stride
(),
InfiniDtype
.
I8
,
device
,
mode
=
"manual"
,
set_tensor
=
weights
(
K
,
N
),
None
,
InfiniDtype
.
I8
,
device
,
mode
=
"randint"
,
randint_low
=-
128
,
randint_high
=
127
,
)
)
weights_scale
=
TestTensor
(
x_scale
=
TestTensor
((
M
,),
None
,
InfiniDtype
.
F32
,
device
,
mode
=
"random"
)
(
N
,),
weights_scale
.
stride
(),
InfiniDtype
.
F32
,
device
,
mode
=
"manual"
,
set_tensor
=
weights_scale
weights_scale
=
TestTensor
((
N
,),
None
,
InfiniDtype
.
F32
,
device
,
mode
=
"random"
)
bias
=
TestTensor
((
N
,),
None
,
dtype
,
device
,
mode
=
"random"
)
y
=
TestTensor
(
y_shape
,
None
,
dtype
,
device
,
mode
=
"zeros"
)
ans
=
torch_scaled_mm
(
x_packed
.
torch_tensor
(),
weights
.
torch_tensor
(),
x_scale
.
torch_tensor
(),
weights_scale
.
torch_tensor
(),
out_dtype
=
torch
.
float16
if
dtype
==
InfiniDtype
.
F16
else
torch
.
bfloat16
,
bias
=
bias
.
torch_tensor
(),
)
)
y
=
TestTensor
(
y_shape
,
None
,
dtype
,
device
)
bias
=
TestTensor
((
N
,),
bias
.
stride
(),
dtype
,
device
,
mode
=
"manual"
,
set_tensor
=
bias
)
descriptor
=
infiniopOperatorDescriptor_t
()
descriptor
=
infiniopOperatorDescriptor_t
()
check_error
(
check_error
(
...
@@ -164,7 +172,20 @@ def test(
...
@@ -164,7 +172,20 @@ def test(
# Profiling workflow
# Profiling workflow
if
PROFILE
:
if
PROFILE
:
# fmt: off
# fmt: off
profile_operation
(
"PyTorch"
,
lambda
:
torch_scaled_mm
(
x_packed
,
weights
,
x_scale
,
weights_scale
,
torch
.
float16
if
dtype
==
InfiniDtype
.
F16
else
torch
.
bfloat16
,
bias
=
bias
),
device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
profile_operation
(
"PyTorch"
,
lambda
:
torch_scaled_mm
(
x_packed
.
torch_tensor
(),
weights
.
torch_tensor
(),
x_scale
.
torch_tensor
(),
weights_scale
.
torch_tensor
(),
out_dtype
=
torch
.
float16
if
dtype
==
InfiniDtype
.
F16
else
torch
.
bfloat16
,
bias
=
bias
.
torch_tensor
()
),
device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
profile_operation
(
" lib"
,
lambda
:
lib_linear
(),
device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
profile_operation
(
" lib"
,
lambda
:
lib_linear
(),
device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
# fmt: on
# fmt: on
...
@@ -181,6 +202,12 @@ if __name__ == "__main__":
...
@@ -181,6 +202,12 @@ if __name__ == "__main__":
NUM_ITERATIONS
=
args
.
num_iterations
NUM_ITERATIONS
=
args
.
num_iterations
for
device
in
get_test_devices
(
args
):
for
device
in
get_test_devices
(
args
):
test_operator
(
device
,
test
,
_TEST_CASES
,
_TENSOR_DTYPES
)
# muDNN(v3101): INT8 quantized multiplication → BF16 output.
# Moore backend: BF16 output only.
if
args
.
moore
==
True
:
_TENSOR_DTYPES_MOORE
=
[
InfiniDtype
.
BF16
]
test_operator
(
device
,
test
,
_TEST_CASES
,
_TENSOR_DTYPES_MOORE
)
else
:
test_operator
(
device
,
test
,
_TEST_CASES
,
_TENSOR_DTYPES
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment