Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
784139b9
Unverified
Commit
784139b9
authored
Feb 13, 2026
by
thatPepe
Committed by
GitHub
Feb 13, 2026
Browse files
Merge pull request #990 from InfiniTensor/demo131
Demo-131 Cuda graph with optimized paged attention
parents
3c8fb3c0
1d6527cb
Changes
582
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
662 additions
and
77 deletions
+662
-77
src/infiniop/devices/moore/moore_kernel_common.h
src/infiniop/devices/moore/moore_kernel_common.h
+3
-0
src/infiniop/devices/nvidia/nvidia_common.cu
src/infiniop/devices/nvidia/nvidia_common.cu
+12
-0
src/infiniop/devices/nvidia/nvidia_handle.h
src/infiniop/devices/nvidia/nvidia_handle.h
+11
-0
src/infiniop/devices/nvidia/nvidia_kernel_common.cuh
src/infiniop/devices/nvidia/nvidia_kernel_common.cuh
+4
-1
src/infiniop/elementwise/bang/elementwise_bang.h
src/infiniop/elementwise/bang/elementwise_bang.h
+2
-2
src/infiniop/elementwise/bang/elementwise_bang_kernel.h
src/infiniop/elementwise/bang/elementwise_bang_kernel.h
+2
-2
src/infiniop/ninetoothed/build.py
src/infiniop/ninetoothed/build.py
+48
-29
src/infiniop/ninetoothed/utils.h
src/infiniop/ninetoothed/utils.h
+75
-0
src/infiniop/ops/add/bang/add_bang.mlu
src/infiniop/ops/add/bang/add_bang.mlu
+5
-1
src/infiniop/ops/add/bang/add_bang_internal.mlu
src/infiniop/ops/add/bang/add_bang_internal.mlu
+3
-1
src/infiniop/ops/add/operator.cc
src/infiniop/ops/add/operator.cc
+25
-1
src/infiniop/ops/add_rms_norm/add_rms_norm.h
src/infiniop/ops/add_rms_norm/add_rms_norm.h
+3
-3
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc
+15
-15
src/infiniop/ops/add_rms_norm/info.h
src/infiniop/ops/add_rms_norm/info.h
+3
-3
src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.cuh
src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.cuh
+8
-0
src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.maca
src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.maca
+191
-0
src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.h
src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.h
+8
-0
src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.mu
src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.mu
+183
-0
src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu
src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu
+16
-8
src/infiniop/ops/add_rms_norm/operator.cc
src/infiniop/ops/add_rms_norm/operator.cc
+45
-11
No files found.
src/infiniop/devices/moore/moore_kernel_common.h
View file @
784139b9
...
...
@@ -6,6 +6,7 @@
// Posible maximum number of threads per block for MUSA architectures
// Used for picking correct kernel launch configuration
#define MOORE_BLOCK_SIZE_4096 4096
#define MOORE_BLOCK_SIZE_2048 2048
#define MOORE_BLOCK_SIZE_1024 1024
#define MOORE_BLOCK_SIZE_512 512
...
...
@@ -16,6 +17,8 @@ using cuda_bfloat16 = mt_bfloat16;
using
cuda_bfloat162
=
mt_bfloat162
;
using
cuda_fp8_e4m3
=
__mt_fp8_e4m3
;
using
__nv_bfloat16
=
__mt_bfloat16
;
namespace
device
::
moore
{
// get the memory offset of the given element in a tensor given its flat index
...
...
src/infiniop/devices/nvidia/nvidia_common.cu
View file @
784139b9
...
...
@@ -110,6 +110,18 @@ infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
}
// namespace iluvatar
namespace
ali
{
Handle
::
Handle
(
int
device_id
)
:
nvidia
::
Handle
(
INFINI_DEVICE_ALI
,
device_id
)
{}
infiniStatus_t
Handle
::
create
(
InfiniopHandle
**
handle_ptr
,
int
device_id
)
{
*
handle_ptr
=
new
Handle
(
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace ali
namespace
qy
{
Handle
::
Handle
(
int
device_id
)
...
...
src/infiniop/devices/nvidia/nvidia_handle.h
View file @
784139b9
...
...
@@ -35,6 +35,17 @@ public:
}
// namespace iluvatar
namespace
ali
{
struct
Handle
:
public
nvidia
::
Handle
{
Handle
(
int
device_id
);
public:
static
infiniStatus_t
create
(
InfiniopHandle
**
handle_ptr
,
int
device_id
);
};
}
// namespace ali
namespace
qy
{
struct
Handle
:
public
nvidia
::
Handle
{
...
...
src/infiniop/devices/nvidia/nvidia_kernel_common.cuh
View file @
784139b9
...
...
@@ -9,11 +9,14 @@
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#ifndef ENABLE_HYGON_API
#include <cuda_fp8.h>
#endif
// Posible maximum number of threads per block for CUDA architectures
// Used for picking correct kernel launch configuration
#define CUDA_BLOCK_SIZE_4096 4096
#define CUDA_BLOCK_SIZE_2048 2048
#define CUDA_BLOCK_SIZE_1024 1024
#define CUDA_BLOCK_SIZE_512 512
...
...
@@ -52,7 +55,7 @@ exp_(const float val) {
return
expf
(
val
);
}
#if !defined(ENABLE_ILUVATAR_API) && !defined(ENABLE_QY_API) && !defined(ENABLE_HYGON_API)
#if !defined(ENABLE_ILUVATAR_API) && !defined(ENABLE_QY_API) && !defined(ENABLE_HYGON_API)
&& !defined(ENABLE_ALI_API)
__forceinline__
__device__
long
double
exp_
(
const
long
double
val
)
{
return
expl
(
val
);
...
...
src/infiniop/elementwise/bang/elementwise_bang.h
View file @
784139b9
...
...
@@ -127,8 +127,8 @@ private:
const
int8_t
*
d_meta_start
=
reinterpret_cast
<
int8_t
*>
(
workspace
)
+
input_arr_size
;
// Copy input pointer array and metadata to device
CNRT_CHECK
(
cnrtMemcpy
(
workspace
,
(
void
*
)
h_inputs_arr
,
input_arr_size
,
CNRT_MEM_TRANS_DIR_HOST2DEV
));
CNRT_CHECK
(
cnrtMemcpy
((
void
*
)
d_meta_start
,
(
void
*
)
info_meta_start
,
info
.
getMetaMemSize
(),
CNRT_MEM_TRANS_DIR_HOST2DEV
));
CNRT_CHECK
(
cnrtMemcpy
(
workspace
,
(
void
*
)
h_inputs_arr
,
input_arr_size
,
cnrtMemcpyHostToDev
));
CNRT_CHECK
(
cnrtMemcpy
((
void
*
)
d_meta_start
,
(
void
*
)
info_meta_start
,
info
.
getMetaMemSize
(),
cnrtMemcpyHostToDev
));
// Setup pointers to device memory regions
d_inputs_arr
=
reinterpret_cast
<
const
void
**>
(
workspace
);
...
...
src/infiniop/elementwise/bang/elementwise_bang_kernel.h
View file @
784139b9
...
...
@@ -248,10 +248,10 @@ void launchElementwiseKernelWrapper(
dim
.
z
=
1
;
// Choose kernel type based on problem characteristics
cnrtFunctionType_t
func_type
=
CNRT_FUNC_TYPE_BLOCK
;
cnrtFunctionType_t
func_type
=
cnrtFuncTypeBlock
;
if
(
output_size
>
1024
*
1024
&&
output_contiguous
)
{
// For large contiguous operations, use UNION type
func_type
=
CNRT_FUNC_TYPE_UNION
1
;
func_type
=
cnrtFuncTypeUnion
1
;
}
// Launch the kernel with optimal configuration
...
...
src/infiniop/ninetoothed/build.py
View file @
784139b9
import
concurrent.futures
import
functools
import
inspect
import
itertools
...
...
@@ -16,40 +17,28 @@ BUILD_DIRECTORY_PATH = (
def
build
(
premake
,
constexpr_param_grid
,
caller
,
op_name
,
output_dir
):
headers
=
[]
all_param_names
=
[]
combinations
=
[]
launches
=
[]
for
combination
in
_generate_param_value_combinations
(
constexpr_param_grid
)
:
arrangement
,
application
,
tensors
=
premake
(
**
combination
)
with
concurrent
.
futures
.
ProcessPoolExecutor
()
as
executor
:
futures
=
[]
for
param_name
,
param_value
in
combination
.
items
():
if
isinstance
(
param_value
,
str
):
combination
[
param_name
]
=
(
f
"INFINI_DTYPE_
{
combination
[
param_name
].
replace
(
'fp'
,
'F'
).
upper
()
}
"
)
for
combination
in
tuple
(
_generate_param_value_combinations
(
constexpr_param_grid
)
):
future
=
executor
.
submit
(
_make
,
premake
,
combination
,
caller
,
op_name
,
output_dir
)
combination
=
{
f
"
{
name
}
_"
:
value
for
name
,
value
in
combination
.
items
()}
futures
.
append
(
future
)
kernel_name
=
f
"
{
op_name
}
_
{
_generate_suffix
(
combination
.
values
())
}
"
for
future
in
concurrent
.
futures
.
as_completed
(
futures
):
header
,
param_names
,
combination
,
launch
=
future
.
result
()
ninetoothed
.
make
(
arrangement
,
application
,
tensors
,
caller
=
caller
,
kernel_name
=
kernel_name
,
output_dir
=
output_dir
,
)
header
=
output_dir
/
f
"
{
kernel_name
}
.h"
param_names
=
(
"stream"
,)
+
tuple
(
inspect
.
signature
(
application
).
parameters
.
keys
()
)
launch
=
f
""" if (
{
_generate_condition
(
combination
)
}
)
return launch_
{
kernel_name
}
(
{
", "
.
join
(
param_names
)
}
);"""
headers
.
append
(
header
)
all_param_names
.
append
(
param_names
)
launches
.
append
(
launch
)
headers
.
append
(
header
)
all_param_names
.
append
(
param_names
)
combinations
.
append
(
combination
)
launches
.
append
(
launch
)
includes
=
"
\n
"
.
join
(
f
'#include "
{
header
}
"'
for
header
in
headers
)
...
...
@@ -64,7 +53,7 @@ def build(premake, constexpr_param_grid, caller, op_name, output_dir):
"NineToothedStream"
,
]
+
[
"NineToothedTensor"
for
_
in
range
(
len
(
param_names
)
-
1
)]
for
param_name
in
combination
:
for
param_name
in
functools
.
reduce
(
lambda
x
,
y
:
x
|
y
,
combination
s
,
{})
:
param_names
.
append
(
param_name
)
param_types
.
append
(
"int"
)
...
...
@@ -97,6 +86,36 @@ def build(premake, constexpr_param_grid, caller, op_name, output_dir):
(
BUILD_DIRECTORY_PATH
/
header_file_name
).
write_text
(
header_content
)
def
_make
(
premake
,
combination
,
caller
,
op_name
,
output_dir
):
arrangement
,
application
,
tensors
=
premake
(
**
combination
)
for
param_name
,
param_value
in
combination
.
items
():
if
isinstance
(
param_value
,
str
):
combination
[
param_name
]
=
(
f
"INFINI_DTYPE_
{
combination
[
param_name
].
replace
(
'fp'
,
'F'
).
upper
()
}
"
)
combination
=
{
f
"
{
name
}
_"
:
value
for
name
,
value
in
combination
.
items
()}
kernel_name
=
f
"
{
op_name
}
_
{
_generate_suffix
(
combination
.
values
())
}
"
ninetoothed
.
make
(
arrangement
,
application
,
tensors
,
caller
=
caller
,
kernel_name
=
kernel_name
,
output_dir
=
output_dir
,
)
header
=
output_dir
/
f
"
{
kernel_name
}
.h"
param_names
=
(
"stream"
,)
+
tuple
(
inspect
.
signature
(
application
).
parameters
.
keys
())
launch
=
f
""" if (
{
_generate_condition
(
combination
)
}
)
return launch_
{
kernel_name
}
(
{
", "
.
join
(
param_names
)
}
);"""
return
header
,
param_names
,
combination
,
launch
def
_generate_condition
(
combination
):
return
" && "
.
join
(
f
"
{
param
}
==
{
value
}
"
for
param
,
value
in
combination
.
items
())
...
...
src/infiniop/ninetoothed/utils.h
0 → 100644
View file @
784139b9
#ifndef __NINETOOTHED_UTILS__
#define __NINETOOTHED_UTILS__
#include <initializer_list>
#include <limits>
#include <type_traits>
#include <vector>
namespace
ninetoothed
{
template
<
typename
T
=
float
>
class
Tensor
{
public:
using
Data
=
decltype
(
NineToothedTensor
::
data
);
using
Size
=
std
::
remove_pointer_t
<
decltype
(
NineToothedTensor
::
shape
)
>
;
using
Stride
=
std
::
remove_pointer_t
<
decltype
(
NineToothedTensor
::
strides
)
>
;
template
<
typename
Shape
,
typename
Strides
>
Tensor
(
const
void
*
data
,
Shape
shape
,
Strides
strides
)
:
data_
{
data
},
shape_
{
shape
},
strides_
{
strides
},
ndim_
{
shape_
.
size
()}
{}
Tensor
(
const
void
*
data
,
std
::
initializer_list
<
Size
>
shape
,
std
::
initializer_list
<
Stride
>
strides
)
:
Tensor
{
data
,
decltype
(
shape_
){
shape
},
decltype
(
strides_
){
strides
}}
{}
Tensor
(
const
void
*
data
,
const
Size
*
shape
,
const
Stride
*
strides
,
Size
ndim
)
:
data_
{
data
},
shape_
{
shape
,
shape
+
ndim
},
strides_
{
strides
,
strides
+
ndim
},
ndim_
{
shape_
.
size
()}
{}
Tensor
(
const
T
value
)
:
value_
{
value
},
data_
{
&
value_
},
ndim_
{
0
}
{}
operator
NineToothedTensor
()
{
return
{
const_cast
<
Data
>
(
data_
),
shape_
.
data
(),
strides_
.
data
()};
}
template
<
typename
Shape
>
Tensor
expand
(
const
Shape
&
sizes
)
const
{
auto
new_ndim
{
sizes
.
size
()};
decltype
(
shape_
)
shape
(
new_ndim
,
1
);
decltype
(
strides_
)
strides
(
new_ndim
,
0
);
auto
num_new_dims
{
new_ndim
-
ndim_
};
for
(
auto
dim
{
decltype
(
ndim_
){
0
}};
dim
<
ndim_
;
++
dim
)
{
shape
[
dim
+
num_new_dims
]
=
shape_
[
dim
];
strides
[
dim
+
num_new_dims
]
=
strides_
[
dim
];
}
for
(
auto
dim
{
decltype
(
new_ndim
){
0
}};
dim
<
new_ndim
;
++
dim
)
{
if
(
sizes
[
dim
]
==
std
::
numeric_limits
<
std
::
remove_reference_t
<
decltype
(
sizes
[
dim
])
>>::
max
()
||
shape
[
dim
]
!=
1
)
{
continue
;
}
shape
[
dim
]
=
sizes
[
dim
];
strides
[
dim
]
=
0
;
}
return
{
data_
,
shape
,
strides
};
}
Tensor
expand_as
(
const
Tensor
&
other
)
const
{
return
expand
(
other
.
shape_
);
}
private:
const
void
*
data_
{
nullptr
};
std
::
vector
<
Size
>
shape_
;
std
::
vector
<
Stride
>
strides_
;
Size
ndim_
{
0
};
T
value_
{
0
};
};
}
// namespace ninetoothed
#endif
src/infiniop/ops/add/bang/add_bang.mlu
View file @
784139b9
...
...
@@ -31,7 +31,7 @@ infiniStatus_t Descriptor::create(
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32
, INFINI_DTYPE_I32, INFINI_DTYPE_I64
);
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
...
...
@@ -59,6 +59,10 @@ infiniStatus_t Descriptor::calculate(
return _device_info->calculate<AddOp, bfloat16_t>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_F32:
return _device_info->calculate<AddOp, float>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_I32:
return _device_info->calculate<AddOp, int32_t>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_I64:
return _device_info->calculate<AddOp, int64_t>(_info, workspace, output, inputs, queue);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
...
...
src/infiniop/ops/add/bang/add_bang_internal.mlu
View file @
784139b9
...
...
@@ -8,7 +8,7 @@ public:
static constexpr size_t num_inputs = 2;
template <typename T>
__mlu_device__ void operator()(T *out, const T *a, const T *b, size_t num_elements) const {
if constexpr (std::is_same_v<T, half> || std::is_same_v<T, bfloat16_t> || std::is_same_v<T, float>) {
if constexpr (std::is_same_v<T, half> || std::is_same_v<T, bfloat16_t> || std::is_same_v<T, float>
|| std::is_same_v<T, int32_t> || std::is_same_v<T, int64_t>
) {
__bang_add(out, a, b, num_elements);
} else {
out = a + b;
...
...
@@ -21,5 +21,7 @@ LAUNCH_ELEMENTWISE_KERNEL_IMPL(Add, AddOp)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Add, half)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Add, bfloat16_t)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Add, float)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Add, int32_t)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Add, int64_t)
#endif // __ADD_BANG_INTERNAL_H__
src/infiniop/ops/add/operator.cc
View file @
784139b9
...
...
@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/add_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
|| defined(ENABLE_HYGON_API) || defined(ENABLE_ALI_API)
#include "nvidia/add_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
...
...
@@ -48,9 +48,15 @@ __C infiniStatus_t infiniopCreateAddDescriptor(
#ifdef ENABLE_ILUVATAR_API
CREATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
CREATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
CREATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_HYGON_API
CREATE
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#ifdef ENABLE_METAX_API
CREATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
...
...
@@ -88,9 +94,15 @@ __C infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t desc, siz
#ifdef ENABLE_ILUVATAR_API
GET
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
GET
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
GET
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_HYGON_API
GET
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#ifdef ENABLE_METAX_API
GET
(
INFINI_DEVICE_METAX
,
metax
);
#endif
...
...
@@ -136,9 +148,15 @@ __C infiniStatus_t infiniopAdd(
#ifdef ENABLE_ILUVATAR_API
CALCULATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
CALCULATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
CALCULATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_HYGON_API
CALCULATE
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
...
...
@@ -178,9 +196,15 @@ infiniopDestroyAddDescriptor(infiniopAddDescriptor_t desc) {
#ifdef ENABLE_ILUVATAR_API
DELETE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
DELETE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
DELETE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_HYGON_API
DELETE
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#ifdef ENABLE_METAX_API
DELETE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
...
...
src/infiniop/ops/add_rms_norm/add_rms_norm.h
View file @
784139b9
...
...
@@ -33,19 +33,19 @@
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t residual_out_desc, \
infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t b_desc, \
infiniopTensorDescriptor_t weight_desc, \
float epsilon, \
infiniopTensorDescriptor_t residual_out_desc); \
float epsilon); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *y, \
void *residual_out, \
const void *a, \
const void *b, \
const void *weight, \
void *residual_out, \
void *stream) const; \
}; \
}
...
...
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc
View file @
784139b9
...
...
@@ -10,19 +10,19 @@ infiniStatus_t Descriptor::create(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
residual_out_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
,
infiniopTensorDescriptor_t
weight_desc
,
float
epsilon
,
infiniopTensorDescriptor_t
residual_out_desc
)
{
auto
result
=
AddRMSNormInfo
::
create
(
y_desc
,
a_desc
,
b_desc
,
weight_desc
,
epsilon
,
residual_out_desc
);
float
epsilon
)
{
auto
result
=
AddRMSNormInfo
::
create
(
y_desc
,
residual_out_desc
,
a_desc
,
b_desc
,
weight_desc
,
epsilon
);
CHECK_RESULT
(
result
);
*
desc_ptr
=
new
Descriptor
(
nullptr
,
result
.
take
(),
0
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
template
<
typename
T
>
infiniStatus_t
add_rmsnorm
(
const
AddRMSNormInfo
*
info
,
T
*
y
,
const
T
*
a
,
const
T
*
b
,
const
T
*
w
,
T
*
residual_out
)
{
infiniStatus_t
add_rmsnorm
(
const
AddRMSNormInfo
*
info
,
T
*
y
,
T
*
residual_out
,
const
T
*
a
,
const
T
*
b
,
const
T
*
w
)
{
const
size_t
batch_size
=
info
->
shape
[
0
];
const
size_t
nhead
=
info
->
ndim
()
>
2
?
info
->
shape
[
1
]
:
1
;
const
size_t
dim
=
info
->
dim
();
...
...
@@ -61,7 +61,7 @@ infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T
}
template
<
typename
T
,
typename
Tw
>
infiniStatus_t
add_rmsnormHalfPrecision
(
const
AddRMSNormInfo
*
info
,
T
*
y
,
const
T
*
a
,
const
T
*
b
,
const
Tw
*
w
,
T
*
residual_out
)
{
infiniStatus_t
add_rmsnormHalfPrecision
(
const
AddRMSNormInfo
*
info
,
T
*
y
,
T
*
residual_out
,
const
T
*
a
,
const
T
*
b
,
const
Tw
*
w
)
{
static_assert
(
std
::
is_same
<
T
,
fp16_t
>::
value
||
std
::
is_same
<
T
,
bf16_t
>::
value
,
"T must be fp16_t or bf16_t"
);
...
...
@@ -112,32 +112,32 @@ infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
const
void
*
a
,
const
void
*
b
,
const
void
*
weight
,
void
*
residual_out
,
void
*
stream
)
const
{
void
*
y
,
void
*
residual_out
,
const
void
*
a
,
const
void
*
b
,
const
void
*
weight
,
void
*
stream
)
const
{
if
(
_info
.
atype
==
INFINI_DTYPE_F16
)
{
if
(
_info
.
wtype
==
INFINI_DTYPE_F16
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
fp16_t
*
)
y
,
(
const
fp16_t
*
)
a
,
(
const
fp16_t
*
)
b
,
(
const
fp16_t
*
)
weight
,
(
fp16_t
*
)
residual_out
));
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
fp16_t
*
)
y
,
(
fp16_t
*
)
residual_out
,
(
const
fp16_t
*
)
a
,
(
const
fp16_t
*
)
b
,
(
const
fp16_t
*
)
weight
));
}
else
if
(
_info
.
wtype
==
INFINI_DTYPE_F32
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
fp16_t
*
)
y
,
(
const
fp16_t
*
)
a
,
(
const
fp16_t
*
)
b
,
(
const
float
*
)
weight
,
(
fp16_t
*
)
residual_out
));
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
fp16_t
*
)
y
,
(
fp16_t
*
)
residual_out
,
(
const
fp16_t
*
)
a
,
(
const
fp16_t
*
)
b
,
(
const
float
*
)
weight
));
}
else
if
(
_info
.
wtype
==
INFINI_DTYPE_BF16
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
fp16_t
*
)
y
,
(
const
fp16_t
*
)
a
,
(
const
fp16_t
*
)
b
,
(
const
bf16_t
*
)
weight
,
(
fp16_t
*
)
residual_out
));
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
fp16_t
*
)
y
,
(
fp16_t
*
)
residual_out
,
(
const
fp16_t
*
)
a
,
(
const
fp16_t
*
)
b
,
(
const
bf16_t
*
)
weight
));
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
if
(
_info
.
atype
==
INFINI_DTYPE_BF16
)
{
if
(
_info
.
wtype
==
INFINI_DTYPE_BF16
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
bf16_t
*
)
y
,
(
const
bf16_t
*
)
a
,
(
const
bf16_t
*
)
b
,
(
const
bf16_t
*
)
weight
,
(
bf16_t
*
)
residual_out
));
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
bf16_t
*
)
y
,
(
bf16_t
*
)
residual_out
,
(
const
bf16_t
*
)
a
,
(
const
bf16_t
*
)
b
,
(
const
bf16_t
*
)
weight
));
}
else
if
(
_info
.
wtype
==
INFINI_DTYPE_F32
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
bf16_t
*
)
y
,
(
const
bf16_t
*
)
a
,
(
const
bf16_t
*
)
b
,
(
const
float
*
)
weight
,
(
bf16_t
*
)
residual_out
));
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
bf16_t
*
)
y
,
(
bf16_t
*
)
residual_out
,
(
const
bf16_t
*
)
a
,
(
const
bf16_t
*
)
b
,
(
const
float
*
)
weight
));
}
else
if
(
_info
.
wtype
==
INFINI_DTYPE_F16
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
bf16_t
*
)
y
,
(
const
bf16_t
*
)
a
,
(
const
bf16_t
*
)
b
,
(
const
fp16_t
*
)
weight
,
(
bf16_t
*
)
residual_out
));
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
bf16_t
*
)
y
,
(
bf16_t
*
)
residual_out
,
(
const
bf16_t
*
)
a
,
(
const
bf16_t
*
)
b
,
(
const
fp16_t
*
)
weight
));
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
if
(
_info
.
atype
==
INFINI_DTYPE_F32
)
{
CHECK_STATUS
(
add_rmsnorm
(
&
_info
,
(
float
*
)
y
,
(
const
float
*
)
a
,
(
const
float
*
)
b
,
(
const
float
*
)
weight
,
(
float
*
)
residual_out
));
CHECK_STATUS
(
add_rmsnorm
(
&
_info
,
(
float
*
)
y
,
(
float
*
)
residual_out
,
(
const
float
*
)
a
,
(
const
float
*
)
b
,
(
const
float
*
)
weight
));
}
else
if
(
_info
.
atype
==
INFINI_DTYPE_F64
)
{
CHECK_STATUS
(
add_rmsnorm
(
&
_info
,
(
double
*
)
y
,
(
const
double
*
)
a
,
(
const
double
*
)
b
,
(
const
double
*
)
weight
,
(
double
*
)
residual_out
));
CHECK_STATUS
(
add_rmsnorm
(
&
_info
,
(
double
*
)
y
,
(
double
*
)
residual_out
,
(
const
double
*
)
a
,
(
const
double
*
)
b
,
(
const
double
*
)
weight
));
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
...
...
src/infiniop/ops/add_rms_norm/info.h
View file @
784139b9
...
...
@@ -16,9 +16,9 @@ public:
float
epsilon
;
std
::
vector
<
size_t
>
shape
;
std
::
vector
<
ptrdiff_t
>
y_strides
;
std
::
vector
<
ptrdiff_t
>
residual_out_strides
;
std
::
vector
<
ptrdiff_t
>
a_strides
;
std
::
vector
<
ptrdiff_t
>
b_strides
;
std
::
vector
<
ptrdiff_t
>
residual_out_strides
;
bool
has_residual_out
;
size_t
ndim
()
const
{
return
shape
.
size
();
}
...
...
@@ -26,11 +26,11 @@ public:
static
utils
::
Result
<
AddRMSNormInfo
>
create
(
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
residual_out_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
,
infiniopTensorDescriptor_t
weight_desc
,
float
epsilon
,
infiniopTensorDescriptor_t
residual_out_desc
)
{
float
epsilon
)
{
auto
atype
=
y_desc
->
dtype
();
auto
wtype
=
weight_desc
->
dtype
();
...
...
src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.cuh
0 → 100644
View file @
784139b9
#ifndef __ADD_RMS_NORM_METAX_CUH__
#define __ADD_RMS_NORM_METAX_CUH__
#include "../add_rms_norm.h"
DESCRIPTOR
(
metax
)
#endif
src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.maca
0 → 100644
View file @
784139b9
#include "../../../devices/metax/metax_common.h"
#include "add_rms_norm_metax.cuh"
#include "../../../devices/metax/metax_kernel_common.h"
#include <cub/block/block_reduce.cuh>
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
// Kernel function template for add_rms_norm on Metax platform
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_METAX_KERNEL add_rmsnormKernel(
Tdata *__restrict__ y,
Tdata *__restrict__ residual_out,
ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
ptrdiff_t stride_residual_out_batch,
ptrdiff_t stride_residual_out_nhead,
const Tdata *__restrict__ a,
ptrdiff_t stride_a_batch,
ptrdiff_t stride_a_nhead,
const Tdata *__restrict__ b,
ptrdiff_t stride_b_batch,
ptrdiff_t stride_b_nhead,
const Tweight *__restrict__ w,
size_t nhead,
size_t dim,
float epsilon) {
add_rmsnormBlock<BLOCK_SIZE, Tcompute>(
y, residual_out,
stride_y_batch, stride_y_nhead,
stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
w, nhead, dim, epsilon);
}
namespace op::add_rms_norm::metax {
// Internal opaque structure for Metax device handle
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
// Destructor
Descriptor::~Descriptor() {
delete _opaque;
}
// Create descriptor for add_rms_norm operator
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t residual_out_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon) {
auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon);
CHECK_RESULT(result);
auto info = result.take();
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
std::move(info),
0,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
// Launch kernel with different data types
template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel(
uint32_t batch_size, size_t nhead, size_t dim,
void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead,
void *residual_out, ptrdiff_t stride_residual_out_batch, ptrdiff_t stride_residual_out_nhead,
const void *a, ptrdiff_t stride_a_batch, ptrdiff_t stride_a_nhead,
const void *b, ptrdiff_t stride_b_batch, ptrdiff_t stride_b_nhead,
const void *w, infiniDtype_t wtype,
float epsilon,
hcStream_t stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
add_rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size * nhead, BLOCK_SIZE, 0, stream>>>( \
reinterpret_cast<Tdata *>(y), \
reinterpret_cast<Tdata *>(residual_out), \
stride_y_batch, \
stride_y_nhead, \
stride_residual_out_batch, \
stride_residual_out_nhead, \
reinterpret_cast<const Tdata *>(a), \
stride_a_batch, \
stride_a_nhead, \
reinterpret_cast<const Tdata *>(b), \
stride_b_batch, \
stride_b_nhead, \
reinterpret_cast<const Tweight *>(w), \
nhead, \
dim, \
epsilon)
// Handle different data type combinations following Metax pattern
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(half, __hpcc_bfloat16, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(__hpcc_bfloat16, __hpcc_bfloat16, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(__hpcc_bfloat16, half, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(__hpcc_bfloat16, float, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(float, float, float);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
#undef LAUNCH_KERNEL
return INFINI_STATUS_SUCCESS;
}
// Main calculation function
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, void *residual_out, const void *a, const void *b, const void *weight,
void *stream) const {
// Check workspace size
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
// Extract tensor strides and dimensions
auto stride_a_batch = _info.a_strides[0];
auto stride_a_nhead = _info.a_strides[1];
auto stride_b_batch = _info.b_strides[0];
auto stride_b_nhead = _info.b_strides[1];
auto stride_y_batch = _info.y_strides[0];
auto stride_y_nhead = _info.y_strides[1];
auto stride_residual_out_batch = _info.residual_out_strides[0];
auto stride_residual_out_nhead = _info.residual_out_strides[1];
auto dim = _info.dim();
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1;
auto stream_ = reinterpret_cast<hcStream_t>(stream);
// Launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_512>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, stream_));
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_1024>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, stream_));
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_2048) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_2048>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, stream_));
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_4096) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_4096>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, stream_));
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::add_rms_norm::metax
src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.h
0 → 100644
View file @
784139b9
#ifndef __ADD_RMS_NORM_MOORE_H__
#define __ADD_RMS_NORM_MOORE_H__
#include "../add_rms_norm.h"
DESCRIPTOR
(
moore
)
#endif
src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.mu
0 → 100644
View file @
784139b9
#include "../../../devices/moore/moore_common.h"
#include "add_rms_norm_moore.h"
#include "../../../devices/moore/moore_kernel_common.h"
#include <cub/block/block_reduce.cuh>
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
// Kernel function template for add_rms_norm on Moore platform
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_MOORE_KERNEL add_rmsnormKernel(
Tdata *__restrict__ y,
Tdata *__restrict__ residual_out,
ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
ptrdiff_t stride_residual_out_batch,
ptrdiff_t stride_residual_out_nhead,
const Tdata *__restrict__ a,
ptrdiff_t stride_a_batch,
ptrdiff_t stride_a_nhead,
const Tdata *__restrict__ b,
ptrdiff_t stride_b_batch,
ptrdiff_t stride_b_nhead,
const Tweight *__restrict__ w,
size_t nhead,
size_t dim,
float epsilon) {
add_rmsnormBlock<BLOCK_SIZE, Tcompute>(
y, residual_out,
stride_y_batch, stride_y_nhead,
stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
w, nhead, dim, epsilon);
}
namespace op::add_rms_norm::moore {
// Internal opaque structure for Moore device handle
struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};
// Destructor
Descriptor::~Descriptor() {
delete _opaque;
}
// Create descriptor for add_rms_norm operator
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t residual_out_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon) {
auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon);
CHECK_RESULT(result);
auto info = result.take();
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
std::move(info),
0,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
// Launch kernel with different data types
template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel(
uint32_t batch_size, size_t nhead, size_t dim,
void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead,
void *residual_out, ptrdiff_t stride_residual_out_batch, ptrdiff_t stride_residual_out_nhead,
const void *a, ptrdiff_t stride_a_batch, ptrdiff_t stride_a_nhead,
const void *b, ptrdiff_t stride_b_batch, ptrdiff_t stride_b_nhead,
const void *w, infiniDtype_t wtype,
float epsilon,
musaStream_t musa_stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
add_rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size * nhead, BLOCK_SIZE, 0, musa_stream>>>( \
reinterpret_cast<Tdata *>(y), \
reinterpret_cast<Tdata *>(residual_out), \
stride_y_batch, \
stride_y_nhead, \
stride_residual_out_batch, \
stride_residual_out_nhead, \
reinterpret_cast<const Tdata *>(a), \
stride_a_batch, \
stride_a_nhead, \
reinterpret_cast<const Tdata *>(b), \
stride_b_batch, \
stride_b_nhead, \
reinterpret_cast<const Tweight *>(w), \
nhead, \
dim, \
epsilon)
// Handle different data type combinations
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(half, __mt_bfloat16, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(__mt_bfloat16, __mt_bfloat16, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(__mt_bfloat16, half, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(__mt_bfloat16, float, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(float, float, float);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
#undef LAUNCH_KERNEL
return INFINI_STATUS_SUCCESS;
}
// Main calculation function
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, void *residual_out, const void *a, const void *b, const void *weight,
void *stream) const {
// Check workspace size
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
// Extract tensor strides and dimensions
auto stride_a_batch = _info.a_strides[0];
auto stride_a_nhead = _info.a_strides[1];
auto stride_b_batch = _info.b_strides[0];
auto stride_b_nhead = _info.b_strides[1];
auto stride_y_batch = _info.y_strides[0];
auto stride_y_nhead = _info.y_strides[1];
auto stride_residual_out_batch = _info.residual_out_strides[0];
auto stride_residual_out_nhead = _info.residual_out_strides[1];
auto dim = _info.dim();
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1;
auto musa_stream = reinterpret_cast<musaStream_t>(stream);
// Launch kernel with appropriate block size based on device capability
if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_1024>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, musa_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_512>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, musa_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_2048) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_2048>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, musa_stream));
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::add_rms_norm::moore
src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu
View file @
784139b9
...
...
@@ -49,12 +49,12 @@ infiniStatus_t Descriptor::create(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
residual_out_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
,
infiniopTensorDescriptor_t
weight_desc
,
float
epsilon
,
infiniopTensorDescriptor_t
residual_out_desc
)
{
auto
result
=
AddRMSNormInfo
::
create
(
y_desc
,
a_desc
,
b_desc
,
weight_desc
,
epsilon
,
residual_out_desc
);
float
epsilon
)
{
auto
result
=
AddRMSNormInfo
::
create
(
y_desc
,
residual_out_desc
,
a_desc
,
b_desc
,
weight_desc
,
epsilon
);
CHECK_RESULT
(
result
);
auto
info
=
result
.
take
();
...
...
@@ -122,8 +122,8 @@ infiniStatus_t launchKernel(
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
const
void
*
a
,
const
void
*
b
,
const
void
*
weight
,
void
*
residual_out
,
void
*
stream
)
const
{
void
*
y
,
void
*
residual_out
,
const
void
*
a
,
const
void
*
b
,
const
void
*
weight
,
void
*
stream
)
const
{
if
(
workspace_size
<
_workspace_size
)
{
return
INFINI_STATUS_INSUFFICIENT_WORKSPACE
;
...
...
@@ -143,7 +143,15 @@ infiniStatus_t Descriptor::calculate(
auto
cuda_stream
=
reinterpret_cast
<
cudaStream_t
>
(
stream
);
// launch kernel with different block sizes
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_1024
)
{
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_512
)
{
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_512
>
(
batch_size
,
nhead
,
dim
,
y
,
_info
.
atype
,
stride_y_batch
,
stride_y_nhead
,
residual_out
,
stride_residual_out_batch
,
stride_residual_out_nhead
,
a
,
stride_a_batch
,
stride_a_nhead
,
b
,
stride_b_batch
,
stride_b_nhead
,
weight
,
_info
.
wtype
,
_info
.
epsilon
,
cuda_stream
));
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_1024
)
{
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_1024
>
(
batch_size
,
nhead
,
dim
,
y
,
_info
.
atype
,
stride_y_batch
,
stride_y_nhead
,
...
...
@@ -151,8 +159,8 @@ infiniStatus_t Descriptor::calculate(
a
,
stride_a_batch
,
stride_a_nhead
,
b
,
stride_b_batch
,
stride_b_nhead
,
weight
,
_info
.
wtype
,
_info
.
epsilon
,
cuda_stream
));
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_
512
)
{
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_
512
>
(
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_
2048
)
{
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_
2048
>
(
batch_size
,
nhead
,
dim
,
y
,
_info
.
atype
,
stride_y_batch
,
stride_y_nhead
,
residual_out
,
stride_residual_out_batch
,
stride_residual_out_nhead
,
...
...
src/infiniop/ops/add_rms_norm/operator.cc
View file @
784139b9
...
...
@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/add_rms_norm_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
|| defined(ENABLE_ALI_API)
#include "nvidia/add_rms_norm_nvidia.cuh"
#endif
#ifdef ENABLE_ASCEND_API
...
...
@@ -17,12 +17,10 @@
// #include "bang/add_rms_norm_bang.h"
#endif
#ifdef ENABLE_METAX_API
// TODO: Add Metax implementation
// #include "metax/add_rms_norm_metax.cuh"
#include "metax/add_rms_norm_metax.cuh"
#endif
#ifdef ENABLE_MOORE_API
// TODO: Add Moore implementation
// #include "moore/add_rms_norm_moore.h"
#include "moore/add_rms_norm_moore.h"
#endif
#ifdef ENABLE_KUNLUN_API
// TODO: Add Kunlun implementation
...
...
@@ -32,12 +30,12 @@
__C
infiniStatus_t
infiniopCreateAddRMSNormDescriptor
(
infiniopHandle_t
handle
,
infiniopAddRMSNormDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
residual_out_desc
,
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
,
infiniopTensorDescriptor_t
weight_desc
,
float
epsilon
,
infiniopTensorDescriptor_t
residual_out_desc
)
{
float
epsilon
)
{
#define CREATE(CASE, NAMESPACE) \
case CASE: \
...
...
@@ -45,11 +43,11 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor(
handle, \
reinterpret_cast<op::add_rms_norm::NAMESPACE::Descriptor **>(desc_ptr), \
y_desc, \
residual_out_desc, \
a_desc, \
b_desc, \
weight_desc, \
epsilon, \
residual_out_desc)
epsilon)
switch
(
handle
->
device
)
{
#ifdef ENABLE_CPU_API
...
...
@@ -61,6 +59,15 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor(
#ifdef ENABLE_ILUVATAR_API
CREATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
CREATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_MOORE_API
CREATE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
#ifdef ENABLE_METAX_API
CREATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_QY_API
CREATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
...
...
@@ -94,6 +101,15 @@ __C infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescript
#ifdef ENABLE_ILUVATAR_API
GET
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
GET
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_MOORE_API
GET
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
#ifdef ENABLE_METAX_API
GET
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_QY_API
GET
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
...
...
@@ -116,16 +132,16 @@ __C infiniStatus_t infiniopAddRMSNorm(
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
void
*
residual_out
,
const
void
*
a
,
const
void
*
b
,
const
void
*
weight
,
void
*
residual_out
,
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::add_rms_norm::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, y,
a, b, weight, residual_ou
t, stream)
->calculate(workspace, workspace_size, y,
residual_out, a, b, weigh
t, stream)
switch
(
desc
->
device_type
)
{
...
...
@@ -138,6 +154,15 @@ __C infiniStatus_t infiniopAddRMSNorm(
#ifdef ENABLE_ILUVATAR_API
CALCULATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
CALCULATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_MOORE_API
CALCULATE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
#ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_QY_API
CALCULATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
...
...
@@ -173,6 +198,15 @@ __C infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescrip
#ifdef ENABLE_ILUVATAR_API
DESTROY
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
DESTROY
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_MOORE_API
DESTROY
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
#ifdef ENABLE_METAX_API
DESTROY
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_QY_API
DESTROY
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
...
...
Prev
1
…
3
4
5
6
7
8
9
10
11
…
30
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment