Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c8ab988b
Unverified
Commit
c8ab988b
authored
Dec 04, 2025
by
Lucas Wilkinson
Committed by
GitHub
Dec 04, 2025
Browse files
[BugFix] Fix DBO assert `assert B_block_table == B_q` (#29933)
Signed-off-by:
Lucas Wilkinson
<
lwilkins@redhat.com
>
parent
48a5fff6
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
83 additions
and
63 deletions
+83
-63
tests/v1/attention/test_attention_splitting.py
tests/v1/attention/test_attention_splitting.py
+9
-3
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+2
-2
vllm/v1/worker/dp_utils.py
vllm/v1/worker/dp_utils.py
+4
-39
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+29
-16
vllm/v1/worker/ubatch_utils.py
vllm/v1/worker/ubatch_utils.py
+39
-3
No files found.
tests/v1/attention/test_attention_splitting.py
View file @
c8ab988b
...
@@ -13,7 +13,7 @@ from vllm.v1.attention.backends.utils import (
...
@@ -13,7 +13,7 @@ from vllm.v1.attention.backends.utils import (
split_attn_metadata
,
split_attn_metadata
,
split_decodes_and_prefills
,
split_decodes_and_prefills
,
)
)
from
vllm.v1.worker.ubatch_utils
import
create_ubatch_slices
from
vllm.v1.worker.ubatch_utils
import
maybe_
create_ubatch_slices
@
pytest
.
fixture
@
pytest
.
fixture
...
@@ -294,8 +294,14 @@ def test_prefill_split_across_ubatches(
...
@@ -294,8 +294,14 @@ def test_prefill_split_across_ubatches(
qsl_np
=
common
.
query_start_loc_cpu
.
numpy
()
qsl_np
=
common
.
query_start_loc_cpu
.
numpy
()
num_tokens
=
common
.
num_actual_tokens
num_tokens
=
common
.
num_actual_tokens
ubatch_slices
=
create_ubatch_slices
(
num_scheduled_tokens
,
split_point
)
ubatch_slices
,
_
=
maybe_create_ubatch_slices
(
assert
len
(
ubatch_slices
)
==
2
True
,
num_scheduled_tokens
,
num_tokens
,
batch_spec
.
batch_size
,
split_point
=
split_point
,
)
assert
ubatch_slices
is
not
None
and
len
(
ubatch_slices
)
==
2
first_meta
=
_make_metadata_with_slice
(
ubatch_slices
[
0
],
common
)
first_meta
=
_make_metadata_with_slice
(
ubatch_slices
[
0
],
common
)
second_meta
=
_make_metadata_with_slice
(
ubatch_slices
[
1
],
common
)
second_meta
=
_make_metadata_with_slice
(
ubatch_slices
[
1
],
common
)
...
...
vllm/v1/spec_decode/eagle.py
View file @
c8ab988b
...
@@ -1258,7 +1258,7 @@ class EagleProposer:
...
@@ -1258,7 +1258,7 @@ class EagleProposer:
num_tokens_padded
:
int
,
num_tokens_padded
:
int
,
)
->
tuple
[
int
,
torch
.
Tensor
]:
)
->
tuple
[
int
,
torch
.
Tensor
]:
# TODO(Flechman): support DBO ubatching
# TODO(Flechman): support DBO ubatching
ubatch_slices
,
num_toks_across_dp
=
coordinate_batch_across_dp
(
should_ubatch
,
num_toks_across_dp
=
coordinate_batch_across_dp
(
num_tokens_unpadded
=
num_tokens_unpadded
,
num_tokens_unpadded
=
num_tokens_unpadded
,
parallel_config
=
self
.
vllm_config
.
parallel_config
,
parallel_config
=
self
.
vllm_config
.
parallel_config
,
allow_microbatching
=
False
,
allow_microbatching
=
False
,
...
@@ -1267,7 +1267,7 @@ class EagleProposer:
...
@@ -1267,7 +1267,7 @@ class EagleProposer:
uniform_decode
=
None
,
uniform_decode
=
None
,
num_scheduled_tokens_per_request
=
None
,
num_scheduled_tokens_per_request
=
None
,
)
)
assert
ubatch_slices
is
None
,
"DBO ubatching not implemented for EAGLE"
assert
not
should_ubatch
,
"DBO ubatching not implemented for EAGLE"
num_tokens_dp_padded
=
num_tokens_padded
num_tokens_dp_padded
=
num_tokens_padded
if
num_toks_across_dp
is
not
None
:
if
num_toks_across_dp
is
not
None
:
...
...
vllm/v1/worker/dp_utils.py
View file @
c8ab988b
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -9,10 +10,7 @@ from vllm.config import ParallelConfig
...
@@ -9,10 +10,7 @@ from vllm.config import ParallelConfig
from
vllm.distributed.parallel_state
import
get_dp_group
from
vllm.distributed.parallel_state
import
get_dp_group
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.worker.ubatch_utils
import
(
from
vllm.v1.worker.ubatch_utils
import
(
UBatchSlice
,
UBatchSlices
,
check_ubatch_thresholds
,
check_ubatch_thresholds
,
create_ubatch_slices
,
is_second_ubatch_empty
,
is_second_ubatch_empty
,
)
)
...
@@ -91,20 +89,6 @@ def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch
...
@@ -91,20 +89,6 @@ def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch
return
num_tokens_across_dp
.
cpu
()
return
num_tokens_across_dp
.
cpu
()
# This just pads the second ubatch slice out to the total number of tokens
# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
def
_pad_out_ubatch_slice
(
ubatch_slices
:
UBatchSlices
,
num_total_tokens
:
int
)
->
UBatchSlices
:
padded_second_token_slice
=
slice
(
ubatch_slices
[
1
].
token_slice
.
start
,
num_total_tokens
)
ubatch_slices
[
1
]
=
UBatchSlice
(
ubatch_slices
[
1
].
request_slice
,
padded_second_token_slice
)
return
ubatch_slices
def
_synchronize_dp_ranks
(
def
_synchronize_dp_ranks
(
num_tokens_unpadded
:
int
,
num_tokens_unpadded
:
int
,
num_tokens_padded
:
int
,
num_tokens_padded
:
int
,
...
@@ -175,7 +159,7 @@ def coordinate_batch_across_dp(
...
@@ -175,7 +159,7 @@ def coordinate_batch_across_dp(
num_tokens_padded
:
int
|
None
=
None
,
num_tokens_padded
:
int
|
None
=
None
,
uniform_decode
:
bool
|
None
=
None
,
uniform_decode
:
bool
|
None
=
None
,
num_scheduled_tokens_per_request
:
np
.
ndarray
|
None
=
None
,
num_scheduled_tokens_per_request
:
np
.
ndarray
|
None
=
None
,
)
->
tuple
[
UBatchSlices
|
None
,
torch
.
Tensor
|
None
]:
)
->
tuple
[
bool
,
torch
.
Tensor
|
None
]:
"""
"""
Coordinates amongst all DP ranks to determine if and how the full batch
Coordinates amongst all DP ranks to determine if and how the full batch
should be split into microbatches.
should be split into microbatches.
...
@@ -204,7 +188,7 @@ def coordinate_batch_across_dp(
...
@@ -204,7 +188,7 @@ def coordinate_batch_across_dp(
"""
"""
if
parallel_config
.
data_parallel_size
==
1
:
if
parallel_config
.
data_parallel_size
==
1
:
# Early exit.
# Early exit.
return
Non
e
,
None
return
Fals
e
,
None
# If the caller has explicitly enabled microbatching.
# If the caller has explicitly enabled microbatching.
should_attempt_ubatching
=
False
should_attempt_ubatching
=
False
...
@@ -228,23 +212,4 @@ def coordinate_batch_across_dp(
...
@@ -228,23 +212,4 @@ def coordinate_batch_across_dp(
parallel_config
,
parallel_config
,
)
)
# Don't microbatch unless every other DP worker is also microbatching
return
(
should_ubatch
,
num_tokens_after_padding
)
if
not
should_ubatch
:
return
(
None
,
num_tokens_after_padding
)
# This doesn't actually pad the ubatch slices. It just initializes the
# split point to the padded value so that padding can be applied
# to the second ubatch in pad_out_ubatch_slice after attention
# metadata creation
assert
num_tokens_after_padding
is
not
None
num_tokens_padded
=
int
(
num_tokens_after_padding
[
0
].
item
())
token_split_point
=
int
(
num_tokens_padded
)
//
2
assert
num_scheduled_tokens_per_request
is
not
None
ubatch_slices
=
create_ubatch_slices
(
num_scheduled_tokens_per_request
,
token_split_point
)
ubatch_slices
=
_pad_out_ubatch_slice
(
ubatch_slices
,
num_tokens_padded
)
assert
sum
(
s
.
num_tokens
for
s
in
ubatch_slices
)
==
num_tokens_padded
return
(
ubatch_slices
,
num_tokens_after_padding
)
vllm/v1/worker/gpu_model_runner.py
View file @
c8ab988b
...
@@ -153,6 +153,7 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
...
@@ -153,6 +153,7 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from
vllm.v1.worker.ubatch_utils
import
(
from
vllm.v1.worker.ubatch_utils
import
(
UBatchSlices
,
UBatchSlices
,
check_ubatch_thresholds
,
check_ubatch_thresholds
,
maybe_create_ubatch_slices
,
)
)
from
vllm.v1.worker.utils
import
is_residual_scattered_for_sp
from
vllm.v1.worker.utils
import
is_residual_scattered_for_sp
...
@@ -2743,7 +2744,7 @@ class GPUModelRunner(
...
@@ -2743,7 +2744,7 @@ class GPUModelRunner(
)
->
tuple
[
)
->
tuple
[
CUDAGraphMode
,
CUDAGraphMode
,
BatchDescriptor
,
BatchDescriptor
,
UBatchSlices
|
None
,
bool
,
torch
.
Tensor
|
None
,
torch
.
Tensor
|
None
,
CUDAGraphStat
|
None
,
CUDAGraphStat
|
None
,
]:
]:
...
@@ -2779,7 +2780,7 @@ class GPUModelRunner(
...
@@ -2779,7 +2780,7 @@ class GPUModelRunner(
# Extra coordination when running data-parallel since we need to coordinate
# Extra coordination when running data-parallel since we need to coordinate
# across ranks
# across ranks
ubatch_slices
,
num_tokens_across_dp
=
Non
e
,
None
should_ubatch
,
num_tokens_across_dp
=
Fals
e
,
None
if
self
.
vllm_config
.
parallel_config
.
data_parallel_size
>
1
:
if
self
.
vllm_config
.
parallel_config
.
data_parallel_size
>
1
:
# Disable DP padding when running eager to avoid excessive padding when
# Disable DP padding when running eager to avoid excessive padding when
# running prefills. This lets us set cudagraph_mode="NONE" on the prefiller
# running prefills. This lets us set cudagraph_mode="NONE" on the prefiller
...
@@ -2789,8 +2790,8 @@ class GPUModelRunner(
...
@@ -2789,8 +2790,8 @@ class GPUModelRunner(
self
.
compilation_config
.
cudagraph_mode
!=
CUDAGraphMode
.
NONE
self
.
compilation_config
.
cudagraph_mode
!=
CUDAGraphMode
.
NONE
)
)
ubatch_slices
,
num_tokens_across_dp
=
coordinate_batch_across_dp
(
should_ubatch
,
num_tokens_across_dp
=
coordinate_batch_across_dp
(
num_tokens_unpadded
=
num_tokens
_padded
,
num_tokens_unpadded
=
num_tokens
,
parallel_config
=
self
.
parallel_config
,
parallel_config
=
self
.
parallel_config
,
allow_microbatching
=
allow_microbatching
,
allow_microbatching
=
allow_microbatching
,
allow_dp_padding
=
allow_dp_padding
,
allow_dp_padding
=
allow_dp_padding
,
...
@@ -2822,7 +2823,7 @@ class GPUModelRunner(
...
@@ -2822,7 +2823,7 @@ class GPUModelRunner(
return
(
return
(
cudagraph_mode
,
cudagraph_mode
,
batch_descriptor
,
batch_descriptor
,
ubatch_slices
,
should_ubatch
,
num_tokens_across_dp
,
num_tokens_across_dp
,
cudagraph_stats
,
cudagraph_stats
,
)
)
...
@@ -2921,7 +2922,7 @@ class GPUModelRunner(
...
@@ -2921,7 +2922,7 @@ class GPUModelRunner(
(
(
cudagraph_mode
,
cudagraph_mode
,
batch_desc
,
batch_desc
,
ubatch_slices
,
should_ubatch
,
num_tokens_across_dp
,
num_tokens_across_dp
,
cudagraph_stats
,
cudagraph_stats
,
)
=
self
.
_determine_batch_execution_and_padding
(
)
=
self
.
_determine_batch_execution_and_padding
(
...
@@ -2934,10 +2935,10 @@ class GPUModelRunner(
...
@@ -2934,10 +2935,10 @@ class GPUModelRunner(
logger
.
debug
(
logger
.
debug
(
"Running batch with cudagraph_mode: %s, batch_descriptor: %s, "
"Running batch with cudagraph_mode: %s, batch_descriptor: %s, "
"
ubatch_slices
: %s, num_tokens_across_dp: %s"
,
"
should_ubatch
: %s, num_tokens_across_dp: %s"
,
cudagraph_mode
,
cudagraph_mode
,
batch_desc
,
batch_desc
,
ubatch_slices
,
should_ubatch
,
num_tokens_across_dp
,
num_tokens_across_dp
,
)
)
...
@@ -2945,10 +2946,18 @@ class GPUModelRunner(
...
@@ -2945,10 +2946,18 @@ class GPUModelRunner(
num_reqs_padded
=
(
num_reqs_padded
=
(
batch_desc
.
num_reqs
if
batch_desc
.
num_reqs
is
not
None
else
num_reqs
batch_desc
.
num_reqs
if
batch_desc
.
num_reqs
is
not
None
else
num_reqs
)
)
ubatch_slices
,
ubatch_slices_padded
=
maybe_create_ubatch_slices
(
should_ubatch
,
num_scheduled_tokens_np
,
num_tokens_padded
,
num_reqs_padded
,
)
use_spec_decode
=
len
(
scheduler_output
.
scheduled_spec_decode_tokens
)
>
0
pad_attn
=
cudagraph_mode
==
CUDAGraphMode
.
FULL
pad_attn
=
cudagraph_mode
==
CUDAGraphMode
.
FULL
use_spec_decode
=
len
(
scheduler_output
.
scheduled_spec_decode_tokens
)
>
0
ubatch_slices_attn
=
ubatch_slices_padded
if
pad_attn
else
ubatch_slices
(
attn_metadata
,
spec_decode_common_attn_metadata
)
=
(
(
attn_metadata
,
spec_decode_common_attn_metadata
)
=
(
self
.
_build_attention_metadata
(
self
.
_build_attention_metadata
(
num_tokens
=
num_tokens_unpadded
,
num_tokens
=
num_tokens_unpadded
,
...
@@ -2956,7 +2965,7 @@ class GPUModelRunner(
...
@@ -2956,7 +2965,7 @@ class GPUModelRunner(
num_reqs
=
num_reqs
,
num_reqs
=
num_reqs
,
num_reqs_padded
=
num_reqs_padded
if
pad_attn
else
None
,
num_reqs_padded
=
num_reqs_padded
if
pad_attn
else
None
,
max_query_len
=
max_num_scheduled_tokens
,
max_query_len
=
max_num_scheduled_tokens
,
ubatch_slices
=
ubatch_slices
,
ubatch_slices
=
ubatch_slices
_attn
,
logits_indices
=
logits_indices
,
logits_indices
=
logits_indices
,
use_spec_decode
=
use_spec_decode
,
use_spec_decode
=
use_spec_decode
,
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
,
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
,
...
@@ -2993,7 +3002,7 @@ class GPUModelRunner(
...
@@ -2993,7 +3002,7 @@ class GPUModelRunner(
num_tokens_across_dp
=
num_tokens_across_dp
,
num_tokens_across_dp
=
num_tokens_across_dp
,
cudagraph_runtime_mode
=
cudagraph_mode
,
cudagraph_runtime_mode
=
cudagraph_mode
,
batch_descriptor
=
batch_desc
,
batch_descriptor
=
batch_desc
,
ubatch_slices
=
ubatch_slices
,
ubatch_slices
=
ubatch_slices
_padded
,
),
),
record_function_or_nullcontext
(
"gpu_model_runner: forward"
),
record_function_or_nullcontext
(
"gpu_model_runner: forward"
),
self
.
maybe_get_kv_connector_output
(
scheduler_output
)
as
kv_connector_output
,
self
.
maybe_get_kv_connector_output
(
scheduler_output
)
as
kv_connector_output
,
...
@@ -3945,7 +3954,7 @@ class GPUModelRunner(
...
@@ -3945,7 +3954,7 @@ class GPUModelRunner(
num_sampled_tokens
=
np
.
ones
(
num_reqs
,
dtype
=
np
.
int32
)
num_sampled_tokens
=
np
.
ones
(
num_reqs
,
dtype
=
np
.
int32
)
_cudagraph_mode
,
batch_desc
,
ubatch_slices
,
num_tokens_across_dp
,
_
=
(
_cudagraph_mode
,
batch_desc
,
should_ubatch
,
num_tokens_across_dp
,
_
=
(
self
.
_determine_batch_execution_and_padding
(
self
.
_determine_batch_execution_and_padding
(
num_tokens
=
num_tokens_unpadded
,
num_tokens
=
num_tokens_unpadded
,
num_reqs
=
num_reqs
,
num_reqs
=
num_reqs
,
...
@@ -3979,6 +3988,9 @@ class GPUModelRunner(
...
@@ -3979,6 +3988,9 @@ class GPUModelRunner(
num_reqs_padded
=
(
num_reqs_padded
=
(
batch_desc
.
num_reqs
if
batch_desc
.
num_reqs
is
not
None
else
num_reqs
batch_desc
.
num_reqs
if
batch_desc
.
num_reqs
is
not
None
else
num_reqs
)
)
ubatch_slices
,
ubatch_slices_padded
=
maybe_create_ubatch_slices
(
should_ubatch
,
num_scheduled_tokens
,
num_tokens_padded
,
num_reqs_padded
)
attn_metadata
:
PerLayerAttnMetadata
|
None
=
None
attn_metadata
:
PerLayerAttnMetadata
|
None
=
None
...
@@ -4000,11 +4012,12 @@ class GPUModelRunner(
...
@@ -4000,11 +4012,12 @@ class GPUModelRunner(
self
.
query_start_loc
.
np
[
1
:
num_reqs
+
1
]
=
cum_num_tokens
self
.
query_start_loc
.
np
[
1
:
num_reqs
+
1
]
=
cum_num_tokens
self
.
query_start_loc
.
copy_to_gpu
()
self
.
query_start_loc
.
copy_to_gpu
()
pad_attn
=
cudagraph_runtime_mode
==
CUDAGraphMode
.
FULL
attn_metadata
,
_
=
self
.
_build_attention_metadata
(
attn_metadata
,
_
=
self
.
_build_attention_metadata
(
num_tokens
=
num_tokens_unpadded
,
num_tokens
=
num_tokens_unpadded
,
num_reqs
=
num_reqs_padded
,
num_reqs
=
num_reqs_padded
,
max_query_len
=
max_query_len
,
max_query_len
=
max_query_len
,
ubatch_slices
=
ubatch_slices
,
ubatch_slices
=
ubatch_slices_padded
if
pad_attn
else
ubatch_slices
,
for_cudagraph_capture
=
is_graph_capturing
,
for_cudagraph_capture
=
is_graph_capturing
,
)
)
...
@@ -4056,11 +4069,11 @@ class GPUModelRunner(
...
@@ -4056,11 +4069,11 @@ class GPUModelRunner(
num_tokens_padded
,
None
,
False
num_tokens_padded
,
None
,
False
)
)
if
ubatch_slices
is
not
None
:
if
ubatch_slices
_padded
is
not
None
:
# Adjust values to reflect a single ubatch.
# Adjust values to reflect a single ubatch.
# TODO(sage,lucas): this is cruft that should be addressed in
# TODO(sage,lucas): this is cruft that should be addressed in
# the padding refactor.
# the padding refactor.
num_tokens_padded
=
ubatch_slices
[
0
].
num_tokens
num_tokens_padded
=
ubatch_slices
_padded
[
0
].
num_tokens
if
num_tokens_across_dp
is
not
None
:
if
num_tokens_across_dp
is
not
None
:
num_tokens_across_dp
[:]
=
num_tokens_padded
num_tokens_across_dp
[:]
=
num_tokens_padded
...
@@ -4073,7 +4086,7 @@ class GPUModelRunner(
...
@@ -4073,7 +4086,7 @@ class GPUModelRunner(
num_tokens_across_dp
=
num_tokens_across_dp
,
num_tokens_across_dp
=
num_tokens_across_dp
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
batch_descriptor
=
batch_desc
,
batch_descriptor
=
batch_desc
,
ubatch_slices
=
ubatch_slices
,
ubatch_slices
=
ubatch_slices
_padded
,
),
),
):
):
outputs
=
self
.
model
(
outputs
=
self
.
model
(
...
...
vllm/v1/worker/ubatch_utils.py
View file @
c8ab988b
...
@@ -42,9 +42,37 @@ def check_ubatch_thresholds(
...
@@ -42,9 +42,37 @@ def check_ubatch_thresholds(
return
num_tokens
>=
config
.
dbo_prefill_token_threshold
return
num_tokens
>=
config
.
dbo_prefill_token_threshold
def
create_ubatch_slices
(
# This just pads the second ubatch slice out to the total number of tokens
num_scheduled_tokens
:
np
.
ndarray
,
split_point
:
int
# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
def
_pad_out_ubatch_slices
(
ubatch_slices
:
UBatchSlices
,
num_total_tokens
:
int
,
num_reqs_padded
:
int
)
->
UBatchSlices
:
)
->
UBatchSlices
:
# TODO(lucas): handle empty second ubatch
padded_second_request_slice
=
slice
(
ubatch_slices
[
1
].
request_slice
.
start
,
num_reqs_padded
)
padded_second_token_slice
=
slice
(
ubatch_slices
[
1
].
token_slice
.
start
,
num_total_tokens
)
return
[
ubatch_slices
[
0
],
UBatchSlice
(
padded_second_request_slice
,
padded_second_token_slice
),
]
def
maybe_create_ubatch_slices
(
should_ubatch
:
bool
,
num_scheduled_tokens
:
np
.
ndarray
,
num_tokens_padded
:
int
,
num_reqs_padded
:
int
,
split_point
:
int
|
None
=
None
,
)
->
tuple
[
UBatchSlices
|
None
,
UBatchSlices
|
None
]:
if
not
should_ubatch
:
return
None
,
None
if
split_point
is
None
:
split_point
=
int
(
num_tokens_padded
)
//
2
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
# in cu_num_tokens directly (i.e. query_start_loc)
# in cu_num_tokens directly (i.e. query_start_loc)
cu_num_tokens
=
np
.
zeros
(
len
(
num_scheduled_tokens
)
+
1
,
dtype
=
np
.
int32
)
cu_num_tokens
=
np
.
zeros
(
len
(
num_scheduled_tokens
)
+
1
,
dtype
=
np
.
int32
)
...
@@ -67,7 +95,15 @@ def create_ubatch_slices(
...
@@ -67,7 +95,15 @@ def create_ubatch_slices(
)
)
second_ubatch_req_slice
=
slice
(
second_ubatch_req_start
,
len
(
cu_num_tokens
)
-
1
)
second_ubatch_req_slice
=
slice
(
second_ubatch_req_start
,
len
(
cu_num_tokens
)
-
1
)
return
[
ubatch_slices
=
[
UBatchSlice
(
first_ubatch_req_slice
,
first_ubatch_token_slice
),
UBatchSlice
(
first_ubatch_req_slice
,
first_ubatch_token_slice
),
UBatchSlice
(
second_ubatch_req_slice
,
second_ubatch_token_slice
),
UBatchSlice
(
second_ubatch_req_slice
,
second_ubatch_token_slice
),
]
]
ubatch_slices_padded
=
_pad_out_ubatch_slices
(
ubatch_slices
,
num_tokens_padded
,
num_reqs_padded
)
assert
sum
(
s
.
num_tokens
for
s
in
ubatch_slices_padded
)
==
num_tokens_padded
return
ubatch_slices
,
ubatch_slices_padded
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment