Commit d4f726de authored by zhushuang's avatar zhushuang
Browse files

issue/972 - feat: add scaled_mm with muDNN BatchMatMul for moore gpu

parent 012df56c
#ifndef __GEMM_INFO_H__ #ifndef __I8GEMM_INFO_H__
#define __I8GEMM_INFO_H__ #define __I8GEMM_INFO_H__
#include "../../../utils.h" #include "../../../utils.h"
......
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
size_t workspace_size, \ size_t workspace_size, \
infiniDtype_t out_dtype, \ infiniDtype_t out_dtype, \
infiniDevice_t device_type, int device_id) \ infiniDevice_t device_type, int device_id) \
: InfiniopDescriptor{device_type, device_id}, _out_dtype(out_dtype), \ : InfiniopDescriptor{device_type, device_id}, _opaque(opaque), \
_opaque(opaque), _info(info), _workspace_size(workspace_size) {} \ _workspace_size(workspace_size), _info(info), _out_dtype(out_dtype) {} \
\ \
public: \ public: \
~Descriptor(); \ ~Descriptor(); \
......
#ifndef __INT8_GEMM_MOORE_API_H__
#define __INT8_GEMM_MOORE_API_H__
#include "../int8_gemm.h"
DESCRIPTOR(moore)
#endif // __INT8_GEMM_MOORE_API_H__
#include "../../../devices/moore/moore_common.h"
#include "../../../devices/moore/moore_handle.h"
#include "int8_gemm_moore.h"
namespace op::i8gemm::moore {
static void moore_i8gemm_launch(
const I8GemmInfo &info,
std::shared_ptr<device::moore::Handle::Internal> &internal,
void* out,
const int8_t* A,
const int8_t* B,
const float* A_scale,
const float* B_scale,
const void* bias,
infiniDtype_t out_dtype,
musaStream_t stream)
{
internal->useMudnn(stream,
[&](::musa::dnn::Handle &mudnn_handle) -> infiniStatus_t {
// 1. Operator
auto matmul = std::make_unique<::musa::dnn::BatchMatMul>();
matmul->SetComputeMode(::musa::dnn::BatchMatMul::ComputeMode::TENSOR);
// 2. Tensors
::musa::dnn::Tensor out_t, a_t, b_t, bias_t;
::musa::dnn::Tensor scale_a_t, scale_b_t;
// 3. Output dtype
if (out_dtype == INFINI_DTYPE_F16) {
out_t.SetType(::musa::dnn::Tensor::Type::HALF);
bias_t.SetType(::musa::dnn::Tensor::Type::HALF);
} else {
out_t.SetType(::musa::dnn::Tensor::Type::BFLOAT16);
bias_t.SetType(::musa::dnn::Tensor::Type::BFLOAT16);
}
// 4. Input INT8
a_t.SetType(::musa::dnn::Tensor::Type::INT8);
b_t.SetType(::musa::dnn::Tensor::Type::INT8);
// 5. Scale (per-tensor)
scale_a_t.SetType(::musa::dnn::Tensor::Type::FLOAT);
scale_b_t.SetType(::musa::dnn::Tensor::Type::FLOAT);
// 6. Bind memory
out_t.SetAddr(out);
a_t.SetAddr(const_cast<int8_t*>(A));
b_t.SetAddr(const_cast<int8_t*>(B));
scale_a_t.SetAddr(const_cast<float*>(A_scale));
scale_b_t.SetAddr(const_cast<float*>(B_scale));
if (bias)
bias_t.SetAddr(const_cast<void*>(bias));
// 7. A NdInfo
{
std::array<int64_t,3> dims;
std::array<int64_t,3> strides;
if (info.a_matrix.col_stride != 1) {
dims = {info.batch, info.k, info.m};
} else {
dims = {info.batch, info.m, info.k};
}
strides = {
info.a_matrix.stride,
info.a_matrix.ld(),
1
};
a_t.SetNdInfo(3, dims.data(), strides.data());
}
// 8. B NdInfo
{
std::array<int64_t,3> dims;
std::array<int64_t,3> strides;
if (info.b_matrix.col_stride != 1) {
dims = {info.batch, info.n, info.k};
} else {
dims = {info.batch, info.k, info.n};
}
strides = {
info.b_matrix.stride,
info.b_matrix.ld(),
1
};
b_t.SetNdInfo(3, dims.data(), strides.data());
}
// 9. out NdInfo
{
std::array<int64_t, 3> dims = {
info.batch,
info.m,
info.n
};
std::array<int64_t, 3> strides = {
info.m * info.n,
info.n,
1
};
out_t.SetNdInfo(3, dims.data(), strides.data());
}
// 10. Bias & scale NdInfo
if (bias) {
std::array<int64_t,1> dims = { info.n };
std::array<int64_t,1> strides = { 1 };
bias_t.SetNdInfo(1, dims.data(), strides.data());
}
{
std::array<int64_t,3> a_scale_dims = { info.batch, info.m, 1 };
std::array<int64_t,3> a_scale_strides = { info.m, 1, 1 };
scale_a_t.SetNdInfo(3, a_scale_dims.data(), a_scale_strides.data());
std::array<int64_t,3> b_scale_dims = { info.batch, 1, info.n };
std::array<int64_t,3> b_scale_strides = { info.n, 1, 1 };
scale_b_t.SetNdInfo(3, b_scale_dims.data(), b_scale_strides.data());
}
// 11. Transpose
matmul->SetTranspose(
info.a_matrix.col_stride != 1,
info.b_matrix.col_stride != 1);
// 12. Lt param (no epilogue enum)
::musa::dnn::MatMulLtParam lt_param;
lt_param.SetScale(
scale_a_t,
scale_b_t,
::musa::dnn::Tensor(),
::musa::dnn::Tensor());
// 13. Alpha / Beta
matmul->SetAlpha(1.0);
matmul->SetBeta(0.0);
matmul->SetGamma(1.0);
// 14. Workspace
::musa::dnn::MemoryMaintainer maintainer =
[](size_t size) {
void* ptr = nullptr;
musaMalloc(&ptr, size);
return ::musa::dnn::MemoryHandler(
ptr,
[](void* p) { if (p) musaFree(p); });
};
// 15. Run
matmul->RunLt(
mudnn_handle,
out_t,
a_t,
b_t,
::musa::dnn::Tensor(),
bias ? bias_t : ::musa::dnn::Tensor(),
lt_param,
maintainer);
return INFINI_STATUS_SUCCESS;
});
}
/* ================= Descriptor ================= */
struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t bias_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t a_scale_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t b_scale_desc)
{
auto handle = reinterpret_cast<device::moore::Handle *>(handle_);
auto dtype = out_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16);
auto result = I8GemmInfo::create(
out_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(
new Opaque{handle->internal()},
result.take(),
0,
dtype,
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *out,
const void *bias,
const void *a,
const void *a_scale,
const void *b,
const void *b_scale,
void *stream_) const
{
moore_i8gemm_launch(
_info,
_opaque->internal,
out,
static_cast<const int8_t*>(a),
static_cast<const int8_t*>(b),
static_cast<const float*>(a_scale),
static_cast<const float*>(b_scale),
bias,
_out_dtype,
reinterpret_cast<musaStream_t>(stream_));
return INFINI_STATUS_SUCCESS;
}
} // namespace op::i8gemm::moore
...@@ -6,6 +6,10 @@ ...@@ -6,6 +6,10 @@
#include "nvidia/int8_gemm_nvidia.cuh" #include "nvidia/int8_gemm_nvidia.cuh"
#endif #endif
#if defined(ENABLE_MOORE_API)
#include "moore/int8_gemm_moore.h"
#endif
__C infiniStatus_t infiniopCreateI8GemmDescriptor(infiniopHandle_t handle, __C infiniStatus_t infiniopCreateI8GemmDescriptor(infiniopHandle_t handle,
infiniopI8GemmDescriptor_t *desc_ptr, infiniopI8GemmDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc, infiniopTensorDescriptor_t out_desc,
...@@ -31,6 +35,9 @@ __C infiniStatus_t infiniopCreateI8GemmDescriptor(infiniopHandle_t handle, ...@@ -31,6 +35,9 @@ __C infiniStatus_t infiniopCreateI8GemmDescriptor(infiniopHandle_t handle,
#endif #endif
#if defined(ENABLE_QY_API) #if defined(ENABLE_QY_API)
CREATE(INFINI_DEVICE_QY, nvidia) CREATE(INFINI_DEVICE_QY, nvidia)
#endif
#if defined(ENABLE_MOORE_API)
CREATE(INFINI_DEVICE_MOORE, moore)
#endif #endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -49,6 +56,9 @@ __C infiniStatus_t infiniopGetI8GemmWorkspaceSize(infiniopI8GemmDescriptor_t des ...@@ -49,6 +56,9 @@ __C infiniStatus_t infiniopGetI8GemmWorkspaceSize(infiniopI8GemmDescriptor_t des
#endif #endif
#if defined(ENABLE_QY_API) #if defined(ENABLE_QY_API)
GET(INFINI_DEVICE_QY, nvidia) GET(INFINI_DEVICE_QY, nvidia)
#endif
#if defined(ENABLE_MOORE_API)
GET(INFINI_DEVICE_MOORE, moore)
#endif #endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -76,6 +86,9 @@ __C infiniStatus_t infiniopI8Gemm(infiniopI8GemmDescriptor_t desc, ...@@ -76,6 +86,9 @@ __C infiniStatus_t infiniopI8Gemm(infiniopI8GemmDescriptor_t desc,
#endif #endif
#if defined(ENABLE_QY_API) #if defined(ENABLE_QY_API)
CACULATE(INFINI_DEVICE_QY, nvidia) CACULATE(INFINI_DEVICE_QY, nvidia)
#endif
#if defined(ENABLE_MOORE_API)
CACULATE(INFINI_DEVICE_MOORE, moore)
#endif #endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -94,6 +107,9 @@ __C infiniStatus_t infiniopDestroyI8GemmDescriptor(infiniopI8GemmDescriptor_t de ...@@ -94,6 +107,9 @@ __C infiniStatus_t infiniopDestroyI8GemmDescriptor(infiniopI8GemmDescriptor_t de
#endif #endif
#if defined(ENABLE_QY_API) #if defined(ENABLE_QY_API)
DESTROY(INFINI_DEVICE_QY, nvidia) DESTROY(INFINI_DEVICE_QY, nvidia)
#endif
#if defined(ENABLE_MOORE_API)
DESTROY(INFINI_DEVICE_MOORE, moore)
#endif #endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
...@@ -25,6 +25,7 @@ from enum import Enum, auto ...@@ -25,6 +25,7 @@ from enum import Enum, auto
# These are not meant to be imported from other modules # These are not meant to be imported from other modules
_TEST_CASES_ = [ _TEST_CASES_ = [
# x_shape, w_shape, y_shape, alpha, beta # x_shape, w_shape, y_shape, alpha, beta
((2, 4), (4, 2), (2, 2)),
((128, 512), (512, 1024), (128, 1024)), ((128, 512), (512, 1024), (128, 1024)),
((256, 1024), (1024, 2048), (256, 2048)), ((256, 1024), (1024, 2048), (256, 2048)),
((1024, 2048), (2048, 1024), (1024, 1024)), ((1024, 2048), (2048, 1024), (1024, 1024)),
...@@ -59,10 +60,8 @@ _TOLERANCE_MAP = { ...@@ -59,10 +60,8 @@ _TOLERANCE_MAP = {
DEBUG = False DEBUG = False
PROFILE = False PROFILE = False
NUM_PRERUN = 10 NUM_PRERUN = 10
NUM_ITERATIONS = 1000 NUM_ITERATIONS = 100
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias): def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias):
o = torch.matmul(a.to(torch.float32), b.to(torch.float32)) o = torch.matmul(a.to(torch.float32), b.to(torch.float32))
...@@ -72,6 +71,7 @@ def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias): ...@@ -72,6 +71,7 @@ def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias):
o = o.to(torch.float32) * scale_a.view(-1, 1) * scale_b.view(1, -1) o = o.to(torch.float32) * scale_a.view(-1, 1) * scale_b.view(1, -1)
return o.to(out_dtype) return o.to(out_dtype)
def test( def test(
handle, handle,
device, device,
...@@ -87,30 +87,38 @@ def test( ...@@ -87,30 +87,38 @@ def test(
) )
M, K = x_shape M, K = x_shape
N = w_shape[1] N = w_shape[1]
x_packed = to_int8(torch.randn((M, K), device="cuda") * 5)
weights = to_int8(torch.randn((N, K), device="cuda").t() * 5)
x_scale = torch.randn((M,), device="cuda", dtype=torch.float32)
weights_scale = torch.randn((N,), device="cuda", dtype=torch.float32)
bias = torch.randn((N,), device="cuda", dtype=torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16) * 10
ans = torch_scaled_mm(x_packed, weights, x_scale, weights_scale, torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16, bias=bias)
x_packed = TestTensor( x_packed = TestTensor(
(M, K), x_packed.stride(), InfiniDtype.I8, device, mode="manual", set_tensor=x_packed (M, K),
) None,
x_scale = TestTensor( InfiniDtype.I8,
(M,), x_scale.stride(), InfiniDtype.F32, device, mode="manual", set_tensor=x_scale device,
mode="randint",
randint_low=-128,
randint_high=127,
) )
weights = TestTensor( weights = TestTensor(
(K, N), weights.stride(), InfiniDtype.I8, device, mode="manual", set_tensor=weights (K, N),
None,
InfiniDtype.I8,
device,
mode="randint",
randint_low=-128,
randint_high=127,
) )
weights_scale = TestTensor( x_scale = TestTensor((M,), None, InfiniDtype.F32, device, mode="random")
(N,), weights_scale.stride(), InfiniDtype.F32, device, mode="manual", set_tensor=weights_scale weights_scale = TestTensor((N,), None, InfiniDtype.F32, device, mode="random")
bias = TestTensor((N,), None, dtype, device, mode="random")
y = TestTensor(y_shape, None, dtype, device, mode="zeros")
ans = torch_scaled_mm(
x_packed.torch_tensor(),
weights.torch_tensor(),
x_scale.torch_tensor(),
weights_scale.torch_tensor(),
out_dtype=torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16,
bias=bias.torch_tensor(),
) )
y = TestTensor(y_shape, None, dtype, device)
bias = TestTensor((N,), bias.stride(), dtype, device, mode="manual", set_tensor=bias)
descriptor = infiniopOperatorDescriptor_t() descriptor = infiniopOperatorDescriptor_t()
check_error( check_error(
...@@ -164,7 +172,20 @@ def test( ...@@ -164,7 +172,20 @@ def test(
# Profiling workflow # Profiling workflow
if PROFILE: if PROFILE:
# fmt: off # fmt: off
profile_operation("PyTorch", lambda: torch_scaled_mm(x_packed, weights, x_scale, weights_scale, torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16, bias=bias), device, NUM_PRERUN, NUM_ITERATIONS) profile_operation(
"PyTorch",
lambda: torch_scaled_mm(
x_packed.torch_tensor(),
weights.torch_tensor(),
x_scale.torch_tensor(),
weights_scale.torch_tensor(),
out_dtype=torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16,
bias=bias.torch_tensor()
),
device,
NUM_PRERUN,
NUM_ITERATIONS
)
profile_operation(" lib", lambda: lib_linear(), device, NUM_PRERUN, NUM_ITERATIONS) profile_operation(" lib", lambda: lib_linear(), device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on # fmt: on
...@@ -181,6 +202,12 @@ if __name__ == "__main__": ...@@ -181,6 +202,12 @@ if __name__ == "__main__":
NUM_ITERATIONS = args.num_iterations NUM_ITERATIONS = args.num_iterations
for device in get_test_devices(args): for device in get_test_devices(args):
test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) # muDNN(v3101): INT8 quantized multiplication → BF16 output.
# Moore backend: BF16 output only.
if args.moore == True:
_TENSOR_DTYPES_MOORE = [InfiniDtype.BF16]
test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES_MOORE)
else:
test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m") print("\033[92mTest passed!\033[0m")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment