Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
7712471f
Commit
7712471f
authored
Dec 24, 2025
by
zhuyue
Browse files
Add NVIDIA GPU implementation for add_rms_norm and make residual_out required.
parent
2a432b34
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
355 additions
and
134 deletions
+355
-134
python/infinicore/ops/add_rms_norm.py
python/infinicore/ops/add_rms_norm.py
+22
-6
src/infiniop/ops/add_rms_norm/add_rms_norm.h
src/infiniop/ops/add_rms_norm/add_rms_norm.h
+10
-10
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc
+18
-50
src/infiniop/ops/add_rms_norm/cuda/kernel.cuh
src/infiniop/ops/add_rms_norm/cuda/kernel.cuh
+63
-0
src/infiniop/ops/add_rms_norm/info.h
src/infiniop/ops/add_rms_norm/info.h
+25
-33
src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu
src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu
+175
-0
src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cuh
src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cuh
+8
-0
src/infiniop/ops/add_rms_norm/operator.cc
src/infiniop/ops/add_rms_norm/operator.cc
+34
-35
No files found.
python/infinicore/ops/add_rms_norm.py
View file @
7712471f
...
...
@@ -5,27 +5,43 @@ from infinicore.tensor import Tensor
def
add_rms_norm
(
a
,
b
,
weight
,
epsilon
=
1e-5
,
*
,
out
=
None
):
"""
Fused Add and RMS Normalization.
Args:
a: First input tensor
b: Second input tensor
weight: Scale weights
epsilon: Small constant for numerical stability, default is 1e-5
out: Optional output tuple (y, residual_out) for in-place operation
Returns:
Tuple of (normalized_result, add_result): (RMSNorm(a + b) * weight, a + b)
The add_result can be used as residual for subsequent layers.
"""
if
out
is
None
:
result
=
_infinicore
.
add_rms_norm
(
a
.
_underlying
,
b
.
_underlying
,
weight
.
_underlying
,
epsilon
)
result
=
_infinicore
.
add_rms_norm
(
a
.
_underlying
,
b
.
_underlying
,
weight
.
_underlying
,
epsilon
)
return
(
Tensor
(
result
[
0
]),
Tensor
(
result
[
1
]))
y
,
residual_out
=
out
_infinicore
.
add_rms_norm_
(
y
.
_underlying
,
residual_out
.
_underlying
,
a
.
_underlying
,
b
.
_underlying
,
weight
.
_underlying
,
epsilon
)
_infinicore
.
add_rms_norm_
(
y
.
_underlying
,
residual_out
.
_underlying
,
a
.
_underlying
,
b
.
_underlying
,
weight
.
_underlying
,
epsilon
,
)
return
(
y
,
residual_out
)
def
add_rms_norm_
(
y
,
residual_out
,
a
,
b
,
weight
,
epsilon
=
1e-5
):
"""In-place Fused Add and RMS Normalization."""
_infinicore
.
add_rms_norm_
(
y
.
_underlying
,
residual_out
.
_underlying
,
a
.
_underlying
,
b
.
_underlying
,
weight
.
_underlying
,
epsilon
)
_infinicore
.
add_rms_norm_
(
y
.
_underlying
,
residual_out
.
_underlying
,
a
.
_underlying
,
b
.
_underlying
,
weight
.
_underlying
,
epsilon
,
)
src/infiniop/ops/add_rms_norm/add_rms_norm.h
View file @
7712471f
...
...
@@ -6,8 +6,8 @@
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::add_rms_norm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
namespace op::add_rms_norm::NAMESPACE {
\
class Descriptor final : public InfiniopDescriptor {
\
struct Opaque; \
Opaque *_opaque; \
AddRMSNormInfo _info; \
...
...
@@ -19,7 +19,7 @@
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
: InfiniopDescriptor{device_type, device_id},
\
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
...
...
@@ -29,24 +29,24 @@
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
static infiniStatus_t create(
\
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t y_desc,
\
infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t b_desc, \
infiniopTensorDescriptor_t weight_desc,
\
float epsilon, \
infiniopTensorDescriptor_t residual_out_desc); \
infiniopTensorDescriptor_t weight_desc, \
float epsilon,
\
infiniopTensorDescriptor_t residual_out_desc);
\
\
infiniStatus_t calculate( \
infiniStatus_t calculate(
\
void *workspace, size_t workspace_size, \
void *y, \
const void *a, \
const void *b, \
const void *weight, \
void *residual_out, \
void *stream) const; \
void *stream) const;
\
}; \
}
...
...
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc
View file @
7712471f
...
...
@@ -36,16 +36,13 @@ infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T
const
T
*
a_ptr
=
a
+
i
*
info
->
a_strides
[
0
]
+
j
*
info
->
a_strides
[
1
];
const
T
*
b_ptr
=
b
+
i
*
info
->
b_strides
[
0
]
+
j
*
info
->
b_strides
[
1
];
T
*
y_ptr
=
y
+
i
*
info
->
y_strides
[
0
]
+
j
*
info
->
y_strides
[
1
];
T
*
residual_out_ptr
=
info
->
has_residual_out
?
(
residual_out
+
i
*
info
->
residual_out_strides
[
0
]
+
j
*
info
->
residual_out_strides
[
1
])
:
nullptr
;
T
*
residual_out_ptr
=
residual_out
+
i
*
info
->
residual_out_strides
[
0
]
+
j
*
info
->
residual_out_strides
[
1
];
// Compute add(a, b) once and store it
T
sum_squared
=
(
T
)
0
;
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
T
sum_val
=
a_ptr
[
k
]
+
b_ptr
[
k
];
if
(
residual_out_ptr
!=
nullptr
)
{
residual_out_ptr
[
k
]
=
sum_val
;
// Store add result
}
residual_out_ptr
[
k
]
=
sum_val
;
// Store add result
sum_squared
+=
sum_val
*
sum_val
;
}
...
...
@@ -54,18 +51,9 @@ infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T
T
rms
=
(
T
)
1
/
std
::
sqrt
(
sum_squared
/
(
T
)(
dim
)
+
(
T
)(
info
->
epsilon
));
// Apply normalization: y = (a + b) * w * rms
// Reuse the stored sum values if residual_out was computed, otherwise recompute
if
(
residual_out_ptr
!=
nullptr
)
{
// Reuse stored values
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
y_ptr
[
k
]
=
residual_out_ptr
[
k
]
*
w
[
k
]
*
rms
;
}
}
else
{
// Recompute sum
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
T
sum_val
=
a_ptr
[
k
]
+
b_ptr
[
k
];
y_ptr
[
k
]
=
sum_val
*
w
[
k
]
*
rms
;
}
// Reuse stored values from residual_out
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
y_ptr
[
k
]
=
residual_out_ptr
[
k
]
*
w
[
k
]
*
rms
;
}
}
...
...
@@ -90,16 +78,13 @@ infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const
const
T
*
a_ptr
=
a
+
i
*
info
->
a_strides
[
0
]
+
j
*
info
->
a_strides
[
1
];
const
T
*
b_ptr
=
b
+
i
*
info
->
b_strides
[
0
]
+
j
*
info
->
b_strides
[
1
];
T
*
y_ptr
=
y
+
i
*
info
->
y_strides
[
0
]
+
j
*
info
->
y_strides
[
1
];
T
*
residual_out_ptr
=
info
->
has_residual_out
?
(
residual_out
+
i
*
info
->
residual_out_strides
[
0
]
+
j
*
info
->
residual_out_strides
[
1
])
:
nullptr
;
T
*
residual_out_ptr
=
residual_out
+
i
*
info
->
residual_out_strides
[
0
]
+
j
*
info
->
residual_out_strides
[
1
];
// Compute sum of squares for RMS normalization and store add result
float
sum_squared
=
0.0
f
;
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
float
sum_val
=
utils
::
cast
<
float
>
(
a_ptr
[
k
])
+
utils
::
cast
<
float
>
(
b_ptr
[
k
]);
if
(
residual_out_ptr
!=
nullptr
)
{
residual_out_ptr
[
k
]
=
utils
::
cast
<
T
>
(
sum_val
);
// Store add result
}
residual_out_ptr
[
k
]
=
utils
::
cast
<
T
>
(
sum_val
);
// Store add result
sum_squared
+=
sum_val
*
sum_val
;
}
...
...
@@ -107,35 +92,18 @@ infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const
float
rms
=
1.
f
/
std
::
sqrt
(
sum_squared
/
(
float
)(
dim
)
+
info
->
epsilon
);
// Apply normalization: y = (a + b) * w * rms
// Reuse stored values if residual_out was computed, otherwise recompute
if
(
residual_out_ptr
!=
nullptr
)
{
// Reuse stored values
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
float
sum_val
=
utils
::
cast
<
float
>
(
residual_out_ptr
[
k
]);
float
val
;
if
constexpr
(
std
::
is_same
<
Tw
,
float
>::
value
)
{
val
=
sum_val
*
w
[
k
]
*
rms
;
}
else
if
constexpr
(
std
::
is_same
<
Tw
,
T
>::
value
||
std
::
is_same_v
<
Tw
,
fp16_t
>
||
std
::
is_same_v
<
Tw
,
bf16_t
>
)
{
val
=
sum_val
*
utils
::
cast
<
float
>
(
w
[
k
])
*
rms
;
}
else
{
std
::
abort
();
}
y_ptr
[
k
]
=
utils
::
cast
<
T
>
(
val
);
}
}
else
{
// Recompute sum
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
float
sum_val
=
utils
::
cast
<
float
>
(
a_ptr
[
k
])
+
utils
::
cast
<
float
>
(
b_ptr
[
k
]);
float
val
;
if
constexpr
(
std
::
is_same
<
Tw
,
float
>::
value
)
{
val
=
sum_val
*
w
[
k
]
*
rms
;
}
else
if
constexpr
(
std
::
is_same
<
Tw
,
T
>::
value
||
std
::
is_same_v
<
Tw
,
fp16_t
>
||
std
::
is_same_v
<
Tw
,
bf16_t
>
)
{
val
=
sum_val
*
utils
::
cast
<
float
>
(
w
[
k
])
*
rms
;
}
else
{
std
::
abort
();
}
y_ptr
[
k
]
=
utils
::
cast
<
T
>
(
val
);
// Reuse stored values from residual_out
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
float
sum_val
=
utils
::
cast
<
float
>
(
residual_out_ptr
[
k
]);
float
val
;
if
constexpr
(
std
::
is_same
<
Tw
,
float
>::
value
)
{
val
=
sum_val
*
w
[
k
]
*
rms
;
}
else
if
constexpr
(
std
::
is_same
<
Tw
,
T
>::
value
||
std
::
is_same_v
<
Tw
,
fp16_t
>
||
std
::
is_same_v
<
Tw
,
bf16_t
>
)
{
val
=
sum_val
*
utils
::
cast
<
float
>
(
w
[
k
])
*
rms
;
}
else
{
std
::
abort
();
}
y_ptr
[
k
]
=
utils
::
cast
<
T
>
(
val
);
}
}
...
...
src/infiniop/ops/add_rms_norm/cuda/kernel.cuh
0 → 100644
View file @
7712471f
#ifndef __ADD_RMS_NORM_CUDA_KERNEL_H__
#define __ADD_RMS_NORM_CUDA_KERNEL_H__
#include <cub/block/block_reduce.cuh>
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tcompute
,
typename
Tdata
,
typename
Tweight
>
__device__
void
add_rmsnormBlock
(
Tdata
*
__restrict__
y
,
Tdata
*
__restrict__
residual_out
,
ptrdiff_t
stride_y_batch
,
ptrdiff_t
stride_y_nhead
,
ptrdiff_t
stride_residual_out_batch
,
ptrdiff_t
stride_residual_out_nhead
,
const
Tdata
*
__restrict__
a
,
ptrdiff_t
stride_a_batch
,
ptrdiff_t
stride_a_nhead
,
const
Tdata
*
__restrict__
b
,
ptrdiff_t
stride_b_batch
,
ptrdiff_t
stride_b_nhead
,
const
Tweight
*
__restrict__
w
,
size_t
nhead
,
size_t
dim
,
float
epsilon
)
{
// Each block takes care of one head in one batch
// Each thread deals with every block_size element in the row
size_t
batch_idx
=
blockIdx
.
x
/
nhead
;
size_t
head_idx
=
blockIdx
.
x
%
nhead
;
auto
y_ptr
=
y
+
batch_idx
*
stride_y_batch
+
head_idx
*
stride_y_nhead
;
auto
a_ptr
=
a
+
batch_idx
*
stride_a_batch
+
head_idx
*
stride_a_nhead
;
auto
b_ptr
=
b
+
batch_idx
*
stride_b_batch
+
head_idx
*
stride_b_nhead
;
auto
w_ptr
=
w
;
Tdata
*
residual_out_ptr
=
residual_out
+
batch_idx
*
stride_residual_out_batch
+
head_idx
*
stride_residual_out_nhead
;
// Compute add(a, b) and sum of squares in one pass
Tcompute
sum_squared
=
0
;
for
(
size_t
i
=
threadIdx
.
x
;
i
<
dim
;
i
+=
BLOCK_SIZE
)
{
Tcompute
sum_val
=
Tcompute
(
a_ptr
[
i
])
+
Tcompute
(
b_ptr
[
i
]);
residual_out_ptr
[
i
]
=
Tdata
(
sum_val
);
// Store add result
sum_squared
+=
sum_val
*
sum_val
;
}
// Block-reduce sum of squares
using
BlockReduce
=
cub
::
BlockReduce
<
Tcompute
,
BLOCK_SIZE
>
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
sum_squared
=
BlockReduce
(
temp_storage
).
Sum
(
sum_squared
);
// Thread_0 computes RMS=1/sqrt(ss/dim+epsilon) and stores in shared memory
__shared__
Tcompute
rms
;
if
(
threadIdx
.
x
==
0
)
{
rms
=
Tcompute
(
rsqrtf
(
sum_squared
/
Tcompute
(
dim
)
+
epsilon
));
}
__syncthreads
();
// Apply normalization: y = (a + b) * w * rms
// Reuse stored values from residual_out
for
(
size_t
i
=
threadIdx
.
x
;
i
<
dim
;
i
+=
BLOCK_SIZE
)
{
Tcompute
sum_val
=
Tcompute
(
residual_out_ptr
[
i
]);
// Reuse stored value
y_ptr
[
i
]
=
Tdata
(
sum_val
*
Tcompute
(
w_ptr
[
i
])
*
rms
);
}
}
#endif
src/infiniop/ops/add_rms_norm/info.h
View file @
7712471f
...
...
@@ -34,12 +34,12 @@ public:
auto
atype
=
y_desc
->
dtype
();
auto
wtype
=
weight_desc
->
dtype
();
// Check that all input tensors have the same dtype
if
(
a_desc
->
dtype
()
!=
atype
||
b_desc
->
dtype
()
!=
atype
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
atype
==
INFINI_DTYPE_F16
||
atype
==
INFINI_DTYPE_BF16
)
{
// For half-precision types (FP16/BF16), weights can be the same half-precision type or FP32
if
(
wtype
!=
atype
&&
wtype
!=
INFINI_DTYPE_F32
&&
wtype
!=
INFINI_DTYPE_BF16
&&
wtype
!=
INFINI_DTYPE_F16
)
{
...
...
@@ -71,9 +71,7 @@ public:
batch
=
y_desc
->
dim
(
0
);
dim
=
y_desc
->
dim
(
1
);
if
(
a_desc
->
dim
(
0
)
!=
batch
||
a_desc
->
dim
(
1
)
!=
dim
||
b_desc
->
dim
(
0
)
!=
batch
||
b_desc
->
dim
(
1
)
!=
dim
||
weight_desc
->
dim
(
0
)
!=
dim
)
{
if
(
a_desc
->
dim
(
0
)
!=
batch
||
a_desc
->
dim
(
1
)
!=
dim
||
b_desc
->
dim
(
0
)
!=
batch
||
b_desc
->
dim
(
1
)
!=
dim
||
weight_desc
->
dim
(
0
)
!=
dim
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
else
if
(
y_ndim
==
3
)
{
...
...
@@ -81,9 +79,7 @@ public:
nhead
=
y_desc
->
dim
(
1
);
dim
=
y_desc
->
dim
(
2
);
if
(
a_desc
->
dim
(
0
)
!=
batch
||
a_desc
->
dim
(
1
)
!=
nhead
||
a_desc
->
dim
(
2
)
!=
dim
||
b_desc
->
dim
(
0
)
!=
batch
||
b_desc
->
dim
(
1
)
!=
nhead
||
b_desc
->
dim
(
2
)
!=
dim
||
weight_desc
->
dim
(
0
)
!=
dim
)
{
if
(
a_desc
->
dim
(
0
)
!=
batch
||
a_desc
->
dim
(
1
)
!=
nhead
||
a_desc
->
dim
(
2
)
!=
dim
||
b_desc
->
dim
(
0
)
!=
batch
||
b_desc
->
dim
(
1
)
!=
nhead
||
b_desc
->
dim
(
2
)
!=
dim
||
weight_desc
->
dim
(
0
)
!=
dim
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
else
{
...
...
@@ -91,32 +87,30 @@ public:
}
// Check contiguity of the last dimension
if
(
y_desc
->
stride
(
y_ndim
-
1
)
!=
1
||
a_desc
->
stride
(
a_ndim
-
1
)
!=
1
||
b_desc
->
stride
(
b_ndim
-
1
)
!=
1
||
weight_desc
->
stride
(
w_ndim
-
1
)
!=
1
)
{
if
(
y_desc
->
stride
(
y_ndim
-
1
)
!=
1
||
a_desc
->
stride
(
a_ndim
-
1
)
!=
1
||
b_desc
->
stride
(
b_ndim
-
1
)
!=
1
||
weight_desc
->
stride
(
w_ndim
-
1
)
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
// Check residual_out_desc if provided
bool
has_residual_out
=
(
residual_out_desc
!=
nullptr
);
if
(
has_residual_out
)
{
const
size_t
residual_out_ndim
=
residual_out_desc
->
ndim
();
if
(
residual_out_ndim
!=
y_ndim
)
{
// residual_out_desc is required (always needed for fused operator)
if
(
residual_out_desc
==
nullptr
)
{
return
INFINI_STATUS_BAD_PARAM
;
}
const
size_t
residual_out_ndim
=
residual_out_desc
->
ndim
();
if
(
residual_out_ndim
!=
y_ndim
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
residual_out_desc
->
dtype
()
!=
atype
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
// Check shape matches
for
(
size_t
i
=
0
;
i
<
y_ndim
;
i
++
)
{
if
(
residual_out_desc
->
dim
(
i
)
!=
y_desc
->
dim
(
i
))
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
residual_out_desc
->
dtype
()
!=
atype
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
// Check shape matches
for
(
size_t
i
=
0
;
i
<
y_ndim
;
i
++
)
{
if
(
residual_out_desc
->
dim
(
i
)
!=
y_desc
->
dim
(
i
))
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
if
(
residual_out_desc
->
stride
(
residual_out_ndim
-
1
)
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
}
if
(
residual_out_desc
->
stride
(
residual_out_ndim
-
1
)
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
AddRMSNormInfo
info
;
...
...
@@ -127,10 +121,8 @@ public:
info
.
y_strides
=
y_desc
->
strides
();
info
.
a_strides
=
a_desc
->
strides
();
info
.
b_strides
=
b_desc
->
strides
();
info
.
has_residual_out
=
has_residual_out
;
if
(
has_residual_out
)
{
info
.
residual_out_strides
=
residual_out_desc
->
strides
();
}
info
.
has_residual_out
=
true
;
// Always true now
info
.
residual_out_strides
=
residual_out_desc
->
strides
();
return
utils
::
Result
<
AddRMSNormInfo
>
(
info
);
}
};
...
...
src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu
0 → 100644
View file @
7712471f
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "add_rms_norm_nvidia.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include <cub/block/block_reduce.cuh>
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tcompute
,
typename
Tdata
,
typename
Tweight
>
INFINIOP_CUDA_KERNEL
add_rmsnormKernel
(
Tdata
*
__restrict__
y
,
Tdata
*
__restrict__
residual_out
,
ptrdiff_t
stride_y_batch
,
ptrdiff_t
stride_y_nhead
,
ptrdiff_t
stride_residual_out_batch
,
ptrdiff_t
stride_residual_out_nhead
,
const
Tdata
*
__restrict__
a
,
ptrdiff_t
stride_a_batch
,
ptrdiff_t
stride_a_nhead
,
const
Tdata
*
__restrict__
b
,
ptrdiff_t
stride_b_batch
,
ptrdiff_t
stride_b_nhead
,
const
Tweight
*
__restrict__
w
,
size_t
nhead
,
size_t
dim
,
float
epsilon
)
{
add_rmsnormBlock
<
BLOCK_SIZE
,
Tcompute
>
(
y
,
residual_out
,
stride_y_batch
,
stride_y_nhead
,
stride_residual_out_batch
,
stride_residual_out_nhead
,
a
,
stride_a_batch
,
stride_a_nhead
,
b
,
stride_b_batch
,
stride_b_nhead
,
w
,
nhead
,
dim
,
epsilon
);
}
namespace
op
::
add_rms_norm
::
nvidia
{
struct
Descriptor
::
Opaque
{
std
::
shared_ptr
<
device
::
nvidia
::
Handle
::
Internal
>
internal
;
};
Descriptor
::~
Descriptor
()
{
delete
_opaque
;
}
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
,
infiniopTensorDescriptor_t
weight_desc
,
float
epsilon
,
infiniopTensorDescriptor_t
residual_out_desc
)
{
auto
result
=
AddRMSNormInfo
::
create
(
y_desc
,
a_desc
,
b_desc
,
weight_desc
,
epsilon
,
residual_out_desc
);
CHECK_RESULT
(
result
);
auto
info
=
result
.
take
();
*
desc_ptr
=
new
Descriptor
(
new
Opaque
{
reinterpret_cast
<
device
::
nvidia
::
Handle
*>
(
handle
)
->
internal
()},
std
::
move
(
info
),
0
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
// launch kernel with different data types
template
<
unsigned
int
BLOCK_SIZE
>
infiniStatus_t
launchKernel
(
uint32_t
batch_size
,
size_t
nhead
,
size_t
dim
,
void
*
y
,
infiniDtype_t
atype
,
ptrdiff_t
stride_y_batch
,
ptrdiff_t
stride_y_nhead
,
void
*
residual_out
,
ptrdiff_t
stride_residual_out_batch
,
ptrdiff_t
stride_residual_out_nhead
,
const
void
*
a
,
ptrdiff_t
stride_a_batch
,
ptrdiff_t
stride_a_nhead
,
const
void
*
b
,
ptrdiff_t
stride_b_batch
,
ptrdiff_t
stride_b_nhead
,
const
void
*
w
,
infiniDtype_t
wtype
,
float
epsilon
,
cudaStream_t
cuda_stream
)
{
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
add_rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size * nhead, BLOCK_SIZE, 0, cuda_stream>>>( \
reinterpret_cast<Tdata *>(y), \
reinterpret_cast<Tdata *>(residual_out), \
stride_y_batch, \
stride_y_nhead, \
stride_residual_out_batch, \
stride_residual_out_nhead, \
reinterpret_cast<const Tdata *>(a), \
stride_a_batch, \
stride_a_nhead, \
reinterpret_cast<const Tdata *>(b), \
stride_b_batch, \
stride_b_nhead, \
reinterpret_cast<const Tweight *>(w), \
nhead, \
dim, \
epsilon)
if
(
atype
==
INFINI_DTYPE_F16
&&
wtype
==
INFINI_DTYPE_F16
)
{
LAUNCH_KERNEL
(
half
,
half
,
float
);
}
else
if
(
atype
==
INFINI_DTYPE_F16
&&
wtype
==
INFINI_DTYPE_BF16
)
{
LAUNCH_KERNEL
(
half
,
__nv_bfloat16
,
float
);
}
else
if
(
atype
==
INFINI_DTYPE_F16
&&
wtype
==
INFINI_DTYPE_F32
)
{
LAUNCH_KERNEL
(
half
,
float
,
float
);
}
else
if
(
atype
==
INFINI_DTYPE_BF16
&&
wtype
==
INFINI_DTYPE_BF16
)
{
LAUNCH_KERNEL
(
__nv_bfloat16
,
__nv_bfloat16
,
float
);
}
else
if
(
atype
==
INFINI_DTYPE_BF16
&&
wtype
==
INFINI_DTYPE_F16
)
{
LAUNCH_KERNEL
(
__nv_bfloat16
,
half
,
float
);
}
else
if
(
atype
==
INFINI_DTYPE_BF16
&&
wtype
==
INFINI_DTYPE_F32
)
{
LAUNCH_KERNEL
(
__nv_bfloat16
,
float
,
float
);
}
else
if
(
atype
==
INFINI_DTYPE_F32
&&
wtype
==
INFINI_DTYPE_F32
)
{
LAUNCH_KERNEL
(
float
,
float
,
float
);
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
#undef LAUNCH_KERNEL
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
const
void
*
a
,
const
void
*
b
,
const
void
*
weight
,
void
*
residual_out
,
void
*
stream
)
const
{
if
(
workspace_size
<
_workspace_size
)
{
return
INFINI_STATUS_INSUFFICIENT_WORKSPACE
;
}
auto
stride_a_batch
=
_info
.
a_strides
[
0
];
auto
stride_a_nhead
=
_info
.
a_strides
[
1
];
auto
stride_b_batch
=
_info
.
b_strides
[
0
];
auto
stride_b_nhead
=
_info
.
b_strides
[
1
];
auto
stride_y_batch
=
_info
.
y_strides
[
0
];
auto
stride_y_nhead
=
_info
.
y_strides
[
1
];
auto
stride_residual_out_batch
=
_info
.
residual_out_strides
[
0
];
auto
stride_residual_out_nhead
=
_info
.
residual_out_strides
[
1
];
auto
dim
=
_info
.
dim
();
uint32_t
batch_size
=
static_cast
<
uint32_t
>
(
_info
.
shape
[
0
]);
size_t
nhead
=
_info
.
shape
.
size
()
>
2
?
_info
.
shape
[
1
]
:
1
;
auto
cuda_stream
=
reinterpret_cast
<
cudaStream_t
>
(
stream
);
// launch kernel with different block sizes
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_1024
)
{
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_1024
>
(
batch_size
,
nhead
,
dim
,
y
,
_info
.
atype
,
stride_y_batch
,
stride_y_nhead
,
residual_out
,
stride_residual_out_batch
,
stride_residual_out_nhead
,
a
,
stride_a_batch
,
stride_a_nhead
,
b
,
stride_b_batch
,
stride_b_nhead
,
weight
,
_info
.
wtype
,
_info
.
epsilon
,
cuda_stream
));
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_512
)
{
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_512
>
(
batch_size
,
nhead
,
dim
,
y
,
_info
.
atype
,
stride_y_batch
,
stride_y_nhead
,
residual_out
,
stride_residual_out_batch
,
stride_residual_out_nhead
,
a
,
stride_a_batch
,
stride_a_nhead
,
b
,
stride_b_batch
,
stride_b_nhead
,
weight
,
_info
.
wtype
,
_info
.
epsilon
,
cuda_stream
));
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_4096
)
{
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_4096
>
(
batch_size
,
nhead
,
dim
,
y
,
_info
.
atype
,
stride_y_batch
,
stride_y_nhead
,
residual_out
,
stride_residual_out_batch
,
stride_residual_out_nhead
,
a
,
stride_a_batch
,
stride_a_nhead
,
b
,
stride_b_batch
,
stride_b_nhead
,
weight
,
_info
.
wtype
,
_info
.
epsilon
,
cuda_stream
));
}
else
{
return
INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::add_rms_norm::nvidia
src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cuh
0 → 100644
View file @
7712471f
#ifndef __ADD_RMS_NORM_NVIDIA_CUDA_H__
#define __ADD_RMS_NORM_NVIDIA_CUDA_H__
#include "../add_rms_norm.h"
DESCRIPTOR
(
nvidia
)
#endif
src/infiniop/ops/add_rms_norm/operator.cc
View file @
7712471f
...
...
@@ -6,8 +6,7 @@
#include "cpu/add_rms_norm_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
// TODO: Add NVIDIA implementation
// #include "nvidia/add_rms_norm_nvidia.cuh"
#include "nvidia/add_rms_norm_nvidia.cuh"
#endif
#ifdef ENABLE_ASCEND_API
// TODO: Add Ascend implementation
...
...
@@ -40,16 +39,16 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor(
float
epsilon
,
infiniopTensorDescriptor_t
residual_out_desc
)
{
#define CREATE(CASE, NAMESPACE)
\
case CASE:
\
return op::add_rms_norm::NAMESPACE::Descriptor::create(
\
handle,
\
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::add_rms_norm::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::add_rms_norm::NAMESPACE::Descriptor **>(desc_ptr), \
y_desc,
\
a_desc,
\
b_desc,
\
weight_desc,
\
epsilon,
\
y_desc, \
a_desc, \
b_desc, \
weight_desc, \
epsilon, \
residual_out_desc)
switch
(
handle
->
device
)
{
...
...
@@ -57,16 +56,16 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor(
CREATE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
//
CREATE(INFINI_DEVICE_NVIDIA, nvidia);
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
//
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
CREATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
//
CREATE(INFINI_DEVICE_QY, nvidia);
CREATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_HYGON_API
//
CREATE(INFINI_DEVICE_HYGON, nvidia);
CREATE
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#ifdef ENABLE_KUNLUN_API
// CREATE(INFINI_DEVICE_KUNLUN, kunlun);
...
...
@@ -80,8 +79,8 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor(
__C
infiniStatus_t
infiniopGetAddRMSNormWorkspaceSize
(
infiniopAddRMSNormDescriptor_t
desc
,
size_t
*
size
)
{
#define GET(CASE, NAMESPACE) \
case CASE: \
#define GET(CASE, NAMESPACE)
\
case CASE:
\
*size = reinterpret_cast<op::add_rms_norm::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS
...
...
@@ -90,16 +89,16 @@ __C infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescript
GET
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
//
GET(INFINI_DEVICE_NVIDIA, nvidia);
GET
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
//
GET(INFINI_DEVICE_ILUVATAR, nvidia);
GET
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
//
GET(INFINI_DEVICE_QY, nvidia);
GET
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_HYGON_API
//
GET(INFINI_DEVICE_HYGON, nvidia);
GET
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#ifdef ENABLE_KUNLUN_API
// GET(INFINI_DEVICE_KUNLUN, kunlun);
...
...
@@ -123,9 +122,9 @@ __C infiniStatus_t infiniopAddRMSNorm(
void
*
residual_out
,
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE)
\
case CASE:
\
return reinterpret_cast<const op::add_rms_norm::NAMESPACE::Descriptor *>(desc)
\
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::add_rms_norm::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, y, a, b, weight, residual_out, stream)
switch
(
desc
->
device_type
)
{
...
...
@@ -134,16 +133,16 @@ __C infiniStatus_t infiniopAddRMSNorm(
CALCULATE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
//
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
//
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
CALCULATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
//
CALCULATE(INFINI_DEVICE_QY, nvidia);
CALCULATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_HYGON_API
//
CALCULATE(INFINI_DEVICE_HYGON, nvidia);
CALCULATE
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#ifdef ENABLE_KUNLUN_API
// CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
...
...
@@ -159,9 +158,9 @@ __C infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescrip
return
INFINI_STATUS_SUCCESS
;
}
#define DESTROY(CASE, NAMESPACE)
\
case CASE:
\
delete reinterpret_cast<op::add_rms_norm::NAMESPACE::Descriptor *>(desc);
\
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::add_rms_norm::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS
switch
(
desc
->
device_type
)
{
...
...
@@ -169,16 +168,16 @@ __C infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescrip
DESTROY
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
//
DESTROY(INFINI_DEVICE_NVIDIA, nvidia);
DESTROY
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
//
DESTROY(INFINI_DEVICE_ILUVATAR, nvidia);
DESTROY
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
//
DESTROY(INFINI_DEVICE_QY, nvidia);
DESTROY
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_HYGON_API
//
DESTROY(INFINI_DEVICE_HYGON, nvidia);
DESTROY
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#ifdef ENABLE_KUNLUN_API
// DESTROY(INFINI_DEVICE_KUNLUN, kunlun);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment