Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
dc8ddd58
Unverified
Commit
dc8ddd58
authored
Aug 22, 2025
by
Ziminli
Committed by
GitHub
Aug 22, 2025
Browse files
issue/388: Support 3D Cases for RMS Norm
parent
47895fae
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
160 additions
and
101 deletions
+160
-101
src/infiniop/ops/rms_norm/cpu/rms_norm_cpu.cc
src/infiniop/ops/rms_norm/cpu/rms_norm_cpu.cc
+35
-19
src/infiniop/ops/rms_norm/cuda/kernel.cuh
src/infiniop/ops/rms_norm/cuda/kernel.cuh
+11
-5
src/infiniop/ops/rms_norm/info.h
src/infiniop/ops/rms_norm/info.h
+27
-9
src/infiniop/ops/rms_norm/metax/rms_norm_metax.maca
src/infiniop/ops/rms_norm/metax/rms_norm_metax.maca
+26
-22
src/infiniop/ops/rms_norm/moore/rms_norm_moore.mu
src/infiniop/ops/rms_norm/moore/rms_norm_moore.mu
+23
-19
src/infiniop/ops/rms_norm/nvidia/rms_norm_nvidia.cu
src/infiniop/ops/rms_norm/nvidia/rms_norm_nvidia.cu
+28
-24
test/infiniop-test/test_generate/testcases/rms_norm.py
test/infiniop-test/test_generate/testcases/rms_norm.py
+4
-0
test/infiniop/rms_norm.py
test/infiniop/rms_norm.py
+6
-3
No files found.
src/infiniop/ops/rms_norm/cpu/rms_norm_cpu.cc
View file @
dc8ddd58
...
@@ -21,19 +21,27 @@ infiniStatus_t Descriptor::create(
...
@@ -21,19 +21,27 @@ infiniStatus_t Descriptor::create(
template
<
typename
T
>
template
<
typename
T
>
infiniStatus_t
rmsnorm
(
const
RMSNormInfo
*
info
,
T
*
y
,
const
T
*
x
,
const
T
*
w
)
{
infiniStatus_t
rmsnorm
(
const
RMSNormInfo
*
info
,
T
*
y
,
const
T
*
x
,
const
T
*
w
)
{
const
size_t
batch_size
=
info
->
shape
[
0
];
const
size_t
nhead
=
info
->
shape
.
size
()
>
2
?
info
->
shape
[
1
]
:
1
;
const
size_t
dim
=
info
->
shape
.
back
();
const
ptrdiff_t
total_blocks
=
static_cast
<
ptrdiff_t
>
(
batch_size
*
nhead
);
#pragma omp parallel for
#pragma omp parallel for
for
(
ptrdiff_t
i
=
0
;
i
<
ptrdiff_t
(
info
->
shape
[
0
]);
i
++
)
{
for
(
ptrdiff_t
block_idx
=
0
;
block_idx
<
total_blocks
;
++
block_idx
)
{
T
*
x_
=
(
T
*
)(
x
+
i
*
info
->
x_strides
[
0
]);
const
size_t
i
=
block_idx
/
nhead
;
// batch index
T
*
y_
=
(
T
*
)(
y
+
i
*
info
->
y_strides
[
0
]);
const
size_t
j
=
block_idx
%
nhead
;
// head index
const
T
*
x_ptr
=
x
+
i
*
info
->
x_strides
[
0
]
+
j
*
info
->
x_strides
[
1
];
T
*
y_ptr
=
y
+
i
*
info
->
y_strides
[
0
]
+
j
*
info
->
y_strides
[
1
];
// [Reduce] sum of x^2 on last dimension
// [Reduce] sum of x^2 on last dimension
T
ss
=
op
::
common_cpu
::
reduce_op
::
sumSquared
(
x_
,
info
->
shape
[
1
]
,
info
->
x_strides
[
1
]
);
T
ss
=
op
::
common_cpu
::
reduce_op
::
sumSquared
(
x_
ptr
,
dim
,
info
->
x_strides
.
back
()
);
// 1 / (sqrt(sum/dim + eps))
// 1 / (sqrt(sum/dim + eps))
T
rms
=
(
T
)
1
/
std
::
sqrt
(
ss
/
(
T
)(
info
->
shape
[
1
]
)
+
(
T
)(
info
->
epsilon
));
T
rms
=
(
T
)
1
/
std
::
sqrt
(
ss
/
(
T
)(
dim
)
+
(
T
)(
info
->
epsilon
));
for
(
size_t
j
=
0
;
j
<
info
->
shape
[
1
]
;
j
++
)
{
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
y_
[
j
*
info
->
y_strides
[
1
]]
=
x_
[
j
*
info
->
x_strides
[
1
]
]
*
w
[
j
]
*
rms
;
y_
ptr
[
k
]
=
x_ptr
[
k
]
*
w
[
k
]
*
rms
;
}
}
}
}
...
@@ -45,24 +53,32 @@ infiniStatus_t rmsnormHalfPrecision(const RMSNormInfo *info, T *y, const T *x, c
...
@@ -45,24 +53,32 @@ infiniStatus_t rmsnormHalfPrecision(const RMSNormInfo *info, T *y, const T *x, c
static_assert
(
std
::
is_same
<
T
,
fp16_t
>::
value
||
std
::
is_same
<
T
,
bf16_t
>::
value
,
static_assert
(
std
::
is_same
<
T
,
fp16_t
>::
value
||
std
::
is_same
<
T
,
bf16_t
>::
value
,
"T must be fp16_t or bf16_t"
);
"T must be fp16_t or bf16_t"
);
const
size_t
batch_size
=
info
->
shape
[
0
];
const
size_t
nhead
=
info
->
shape
.
size
()
>
2
?
info
->
shape
[
1
]
:
1
;
const
size_t
dim
=
info
->
shape
.
back
();
const
ptrdiff_t
total_blocks
=
static_cast
<
ptrdiff_t
>
(
batch_size
*
nhead
);
#pragma omp parallel for
#pragma omp parallel for
for
(
ptrdiff_t
i
=
0
;
i
<
ptrdiff_t
(
info
->
shape
[
0
]);
i
++
)
{
for
(
ptrdiff_t
block_idx
=
0
;
block_idx
<
total_blocks
;
++
block_idx
)
{
T
*
x_
=
(
T
*
)(
x
+
i
*
info
->
x_strides
[
0
]);
const
size_t
i
=
block_idx
/
nhead
;
// batch index
T
*
y_
=
(
T
*
)(
y
+
i
*
info
->
y_strides
[
0
]);
const
size_t
j
=
block_idx
%
nhead
;
// head index
const
T
*
x_ptr
=
x
+
i
*
info
->
x_strides
[
0
]
+
j
*
info
->
x_strides
[
1
];
T
*
y_ptr
=
y
+
i
*
info
->
y_strides
[
0
]
+
j
*
info
->
y_strides
[
1
];
// [Reduce] sum of x^2 on last dimension
// [Reduce] sum of x^2 on last dimension
float
ss
=
op
::
common_cpu
::
reduce_op
::
sumSquared
(
x_
,
info
->
shape
[
1
]
,
info
->
x_strides
[
1
]
);
float
ss
=
op
::
common_cpu
::
reduce_op
::
sumSquared
(
x_
ptr
,
dim
,
info
->
x_strides
.
back
()
);
// 1 / (sqrt(sum/dim + eps))
// 1 / (sqrt(sum/dim + eps))
float
rms
=
1.
f
/
std
::
sqrt
(
ss
/
(
float
)(
info
->
shape
[
1
]
)
+
info
->
epsilon
);
float
rms
=
1.
f
/
std
::
sqrt
(
ss
/
(
float
)(
dim
)
+
info
->
epsilon
);
for
(
size_t
j
=
0
;
j
<
info
->
shape
[
1
]
;
j
++
)
{
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
if
constexpr
(
std
::
is_same
<
Tw
,
float
>::
value
)
{
if
constexpr
(
std
::
is_same
<
Tw
,
float
>::
value
)
{
float
val
=
utils
::
cast
<
float
>
(
x_
[
j
*
info
->
x_strides
[
1
]
])
*
w
[
j
]
*
rms
;
float
val
=
utils
::
cast
<
float
>
(
x_
ptr
[
k
])
*
w
[
k
]
*
rms
;
y_
[
j
*
info
->
y_strides
[
1
]
]
=
utils
::
cast
<
T
>
(
val
);
y_
ptr
[
k
]
=
utils
::
cast
<
T
>
(
val
);
}
else
if
constexpr
(
std
::
is_same
<
Tw
,
T
>::
value
)
{
}
else
if
constexpr
(
std
::
is_same
<
Tw
,
T
>::
value
)
{
float
val
=
utils
::
cast
<
float
>
(
x_
[
j
*
info
->
x_strides
[
1
]
])
*
utils
::
cast
<
float
>
(
w
[
j
])
*
rms
;
float
val
=
utils
::
cast
<
float
>
(
x_
ptr
[
k
])
*
utils
::
cast
<
float
>
(
w
[
k
])
*
rms
;
y_
[
j
*
info
->
y_strides
[
1
]
]
=
utils
::
cast
<
T
>
(
val
);
y_
ptr
[
k
]
=
utils
::
cast
<
T
>
(
val
);
}
else
{
}
else
{
std
::
abort
();
std
::
abort
();
}
}
...
@@ -93,9 +109,9 @@ infiniStatus_t Descriptor::calculate(
...
@@ -93,9 +109,9 @@ infiniStatus_t Descriptor::calculate(
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
}
else
if
(
_info
.
atype
==
INFINI_DTYPE_F32
)
{
}
else
if
(
_info
.
atype
==
INFINI_DTYPE_F32
)
{
CHECK_STATUS
(
rmsnorm
(
&
_info
,
(
float
*
)
y
,
(
float
*
)
x
,
(
float
*
)
w
));
CHECK_STATUS
(
rmsnorm
(
&
_info
,
(
float
*
)
y
,
(
const
float
*
)
x
,
(
const
float
*
)
w
));
}
else
if
(
_info
.
atype
==
INFINI_DTYPE_F64
)
{
}
else
if
(
_info
.
atype
==
INFINI_DTYPE_F64
)
{
CHECK_STATUS
(
rmsnorm
(
&
_info
,
(
double
*
)
y
,
(
double
*
)
x
,
(
double
*
)
w
));
CHECK_STATUS
(
rmsnorm
(
&
_info
,
(
double
*
)
y
,
(
const
double
*
)
x
,
(
const
double
*
)
w
));
}
else
{
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
...
...
src/infiniop/ops/rms_norm/cuda/kernel.cuh
View file @
dc8ddd58
...
@@ -4,16 +4,22 @@
...
@@ -4,16 +4,22 @@
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tcompute
,
typename
Tdata
,
typename
Tweight
>
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tcompute
,
typename
Tdata
,
typename
Tweight
>
__device__
void
rmsnormBlock
(
__device__
void
rmsnormBlock
(
Tdata
*
__restrict__
y
,
Tdata
*
__restrict__
y
,
ptrdiff_t
stride_y
,
ptrdiff_t
stride_y_batch
,
ptrdiff_t
stride_y_nhead
,
const
Tdata
*
__restrict__
x
,
const
Tdata
*
__restrict__
x
,
ptrdiff_t
stride_x
,
ptrdiff_t
stride_x_batch
,
ptrdiff_t
stride_x_nhead
,
const
Tweight
*
__restrict__
w
,
const
Tweight
*
__restrict__
w
,
size_t
nhead
,
size_t
dim
,
size_t
dim
,
float
epsilon
)
{
float
epsilon
)
{
// Each block takes care of
a row of continuous data of length dim
// Each block takes care of
one head in one batch
// Each thread deals with every block_size element in the row
// Each thread deals with every block_size element in the row
auto
y_ptr
=
y
+
blockIdx
.
x
*
stride_y
;
size_t
batch_idx
=
blockIdx
.
x
/
nhead
;
auto
x_ptr
=
x
+
blockIdx
.
x
*
stride_x
;
size_t
head_idx
=
blockIdx
.
x
%
nhead
;
auto
y_ptr
=
y
+
batch_idx
*
stride_y_batch
+
head_idx
*
stride_y_nhead
;
auto
x_ptr
=
x
+
batch_idx
*
stride_x_batch
+
head_idx
*
stride_x_nhead
;
auto
w_ptr
=
w
;
auto
w_ptr
=
w
;
// Block-reduce sum of x^2
// Block-reduce sum of x^2
...
...
src/infiniop/ops/rms_norm/info.h
View file @
dc8ddd58
...
@@ -46,21 +46,39 @@ public:
...
@@ -46,21 +46,39 @@ public:
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
if
(
y_desc
->
ndim
()
!=
2
||
x_desc
->
ndim
()
!=
2
||
w
_desc
->
ndim
()
!=
1
)
{
const
size_t
y_ndim
=
y
_desc
->
ndim
()
;
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
const
size_t
x_ndim
=
x_desc
->
ndim
()
;
}
const
size_t
w_ndim
=
w_desc
->
ndim
();
size_t
batch
=
y_desc
->
shape
()[
0
];
if
(
y_ndim
!=
x_ndim
||
w_ndim
!=
1
)
{
size_t
dim
=
y_desc
->
shape
()[
1
];
if
(
x_desc
->
shape
()[
0
]
!=
batch
||
x_desc
->
shape
()[
1
]
!=
dim
||
w_desc
->
shape
()[
0
]
!=
dim
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
if
(
w_desc
->
stride
(
0
)
!=
1
)
{
size_t
batch
=
1
;
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
size_t
nhead
=
1
;
size_t
dim
=
0
;
if
(
y_ndim
==
2
)
{
batch
=
y_desc
->
dim
(
0
);
dim
=
y_desc
->
dim
(
1
);
if
(
x_desc
->
dim
(
0
)
!=
batch
||
x_desc
->
dim
(
1
)
!=
dim
||
w_desc
->
dim
(
0
)
!=
dim
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
else
if
(
y_ndim
==
3
)
{
batch
=
y_desc
->
dim
(
0
);
nhead
=
y_desc
->
dim
(
1
);
dim
=
y_desc
->
dim
(
2
);
if
(
x_desc
->
dim
(
0
)
!=
batch
||
x_desc
->
dim
(
1
)
!=
nhead
||
x_desc
->
dim
(
2
)
!=
dim
||
w_desc
->
dim
(
0
)
!=
dim
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
else
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
if
(
x_desc
->
stride
(
1
)
!=
1
||
y_desc
->
stride
(
1
)
!=
1
)
{
// Check contiguity of the last dimension
if
(
y_desc
->
stride
(
y_ndim
-
1
)
!=
1
||
x_desc
->
stride
(
x_ndim
-
1
)
!=
1
||
w_desc
->
stride
(
w_ndim
-
1
)
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
}
...
...
src/infiniop/ops/rms_norm/metax/rms_norm_metax.maca
View file @
dc8ddd58
...
@@ -11,13 +11,16 @@
...
@@ -11,13 +11,16 @@
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_METAX_KERNEL rmsnormKernel(
INFINIOP_METAX_KERNEL rmsnormKernel(
Tdata *__restrict__ y,
Tdata *__restrict__ y,
ptrdiff_t stride_y,
ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
const Tdata *__restrict__ x,
const Tdata *__restrict__ x,
ptrdiff_t stride_x,
ptrdiff_t stride_x_batch,
ptrdiff_t stride_x_nhead,
const Tweight *__restrict__ w,
const Tweight *__restrict__ w,
size_t nhead,
size_t dim,
size_t dim,
float epsilon) {
float epsilon) {
rmsnormBlock<BLOCK_SIZE, Tcompute>(y, stride_y
, x, stride_x, w
, dim, epsilon);
rmsnormBlock<BLOCK_SIZE, Tcompute>(y, stride_y
_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, nhead
, dim, epsilon);
}
}
namespace op::rms_norm::metax {
namespace op::rms_norm::metax {
...
@@ -41,11 +44,6 @@ infiniStatus_t Descriptor::create(
...
@@ -41,11 +44,6 @@ infiniStatus_t Descriptor::create(
CHECK_RESULT(result);
CHECK_RESULT(result);
auto info = result.take();
auto info = result.take();
// only support contiguous last dimension
if (info.x_strides[1] != 1 || info.y_strides[1] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
*desc_ptr = new Descriptor(
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
std::move(info),
std::move(info),
...
@@ -57,21 +55,24 @@ infiniStatus_t Descriptor::create(
...
@@ -57,21 +55,24 @@ infiniStatus_t Descriptor::create(
// launch kernel with different data types
// launch kernel with different data types
template <unsigned int BLOCK_SIZE>
template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel(
infiniStatus_t launchKernel(
uint32_t batch_size, size_t dim,
uint32_t batch_size, size_t
nhead, size_t
dim,
void *y, infiniDtype_t atype, ptrdiff_t stride_y,
void *y, infiniDtype_t atype, ptrdiff_t stride_y
_batch, ptrdiff_t stride_y_nhead
,
const void *x, ptrdiff_t stride_x,
const void *x, ptrdiff_t stride_x
_batch, ptrdiff_t stride_x_nhead
,
const void *w, infiniDtype_t wtype,
const void *w, infiniDtype_t wtype,
float epsilon,
float epsilon,
hcStream_t stream) {
hcStream_t stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size, BLOCK_SIZE, 0, stream>>>( \
rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size * nhead, BLOCK_SIZE, 0, stream>>>( \
reinterpret_cast<Tdata *>(y), \
reinterpret_cast<Tdata *>(y), \
stride_y, \
stride_y_batch, \
reinterpret_cast<const Tdata *>(x), \
stride_y_nhead, \
stride_x, \
reinterpret_cast<const Tdata *>(x), \
reinterpret_cast<const Tweight *>(w), \
stride_x_batch, \
dim, \
stride_x_nhead, \
reinterpret_cast<const Tweight *>(w), \
nhead, \
dim, \
epsilon)
epsilon)
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
...
@@ -102,15 +103,18 @@ infiniStatus_t Descriptor::calculate(
...
@@ -102,15 +103,18 @@ infiniStatus_t Descriptor::calculate(
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
}
auto stride_x = _info.x_strides[0];
auto stride_x_batch = _info.x_strides[0];
auto stride_y = _info.y_strides[0];
auto stride_x_nhead = _info.x_strides[1];
auto stride_y_batch = _info.y_strides[0];
auto stride_y_nhead = _info.y_strides[1];
auto dim = _info.dim();
auto dim = _info.dim();
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1;
auto stream = reinterpret_cast<hcStream_t>(stream_);
auto stream = reinterpret_cast<hcStream_t>(stream_);
// launch kernel with different block sizes
// launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y
, x, stride_x
, w, _info.wtype, _info.epsilon, stream));
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_1024>(batch_size,
nhead,
dim, y, _info.atype, stride_y
_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead
, w, _info.wtype, _info.epsilon, stream));
} else {
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
}
...
...
src/infiniop/ops/rms_norm/moore/rms_norm_moore.mu
View file @
dc8ddd58
...
@@ -11,13 +11,16 @@
...
@@ -11,13 +11,16 @@
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_MOORE_KERNEL rmsnormKernel(
INFINIOP_MOORE_KERNEL rmsnormKernel(
Tdata *__restrict__ y,
Tdata *__restrict__ y,
ptrdiff_t stride_y,
ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
const Tdata *__restrict__ x,
const Tdata *__restrict__ x,
ptrdiff_t stride_x,
ptrdiff_t stride_x_batch,
ptrdiff_t stride_x_nhead,
const Tweight *__restrict__ w,
const Tweight *__restrict__ w,
size_t nhead,
size_t dim,
size_t dim,
float epsilon) {
float epsilon) {
rmsnormBlock<BLOCK_SIZE, Tcompute>(y, stride_y
, x, stride_x, w
, dim, epsilon);
rmsnormBlock<BLOCK_SIZE, Tcompute>(y, stride_y
_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, nhead
, dim, epsilon);
}
}
namespace op::rms_norm::moore {
namespace op::rms_norm::moore {
...
@@ -41,11 +44,6 @@ infiniStatus_t Descriptor::create(
...
@@ -41,11 +44,6 @@ infiniStatus_t Descriptor::create(
CHECK_RESULT(result);
CHECK_RESULT(result);
auto info = result.take();
auto info = result.take();
// only support contiguous last dimension
if (info.x_strides[1] != 1 || info.y_strides[1] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
*desc_ptr = new Descriptor(
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
std::move(info),
std::move(info),
...
@@ -57,20 +55,23 @@ infiniStatus_t Descriptor::create(
...
@@ -57,20 +55,23 @@ infiniStatus_t Descriptor::create(
// launch kernel with different data types
// launch kernel with different data types
template <unsigned int BLOCK_SIZE>
template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel(
infiniStatus_t launchKernel(
uint32_t batch_size, size_t dim,
uint32_t batch_size, size_t
nhead, size_t
dim,
void *y, infiniDtype_t atype, ptrdiff_t stride_y,
void *y, infiniDtype_t atype, ptrdiff_t stride_y
_batch, ptrdiff_t stride_y_nhead
,
const void *x, ptrdiff_t stride_x,
const void *x, ptrdiff_t stride_x
_batch, ptrdiff_t stride_x_nhead
,
const void *w, infiniDtype_t wtype,
const void *w, infiniDtype_t wtype,
float epsilon,
float epsilon,
musaStream_t musa_stream) {
musaStream_t musa_stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size, BLOCK_SIZE, 0, musa_stream>>>( \
rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size
* nhead
, BLOCK_SIZE, 0, musa_stream>>>( \
reinterpret_cast<Tdata *>(y), \
reinterpret_cast<Tdata *>(y), \
stride_y, \
stride_y_batch, \
stride_y_nhead, \
reinterpret_cast<const Tdata *>(x), \
reinterpret_cast<const Tdata *>(x), \
stride_x, \
stride_x_batch, \
stride_x_nhead, \
reinterpret_cast<const Tweight *>(w), \
reinterpret_cast<const Tweight *>(w), \
nhead, \
dim, \
dim, \
epsilon)
epsilon)
...
@@ -102,19 +103,22 @@ infiniStatus_t Descriptor::calculate(
...
@@ -102,19 +103,22 @@ infiniStatus_t Descriptor::calculate(
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
}
auto stride_x = _info.x_strides[0];
auto stride_x_batch = _info.x_strides[0];
auto stride_y = _info.y_strides[0];
auto stride_x_nhead = _info.x_strides[1];
auto stride_y_batch = _info.y_strides[0];
auto stride_y_nhead = _info.y_strides[1];
auto dim = _info.dim();
auto dim = _info.dim();
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1;
auto musa_stream = reinterpret_cast<musaStream_t>(stream);
auto musa_stream = reinterpret_cast<musaStream_t>(stream);
// launch kernel with different block sizes
// launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_1024) {
if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y
, x, stride_x
, w, _info.wtype, _info.epsilon, musa_stream));
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_1024>(batch_size,
nhead,
dim, y, _info.atype, stride_y
_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead
, w, _info.wtype, _info.epsilon, musa_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_512) {
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_512>(batch_size, dim, y, _info.atype, stride_y
, x, stride_x
, w, _info.wtype, _info.epsilon, musa_stream));
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_512>(batch_size,
nhead,
dim, y, _info.atype, stride_y
_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead
, w, _info.wtype, _info.epsilon, musa_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_2048) {
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_2048) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_2048>(batch_size, dim, y, _info.atype, stride_y
, x, stride_x
, w, _info.wtype, _info.epsilon, musa_stream));
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_2048>(batch_size,
nhead,
dim, y, _info.atype, stride_y
_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead
, w, _info.wtype, _info.epsilon, musa_stream));
} else {
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
}
...
...
src/infiniop/ops/rms_norm/nvidia/rms_norm_nvidia.cu
View file @
dc8ddd58
...
@@ -11,13 +11,16 @@
...
@@ -11,13 +11,16 @@
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tcompute
,
typename
Tdata
,
typename
Tweight
>
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tcompute
,
typename
Tdata
,
typename
Tweight
>
INFINIOP_CUDA_KERNEL
rmsnormKernel
(
INFINIOP_CUDA_KERNEL
rmsnormKernel
(
Tdata
*
__restrict__
y
,
Tdata
*
__restrict__
y
,
ptrdiff_t
stride_y
,
ptrdiff_t
stride_y_batch
,
ptrdiff_t
stride_y_nhead
,
const
Tdata
*
__restrict__
x
,
const
Tdata
*
__restrict__
x
,
ptrdiff_t
stride_x
,
ptrdiff_t
stride_x_batch
,
ptrdiff_t
stride_x_nhead
,
const
Tweight
*
__restrict__
w
,
const
Tweight
*
__restrict__
w
,
size_t
nhead
,
size_t
dim
,
size_t
dim
,
float
epsilon
)
{
float
epsilon
)
{
rmsnormBlock
<
BLOCK_SIZE
,
Tcompute
>
(
y
,
stride_y
,
x
,
stride_x
,
w
,
dim
,
epsilon
);
rmsnormBlock
<
BLOCK_SIZE
,
Tcompute
>
(
y
,
stride_y
_batch
,
stride_y_nhead
,
x
,
stride_x_batch
,
stride_x_nhead
,
w
,
nhead
,
dim
,
epsilon
);
}
}
namespace
op
::
rms_norm
::
nvidia
{
namespace
op
::
rms_norm
::
nvidia
{
...
@@ -41,11 +44,6 @@ infiniStatus_t Descriptor::create(
...
@@ -41,11 +44,6 @@ infiniStatus_t Descriptor::create(
CHECK_RESULT
(
result
);
CHECK_RESULT
(
result
);
auto
info
=
result
.
take
();
auto
info
=
result
.
take
();
// only support contiguous last dimension
if
(
info
.
x_strides
[
1
]
!=
1
||
info
.
y_strides
[
1
]
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
*
desc_ptr
=
new
Descriptor
(
*
desc_ptr
=
new
Descriptor
(
new
Opaque
{
reinterpret_cast
<
device
::
nvidia
::
Handle
*>
(
handle
)
->
internal
()},
new
Opaque
{
reinterpret_cast
<
device
::
nvidia
::
Handle
*>
(
handle
)
->
internal
()},
std
::
move
(
info
),
std
::
move
(
info
),
...
@@ -57,21 +55,24 @@ infiniStatus_t Descriptor::create(
...
@@ -57,21 +55,24 @@ infiniStatus_t Descriptor::create(
// launch kernel with different data types
// launch kernel with different data types
template
<
unsigned
int
BLOCK_SIZE
>
template
<
unsigned
int
BLOCK_SIZE
>
infiniStatus_t
launchKernel
(
infiniStatus_t
launchKernel
(
uint32_t
batch_size
,
size_t
dim
,
uint32_t
batch_size
,
size_t
nhead
,
size_t
dim
,
void
*
y
,
infiniDtype_t
atype
,
ptrdiff_t
stride_y
,
void
*
y
,
infiniDtype_t
atype
,
ptrdiff_t
stride_y
_batch
,
ptrdiff_t
stride_y_nhead
,
const
void
*
x
,
ptrdiff_t
stride_x
,
const
void
*
x
,
ptrdiff_t
stride_x
_batch
,
ptrdiff_t
stride_x_nhead
,
const
void
*
w
,
infiniDtype_t
wtype
,
const
void
*
w
,
infiniDtype_t
wtype
,
float
epsilon
,
float
epsilon
,
cudaStream_t
cuda_stream
)
{
cudaStream_t
cuda_stream
)
{
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size, BLOCK_SIZE, 0, cuda_stream>>>( \
rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size * nhead, BLOCK_SIZE, 0, cuda_stream>>>( \
reinterpret_cast<Tdata *>(y), \
reinterpret_cast<Tdata *>(y), \
stride_y, \
stride_y_batch, \
reinterpret_cast<const Tdata *>(x), \
stride_y_nhead, \
stride_x, \
reinterpret_cast<const Tdata *>(x), \
reinterpret_cast<const Tweight *>(w), \
stride_x_batch, \
dim, \
stride_x_nhead, \
reinterpret_cast<const Tweight *>(w), \
nhead, \
dim, \
epsilon)
epsilon)
if
(
atype
==
INFINI_DTYPE_F16
&&
wtype
==
INFINI_DTYPE_F16
)
{
if
(
atype
==
INFINI_DTYPE_F16
&&
wtype
==
INFINI_DTYPE_F16
)
{
...
@@ -102,19 +103,22 @@ infiniStatus_t Descriptor::calculate(
...
@@ -102,19 +103,22 @@ infiniStatus_t Descriptor::calculate(
return
INFINI_STATUS_INSUFFICIENT_WORKSPACE
;
return
INFINI_STATUS_INSUFFICIENT_WORKSPACE
;
}
}
auto
stride_x
=
_info
.
x_strides
[
0
];
auto
stride_x_batch
=
_info
.
x_strides
[
0
];
auto
stride_y
=
_info
.
y_strides
[
0
];
auto
stride_x_nhead
=
_info
.
x_strides
[
1
];
auto
stride_y_batch
=
_info
.
y_strides
[
0
];
auto
stride_y_nhead
=
_info
.
y_strides
[
1
];
auto
dim
=
_info
.
dim
();
auto
dim
=
_info
.
dim
();
uint32_t
batch_size
=
static_cast
<
uint32_t
>
(
_info
.
shape
[
0
]);
uint32_t
batch_size
=
static_cast
<
uint32_t
>
(
_info
.
shape
[
0
]);
size_t
nhead
=
_info
.
shape
.
size
()
>
2
?
_info
.
shape
[
1
]
:
1
;
auto
cuda_stream
=
reinterpret_cast
<
cudaStream_t
>
(
stream
);
auto
cuda_stream
=
reinterpret_cast
<
cudaStream_t
>
(
stream
);
// launch kernel with different block sizes
// launch kernel with different block sizes
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_1024
)
{
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_1024
)
{
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_1024
>
(
batch_size
,
dim
,
y
,
_info
.
atype
,
stride_y
,
x
,
stride_x
,
w
,
_info
.
wtype
,
_info
.
epsilon
,
cuda_stream
));
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_1024
>
(
batch_size
,
nhead
,
dim
,
y
,
_info
.
atype
,
stride_y
_batch
,
stride_y_nhead
,
x
,
stride_x_batch
,
stride_x_nhead
,
w
,
_info
.
wtype
,
_info
.
epsilon
,
cuda_stream
));
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_512
)
{
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_512
)
{
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_512
>
(
batch_size
,
dim
,
y
,
_info
.
atype
,
stride_y
,
x
,
stride_x
,
w
,
_info
.
wtype
,
_info
.
epsilon
,
cuda_stream
));
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_512
>
(
batch_size
,
nhead
,
dim
,
y
,
_info
.
atype
,
stride_y
_batch
,
stride_y_nhead
,
x
,
stride_x_batch
,
stride_x_nhead
,
w
,
_info
.
wtype
,
_info
.
epsilon
,
cuda_stream
));
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_4096
)
{
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_4096
)
{
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_4096
>
(
batch_size
,
dim
,
y
,
_info
.
atype
,
stride_y
,
x
,
stride_x
,
w
,
_info
.
wtype
,
_info
.
epsilon
,
cuda_stream
));
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_4096
>
(
batch_size
,
nhead
,
dim
,
y
,
_info
.
atype
,
stride_y
_batch
,
stride_y_nhead
,
x
,
stride_x_batch
,
stride_x_nhead
,
w
,
_info
.
wtype
,
_info
.
epsilon
,
cuda_stream
));
}
else
{
}
else
{
return
INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED
;
}
}
...
...
test/infiniop-test/test_generate/testcases/rms_norm.py
View file @
dc8ddd58
...
@@ -97,6 +97,10 @@ if __name__ == "__main__":
...
@@ -97,6 +97,10 @@ if __name__ == "__main__":
((
500
,
4096
),
None
,
(
8192
,
1
)),
((
500
,
4096
),
None
,
(
8192
,
1
)),
((
4
,
512
),
(
1024
,
1
),
(
512
,
1
)),
((
4
,
512
),
(
1024
,
1
),
(
512
,
1
)),
((
4
,
512
),
None
,
(
2048
,
1
)),
((
4
,
512
),
None
,
(
2048
,
1
)),
((
3
,
4
,
512
),
None
,
None
),
((
3
,
4
,
512
),
None
,
(
4096
,
1024
,
1
)),
((
3
,
4
,
512
),
(
4096
,
1024
,
1
),
None
),
((
3
,
4
,
512
),
(
4096
,
1024
,
1
),
(
4096
,
1024
,
1
)),
]
]
_TENSOR_DTYPES_
=
[
np
.
float32
,
np
.
float16
]
_TENSOR_DTYPES_
=
[
np
.
float32
,
np
.
float16
]
for
dtype
in
_TENSOR_DTYPES_
:
for
dtype
in
_TENSOR_DTYPES_
:
...
...
test/infiniop/rms_norm.py
View file @
dc8ddd58
...
@@ -25,11 +25,14 @@ from libinfiniop import (
...
@@ -25,11 +25,14 @@ from libinfiniop import (
_TEST_CASES_
=
[
_TEST_CASES_
=
[
# y_shape, x_shape, w_shape, y_stride, x_stride
# y_shape, x_shape, w_shape, y_stride, x_stride
((
1
,
4
),
(
1
,
4
),
(
4
,),
None
,
None
),
((
1
,
4
),
(
1
,
4
),
(
4
,),
None
,
None
),
((
1
,
4
),
(
1
,
4
),
(
4
,),
None
,
None
),
((
2
,
4
),
(
2
,
4
),
(
4
,),
None
,
None
),
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
None
,
None
),
((
2
,
2
,
4
),
(
2
,
2
,
4
),
(
4
,),
None
,
None
),
((
2
,
2
,
4
),
(
2
,
2
,
4
),
(
4
,),
(
12
,
8
,
1
),
(
12
,
8
,
1
)),
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
None
,
None
),
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
None
,
None
),
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
(
4096
,
1
),
(
4096
,
1
)),
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
(
4096
,
1
),
(
4096
,
1
)),
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
(
4096
,
1
),
(
4096
,
1
)),
((
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
2048
,),
None
,
None
),
((
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
2048
,),
(
2048
,
8192
,
1
),
(
2048
,
8192
,
1
)),
((
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
2048
,),
(
16384
,
4096
,
1
),
(
16384
,
4096
,
1
)),
]
]
# w (weight) types
# w (weight) types
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment