Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
66a8eb93
Commit
66a8eb93
authored
Aug 18, 2025
by
zhushuang
Browse files
feat: Add BF16 support gemm in moore gpu and rename existing 'musa' to 'moore' in some files
parent
831021b8
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
192 additions
and
59 deletions
+192
-59
src/infiniop/devices/handle.cc
src/infiniop/devices/handle.cc
+3
-3
src/infiniop/devices/moore/moore_common.h
src/infiniop/devices/moore/moore_common.h
+3
-3
src/infiniop/devices/moore/moore_handle.cc
src/infiniop/devices/moore/moore_handle.cc
+3
-3
src/infiniop/devices/moore/moore_handle.h
src/infiniop/devices/moore/moore_handle.h
+5
-5
src/infiniop/devices/moore/moore_kernel_common.h
src/infiniop/devices/moore/moore_kernel_common.h
+7
-7
src/infiniop/ops/gemm/moore/gemm_moore.h
src/infiniop/ops/gemm/moore/gemm_moore.h
+8
-0
src/infiniop/ops/gemm/moore/gemm_moore.mu
src/infiniop/ops/gemm/moore/gemm_moore.mu
+125
-0
src/infiniop/ops/gemm/musa/gemm_musa.h
src/infiniop/ops/gemm/musa/gemm_musa.h
+0
-8
src/infiniop/ops/gemm/operator.cc
src/infiniop/ops/gemm/operator.cc
+5
-5
src/infiniop/ops/rms_norm/moore/rms_norm_moore.h
src/infiniop/ops/rms_norm/moore/rms_norm_moore.h
+8
-0
src/infiniop/ops/rms_norm/moore/rms_norm_moore.mu
src/infiniop/ops/rms_norm/moore/rms_norm_moore.mu
+14
-14
src/infiniop/ops/rms_norm/operator.cc
src/infiniop/ops/rms_norm/operator.cc
+5
-5
src/infinirt/infinirt.cc
src/infinirt/infinirt.cc
+1
-1
src/infinirt/moore/infinirt_moore.cc
src/infinirt/moore/infinirt_moore.cc
+1
-1
src/infinirt/moore/infinirt_moore.h
src/infinirt/moore/infinirt_moore.h
+0
-0
xmake.lua
xmake.lua
+1
-1
xmake/moore.lua
xmake/moore.lua
+3
-3
No files found.
src/infiniop/devices/handle.cc
View file @
66a8eb93
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
#include "ascend/ascend_handle.h"
#include "ascend/ascend_handle.h"
#endif
#endif
#ifdef ENABLE_MOORE_API
#ifdef ENABLE_MOORE_API
#include "m
usa/musa
_handle.h"
#include "m
oore/moore
_handle.h"
#endif
#endif
#ifdef ENABLE_KUNLUN_API
#ifdef ENABLE_KUNLUN_API
#include "kunlun/kunlun_handle.h"
#include "kunlun/kunlun_handle.h"
...
@@ -54,7 +54,7 @@ __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) {
...
@@ -54,7 +54,7 @@ __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) {
CREATE
(
INFINI_DEVICE_ASCEND
,
ascend
);
CREATE
(
INFINI_DEVICE_ASCEND
,
ascend
);
#endif
#endif
#ifdef ENABLE_MOORE_API
#ifdef ENABLE_MOORE_API
CREATE
(
INFINI_DEVICE_MOORE
,
m
usa
);
CREATE
(
INFINI_DEVICE_MOORE
,
m
oore
);
#endif
#endif
#ifdef ENABLE_KUNLUN_API
#ifdef ENABLE_KUNLUN_API
CREATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
CREATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
...
@@ -94,7 +94,7 @@ __C infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle) {
...
@@ -94,7 +94,7 @@ __C infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle) {
DELETE
(
INFINI_DEVICE_ASCEND
,
ascend
);
DELETE
(
INFINI_DEVICE_ASCEND
,
ascend
);
#endif
#endif
#ifdef ENABLE_MOORE_API
#ifdef ENABLE_MOORE_API
DELETE
(
INFINI_DEVICE_MOORE
,
m
usa
);
DELETE
(
INFINI_DEVICE_MOORE
,
m
oore
);
#endif
#endif
#ifdef ENABLE_KUNLUN_API
#ifdef ENABLE_KUNLUN_API
DELETE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
DELETE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
...
...
src/infiniop/devices/m
usa/
common
_musa
.h
→
src/infiniop/devices/m
oore/moore_
common.h
View file @
66a8eb93
#include "../../../utils.h"
#include "../../../utils.h"
#include "../pool.h"
#include "../pool.h"
#include "m
usa
_handle.h"
#include "m
oore
_handle.h"
#include <mublas.h>
#include <mublas.h>
#include <mudnn.h>
#include <mudnn.h>
#include <musa.h>
#include <musa.h>
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
#define CHECK_MUBLAS(API) CHECK_INTERNAL(API, MUBLAS_STATUS_SUCCESS)
#define CHECK_MUBLAS(API) CHECK_INTERNAL(API, MUBLAS_STATUS_SUCCESS)
#define CHECK_MUDNN(API) CHECK_INTERNAL((int)API, (int)::musa::dnn::Status::SUCCESS)
#define CHECK_MUDNN(API) CHECK_INTERNAL((int)API, (int)::musa::dnn::Status::SUCCESS)
namespace
device
::
m
usa
{
namespace
device
::
m
oore
{
class
Handle
::
Internal
{
class
Handle
::
Internal
{
Pool
<
std
::
unique_ptr
<
mublasHandle_t
>>
mublas_handles
;
Pool
<
std
::
unique_ptr
<
mublasHandle_t
>>
mublas_handles
;
...
@@ -39,4 +39,4 @@ public:
...
@@ -39,4 +39,4 @@ public:
int
gridSizeZ
()
const
;
int
gridSizeZ
()
const
;
};
};
}
// namespace device::m
usa
}
// namespace device::m
oore
src/infiniop/devices/m
usa/musa
_handle.cc
→
src/infiniop/devices/m
oore/moore
_handle.cc
View file @
66a8eb93
#include "common
_musa
.h"
#include "
moore_
common.h"
namespace
device
::
m
usa
{
namespace
device
::
m
oore
{
Handle
::
Handle
(
infiniDevice_t
device
,
int
device_id
)
Handle
::
Handle
(
infiniDevice_t
device
,
int
device_id
)
:
InfiniopHandle
{
device
,
device_id
},
:
InfiniopHandle
{
device
,
device_id
},
_internal
(
std
::
make_shared
<
Handle
::
Internal
>
(
device_id
))
{}
_internal
(
std
::
make_shared
<
Handle
::
Internal
>
(
device_id
))
{}
...
@@ -67,4 +67,4 @@ infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
...
@@ -67,4 +67,4 @@ infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
return
INFINI_STATUS_SUCCESS
;
return
INFINI_STATUS_SUCCESS
;
}
}
}
// namespace device::m
usa
}
// namespace device::m
oore
src/infiniop/devices/m
usa/musa
_handle.h
→
src/infiniop/devices/m
oore/moore
_handle.h
View file @
66a8eb93
#ifndef __INFINIOP_M
USA
_HANDLE_H__
#ifndef __INFINIOP_M
OORE
_HANDLE_H__
#define __INFINIOP_M
USA
_HANDLE_H__
#define __INFINIOP_M
OORE
_HANDLE_H__
#include "../../handle.h"
#include "../../handle.h"
#include <memory>
#include <memory>
namespace
device
::
m
usa
{
namespace
device
::
m
oore
{
struct
Handle
:
public
InfiniopHandle
{
struct
Handle
:
public
InfiniopHandle
{
Handle
(
int
device_id
);
Handle
(
int
device_id
);
class
Internal
;
class
Internal
;
...
@@ -20,6 +20,6 @@ private:
...
@@ -20,6 +20,6 @@ private:
std
::
shared_ptr
<
Internal
>
_internal
;
std
::
shared_ptr
<
Internal
>
_internal
;
};
};
}
// namespace device::m
usa
}
// namespace device::m
oore
#endif // __INFINIOP_M
USA
_HANDLE_H__
#endif // __INFINIOP_M
OORE
_HANDLE_H__
src/infiniop/devices/m
usa/musa
_kernel_common.h
→
src/infiniop/devices/m
oore/moore
_kernel_common.h
View file @
66a8eb93
#define INFINIOP_M
USA
_KERNEL __global__ void
#define INFINIOP_M
OORE
_KERNEL __global__ void
#include <musa_bf16.h>
#include <musa_bf16.h>
#include <musa_fp16.h>
#include <musa_fp16.h>
// Posible maximum number of threads per block for MUSA architectures
// Posible maximum number of threads per block for MUSA architectures
// Used for picking correct kernel launch configuration
// Used for picking correct kernel launch configuration
#define M
USA
_BLOCK_SIZE_2048 2048
#define M
OORE
_BLOCK_SIZE_2048 2048
#define M
USA
_BLOCK_SIZE_1024 1024
#define M
OORE
_BLOCK_SIZE_1024 1024
#define M
USA
_BLOCK_SIZE_512 512
#define M
OORE
_BLOCK_SIZE_512 512
#define CHECK_M
USA
(API) CHECK_INTERNAL(API, musaSuccess)
#define CHECK_M
OORE
(API) CHECK_INTERNAL(API, musaSuccess)
using
musa_bfloat16
=
mt_bfloat16
;
using
musa_bfloat16
=
mt_bfloat16
;
using
musa_bfloat162
=
mt_bfloat162
;
using
musa_bfloat162
=
mt_bfloat162
;
namespace
device
::
m
usa
{
namespace
device
::
m
oore
{
// return the memory offset of original tensor, given the flattened index of broadcasted tensor
// return the memory offset of original tensor, given the flattened index of broadcasted tensor
__forceinline__
__device__
__host__
size_t
__forceinline__
__device__
__host__
size_t
...
@@ -45,7 +45,7 @@ indexToOffset(
...
@@ -45,7 +45,7 @@ indexToOffset(
}
}
return
res
;
return
res
;
}
}
}
// namespace device::m
usa
}
// namespace device::m
oore
__forceinline__
__device__
float
__forceinline__
__device__
float
exp_
(
const
float
val
)
{
exp_
(
const
float
val
)
{
...
...
src/infiniop/ops/gemm/moore/gemm_moore.h
0 → 100644
View file @
66a8eb93
#ifndef __GEMM_MOORE_H__
#define __GEMM_MOORE_H__
#include "../gemm.h"
DESCRIPTOR
(
moore
)
#endif // __GEMM_MOORE_H__
src/infiniop/ops/gemm/m
usa
/gemm_m
usa
.mu
→
src/infiniop/ops/gemm/m
oore
/gemm_m
oore
.mu
View file @
66a8eb93
#include "../../../devices/m
usa/
common
_musa
.h"
#include "../../../devices/m
oore/moore_
common.h"
#include "../../../devices/m
usa/musa
_handle.h"
#include "../../../devices/m
oore/moore
_handle.h"
#include "gemm_m
usa
.h"
#include "gemm_m
oore
.h"
namespace op::gemm::m
usa
{
namespace op::gemm::m
oore
{
struct Descriptor::Opaque {
struct Descriptor::Opaque {
std::shared_ptr<device::m
usa
::Handle::Internal> internal;
std::shared_ptr<device::m
oore
::Handle::Internal> internal;
};
};
Descriptor::~Descriptor() {
Descriptor::~Descriptor() {
...
@@ -18,10 +18,10 @@ infiniStatus_t Descriptor::create(
...
@@ -18,10 +18,10 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
infiniopTensorDescriptor_t b_desc) {
auto handle = reinterpret_cast<device::m
usa
::Handle *>(handle_);
auto handle = reinterpret_cast<device::m
oore
::Handle *>(handle_);
auto dtype = c_desc->dtype();
auto dtype = c_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32);
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32
, INFINI_DTYPE_BF16
);
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
CHECK_RESULT(result);
CHECK_RESULT(result);
...
@@ -33,41 +33,63 @@ infiniStatus_t Descriptor::create(
...
@@ -33,41 +33,63 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_SUCCESS;
return INFINI_STATUS_SUCCESS;
}
}
template <typename Tdata>
infiniStatus_t Descriptor::calculate(
infiniStatus_t calculate(
void *workspace,
const MatmulInfo &info,
size_t workspace_size,
std::shared_ptr<device::musa::Handle::Internal> &_internal,
void *c,
void *c,
float beta,
float beta,
const void *a,
const void *a,
const void *b,
const void *b,
float alpha,
float alpha,
void *stream) {
void *stream)
const
{
musaDataType a_type, b_type, c_type;
musaDataType a_type, b_type, c_type;
mublasComputeType_t compute_type;
mublasComputeType_t compute_type;
Tdata alpha_, beta_;
if constexpr (std::is_same<Tdata, half>::value) {
// MUSA's GEMM operations require that the scalar values alpha and beta have the same data type as the matrices.
alpha_ = __float2half(alpha);
// This ensures correct computation during the muBLAS GEMM operation.
beta_ = __float2half(beta);
// Declare half-precision variables to handle F16 types.
half alpha_h, beta_h;
// Initialize generic void pointers for alpha and beta.
// They point to the original float values
// It will be used directly when the GEMM operation is performed with F32 data.
const void *p_alpha = α
const void *p_beta = β
switch (_dtype) {
case INFINI_DTYPE_F16:
a_type = b_type = c_type = MUSA_R_16F;
a_type = b_type = c_type = MUSA_R_16F;
compute_type = MUBLAS_COMPUTE_16F;
compute_type = MUBLAS_COMPUTE_16F;
} else {
alpha_ = alpha;
// Convert alpha/beta to half-precision and update the pointers.
beta_ = beta;
alpha_h = __float2half(alpha);
beta_h = __float2half(beta);
p_alpha = &alpha_h;
p_beta = &beta_h;
break;
case INFINI_DTYPE_BF16:
a_type = b_type = c_type = MUSA_R_16BF;
compute_type = MUBLAS_COMPUTE_32F;
break;
case INFINI_DTYPE_F32:
a_type = b_type = c_type = MUSA_R_32F;
a_type = b_type = c_type = MUSA_R_32F;
compute_type = MUBLAS_COMPUTE_32F_FAST_TF32;
compute_type = MUBLAS_COMPUTE_32F_FAST_TF32;
break;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
if (info.is_transed) {
if (
_
info.is_transed) {
std::swap(a, b);
std::swap(a, b);
}
}
auto op_a = info.a_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T;
auto op_a =
_
info.a_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T;
auto op_b = info.b_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T;
auto op_b =
_
info.b_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T;
CHECK_STATUS(_internal->useMublas(
CHECK_STATUS(_
opaque->
internal->useMublas(
(musaStream_t)stream,
(musaStream_t)stream,
[&](mublasHandle_t handle) {
[&](mublasHandle_t handle) {
CHECK_MUBLAS(
CHECK_MUBLAS(
...
@@ -75,24 +97,24 @@ infiniStatus_t calculate(
...
@@ -75,24 +97,24 @@ infiniStatus_t calculate(
handle,
handle,
op_a,
op_a,
op_b,
op_b,
static_cast<int>(info.m),
static_cast<int>(
_
info.m),
static_cast<int>(info.n),
static_cast<int>(
_
info.n),
static_cast<int>(info.k),
static_cast<int>(
_
info.k),
&
alpha
_
,
p_
alpha,
a,
a,
a_type,
a_type,
static_cast<int>(info.a_matrix.ld()),
static_cast<int>(
_
info.a_matrix.ld()),
info.a_matrix.stride,
_
info.a_matrix.stride,
b,
b,
b_type,
b_type,
static_cast<int>(info.b_matrix.ld()),
static_cast<int>(
_
info.b_matrix.ld()),
info.b_matrix.stride,
_
info.b_matrix.stride,
&
beta
_
,
p_
beta,
c,
c,
c_type,
c_type,
static_cast<int>(info.c_matrix.ld()),
static_cast<int>(
_
info.c_matrix.ld()),
info.c_matrix.stride,
_
info.c_matrix.stride,
static_cast<int>(info.batch),
static_cast<int>(
_
info.batch),
compute_type,
compute_type,
MUBLAS_GEMM_DEFAULT));
MUBLAS_GEMM_DEFAULT));
return INFINI_STATUS_SUCCESS;
return INFINI_STATUS_SUCCESS;
...
@@ -100,22 +122,4 @@ infiniStatus_t calculate(
...
@@ -100,22 +122,4 @@ infiniStatus_t calculate(
return INFINI_STATUS_SUCCESS;
return INFINI_STATUS_SUCCESS;
}
}
infiniStatus_t Descriptor::calculate(void *workspace,
} // namespace op::gemm::moore
size_t workspace_size,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
void *stream) const {
switch (_dtype) {
case INFINI_DTYPE_F16:
return musa::calculate<half>(_info, _opaque->internal, c, beta, a, b, alpha, stream);
case INFINI_DTYPE_F32:
return musa::calculate<float>(_info,_opaque->internal, c, beta, a, b, alpha, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
} // namespace op::gemm::musa
src/infiniop/ops/gemm/musa/gemm_musa.h
deleted
100644 → 0
View file @
831021b8
#ifndef __GEMM_MUSA_H__
#define __GEMM_MUSA_H__
#include "../gemm.h"
DESCRIPTOR
(
musa
)
#endif // __GEMM_MUSA_H__
src/infiniop/ops/gemm/operator.cc
View file @
66a8eb93
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
#include "metax/gemm_metax.h"
#include "metax/gemm_metax.h"
#endif
#endif
#ifdef ENABLE_MOORE_API
#ifdef ENABLE_MOORE_API
#include "m
usa
/gemm_m
usa
.h"
#include "m
oore
/gemm_m
oore
.h"
#endif
#endif
#ifdef ENABLE_KUNLUN_API
#ifdef ENABLE_KUNLUN_API
#include "kunlun/gemm_kunlun.h"
#include "kunlun/gemm_kunlun.h"
...
@@ -61,7 +61,7 @@ __C infiniStatus_t infiniopCreateGemmDescriptor(
...
@@ -61,7 +61,7 @@ __C infiniStatus_t infiniopCreateGemmDescriptor(
CREATE
(
INFINI_DEVICE_METAX
,
metax
);
CREATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#endif
#ifdef ENABLE_MOORE_API
#ifdef ENABLE_MOORE_API
CREATE
(
INFINI_DEVICE_MOORE
,
m
usa
);
CREATE
(
INFINI_DEVICE_MOORE
,
m
oore
);
#endif
#endif
#ifdef ENABLE_KUNLUN_API
#ifdef ENABLE_KUNLUN_API
...
@@ -106,7 +106,7 @@ infiniopGetGemmWorkspaceSize(
...
@@ -106,7 +106,7 @@ infiniopGetGemmWorkspaceSize(
GET
(
INFINI_DEVICE_METAX
,
metax
);
GET
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#endif
#ifdef ENABLE_MOORE_API
#ifdef ENABLE_MOORE_API
GET
(
INFINI_DEVICE_MOORE
,
m
usa
);
GET
(
INFINI_DEVICE_MOORE
,
m
oore
);
#endif
#endif
#ifdef ENABLE_KUNLUN_API
#ifdef ENABLE_KUNLUN_API
GET
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
GET
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
...
@@ -158,7 +158,7 @@ __C infiniStatus_t infiniopGemm(
...
@@ -158,7 +158,7 @@ __C infiniStatus_t infiniopGemm(
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
);
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#endif
#ifdef ENABLE_MOORE_API
#ifdef ENABLE_MOORE_API
CALCULATE
(
INFINI_DEVICE_MOORE
,
m
usa
);
CALCULATE
(
INFINI_DEVICE_MOORE
,
m
oore
);
#endif
#endif
#ifdef ENABLE_KUNLUN_API
#ifdef ENABLE_KUNLUN_API
CALCULATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
CALCULATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
...
@@ -200,7 +200,7 @@ infiniopDestroyGemmDescriptor(infiniopGemmDescriptor_t desc) {
...
@@ -200,7 +200,7 @@ infiniopDestroyGemmDescriptor(infiniopGemmDescriptor_t desc) {
DELETE
(
INFINI_DEVICE_METAX
,
metax
);
DELETE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#endif
#ifdef ENABLE_MOORE_API
#ifdef ENABLE_MOORE_API
DELETE
(
INFINI_DEVICE_MOORE
,
m
usa
);
DELETE
(
INFINI_DEVICE_MOORE
,
m
oore
);
#endif
#endif
#ifdef ENABLE_KUNLUN_API
#ifdef ENABLE_KUNLUN_API
DELETE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
DELETE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
...
...
src/infiniop/ops/rms_norm/m
usa
/rms_norm_m
usa
.h
→
src/infiniop/ops/rms_norm/m
oore
/rms_norm_m
oore
.h
View file @
66a8eb93
#ifndef __RMS_NORM_M
USA
_H__
#ifndef __RMS_NORM_M
OORE
_H__
#define __RMS_NORM_M
USA
_H__
#define __RMS_NORM_M
OORE
_H__
#include "../rms_norm.h"
#include "../rms_norm.h"
DESCRIPTOR
(
m
usa
)
DESCRIPTOR
(
m
oore
)
#endif
#endif
src/infiniop/ops/rms_norm/m
usa
/rms_norm_m
usa
.mu
→
src/infiniop/ops/rms_norm/m
oore
/rms_norm_m
oore
.mu
View file @
66a8eb93
#include "../../../devices/m
usa/
common
_musa
.h"
#include "../../../devices/m
oore/moore_
common.h"
#include "rms_norm_m
usa
.h"
#include "rms_norm_m
oore
.h"
#include "../../../devices/m
usa/musa
_kernel_common.h"
#include "../../../devices/m
oore/moore
_kernel_common.h"
#include <cub/block/block_reduce.cuh>
#include <cub/block/block_reduce.cuh>
#include "../../../reduce/cuda/reduce.cuh"
#include "../../../reduce/cuda/reduce.cuh"
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
#include "../cuda/kernel.cuh"
#include "../cuda/kernel.cuh"
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_M
USA
_KERNEL rmsnormKernel(
INFINIOP_M
OORE
_KERNEL rmsnormKernel(
Tdata *__restrict__ y,
Tdata *__restrict__ y,
ptrdiff_t stride_y,
ptrdiff_t stride_y,
const Tdata *__restrict__ x,
const Tdata *__restrict__ x,
...
@@ -20,10 +20,10 @@ INFINIOP_MUSA_KERNEL rmsnormKernel(
...
@@ -20,10 +20,10 @@ INFINIOP_MUSA_KERNEL rmsnormKernel(
rmsnormBlock<BLOCK_SIZE, Tcompute>(y, stride_y, x, stride_x, w, dim, epsilon);
rmsnormBlock<BLOCK_SIZE, Tcompute>(y, stride_y, x, stride_x, w, dim, epsilon);
}
}
namespace op::rms_norm::m
usa
{
namespace op::rms_norm::m
oore
{
struct Descriptor::Opaque {
struct Descriptor::Opaque {
std::shared_ptr<device::m
usa
::Handle::Internal> internal;
std::shared_ptr<device::m
oore
::Handle::Internal> internal;
};
};
Descriptor::~Descriptor() {
Descriptor::~Descriptor() {
...
@@ -47,7 +47,7 @@ infiniStatus_t Descriptor::create(
...
@@ -47,7 +47,7 @@ infiniStatus_t Descriptor::create(
}
}
*desc_ptr = new Descriptor(
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::m
usa
::Handle *>(handle)->internal()},
new Opaque{reinterpret_cast<device::m
oore
::Handle *>(handle)->internal()},
std::move(info),
std::move(info),
0,
0,
handle->device, handle->device_id);
handle->device, handle->device_id);
...
@@ -109,15 +109,15 @@ infiniStatus_t Descriptor::calculate(
...
@@ -109,15 +109,15 @@ infiniStatus_t Descriptor::calculate(
auto musa_stream = reinterpret_cast<musaStream_t>(stream);
auto musa_stream = reinterpret_cast<musaStream_t>(stream);
// launch kernel with different block sizes
// launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == M
USA
_BLOCK_SIZE_1024) {
if (_opaque->internal->maxThreadsPerBlock() == M
OORE
_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<M
USA
_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
CHECK_STATUS(launchKernel<M
OORE
_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == M
USA
_BLOCK_SIZE_512) {
} else if (_opaque->internal->maxThreadsPerBlock() == M
OORE
_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<M
USA
_BLOCK_SIZE_512>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
CHECK_STATUS(launchKernel<M
OORE
_BLOCK_SIZE_512>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == M
USA
_BLOCK_SIZE_2048) {
} else if (_opaque->internal->maxThreadsPerBlock() == M
OORE
_BLOCK_SIZE_2048) {
CHECK_STATUS(launchKernel<M
USA
_BLOCK_SIZE_2048>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
CHECK_STATUS(launchKernel<M
OORE
_BLOCK_SIZE_2048>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
} else {
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
}
return INFINI_STATUS_SUCCESS;
return INFINI_STATUS_SUCCESS;
}
}
} // namespace op::rms_norm::m
usa
} // namespace op::rms_norm::m
oore
src/infiniop/ops/rms_norm/operator.cc
View file @
66a8eb93
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
#include "metax/rms_norm_metax.cuh"
#include "metax/rms_norm_metax.cuh"
#endif
#endif
#ifdef ENABLE_MOORE_API
#ifdef ENABLE_MOORE_API
#include "m
usa
/rms_norm_m
usa
.h"
#include "m
oore
/rms_norm_m
oore
.h"
#endif
#endif
#ifdef ENABLE_KUNLUN_API
#ifdef ENABLE_KUNLUN_API
#include "kunlun/rms_norm_kunlun.h"
#include "kunlun/rms_norm_kunlun.h"
...
@@ -64,7 +64,7 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
...
@@ -64,7 +64,7 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
CREATE
(
INFINI_DEVICE_METAX
,
metax
);
CREATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#endif
#ifdef ENABLE_MOORE_API
#ifdef ENABLE_MOORE_API
CREATE
(
INFINI_DEVICE_MOORE
,
m
usa
);
CREATE
(
INFINI_DEVICE_MOORE
,
m
oore
);
#endif
#endif
}
}
...
@@ -105,7 +105,7 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d
...
@@ -105,7 +105,7 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d
GET
(
INFINI_DEVICE_METAX
,
metax
);
GET
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#endif
#ifdef ENABLE_MOORE_API
#ifdef ENABLE_MOORE_API
GET
(
INFINI_DEVICE_MOORE
,
m
usa
);
GET
(
INFINI_DEVICE_MOORE
,
m
oore
);
#endif
#endif
}
}
...
@@ -147,7 +147,7 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works
...
@@ -147,7 +147,7 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
);
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#endif
#ifdef ENABLE_MOORE_API
#ifdef ENABLE_MOORE_API
CALCULATE
(
INFINI_DEVICE_MOORE
,
m
usa
);
CALCULATE
(
INFINI_DEVICE_MOORE
,
m
oore
);
#endif
#endif
}
}
...
@@ -188,7 +188,7 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t
...
@@ -188,7 +188,7 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t
DESTROY
(
INFINI_DEVICE_METAX
,
metax
);
DESTROY
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#endif
#ifdef ENABLE_MOORE_API
#ifdef ENABLE_MOORE_API
DESTROY
(
INFINI_DEVICE_MOORE
,
m
usa
);
DESTROY
(
INFINI_DEVICE_MOORE
,
m
oore
);
#endif
#endif
}
}
...
...
src/infinirt/infinirt.cc
View file @
66a8eb93
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include "cuda/infinirt_cuda.cuh"
#include "cuda/infinirt_cuda.cuh"
#include "kunlun/infinirt_kunlun.h"
#include "kunlun/infinirt_kunlun.h"
#include "metax/infinirt_metax.h"
#include "metax/infinirt_metax.h"
#include "m
usa
/infinirt_m
usa
.h"
#include "m
oore
/infinirt_m
oore
.h"
thread_local
infiniDevice_t
CURRENT_DEVICE_TYPE
=
INFINI_DEVICE_CPU
;
thread_local
infiniDevice_t
CURRENT_DEVICE_TYPE
=
INFINI_DEVICE_CPU
;
thread_local
int
CURRENT_DEVICE_ID
=
0
;
thread_local
int
CURRENT_DEVICE_ID
=
0
;
...
...
src/infinirt/m
usa
/infinirt_m
usa
.cc
→
src/infinirt/m
oore
/infinirt_m
oore
.cc
View file @
66a8eb93
#include "infinirt_m
usa
.h"
#include "infinirt_m
oore
.h"
#include "../../utils.h"
#include "../../utils.h"
#include <musa_runtime.h>
#include <musa_runtime.h>
#include <musa_runtime_api.h>
#include <musa_runtime_api.h>
...
...
src/infinirt/m
usa
/infinirt_m
usa
.h
→
src/infinirt/m
oore
/infinirt_m
oore
.h
View file @
66a8eb93
File moved
xmake.lua
View file @
66a8eb93
...
@@ -119,7 +119,7 @@ option_end()
...
@@ -119,7 +119,7 @@ option_end()
if
has_config
(
"moore-gpu"
)
then
if
has_config
(
"moore-gpu"
)
then
add_defines
(
"ENABLE_MOORE_API"
)
add_defines
(
"ENABLE_MOORE_API"
)
includes
(
"xmake/m
usa
.lua"
)
includes
(
"xmake/m
oore
.lua"
)
end
end
-- 海光
-- 海光
...
...
xmake/m
usa
.lua
→
xmake/m
oore
.lua
View file @
66a8eb93
...
@@ -42,8 +42,8 @@ target("infiniop-moore")
...
@@ -42,8 +42,8 @@ target("infiniop-moore")
set_languages
(
"cxx17"
)
set_languages
(
"cxx17"
)
set_warnings
(
"all"
,
"error"
)
set_warnings
(
"all"
,
"error"
)
add_cxflags
(
"-lstdc++"
,
"-fPIC"
,
"-Wno-comment"
)
add_cxflags
(
"-lstdc++"
,
"-fPIC"
,
"-Wno-comment"
)
add_files
(
"../src/infiniop/devices/m
usa
/*.cc"
)
add_files
(
"../src/infiniop/devices/m
oore
/*.cc"
)
add_files
(
"../src/infiniop/ops/*/m
usa
/*.mu"
,
{
rule
=
"mu"
})
add_files
(
"../src/infiniop/ops/*/m
oore
/*.mu"
,
{
rule
=
"mu"
})
target_end
()
target_end
()
target
(
"infinirt-moore"
)
target
(
"infinirt-moore"
)
...
@@ -53,5 +53,5 @@ target("infinirt-moore")
...
@@ -53,5 +53,5 @@ target("infinirt-moore")
add_deps
(
"infini-utils"
)
add_deps
(
"infini-utils"
)
set_warnings
(
"all"
,
"error"
)
set_warnings
(
"all"
,
"error"
)
add_cxflags
(
"-lstdc++"
,
"-fPIC"
)
add_cxflags
(
"-lstdc++"
,
"-fPIC"
)
add_files
(
"../src/infinirt/m
usa
/*.cc"
)
add_files
(
"../src/infinirt/m
oore
/*.cc"
)
target_end
()
target_end
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment