Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
a4fdef4c
Commit
a4fdef4c
authored
Feb 24, 2026
by
zhanghj2
Browse files
优化softmax计算
parent
3a477917
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
12 deletions
+14
-12
csrc/softmax.h
csrc/softmax.h
+14
-12
No files found.
csrc/softmax.h
View file @
a4fdef4c
...
...
@@ -55,12 +55,13 @@ __device__ __forceinline__ void warp_allreduce_(Tensor<Engine0, Layout0> &dst, T
// smem_reduce(row, col) = dst(0);
}
__syncthreads
();
if
(
tidx
<
16
)
{
smem_reduce
(
row
+
64
)
=
op
(
op
(
smem_reduce
(
row
*
4
),
smem_reduce
(
row
*
4
+
1
)),
op
(
smem_reduce
(
row
*
4
+
2
),
smem_reduce
(
row
*
4
+
3
)));
}
__syncthreads
();
dst
(
0
)
=
smem_reduce
(
row
+
64
);
// if (tidx < 16)
// {
// smem_reduce(row + 64) = op(op(smem_reduce(row * 4), smem_reduce(row * 4 + 1)), op(smem_reduce(row * 4 + 2), smem_reduce(row * 4 + 3)));
// }
// __syncthreads();
// dst(0) = smem_reduce(row + 64);
dst
(
0
)
=
op
(
op
(
smem_reduce
(
row
*
4
),
smem_reduce
(
row
*
4
+
1
)),
op
(
smem_reduce
(
row
*
4
+
2
),
smem_reduce
(
row
*
4
+
3
)));
}
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Operator
>
...
...
@@ -75,12 +76,13 @@ __device__ __forceinline__ void warp_allreduce_tp1(Tensor<Engine0, Layout0> &dst
smem_reduce
[
row
*
2
+
(
warp_id
/
4
)]
=
dst
[
0
];
}
__syncthreads
();
if
(
col
==
0
&&
warp_id
<
4
)
{
// printf("sum %d %d %d %.2f %.2f \n", row, tidx, warp_id, smem_reduce[row * 2], smem_reduce[row * 2 + warp_id / 4]);
smem_reduce
[
128
+
row
]
=
op
(
smem_reduce
[
row
*
2
],
smem_reduce
[
row
*
2
+
1
]);
}
__syncthreads
();
dst
(
0
)
=
smem_reduce
(
128
+
row
);
// if (col == 0 && warp_id < 4) {
// // printf("sum %d %d %d %.2f %.2f \n", row, tidx, warp_id, smem_reduce[row * 2], smem_reduce[row * 2 + warp_id / 4]);
// smem_reduce[128 + row] = op(smem_reduce[row * 2], smem_reduce[row * 2 + 1]);
// }
// __syncthreads();
// dst(0) = smem_reduce(128 + row);
dst
(
0
)
=
op
(
smem_reduce
[
row
*
2
],
smem_reduce
[
row
*
2
+
1
]);
}
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Operator
>
__device__
__forceinline__
void
warp_allreduce_tp4
(
Tensor
<
Engine0
,
Layout0
>
&
dst
,
Tensor
<
Engine1
,
Layout1
>
&
smem_reduce
,
Operator
&
op
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment