Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f29aeb5a
Unverified
Commit
f29aeb5a
authored
Oct 31, 2025
by
Matthew Bonanni
Committed by
GitHub
Oct 31, 2025
Browse files
Add FLASHINFER_MLA to test_mla_backends and add B200 CI run (#27663)
Signed-off-by:
Matthew Bonanni
<
mbonanni@redhat.com
>
parent
5e8862e9
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
208 additions
and
64 deletions
+208
-64
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+10
-0
tests/v1/attention/test_mla_backends.py
tests/v1/attention/test_mla_backends.py
+182
-62
tests/v1/attention/utils.py
tests/v1/attention/utils.py
+11
-1
vllm/v1/attention/backends/mla/flashinfer_mla.py
vllm/v1/attention/backends/mla/flashinfer_mla.py
+5
-1
No files found.
.buildkite/test-pipeline.yaml
View file @
f29aeb5a
...
@@ -340,6 +340,16 @@ steps:
...
@@ -340,6 +340,16 @@ steps:
commands
:
commands
:
-
pytest -v -s v1/attention
-
pytest -v -s v1/attention
-
label
:
V1 Test attention (B200)
# 10min
timeout_in_minutes
:
30
gpu
:
b200
source_file_dependencies
:
-
vllm/v1/attention
-
tests/v1/attention
commands
:
-
export VLLM_DISABLE_FLASHINFER_PREFILL=1
# TODO: FI prefill is bugged and causes incorrectness, fix this
-
pytest -v -s v1/attention
-
label
:
V1 Test others (CPU)
# 5 mins
-
label
:
V1 Test others (CPU)
# 5 mins
source_file_dependencies
:
source_file_dependencies
:
-
vllm/
-
vllm/
...
...
tests/v1/attention/test_mla_backends.py
View file @
f29aeb5a
...
@@ -14,16 +14,19 @@ import torch
...
@@ -14,16 +14,19 @@ import torch
from
tests.v1.attention.utils
import
(
from
tests.v1.attention.utils
import
(
BatchSpec
,
BatchSpec
,
create_common_attn_metadata
,
create_common_attn_metadata
,
create_standard_kv_cache_spec
,
create_vllm_config
,
create_vllm_config
,
try_get_attention_backend
,
try_get_attention_backend
,
)
)
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.registry
import
_Backend
from
vllm.attention.backends.registry
import
_Backend
,
backend_to_class_str
from
vllm.attention.ops.flashmla
import
is_flashmla_dense_supported
from
vllm.attention.ops.flashmla
import
is_flashmla_dense_supported
from
vllm.attention.utils.fa_utils
import
flash_attn_supports_mla
from
vllm.config.vllm
import
set_current_vllm_config
from
vllm.config.vllm
import
set_current_vllm_config
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.utils.import_utils
import
resolve_obj_by_qualname
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.v1.attention.backends.mla.common
import
QueryLenSupport
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.kv_cache_interface
import
FullAttentionSpec
from
vllm.v1.kv_cache_interface
import
FullAttentionSpec
...
@@ -31,17 +34,46 @@ BACKENDS_TO_TEST = [
...
@@ -31,17 +34,46 @@ BACKENDS_TO_TEST = [
_Backend
.
CUTLASS_MLA
,
_Backend
.
CUTLASS_MLA
,
_Backend
.
FLASHMLA
,
_Backend
.
FLASHMLA
,
_Backend
.
FLASH_ATTN_MLA
,
_Backend
.
FLASH_ATTN_MLA
,
_Backend
.
FLASHINFER_MLA
,
_Backend
.
TRITON_MLA
,
_Backend
.
TRITON_MLA
,
]
]
# Remove
CUTLASS_MLA
from the list if not using sm100
# Remove
sm100 backends
from the list if not using sm100
if
not
torch
.
cuda
.
is_available
()
or
torch
.
cuda
.
get_device_properties
(
0
).
major
<
10
:
if
not
torch
.
cuda
.
is_available
()
or
torch
.
cuda
.
get_device_properties
(
0
).
major
<
10
:
BACKENDS_TO_TEST
.
remove
(
_Backend
.
CUTLASS_MLA
)
BACKENDS_TO_TEST
.
remove
(
_Backend
.
CUTLASS_MLA
)
BACKENDS_TO_TEST
.
remove
(
_Backend
.
FLASHINFER_MLA
)
# Remove FLASH_ATTN_MLA from the list if not supported
if
not
flash_attn_supports_mla
():
BACKENDS_TO_TEST
.
remove
(
_Backend
.
FLASH_ATTN_MLA
)
# Remove FLASHMLA from the list if not supported
# Remove FLASHMLA from the list if not supported
if
not
is_flashmla_dense_supported
()[
0
]:
if
not
is_flashmla_dense_supported
()[
0
]:
BACKENDS_TO_TEST
.
remove
(
_Backend
.
FLASHMLA
)
BACKENDS_TO_TEST
.
remove
(
_Backend
.
FLASHMLA
)
SPEC_DECODE_BACKENDS
=
[]
for
backend
in
BACKENDS_TO_TEST
:
builder_cls
,
_
=
try_get_attention_backend
(
backend
)
query_len_support
=
getattr
(
builder_cls
,
"query_len_support"
,
QueryLenSupport
.
SINGLE_ONLY
)
if
query_len_support
!=
QueryLenSupport
.
SINGLE_ONLY
:
SPEC_DECODE_BACKENDS
.
append
(
backend
)
BACKEND_BLOCK_SIZES
=
{}
for
backend
in
BACKENDS_TO_TEST
:
backend_class_str
=
backend_to_class_str
(
backend
)
backend_class
=
resolve_obj_by_qualname
(
backend_class_str
)
supported_sizes
=
backend_class
.
get_supported_kernel_block_size
()
if
supported_sizes
:
default_size
=
supported_sizes
[
0
]
block_size
=
(
default_size
if
isinstance
(
default_size
,
int
)
else
default_size
.
base
)
else
:
block_size
=
16
BACKEND_BLOCK_SIZES
[
backend
]
=
block_size
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
...
@@ -236,6 +268,26 @@ class MockAttentionLayer:
...
@@ -236,6 +268,26 @@ class MockAttentionLayer:
self
.
_q_scale
=
torch
.
tensor
(
1.0
,
device
=
device
)
self
.
_q_scale
=
torch
.
tensor
(
1.0
,
device
=
device
)
self
.
_k_scale
=
torch
.
tensor
(
1.0
,
device
=
device
)
self
.
_k_scale
=
torch
.
tensor
(
1.0
,
device
=
device
)
self
.
_v_scale
=
torch
.
tensor
(
1.0
,
device
=
device
)
self
.
_v_scale
=
torch
.
tensor
(
1.0
,
device
=
device
)
self
.
_prob_scale
=
torch
.
tensor
(
1.0
,
device
=
device
)
self
.
_q_scale_float
=
1.0
self
.
_k_scale_float
=
1.0
self
.
_v_scale_float
=
1.0
def
forward
(
self
,
*
_args
,
**
_kwargs
):
raise
NotImplementedError
class
MockMLAAttentionLayer
(
AttentionLayerBase
):
"""A mock MLA attention layer for populating static_forward_context."""
def
__init__
(
self
,
impl
):
self
.
impl
=
impl
def
get_attn_backend
(
self
):
raise
NotImplementedError
def
get_kv_cache_spec
(
self
,
vllm_config
):
raise
NotImplementedError
def
run_attention_backend
(
def
run_attention_backend
(
...
@@ -262,13 +314,6 @@ def run_attention_backend(
...
@@ -262,13 +314,6 @@ def run_attention_backend(
# Set the current vllm config so that get_current_vllm_config() works
# Set the current vllm config so that get_current_vllm_config() works
# in the backend implementations
# in the backend implementations
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
# Build metadata
builder
=
builder_cls
(
kv_cache_spec
,
layer_names
,
vllm_config
,
device
)
attn_metadata
=
builder
.
build
(
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
,
)
# Instantiate MLA implementation
# Instantiate MLA implementation
num_heads
=
vllm_config
.
model_config
.
get_num_attention_heads
(
num_heads
=
vllm_config
.
model_config
.
get_num_attention_heads
(
vllm_config
.
parallel_config
vllm_config
.
parallel_config
...
@@ -302,6 +347,19 @@ def run_attention_backend(
...
@@ -302,6 +347,19 @@ def run_attention_backend(
act_dtype
=
_convert_dtype_to_torch
(
vllm_config
.
model_config
.
dtype
)
act_dtype
=
_convert_dtype_to_torch
(
vllm_config
.
model_config
.
dtype
)
impl
.
process_weights_after_loading
(
act_dtype
)
impl
.
process_weights_after_loading
(
act_dtype
)
# Populate static_forward_context with mock attention layers
for
layer_name
in
layer_names
:
vllm_config
.
compilation_config
.
static_forward_context
[
layer_name
]
=
(
MockMLAAttentionLayer
(
impl
)
)
# Build metadata
builder
=
builder_cls
(
kv_cache_spec
,
layer_names
,
vllm_config
,
device
)
attn_metadata
=
builder
.
build
(
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
,
)
# Create mock layer and output buffer
# Create mock layer and output buffer
mock_layer
=
MockAttentionLayer
(
device
)
mock_layer
=
MockAttentionLayer
(
device
)
num_tokens
=
query
.
shape
[
0
]
num_tokens
=
query
.
shape
[
0
]
...
@@ -353,15 +411,14 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
...
@@ -353,15 +411,14 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
simulated paged KV cache.
simulated paged KV cache.
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
"""
"""
from
vllm.v1.attention.backends.mla.common
import
QueryLenSupport
batch_spec
=
BATCH_SPECS
[
batch_spec_name
]
batch_spec
=
BATCH_SPECS
[
batch_spec_name
]
is_spec_decode_test
=
batch_spec_name
.
startswith
(
"spec_decode"
)
is_spec_decode_test
=
batch_spec_name
.
startswith
(
"spec_decode"
)
spec_decode_backends
=
{
_Backend
.
FLASH_ATTN_MLA
,
_Backend
.
FLASHMLA
}
unique_block_sizes
=
sorted
(
set
(
BACKEND_BLOCK_SIZES
.
values
()))
default_block_size
=
unique_block_sizes
[
0
]
block_size
=
16
required_blocks
=
sum
(
required_blocks
=
sum
(
(
seq_len
+
block_size
-
1
)
//
block_size
for
seq_len
in
batch_spec
.
seq_lens
(
seq_len
+
default_block_size
-
1
)
//
default_block_size
for
seq_len
in
batch_spec
.
seq_lens
)
)
# Add 1 for null block at index 0, and some buffer
# Add 1 for null block at index 0, and some buffer
num_gpu_blocks
=
required_blocks
+
1
+
100
num_gpu_blocks
=
required_blocks
+
1
+
100
...
@@ -370,7 +427,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
...
@@ -370,7 +427,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
model_name
=
model
,
model_name
=
model
,
max_model_len
=
max
(
batch_spec
.
seq_lens
),
max_model_len
=
max
(
batch_spec
.
seq_lens
),
num_gpu_blocks
=
num_gpu_blocks
,
num_gpu_blocks
=
num_gpu_blocks
,
block_size
=
block_size
,
block_size
=
default_
block_size
,
)
)
# For spec decode tests, add a speculative_config to set the reorder_batch_threshold
# For spec decode tests, add a speculative_config to set the reorder_batch_threshold
...
@@ -388,8 +445,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
...
@@ -388,8 +445,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
device
=
torch
.
device
(
"cuda:0"
)
device
=
torch
.
device
(
"cuda:0"
)
kv_cache_spec
=
create_standard_kv_cache_spec
(
vllm_config
)
# 1. Setup
# 1. Setup
batch_size
=
batch_spec
.
batch_size
batch_size
=
batch_spec
.
batch_size
seq_lens
=
batch_spec
.
seq_lens
seq_lens
=
batch_spec
.
seq_lens
...
@@ -399,7 +454,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
...
@@ -399,7 +454,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
)
)
head_size
=
vllm_config
.
model_config
.
get_head_size
()
head_size
=
vllm_config
.
model_config
.
get_head_size
()
dtype
=
_convert_dtype_to_torch
(
vllm_config
.
model_config
.
dtype
)
dtype
=
_convert_dtype_to_torch
(
vllm_config
.
model_config
.
dtype
)
block_size
=
vllm_config
.
cache_config
.
block_size
kv_lora_rank
=
512
kv_lora_rank
=
512
qk_rope_head_dim
=
64
qk_rope_head_dim
=
64
qk_nope_head_dim
=
128
qk_nope_head_dim
=
128
...
@@ -598,12 +652,44 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
...
@@ -598,12 +652,44 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
)
)
mock_kv_b_proj
.
weight
=
torch
.
nn
.
Parameter
(
kv_b_proj_weight
.
T
,
requires_grad
=
False
)
mock_kv_b_proj
.
weight
=
torch
.
nn
.
Parameter
(
kv_b_proj_weight
.
T
,
requires_grad
=
False
)
# Create metadata using original batch spec
# 3. Create metadata and KV caches for each block size
# Group backends by block size and test each group
metadata_per_block_size
=
{}
kv_cache_per_block_size
=
{}
for
block_size
in
unique_block_sizes
:
# Create metadata for this block size
common_attn_metadata
=
create_common_attn_metadata
(
common_attn_metadata
=
create_common_attn_metadata
(
batch_spec
,
vllm_config
.
cache_config
.
block_size
,
device
batch_spec
,
block_size
,
device
)
# Pad block table to meet requirement:
# block_num % (128 / block_size) == 0
required_divisor
=
int
(
128
/
block_size
)
current_block_num
=
common_attn_metadata
.
block_table_tensor
.
shape
[
1
]
if
current_block_num
%
required_divisor
!=
0
:
# Pad to next multiple of required_divisor
padded_block_num
=
(
(
current_block_num
+
required_divisor
-
1
)
//
required_divisor
)
*
required_divisor
padding_cols
=
padded_block_num
-
current_block_num
padding
=
torch
.
zeros
(
(
common_attn_metadata
.
block_table_tensor
.
shape
[
0
],
padding_cols
),
dtype
=
torch
.
int32
,
device
=
device
,
)
common_attn_metadata
.
block_table_tensor
=
torch
.
cat
(
[
common_attn_metadata
.
block_table_tensor
,
padding
],
dim
=
1
)
metadata_per_block_size
[
block_size
]
=
common_attn_metadata
# Create KV cache for this block size
required_blocks_for_size
=
sum
(
(
seq_len
+
block_size
-
1
)
//
block_size
for
seq_len
in
batch_spec
.
seq_lens
)
)
num_blocks_for_size
=
required_blocks_for_size
+
1
+
100
# 3. Simulate Paged KV Cache and a realistic slot_mapping
kv_cache
=
create_and_prepopulate_kv_cache
(
kv_cache
=
create_and_prepopulate_kv_cache
(
kv_c_contexts
=
kv_c_contexts
,
kv_c_contexts
=
kv_c_contexts
,
k_pe_contexts
=
k_pe_contexts
,
k_pe_contexts
=
k_pe_contexts
,
...
@@ -611,20 +697,38 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
...
@@ -611,20 +697,38 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
head_size
=
head_size
,
head_size
=
head_size
,
dtype
=
dtype
,
dtype
=
dtype
,
device
=
device
,
device
=
device
,
num_blocks
=
vllm_config
.
cache_config
.
num_gpu_blocks
,
num_blocks
=
num_blocks_for_size
,
common_attn_metadata
=
common_attn_metadata
,
common_attn_metadata
=
common_attn_metadata
,
randomize_blocks
=
True
,
randomize_blocks
=
True
,
)
)
kv_cache_per_block_size
[
block_size
]
=
kv_cache
# 4. Run vLLM backends and compare
# 4. Run vLLM backends and compare
failures
=
[]
for
backend_idx
,
backend_name
in
enumerate
(
BACKENDS_TO_TEST
):
for
backend_idx
,
backend_name
in
enumerate
(
BACKENDS_TO_TEST
):
# Skip backends that don't support spec decode for spec decode tests
# Skip backends that don't support spec decode for spec decode tests
if
is_spec_decode_test
and
backend_name
not
in
spec_decode_backends
:
if
is_spec_decode_test
and
backend_name
not
in
SPEC_DECODE_BACKENDS
:
continue
continue
# Get the appropriate block_size, metadata, and cache for this backend
block_size
=
BACKEND_BLOCK_SIZES
[
backend_name
]
common_attn_metadata
=
metadata_per_block_size
[
block_size
]
kv_cache
=
kv_cache_per_block_size
[
block_size
]
# Create kv_cache_spec with the correct block_size for this backend
backend_kv_cache_spec
=
FullAttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
vllm_config
.
model_config
.
get_num_kv_heads
(
vllm_config
.
parallel_config
),
head_size
=
vllm_config
.
model_config
.
get_head_size
(),
dtype
=
vllm_config
.
model_config
.
dtype
,
sliding_window
=
vllm_config
.
model_config
.
get_sliding_window
(),
)
backend_output
=
run_attention_backend
(
backend_output
=
run_attention_backend
(
backend_name
,
backend_name
,
kv_cache_spec
,
backend_
kv_cache_spec
,
[
"placeholder"
],
[
"placeholder"
],
vllm_config
,
vllm_config
,
device
,
device
,
...
@@ -644,6 +748,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
...
@@ -644,6 +748,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
expected_output
=
sdpa_outputs
[
backend_name
]
expected_output
=
sdpa_outputs
[
backend_name
]
# Check shape and dtype consistency
# Check shape and dtype consistency
try
:
assert
backend_output
.
shape
==
expected_output
.
shape
,
(
assert
backend_output
.
shape
==
expected_output
.
shape
,
(
f
"[
{
backend_name
}
] shape
{
backend_output
.
shape
}
!= "
f
"[
{
backend_name
}
] shape
{
backend_output
.
shape
}
!= "
f
"SDPA shape
{
expected_output
.
shape
}
"
f
"SDPA shape
{
expected_output
.
shape
}
"
...
@@ -673,3 +778,18 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
...
@@ -673,3 +778,18 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
f
"[
{
backend_name
}
] output differs from SDPA baseline. "
f
"[
{
backend_name
}
] output differs from SDPA baseline. "
f
"Max diff:
{
max_diff
:.
6
f
}
, max rel diff:
{
max_rel_diff
:.
6
f
}
)"
f
"Max diff:
{
max_diff
:.
6
f
}
, max rel diff:
{
max_rel_diff
:.
6
f
}
)"
)
)
except
AssertionError
as
e
:
failures
.
append
(
str
(
e
))
# Report all failures at once
if
failures
:
# Create a summary for the single-line failure message
backend_names
=
[]
for
f
in
failures
:
if
"[_Backend."
in
f
:
backend_name
=
f
.
split
(
"["
)[
1
].
split
(
"]"
)[
0
]
backend_names
.
append
(
backend_name
)
summary
=
f
"
{
len
(
failures
)
}
backend(s) failed:
{
', '
.
join
(
backend_names
)
}
"
detailed_msg
=
"
\n
"
.
join
(
failures
)
pytest
.
fail
(
f
"
{
summary
}
\n
{
detailed_msg
}
"
)
tests/v1/attention/utils.py
View file @
f29aeb5a
...
@@ -285,7 +285,17 @@ full_cg_backend_configs = {
...
@@ -285,7 +285,17 @@ full_cg_backend_configs = {
name
=
"CutlassMLA"
,
name
=
"CutlassMLA"
,
env_vars
=
{
env_vars
=
{
"VLLM_ATTENTION_BACKEND"
:
"CUTLASS_MLA"
,
"VLLM_ATTENTION_BACKEND"
:
"CUTLASS_MLA"
,
"FORCE_NUM_KV_SPLITS"
:
"1"
,
# TODO: remove this when hang issue is fixed
},
comp_config
=
{
"cudagraph_mode"
:
"FULL_AND_PIECEWISE"
,
},
specific_gpu_arch
=
(
10
,
0
),
),
# FlashInfer MLA on Blackwell
"FlashInferMLA"
:
BackendConfig
(
name
=
"FlashInferMLA"
,
env_vars
=
{
"VLLM_ATTENTION_BACKEND"
:
"FLASHINFER_MLA"
,
},
},
comp_config
=
{
comp_config
=
{
"cudagraph_mode"
:
"FULL_AND_PIECEWISE"
,
"cudagraph_mode"
:
"FULL_AND_PIECEWISE"
,
...
...
vllm/v1/attention/backends/mla/flashinfer_mla.py
View file @
f29aeb5a
...
@@ -6,7 +6,7 @@ from typing import ClassVar
...
@@ -6,7 +6,7 @@ from typing import ClassVar
import
torch
import
torch
from
flashinfer.decode
import
trtllm_batch_decode_with_kv_cache_mla
from
flashinfer.decode
import
trtllm_batch_decode_with_kv_cache_mla
from
vllm.attention.backends.abstract
import
AttentionLayer
,
AttentionType
from
vllm.attention.backends.abstract
import
AttentionLayer
,
AttentionType
,
MultipleOf
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.mla.common
import
(
from
vllm.v1.attention.backends.mla.common
import
(
MLACommonBackend
,
MLACommonBackend
,
...
@@ -40,6 +40,10 @@ class FlashInferMLABackend(MLACommonBackend):
...
@@ -40,6 +40,10 @@ class FlashInferMLABackend(MLACommonBackend):
def
get_builder_cls
()
->
type
[
"FlashInferMLAMetadataBuilder"
]:
def
get_builder_cls
()
->
type
[
"FlashInferMLAMetadataBuilder"
]:
return
FlashInferMLAMetadataBuilder
return
FlashInferMLAMetadataBuilder
@
classmethod
def
get_supported_kernel_block_size
(
cls
)
->
list
[
int
|
MultipleOf
]:
return
[
32
,
64
]
g_fi_workspace
=
torch
.
zeros
(
g_fi_workspace
=
torch
.
zeros
(
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE
,
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment