Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
c4c5912b
"openmmapi/src/PythonForce.cpp" did not exist on "b723e4af33f244f286a8586670754c7f548fc075"
Commit
c4c5912b
authored
Feb 25, 2025
by
chunyang.wen
Browse files
Update docstring
parent
bcb90f2a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
6 deletions
+6
-6
flash_mla/flash_mla_interface.py
flash_mla/flash_mla_interface.py
+6
-6
No files found.
flash_mla/flash_mla_interface.py
View file @
c4c5912b
...
@@ -16,7 +16,7 @@ def get_mla_metadata(
...
@@ -16,7 +16,7 @@ def get_mla_metadata(
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
num_heads_k: num_heads_k.
num_heads_k: num_heads_k.
Return:
Return
s
:
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
"""
"""
...
@@ -40,13 +40,13 @@ def flash_mla_with_kvcache(
...
@@ -40,13 +40,13 @@ def flash_mla_with_kvcache(
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head
_
dim of v.
head_dim_v: Head
dim
ension
of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return
ed
by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, return
ed
by get_mla_metadata.
softmax_scale: float. The scal
ing
of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
softmax_scale: float. The scal
e
of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
causal: bool. Whether to apply causal attention mask.
Return:
Return
s
:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment