Commit d4f726de authored by zhushuang's avatar zhushuang
Browse files

issue/972 - feat: add scaled_mm with muDNN BatchMatMul for moore gpu

parent 012df56c
#ifndef __GEMM_INFO_H__
#ifndef __I8GEMM_INFO_H__
#define __I8GEMM_INFO_H__
#include "../../../utils.h"
......
......@@ -18,8 +18,8 @@
size_t workspace_size, \
infiniDtype_t out_dtype, \
infiniDevice_t device_type, int device_id) \
: InfiniopDescriptor{device_type, device_id}, _out_dtype(out_dtype), \
_opaque(opaque), _info(info), _workspace_size(workspace_size) {} \
: InfiniopDescriptor{device_type, device_id}, _opaque(opaque), \
_workspace_size(workspace_size), _info(info), _out_dtype(out_dtype) {} \
\
public: \
~Descriptor(); \
......
#ifndef __INT8_GEMM_MOORE_API_H__
#define __INT8_GEMM_MOORE_API_H__
#include "../int8_gemm.h"
DESCRIPTOR(moore)
#endif // __INT8_GEMM_MOORE_API_H__
#include "../../../devices/moore/moore_common.h"
#include "../../../devices/moore/moore_handle.h"
#include "int8_gemm_moore.h"
namespace op::i8gemm::moore {
static void moore_i8gemm_launch(
const I8GemmInfo &info,
std::shared_ptr<device::moore::Handle::Internal> &internal,
void* out,
const int8_t* A,
const int8_t* B,
const float* A_scale,
const float* B_scale,
const void* bias,
infiniDtype_t out_dtype,
musaStream_t stream)
{
internal->useMudnn(stream,
[&](::musa::dnn::Handle &mudnn_handle) -> infiniStatus_t {
// 1. Operator
auto matmul = std::make_unique<::musa::dnn::BatchMatMul>();
matmul->SetComputeMode(::musa::dnn::BatchMatMul::ComputeMode::TENSOR);
// 2. Tensors
::musa::dnn::Tensor out_t, a_t, b_t, bias_t;
::musa::dnn::Tensor scale_a_t, scale_b_t;
// 3. Output dtype
if (out_dtype == INFINI_DTYPE_F16) {
out_t.SetType(::musa::dnn::Tensor::Type::HALF);
bias_t.SetType(::musa::dnn::Tensor::Type::HALF);
} else {
out_t.SetType(::musa::dnn::Tensor::Type::BFLOAT16);
bias_t.SetType(::musa::dnn::Tensor::Type::BFLOAT16);
}
// 4. Input INT8
a_t.SetType(::musa::dnn::Tensor::Type::INT8);
b_t.SetType(::musa::dnn::Tensor::Type::INT8);
// 5. Scale (per-tensor)
scale_a_t.SetType(::musa::dnn::Tensor::Type::FLOAT);
scale_b_t.SetType(::musa::dnn::Tensor::Type::FLOAT);
// 6. Bind memory
out_t.SetAddr(out);
a_t.SetAddr(const_cast<int8_t*>(A));
b_t.SetAddr(const_cast<int8_t*>(B));
scale_a_t.SetAddr(const_cast<float*>(A_scale));
scale_b_t.SetAddr(const_cast<float*>(B_scale));
if (bias)
bias_t.SetAddr(const_cast<void*>(bias));
// 7. A NdInfo
{
std::array<int64_t,3> dims;
std::array<int64_t,3> strides;
if (info.a_matrix.col_stride != 1) {
dims = {info.batch, info.k, info.m};
} else {
dims = {info.batch, info.m, info.k};
}
strides = {
info.a_matrix.stride,
info.a_matrix.ld(),
1
};
a_t.SetNdInfo(3, dims.data(), strides.data());
}
// 8. B NdInfo
{
std::array<int64_t,3> dims;
std::array<int64_t,3> strides;
if (info.b_matrix.col_stride != 1) {
dims = {info.batch, info.n, info.k};
} else {
dims = {info.batch, info.k, info.n};
}
strides = {
info.b_matrix.stride,
info.b_matrix.ld(),
1
};
b_t.SetNdInfo(3, dims.data(), strides.data());
}
// 9. out NdInfo
{
std::array<int64_t, 3> dims = {
info.batch,
info.m,
info.n
};
std::array<int64_t, 3> strides = {
info.m * info.n,
info.n,
1
};
out_t.SetNdInfo(3, dims.data(), strides.data());
}
// 10. Bias & scale NdInfo
if (bias) {
std::array<int64_t,1> dims = { info.n };
std::array<int64_t,1> strides = { 1 };
bias_t.SetNdInfo(1, dims.data(), strides.data());
}
{
std::array<int64_t,3> a_scale_dims = { info.batch, info.m, 1 };
std::array<int64_t,3> a_scale_strides = { info.m, 1, 1 };
scale_a_t.SetNdInfo(3, a_scale_dims.data(), a_scale_strides.data());
std::array<int64_t,3> b_scale_dims = { info.batch, 1, info.n };
std::array<int64_t,3> b_scale_strides = { info.n, 1, 1 };
scale_b_t.SetNdInfo(3, b_scale_dims.data(), b_scale_strides.data());
}
// 11. Transpose
matmul->SetTranspose(
info.a_matrix.col_stride != 1,
info.b_matrix.col_stride != 1);
// 12. Lt param (no epilogue enum)
::musa::dnn::MatMulLtParam lt_param;
lt_param.SetScale(
scale_a_t,
scale_b_t,
::musa::dnn::Tensor(),
::musa::dnn::Tensor());
// 13. Alpha / Beta
matmul->SetAlpha(1.0);
matmul->SetBeta(0.0);
matmul->SetGamma(1.0);
// 14. Workspace
::musa::dnn::MemoryMaintainer maintainer =
[](size_t size) {
void* ptr = nullptr;
musaMalloc(&ptr, size);
return ::musa::dnn::MemoryHandler(
ptr,
[](void* p) { if (p) musaFree(p); });
};
// 15. Run
matmul->RunLt(
mudnn_handle,
out_t,
a_t,
b_t,
::musa::dnn::Tensor(),
bias ? bias_t : ::musa::dnn::Tensor(),
lt_param,
maintainer);
return INFINI_STATUS_SUCCESS;
});
}
/* ================= Descriptor ================= */
struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t bias_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t a_scale_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t b_scale_desc)
{
auto handle = reinterpret_cast<device::moore::Handle *>(handle_);
auto dtype = out_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16);
auto result = I8GemmInfo::create(
out_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(
new Opaque{handle->internal()},
result.take(),
0,
dtype,
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *out,
const void *bias,
const void *a,
const void *a_scale,
const void *b,
const void *b_scale,
void *stream_) const
{
moore_i8gemm_launch(
_info,
_opaque->internal,
out,
static_cast<const int8_t*>(a),
static_cast<const int8_t*>(b),
static_cast<const float*>(a_scale),
static_cast<const float*>(b_scale),
bias,
_out_dtype,
reinterpret_cast<musaStream_t>(stream_));
return INFINI_STATUS_SUCCESS;
}
} // namespace op::i8gemm::moore
......@@ -6,6 +6,10 @@
#include "nvidia/int8_gemm_nvidia.cuh"
#endif
#if defined(ENABLE_MOORE_API)
#include "moore/int8_gemm_moore.h"
#endif
__C infiniStatus_t infiniopCreateI8GemmDescriptor(infiniopHandle_t handle,
infiniopI8GemmDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
......@@ -31,6 +35,9 @@ __C infiniStatus_t infiniopCreateI8GemmDescriptor(infiniopHandle_t handle,
#endif
#if defined(ENABLE_QY_API)
CREATE(INFINI_DEVICE_QY, nvidia)
#endif
#if defined(ENABLE_MOORE_API)
CREATE(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -49,6 +56,9 @@ __C infiniStatus_t infiniopGetI8GemmWorkspaceSize(infiniopI8GemmDescriptor_t des
#endif
#if defined(ENABLE_QY_API)
GET(INFINI_DEVICE_QY, nvidia)
#endif
#if defined(ENABLE_MOORE_API)
GET(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -76,6 +86,9 @@ __C infiniStatus_t infiniopI8Gemm(infiniopI8GemmDescriptor_t desc,
#endif
#if defined(ENABLE_QY_API)
CACULATE(INFINI_DEVICE_QY, nvidia)
#endif
#if defined(ENABLE_MOORE_API)
CACULATE(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -94,6 +107,9 @@ __C infiniStatus_t infiniopDestroyI8GemmDescriptor(infiniopI8GemmDescriptor_t de
#endif
#if defined(ENABLE_QY_API)
DESTROY(INFINI_DEVICE_QY, nvidia)
#endif
#if defined(ENABLE_MOORE_API)
DESTROY(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
......@@ -25,6 +25,7 @@ from enum import Enum, auto
# These are not meant to be imported from other modules
_TEST_CASES_ = [
# x_shape, w_shape, y_shape, alpha, beta
((2, 4), (4, 2), (2, 2)),
((128, 512), (512, 1024), (128, 1024)),
((256, 1024), (1024, 2048), (256, 2048)),
((1024, 2048), (2048, 1024), (1024, 1024)),
......@@ -59,10 +60,8 @@ _TOLERANCE_MAP = {
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
NUM_ITERATIONS = 100
def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias):
o = torch.matmul(a.to(torch.float32), b.to(torch.float32))
......@@ -72,6 +71,7 @@ def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias):
o = o.to(torch.float32) * scale_a.view(-1, 1) * scale_b.view(1, -1)
return o.to(out_dtype)
def test(
handle,
device,
......@@ -87,30 +87,38 @@ def test(
)
M, K = x_shape
N = w_shape[1]
x_packed = to_int8(torch.randn((M, K), device="cuda") * 5)
weights = to_int8(torch.randn((N, K), device="cuda").t() * 5)
x_scale = torch.randn((M,), device="cuda", dtype=torch.float32)
weights_scale = torch.randn((N,), device="cuda", dtype=torch.float32)
bias = torch.randn((N,), device="cuda", dtype=torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16) * 10
ans = torch_scaled_mm(x_packed, weights, x_scale, weights_scale, torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16, bias=bias)
x_packed = TestTensor(
(M, K), x_packed.stride(), InfiniDtype.I8, device, mode="manual", set_tensor=x_packed
)
x_scale = TestTensor(
(M,), x_scale.stride(), InfiniDtype.F32, device, mode="manual", set_tensor=x_scale
(M, K),
None,
InfiniDtype.I8,
device,
mode="randint",
randint_low=-128,
randint_high=127,
)
weights = TestTensor(
(K, N), weights.stride(), InfiniDtype.I8, device, mode="manual", set_tensor=weights
(K, N),
None,
InfiniDtype.I8,
device,
mode="randint",
randint_low=-128,
randint_high=127,
)
weights_scale = TestTensor(
(N,), weights_scale.stride(), InfiniDtype.F32, device, mode="manual", set_tensor=weights_scale
x_scale = TestTensor((M,), None, InfiniDtype.F32, device, mode="random")
weights_scale = TestTensor((N,), None, InfiniDtype.F32, device, mode="random")
bias = TestTensor((N,), None, dtype, device, mode="random")
y = TestTensor(y_shape, None, dtype, device, mode="zeros")
ans = torch_scaled_mm(
x_packed.torch_tensor(),
weights.torch_tensor(),
x_scale.torch_tensor(),
weights_scale.torch_tensor(),
out_dtype=torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16,
bias=bias.torch_tensor(),
)
y = TestTensor(y_shape, None, dtype, device)
bias = TestTensor((N,), bias.stride(), dtype, device, mode="manual", set_tensor=bias)
descriptor = infiniopOperatorDescriptor_t()
check_error(
......@@ -164,7 +172,20 @@ def test(
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: torch_scaled_mm(x_packed, weights, x_scale, weights_scale, torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16, bias=bias), device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(
"PyTorch",
lambda: torch_scaled_mm(
x_packed.torch_tensor(),
weights.torch_tensor(),
x_scale.torch_tensor(),
weights_scale.torch_tensor(),
out_dtype=torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16,
bias=bias.torch_tensor()
),
device,
NUM_PRERUN,
NUM_ITERATIONS
)
profile_operation(" lib", lambda: lib_linear(), device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
......@@ -181,6 +202,12 @@ if __name__ == "__main__":
NUM_ITERATIONS = args.num_iterations
for device in get_test_devices(args):
test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
# muDNN(v3101): INT8 quantized multiplication → BF16 output.
# Moore backend: BF16 output only.
if args.moore == True:
_TENSOR_DTYPES_MOORE = [InfiniDtype.BF16]
test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES_MOORE)
else:
test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment