Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
8d09630a
Unverified
Commit
8d09630a
authored
Feb 11, 2026
by
gongchensu
Committed by
GitHub
Feb 11, 2026
Browse files
Merge branch 'demo131' into Issue/862
parents
ab52dead
012df56c
Changes
387
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
972 additions
and
37 deletions
+972
-37
src/infiniop/devices/moore/moore_kernel_common.h
src/infiniop/devices/moore/moore_kernel_common.h
+3
-0
src/infiniop/devices/nvidia/nvidia_common.cu
src/infiniop/devices/nvidia/nvidia_common.cu
+16
-0
src/infiniop/devices/nvidia/nvidia_handle.h
src/infiniop/devices/nvidia/nvidia_handle.h
+11
-0
src/infiniop/devices/nvidia/nvidia_kernel_common.cuh
src/infiniop/devices/nvidia/nvidia_kernel_common.cuh
+2
-1
src/infiniop/elementwise/bang/elementwise_bang.h
src/infiniop/elementwise/bang/elementwise_bang.h
+2
-2
src/infiniop/elementwise/bang/elementwise_bang_kernel.h
src/infiniop/elementwise/bang/elementwise_bang_kernel.h
+2
-2
src/infiniop/ninetoothed/build.py
src/infiniop/ninetoothed/build.py
+48
-29
src/infiniop/ninetoothed/utils.h
src/infiniop/ninetoothed/utils.h
+75
-0
src/infiniop/ops/add/bang/add_bang.mlu
src/infiniop/ops/add/bang/add_bang.mlu
+5
-1
src/infiniop/ops/add/bang/add_bang_internal.mlu
src/infiniop/ops/add/bang/add_bang_internal.mlu
+3
-1
src/infiniop/ops/add/operator.cc
src/infiniop/ops/add/operator.cc
+13
-1
src/infiniop/ops/add_rms_norm/add_rms_norm.h
src/infiniop/ops/add_rms_norm/add_rms_norm.h
+53
-0
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc
+147
-0
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.h
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.h
+7
-0
src/infiniop/ops/add_rms_norm/cuda/kernel.cuh
src/infiniop/ops/add_rms_norm/cuda/kernel.cuh
+63
-0
src/infiniop/ops/add_rms_norm/info.h
src/infiniop/ops/add_rms_norm/info.h
+132
-0
src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.cuh
src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.cuh
+8
-0
src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.maca
src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.maca
+191
-0
src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.h
src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.h
+8
-0
src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.mu
src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.mu
+183
-0
No files found.
src/infiniop/devices/moore/moore_kernel_common.h
View file @
8d09630a
...
...
@@ -6,6 +6,7 @@
// Posible maximum number of threads per block for MUSA architectures
// Used for picking correct kernel launch configuration
#define MOORE_BLOCK_SIZE_4096 4096
#define MOORE_BLOCK_SIZE_2048 2048
#define MOORE_BLOCK_SIZE_1024 1024
#define MOORE_BLOCK_SIZE_512 512
...
...
@@ -16,6 +17,8 @@ using cuda_bfloat16 = mt_bfloat16;
using
cuda_bfloat162
=
mt_bfloat162
;
using
cuda_fp8_e4m3
=
__mt_fp8_e4m3
;
using
__nv_bfloat16
=
__mt_bfloat16
;
namespace
device
::
moore
{
// get the memory offset of the given element in a tensor given its flat index
...
...
src/infiniop/devices/nvidia/nvidia_common.cu
View file @
8d09630a
...
...
@@ -23,6 +23,10 @@ Handle::Internal::Internal(int device_id) {
_grid_size
[
0
]
=
prop
.
maxGridSize
[
0
];
_grid_size
[
1
]
=
prop
.
maxGridSize
[
1
];
_grid_size
[
2
]
=
prop
.
maxGridSize
[
2
];
this
->
useCublas
(
nullptr
,
[](
cublasHandle_t
handle
)
{
return
INFINI_STATUS_SUCCESS
;
});
#ifdef ENABLE_CUDNN_API
this
->
useCudnn
(
nullptr
,
[](
cudnnHandle_t
handle
)
{
return
INFINI_STATUS_SUCCESS
;
});
#endif
}
infiniStatus_t
Handle
::
Internal
::
useCublas
(
cudaStream_t
stream
,
const
Fn
<
cublasHandle_t
>
&
f
)
const
{
...
...
@@ -106,6 +110,18 @@ infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
}
// namespace iluvatar
namespace
ali
{
Handle
::
Handle
(
int
device_id
)
:
nvidia
::
Handle
(
INFINI_DEVICE_ALI
,
device_id
)
{}
infiniStatus_t
Handle
::
create
(
InfiniopHandle
**
handle_ptr
,
int
device_id
)
{
*
handle_ptr
=
new
Handle
(
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace ali
namespace
qy
{
Handle
::
Handle
(
int
device_id
)
...
...
src/infiniop/devices/nvidia/nvidia_handle.h
View file @
8d09630a
...
...
@@ -35,6 +35,17 @@ public:
}
// namespace iluvatar
namespace
ali
{
struct
Handle
:
public
nvidia
::
Handle
{
Handle
(
int
device_id
);
public:
static
infiniStatus_t
create
(
InfiniopHandle
**
handle_ptr
,
int
device_id
);
};
}
// namespace ali
namespace
qy
{
struct
Handle
:
public
nvidia
::
Handle
{
...
...
src/infiniop/devices/nvidia/nvidia_kernel_common.cuh
View file @
8d09630a
...
...
@@ -16,6 +16,7 @@
// Posible maximum number of threads per block for CUDA architectures
// Used for picking correct kernel launch configuration
#define CUDA_BLOCK_SIZE_4096 4096
#define CUDA_BLOCK_SIZE_2048 2048
#define CUDA_BLOCK_SIZE_1024 1024
#define CUDA_BLOCK_SIZE_512 512
...
...
@@ -54,7 +55,7 @@ exp_(const float val) {
return
expf
(
val
);
}
#if !defined(ENABLE_ILUVATAR_API) && !defined(ENABLE_QY_API) && !defined(ENABLE_HYGON_API)
#if !defined(ENABLE_ILUVATAR_API) && !defined(ENABLE_QY_API) && !defined(ENABLE_HYGON_API)
&& !defined(ENABLE_ALI_API)
__forceinline__
__device__
long
double
exp_
(
const
long
double
val
)
{
return
expl
(
val
);
...
...
src/infiniop/elementwise/bang/elementwise_bang.h
View file @
8d09630a
...
...
@@ -127,8 +127,8 @@ private:
const
int8_t
*
d_meta_start
=
reinterpret_cast
<
int8_t
*>
(
workspace
)
+
input_arr_size
;
// Copy input pointer array and metadata to device
CNRT_CHECK
(
cnrtMemcpy
(
workspace
,
(
void
*
)
h_inputs_arr
,
input_arr_size
,
CNRT_MEM_TRANS_DIR_HOST2DEV
));
CNRT_CHECK
(
cnrtMemcpy
((
void
*
)
d_meta_start
,
(
void
*
)
info_meta_start
,
info
.
getMetaMemSize
(),
CNRT_MEM_TRANS_DIR_HOST2DEV
));
CNRT_CHECK
(
cnrtMemcpy
(
workspace
,
(
void
*
)
h_inputs_arr
,
input_arr_size
,
cnrtMemcpyHostToDev
));
CNRT_CHECK
(
cnrtMemcpy
((
void
*
)
d_meta_start
,
(
void
*
)
info_meta_start
,
info
.
getMetaMemSize
(),
cnrtMemcpyHostToDev
));
// Setup pointers to device memory regions
d_inputs_arr
=
reinterpret_cast
<
const
void
**>
(
workspace
);
...
...
src/infiniop/elementwise/bang/elementwise_bang_kernel.h
View file @
8d09630a
...
...
@@ -248,10 +248,10 @@ void launchElementwiseKernelWrapper(
dim
.
z
=
1
;
// Choose kernel type based on problem characteristics
cnrtFunctionType_t
func_type
=
CNRT_FUNC_TYPE_BLOCK
;
cnrtFunctionType_t
func_type
=
cnrtFuncTypeBlock
;
if
(
output_size
>
1024
*
1024
&&
output_contiguous
)
{
// For large contiguous operations, use UNION type
func_type
=
CNRT_FUNC_TYPE_UNION
1
;
func_type
=
cnrtFuncTypeUnion
1
;
}
// Launch the kernel with optimal configuration
...
...
src/infiniop/ninetoothed/build.py
View file @
8d09630a
import
concurrent.futures
import
functools
import
inspect
import
itertools
...
...
@@ -16,40 +17,28 @@ BUILD_DIRECTORY_PATH = (
def
build
(
premake
,
constexpr_param_grid
,
caller
,
op_name
,
output_dir
):
headers
=
[]
all_param_names
=
[]
combinations
=
[]
launches
=
[]
for
combination
in
_generate_param_value_combinations
(
constexpr_param_grid
)
:
arrangement
,
application
,
tensors
=
premake
(
**
combination
)
with
concurrent
.
futures
.
ProcessPoolExecutor
()
as
executor
:
futures
=
[]
for
param_name
,
param_value
in
combination
.
items
():
if
isinstance
(
param_value
,
str
):
combination
[
param_name
]
=
(
f
"INFINI_DTYPE_
{
combination
[
param_name
].
replace
(
'fp'
,
'F'
).
upper
()
}
"
)
for
combination
in
tuple
(
_generate_param_value_combinations
(
constexpr_param_grid
)
):
future
=
executor
.
submit
(
_make
,
premake
,
combination
,
caller
,
op_name
,
output_dir
)
combination
=
{
f
"
{
name
}
_"
:
value
for
name
,
value
in
combination
.
items
()}
futures
.
append
(
future
)
kernel_name
=
f
"
{
op_name
}
_
{
_generate_suffix
(
combination
.
values
())
}
"
for
future
in
concurrent
.
futures
.
as_completed
(
futures
):
header
,
param_names
,
combination
,
launch
=
future
.
result
()
ninetoothed
.
make
(
arrangement
,
application
,
tensors
,
caller
=
caller
,
kernel_name
=
kernel_name
,
output_dir
=
output_dir
,
)
header
=
output_dir
/
f
"
{
kernel_name
}
.h"
param_names
=
(
"stream"
,)
+
tuple
(
inspect
.
signature
(
application
).
parameters
.
keys
()
)
launch
=
f
""" if (
{
_generate_condition
(
combination
)
}
)
return launch_
{
kernel_name
}
(
{
", "
.
join
(
param_names
)
}
);"""
headers
.
append
(
header
)
all_param_names
.
append
(
param_names
)
launches
.
append
(
launch
)
headers
.
append
(
header
)
all_param_names
.
append
(
param_names
)
combinations
.
append
(
combination
)
launches
.
append
(
launch
)
includes
=
"
\n
"
.
join
(
f
'#include "
{
header
}
"'
for
header
in
headers
)
...
...
@@ -64,7 +53,7 @@ def build(premake, constexpr_param_grid, caller, op_name, output_dir):
"NineToothedStream"
,
]
+
[
"NineToothedTensor"
for
_
in
range
(
len
(
param_names
)
-
1
)]
for
param_name
in
combination
:
for
param_name
in
functools
.
reduce
(
lambda
x
,
y
:
x
|
y
,
combination
s
,
{})
:
param_names
.
append
(
param_name
)
param_types
.
append
(
"int"
)
...
...
@@ -97,6 +86,36 @@ def build(premake, constexpr_param_grid, caller, op_name, output_dir):
(
BUILD_DIRECTORY_PATH
/
header_file_name
).
write_text
(
header_content
)
def
_make
(
premake
,
combination
,
caller
,
op_name
,
output_dir
):
arrangement
,
application
,
tensors
=
premake
(
**
combination
)
for
param_name
,
param_value
in
combination
.
items
():
if
isinstance
(
param_value
,
str
):
combination
[
param_name
]
=
(
f
"INFINI_DTYPE_
{
combination
[
param_name
].
replace
(
'fp'
,
'F'
).
upper
()
}
"
)
combination
=
{
f
"
{
name
}
_"
:
value
for
name
,
value
in
combination
.
items
()}
kernel_name
=
f
"
{
op_name
}
_
{
_generate_suffix
(
combination
.
values
())
}
"
ninetoothed
.
make
(
arrangement
,
application
,
tensors
,
caller
=
caller
,
kernel_name
=
kernel_name
,
output_dir
=
output_dir
,
)
header
=
output_dir
/
f
"
{
kernel_name
}
.h"
param_names
=
(
"stream"
,)
+
tuple
(
inspect
.
signature
(
application
).
parameters
.
keys
())
launch
=
f
""" if (
{
_generate_condition
(
combination
)
}
)
return launch_
{
kernel_name
}
(
{
", "
.
join
(
param_names
)
}
);"""
return
header
,
param_names
,
combination
,
launch
def
_generate_condition
(
combination
):
return
" && "
.
join
(
f
"
{
param
}
==
{
value
}
"
for
param
,
value
in
combination
.
items
())
...
...
src/infiniop/ninetoothed/utils.h
0 → 100644
View file @
8d09630a
#ifndef __NINETOOTHED_UTILS__
#define __NINETOOTHED_UTILS__
#include <initializer_list>
#include <limits>
#include <type_traits>
#include <vector>
namespace
ninetoothed
{
template
<
typename
T
=
float
>
class
Tensor
{
public:
using
Data
=
decltype
(
NineToothedTensor
::
data
);
using
Size
=
std
::
remove_pointer_t
<
decltype
(
NineToothedTensor
::
shape
)
>
;
using
Stride
=
std
::
remove_pointer_t
<
decltype
(
NineToothedTensor
::
strides
)
>
;
template
<
typename
Shape
,
typename
Strides
>
Tensor
(
const
void
*
data
,
Shape
shape
,
Strides
strides
)
:
data_
{
data
},
shape_
{
shape
},
strides_
{
strides
},
ndim_
{
shape_
.
size
()}
{}
Tensor
(
const
void
*
data
,
std
::
initializer_list
<
Size
>
shape
,
std
::
initializer_list
<
Stride
>
strides
)
:
Tensor
{
data
,
decltype
(
shape_
){
shape
},
decltype
(
strides_
){
strides
}}
{}
Tensor
(
const
void
*
data
,
const
Size
*
shape
,
const
Stride
*
strides
,
Size
ndim
)
:
data_
{
data
},
shape_
{
shape
,
shape
+
ndim
},
strides_
{
strides
,
strides
+
ndim
},
ndim_
{
shape_
.
size
()}
{}
Tensor
(
const
T
value
)
:
value_
{
value
},
data_
{
&
value_
},
ndim_
{
0
}
{}
operator
NineToothedTensor
()
{
return
{
const_cast
<
Data
>
(
data_
),
shape_
.
data
(),
strides_
.
data
()};
}
template
<
typename
Shape
>
Tensor
expand
(
const
Shape
&
sizes
)
const
{
auto
new_ndim
{
sizes
.
size
()};
decltype
(
shape_
)
shape
(
new_ndim
,
1
);
decltype
(
strides_
)
strides
(
new_ndim
,
0
);
auto
num_new_dims
{
new_ndim
-
ndim_
};
for
(
auto
dim
{
decltype
(
ndim_
){
0
}};
dim
<
ndim_
;
++
dim
)
{
shape
[
dim
+
num_new_dims
]
=
shape_
[
dim
];
strides
[
dim
+
num_new_dims
]
=
strides_
[
dim
];
}
for
(
auto
dim
{
decltype
(
new_ndim
){
0
}};
dim
<
new_ndim
;
++
dim
)
{
if
(
sizes
[
dim
]
==
std
::
numeric_limits
<
std
::
remove_reference_t
<
decltype
(
sizes
[
dim
])
>>::
max
()
||
shape
[
dim
]
!=
1
)
{
continue
;
}
shape
[
dim
]
=
sizes
[
dim
];
strides
[
dim
]
=
0
;
}
return
{
data_
,
shape
,
strides
};
}
Tensor
expand_as
(
const
Tensor
&
other
)
const
{
return
expand
(
other
.
shape_
);
}
private:
const
void
*
data_
{
nullptr
};
std
::
vector
<
Size
>
shape_
;
std
::
vector
<
Stride
>
strides_
;
Size
ndim_
{
0
};
T
value_
{
0
};
};
}
// namespace ninetoothed
#endif
src/infiniop/ops/add/bang/add_bang.mlu
View file @
8d09630a
...
...
@@ -31,7 +31,7 @@ infiniStatus_t Descriptor::create(
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32
, INFINI_DTYPE_I32, INFINI_DTYPE_I64
);
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
...
...
@@ -59,6 +59,10 @@ infiniStatus_t Descriptor::calculate(
return _device_info->calculate<AddOp, bfloat16_t>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_F32:
return _device_info->calculate<AddOp, float>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_I32:
return _device_info->calculate<AddOp, int32_t>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_I64:
return _device_info->calculate<AddOp, int64_t>(_info, workspace, output, inputs, queue);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
...
...
src/infiniop/ops/add/bang/add_bang_internal.mlu
View file @
8d09630a
...
...
@@ -8,7 +8,7 @@ public:
static constexpr size_t num_inputs = 2;
template <typename T>
__mlu_device__ void operator()(T *out, const T *a, const T *b, size_t num_elements) const {
if constexpr (std::is_same_v<T, half> || std::is_same_v<T, bfloat16_t> || std::is_same_v<T, float>) {
if constexpr (std::is_same_v<T, half> || std::is_same_v<T, bfloat16_t> || std::is_same_v<T, float>
|| std::is_same_v<T, int32_t> || std::is_same_v<T, int64_t>
) {
__bang_add(out, a, b, num_elements);
} else {
out = a + b;
...
...
@@ -21,5 +21,7 @@ LAUNCH_ELEMENTWISE_KERNEL_IMPL(Add, AddOp)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Add, half)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Add, bfloat16_t)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Add, float)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Add, int32_t)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Add, int64_t)
#endif // __ADD_BANG_INTERNAL_H__
src/infiniop/ops/add/operator.cc
View file @
8d09630a
...
...
@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/add_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
|| defined(ENABLE_ALI_API)
#include "nvidia/add_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
...
...
@@ -48,6 +48,9 @@ __C infiniStatus_t infiniopCreateAddDescriptor(
#ifdef ENABLE_ILUVATAR_API
CREATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
CREATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
CREATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
...
...
@@ -91,6 +94,9 @@ __C infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t desc, siz
#ifdef ENABLE_ILUVATAR_API
GET
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
GET
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
GET
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
...
...
@@ -142,6 +148,9 @@ __C infiniStatus_t infiniopAdd(
#ifdef ENABLE_ILUVATAR_API
CALCULATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
CALCULATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
CALCULATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
...
...
@@ -187,6 +196,9 @@ infiniopDestroyAddDescriptor(infiniopAddDescriptor_t desc) {
#ifdef ENABLE_ILUVATAR_API
DELETE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
DELETE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
DELETE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
...
...
src/infiniop/ops/add_rms_norm/add_rms_norm.h
0 → 100644
View file @
8d09630a
#ifndef ADD_RMS_NORM_H
#define ADD_RMS_NORM_H
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::add_rms_norm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
AddRMSNormInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
AddRMSNormInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t residual_out_desc, \
infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t b_desc, \
infiniopTensorDescriptor_t weight_desc, \
float epsilon); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *y, \
void *residual_out, \
const void *a, \
const void *b, \
const void *weight, \
void *stream) const; \
}; \
}
#endif // ADD_RMS_NORM_H
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc
0 → 100644
View file @
8d09630a
#include "add_rms_norm_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
#include "../../../reduce/cpu/reduce.h"
namespace
op
::
add_rms_norm
::
cpu
{
Descriptor
::~
Descriptor
()
{}
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
residual_out_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
,
infiniopTensorDescriptor_t
weight_desc
,
float
epsilon
)
{
auto
result
=
AddRMSNormInfo
::
create
(
y_desc
,
residual_out_desc
,
a_desc
,
b_desc
,
weight_desc
,
epsilon
);
CHECK_RESULT
(
result
);
*
desc_ptr
=
new
Descriptor
(
nullptr
,
result
.
take
(),
0
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
template
<
typename
T
>
infiniStatus_t
add_rmsnorm
(
const
AddRMSNormInfo
*
info
,
T
*
y
,
T
*
residual_out
,
const
T
*
a
,
const
T
*
b
,
const
T
*
w
)
{
const
size_t
batch_size
=
info
->
shape
[
0
];
const
size_t
nhead
=
info
->
ndim
()
>
2
?
info
->
shape
[
1
]
:
1
;
const
size_t
dim
=
info
->
dim
();
const
ptrdiff_t
total_blocks
=
static_cast
<
ptrdiff_t
>
(
batch_size
*
nhead
);
#pragma omp parallel for
for
(
ptrdiff_t
block_idx
=
0
;
block_idx
<
total_blocks
;
++
block_idx
)
{
const
size_t
i
=
block_idx
/
nhead
;
// batch index
const
size_t
j
=
block_idx
%
nhead
;
// head index
const
T
*
a_ptr
=
a
+
i
*
info
->
a_strides
[
0
]
+
j
*
info
->
a_strides
[
1
];
const
T
*
b_ptr
=
b
+
i
*
info
->
b_strides
[
0
]
+
j
*
info
->
b_strides
[
1
];
T
*
y_ptr
=
y
+
i
*
info
->
y_strides
[
0
]
+
j
*
info
->
y_strides
[
1
];
T
*
residual_out_ptr
=
residual_out
+
i
*
info
->
residual_out_strides
[
0
]
+
j
*
info
->
residual_out_strides
[
1
];
// Compute add(a, b) once and store it
T
sum_squared
=
(
T
)
0
;
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
T
sum_val
=
a_ptr
[
k
]
+
b_ptr
[
k
];
residual_out_ptr
[
k
]
=
sum_val
;
// Store add result
sum_squared
+=
sum_val
*
sum_val
;
}
// Compute RMS: 1 / (sqrt(mean(sum^2) + eps))
// Note: mean = sum_squared / dim
T
rms
=
(
T
)
1
/
std
::
sqrt
(
sum_squared
/
(
T
)(
dim
)
+
(
T
)(
info
->
epsilon
));
// Apply normalization: y = (a + b) * w * rms
// Reuse stored values from residual_out
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
y_ptr
[
k
]
=
residual_out_ptr
[
k
]
*
w
[
k
]
*
rms
;
}
}
return
INFINI_STATUS_SUCCESS
;
}
template
<
typename
T
,
typename
Tw
>
infiniStatus_t
add_rmsnormHalfPrecision
(
const
AddRMSNormInfo
*
info
,
T
*
y
,
T
*
residual_out
,
const
T
*
a
,
const
T
*
b
,
const
Tw
*
w
)
{
static_assert
(
std
::
is_same
<
T
,
fp16_t
>::
value
||
std
::
is_same
<
T
,
bf16_t
>::
value
,
"T must be fp16_t or bf16_t"
);
const
size_t
batch_size
=
info
->
shape
[
0
];
const
size_t
nhead
=
info
->
ndim
()
>
2
?
info
->
shape
[
1
]
:
1
;
const
size_t
dim
=
info
->
dim
();
const
ptrdiff_t
total_blocks
=
static_cast
<
ptrdiff_t
>
(
batch_size
*
nhead
);
#pragma omp parallel for
for
(
ptrdiff_t
block_idx
=
0
;
block_idx
<
total_blocks
;
++
block_idx
)
{
const
size_t
i
=
block_idx
/
nhead
;
// batch index
const
size_t
j
=
block_idx
%
nhead
;
// head index
const
T
*
a_ptr
=
a
+
i
*
info
->
a_strides
[
0
]
+
j
*
info
->
a_strides
[
1
];
const
T
*
b_ptr
=
b
+
i
*
info
->
b_strides
[
0
]
+
j
*
info
->
b_strides
[
1
];
T
*
y_ptr
=
y
+
i
*
info
->
y_strides
[
0
]
+
j
*
info
->
y_strides
[
1
];
T
*
residual_out_ptr
=
residual_out
+
i
*
info
->
residual_out_strides
[
0
]
+
j
*
info
->
residual_out_strides
[
1
];
// Compute sum of squares for RMS normalization and store add result
float
sum_squared
=
0.0
f
;
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
float
sum_val
=
utils
::
cast
<
float
>
(
a_ptr
[
k
])
+
utils
::
cast
<
float
>
(
b_ptr
[
k
]);
residual_out_ptr
[
k
]
=
utils
::
cast
<
T
>
(
sum_val
);
// Store add result
sum_squared
+=
sum_val
*
sum_val
;
}
// Compute RMS: 1 / (sqrt(sum/dim + eps))
float
rms
=
1.
f
/
std
::
sqrt
(
sum_squared
/
(
float
)(
dim
)
+
info
->
epsilon
);
// Apply normalization: y = (a + b) * w * rms
// Reuse stored values from residual_out
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
float
sum_val
=
utils
::
cast
<
float
>
(
residual_out_ptr
[
k
]);
float
val
;
if
constexpr
(
std
::
is_same
<
Tw
,
float
>::
value
)
{
val
=
sum_val
*
w
[
k
]
*
rms
;
}
else
if
constexpr
(
std
::
is_same
<
Tw
,
T
>::
value
||
std
::
is_same_v
<
Tw
,
fp16_t
>
||
std
::
is_same_v
<
Tw
,
bf16_t
>
)
{
val
=
sum_val
*
utils
::
cast
<
float
>
(
w
[
k
])
*
rms
;
}
else
{
std
::
abort
();
}
y_ptr
[
k
]
=
utils
::
cast
<
T
>
(
val
);
}
}
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
void
*
residual_out
,
const
void
*
a
,
const
void
*
b
,
const
void
*
weight
,
void
*
stream
)
const
{
if
(
_info
.
atype
==
INFINI_DTYPE_F16
)
{
if
(
_info
.
wtype
==
INFINI_DTYPE_F16
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
fp16_t
*
)
y
,
(
fp16_t
*
)
residual_out
,
(
const
fp16_t
*
)
a
,
(
const
fp16_t
*
)
b
,
(
const
fp16_t
*
)
weight
));
}
else
if
(
_info
.
wtype
==
INFINI_DTYPE_F32
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
fp16_t
*
)
y
,
(
fp16_t
*
)
residual_out
,
(
const
fp16_t
*
)
a
,
(
const
fp16_t
*
)
b
,
(
const
float
*
)
weight
));
}
else
if
(
_info
.
wtype
==
INFINI_DTYPE_BF16
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
fp16_t
*
)
y
,
(
fp16_t
*
)
residual_out
,
(
const
fp16_t
*
)
a
,
(
const
fp16_t
*
)
b
,
(
const
bf16_t
*
)
weight
));
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
if
(
_info
.
atype
==
INFINI_DTYPE_BF16
)
{
if
(
_info
.
wtype
==
INFINI_DTYPE_BF16
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
bf16_t
*
)
y
,
(
bf16_t
*
)
residual_out
,
(
const
bf16_t
*
)
a
,
(
const
bf16_t
*
)
b
,
(
const
bf16_t
*
)
weight
));
}
else
if
(
_info
.
wtype
==
INFINI_DTYPE_F32
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
bf16_t
*
)
y
,
(
bf16_t
*
)
residual_out
,
(
const
bf16_t
*
)
a
,
(
const
bf16_t
*
)
b
,
(
const
float
*
)
weight
));
}
else
if
(
_info
.
wtype
==
INFINI_DTYPE_F16
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
bf16_t
*
)
y
,
(
bf16_t
*
)
residual_out
,
(
const
bf16_t
*
)
a
,
(
const
bf16_t
*
)
b
,
(
const
fp16_t
*
)
weight
));
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
if
(
_info
.
atype
==
INFINI_DTYPE_F32
)
{
CHECK_STATUS
(
add_rmsnorm
(
&
_info
,
(
float
*
)
y
,
(
float
*
)
residual_out
,
(
const
float
*
)
a
,
(
const
float
*
)
b
,
(
const
float
*
)
weight
));
}
else
if
(
_info
.
atype
==
INFINI_DTYPE_F64
)
{
CHECK_STATUS
(
add_rmsnorm
(
&
_info
,
(
double
*
)
y
,
(
double
*
)
residual_out
,
(
const
double
*
)
a
,
(
const
double
*
)
b
,
(
const
double
*
)
weight
));
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::add_rms_norm::cpu
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.h
0 → 100644
View file @
8d09630a
#ifndef __ADD_RMS_NORM_CPU_H__
#define __ADD_RMS_NORM_CPU_H__
#include "../add_rms_norm.h"
DESCRIPTOR
(
cpu
)
#endif
src/infiniop/ops/add_rms_norm/cuda/kernel.cuh
0 → 100644
View file @
8d09630a
#ifndef __ADD_RMS_NORM_CUDA_KERNEL_H__
#define __ADD_RMS_NORM_CUDA_KERNEL_H__
#include <cub/block/block_reduce.cuh>
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tcompute
,
typename
Tdata
,
typename
Tweight
>
__device__
void
add_rmsnormBlock
(
Tdata
*
__restrict__
y
,
Tdata
*
__restrict__
residual_out
,
ptrdiff_t
stride_y_batch
,
ptrdiff_t
stride_y_nhead
,
ptrdiff_t
stride_residual_out_batch
,
ptrdiff_t
stride_residual_out_nhead
,
const
Tdata
*
__restrict__
a
,
ptrdiff_t
stride_a_batch
,
ptrdiff_t
stride_a_nhead
,
const
Tdata
*
__restrict__
b
,
ptrdiff_t
stride_b_batch
,
ptrdiff_t
stride_b_nhead
,
const
Tweight
*
__restrict__
w
,
size_t
nhead
,
size_t
dim
,
float
epsilon
)
{
// Each block takes care of one head in one batch
// Each thread deals with every block_size element in the row
size_t
batch_idx
=
blockIdx
.
x
/
nhead
;
size_t
head_idx
=
blockIdx
.
x
%
nhead
;
auto
y_ptr
=
y
+
batch_idx
*
stride_y_batch
+
head_idx
*
stride_y_nhead
;
auto
a_ptr
=
a
+
batch_idx
*
stride_a_batch
+
head_idx
*
stride_a_nhead
;
auto
b_ptr
=
b
+
batch_idx
*
stride_b_batch
+
head_idx
*
stride_b_nhead
;
auto
w_ptr
=
w
;
Tdata
*
residual_out_ptr
=
residual_out
+
batch_idx
*
stride_residual_out_batch
+
head_idx
*
stride_residual_out_nhead
;
// Compute add(a, b) and sum of squares in one pass
Tcompute
sum_squared
=
0
;
for
(
size_t
i
=
threadIdx
.
x
;
i
<
dim
;
i
+=
BLOCK_SIZE
)
{
Tcompute
sum_val
=
Tcompute
(
a_ptr
[
i
])
+
Tcompute
(
b_ptr
[
i
]);
residual_out_ptr
[
i
]
=
Tdata
(
sum_val
);
// Store add result
sum_squared
+=
sum_val
*
sum_val
;
}
// Block-reduce sum of squares
using
BlockReduce
=
cub
::
BlockReduce
<
Tcompute
,
BLOCK_SIZE
>
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
sum_squared
=
BlockReduce
(
temp_storage
).
Sum
(
sum_squared
);
// Thread_0 computes RMS=1/sqrt(ss/dim+epsilon) and stores in shared memory
__shared__
Tcompute
rms
;
if
(
threadIdx
.
x
==
0
)
{
rms
=
Tcompute
(
rsqrtf
(
sum_squared
/
Tcompute
(
dim
)
+
epsilon
));
}
__syncthreads
();
// Apply normalization: y = (a + b) * w * rms
// Reuse stored values from residual_out
for
(
size_t
i
=
threadIdx
.
x
;
i
<
dim
;
i
+=
BLOCK_SIZE
)
{
Tcompute
sum_val
=
Tcompute
(
residual_out_ptr
[
i
]);
// Reuse stored value
y_ptr
[
i
]
=
Tdata
(
sum_val
*
Tcompute
(
w_ptr
[
i
])
*
rms
);
}
}
#endif
src/infiniop/ops/add_rms_norm/info.h
0 → 100644
View file @
8d09630a
#ifndef __ADD_RMS_NORM_INFO_H__
#define __ADD_RMS_NORM_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <vector>
namespace
op
::
add_rms_norm
{
class
AddRMSNormInfo
{
AddRMSNormInfo
()
=
default
;
public:
infiniDtype_t
wtype
;
infiniDtype_t
atype
;
float
epsilon
;
std
::
vector
<
size_t
>
shape
;
std
::
vector
<
ptrdiff_t
>
y_strides
;
std
::
vector
<
ptrdiff_t
>
residual_out_strides
;
std
::
vector
<
ptrdiff_t
>
a_strides
;
std
::
vector
<
ptrdiff_t
>
b_strides
;
bool
has_residual_out
;
size_t
ndim
()
const
{
return
shape
.
size
();
}
size_t
dim
()
const
{
return
shape
[
ndim
()
-
1
];
}
static
utils
::
Result
<
AddRMSNormInfo
>
create
(
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
residual_out_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
,
infiniopTensorDescriptor_t
weight_desc
,
float
epsilon
)
{
auto
atype
=
y_desc
->
dtype
();
auto
wtype
=
weight_desc
->
dtype
();
// Check that all input tensors have the same dtype
if
(
a_desc
->
dtype
()
!=
atype
||
b_desc
->
dtype
()
!=
atype
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
atype
==
INFINI_DTYPE_F16
||
atype
==
INFINI_DTYPE_BF16
)
{
// For half-precision types (FP16/BF16), weights can be the same half-precision type or FP32
if
(
wtype
!=
atype
&&
wtype
!=
INFINI_DTYPE_F32
&&
wtype
!=
INFINI_DTYPE_BF16
&&
wtype
!=
INFINI_DTYPE_F16
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
if
(
atype
==
INFINI_DTYPE_F32
||
atype
==
INFINI_DTYPE_F64
)
{
// For FP32/FP64, activations and weights must be of the same type
if
(
atype
!=
wtype
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
const
size_t
y_ndim
=
y_desc
->
ndim
();
const
size_t
a_ndim
=
a_desc
->
ndim
();
const
size_t
b_ndim
=
b_desc
->
ndim
();
const
size_t
w_ndim
=
weight_desc
->
ndim
();
if
(
y_ndim
!=
a_ndim
||
y_ndim
!=
b_ndim
||
w_ndim
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
size_t
batch
=
1
;
size_t
nhead
=
1
;
size_t
dim
=
0
;
if
(
y_ndim
==
2
)
{
batch
=
y_desc
->
dim
(
0
);
dim
=
y_desc
->
dim
(
1
);
if
(
a_desc
->
dim
(
0
)
!=
batch
||
a_desc
->
dim
(
1
)
!=
dim
||
b_desc
->
dim
(
0
)
!=
batch
||
b_desc
->
dim
(
1
)
!=
dim
||
weight_desc
->
dim
(
0
)
!=
dim
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
else
if
(
y_ndim
==
3
)
{
batch
=
y_desc
->
dim
(
0
);
nhead
=
y_desc
->
dim
(
1
);
dim
=
y_desc
->
dim
(
2
);
if
(
a_desc
->
dim
(
0
)
!=
batch
||
a_desc
->
dim
(
1
)
!=
nhead
||
a_desc
->
dim
(
2
)
!=
dim
||
b_desc
->
dim
(
0
)
!=
batch
||
b_desc
->
dim
(
1
)
!=
nhead
||
b_desc
->
dim
(
2
)
!=
dim
||
weight_desc
->
dim
(
0
)
!=
dim
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
else
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
// Check contiguity of the last dimension
if
(
y_desc
->
stride
(
y_ndim
-
1
)
!=
1
||
a_desc
->
stride
(
a_ndim
-
1
)
!=
1
||
b_desc
->
stride
(
b_ndim
-
1
)
!=
1
||
weight_desc
->
stride
(
w_ndim
-
1
)
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
// residual_out_desc is required (always needed for fused operator)
if
(
residual_out_desc
==
nullptr
)
{
return
INFINI_STATUS_BAD_PARAM
;
}
const
size_t
residual_out_ndim
=
residual_out_desc
->
ndim
();
if
(
residual_out_ndim
!=
y_ndim
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
residual_out_desc
->
dtype
()
!=
atype
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
// Check shape matches
for
(
size_t
i
=
0
;
i
<
y_ndim
;
i
++
)
{
if
(
residual_out_desc
->
dim
(
i
)
!=
y_desc
->
dim
(
i
))
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
if
(
residual_out_desc
->
stride
(
residual_out_ndim
-
1
)
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
AddRMSNormInfo
info
;
info
.
wtype
=
wtype
;
info
.
atype
=
atype
;
info
.
epsilon
=
epsilon
;
info
.
shape
=
y_desc
->
shape
();
info
.
y_strides
=
y_desc
->
strides
();
info
.
a_strides
=
a_desc
->
strides
();
info
.
b_strides
=
b_desc
->
strides
();
info
.
has_residual_out
=
true
;
// Always true now
info
.
residual_out_strides
=
residual_out_desc
->
strides
();
return
utils
::
Result
<
AddRMSNormInfo
>
(
info
);
}
};
}
// namespace op::add_rms_norm
#endif // __ADD_RMS_NORM_INFO_H__
src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.cuh
0 → 100644
View file @
8d09630a
#ifndef __ADD_RMS_NORM_METAX_CUH__
#define __ADD_RMS_NORM_METAX_CUH__
#include "../add_rms_norm.h"
DESCRIPTOR
(
metax
)
#endif
src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.maca
0 → 100644
View file @
8d09630a
#include "../../../devices/metax/metax_common.h"
#include "add_rms_norm_metax.cuh"
#include "../../../devices/metax/metax_kernel_common.h"
#include <cub/block/block_reduce.cuh>
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
// Kernel function template for add_rms_norm on Metax platform
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_METAX_KERNEL add_rmsnormKernel(
Tdata *__restrict__ y,
Tdata *__restrict__ residual_out,
ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
ptrdiff_t stride_residual_out_batch,
ptrdiff_t stride_residual_out_nhead,
const Tdata *__restrict__ a,
ptrdiff_t stride_a_batch,
ptrdiff_t stride_a_nhead,
const Tdata *__restrict__ b,
ptrdiff_t stride_b_batch,
ptrdiff_t stride_b_nhead,
const Tweight *__restrict__ w,
size_t nhead,
size_t dim,
float epsilon) {
add_rmsnormBlock<BLOCK_SIZE, Tcompute>(
y, residual_out,
stride_y_batch, stride_y_nhead,
stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
w, nhead, dim, epsilon);
}
namespace op::add_rms_norm::metax {
// Internal opaque structure for Metax device handle
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
// Destructor
Descriptor::~Descriptor() {
delete _opaque;
}
// Create descriptor for add_rms_norm operator
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t residual_out_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon) {
auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon);
CHECK_RESULT(result);
auto info = result.take();
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
std::move(info),
0,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
// Launch kernel with different data types
template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel(
uint32_t batch_size, size_t nhead, size_t dim,
void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead,
void *residual_out, ptrdiff_t stride_residual_out_batch, ptrdiff_t stride_residual_out_nhead,
const void *a, ptrdiff_t stride_a_batch, ptrdiff_t stride_a_nhead,
const void *b, ptrdiff_t stride_b_batch, ptrdiff_t stride_b_nhead,
const void *w, infiniDtype_t wtype,
float epsilon,
hcStream_t stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
add_rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size * nhead, BLOCK_SIZE, 0, stream>>>( \
reinterpret_cast<Tdata *>(y), \
reinterpret_cast<Tdata *>(residual_out), \
stride_y_batch, \
stride_y_nhead, \
stride_residual_out_batch, \
stride_residual_out_nhead, \
reinterpret_cast<const Tdata *>(a), \
stride_a_batch, \
stride_a_nhead, \
reinterpret_cast<const Tdata *>(b), \
stride_b_batch, \
stride_b_nhead, \
reinterpret_cast<const Tweight *>(w), \
nhead, \
dim, \
epsilon)
// Handle different data type combinations following Metax pattern
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(half, __hpcc_bfloat16, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(__hpcc_bfloat16, __hpcc_bfloat16, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(__hpcc_bfloat16, half, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(__hpcc_bfloat16, float, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(float, float, float);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
#undef LAUNCH_KERNEL
return INFINI_STATUS_SUCCESS;
}
// Main calculation function
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, void *residual_out, const void *a, const void *b, const void *weight,
void *stream) const {
// Check workspace size
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
// Extract tensor strides and dimensions
auto stride_a_batch = _info.a_strides[0];
auto stride_a_nhead = _info.a_strides[1];
auto stride_b_batch = _info.b_strides[0];
auto stride_b_nhead = _info.b_strides[1];
auto stride_y_batch = _info.y_strides[0];
auto stride_y_nhead = _info.y_strides[1];
auto stride_residual_out_batch = _info.residual_out_strides[0];
auto stride_residual_out_nhead = _info.residual_out_strides[1];
auto dim = _info.dim();
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1;
auto stream_ = reinterpret_cast<hcStream_t>(stream);
// Launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_512>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, stream_));
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_1024>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, stream_));
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_2048) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_2048>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, stream_));
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_4096) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_4096>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, stream_));
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::add_rms_norm::metax
src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.h
0 → 100644
View file @
8d09630a
#ifndef __ADD_RMS_NORM_MOORE_H__
#define __ADD_RMS_NORM_MOORE_H__
#include "../add_rms_norm.h"
DESCRIPTOR
(
moore
)
#endif
src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.mu
0 → 100644
View file @
8d09630a
#include "../../../devices/moore/moore_common.h"
#include "add_rms_norm_moore.h"
#include "../../../devices/moore/moore_kernel_common.h"
#include <cub/block/block_reduce.cuh>
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
// Kernel function template for add_rms_norm on Moore platform
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_MOORE_KERNEL add_rmsnormKernel(
Tdata *__restrict__ y,
Tdata *__restrict__ residual_out,
ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
ptrdiff_t stride_residual_out_batch,
ptrdiff_t stride_residual_out_nhead,
const Tdata *__restrict__ a,
ptrdiff_t stride_a_batch,
ptrdiff_t stride_a_nhead,
const Tdata *__restrict__ b,
ptrdiff_t stride_b_batch,
ptrdiff_t stride_b_nhead,
const Tweight *__restrict__ w,
size_t nhead,
size_t dim,
float epsilon) {
add_rmsnormBlock<BLOCK_SIZE, Tcompute>(
y, residual_out,
stride_y_batch, stride_y_nhead,
stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
w, nhead, dim, epsilon);
}
namespace op::add_rms_norm::moore {
// Internal opaque structure for Moore device handle
struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};
// Destructor
Descriptor::~Descriptor() {
delete _opaque;
}
// Create descriptor for add_rms_norm operator
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t residual_out_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon) {
auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon);
CHECK_RESULT(result);
auto info = result.take();
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
std::move(info),
0,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
// Launch kernel with different data types
template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel(
uint32_t batch_size, size_t nhead, size_t dim,
void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead,
void *residual_out, ptrdiff_t stride_residual_out_batch, ptrdiff_t stride_residual_out_nhead,
const void *a, ptrdiff_t stride_a_batch, ptrdiff_t stride_a_nhead,
const void *b, ptrdiff_t stride_b_batch, ptrdiff_t stride_b_nhead,
const void *w, infiniDtype_t wtype,
float epsilon,
musaStream_t musa_stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
add_rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size * nhead, BLOCK_SIZE, 0, musa_stream>>>( \
reinterpret_cast<Tdata *>(y), \
reinterpret_cast<Tdata *>(residual_out), \
stride_y_batch, \
stride_y_nhead, \
stride_residual_out_batch, \
stride_residual_out_nhead, \
reinterpret_cast<const Tdata *>(a), \
stride_a_batch, \
stride_a_nhead, \
reinterpret_cast<const Tdata *>(b), \
stride_b_batch, \
stride_b_nhead, \
reinterpret_cast<const Tweight *>(w), \
nhead, \
dim, \
epsilon)
// Handle different data type combinations
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(half, __mt_bfloat16, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(__mt_bfloat16, __mt_bfloat16, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(__mt_bfloat16, half, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(__mt_bfloat16, float, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(float, float, float);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
#undef LAUNCH_KERNEL
return INFINI_STATUS_SUCCESS;
}
// Main calculation function
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, void *residual_out, const void *a, const void *b, const void *weight,
void *stream) const {
// Check workspace size
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
// Extract tensor strides and dimensions
auto stride_a_batch = _info.a_strides[0];
auto stride_a_nhead = _info.a_strides[1];
auto stride_b_batch = _info.b_strides[0];
auto stride_b_nhead = _info.b_strides[1];
auto stride_y_batch = _info.y_strides[0];
auto stride_y_nhead = _info.y_strides[1];
auto stride_residual_out_batch = _info.residual_out_strides[0];
auto stride_residual_out_nhead = _info.residual_out_strides[1];
auto dim = _info.dim();
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1;
auto musa_stream = reinterpret_cast<musaStream_t>(stream);
// Launch kernel with appropriate block size based on device capability
if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_1024>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, musa_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_512>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, musa_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_2048) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_2048>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, musa_stream));
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::add_rms_norm::moore
Prev
1
…
5
6
7
8
9
10
11
12
13
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment