Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
68a72a5c
Unverified
Commit
68a72a5c
authored
Nov 07, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Nov 07, 2025
Browse files
Revert "[PerfFix] Avoid separate thread for MP executor shm spin (#28012)" (#28289)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
0f872b79
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
131 additions
and
143 deletions
+131
-143
tests/v1/executor/test_executor.py
tests/v1/executor/test_executor.py
+1
-2
tests/v1/kv_connector/unit/test_output_aggregator.py
tests/v1/kv_connector/unit/test_output_aggregator.py
+20
-12
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+29
-14
vllm/v1/executor/abstract.py
vllm/v1/executor/abstract.py
+2
-2
vllm/v1/executor/multiproc_executor.py
vllm/v1/executor/multiproc_executor.py
+61
-67
vllm/v1/executor/ray_executor.py
vllm/v1/executor/ray_executor.py
+6
-5
vllm/v1/executor/ray_utils.py
vllm/v1/executor/ray_utils.py
+4
-4
vllm/v1/executor/uniproc_executor.py
vllm/v1/executor/uniproc_executor.py
+7
-36
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+1
-1
No files found.
tests/v1/executor/test_executor.py
View file @
68a72a5c
...
...
@@ -4,7 +4,6 @@
import
asyncio
import
os
from
collections.abc
import
Callable
from
concurrent.futures
import
Future
from
typing
import
Any
import
pytest
...
...
@@ -28,7 +27,7 @@ class CustomMultiprocExecutor(MultiprocExecutor):
kwargs
:
dict
|
None
=
None
,
non_block
:
bool
=
False
,
unique_reply_rank
:
int
|
None
=
None
,
)
->
Any
|
list
[
Any
]
|
Future
[
Any
|
list
[
Any
]]
:
)
->
list
[
Any
]:
# Drop marker to show that this was run
with
open
(
".marker"
,
"w"
):
...
...
...
tests/v1/kv_connector/unit/test_output_aggregator.py
View file @
68a72a5c
...
...
@@ -89,12 +89,14 @@ def test_aggregate_workers_output():
def
test_async_aggregate_workers_output
():
aggregator
=
KVOutputAggregator
(
expected_finished_count
=
2
)
future
:
Future
[
list
[
DummyModelRunnerOutput
]]
=
Future
()
result_future
=
aggregator
.
async_aggregate
(
future
)
future1
:
Future
[
DummyModelRunnerOutput
]
=
Future
()
future2
:
Future
[
DummyModelRunnerOutput
]
=
Future
()
result_future
=
aggregator
.
async_aggregate
([
future1
,
future2
])
output1
=
DummyModelRunnerOutput
()
output2
=
DummyModelRunnerOutput
()
future
.
set_result
([
output1
,
output2
])
future1
.
set_result
(
output1
)
future2
.
set_result
(
output2
)
assert
result_future
.
done
()
aggregated
=
result_future
.
result
()
...
...
@@ -104,14 +106,16 @@ def test_async_aggregate_workers_output():
assert
aggregated
.
finished_recving
is
None
assert
not
aggregated
.
invalid_block_ids
future
=
Future
()
result_future
=
aggregator
.
async_aggregate
(
future
)
future1
=
Future
()
future2
=
Future
()
result_future
=
aggregator
.
async_aggregate
([
future1
,
future2
])
output1
=
DummyModelRunnerOutput
(
finished_sending
=
{
"req1"
},
finished_recving
=
{
"req2"
}
)
output2
=
DummyModelRunnerOutput
(
invalid_block_ids
=
{
1
})
future
.
set_result
([
output1
,
output2
])
future1
.
set_result
(
output1
)
future2
.
set_result
(
output2
)
assert
result_future
.
done
()
aggregated
=
result_future
.
result
()
...
...
@@ -121,12 +125,14 @@ def test_async_aggregate_workers_output():
assert
aggregated
.
finished_recving
is
None
assert
aggregated
.
invalid_block_ids
==
{
1
}
future
=
Future
()
result_future
=
aggregator
.
async_aggregate
(
future
)
future1
=
Future
()
future2
=
Future
()
result_future
=
aggregator
.
async_aggregate
([
future1
,
future2
])
output1
=
DummyModelRunnerOutput
(
invalid_block_ids
=
{
2
})
output2
=
DummyModelRunnerOutput
(
finished_sending
=
{
"req1"
})
future
.
set_result
([
output1
,
output2
])
future1
.
set_result
(
output1
)
future2
.
set_result
(
output2
)
assert
result_future
.
done
()
aggregated
=
result_future
.
result
()
...
...
@@ -136,14 +142,16 @@ def test_async_aggregate_workers_output():
assert
aggregated
.
finished_recving
is
None
assert
aggregated
.
invalid_block_ids
==
{
2
}
future
=
Future
()
result_future
=
aggregator
.
async_aggregate
(
future
)
future1
=
Future
()
future2
=
Future
()
result_future
=
aggregator
.
async_aggregate
([
future1
,
future2
])
output1
=
DummyModelRunnerOutput
(
invalid_block_ids
=
{
3
,
4
})
output2
=
DummyModelRunnerOutput
(
finished_recving
=
{
"req2"
},
invalid_block_ids
=
{
4
,
5
}
)
future
.
set_result
([
output1
,
output2
])
future1
.
set_result
(
output1
)
future2
.
set_result
(
output2
)
assert
result_future
.
done
()
aggregated
=
result_future
.
result
()
...
...
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
68a72a5c
...
...
@@ -221,24 +221,39 @@ class KVOutputAggregator:
def
async_aggregate
(
self
,
output_future
:
Future
[
Sequence
[
ModelRunnerOutput
|
None
]],
output_future
s
:
Sequenc
e
[
Futur
e
[
ModelRunnerOutput
|
None
]],
output_rank
:
int
=
0
,
)
->
Future
[
ModelRunnerOutput
|
None
]:
"""Takes a
future that resolves to a
list of
o
ut
put
s and returns a
future
which resolves to a single aggregated
output."""
"""Takes a list of
f
ut
ure
s and returns a
single future which resolves
to the respective list of
output
s
."""
result_future
:
Future
[
ModelRunnerOutput
|
None
]
=
Future
()
def
callback
(
fut
):
if
result_future
.
done
():
return
try
:
result_future
.
set_result
(
self
.
aggregate
(
fut
.
result
(),
output_rank
))
except
CancelledError
:
result_future
.
cancel
()
except
Exception
as
e
:
result_future
.
set_exception
(
e
)
output_future
.
add_done_callback
(
callback
)
outputs
:
list
[
ModelRunnerOutput
|
None
]
=
[
None
]
*
len
(
output_futures
)
remaining
=
len
(
output_futures
)
def
make_callback
(
idx
):
def
callback
(
fut
):
if
result_future
.
done
():
return
try
:
outputs
[
idx
]
=
fut
.
result
()
except
CancelledError
:
result_future
.
cancel
()
except
Exception
as
e
:
result_future
.
set_exception
(
e
)
# this check assumes io_thread_pool uses a single thread
nonlocal
remaining
remaining
-=
1
if
not
remaining
:
result_future
.
set_result
(
self
.
aggregate
(
outputs
,
output_rank
))
return
callback
for
i
,
output_future
in
enumerate
(
output_futures
):
output_future
.
add_done_callback
(
make_callback
(
i
))
return
result_future
...
...
vllm/v1/executor/abstract.py
View file @
68a72a5c
...
...
@@ -171,7 +171,7 @@ class Executor(ABC):
args
:
tuple
=
(),
kwargs
:
dict
|
None
=
None
,
non_block
:
Literal
[
True
]
=
True
,
)
->
Future
[
list
[
_R
]]:
)
->
list
[
Future
[
_R
]]:
pass
@
abstractmethod
...
...
@@ -219,7 +219,7 @@ class Executor(ABC):
def
sample_tokens
(
self
,
grammar_output
:
GrammarOutput
|
None
,
non_block
:
bool
=
False
)
->
ModelRunnerOutput
|
None
|
Future
[
ModelRunnerOutput
|
None
]:
)
->
ModelRunnerOutput
|
Future
[
ModelRunnerOutput
]:
output
=
self
.
collective_rpc
(
# type: ignore[call-overload]
"sample_tokens"
,
args
=
(
grammar_output
,),
non_block
=
non_block
)
...
...
vllm/v1/executor/multiproc_executor.py
View file @
68a72a5c
...
...
@@ -9,10 +9,8 @@ import threading
import
time
import
traceback
import
weakref
from
collections
import
deque
from
collections.abc
import
Callable
from
concurrent.futures
import
Future
,
InvalidStateError
from
contextlib
import
suppress
from
concurrent.futures
import
Future
,
ThreadPoolExecutor
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
functools
import
cached_property
,
partial
...
...
@@ -56,30 +54,6 @@ from vllm.v1.worker.worker_base import WorkerWrapperBase
logger
=
init_logger
(
__name__
)
class
FutureWrapper
(
Future
):
def
__init__
(
self
,
futures_queue
:
deque
[
tuple
[
"FutureWrapper"
,
Callable
]]):
self
.
futures_queue
=
futures_queue
super
().
__init__
()
def
result
(
self
,
timeout
=
None
):
if
timeout
is
not
None
:
raise
RuntimeError
(
"timeout not implemented"
)
# Drain any futures ahead of us in the queue.
while
not
self
.
done
():
future
,
get_response
=
self
.
futures_queue
.
pop
()
future
.
wait_for_response
(
get_response
)
return
super
().
result
()
def
wait_for_response
(
self
,
get_response
:
Callable
):
try
:
response
=
get_response
()
with
suppress
(
InvalidStateError
):
self
.
set_result
(
response
)
except
Exception
as
e
:
with
suppress
(
InvalidStateError
):
self
.
set_exception
(
e
)
class
MultiprocExecutor
(
Executor
):
supports_pp
:
bool
=
True
...
...
@@ -90,6 +64,7 @@ class MultiprocExecutor(Executor):
self
.
is_failed
=
False
self
.
shutdown_event
=
threading
.
Event
()
self
.
failure_callback
:
FailureCallback
|
None
=
None
self
.
io_thread_pool
:
ThreadPoolExecutor
|
None
=
None
self
.
world_size
=
self
.
parallel_config
.
world_size
tensor_parallel_size
=
self
.
parallel_config
.
tensor_parallel_size
...
...
@@ -157,7 +132,12 @@ class MultiprocExecutor(Executor):
uw
.
death_writer
.
close
()
self
.
_ensure_worker_termination
([
uw
.
proc
for
uw
in
unready_workers
])
self
.
futures_queue
=
deque
[
tuple
[
FutureWrapper
,
Callable
]]()
# Note: must use only 1 IO thread to keep dequeue sequence
# from the response queue.
# _async_aggregate_workers_output also assumes a single IO thread.
self
.
io_thread_pool
=
ThreadPoolExecutor
(
max_workers
=
1
,
thread_name_prefix
=
"mp_exec_io"
)
self
.
output_rank
=
self
.
_get_output_rank
()
self
.
has_connector
=
self
.
vllm_config
.
kv_transfer_config
is
not
None
...
...
@@ -215,13 +195,14 @@ class MultiprocExecutor(Executor):
)
->
ModelRunnerOutput
|
None
|
Future
[
ModelRunnerOutput
|
None
]:
if
not
self
.
has_connector
:
# get output only from a single worker (output_rank)
return
self
.
collective_rpc
(
(
output
,)
=
self
.
collective_rpc
(
method
,
args
=
args
,
unique_reply_rank
=
self
.
output_rank
,
non_block
=
non_block
,
timeout
=
envs
.
VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS
,
)
return
output
# get output from all workers
outputs
=
self
.
collective_rpc
(
...
...
@@ -242,11 +223,12 @@ class MultiprocExecutor(Executor):
def
take_draft_token_ids
(
self
)
->
DraftTokenIds
|
None
:
# OPTIMIZATION: Get output only from a single worker (output_rank)
return
self
.
collective_rpc
(
outputs
=
self
.
collective_rpc
(
"take_draft_token_ids"
,
unique_reply_rank
=
self
.
output_rank
)
return
outputs
[
0
]
def
collective_rpc
(
# type: ignore[override]
def
collective_rpc
(
self
,
method
:
str
|
Callable
,
timeout
:
float
|
None
=
None
,
...
...
@@ -254,9 +236,7 @@ class MultiprocExecutor(Executor):
kwargs
:
dict
|
None
=
None
,
non_block
:
bool
=
False
,
unique_reply_rank
:
int
|
None
=
None
,
)
->
Any
|
list
[
Any
]
|
Future
[
Any
|
list
[
Any
]]:
"""Returns single result if unique_reply_rank is provided, otherwise list."""
)
->
list
[
Any
]:
if
self
.
is_failed
:
raise
RuntimeError
(
"Executor failed."
)
...
...
@@ -266,52 +246,63 @@ class MultiprocExecutor(Executor):
# NOTE: If the args are heterogeneous, then we pack them into a list,
# and unpack them in the method of every worker, because every worker
# knows their own rank.
try
:
if
isinstance
(
method
,
str
):
send_method
=
method
else
:
send_method
=
cloudpickle
.
dumps
(
method
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
self
.
rpc_broadcast_mq
.
enqueue
(
(
send_method
,
args
,
kwargs
,
unique_reply_rank
)
)
if
isinstance
(
method
,
str
):
send_method
=
method
else
:
send_method
=
cloudpickle
.
dumps
(
method
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
self
.
rpc_broadcast_mq
.
enqueue
((
send_method
,
args
,
kwargs
,
unique_reply_rank
))
workers
=
(
(
self
.
workers
[
unique_reply_rank
],)
if
unique_reply_rank
is
not
None
else
self
.
workers
)
responses
=
[]
workers
=
(
(
self
.
workers
[
unique_reply_rank
],)
if
unique_reply_rank
is
not
None
else
self
.
workers
)
def
get_response
(
w
:
WorkerProcHandle
,
dequeue_timeout
:
float
|
None
=
None
,
cancel_event
:
threading
.
Event
|
None
=
None
,
):
status
,
result
=
w
.
worker_response_mq
.
dequeue
(
timeout
=
dequeue_timeout
,
cancel
=
cancel_event
)
shutdown_event
=
self
.
shutdown_event
if
status
!=
WorkerProc
.
ResponseStatus
.
SUCCESS
:
raise
RuntimeError
(
f
"Worker failed with error '
{
result
}
', please check the"
" stack trace above for the root cause"
)
return
result
def
get_response
():
responses
=
[]
for
w
in
workers
:
dequeue_timeout
=
(
None
if
deadline
is
None
else
(
deadline
-
time
.
monotonic
())
)
try
:
status
,
result
=
w
.
worker_response_mq
.
dequeue
(
timeout
=
dequeue_timeout
,
cancel
=
shutdown_event
if
self
.
io_thread_pool
is
not
None
:
# We must consume worker_response_mq from a single thread.
result
=
self
.
io_thread_pool
.
submit
(
# type: ignore
get_response
,
w
,
dequeue_timeout
,
self
.
shutdown_event
)
except
TimeoutError
as
e
:
raise
TimeoutError
(
f
"RPC call to
{
method
}
timed out."
)
from
e
if
status
!=
WorkerProc
.
ResponseStatus
.
SUCCESS
:
if
not
non_block
:
result
=
result
.
result
()
elif
not
non_block
:
result
=
get_response
(
w
,
dequeue_timeout
,
self
.
shutdown_event
)
else
:
raise
RuntimeError
(
f
"Worker failed with error '
{
result
}
', please check the"
" stack trace above for the root cause"
"non_block can only be used when max_concurrent_batches > 1"
)
responses
.
append
(
result
)
return
responses
[
0
]
if
unique_reply_rank
is
not
None
else
responses
if
non_block
:
future
=
FutureWrapper
(
self
.
futures_queue
)
self
.
futures_queue
.
appendleft
((
future
,
get_response
))
return
future
# First drain any pending futures in the queue.
while
self
.
futures_queue
:
future
,
get_fut_response
=
self
.
futures_queue
.
pop
()
future
.
wait_for_response
(
get_fut_response
)
return
get_response
()
return
responses
except
TimeoutError
as
e
:
raise
TimeoutError
(
f
"RPC call to
{
method
}
timed out."
)
from
e
@
staticmethod
def
_ensure_worker_termination
(
worker_procs
:
list
[
BaseProcess
]):
...
...
@@ -357,6 +348,9 @@ class MultiprocExecutor(Executor):
self
.
_ensure_worker_termination
([
w
.
proc
for
w
in
workers
])
self
.
shutdown_event
.
set
()
if
self
.
io_thread_pool
is
not
None
:
self
.
io_thread_pool
.
shutdown
(
wait
=
False
,
cancel_futures
=
True
)
del
self
.
io_thread_pool
self
.
rpc_broadcast_mq
=
None
...
...
vllm/v1/executor/ray_executor.py
View file @
68a72a5c
...
...
@@ -435,25 +435,26 @@ class RayDistributedExecutor(Executor):
# When PP is used, we return a FutureWrapper immediately so that
# the scheduler can yield to the next batch.
return
FutureWrapper
(
refs
[
0
]
)
return
FutureWrapper
(
refs
)
# Get output from all workers when connector is present
assert
self
.
kv_output_aggregator
is
not
None
if
not
non_block
:
# Block and get results from all workers
return
self
.
kv_output_aggregator
.
aggregate
(
ray
.
get
(
refs
))
outputs
=
[
ref
.
get
()
for
ref
in
refs
]
return
self
.
kv_output_aggregator
.
aggregate
(
outputs
)
# Return a future that will aggregate outputs from all workers
return
FutureWrapper
(
refs
,
self
.
kv_output_aggregator
)
def
collective_rpc
(
# type: ignore[override]
def
collective_rpc
(
self
,
method
:
str
|
Callable
,
timeout
:
float
|
None
=
None
,
args
:
tuple
=
(),
kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
non_block
:
bool
=
False
,
)
->
list
[
Any
]
|
Future
[
list
[
Any
]]
:
)
->
list
[
Any
]:
"""Runs the given method on all workers."""
sent_method
=
method
if
isinstance
(
method
,
str
)
else
cloudpickle
.
dumps
(
method
)
del
method
...
...
@@ -469,7 +470,7 @@ class RayDistributedExecutor(Executor):
# Get the results of the ray workers.
if
non_block
:
return
FutureWrapper
(
ray_worker_outputs
)
return
[
FutureWrapper
(
(
output
,))
for
output
in
ray_worker_outputs
]
return
ray
.
get
(
ray_worker_outputs
,
timeout
=
timeout
)
...
...
vllm/v1/executor/ray_utils.py
View file @
68a72a5c
...
...
@@ -141,19 +141,19 @@ class FutureWrapper(Future):
the result() call. If not only the first worker's output is returned.
"""
def
__init__
(
self
,
ref_or_
refs
,
aggregator
:
KVOutputAggregator
|
None
=
None
):
def
__init__
(
self
,
refs
,
aggregator
:
KVOutputAggregator
|
None
=
None
):
super
().
__init__
()
self
.
ref
_or_refs
=
ref_or_
refs
self
.
ref
s
=
refs
self
.
aggregator
=
aggregator
def
result
(
self
,
timeout
=
None
):
if
timeout
is
not
None
:
raise
NotImplementedError
(
"timeout is not supported"
)
outputs
=
ray
.
get
(
self
.
ref_or_refs
,
timeout
=
timeout
)
if
self
.
aggregator
is
None
:
return
outputs
return
self
.
refs
[
0
].
get
()
outputs
=
[
ref
.
get
()
for
ref
in
self
.
refs
]
return
self
.
aggregator
.
aggregate
(
outputs
,
output_rank
=
0
)
...
...
vllm/v1/executor/uniproc_executor.py
View file @
68a72a5c
...
...
@@ -13,10 +13,9 @@ import torch.distributed as dist
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.utils.network_utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
from
vllm.v1.engine
import
ReconfigureDistributedRequest
,
ReconfigureRankType
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.outputs
import
AsyncModelRunnerOutput
,
DraftTokenIds
,
ModelRunnerOutput
from
vllm.v1.outputs
import
AsyncModelRunnerOutput
from
vllm.v1.serial_utils
import
run_method
from
vllm.v1.worker.worker_base
import
WorkerWrapperBase
...
...
@@ -59,60 +58,32 @@ class UniProcExecutor(Executor):
def
max_concurrent_batches
(
self
)
->
int
:
return
2
if
self
.
scheduler_config
.
async_scheduling
else
1
def
collective_rpc
(
# type: ignore[override]
def
collective_rpc
(
self
,
method
:
str
|
Callable
,
timeout
:
float
|
None
=
None
,
args
:
tuple
=
(),
kwargs
:
dict
|
None
=
None
,
non_block
:
bool
=
False
,
single_value
:
bool
=
False
,
)
->
Any
|
list
[
Any
]
|
Future
[
Any
|
list
[
Any
]]:
)
->
list
[
Any
]:
if
kwargs
is
None
:
kwargs
=
{}
if
not
non_block
:
result
=
run_method
(
self
.
driver_worker
,
method
,
args
,
kwargs
)
return
result
if
single_value
else
[
result
]
return
[
run_method
(
self
.
driver_worker
,
method
,
args
,
kwargs
)]
try
:
result
=
run_method
(
self
.
driver_worker
,
method
,
args
,
kwargs
)
if
isinstance
(
result
,
AsyncModelRunnerOutput
):
if
(
async_thread
:
=
self
.
async_output_thread
)
is
not
None
:
get_output
=
result
.
get_output
if
not
single_value
:
get_output
=
lambda
:
[
get_output
()]
return
async_thread
.
submit
(
get_output
)
return
[
async_thread
.
submit
(
result
.
get_output
)]
result
=
result
.
get_output
()
future
=
Future
[
Any
]()
future
.
set_result
(
result
if
single_value
else
[
result
]
)
future
.
set_result
(
result
)
except
Exception
as
e
:
future
=
Future
[
Any
]()
future
.
set_exception
(
e
)
return
future
def
execute_model
(
# type: ignore[override]
self
,
scheduler_output
:
SchedulerOutput
,
non_block
:
bool
=
False
)
->
ModelRunnerOutput
|
None
|
Future
[
ModelRunnerOutput
|
None
]:
return
self
.
collective_rpc
(
"execute_model"
,
args
=
(
scheduler_output
,),
non_block
=
non_block
,
single_value
=
True
,
)
def
sample_tokens
(
# type: ignore[override]
self
,
grammar_output
:
GrammarOutput
|
None
,
non_block
:
bool
=
False
)
->
ModelRunnerOutput
|
None
|
Future
[
ModelRunnerOutput
|
None
]:
return
self
.
collective_rpc
(
"sample_tokens"
,
args
=
(
grammar_output
,),
non_block
=
non_block
,
single_value
=
True
,
)
def
take_draft_token_ids
(
self
)
->
DraftTokenIds
|
None
:
return
self
.
collective_rpc
(
"take_draft_token_ids"
,
single_value
=
True
)
return
[
future
]
def
check_health
(
self
)
->
None
:
# UniProcExecutor will always be healthy as long as
...
...
vllm/v1/worker/gpu_worker.py
View file @
68a72a5c
...
...
@@ -524,7 +524,7 @@ class Worker(WorkerBase):
@
torch
.
inference_mode
()
def
sample_tokens
(
self
,
grammar_output
:
"GrammarOutput
| None
"
self
,
grammar_output
:
"GrammarOutput"
)
->
ModelRunnerOutput
|
AsyncModelRunnerOutput
:
return
self
.
model_runner
.
sample_tokens
(
grammar_output
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment