Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
58b43d4a
Commit
58b43d4a
authored
Jan 30, 2026
by
zhanghj2
Browse files
修改写出
parent
d6379e50
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
csrc/smxx/decode/combine/combine.cu
csrc/smxx/decode/combine/combine.cu
+3
-3
No files found.
csrc/smxx/decode/combine/combine.cu
View file @
58b43d4a
...
@@ -131,7 +131,7 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
...
@@ -131,7 +131,7 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
float4
datas
[
ELEMS_PER_THREAD
];
float4
datas
[
ELEMS_PER_THREAD
];
CUTLASS_PRAGMA_UNROLL
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
ELEMS_PER_THREAD
;
++
i
)
{
for
(
int
i
=
0
;
i
<
ELEMS_PER_THREAD
;
++
i
)
{
datas
[
i
]
=
*
(
float4
*
)(
oaccum_ptr
+
lane_idx
*
4
+
i
*
256
);
// NOTE We don't use __ldg here since it is incompatible with PDL
datas
[
i
]
=
*
(
float4
*
)(
oaccum_ptr
+
lane_idx
*
8
+
i
*
4
);
// NOTE We don't use __ldg here since it is incompatible with PDL
}
}
// Warp #i accumulates activation for seq #i
// Warp #i accumulates activation for seq #i
{
{
...
@@ -155,7 +155,7 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
...
@@ -155,7 +155,7 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
result
[
i
].
z
+=
lse_scale
*
datas
[
i
].
z
;
result
[
i
].
z
+=
lse_scale
*
datas
[
i
].
z
;
result
[
i
].
w
+=
lse_scale
*
datas
[
i
].
w
;
result
[
i
].
w
+=
lse_scale
*
datas
[
i
].
w
;
if
(
split
!=
my_num_splits
-
1
)
{
if
(
split
!=
my_num_splits
-
1
)
{
datas
[
i
]
=
*
(
float4
*
)(
oaccum_ptr
+
(
split
+
1
)
*
params
.
stride_o_accum_split
+
lane_idx
*
4
+
i
*
256
);
datas
[
i
]
=
*
(
float4
*
)(
oaccum_ptr
+
(
split
+
1
)
*
params
.
stride_o_accum_split
+
lane_idx
*
8
+
i
*
4
);
}
}
}
}
// }
// }
...
@@ -182,7 +182,7 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
...
@@ -182,7 +182,7 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
data_converted
[
2
]
=
(
ElementT
)(
data
.
z
);
data_converted
[
2
]
=
(
ElementT
)(
data
.
z
);
data_converted
[
3
]
=
(
ElementT
)(
data
.
w
);
data_converted
[
3
]
=
(
ElementT
)(
data
.
w
);
static_assert
(
sizeof
(
ElementT
)
==
2
);
static_assert
(
sizeof
(
ElementT
)
==
2
);
*
(
uint64_t
*
)(
o_ptr
+
lane_idx
*
4
+
i
*
256
)
=
*
(
uint64_t
*
)
data_converted
;
*
(
uint64_t
*
)(
o_ptr
+
lane_idx
*
8
+
i
*
4
)
=
*
(
uint64_t
*
)
data_converted
;
}
}
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment