Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d49f2731
"vllm/vscode:/vscode.git/clone" did not exist on "9659bc7f271ec640da780b5ca739e261764b954b"
Unverified
Commit
d49f2731
authored
Mar 19, 2026
by
zhanqiuhu
Committed by
GitHub
Mar 19, 2026
Browse files
[SSM/Mamba] Follow-up: N-1 prefill for P/D disaggregation (#37310)
parent
b21d3843
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
263 additions
and
13 deletions
+263
-13
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+1
-1
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
+113
-8
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
+77
-1
tests/v1/kv_connector/unit/utils.py
tests/v1/kv_connector/unit/utils.py
+30
-2
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+42
-1
No files found.
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
d49f2731
...
@@ -2007,7 +2007,7 @@ def test_transfer_failure_logging(
...
@@ -2007,7 +2007,7 @@ def test_transfer_failure_logging(
connector
=
NixlConnector
(
connector
=
NixlConnector
(
vllm_config
,
vllm_config
,
KVConnectorRole
.
WORKER
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
,
hm
a_enabled
=
enable_hma
),
make_kv_cache_config
(
block_size
=
16
,
sw
a_enabled
=
enable_hma
),
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
vllm_config
,
...
...
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
View file @
d49f2731
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for NixlConnectorScheduler
sw_sizes calculation with HMA
."""
"""Unit tests for NixlConnectorScheduler
with HMA and Mamba N-1 prefill
."""
from
unittest.mock
import
patch
from
unittest.mock
import
patch
...
@@ -14,24 +14,26 @@ from vllm.v1.core.single_type_kv_cache_manager import (
...
@@ -14,24 +14,26 @@ from vllm.v1.core.single_type_kv_cache_manager import (
)
)
from
.utils
import
(
from
.utils
import
(
create_request
,
create_vllm_config
,
create_vllm_config
,
make_kv_cache_config
,
make_kv_cache_config
,
make_nixl_scheduler
,
)
)
@
pytest
.
mark
.
cpu_test
@
pytest
.
mark
.
cpu_test
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"
hm
a_enabled,expected_sw_sizes"
,
"
sw
a_enabled,expected_sw_sizes"
,
[
[
#
HM
A enabled: FullAttentionSpec (0) + SlidingWindowSpec (2048/16=128)
#
SW
A enabled: FullAttentionSpec (0) + SlidingWindowSpec (2048/16=128)
(
True
,
[
0
,
128
+
1
]),
(
True
,
[
0
,
128
+
1
]),
#
HM
A disabled: only FullAttentionSpec (0)
#
SW
A disabled: only FullAttentionSpec (0)
(
False
,
[
0
]),
(
False
,
[
0
]),
],
],
)
)
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform"
)
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform"
)
def
test_sw_sizes
(
mock_platform
,
hm
a_enabled
,
expected_sw_sizes
):
def
test_sw_sizes
(
mock_platform
,
sw
a_enabled
,
expected_sw_sizes
):
"""Test sw_sizes is correctly computed based on
HM
A enabled/disabled."""
"""Test sw_sizes is correctly computed based on
SW
A enabled/disabled."""
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
NixlConnectorScheduler
,
NixlConnectorScheduler
,
)
)
...
@@ -42,7 +44,7 @@ def test_sw_sizes(mock_platform, hma_enabled, expected_sw_sizes):
...
@@ -42,7 +44,7 @@ def test_sw_sizes(mock_platform, hma_enabled, expected_sw_sizes):
vllm_config
=
create_vllm_config
(
block_size
=
block_size
)
vllm_config
=
create_vllm_config
(
block_size
=
block_size
)
# SW 2048 tokens=>128 blocks
# SW 2048 tokens=>128 blocks
kv_cache_config
=
make_kv_cache_config
(
kv_cache_config
=
make_kv_cache_config
(
block_size
=
block_size
,
hm
a_enabled
=
hm
a_enabled
,
sw_size
=
2048
block_size
=
block_size
,
sw
a_enabled
=
sw
a_enabled
,
sw_size
=
2048
)
)
scheduler
=
NixlConnectorScheduler
(
scheduler
=
NixlConnectorScheduler
(
...
@@ -75,7 +77,7 @@ def test_logical_to_kernel_block_ids_with_hma():
...
@@ -75,7 +77,7 @@ def test_logical_to_kernel_block_ids_with_hma():
# So each logical block maps to 2 kernel blocks eg [0]->[0,1]
# So each logical block maps to 2 kernel blocks eg [0]->[0,1]
worker
.
_physical_blocks_per_logical_kv_block
=
2
worker
.
_physical_blocks_per_logical_kv_block
=
2
# FA + SW groups (neither is MambaSpec, so both get expanded)
# FA + SW groups (neither is MambaSpec, so both get expanded)
worker
.
kv_cache_config
=
make_kv_cache_config
(
block_size
=
16
,
hm
a_enabled
=
True
)
worker
.
kv_cache_config
=
make_kv_cache_config
(
block_size
=
16
,
sw
a_enabled
=
True
)
# Test conversion: FA + SW group
# Test conversion: FA + SW group
logical_block_ids
=
[[
0
,
1
,
2
],
[
3
,
4
]]
logical_block_ids
=
[[
0
,
1
,
2
],
[
3
,
4
]]
...
@@ -313,3 +315,106 @@ def test_nixl_metadata_hybrid_ssm_block_ids():
...
@@ -313,3 +315,106 @@ def test_nixl_metadata_hybrid_ssm_block_ids():
assert
list
(
req_meta
.
remote
.
block_ids
[
0
])
==
[
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
]
assert
list
(
req_meta
.
remote
.
block_ids
[
0
])
==
[
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
]
assert
list
(
req_meta
.
remote
.
block_ids
[
1
])
==
[
20
,
21
]
assert
list
(
req_meta
.
remote
.
block_ids
[
1
])
==
[
20
,
21
]
assert
len
(
req_meta
.
remote
.
block_ids
[
0
])
!=
len
(
req_meta
.
remote
.
block_ids
[
1
])
assert
len
(
req_meta
.
remote
.
block_ids
[
0
])
!=
len
(
req_meta
.
remote
.
block_ids
[
1
])
# ── Mamba N-1 prefill tests ──────────────────────────────────────────────
@
pytest
.
mark
.
cpu_test
@
pytest
.
mark
.
parametrize
(
"has_mamba,is_hma_required,expected_count"
,
[
(
True
,
True
,
9
),
(
False
,
False
,
10
),
(
False
,
True
,
10
),
],
ids
=
[
"mamba"
,
"fa_only"
,
"swa_only"
],
)
def
test_mamba_n1_d_side
(
has_mamba
,
is_hma_required
,
expected_count
):
"""D-side: Mamba gets N-1 matched tokens, non-Mamba gets N."""
sched
=
make_nixl_scheduler
(
has_mamba
=
has_mamba
,
is_hma_required
=
is_hma_required
)
req
=
create_request
(
num_tokens
=
10
,
do_remote_prefill
=
True
)
count
,
is_async
=
sched
.
get_num_new_matched_tokens
(
req
,
num_computed_tokens
=
0
)
assert
count
==
expected_count
assert
is_async
is
True
@
pytest
.
mark
.
cpu_test
def
test_mamba_n1_p_side_truncation
():
"""P-side: Mamba truncates prompt to N-1, sets max_tokens=1.
Also verifies idempotency (calling again is a no-op) which is
needed for preemption safety via the _p_side_truncated guard,
and that non-Mamba models skip truncation entirely.
"""
sched
=
make_nixl_scheduler
(
has_mamba
=
True
,
is_hma_required
=
True
)
req
=
create_request
(
num_tokens
=
10
,
do_remote_decode
=
True
)
req
.
max_tokens
=
128
original_len
=
len
(
req
.
prompt_token_ids
)
count
,
is_async
=
sched
.
get_num_new_matched_tokens
(
req
,
num_computed_tokens
=
0
)
assert
count
==
0
assert
is_async
is
False
assert
len
(
req
.
prompt_token_ids
)
==
original_len
-
1
assert
req
.
num_prompt_tokens
==
original_len
-
1
assert
req
.
max_tokens
==
1
assert
req
.
kv_transfer_params
[
"_p_side_truncated"
]
is
True
# Idempotency: second call must not truncate further
sched
.
get_num_new_matched_tokens
(
req
,
num_computed_tokens
=
0
)
assert
len
(
req
.
prompt_token_ids
)
==
original_len
-
1
# Non-Mamba: truncation is skipped
fa_sched
=
make_nixl_scheduler
(
has_mamba
=
False
,
is_hma_required
=
False
)
fa_req
=
create_request
(
num_tokens
=
10
,
do_remote_decode
=
True
)
fa_original
=
len
(
fa_req
.
prompt_token_ids
)
fa_sched
.
get_num_new_matched_tokens
(
fa_req
,
num_computed_tokens
=
0
)
assert
len
(
fa_req
.
prompt_token_ids
)
==
fa_original
@
pytest
.
mark
.
cpu_test
@
pytest
.
mark
.
parametrize
(
"swa_enabled,mamba_enabled,expected_has_mamba,expected_is_hma"
,
[
(
True
,
True
,
True
,
True
),
(
True
,
False
,
False
,
True
),
(
False
,
False
,
False
,
False
),
],
ids
=
[
"fa_swa_mamba"
,
"fa_swa_only"
,
"fa_only"
],
)
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform"
)
def
test_has_mamba_init
(
mock_platform
,
swa_enabled
,
mamba_enabled
,
expected_has_mamba
,
expected_is_hma
,
):
"""Test _has_mamba / _is_hma_required derived from kv_cache_groups."""
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
NixlConnectorScheduler
,
)
mock_platform
.
device_type
=
"cpu"
block_size
=
16
vllm_config
=
create_vllm_config
(
block_size
=
block_size
)
# VllmConfig.__post_init__ auto-disables HMA when kv_transfer_config
# is set; override so we can test the scheduler's own derivation.
vllm_config
.
scheduler_config
.
disable_hybrid_kv_cache_manager
=
False
kv_cache_config
=
make_kv_cache_config
(
block_size
=
block_size
,
swa_enabled
=
swa_enabled
,
mamba_enabled
=
mamba_enabled
,
)
scheduler
=
NixlConnectorScheduler
(
vllm_config
=
vllm_config
,
engine_id
=
"test-engine"
,
kv_cache_config
=
kv_cache_config
,
)
assert
scheduler
.
_has_mamba
is
expected_has_mamba
assert
scheduler
.
_is_hma_required
is
expected_is_hma
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
View file @
d49f2731
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
copy
import
copy
from
unittest.mock
import
patch
import
pytest
import
pytest
from
vllm.v1.outputs
import
EMPTY_MODEL_RUNNER_OUTPUT
,
KVConnectorOutput
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
KVConnectorOutput
,
ModelRunnerOutput
,
)
from
vllm.v1.request
import
FinishReason
,
RequestStatus
from
vllm.v1.request
import
FinishReason
,
RequestStatus
from
.utils
import
(
from
.utils
import
(
...
@@ -13,6 +18,7 @@ from .utils import (
...
@@ -13,6 +18,7 @@ from .utils import (
create_request
,
create_request
,
create_scheduler
,
create_scheduler
,
create_vllm_config
,
create_vllm_config
,
make_kv_cache_config
,
)
)
pytestmark
=
pytest
.
mark
.
cpu_test
pytestmark
=
pytest
.
mark
.
cpu_test
...
@@ -579,3 +585,73 @@ def test_cannot_recv():
...
@@ -579,3 +585,73 @@ def test_cannot_recv():
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
_
=
scheduler
.
schedule
()
_
=
scheduler
.
schedule
()
assert_scheduler_empty
(
scheduler
)
assert_scheduler_empty
(
scheduler
)
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform"
)
def
test_p_side_chunked_prefill_mamba
(
mock_platform
):
"""P-side integration: Mamba N-1 truncation + chunked prefill completes.
A 64-token P-side request is truncated to 63 by the N-1 fix, then
chunked into two prefill steps (32 + 31) and finishes with
LENGTH_CAPPED because max_tokens is set to 1.
"""
mock_platform
.
device_type
=
"cpu"
BATCH_SIZE
=
32
NUM_TOKENS
=
64
BLOCK_SIZE
=
16
vllm_config
=
create_vllm_config
(
max_num_batched_tokens
=
BATCH_SIZE
,
block_size
=
BLOCK_SIZE
,
)
vllm_config
.
scheduler_config
.
disable_hybrid_kv_cache_manager
=
False
kv_cache_config
=
make_kv_cache_config
(
block_size
=
BLOCK_SIZE
,
mamba_enabled
=
True
,
num_blocks
=
10000
,
)
scheduler
=
create_scheduler
(
vllm_config
,
kv_cache_config
=
kv_cache_config
)
request
=
create_request
(
num_tokens
=
NUM_TOKENS
,
do_remote_decode
=
True
,
block_size
=
BLOCK_SIZE
,
)
request
.
max_tokens
=
128
scheduler
.
add_request
(
request
)
request_id
=
request
.
request_id
# ── Step 1: first chunk ──
scheduler_output
=
scheduler
.
schedule
()
assert
len
(
request
.
prompt_token_ids
)
==
NUM_TOKENS
-
1
assert
request
.
max_tokens
==
1
assert
scheduler_output
.
num_scheduled_tokens
[
request_id
]
==
BATCH_SIZE
assert
request
.
num_computed_tokens
==
BATCH_SIZE
# Model returns no tokens for intermediate prefill chunk
intermediate_output
=
ModelRunnerOutput
(
req_ids
=
[
request
.
request_id
],
req_id_to_index
=
{
request
.
request_id
:
0
},
sampled_token_ids
=
[[]],
)
scheduler
.
update_from_output
(
scheduler_output
,
intermediate_output
)
# ── Step 2: remaining chunk ──
scheduler_output
=
scheduler
.
schedule
()
remaining
=
NUM_TOKENS
-
1
-
BATCH_SIZE
# 31
assert
scheduler_output
.
num_scheduled_tokens
[
request_id
]
==
remaining
assert
request
.
num_computed_tokens
==
NUM_TOKENS
-
1
# Prefill complete: model generates 1 decode token
final_output
=
create_model_runner_output
([
request
])
engine_core_outputs
=
scheduler
.
update_from_output
(
scheduler_output
,
final_output
)
# max_tokens=1 → request finishes with LENGTH
outputs
=
engine_core_outputs
[
0
].
outputs
assert
len
(
outputs
)
==
1
assert
outputs
[
0
].
finish_reason
==
FinishReason
.
LENGTH
tests/v1/kv_connector/unit/utils.py
View file @
d49f2731
...
@@ -37,6 +37,7 @@ from vllm.v1.kv_cache_interface import (
...
@@ -37,6 +37,7 @@ from vllm.v1.kv_cache_interface import (
FullAttentionSpec
,
FullAttentionSpec
,
KVCacheConfig
,
KVCacheConfig
,
KVCacheGroupSpec
,
KVCacheGroupSpec
,
MambaSpec
,
SlidingWindowSpec
,
SlidingWindowSpec
,
)
)
from
vllm.v1.outputs
import
KVConnectorOutput
,
ModelRunnerOutput
from
vllm.v1.outputs
import
KVConnectorOutput
,
ModelRunnerOutput
...
@@ -423,7 +424,8 @@ KVConnectorFactory.register_connector(
...
@@ -423,7 +424,8 @@ KVConnectorFactory.register_connector(
def
make_kv_cache_config
(
def
make_kv_cache_config
(
block_size
:
int
,
block_size
:
int
,
hma_enabled
:
bool
=
False
,
swa_enabled
:
bool
=
False
,
mamba_enabled
:
bool
=
False
,
sw_size
:
int
=
128
,
sw_size
:
int
=
128
,
num_blocks
:
int
=
100
,
num_blocks
:
int
=
100
,
)
->
KVCacheConfig
:
)
->
KVCacheConfig
:
...
@@ -438,7 +440,7 @@ def make_kv_cache_config(
...
@@ -438,7 +440,7 @@ def make_kv_cache_config(
),
),
)
)
]
]
if
hm
a_enabled
:
if
sw
a_enabled
:
kv_cache_groups
.
append
(
kv_cache_groups
.
append
(
KVCacheGroupSpec
(
KVCacheGroupSpec
(
[
"layer1"
,
"layer3"
],
[
"layer1"
,
"layer3"
],
...
@@ -451,6 +453,32 @@ def make_kv_cache_config(
...
@@ -451,6 +453,32 @@ def make_kv_cache_config(
),
),
)
)
)
)
if
mamba_enabled
:
kv_cache_groups
.
append
(
KVCacheGroupSpec
(
[
"mamba0"
,
"mamba1"
],
MambaSpec
(
block_size
=
block_size
,
shapes
=
((
16
,),
(
16
,)),
dtypes
=
(
torch
.
float16
,),
),
)
)
return
KVCacheConfig
(
return
KVCacheConfig
(
num_blocks
=
num_blocks
,
kv_cache_tensors
=
[],
kv_cache_groups
=
kv_cache_groups
num_blocks
=
num_blocks
,
kv_cache_tensors
=
[],
kv_cache_groups
=
kv_cache_groups
)
)
def
make_nixl_scheduler
(
has_mamba
:
bool
=
False
,
is_hma_required
:
bool
=
False
):
"""Create a NixlConnectorScheduler via __new__ (skipping __init__).
Only sets the two flags needed by the N-1 prefill logic.
"""
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
NixlConnectorScheduler
,
)
sched
=
object
.
__new__
(
NixlConnectorScheduler
)
sched
.
_has_mamba
=
has_mamba
sched
.
_is_hma_required
=
is_hma_required
return
sched
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
d49f2731
...
@@ -572,6 +572,10 @@ class NixlConnectorScheduler:
...
@@ -572,6 +572,10 @@ class NixlConnectorScheduler:
for
g
in
kv_cache_config
.
kv_cache_groups
for
g
in
kv_cache_config
.
kv_cache_groups
)
)
)
)
self
.
_has_mamba
=
any
(
isinstance
(
g
.
kv_cache_spec
,
MambaSpec
)
for
g
in
kv_cache_config
.
kv_cache_groups
)
logger
.
info
(
"Initializing NIXL Scheduler %s"
,
engine_id
)
logger
.
info
(
"Initializing NIXL Scheduler %s"
,
engine_id
)
if
vllm_config
.
scheduler_config
.
disable_hybrid_kv_cache_manager
:
if
vllm_config
.
scheduler_config
.
disable_hybrid_kv_cache_manager
:
...
@@ -717,6 +721,39 @@ class NixlConnectorScheduler:
...
@@ -717,6 +721,39 @@ class NixlConnectorScheduler:
logger
.
warning
(
"Connection listener got unexpected message %s"
,
msg
)
logger
.
warning
(
"Connection listener got unexpected message %s"
,
msg
)
sock
.
send_multipart
((
identity
,
b
""
,
encoded_data
[
target_tp_rank
]))
sock
.
send_multipart
((
identity
,
b
""
,
encoded_data
[
target_tp_rank
]))
def
_mamba_prefill_token_count
(
self
,
num_prompt_tokens
:
int
)
->
int
:
"""D-side only. Returns N-1 for Mamba models since the decoder
always recomputes the last token and must start from h(N-1)."""
if
self
.
_has_mamba
and
num_prompt_tokens
>
1
:
return
num_prompt_tokens
-
1
return
num_prompt_tokens
def
_truncate_mamba_request_for_prefill
(
self
,
request
:
"Request"
)
->
None
:
"""P-side only: drop the last prompt token so the prefiller computes
h(N-1) instead of h(N). The decoder recomputes the last token to
derive h(N) correctly.
Guarded by ``_p_side_truncated`` to avoid repeated truncation if the
request is preempted and rescheduled."""
params
=
request
.
kv_transfer_params
if
(
params
is
not
None
# Guard against repeated truncation after preemption/reschedule.
and
not
params
.
get
(
"_p_side_truncated"
)
and
request
.
num_prompt_tokens
>
1
):
if
request
.
prompt_token_ids
is
not
None
:
request
.
prompt_token_ids
.
pop
()
elif
request
.
prompt_embeds
is
not
None
:
request
.
prompt_embeds
=
request
.
prompt_embeds
[:
-
1
]
else
:
return
request
.
_all_token_ids
.
pop
()
request
.
num_prompt_tokens
-=
1
request
.
max_tokens
=
1
params
[
"_p_side_truncated"
]
=
True
def
get_num_new_matched_tokens
(
def
get_num_new_matched_tokens
(
self
,
request
:
"Request"
,
num_computed_tokens
:
int
self
,
request
:
"Request"
,
num_computed_tokens
:
int
)
->
tuple
[
int
,
bool
]:
)
->
tuple
[
int
,
bool
]:
...
@@ -746,10 +783,14 @@ class NixlConnectorScheduler:
...
@@ -746,10 +783,14 @@ class NixlConnectorScheduler:
if
params
is
not
None
and
params
.
get
(
"do_remote_prefill"
):
if
params
is
not
None
and
params
.
get
(
"do_remote_prefill"
):
# Remote prefill: get all prompt blocks from remote.
# Remote prefill: get all prompt blocks from remote.
token_ids
=
request
.
prompt_token_ids
or
[]
token_ids
=
request
.
prompt_token_ids
or
[]
count
=
len
(
token_ids
)
-
num_computed_tokens
actual
=
self
.
_mamba_prefill_token_count
(
len
(
token_ids
))
count
=
actual
-
num_computed_tokens
if
count
>
0
:
if
count
>
0
:
return
count
,
True
return
count
,
True
if
params
is
not
None
and
params
.
get
(
"do_remote_decode"
)
and
self
.
_has_mamba
:
self
.
_truncate_mamba_request_for_prefill
(
request
)
# No remote prefill for this request.
# No remote prefill for this request.
return
0
,
False
return
0
,
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment