Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
261330bb
Unverified
Commit
261330bb
authored
Aug 27, 2025
by
Zeyu WANG
Committed by
GitHub
Aug 27, 2025
Browse files
fix calc space bug (#91)
* fix calc space bug * use python code to allocate the buffer for backward kernel
parent
eb758335
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
10 additions
and
50 deletions
+10
-50
csrc/sm100/common/utils.hpp
csrc/sm100/common/utils.hpp
+1
-38
csrc/sm100/device/fmha_device_bwd.hpp
csrc/sm100/device/fmha_device_bwd.hpp
+6
-6
csrc/sm100/fmha_cutlass_bwd_sm100.cuh
csrc/sm100/fmha_cutlass_bwd_sm100.cuh
+2
-5
flash_mla/flash_mla_interface.py
flash_mla/flash_mla_interface.py
+1
-1
No files found.
csrc/sm100/common/utils.hpp
View file @
261330bb
...
@@ -30,41 +30,4 @@ struct cutlass_dtype<__nv_fp8_e5m2> {
...
@@ -30,41 +30,4 @@ struct cutlass_dtype<__nv_fp8_e5m2> {
};
};
template
<
typename
T
>
template
<
typename
T
>
using
cutlass_dtype_t
=
typename
cutlass_dtype
<
T
>::
type
;
using
cutlass_dtype_t
=
typename
cutlass_dtype
<
T
>::
type
;
\ No newline at end of file
template
<
typename
T
>
struct
DeviceAllocation
{
T
*
ptr_
=
nullptr
;
size_t
offset_
=
0
;
size_t
size_
=
0
;
torch
::
Tensor
tensor
;
DeviceAllocation
(
DeviceAllocation
const
&
)
=
delete
;
DeviceAllocation
&
operator
=
(
DeviceAllocation
const
&
)
=
delete
;
DeviceAllocation
()
=
default
;
DeviceAllocation
(
size_t
size
)
{
reset
(
size
);
}
~
DeviceAllocation
()
{}
void
reset
(
size_t
size
,
size_t
offset
=
0
)
{
size_t
num_element
=
sizeof
(
T
)
*
(
size
+
offset
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kByte
).
device
(
torch
::
kCUDA
);
tensor
=
torch
::
empty
(
num_element
,
options
);
ptr_
=
tensor
.
data_ptr
<
T
>
();
size_
=
size
;
offset_
=
offset
;
}
T
*
get
()
{
return
ptr_
+
offset_
;
}
const
T
*
get
()
const
{
return
ptr_
+
offset_
;
}
size_t
size
()
const
{
return
size_
;
}
size_t
get_storage_size
()
const
{
return
(
size_
+
offset_
)
*
sizeof
(
T
);
}
};
csrc/sm100/device/fmha_device_bwd.hpp
View file @
261330bb
...
@@ -225,11 +225,11 @@ public:
...
@@ -225,11 +225,11 @@ public:
int
Q
=
cutlass
::
round_up
(
static_cast
<
int
>
(
Q_
),
8
);
// Alignment
int
Q
=
cutlass
::
round_up
(
static_cast
<
int
>
(
Q_
),
8
);
// Alignment
size_t
workspace_bytes
=
0
;
size_t
workspace_bytes
=
0
;
// OdO vector
// OdO vector
workspace_bytes
+=
B
*
H
*
Q
*
sizeof
(
ElementAccumulator
);
workspace_bytes
+=
sizeof
(
ElementAccumulator
)
*
B
*
H
*
Q
;
// scaled LSE vector
// scaled LSE vector
workspace_bytes
+=
B
*
H
*
Q
*
sizeof
(
ElementAccumulator
);
workspace_bytes
+=
sizeof
(
ElementAccumulator
)
*
B
*
H
*
Q
;
// FP32 versions of outputs that are churned (start off with Q only)
// FP32 versions of outputs that are churned (start off with Q only)
workspace_bytes
+=
B
*
H
*
Q
*
D
*
sizeof
(
ElementAccumulator
);
workspace_bytes
+=
sizeof
(
ElementAccumulator
)
*
B
*
H
*
Q
*
D
;
return
workspace_bytes
;
return
workspace_bytes
;
}
}
...
@@ -247,7 +247,7 @@ public:
...
@@ -247,7 +247,7 @@ public:
ElementAccumulator
*
scaled_lse
=
reinterpret_cast
<
ElementAccumulator
*>
(
workspace_scaled_lse
);
ElementAccumulator
*
scaled_lse
=
reinterpret_cast
<
ElementAccumulator
*>
(
workspace_scaled_lse
);
ElementAccumulator
*
dQ_acc
=
reinterpret_cast
<
ElementAccumulator
*>
(
workspace_dQ
);
ElementAccumulator
*
dQ_acc
=
reinterpret_cast
<
ElementAccumulator
*>
(
workspace_dQ
);
params_
.
dQ_acc
=
dQ_acc
;
params_
.
dQ_acc
=
dQ_acc
;
params_
.
dQ_acc_size
=
B
*
H
*
Q
*
D
*
sizeof
(
ElementAccumulator
);
params_
.
dQ_acc_size
=
sizeof
(
ElementAccumulator
)
*
B
*
H
*
Q
*
D
;
auto
args_sum_OdO
=
to_sum_OdO_arguments
(
args
,
sum_OdO
,
scaled_lse
);
auto
args_sum_OdO
=
to_sum_OdO_arguments
(
args
,
sum_OdO
,
scaled_lse
);
auto
args_convert
=
to_convert_arguments
(
args
,
dQ_acc
);
auto
args_convert
=
to_convert_arguments
(
args
,
dQ_acc
);
params_
.
op_sum_OdO
.
initialize
(
args_sum_OdO
,
nullptr
,
stream
);
params_
.
op_sum_OdO
.
initialize
(
args_sum_OdO
,
nullptr
,
stream
);
...
@@ -274,9 +274,9 @@ public:
...
@@ -274,9 +274,9 @@ public:
int
Q
=
cutlass
::
round_up
(
static_cast
<
int
>
(
Q_
),
8
);
// Alignment
int
Q
=
cutlass
::
round_up
(
static_cast
<
int
>
(
Q_
),
8
);
// Alignment
char
*
workspace_chr
=
reinterpret_cast
<
char
*>
(
workspace
);
char
*
workspace_chr
=
reinterpret_cast
<
char
*>
(
workspace
);
ElementAccumulator
*
sum_OdO
=
reinterpret_cast
<
ElementAccumulator
*>
(
workspace_chr
);
ElementAccumulator
*
sum_OdO
=
reinterpret_cast
<
ElementAccumulator
*>
(
workspace_chr
);
workspace_chr
+=
B
*
H
*
Q
*
sizeof
(
ElementAccumulator
);
workspace_chr
+=
sizeof
(
ElementAccumulator
)
*
B
*
H
*
Q
;
ElementAccumulator
*
scaled_lse
=
reinterpret_cast
<
ElementAccumulator
*>
(
workspace_chr
);
ElementAccumulator
*
scaled_lse
=
reinterpret_cast
<
ElementAccumulator
*>
(
workspace_chr
);
workspace_chr
+=
B
*
H
*
Q
*
sizeof
(
ElementAccumulator
);
workspace_chr
+=
sizeof
(
ElementAccumulator
)
*
B
*
H
*
Q
;
ElementAccumulator
*
dQ_acc
=
reinterpret_cast
<
ElementAccumulator
*>
(
workspace_chr
);
ElementAccumulator
*
dQ_acc
=
reinterpret_cast
<
ElementAccumulator
*>
(
workspace_chr
);
return
initialize_split
(
args
,
dQ_acc
,
sum_OdO
,
scaled_lse
,
stream
);
return
initialize_split
(
args
,
dQ_acc
,
sum_OdO
,
scaled_lse
,
stream
);
}
}
...
...
csrc/sm100/fmha_cutlass_bwd_sm100.cuh
View file @
261330bb
...
@@ -174,13 +174,10 @@ struct BwdRunner {
...
@@ -174,13 +174,10 @@ struct BwdRunner {
Operation
op
;
Operation
op
;
size_t
workspace_size
=
0
;
uint8_t
*
workspace_ptr
=
static_cast
<
uint8_t
*>
(
workspace_buffer
.
data_ptr
());
workspace_size
=
Operation
::
get_workspace_size
(
arguments
);
DeviceAllocation
<
uint8_t
>
workspace
(
workspace_size
);
uint8_t
*
workspace_ptr
=
workspace
.
get
();
CUTLASS_CHECK
(
op
.
can_implement
(
arguments
));
CUTLASS_CHECK
(
op
.
can_implement
(
arguments
));
CUTLASS_CHECK
(
op
.
initialize
(
arguments
,
workspace
.
get
()
));
CUTLASS_CHECK
(
op
.
initialize
(
arguments
,
workspace
_ptr
));
CUTLASS_CHECK
(
op
.
run
(
at
::
cuda
::
getCurrentCUDAStream
()));
CUTLASS_CHECK
(
op
.
run
(
at
::
cuda
::
getCurrentCUDAStream
()));
}
}
...
...
flash_mla/flash_mla_interface.py
View file @
261330bb
...
@@ -154,7 +154,7 @@ def _flash_attn_varlen_backward(
...
@@ -154,7 +154,7 @@ def _flash_attn_varlen_backward(
max_seqlen_qo_aligned
=
(
max_seqlen_qo
+
7
)
//
8
*
8
max_seqlen_qo_aligned
=
(
max_seqlen_qo
+
7
)
//
8
*
8
bs
=
cu_seqlens_qo
.
shape
[
0
]
-
1
bs
=
cu_seqlens_qo
.
shape
[
0
]
-
1
workspace_bytes
=
0
workspace_bytes
=
0
workspace_bytes
+=
4
*
qo_total_len
*
num_qo_heads
*
head_dim_qk
# dQ_acc
workspace_bytes
+=
4
*
bs
*
max_seqlen_qo_aligned
*
num_qo_heads
*
head_dim_qk
# dQ_acc
workspace_bytes
+=
4
*
max_seqlen_qo_aligned
*
bs
*
num_qo_heads
*
2
# sum_OdO and scaled_lse
workspace_bytes
+=
4
*
max_seqlen_qo_aligned
*
bs
*
num_qo_heads
*
2
# sum_OdO and scaled_lse
if
num_qo_heads
!=
num_kv_heads
:
if
num_qo_heads
!=
num_kv_heads
:
workspace_bytes
+=
2
*
kv_total_len
*
num_qo_heads
*
(
head_dim_qk
+
head_dim_vo
)
# dKV_acc
workspace_bytes
+=
2
*
kv_total_len
*
num_qo_heads
*
(
head_dim_qk
+
head_dim_vo
)
# dKV_acc
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment