Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2d7620c3
Unverified
Commit
2d7620c3
authored
Jun 25, 2025
by
Chenyaaang
Committed by
GitHub
Jun 25, 2025
Browse files
[TPU] Add TPU specific var VLLM_TPU_MOST_MODEL_LEN (#19919)
Signed-off-by:
Chenyaaang
<
chenyangli@google.com
>
parent
55c65ab4
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
185 additions
and
77 deletions
+185
-77
tests/v1/tpu/worker/test_tpu_model_runner.py
tests/v1/tpu/worker/test_tpu_model_runner.py
+14
-0
vllm/envs.py
vllm/envs.py
+3
-0
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+0
-10
vllm/v1/attention/backends/pallas.py
vllm/v1/attention/backends/pallas.py
+5
-0
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+163
-67
No files found.
tests/v1/tpu/worker/test_tpu_model_runner.py
View file @
2d7620c3
...
...
@@ -587,3 +587,17 @@ def test_init_kv_cache_with_kv_sharing_valid():
assert
len
(
kv_cache_config
.
kv_cache_groups
[
0
].
layer_names
)
==
2
assert
kv_cache_config
.
kv_cache_groups
[
0
].
layer_names
[
0
]
==
layer_0
assert
kv_cache_config
.
kv_cache_groups
[
0
].
layer_names
[
1
]
==
layer_1
def
test_most_model_len
(
monkeypatch
:
pytest
.
MonkeyPatch
):
monkeypatch
.
setenv
(
"VLLM_TPU_MOST_MODEL_LEN"
,
"2048"
)
vllm_config
=
get_vllm_config
()
vllm_config
.
model_config
.
max_model_len
=
32000
vllm_config
.
scheduler_config
.
max_num_seqs
=
1200
model_runner
=
get_model_runner
(
vllm_config
)
# verify model runner will adjust num_reqs to avoid SMEM OOM.
assert
model_runner
.
num_reqs_most_model_len
==
1200
# num_page_per_req = 32k // 128
# num_reqs = 1024 ** 2 // 2 // num_page_per_req // 4 = 524
assert
model_runner
.
num_reqs_max_model_len
==
524
vllm/envs.py
View file @
2d7620c3
...
...
@@ -119,6 +119,7 @@ if TYPE_CHECKING:
VLLM_MARLIN_USE_ATOMIC_ADD
:
bool
=
False
VLLM_V0_USE_OUTLINES_CACHE
:
bool
=
False
VLLM_TPU_BUCKET_PADDING_GAP
:
int
=
0
VLLM_TPU_MOST_MODEL_LEN
:
Optional
[
int
]
=
None
VLLM_USE_DEEP_GEMM
:
bool
=
False
VLLM_XGRAMMAR_CACHE_MB
:
int
=
0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD
:
int
=
256
...
...
@@ -833,6 +834,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_TPU_BUCKET_PADDING_GAP"
:
lambda
:
int
(
os
.
environ
[
"VLLM_TPU_BUCKET_PADDING_GAP"
])
if
"VLLM_TPU_BUCKET_PADDING_GAP"
in
os
.
environ
else
0
,
"VLLM_TPU_MOST_MODEL_LEN"
:
lambda
:
maybe_convert_int
(
os
.
environ
.
get
(
"VLLM_TPU_MOST_MODEL_LEN"
,
None
)),
# Allow use of DeepGemm kernels for fused moe ops.
"VLLM_USE_DEEP_GEMM"
:
...
...
vllm/platforms/tpu.py
View file @
2d7620c3
...
...
@@ -122,16 +122,6 @@ class TpuPlatform(Platform):
PallasAttentionBackend
)
cache_config
.
block_size
=
PallasAttentionBackend
.
get_page_size
(
vllm_config
)
# type: ignore[assignment]
min_page_size
=
PallasAttentionBackend
.
get_min_page_size
(
vllm_config
)
if
min_page_size
>
cache_config
.
block_size
:
logger
.
warning
(
"Increase the page size from %s to %s to make sure there's"
"no SMEM OOM"
,
cache_config
.
block_size
,
min_page_size
,
)
cache_config
.
block_size
=
min_page_size
# type: ignore[assignment]
parallel_config
=
vllm_config
.
parallel_config
scheduler_config
=
vllm_config
.
scheduler_config
...
...
vllm/v1/attention/backends/pallas.py
View file @
2d7620c3
...
...
@@ -71,6 +71,11 @@ class PallasAttentionBackend(AttentionBackend):
min_page_size
=
1
<<
(
min_page_size
-
1
).
bit_length
()
return
min_page_size
@
staticmethod
def
get_max_num_seqs
(
model_len
:
int
,
page_size
:
int
)
->
int
:
num_page_per_req
=
cdiv
(
model_len
,
page_size
)
return
1024
*
1024
//
2
//
num_page_per_req
//
4
# TPU has limited SREGs (scalar registers), if page_size is too small, we
# can spill SREGs easily which leads to bad performance. The strategy we
# apply here is trying to split max-model-len to 16 pages which make the
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
2d7620c3
...
...
@@ -37,8 +37,8 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
,
SlidingWindowSpec
)
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
Logprobs
Tensor
s
,
ModelRunnerOutput
)
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
Logprobs
List
s
,
LogprobsTensors
,
ModelRunnerOutput
)
from
vllm.v1.sample.tpu.metadata
import
TPUSupportedSamplingMetadata
from
vllm.v1.sample.tpu.sampler
import
Sampler
as
TPUSampler
from
vllm.v1.utils
import
bind_kv_cache
...
...
@@ -150,7 +150,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self
.
sliding_window
=
model_config
.
get_sliding_window
()
self
.
block_size
=
cache_config
.
block_size
self
.
max_model_len
=
model_config
.
max_model_len
self
.
most_model_len
=
envs
.
VLLM_TPU_MOST_MODEL_LEN
self
.
max_num_blocks_per_req
=
cdiv
(
self
.
max_model_len
,
self
.
block_size
)
self
.
num_blocks_per_most_len_req
=
cdiv
(
self
.
most_model_len
,
self
.
block_size
)
if
self
.
most_model_len
is
not
None
else
None
# InputBatch needs to work with sampling tensors greater than padding
# to avoid dynamic shapes. Also, avoid suboptimal alignment.
self
.
max_num_reqs
=
max
(
scheduler_config
.
max_num_seqs
,
MIN_NUM_SEQS
)
...
...
@@ -220,12 +224,19 @@ class TPUModelRunner(LoRAModelRunnerMixin):
dtype
=
torch
.
int32
,
device
=
"cpu"
)
self
.
positions_np
=
self
.
positions_cpu
.
numpy
()
self
.
block_table_cpu
=
torch
.
zeros
(
(
self
.
max_num_reqs
,
self
.
max_num_blocks_per_req
),
dtype
=
torch
.
int32
,
device
=
"cpu"
)
# adjust num_reqs to avoid SMEM OOM.
self
.
num_reqs_most_model_len
=
min
(
PallasAttentionBackend
.
get_max_num_seqs
(
self
.
most_model_len
,
self
.
block_size
),
self
.
max_num_reqs
)
if
self
.
most_model_len
is
not
None
else
None
self
.
num_reqs_max_model_len
=
min
(
PallasAttentionBackend
.
get_max_num_seqs
(
self
.
max_model_len
,
self
.
block_size
),
self
.
max_num_reqs
)
self
.
query_start_loc_cpu
=
torch
.
zeros
(
self
.
max_num_tokens
+
1
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
...
...
@@ -515,25 +526,50 @@ class TPUModelRunner(LoRAModelRunnerMixin):
return
kv_cache_spec
def
_prepare_inputs
(
self
,
scheduler_output
:
"SchedulerOutput"
):
total_num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
assert
total_num_scheduled_tokens
>
0
def
_prepare_inputs
(
self
,
scheduler_output
:
"SchedulerOutput"
,
start_index
:
int
):
assert
scheduler_output
.
total_num_scheduled_tokens
>
0
num_reqs
=
self
.
input_batch
.
num_reqs
assert
num_reqs
>
0
assert
start_index
<
num_reqs
# Get the number of scheduled tokens for each request.
use_max_model_len
=
self
.
most_model_len
is
None
num_scheduled_tokens_per_req
=
[]
max_num_scheduled_tokens_all_reqs
=
0
for
req_id
in
self
.
input_batch
.
req_ids
[:
num_reqs
]:
end_index
=
start_index
# Use either most_model_len or max_model_len depending on request size.
for
i
in
range
(
start_index
,
num_reqs
):
req_id
=
self
.
input_batch
.
req_ids
[
i
]
assert
req_id
is
not
None
num_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
if
not
use_max_model_len
and
num_tokens
>
self
.
most_model_len
:
use_max_model_len
=
True
num_scheduled_tokens_per_req
.
append
(
num_tokens
)
max_num_scheduled_tokens_all_reqs
=
max
(
max_num_scheduled_tokens_all_reqs
,
num_tokens
)
if
use_max_model_len
:
if
len
(
num_scheduled_tokens_per_req
)
>
self
.
num_reqs_max_model_len
:
num_scheduled_tokens_per_req
=
\
num_scheduled_tokens_per_req
[:
self
.
num_reqs_max_model_len
]
end_index
=
start_index
+
self
.
num_reqs_max_model_len
else
:
end_index
=
num_reqs
else
:
if
len
(
num_scheduled_tokens_per_req
)
>
self
.
num_reqs_most_model_len
:
num_scheduled_tokens_per_req
=
\
num_scheduled_tokens_per_req
[:
self
.
num_reqs_most_model_len
]
end_index
=
start_index
+
self
.
num_reqs_most_model_len
else
:
end_index
=
num_reqs
max_num_scheduled_tokens_all_reqs
=
max
(
num_scheduled_tokens_per_req
)
num_scheduled_tokens_per_req
=
np
.
array
(
num_scheduled_tokens_per_req
,
dtype
=
np
.
int32
)
total_num_scheduled_tokens
=
sum
(
num_scheduled_tokens_per_req
)
assert
max_num_scheduled_tokens_all_reqs
>
0
num_reqs
=
len
(
num_scheduled_tokens_per_req
)
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
# For each scheduled token, what are the corresponding req index.
...
...
@@ -615,13 +651,29 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self
.
input_batch
.
block_table
[
0
].
slot_mapping_cpu
[:
padded_total_num_scheduled_tokens
].
to
(
self
.
device
))
block_tables
=
self
.
block_table_cpu
[:
self
.
max_num_reqs
]
block_tables
[:
num_reqs
,
:
self
.
max_num_blocks_per_req
]
=
(
self
.
input_batch
.
block_table
[
0
].
get_cpu_tensor
()[:
num_reqs
])
if
use_max_model_len
:
block_tables
=
self
.
block_table_cpu
[:
self
.
num_reqs_max_model_len
,
:
self
.
max_num_blocks_per_req
]
block_tables
[:
num_reqs
,
:
self
.
max_num_blocks_per_req
]
=
(
self
.
input_batch
.
block_table
[
0
].
get_cpu_tensor
()[:
num_reqs
])
query_start_loc
=
self
.
query_start_loc_cpu
[:
self
.
num_reqs_max_model_len
+
1
].
to
(
self
.
device
)
seq_lens
=
self
.
seq_lens_cpu
[:
self
.
num_reqs_max_model_len
].
to
(
self
.
device
)
else
:
block_tables
=
self
.
block_table_cpu
[:
self
.
num_reqs_most_model_len
,
:
self
.
num_blocks_per_most_len_req
]
block_tables
[:
num_reqs
,
:
self
.
num_blocks_per_most_len_req
]
=
(
self
.
input_batch
.
block_table
[
0
].
get_cpu_tensor
()
[:
num_reqs
,
:
self
.
num_blocks_per_most_len_req
])
query_start_loc
=
self
.
query_start_loc_cpu
[:
self
.
num_reqs_most_model_len
+
1
].
to
(
self
.
device
)
seq_lens
=
self
.
seq_lens_cpu
[:
self
.
num_reqs_most_model_len
].
to
(
self
.
device
)
block_tables
=
block_tables
.
to
(
self
.
device
)
query_start_loc
=
self
.
query_start_loc_cpu
[:
self
.
max_num_reqs
+
1
].
to
(
self
.
device
)
seq_lens
=
self
.
seq_lens_cpu
[:
self
.
max_num_reqs
].
to
(
self
.
device
)
if
self
.
lora_config
is
not
None
:
# We need to respect padding when activating LoRA adapters
...
...
@@ -672,7 +724,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
layer_name
:
attn_metadata
for
layer_name
in
layer_names
}
return
per_layer_attn_metadata
,
logits_indices
,
padded_num_reqs
return
per_layer_attn_metadata
,
logits_indices
,
padded_num_reqs
,
\
num_reqs
,
end_index
def
_scatter_placeholders
(
self
,
...
...
@@ -847,52 +900,84 @@ class TPUModelRunner(LoRAModelRunnerMixin):
else
:
mm_embeds
=
[]
xm
.
mark_step
()
# Prepare inputs
attn_metadata
,
logits_indices
,
padded_num_reqs
=
self
.
_prepare_inputs
(
scheduler_output
)
input_ids
,
inputs_embeds
=
self
.
_get_model_inputs
(
self
.
input_ids
,
mm_embeds
)
xm
.
mark_step
()
num_reqs
=
self
.
input_batch
.
num_reqs
# Run the decoder
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
num_tokens
=
scheduler_output
.
total_num_scheduled_tokens
):
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
self
.
position_ids
,
inputs_embeds
=
inputs_embeds
,
)
hidden_states
=
self
.
select_hidden_states
(
hidden_states
,
logits_indices
)
logits
=
self
.
compute_logits
(
hidden_states
)
tpu_sampling_metadata
=
TPUSupportedSamplingMetadata
.
\
from_input_batch
(
self
.
input_batch
,
padded_num_reqs
,
self
.
device
)
if
scheduler_output
.
grammar_bitmask
is
not
None
:
require_struct_decoding
,
grammar_bitmask_padded
,
arange
=
\
self
.
prepare_structured_decoding_input
(
logits
,
scheduler_output
)
logits
=
self
.
structured_decode
(
require_struct_decoding
,
grammar_bitmask_padded
,
logits
,
arange
)
selected_token_ids
=
self
.
sample_from_logits_func
(
logits
,
tpu_sampling_metadata
)
# NOTE (NickLucche) Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs. We can't enforce it due
# to recompilations outside torch.compiled code, so just make sure
# `sample_from_logits` does not modify the logits in-place.
logprobs
=
self
.
gather_logprobs
(
logits
,
selected_token_ids
)
\
if
tpu_sampling_metadata
.
logprobs
else
None
# Remove padding on cpu and keep dynamic op outside of xla graph.
selected_token_ids
=
selected_token_ids
.
cpu
()[:
num_reqs
]
logprobs_lists
=
logprobs
.
tolists
()
\
if
tpu_sampling_metadata
.
logprobs
else
None
# Prepare inputs, the requests might be splitted into multiple
# executions, combine the result of each execution.
start_index
=
0
combined_selected_tokens
:
list
[
torch
.
Tensor
]
=
[]
combined_logprobs
:
list
[
LogprobsLists
]
=
[]
while
start_index
<
self
.
input_batch
.
num_reqs
:
attn_metadata
,
logits_indices
,
padded_num_reqs
,
num_reqs
,
\
end_index
=
self
.
_prepare_inputs
(
scheduler_output
,
start_index
)
input_ids
,
inputs_embeds
=
self
.
_get_model_inputs
(
self
.
input_ids
,
mm_embeds
)
xm
.
mark_step
()
# Run the decoder
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
num_tokens
=
scheduler_output
.
total_num_scheduled_tokens
):
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
self
.
position_ids
,
inputs_embeds
=
inputs_embeds
,
)
hidden_states
=
self
.
select_hidden_states
(
hidden_states
,
logits_indices
)
logits
=
self
.
compute_logits
(
hidden_states
)
tpu_sampling_metadata
=
TPUSupportedSamplingMetadata
.
\
from_input_batch
(
self
.
input_batch
,
padded_num_reqs
,
self
.
device
)
if
scheduler_output
.
grammar_bitmask
is
not
None
:
require_struct_decoding
,
grammar_bitmask_padded
,
arange
=
\
self
.
prepare_structured_decoding_input
(
logits
,
scheduler_output
)
logits
=
self
.
structured_decode
(
require_struct_decoding
,
grammar_bitmask_padded
,
logits
,
arange
)
selected_token_ids
=
self
.
sample_from_logits_func
(
logits
,
tpu_sampling_metadata
)
# NOTE (NickLucche) Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs. We can't enforce it
# due to recompilations outside torch.compiled code, so just make
# sure `sample_from_logits` does not modify the logits in-place.
logprobs
=
self
.
gather_logprobs
(
logits
,
selected_token_ids
)
\
if
tpu_sampling_metadata
.
logprobs
else
None
# Remove padding on cpu and keep dynamic op outside of xla graph.
selected_token_ids
=
selected_token_ids
.
cpu
()[:
num_reqs
]
combined_selected_tokens
.
append
(
selected_token_ids
)
if
tpu_sampling_metadata
.
logprobs
:
combined_logprobs
.
append
(
logprobs
.
tolists
())
start_index
=
end_index
selected_token_ids
=
torch
.
cat
(
combined_selected_tokens
,
dim
=
0
)
if
tpu_sampling_metadata
.
logprobs
:
def
concat_lists
(
input_lists
):
result
=
[]
for
input_list
in
input_lists
:
result
.
extend
(
input_list
)
return
result
logprobs_lists
=
LogprobsLists
(
logprob_token_ids
=
concat_lists
(
[
lp
.
logprob_token_ids
for
lp
in
combined_logprobs
]),
logprobs
=
concat_lists
([
lp
.
logprobs
for
lp
in
combined_logprobs
]),
sampled_token_ranks
=
concat_lists
([
lp
.
sampled_token_ranks
for
lp
in
combined_logprobs
]))
else
:
logprobs_lists
=
None
# Update the cache state concurrently. Code above will not block until
# we use `selected_token_ids`. Add mark_step if post-processing changes
request_seq_lens
:
list
[
tuple
[
int
,
CachedRequestState
,
int
]]
=
[]
discard_sampled_tokens_req_indices
=
[]
num_reqs
=
self
.
input_batch
.
num_reqs
for
i
,
req_id
in
zip
(
range
(
num_reqs
),
self
.
input_batch
.
req_ids
):
assert
req_id
is
not
None
req_state
=
self
.
requests
[
req_id
]
...
...
@@ -1020,7 +1105,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self
.
sampler
=
TPUSampler
()
@
torch
.
no_grad
()
def
_dummy_run
(
self
,
num_tokens
:
int
)
->
None
:
def
_dummy_run
(
self
,
num_tokens
:
int
,
num_reqs
:
int
,
num_blocks
:
int
)
->
None
:
if
self
.
is_multimodal_model
:
input_ids
=
None
inputs_embeds
=
torch
.
zeros
((
num_tokens
,
self
.
hidden_size
),
...
...
@@ -1030,20 +1116,19 @@ class TPUModelRunner(LoRAModelRunnerMixin):
input_ids
=
torch
.
zeros
((
num_tokens
),
dtype
=
torch
.
int32
).
to
(
self
.
device
)
inputs_embeds
=
None
actual_num_reqs
=
min
(
num_tokens
,
self
.
max_
num_reqs
)
actual_num_reqs
=
min
(
num_tokens
,
num_reqs
)
position_ids
=
torch
.
zeros
(
num_tokens
,
dtype
=
torch
.
int32
).
to
(
self
.
device
)
slot_mapping
=
torch
.
zeros
(
num_tokens
,
dtype
=
torch
.
int64
).
to
(
self
.
device
)
block_tables
=
torch
.
zeros
(
(
self
.
max_num_reqs
,
self
.
block_table_cpu
.
shape
[
1
]),
dtype
=
torch
.
int32
).
to
(
self
.
device
)
query_lens
=
[
1
]
*
self
.
max_num_reqs
block_tables
=
torch
.
zeros
((
num_reqs
,
num_blocks
),
dtype
=
torch
.
int32
).
to
(
self
.
device
)
query_lens
=
[
1
]
*
num_reqs
query_start_loc
=
torch
.
cumsum
(
torch
.
tensor
([
0
]
+
query_lens
,
dtype
=
torch
.
int32
),
dim
=
0
,
dtype
=
torch
.
int32
).
to
(
self
.
device
)
context_lens
=
torch
.
ones
((
self
.
max_
num_reqs
,
),
context_lens
=
torch
.
ones
((
num_reqs
,
),
dtype
=
torch
.
int32
).
to
(
self
.
device
)
num_seqs
=
torch
.
tensor
([
actual_num_reqs
],
dtype
=
torch
.
int32
).
to
(
self
.
device
)
...
...
@@ -1061,6 +1146,9 @@ class TPUModelRunner(LoRAModelRunnerMixin):
torch
.
_dynamo
.
mark_dynamic
(
input_ids
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
position_ids
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
attn_metadata
.
slot_mapping
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
attn_metadata
.
block_tables
,
(
0
,
1
))
torch
.
_dynamo
.
mark_dynamic
(
attn_metadata
.
context_lens
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
attn_metadata
.
query_start_loc
,
0
)
layer_names
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
).
keys
()
...
...
@@ -1152,7 +1240,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
start
=
time
.
perf_counter
()
for
num_tokens
in
self
.
num_tokens_paddings
:
logger
.
info
(
" -- num_tokens: %d"
,
num_tokens
)
self
.
_dummy_run
(
num_tokens
)
self
.
_dummy_run
(
num_tokens
,
self
.
num_reqs_max_model_len
,
self
.
max_num_blocks_per_req
)
if
self
.
most_model_len
is
not
None
:
self
.
_dummy_run
(
num_tokens
,
self
.
num_reqs_most_model_len
,
self
.
num_blocks_per_most_len_req
)
xm
.
wait_device_ops
()
end
=
time
.
perf_counter
()
logger
.
info
(
"Compilation finished in %.2f [secs]."
,
end
-
start
)
...
...
@@ -1341,7 +1433,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self
.
encoder_cache
[
"tmp"
]
=
dict
(
enumerate
(
dummy_encoder_outputs
))
# Trigger compilation for general shape.
self
.
_dummy_run
(
num_tokens
)
self
.
_dummy_run
(
num_tokens
,
self
.
num_reqs_max_model_len
,
self
.
max_num_blocks_per_req
)
if
self
.
most_model_len
is
not
None
:
self
.
_dummy_run
(
num_tokens
,
self
.
num_reqs_most_model_len
,
self
.
num_blocks_per_most_len_req
)
xm
.
mark_step
()
xm
.
wait_device_ops
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment