Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
eeb135eb
Unverified
Commit
eeb135eb
authored
Sep 16, 2025
by
Nick Hill
Committed by
GitHub
Sep 16, 2025
Browse files
[Core] Use `CpuGpuBuffer` for block table tensors (#24795)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
3059b9cc
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
53 additions
and
63 deletions
+53
-63
tests/v1/tpu/worker/test_tpu_model_runner.py
tests/v1/tpu/worker/test_tpu_model_runner.py
+1
-1
tests/v1/worker/test_gpu_input_batch.py
tests/v1/worker/test_gpu_input_batch.py
+4
-1
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+1
-1
vllm/v1/worker/block_table.py
vllm/v1/worker/block_table.py
+32
-43
vllm/v1/worker/cpu_model_runner.py
vllm/v1/worker/cpu_model_runner.py
+4
-4
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+11
-13
No files found.
tests/v1/tpu/worker/test_tpu_model_runner.py
View file @
eeb135eb
...
@@ -125,7 +125,7 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
...
@@ -125,7 +125,7 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
return
False
return
False
num_blocks
=
block_table
.
num_blocks_per_row
[
req_index
]
num_blocks
=
block_table
.
num_blocks_per_row
[
req_index
]
block_table_values
=
block_table
.
block_table
_
np
[
req_index
,
:
num_blocks
]
block_table_values
=
block_table
.
block_table
.
np
[
req_index
,
:
num_blocks
]
return
(
block_table_values
==
req_block_ids
).
all
()
return
(
block_table_values
==
req_block_ids
).
all
()
...
...
tests/v1/worker/test_gpu_input_batch.py
View file @
eeb135eb
...
@@ -15,6 +15,7 @@ from vllm.utils import is_pin_memory_available, make_tensor_with_pad
...
@@ -15,6 +15,7 @@ from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
vllm.v1.sample.logits_processor
import
LogitsProcessors
from
vllm.v1.sample.logits_processor
import
LogitsProcessors
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.utils
import
CpuGpuBuffer
from
vllm.v1.worker.block_table
import
BlockTable
,
MultiGroupBlockTable
from
vllm.v1.worker.block_table
import
BlockTable
,
MultiGroupBlockTable
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
...
@@ -45,7 +46,7 @@ def _compare_objs(obj1,
...
@@ -45,7 +46,7 @@ def _compare_objs(obj1,
is_same
=
False
is_same
=
False
if
isinstance
(
a
,
torch
.
Tensor
):
if
isinstance
(
a
,
torch
.
Tensor
):
if
(
a
.
numel
()
==
0
or
b
.
numel
()
==
0
)
:
if
a
.
numel
()
==
0
or
b
.
numel
()
==
0
:
is_same
=
(
a
.
numel
()
==
0
and
b
.
numel
()
==
0
)
is_same
=
(
a
.
numel
()
==
0
and
b
.
numel
()
==
0
)
elif
torch
.
allclose
(
a
,
b
):
elif
torch
.
allclose
(
a
,
b
):
is_same
=
True
is_same
=
True
...
@@ -61,6 +62,8 @@ def _compare_objs(obj1,
...
@@ -61,6 +62,8 @@ def _compare_objs(obj1,
is_same
=
True
# if we make it here must be same
is_same
=
True
# if we make it here must be same
elif
a
==
b
:
elif
a
==
b
:
is_same
=
True
is_same
=
True
elif
isinstance
(
a
,
CpuGpuBuffer
):
is_same
=
np
.
allclose
(
a
.
np
,
b
.
np
)
and
torch
.
allclose
(
a
.
gpu
,
b
.
gpu
)
assert
is_same
,
f
"Attribute
{
attr_name
}
is different"
\
assert
is_same
,
f
"Attribute
{
attr_name
}
is different"
\
f
" in
{
obj1
}
and
{
obj2
}
:
{
a
}
!=
{
b
}
"
f
" in
{
obj1
}
and
{
obj2
}
:
{
a
}
!=
{
b
}
"
...
...
tests/v1/worker/test_gpu_model_runner.py
View file @
eeb135eb
...
@@ -165,7 +165,7 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
...
@@ -165,7 +165,7 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
req_state
.
block_ids
[
0
]):
req_state
.
block_ids
[
0
]):
return
False
return
False
num_blocks
=
block_table
.
num_blocks_per_row
[
req_index
]
num_blocks
=
block_table
.
num_blocks_per_row
[
req_index
]
return
(
block_table
.
block_table
_
np
[
req_index
,
:
num_blocks
]
==
return
(
block_table
.
block_table
.
np
[
req_index
,
:
num_blocks
]
==
req_state
.
block_ids
[
0
]).
all
()
req_state
.
block_ids
[
0
]).
all
()
...
...
vllm/v1/worker/block_table.py
View file @
eeb135eb
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -7,6 +8,7 @@ import torch
...
@@ -7,6 +8,7 @@ import torch
from
vllm.distributed
import
get_dcp_group
from
vllm.distributed
import
get_dcp_group
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
from
vllm.v1.utils
import
CpuGpuBuffer
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -29,28 +31,13 @@ class BlockTable:
...
@@ -29,28 +31,13 @@ class BlockTable:
self
.
pin_memory
=
pin_memory
self
.
pin_memory
=
pin_memory
self
.
device
=
device
self
.
device
=
device
self
.
block_table
=
torch
.
zeros
(
self
.
block_table
=
self
.
_make_buffer
(
max_num_reqs
,
(
max_num_reqs
,
max_num_blocks_per_req
),
max_num_blocks_per_req
,
device
=
self
.
device
,
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
,
)
self
.
block_table_cpu
=
torch
.
zeros
(
(
max_num_reqs
,
max_num_blocks_per_req
),
device
=
"cpu"
,
dtype
=
torch
.
int32
,
pin_memory
=
pin_memory
,
)
self
.
block_table_np
=
self
.
block_table_cpu
.
numpy
()
self
.
num_blocks_per_row
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_blocks_per_row
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
slot_mapping_cpu
=
torch
.
zeros
(
self
.
max_num_batched_tokens
,
self
.
slot_mapping
=
self
.
_make_buffer
(
self
.
max_num_batched_tokens
,
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
)
device
=
"cpu"
,
pin_memory
=
self
.
pin_memory
)
self
.
slot_mapping_np
=
self
.
slot_mapping_cpu
.
numpy
()
self
.
slot_mapping
=
torch
.
zeros
(
self
.
max_num_batched_tokens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
try
:
try
:
self
.
dcp_world_size
=
get_dcp_group
().
world_size
self
.
dcp_world_size
=
get_dcp_group
().
world_size
self
.
dcp_rank
=
get_dcp_group
().
rank_in_group
self
.
dcp_rank
=
get_dcp_group
().
rank_in_group
...
@@ -69,7 +56,7 @@ class BlockTable:
...
@@ -69,7 +56,7 @@ class BlockTable:
num_blocks
=
len
(
block_ids
)
num_blocks
=
len
(
block_ids
)
start
=
self
.
num_blocks_per_row
[
row_idx
]
start
=
self
.
num_blocks_per_row
[
row_idx
]
self
.
num_blocks_per_row
[
row_idx
]
+=
num_blocks
self
.
num_blocks_per_row
[
row_idx
]
+=
num_blocks
self
.
block_table
_
np
[
row_idx
,
start
:
start
+
num_blocks
]
=
block_ids
self
.
block_table
.
np
[
row_idx
,
start
:
start
+
num_blocks
]
=
block_ids
def
add_row
(
self
,
block_ids
:
list
[
int
],
row_idx
:
int
)
->
None
:
def
add_row
(
self
,
block_ids
:
list
[
int
],
row_idx
:
int
)
->
None
:
self
.
num_blocks_per_row
[
row_idx
]
=
0
self
.
num_blocks_per_row
[
row_idx
]
=
0
...
@@ -77,17 +64,14 @@ class BlockTable:
...
@@ -77,17 +64,14 @@ class BlockTable:
def
move_row
(
self
,
src
:
int
,
tgt
:
int
)
->
None
:
def
move_row
(
self
,
src
:
int
,
tgt
:
int
)
->
None
:
num_blocks
=
self
.
num_blocks_per_row
[
src
]
num_blocks
=
self
.
num_blocks_per_row
[
src
]
self
.
block_table_np
[
tgt
,
:
num_blocks
]
=
self
.
block_table
_
np
[
block_table_np
=
self
.
block_table
.
np
src
,
:
num_blocks
]
block_table_np
[
tgt
,
:
num_blocks
]
=
block_table_np
[
src
,
:
num_blocks
]
self
.
num_blocks_per_row
[
tgt
]
=
num_blocks
self
.
num_blocks_per_row
[
tgt
]
=
num_blocks
def
swap_row
(
self
,
src
:
int
,
tgt
:
int
)
->
None
:
def
swap_row
(
self
,
src
:
int
,
tgt
:
int
)
->
None
:
num_blocks_src
=
self
.
num_blocks_per_row
[
src
]
src_tgt
,
tgt_src
=
[
src
,
tgt
],
[
tgt
,
src
]
num_blocks_tgt
=
self
.
num_blocks_per_row
[
tgt
]
self
.
num_blocks_per_row
[
src_tgt
]
=
self
.
num_blocks_per_row
[
tgt_src
]
self
.
num_blocks_per_row
[
src
]
=
num_blocks_tgt
self
.
block_table
.
np
[
src_tgt
]
=
self
.
block_table
.
np
[
tgt_src
]
self
.
num_blocks_per_row
[
tgt
]
=
num_blocks_src
self
.
block_table_np
[[
src
,
tgt
]]
=
self
.
block_table_np
[[
tgt
,
src
]]
def
compute_slot_mapping
(
self
,
req_indices
:
np
.
ndarray
,
def
compute_slot_mapping
(
self
,
req_indices
:
np
.
ndarray
,
positions
:
np
.
ndarray
)
->
None
:
positions
:
np
.
ndarray
)
->
None
:
...
@@ -107,7 +91,7 @@ class BlockTable:
...
@@ -107,7 +91,7 @@ class BlockTable:
virtual_block_size
=
self
.
block_size
*
self
.
dcp_world_size
virtual_block_size
=
self
.
block_size
*
self
.
dcp_world_size
block_table_indices
=
(
req_indices
*
self
.
max_num_blocks_per_req
+
block_table_indices
=
(
req_indices
*
self
.
max_num_blocks_per_req
+
positions
//
virtual_block_size
)
positions
//
virtual_block_size
)
block_numbers
=
self
.
block_table
_
np
.
ravel
()[
block_table_indices
]
block_numbers
=
self
.
block_table
.
np
.
ravel
()[
block_table_indices
]
# Use virtual_block_size for mask calculation, which marks local
# Use virtual_block_size for mask calculation, which marks local
# tokens.
# tokens.
virtual_block_offsets
=
positions
%
virtual_block_size
virtual_block_offsets
=
positions
%
virtual_block_size
...
@@ -117,40 +101,45 @@ class BlockTable:
...
@@ -117,40 +101,45 @@ class BlockTable:
# Calculate slot_mapping
# Calculate slot_mapping
slot_mapping
=
block_numbers
*
self
.
block_size
+
block_offsets
slot_mapping
=
block_numbers
*
self
.
block_size
+
block_offsets
# Write final slots, use -1 for not-local
# Write final slots, use -1 for not-local
self
.
slot_mapping
_
np
[:
req_indices
.
shape
[
0
]]
=
np
.
where
(
self
.
slot_mapping
.
np
[:
req_indices
.
shape
[
0
]]
=
np
.
where
(
mask
,
slot_mapping
,
-
1
)
mask
,
slot_mapping
,
-
1
)
else
:
else
:
block_table_indices
=
(
req_indices
*
self
.
max_num_blocks_per_req
+
block_table_indices
=
(
req_indices
*
self
.
max_num_blocks_per_req
+
positions
//
self
.
block_size
)
positions
//
self
.
block_size
)
block_numbers
=
self
.
block_table
_
np
.
ravel
()[
block_table_indices
]
block_numbers
=
self
.
block_table
.
np
.
ravel
()[
block_table_indices
]
block_offsets
=
positions
%
self
.
block_size
block_offsets
=
positions
%
self
.
block_size
np
.
add
(
block_numbers
*
self
.
block_size
,
np
.
add
(
block_numbers
*
self
.
block_size
,
block_offsets
,
block_offsets
,
out
=
self
.
slot_mapping
_
np
[:
req_indices
.
shape
[
0
]])
out
=
self
.
slot_mapping
.
np
[:
req_indices
.
shape
[
0
]])
def
commit_block_table
(
self
,
num_reqs
:
int
)
->
None
:
def
commit_block_table
(
self
,
num_reqs
:
int
)
->
None
:
self
.
block_table
[:
num_reqs
].
copy_
(
self
.
block_table_cpu
[:
num_reqs
],
self
.
block_table
.
copy_to_gpu
(
num_reqs
)
non_blocking
=
True
)
def
commit_slot_mapping
(
self
,
num_tokens
:
int
)
->
None
:
def
commit_slot_mapping
(
self
,
num_tokens
:
int
)
->
None
:
self
.
slot_mapping
[:
num_tokens
].
copy_
(
self
.
slot_mapping
.
copy_to_gpu
(
num_tokens
)
self
.
slot_mapping_cpu
[:
num_tokens
],
non_blocking
=
True
)
def
clear
(
self
)
->
None
:
def
clear
(
self
)
->
None
:
self
.
block_table
.
fill_
(
0
)
self
.
block_table
.
gpu
.
fill_
(
0
)
self
.
block_table
_
cpu
.
fill_
(
0
)
self
.
block_table
.
cpu
.
fill_
(
0
)
def
get_device_tensor
(
self
)
->
torch
.
Tensor
:
def
get_device_tensor
(
self
,
num_reqs
:
int
)
->
torch
.
Tensor
:
"""Returns the device tensor of the block table."""
"""Returns the device tensor of the block table."""
return
self
.
block_table
return
self
.
block_table
.
gpu
[:
num_reqs
]
def
get_cpu_tensor
(
self
)
->
torch
.
Tensor
:
def
get_cpu_tensor
(
self
)
->
torch
.
Tensor
:
"""Returns the CPU tensor of the block table."""
"""Returns the CPU tensor of the block table."""
return
self
.
block_table
_
cpu
return
self
.
block_table
.
cpu
def
get_numpy_array
(
self
)
->
np
.
ndarray
:
def
get_numpy_array
(
self
)
->
np
.
ndarray
:
"""Returns the numpy array of the block table."""
"""Returns the numpy array of the block table."""
return
self
.
block_table_np
return
self
.
block_table
.
np
def
_make_buffer
(
self
,
*
size
:
Union
[
int
,
torch
.
SymInt
],
dtype
:
torch
.
dtype
)
->
CpuGpuBuffer
:
return
CpuGpuBuffer
(
*
size
,
dtype
=
dtype
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
)
class
MultiGroupBlockTable
:
class
MultiGroupBlockTable
:
...
...
vllm/v1/worker/cpu_model_runner.py
View file @
eeb135eb
...
@@ -89,7 +89,7 @@ class CPUModelRunner(GPUModelRunner):
...
@@ -89,7 +89,7 @@ class CPUModelRunner(GPUModelRunner):
assert
isinstance
(
device_tensor
,
torch
.
Tensor
)
assert
isinstance
(
device_tensor
,
torch
.
Tensor
)
setattr
(
obj
,
device_attr_name
,
cpu_tensor
)
setattr
(
obj
,
device_attr_name
,
cpu_tensor
)
for
k
,
v
in
vars
(
self
).
item
s
():
for
v
in
vars
(
self
).
value
s
():
if
isinstance
(
v
,
CpuGpuBuffer
):
if
isinstance
(
v
,
CpuGpuBuffer
):
v
.
gpu
=
v
.
cpu
v
.
gpu
=
v
.
cpu
...
@@ -98,9 +98,9 @@ class CPUModelRunner(GPUModelRunner):
...
@@ -98,9 +98,9 @@ class CPUModelRunner(GPUModelRunner):
replace_tensor
(
self
.
input_batch
,
k
,
k
[:
-
11
])
replace_tensor
(
self
.
input_batch
,
k
,
k
[:
-
11
])
for
block_table
in
self
.
input_batch
.
block_table
.
block_tables
:
for
block_table
in
self
.
input_batch
.
block_table
.
block_tables
:
for
k
,
v
in
vars
(
block_table
).
item
s
():
for
v
in
vars
(
block_table
).
value
s
():
if
k
.
endswith
(
"_cpu"
)
and
isinstance
(
v
,
torch
.
Tenso
r
):
if
isinstance
(
v
,
CpuGpuBuffe
r
):
replace_tensor
(
block_table
,
k
,
k
[:
-
4
])
v
.
gpu
=
v
.
cpu
def
load_model
(
self
,
eep_scale_up
:
bool
=
False
)
->
None
:
def
load_model
(
self
,
eep_scale_up
:
bool
=
False
)
->
None
:
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
eeb135eb
...
@@ -427,9 +427,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -427,9 +427,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
*
size
:
Union
[
int
,
torch
.
SymInt
],
*
size
:
Union
[
int
,
torch
.
SymInt
],
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
numpy
:
bool
=
True
)
->
CpuGpuBuffer
:
numpy
:
bool
=
True
)
->
CpuGpuBuffer
:
# Bfloat16 torch tensors cannot be directly cast to a numpy array, so
# if a bfloat16 buffer is needed without a corresponding numpy array,
# don't bother instantiating the numpy array.
return
CpuGpuBuffer
(
*
size
,
return
CpuGpuBuffer
(
*
size
,
dtype
=
dtype
,
dtype
=
dtype
,
device
=
self
.
device
,
device
=
self
.
device
,
...
@@ -1062,13 +1059,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -1062,13 +1059,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_common_prefix_blocks
=
0
num_common_prefix_blocks
=
0
else
:
else
:
blk_table
=
self
.
input_batch
.
block_table
[
kv_cache_group_id
]
blk_table
=
self
.
input_batch
.
block_table
[
kv_cache_group_id
]
blk_table_tensor
=
blk_table
.
get_device_tensor
(
)[:
num_reqs
]
blk_table_tensor
=
blk_table
.
get_device_tensor
(
num_reqs
)
slot_mapping
=
blk_table
.
slot_mapping
[:
slot_mapping
=
blk_table
.
slot_mapping
.
gpu
[:
total_num_scheduled_tokens
]
total_num_scheduled_tokens
]
# Fill unused with -1. Needed for reshape_and_cache in full cuda
# Fill unused with -1. Needed for reshape_and_cache in full cuda
# graph mode.
# graph mode.
blk_table
.
slot_mapping
[
total_num_scheduled_tokens
:].
fill_
(
-
1
)
blk_table
.
slot_mapping
.
gpu
[
total_num_scheduled_tokens
:].
fill_
(
-
1
)
num_common_prefix_blocks
=
(
num_common_prefix_blocks
=
(
scheduler_output
.
scheduler_output
.
num_common_prefix_blocks
[
kv_cache_group_id
])
num_common_prefix_blocks
[
kv_cache_group_id
])
...
@@ -2903,10 +2901,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2903,10 +2901,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_actual_tokens
=
num_tokens
,
num_actual_tokens
=
num_tokens
,
max_query_len
=
max_query_len
,
max_query_len
=
max_query_len
,
max_seq_len
=
self
.
max_model_len
,
max_seq_len
=
self
.
max_model_len
,
block_table_tensor
=
self
.
input_batch
.
block_table
[
block_table_tensor
=
self
.
input_batch
.
kv_cache_group_id
].
get_device_tensor
(
)[:
num_reqs
]
,
block_table
[
kv_cache_group_id
].
get_device_tensor
(
num_reqs
)
,
slot_mapping
=
self
.
input_batch
.
slot_mapping
=
self
.
input_batch
.
block_table
[
block_table
[
kv_cache_group_id
].
slot_mapping
[:
num_tokens
],
kv_cache_group_id
].
slot_mapping
.
gpu
[:
num_tokens
],
causal
=
True
)
causal
=
True
)
for
attn_group
in
self
.
attn_groups
[
kv_cache_group_id
]:
for
attn_group
in
self
.
attn_groups
[
kv_cache_group_id
]:
if
ubatch_slices
is
not
None
:
if
ubatch_slices
is
not
None
:
...
@@ -3265,8 +3263,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -3265,8 +3263,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
uniform_decode
=
False
)
uniform_decode
=
False
)
# Capture full cudagraph for uniform decode batches if we
have
# Capture full cudagraph for uniform decode batches if we
# dont already have full mixed prefill-decode cudagraphs
# don
'
t already have full mixed prefill-decode cudagraphs
.
if
cudagraph_mode
.
decode_mode
()
==
CUDAGraphMode
.
FULL
and
\
if
cudagraph_mode
.
decode_mode
()
==
CUDAGraphMode
.
FULL
and
\
cudagraph_mode
.
separate_routine
():
cudagraph_mode
.
separate_routine
():
max_num_tokens
=
self
.
scheduler_config
.
max_num_seqs
*
\
max_num_tokens
=
self
.
scheduler_config
.
max_num_seqs
*
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment