Commit 66a8eb93 authored by zhushuang's avatar zhushuang
Browse files

feat: Add BF16 support gemm in moore gpu and rename existing 'musa' to 'moore' in some files

parent 831021b8
......@@ -15,7 +15,7 @@
#include "ascend/ascend_handle.h"
#endif
#ifdef ENABLE_MOORE_API
#include "musa/musa_handle.h"
#include "moore/moore_handle.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/kunlun_handle.h"
......@@ -54,7 +54,7 @@ __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) {
CREATE(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, musa);
CREATE(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
......@@ -94,7 +94,7 @@ __C infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle) {
DELETE(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_MOORE_API
DELETE(INFINI_DEVICE_MOORE, musa);
DELETE(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
......
#include "../../../utils.h"
#include "../pool.h"
#include "musa_handle.h"
#include "moore_handle.h"
#include <mublas.h>
#include <mudnn.h>
#include <musa.h>
......@@ -10,7 +10,7 @@
#define CHECK_MUBLAS(API) CHECK_INTERNAL(API, MUBLAS_STATUS_SUCCESS)
#define CHECK_MUDNN(API) CHECK_INTERNAL((int)API, (int)::musa::dnn::Status::SUCCESS)
namespace device::musa {
namespace device::moore {
class Handle::Internal {
Pool<std::unique_ptr<mublasHandle_t>> mublas_handles;
......@@ -39,4 +39,4 @@ public:
int gridSizeZ() const;
};
} // namespace device::musa
} // namespace device::moore
#include "common_musa.h"
#include "moore_common.h"
namespace device::musa {
namespace device::moore {
Handle::Handle(infiniDevice_t device, int device_id)
: InfiniopHandle{device, device_id},
_internal(std::make_shared<Handle::Internal>(device_id)) {}
......@@ -67,4 +67,4 @@ infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
return INFINI_STATUS_SUCCESS;
}
} // namespace device::musa
} // namespace device::moore
#ifndef __INFINIOP_MUSA_HANDLE_H__
#define __INFINIOP_MUSA_HANDLE_H__
#ifndef __INFINIOP_MOORE_HANDLE_H__
#define __INFINIOP_MOORE_HANDLE_H__
#include "../../handle.h"
#include <memory>
namespace device::musa {
namespace device::moore {
struct Handle : public InfiniopHandle {
Handle(int device_id);
class Internal;
......@@ -20,6 +20,6 @@ private:
std::shared_ptr<Internal> _internal;
};
} // namespace device::musa
} // namespace device::moore
#endif // __INFINIOP_MUSA_HANDLE_H__
#endif // __INFINIOP_MOORE_HANDLE_H__
#define INFINIOP_MUSA_KERNEL __global__ void
#define INFINIOP_MOORE_KERNEL __global__ void
#include <musa_bf16.h>
#include <musa_fp16.h>
// Posible maximum number of threads per block for MUSA architectures
// Used for picking correct kernel launch configuration
#define MUSA_BLOCK_SIZE_2048 2048
#define MUSA_BLOCK_SIZE_1024 1024
#define MUSA_BLOCK_SIZE_512 512
#define MOORE_BLOCK_SIZE_2048 2048
#define MOORE_BLOCK_SIZE_1024 1024
#define MOORE_BLOCK_SIZE_512 512
#define CHECK_MUSA(API) CHECK_INTERNAL(API, musaSuccess)
#define CHECK_MOORE(API) CHECK_INTERNAL(API, musaSuccess)
using musa_bfloat16 = mt_bfloat16;
using musa_bfloat162 = mt_bfloat162;
namespace device::musa {
namespace device::moore {
// return the memory offset of original tensor, given the flattened index of broadcasted tensor
__forceinline__ __device__ __host__ size_t
......@@ -45,7 +45,7 @@ indexToOffset(
}
return res;
}
} // namespace device::musa
} // namespace device::moore
__forceinline__ __device__ float
exp_(const float val) {
......
#ifndef __GEMM_MOORE_H__
#define __GEMM_MOORE_H__
#include "../gemm.h"
DESCRIPTOR(moore)
#endif // __GEMM_MOORE_H__
#include "../../../devices/musa/common_musa.h"
#include "../../../devices/musa/musa_handle.h"
#include "gemm_musa.h"
#include "../../../devices/moore/moore_common.h"
#include "../../../devices/moore/moore_handle.h"
#include "gemm_moore.h"
namespace op::gemm::musa {
namespace op::gemm::moore {
struct Descriptor::Opaque {
std::shared_ptr<device::musa::Handle::Internal> internal;
std::shared_ptr<device::moore::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
......@@ -18,10 +18,10 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
auto handle = reinterpret_cast<device::musa::Handle *>(handle_);
auto handle = reinterpret_cast<device::moore::Handle *>(handle_);
auto dtype = c_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32);
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
CHECK_RESULT(result);
......@@ -33,41 +33,63 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_SUCCESS;
}
template <typename Tdata>
infiniStatus_t calculate(
const MatmulInfo &info,
std::shared_ptr<device::musa::Handle::Internal> &_internal,
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
void *stream) {
void *stream) const {
musaDataType a_type, b_type, c_type;
mublasComputeType_t compute_type;
Tdata alpha_, beta_;
if constexpr (std::is_same<Tdata, half>::value) {
alpha_ = __float2half(alpha);
beta_ = __float2half(beta);
// MUSA's GEMM operations require that the scalar values alpha and beta have the same data type as the matrices.
// This ensures correct computation during the muBLAS GEMM operation.
// Declare half-precision variables to handle F16 types.
half alpha_h, beta_h;
// Initialize generic void pointers for alpha and beta.
// They point to the original float values
// It will be used directly when the GEMM operation is performed with F32 data.
const void *p_alpha = &alpha;
const void *p_beta = &beta;
switch (_dtype) {
case INFINI_DTYPE_F16:
a_type = b_type = c_type = MUSA_R_16F;
compute_type = MUBLAS_COMPUTE_16F;
} else {
alpha_ = alpha;
beta_ = beta;
// Convert alpha/beta to half-precision and update the pointers.
alpha_h = __float2half(alpha);
beta_h = __float2half(beta);
p_alpha = &alpha_h;
p_beta = &beta_h;
break;
case INFINI_DTYPE_BF16:
a_type = b_type = c_type = MUSA_R_16BF;
compute_type = MUBLAS_COMPUTE_32F;
break;
case INFINI_DTYPE_F32:
a_type = b_type = c_type = MUSA_R_32F;
compute_type = MUBLAS_COMPUTE_32F_FAST_TF32;
break;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (info.is_transed) {
if (_info.is_transed) {
std::swap(a, b);
}
auto op_a = info.a_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T;
auto op_b = info.b_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T;
auto op_a = _info.a_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T;
auto op_b = _info.b_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T;
CHECK_STATUS(_internal->useMublas(
CHECK_STATUS(_opaque->internal->useMublas(
(musaStream_t)stream,
[&](mublasHandle_t handle) {
CHECK_MUBLAS(
......@@ -75,24 +97,24 @@ infiniStatus_t calculate(
handle,
op_a,
op_b,
static_cast<int>(info.m),
static_cast<int>(info.n),
static_cast<int>(info.k),
&alpha_,
static_cast<int>(_info.m),
static_cast<int>(_info.n),
static_cast<int>(_info.k),
p_alpha,
a,
a_type,
static_cast<int>(info.a_matrix.ld()),
info.a_matrix.stride,
static_cast<int>(_info.a_matrix.ld()),
_info.a_matrix.stride,
b,
b_type,
static_cast<int>(info.b_matrix.ld()),
info.b_matrix.stride,
&beta_,
static_cast<int>(_info.b_matrix.ld()),
_info.b_matrix.stride,
p_beta,
c,
c_type,
static_cast<int>(info.c_matrix.ld()),
info.c_matrix.stride,
static_cast<int>(info.batch),
static_cast<int>(_info.c_matrix.ld()),
_info.c_matrix.stride,
static_cast<int>(_info.batch),
compute_type,
MUBLAS_GEMM_DEFAULT));
return INFINI_STATUS_SUCCESS;
......@@ -100,22 +122,4 @@ infiniStatus_t calculate(
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace,
size_t workspace_size,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
void *stream) const {
switch (_dtype) {
case INFINI_DTYPE_F16:
return musa::calculate<half>(_info, _opaque->internal, c, beta, a, b, alpha, stream);
case INFINI_DTYPE_F32:
return musa::calculate<float>(_info,_opaque->internal, c, beta, a, b, alpha, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
} // namespace op::gemm::musa
} // namespace op::gemm::moore
#ifndef __GEMM_MUSA_H__
#define __GEMM_MUSA_H__
#include "../gemm.h"
DESCRIPTOR(musa)
#endif // __GEMM_MUSA_H__
......@@ -18,7 +18,7 @@
#include "metax/gemm_metax.h"
#endif
#ifdef ENABLE_MOORE_API
#include "musa/gemm_musa.h"
#include "moore/gemm_moore.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/gemm_kunlun.h"
......@@ -61,7 +61,7 @@ __C infiniStatus_t infiniopCreateGemmDescriptor(
CREATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, musa);
CREATE(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_KUNLUN_API
......@@ -106,7 +106,7 @@ infiniopGetGemmWorkspaceSize(
GET(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, musa);
GET(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun);
......@@ -158,7 +158,7 @@ __C infiniStatus_t infiniopGemm(
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, musa);
CALCULATE(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
......@@ -200,7 +200,7 @@ infiniopDestroyGemmDescriptor(infiniopGemmDescriptor_t desc) {
DELETE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_MOORE_API
DELETE(INFINI_DEVICE_MOORE, musa);
DELETE(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
......
#ifndef __RMS_NORM_MUSA_H__
#define __RMS_NORM_MUSA_H__
#ifndef __RMS_NORM_MOORE_H__
#define __RMS_NORM_MOORE_H__
#include "../rms_norm.h"
DESCRIPTOR(musa)
DESCRIPTOR(moore)
#endif
#include "../../../devices/musa/common_musa.h"
#include "rms_norm_musa.h"
#include "../../../devices/moore/moore_common.h"
#include "rms_norm_moore.h"
#include "../../../devices/musa/musa_kernel_common.h"
#include "../../../devices/moore/moore_kernel_common.h"
#include <cub/block/block_reduce.cuh>
#include "../../../reduce/cuda/reduce.cuh"
......@@ -9,7 +9,7 @@
#include "../cuda/kernel.cuh"
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_MUSA_KERNEL rmsnormKernel(
INFINIOP_MOORE_KERNEL rmsnormKernel(
Tdata *__restrict__ y,
ptrdiff_t stride_y,
const Tdata *__restrict__ x,
......@@ -20,10 +20,10 @@ INFINIOP_MUSA_KERNEL rmsnormKernel(
rmsnormBlock<BLOCK_SIZE, Tcompute>(y, stride_y, x, stride_x, w, dim, epsilon);
}
namespace op::rms_norm::musa {
namespace op::rms_norm::moore {
struct Descriptor::Opaque {
std::shared_ptr<device::musa::Handle::Internal> internal;
std::shared_ptr<device::moore::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
......@@ -47,7 +47,7 @@ infiniStatus_t Descriptor::create(
}
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::musa::Handle *>(handle)->internal()},
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
std::move(info),
0,
handle->device, handle->device_id);
......@@ -109,15 +109,15 @@ infiniStatus_t Descriptor::calculate(
auto musa_stream = reinterpret_cast<musaStream_t>(stream);
// launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == MUSA_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<MUSA_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MUSA_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<MUSA_BLOCK_SIZE_512>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MUSA_BLOCK_SIZE_2048) {
CHECK_STATUS(launchKernel<MUSA_BLOCK_SIZE_2048>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_512>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_2048) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_2048>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::rms_norm::musa
} // namespace op::rms_norm::moore
......@@ -15,7 +15,7 @@
#include "metax/rms_norm_metax.cuh"
#endif
#ifdef ENABLE_MOORE_API
#include "musa/rms_norm_musa.h"
#include "moore/rms_norm_moore.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/rms_norm_kunlun.h"
......@@ -64,7 +64,7 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
CREATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, musa);
CREATE(INFINI_DEVICE_MOORE, moore);
#endif
}
......@@ -105,7 +105,7 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d
GET(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, musa);
GET(INFINI_DEVICE_MOORE, moore);
#endif
}
......@@ -147,7 +147,7 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, musa);
CALCULATE(INFINI_DEVICE_MOORE, moore);
#endif
}
......@@ -188,7 +188,7 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t
DESTROY(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_MOORE_API
DESTROY(INFINI_DEVICE_MOORE, musa);
DESTROY(INFINI_DEVICE_MOORE, moore);
#endif
}
......
......@@ -6,7 +6,7 @@
#include "cuda/infinirt_cuda.cuh"
#include "kunlun/infinirt_kunlun.h"
#include "metax/infinirt_metax.h"
#include "musa/infinirt_musa.h"
#include "moore/infinirt_moore.h"
thread_local infiniDevice_t CURRENT_DEVICE_TYPE = INFINI_DEVICE_CPU;
thread_local int CURRENT_DEVICE_ID = 0;
......
#include "infinirt_musa.h"
#include "infinirt_moore.h"
#include "../../utils.h"
#include <musa_runtime.h>
#include <musa_runtime_api.h>
......
......@@ -119,7 +119,7 @@ option_end()
if has_config("moore-gpu") then
add_defines("ENABLE_MOORE_API")
includes("xmake/musa.lua")
includes("xmake/moore.lua")
end
-- 海光
......
......@@ -42,8 +42,8 @@ target("infiniop-moore")
set_languages("cxx17")
set_warnings("all", "error")
add_cxflags("-lstdc++", "-fPIC", "-Wno-comment")
add_files("../src/infiniop/devices/musa/*.cc")
add_files("../src/infiniop/ops/*/musa/*.mu", {rule = "mu"})
add_files("../src/infiniop/devices/moore/*.cc")
add_files("../src/infiniop/ops/*/moore/*.mu", {rule = "mu"})
target_end()
target("infinirt-moore")
......@@ -53,5 +53,5 @@ target("infinirt-moore")
add_deps("infini-utils")
set_warnings("all", "error")
add_cxflags("-lstdc++", "-fPIC")
add_files("../src/infinirt/musa/*.cc")
add_files("../src/infinirt/moore/*.cc")
target_end()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment