Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
e4605f7c
Unverified
Commit
e4605f7c
authored
Jul 11, 2025
by
PanZezhong1725
Committed by
GitHub
Jul 11, 2025
Browse files
Merge pull request #293 from YdrMaster/distinct-cuda
issue291 合并 cuda 代码
parents
5025ebed
eac2b0ca
Changes
68
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
82 additions
and
66 deletions
+82
-66
src/infiniop/ops/gemm/metax/gemm_metax.cc
src/infiniop/ops/gemm/metax/gemm_metax.cc
+8
-7
src/infiniop/ops/gemm/metax/gemm_metax.h
src/infiniop/ops/gemm/metax/gemm_metax.h
+1
-1
src/infiniop/ops/gemm/operator.cc
src/infiniop/ops/gemm/operator.cc
+5
-5
src/infiniop/ops/mul/cpu/mul_cpu.h
src/infiniop/ops/mul/cpu/mul_cpu.h
+1
-1
src/infiniop/ops/mul/cuda/kernel.cuh
src/infiniop/ops/mul/cuda/kernel.cuh
+0
-0
src/infiniop/ops/mul/nvidia/mul_nvidia.cu
src/infiniop/ops/mul/nvidia/mul_nvidia.cu
+8
-8
src/infiniop/ops/mul/nvidia/mul_nvidia.cuh
src/infiniop/ops/mul/nvidia/mul_nvidia.cuh
+1
-1
src/infiniop/ops/mul/operator.cc
src/infiniop/ops/mul/operator.cc
+7
-7
src/infiniop/ops/random_sample/metax/random_sample_kernel.h
src/infiniop/ops/random_sample/metax/random_sample_kernel.h
+2
-2
src/infiniop/ops/random_sample/metax/random_sample_metax.h
src/infiniop/ops/random_sample/metax/random_sample_metax.h
+1
-1
src/infiniop/ops/random_sample/metax/random_sample_metax.maca
...infiniop/ops/random_sample/metax/random_sample_metax.maca
+2
-2
src/infiniop/ops/random_sample/operator.cc
src/infiniop/ops/random_sample/operator.cc
+5
-5
src/infiniop/ops/rearrange/metax/rearrange_kernel.h
src/infiniop/ops/rearrange/metax/rearrange_kernel.h
+0
-0
src/infiniop/ops/rearrange/metax/rearrange_metax.h
src/infiniop/ops/rearrange/metax/rearrange_metax.h
+1
-1
src/infiniop/ops/rearrange/metax/rearrange_metax.maca
src/infiniop/ops/rearrange/metax/rearrange_metax.maca
+3
-3
src/infiniop/ops/rearrange/operator.cc
src/infiniop/ops/rearrange/operator.cc
+4
-4
src/infiniop/ops/relu/cpu/relu_cpu.h
src/infiniop/ops/relu/cpu/relu_cpu.h
+1
-1
src/infiniop/ops/rms_norm/cuda/kernel.cuh
src/infiniop/ops/rms_norm/cuda/kernel.cuh
+2
-5
src/infiniop/ops/rms_norm/metax/rms_norm_metax.cuh
src/infiniop/ops/rms_norm/metax/rms_norm_metax.cuh
+0
-0
src/infiniop/ops/rms_norm/metax/rms_norm_metax.maca
src/infiniop/ops/rms_norm/metax/rms_norm_metax.maca
+30
-12
No files found.
src/infiniop/ops/gemm/m
aca
/gemm_m
aca
.cc
→
src/infiniop/ops/gemm/m
etax
/gemm_m
etax
.cc
View file @
e4605f7c
#include "gemm_m
aca
.h"
#include "gemm_m
etax
.h"
#include "../../../devices/maca/common_maca.h"
#include "../../../devices/maca/maca_handle.h"
namespace
op
::
gemm
::
m
aca
{
namespace
op
::
gemm
::
m
etax
{
struct
Descriptor
::
Opaque
{
std
::
shared_ptr
<
device
::
maca
::
Handle
::
Internal
>
internal
;
...
...
@@ -21,9 +21,7 @@ infiniStatus_t Descriptor::create(
auto
handle
=
reinterpret_cast
<
device
::
maca
::
Handle
*>
(
handle_
);
auto
dtype
=
c_desc
->
dtype
();
if
(
dtype
!=
INFINI_DTYPE_F16
&&
dtype
!=
INFINI_DTYPE_F32
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_F32
,
INFINI_DTYPE_BF16
);
auto
result
=
MatmulInfo
::
create
(
c_desc
,
a_desc
,
b_desc
,
MatrixLayout
::
COL_MAJOR
);
CHECK_RESULT
(
result
);
...
...
@@ -53,7 +51,10 @@ infiniStatus_t Descriptor::calculate(
a_type
=
b_type
=
c_type
=
HPCC_R_16F
;
compute_type
=
HCBLAS_COMPUTE_32F
;
break
;
case
INFINI_DTYPE_BF16
:
a_type
=
b_type
=
c_type
=
HPCC_R_16BF
;
compute_type
=
HCBLAS_COMPUTE_32F
;
break
;
case
INFINI_DTYPE_F32
:
a_type
=
b_type
=
c_type
=
HPCC_R_32F
;
compute_type
=
HCBLAS_COMPUTE_32F_FAST_TF32
;
...
...
@@ -103,4 +104,4 @@ infiniStatus_t Descriptor::calculate(
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::gemm::m
aca
}
// namespace op::gemm::m
etax
src/infiniop/ops/gemm/m
aca
/gemm_m
aca
.h
→
src/infiniop/ops/gemm/m
etax
/gemm_m
etax
.h
View file @
e4605f7c
...
...
@@ -3,6 +3,6 @@
#include "../gemm.h"
DESCRIPTOR
(
m
aca
)
DESCRIPTOR
(
m
etax
)
#endif // __GEMM_MACA_H__
src/infiniop/ops/gemm/operator.cc
View file @
e4605f7c
...
...
@@ -15,7 +15,7 @@
#include "ascend/gemm_ascend.h"
#endif
#ifdef ENABLE_METAX_API
#include "m
aca
/gemm_m
aca
.h"
#include "m
etax
/gemm_m
etax
.h"
#endif
#ifdef ENABLE_MOORE_API
#include "musa/gemm_musa.h"
...
...
@@ -55,7 +55,7 @@ __C infiniStatus_t infiniopCreateGemmDescriptor(
CREATE
(
INFINI_DEVICE_ASCEND
,
ascend
);
#endif
#ifdef ENABLE_METAX_API
CREATE
(
INFINI_DEVICE_METAX
,
m
aca
);
CREATE
(
INFINI_DEVICE_METAX
,
m
etax
);
#endif
#ifdef ENABLE_MOORE_API
CREATE
(
INFINI_DEVICE_MOORE
,
musa
);
...
...
@@ -97,7 +97,7 @@ infiniopGetGemmWorkspaceSize(
GET
(
INFINI_DEVICE_ASCEND
,
ascend
);
#endif
#ifdef ENABLE_METAX_API
GET
(
INFINI_DEVICE_METAX
,
m
aca
);
GET
(
INFINI_DEVICE_METAX
,
m
etax
);
#endif
#ifdef ENABLE_MOORE_API
GET
(
INFINI_DEVICE_MOORE
,
musa
);
...
...
@@ -146,7 +146,7 @@ __C infiniStatus_t infiniopGemm(
CALCULATE
(
INFINI_DEVICE_ASCEND
,
ascend
);
#endif
#ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
m
aca
);
CALCULATE
(
INFINI_DEVICE_METAX
,
m
etax
);
#endif
#ifdef ENABLE_MOORE_API
CALCULATE
(
INFINI_DEVICE_MOORE
,
musa
);
...
...
@@ -185,7 +185,7 @@ infiniopDestroyGemmDescriptor(infiniopGemmDescriptor_t desc) {
DELETE
(
INFINI_DEVICE_ASCEND
,
ascend
);
#endif
#ifdef ENABLE_METAX_API
DELETE
(
INFINI_DEVICE_METAX
,
m
aca
);
DELETE
(
INFINI_DEVICE_METAX
,
m
etax
);
#endif
#ifdef ENABLE_MOORE_API
DELETE
(
INFINI_DEVICE_MOORE
,
musa
);
...
...
src/infiniop/ops/mul/cpu/mul_cpu.h
View file @
e4605f7c
...
...
@@ -3,7 +3,7 @@
#include "../../../elementwise/cpu/elementwise_cpu.h"
ELEMENTWISE_DESCRIPTOR
(
mul
,
cpu
)
ELEMENTWISE_DESCRIPTOR
(
mul
,
cpu
,
cpu
)
namespace
op
::
mul
::
cpu
{
typedef
struct
MulOp
{
...
...
src/infiniop/ops/mul/cuda/
mul_cuda_int
ern
a
l.cuh
→
src/infiniop/ops/mul/cuda/
k
ern
e
l.cuh
View file @
e4605f7c
File moved
src/infiniop/ops/mul/
cud
a/mul_
cud
a.cu
→
src/infiniop/ops/mul/
nvidi
a/mul_
nvidi
a.cu
View file @
e4605f7c
#include "
mul_cuda
.cuh"
#include "mul_
cuda_internal
.cuh"
#include "
../cuda/kernel
.cuh"
#include "mul_
nvidia
.cuh"
namespace
op
::
mul
::
cud
a
{
namespace
op
::
mul
::
nvidi
a
{
Descriptor
::~
Descriptor
()
=
default
;
...
...
@@ -43,17 +43,17 @@ infiniStatus_t Descriptor::calculate(
switch
(
_dtype
)
{
case
INFINI_DTYPE_F16
:
return
_device_info
->
calculate
<
256
,
MulOp
,
half
>
(
_info
,
workspace
,
output
,
inputs
,
stream
);
return
_device_info
->
calculate
<
256
,
cuda
::
MulOp
,
half
>
(
_info
,
workspace
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_F32
:
return
_device_info
->
calculate
<
256
,
MulOp
,
float
>
(
_info
,
workspace
,
output
,
inputs
,
stream
);
return
_device_info
->
calculate
<
256
,
cuda
::
MulOp
,
float
>
(
_info
,
workspace
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_F64
:
return
_device_info
->
calculate
<
256
,
MulOp
,
double
>
(
_info
,
workspace
,
output
,
inputs
,
stream
);
return
_device_info
->
calculate
<
256
,
cuda
::
MulOp
,
double
>
(
_info
,
workspace
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_BF16
:
return
_device_info
->
calculate
<
256
,
MulOp
,
__nv_bfloat16
>
(
_info
,
workspace
,
output
,
inputs
,
stream
);
return
_device_info
->
calculate
<
256
,
cuda
::
MulOp
,
__nv_bfloat16
>
(
_info
,
workspace
,
output
,
inputs
,
stream
);
default:
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::mul::
cud
a
}
// namespace op::mul::
nvidi
a
src/infiniop/ops/mul/
cud
a/mul_
cud
a.cuh
→
src/infiniop/ops/mul/
nvidi
a/mul_
nvidi
a.cuh
View file @
e4605f7c
...
...
@@ -3,6 +3,6 @@
#include "../../../elementwise/cuda/elementwise_cuda_api.cuh"
ELEMENTWISE_DESCRIPTOR
(
mul
,
cuda
)
ELEMENTWISE_DESCRIPTOR
(
mul
,
nvidia
,
cuda
)
#endif // __MUL_CUDA_API_H__
src/infiniop/ops/mul/operator.cc
View file @
e4605f7c
...
...
@@ -7,7 +7,7 @@
#endif
#ifdef ENABLE_NVIDIA_API
#include "
cud
a/mul_
cud
a.cuh"
#include "
nvidi
a/mul_
nvidi
a.cuh"
#endif
__C
infiniStatus_t
infiniopCreateMulDescriptor
(
...
...
@@ -32,7 +32,7 @@ __C infiniStatus_t infiniopCreateMulDescriptor(
CREATE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
cud
a
);
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidi
a
);
#endif
default:
...
...
@@ -47,14 +47,14 @@ __C infiniStatus_t infiniopGetMulWorkspaceSize(infiniopMulDescriptor_t desc, siz
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::mul::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS
;
return INFINI_STATUS_SUCCESS
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
GET
(
INFINI_DEVICE_CPU
,
cpu
)
GET
(
INFINI_DEVICE_CPU
,
cpu
)
;
#endif
#ifdef ENABLE_NVIDIA_API
GET
(
INFINI_DEVICE_NVIDIA
,
cud
a
)
GET
(
INFINI_DEVICE_NVIDIA
,
nvidi
a
)
;
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -84,7 +84,7 @@ __C infiniStatus_t infiniopMul(
CALCULATE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
cud
a
);
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidi
a
);
#endif
default:
...
...
@@ -108,7 +108,7 @@ infiniopDestroyMulDescriptor(infiniopMulDescriptor_t desc) {
DELETE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
DELETE
(
INFINI_DEVICE_NVIDIA
,
cud
a
);
DELETE
(
INFINI_DEVICE_NVIDIA
,
nvidi
a
);
#endif
default:
...
...
src/infiniop/ops/random_sample/m
aca
/random_sample_kernel.h
→
src/infiniop/ops/random_sample/m
etax
/random_sample_kernel.h
View file @
e4605f7c
...
...
@@ -4,7 +4,7 @@
#include <hccub/device/device_reduce.cuh>
#include <hccub/device/device_scan.cuh>
namespace
op
::
random_sample
::
m
aca
{
namespace
op
::
random_sample
::
m
etax
{
// ↓↓↓ 重新封装 cub api,减少模板参数,方便调用
...
...
@@ -256,4 +256,4 @@ struct Algo {
}
};
}
// namespace op::random_sample::m
aca
}
// namespace op::random_sample::m
etax
src/infiniop/ops/random_sample/m
aca
/random_sample_m
aca
.h
→
src/infiniop/ops/random_sample/m
etax
/random_sample_m
etax
.h
View file @
e4605f7c
...
...
@@ -3,6 +3,6 @@
#include "../random_sample.h"
DESCRIPTOR
(
m
aca
)
DESCRIPTOR
(
m
etax
)
#endif // __RANDOM_SAMPLE_MACA_H__
src/infiniop/ops/random_sample/m
aca
/random_sample_m
aca
.maca
→
src/infiniop/ops/random_sample/m
etax
/random_sample_m
etax
.maca
View file @
e4605f7c
...
...
@@ -2,9 +2,9 @@
#include "../../../devices/maca/maca_handle.h"
#include "../info.h"
#include "random_sample_kernel.h"
#include "random_sample_m
aca
.h"
#include "random_sample_m
etax
.h"
namespace op::random_sample::m
aca
{
namespace op::random_sample::m
etax
{
struct Descriptor::Opaque {
std::shared_ptr<device::maca::Handle::Internal> internal;
...
...
src/infiniop/ops/random_sample/operator.cc
View file @
e4605f7c
...
...
@@ -9,7 +9,7 @@
#include "cuda/random_sample_cuda.cuh"
#endif
#ifdef ENABLE_METAX_API
#include "m
aca
/random_sample_m
aca
.h"
#include "m
etax
/random_sample_m
etax
.h"
#endif
#ifdef ENABLE_ASCEND_API
#include "ascend/random_sample_aclnn.h"
...
...
@@ -39,7 +39,7 @@ infiniopCreateRandomSampleDescriptor(
CREATE
(
INFINI_DEVICE_NVIDIA
,
cuda
);
#endif
#ifdef ENABLE_METAX_API
CREATE
(
INFINI_DEVICE_METAX
,
m
aca
);
CREATE
(
INFINI_DEVICE_METAX
,
m
etax
);
#endif
#ifdef ENABLE_ASCEND_API
CREATE
(
INFINI_DEVICE_ASCEND
,
ascend
);
...
...
@@ -72,7 +72,7 @@ __C infiniStatus_t infiniopGetRandomSampleWorkspaceSize(
GET
(
INFINI_DEVICE_NVIDIA
,
cuda
);
#endif
#ifdef ENABLE_METAX_API
GET
(
INFINI_DEVICE_METAX
,
m
aca
);
GET
(
INFINI_DEVICE_METAX
,
m
etax
);
#endif
#ifdef ENABLE_ASCEND_API
GET
(
INFINI_DEVICE_ASCEND
,
ascend
);
...
...
@@ -115,7 +115,7 @@ __C infiniStatus_t infiniopRandomSample(
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
cuda
);
#endif
#ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
m
aca
);
CALCULATE
(
INFINI_DEVICE_METAX
,
m
etax
);
#endif
#ifdef ENABLE_ASCEND_API
CALCULATE
(
INFINI_DEVICE_ASCEND
,
ascend
);
...
...
@@ -145,7 +145,7 @@ __C infiniStatus_t infiniopDestroyRandomSampleDescriptor(
DELETE
(
INFINI_DEVICE_NVIDIA
,
cuda
);
#endif
#ifdef ENABLE_METAX_API
DELETE
(
INFINI_DEVICE_METAX
,
m
aca
);
DELETE
(
INFINI_DEVICE_METAX
,
m
etax
);
#endif
#ifdef ENABLE_ASCEND_API
DELETE
(
INFINI_DEVICE_ASCEND
,
ascend
);
...
...
src/infiniop/ops/rearrange/m
aca
/rearrange_kernel.h
→
src/infiniop/ops/rearrange/m
etax
/rearrange_kernel.h
View file @
e4605f7c
File moved
src/infiniop/ops/rearrange/m
aca
/rearrange_m
aca
.h
→
src/infiniop/ops/rearrange/m
etax
/rearrange_m
etax
.h
View file @
e4605f7c
...
...
@@ -3,6 +3,6 @@
#include "../rearrange.h"
DESCRIPTOR
(
m
aca
)
DESCRIPTOR
(
m
etax
)
#endif // __REARRANGE_MACA_H__
src/infiniop/ops/rearrange/m
aca
/rearrange_m
aca
.maca
→
src/infiniop/ops/rearrange/m
etax
/rearrange_m
etax
.maca
View file @
e4605f7c
#include "../../../tensor.h"
#include "rearrange_kernel.h"
#include "rearrange_m
aca
.h"
#include "rearrange_m
etax
.h"
#include <algorithm>
#include <cmath>
#include <memory>
#include <stdint.h>
#include <vector>
namespace op::rearrange::m
aca
{
namespace op::rearrange::m
etax
{
struct Descriptor::Opaque {
std::shared_ptr<device::maca::Handle::Internal> internal;
...
...
@@ -480,4 +480,4 @@ infiniStatus_t Descriptor::calculate(
return status;
}
} // namespace op::rearrange::m
aca
} // namespace op::rearrange::m
etax
src/infiniop/ops/rearrange/operator.cc
View file @
e4605f7c
...
...
@@ -13,7 +13,7 @@
#include "cuda/rearrange_cuda.cuh"
#endif
#ifdef ENABLE_METAX_API
#include "m
aca
/rearrange_m
aca
.h"
#include "m
etax
/rearrange_m
etax
.h"
#endif
__C
infiniStatus_t
infiniopCreateRearrangeDescriptor
(
...
...
@@ -43,7 +43,7 @@ __C infiniStatus_t infiniopCreateRearrangeDescriptor(
CREATE
(
INFINI_DEVICE_NVIDIA
,
cuda
);
#endif
#ifdef ENABLE_METAX_API
CREATE
(
INFINI_DEVICE_METAX
,
m
aca
);
CREATE
(
INFINI_DEVICE_METAX
,
m
etax
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -76,7 +76,7 @@ __C infiniStatus_t infiniopRearrange(
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
cuda
);
#endif
#ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
m
aca
);
CALCULATE
(
INFINI_DEVICE_METAX
,
m
etax
);
#endif
default:
...
...
@@ -107,7 +107,7 @@ __C infiniStatus_t infiniopDestroyRearrangeDescriptor(
DELETE
(
INFINI_DEVICE_NVIDIA
,
cuda
);
#endif
#ifdef ENABLE_METAX_API
DELETE
(
INFINI_DEVICE_METAX
,
m
aca
);
DELETE
(
INFINI_DEVICE_METAX
,
m
etax
);
#endif
default:
...
...
src/infiniop/ops/relu/cpu/relu_cpu.h
View file @
e4605f7c
...
...
@@ -5,7 +5,7 @@
#include "../../../elementwise/cpu/elementwise_cpu.h"
ELEMENTWISE_DESCRIPTOR
(
relu
,
cpu
)
ELEMENTWISE_DESCRIPTOR
(
relu
,
cpu
,
cpu
)
namespace
op
::
relu
::
cpu
{
typedef
struct
ReluOp
{
...
...
src/infiniop/ops/rms_norm/cuda/
rms_norm_
kernel.cuh
→
src/infiniop/ops/rms_norm/cuda/kernel.cuh
View file @
e4605f7c
#ifndef __RMS_NORM_CUDA_KERNEL_H__
#define __RMS_NORM_CUDA_KERNEL_H__
#include "../../../devices/cuda/cuda_kernel_common.cuh"
#include "../../../reduce/cuda/reduce.cuh"
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tdata
,
typename
Tweight
,
typename
Tcompute
>
INFINIOP_CUDA_KERNEL
rmsnormBlock
(
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tcompute
,
typename
Tdata
,
typename
Tweight
>
__device__
void
rmsnormBlock
(
Tdata
*
__restrict__
y
,
ptrdiff_t
stride_y
,
const
Tdata
*
__restrict__
x
,
...
...
src/infiniop/ops/rms_norm/m
aca
/rms_norm_m
aca
.cuh
→
src/infiniop/ops/rms_norm/m
etax
/rms_norm_m
etax
.cuh
View file @
e4605f7c
File moved
src/infiniop/ops/rms_norm/m
aca
/rms_norm_m
aca
.maca
→
src/infiniop/ops/rms_norm/m
etax
/rms_norm_m
etax
.maca
View file @
e4605f7c
#include "../../../devices/maca/common_maca.h"
#include "../cuda/rms_norm_kernel.cuh"
#include "rms_norm_maca.cuh"
#include "rms_norm_metax.cuh"
#include "../../../devices/maca/maca_kernel_common.h"
#include <cub/block/block_reduce.cuh>
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_MACA_KERNEL rmsnormKernel(
Tdata *__restrict__ y,
ptrdiff_t stride_y,
const Tdata *__restrict__ x,
ptrdiff_t stride_x,
const Tweight *__restrict__ w,
size_t dim,
float epsilon) {
rmsnormBlock<BLOCK_SIZE, Tcompute>(y, stride_y, x, stride_x, w, dim, epsilon);
}
namespace op::rms_norm::maca {
...
...
@@ -46,14 +64,14 @@ infiniStatus_t launchKernel(
float epsilon,
hcStream_t maca_stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
rmsnorm
Block
<BLOCK_SIZE, Tdata, Tweight
, Tcompute
><<<batch_size, BLOCK_SIZE, 0, maca_stream>>>( \
reinterpret_cast<Tdata *>(y), \
stride_y, \
reinterpret_cast<const Tdata *>(x), \
stride_x, \
reinterpret_cast<const Tweight *>(w), \
dim, \
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute)
\
rmsnorm
Kernel
<BLOCK_SIZE,
Tcompute,
Tdata, Tweight><<<batch_size, BLOCK_SIZE, 0, maca_stream>>>( \
reinterpret_cast<Tdata *>(y),
\
stride_y,
\
reinterpret_cast<const Tdata *>(x),
\
stride_x,
\
reinterpret_cast<const Tweight *>(w),
\
dim,
\
epsilon)
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
...
...
@@ -91,8 +109,8 @@ infiniStatus_t Descriptor::calculate(
auto maca_stream = reinterpret_cast<hcStream_t>(stream);
// launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() ==
CUD
A_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<
CUD
A_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, maca_stream));
if (_opaque->internal->maxThreadsPerBlock() ==
MAC
A_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<
MAC
A_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, maca_stream));
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment