Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
c35920e2
Unverified
Commit
c35920e2
authored
Aug 25, 2025
by
zhangyue
Committed by
GitHub
Aug 25, 2025
Browse files
Merge pull request #400 from InfiniTensor/issue/388-rmsnorm3d-p800
issue390: p800 rmsnorm 3d
parents
c9f5f2ec
e6dd5604
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
27 deletions
+32
-27
src/infiniop/ops/rms_norm/kunlun/rms_norm_kunlun.xpu
src/infiniop/ops/rms_norm/kunlun/rms_norm_kunlun.xpu
+32
-27
No files found.
src/infiniop/ops/rms_norm/kunlun/rms_norm_kunlun.xpu
View file @
c35920e2
...
@@ -36,13 +36,6 @@ __global__ void rmsnormKernel(
...
@@ -36,13 +36,6 @@ __global__ void rmsnormKernel(
sync_cluster();
sync_cluster();
}
}
// Instantiate the kernel for different data types and block sizes
INSTANTIATE_RMSNORM_KERNEL(64, float, float, float);
INSTANTIATE_RMSNORM_KERNEL(64, float, bfloat16_t, float);
INSTANTIATE_RMSNORM_KERNEL(64, float, bfloat16_t, bfloat16_t);
INSTANTIATE_RMSNORM_KERNEL(64, float, half, float);
INSTANTIATE_RMSNORM_KERNEL(64, float, half, half);
namespace op::rms_norm::kunlun {
namespace op::rms_norm::kunlun {
struct Descriptor::Opaque {
struct Descriptor::Opaque {
...
@@ -64,13 +57,6 @@ infiniStatus_t Descriptor::create(
...
@@ -64,13 +57,6 @@ infiniStatus_t Descriptor::create(
CHECK_RESULT(result);
CHECK_RESULT(result);
auto info = result.take();
auto info = result.take();
if (info.x_strides[1] != 1 || info.y_strides[1] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
if (info.ndim() != 2) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
*desc_ptr = new Descriptor(
*desc_ptr = new Descriptor(
new Descriptor::Opaque{static_cast<device::kunlun::Handle *>(handle)->internal()},
new Descriptor::Opaque{static_cast<device::kunlun::Handle *>(handle)->internal()},
info,
info,
...
@@ -82,18 +68,29 @@ infiniStatus_t Descriptor::create(
...
@@ -82,18 +68,29 @@ infiniStatus_t Descriptor::create(
template <unsigned int BLOCK_SIZE>
template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel(
infiniStatus_t launchKernel(
uint32
_t batch_size,
uint32
_t dim,
size
_t batch_size,
size_t nhead, size
_t dim,
void *y, infiniDtype_t atype, ptrdiff_t stride_y,
void *y, infiniDtype_t atype, ptrdiff_t stride_y
_batch, ptrdiff_t stride_y_nhead
,
const void *x, ptrdiff_t stride_x,
const void *x, ptrdiff_t stride_x
_batch, ptrdiff_t stride_x_nhead
,
const void *w, infiniDtype_t wtype,
const void *w, infiniDtype_t wtype,
float epsilon,
float epsilon,
kunlunStream_t stream) {
kunlunStream_t stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
uint32_t batch_size_ = static_cast<uint32_t>(batch_size);
rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size, BLOCK_SIZE, stream>>>( \
uint32_t dim_ = static_cast<uint32_t>(dim);
static_cast<Tdata *>(y), stride_y, \
uint32_t nhead_ = static_cast<uint32_t>(nhead);
static_cast<const Tdata *>(x), stride_x, \
auto stride_x_batch_ = static_cast<int32_t>(stride_x_batch);
static_cast<const Tweight *>(w), dim, epsilon);
auto stride_x_nhead_ = static_cast<int32_t>(stride_x_nhead);
auto stride_y_batch_ = static_cast<int32_t>(stride_y_batch);
auto stride_y_nhead_ = static_cast<int32_t>(stride_y_nhead);
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
for (uint32_t i = 0; i < batch_size_; ++i) { \
rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight> \
<<<nhead_, BLOCK_SIZE, stream>>>( \
(Tdata *)y + i * stride_y_batch_, stride_y_nhead_, \
(const Tdata *)x + i * stride_x_batch_, stride_x_nhead_, \
(const Tweight *)w, dim, epsilon); \
}
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float);
LAUNCH_KERNEL(half, half, float);
...
@@ -123,13 +120,21 @@ infiniStatus_t Descriptor::calculate(
...
@@ -123,13 +120,21 @@ infiniStatus_t Descriptor::calculate(
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
}
auto stride_x = static_cast<int32_t>(_info.x_strides[0]);
auto stride_x_batch = _info.x_strides[0];
auto stride_y = static_cast<int32_t>(_info.y_strides[0]);
auto stride_x_nhead = _info.x_strides[1];
auto dim = static_cast<uint32_t>(_info.dim());
auto stride_y_batch = _info.y_strides[0];
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
auto stride_y_nhead = _info.y_strides[1];
auto dim = _info.dim();
auto batch_size = _info.shape[0];
auto nhead = _info.shape.size() > 2 ? _info.shape[1] : 1;
kunlunStream_t stream_ = static_cast<kunlunStream_t>(stream);
// launch kernel with different block sizes
// launch kernel with different block sizes
CHECK_STATUS(launchKernel<64>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, stream));
CHECK_STATUS(launchKernel<64>(batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
x, stride_x_batch, stride_x_nhead,
w, _info.wtype, _info.epsilon, stream_));
return INFINI_STATUS_SUCCESS;
return INFINI_STATUS_SUCCESS;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment