Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
1048c1bc
Commit
1048c1bc
authored
Sep 03, 2025
by
zhangyue
Browse files
issue/421: 适配 rmsnorm 测例修改,支持 bf16 和 f16数据类型 weights
parent
c0d1b0d0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
11 additions
and
7 deletions
+11
-7
src/infiniop/ops/causal_softmax/kunlun/kernel.h
src/infiniop/ops/causal_softmax/kunlun/kernel.h
+1
-1
src/infiniop/ops/rms_norm/kunlun/kernel.h
src/infiniop/ops/rms_norm/kunlun/kernel.h
+1
-1
src/infiniop/ops/rms_norm/kunlun/rms_norm_kunlun.xpu
src/infiniop/ops/rms_norm/kunlun/rms_norm_kunlun.xpu
+4
-0
src/infiniop/reduce/kunlun/reduce_kunlun.h
src/infiniop/reduce/kunlun/reduce_kunlun.h
+5
-5
No files found.
src/infiniop/ops/causal_softmax/kunlun/kernel.h
View file @
1048c1bc
...
...
@@ -54,7 +54,7 @@ __device__ void causalSoftmaxBlock(
// Apply softmax
for
(
size_t
col
=
core_id
();
col
<
width
;
col
+=
BLOCK_SIZE
)
{
if
(
sum_
!=
0
)
{
y
[
col
]
=
to
<
Tdata
>
(
to
<
Tcompute
>
(
y
[
col
])
/
sum_
);
y
[
col
]
=
Tdata
(
Tcompute
(
y
[
col
])
/
sum_
);
}
else
{
y
[
col
]
=
Tdata
(
0
);
}
...
...
src/infiniop/ops/rms_norm/kunlun/kernel.h
View file @
1048c1bc
...
...
@@ -27,7 +27,7 @@ __device__ void rmsnormBlock(
for
(
size_t
i
=
core_id
();
i
<
dim
;
i
+=
BLOCK_SIZE
)
{
Tdata
xi
=
x
[
i
];
Tweight
wi
=
w
[
i
];
y
[
i
]
=
static_cast
<
Tdata
>
(
to
<
Tcompute
>
(
xi
)
*
to
<
Tcompute
>
(
wi
)
*
rms
);
y
[
i
]
=
Tdata
(
Tcompute
(
xi
)
*
Tcompute
(
wi
)
*
rms
);
}
sync_cluster
();
}
...
...
src/infiniop/ops/rms_norm/kunlun/rms_norm_kunlun.xpu
View file @
1048c1bc
...
...
@@ -95,10 +95,14 @@ infiniStatus_t launchKernel(
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(half, bfloat16_t, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(bfloat16_t, bfloat16_t, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(bfloat16_t, half, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(bfloat16_t, float, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
...
...
src/infiniop/reduce/kunlun/reduce_kunlun.h
View file @
1048c1bc
...
...
@@ -14,12 +14,12 @@ __device__ inline Tcompute sumSquared(__shared_ptr__ const Tdata *data_ptr, size
for
(
size_t
i
=
core_id
();
i
<
count
;
i
+=
BLOCK_SIZE
)
{
Tdata
xi
=
data_ptr
[
i
];
ss
+=
to
<
Tcompute
>
(
xi
)
*
to
<
Tcompute
>
(
xi
);
ss
+=
Tcompute
(
xi
)
*
Tcompute
(
xi
);
}
__shared__
Tcompute
temp_storage
;
if
(
core_id
()
==
0
)
{
temp_storage
=
to
<
Tcompute
>
(
0.
f
);
temp_storage
=
Tcompute
(
0.
f
);
}
sync_cluster
();
...
...
@@ -36,12 +36,12 @@ __device__ inline Tcompute sum(__shared_ptr__ const Tdata *data_ptr, size_t coun
for
(
size_t
i
=
core_id
();
i
<
count
;
i
+=
BLOCK_SIZE
)
{
Tdata
xi
=
data_ptr
[
i
];
ss
+=
to
<
Tcompute
>
(
xi
);
ss
+=
Tcompute
(
xi
);
}
__shared__
Tcompute
temp_storage
;
if
(
core_id
()
==
0
)
{
temp_storage
=
to
<
Tcompute
>
(
0.
f
);
temp_storage
=
Tcompute
(
0.
f
);
}
sync_cluster
();
...
...
@@ -58,7 +58,7 @@ __device__ inline Tdata max(__shared_ptr__ const Tdata *data_ptr, size_t count)
for
(
size_t
i
=
core_id
();
i
<
count
;
i
+=
BLOCK_SIZE
)
{
Tdata
xi
=
data_ptr
[
i
];
max_val
=
fmax
(
max_val
,
to
<
Tdata
>
(
xi
));
max_val
=
fmax
(
max_val
,
Tdata
(
xi
));
}
__shared__
Tdata
temp_storage
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment