Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
be39e3cd
Unverified
Commit
be39e3cd
authored
Dec 12, 2024
by
youkaichao
Committed by
GitHub
Dec 13, 2024
Browse files
[core] clean up cudagraph batchsize padding logic (#10996)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
34f1a806
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
150 additions
and
104 deletions
+150
-104
tests/models/decoder_only/language/test_jamba.py
tests/models/decoder_only/language/test_jamba.py
+3
-2
tests/models/decoder_only/language/test_mamba.py
tests/models/decoder_only/language/test_mamba.py
+3
-2
tests/worker/test_encoder_decoder_model_runner.py
tests/worker/test_encoder_decoder_model_runner.py
+2
-2
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+2
-2
vllm/config.py
vllm/config.py
+105
-66
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+14
-6
vllm/model_executor/models/mamba.py
vllm/model_executor/models/mamba.py
+14
-7
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+2
-9
vllm/worker/enc_dec_model_runner.py
vllm/worker/enc_dec_model_runner.py
+1
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+4
-3
vllm/worker/xpu_model_runner.py
vllm/worker/xpu_model_runner.py
+0
-4
No files found.
tests/models/decoder_only/language/test_jamba.py
View file @
be39e3cd
import
pytest
from
tests.utils
import
multi_gpu_test
from
vllm.
config
import
VllmConfig
from
vllm.
engine.arg_utils
import
EngineArgs
from
vllm.sampling_params
import
SamplingParams
from
...utils
import
check_outputs_equal
...
...
@@ -189,7 +189,8 @@ def test_mamba_cache_cg_padding(
# This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible
while
len
(
example_prompts
)
==
VllmConfig
.
get_graph_batch_size
(
vllm_config
=
EngineArgs
(
model
=
model
).
create_engine_config
()
while
len
(
example_prompts
)
==
vllm_config
.
pad_for_cudagraph
(
len
(
example_prompts
)):
example_prompts
.
append
(
example_prompts
[
0
])
...
...
tests/models/decoder_only/language/test_mamba.py
View file @
be39e3cd
...
...
@@ -5,7 +5,7 @@ Run `pytest tests/models/test_mamba.py`.
import
pytest
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
from
vllm.
config
import
VllmConfig
from
vllm.
engine.arg_utils
import
EngineArgs
from
vllm.sampling_params
import
SamplingParams
from
...utils
import
check_outputs_equal
...
...
@@ -200,7 +200,8 @@ def test_mamba_cache_cg_padding(
# This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible
while
len
(
example_prompts
)
==
VllmConfig
.
get_graph_batch_size
(
vllm_config
=
EngineArgs
(
model
=
model
).
create_engine_config
()
while
len
(
example_prompts
)
==
vllm_config
.
pad_for_cudagraph
(
len
(
example_prompts
)):
example_prompts
.
append
(
example_prompts
[
0
])
...
...
tests/worker/test_encoder_decoder_model_runner.py
View file @
be39e3cd
...
...
@@ -4,7 +4,6 @@ from typing import List
import
pytest
import
torch
from
vllm.config
import
VllmConfig
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
...
...
@@ -548,7 +547,8 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
# With CUDA Graph capture and replay enabled, the decoder and encoder
# input sequences will be padded. Create the expected padded tensors
# accordingly.
graph_batch_size
=
VllmConfig
.
get_graph_batch_size
(
expanded_batch_size
)
graph_batch_size
=
model_runner
.
vllm_config
.
pad_for_cudagraph
(
expanded_batch_size
)
cuda_graph_pad_size
=
graph_batch_size
-
expanded_batch_size
padded_seq_lens
=
seq_lens
+
list
(
itertools
.
repeat
(
1
,
cuda_graph_pad_size
))
padded_encoder_seq_lens
=
encoder_seq_lens
+
list
(
...
...
tests/worker/test_model_runner.py
View file @
be39e3cd
...
...
@@ -3,7 +3,6 @@ from typing import List
import
pytest
import
torch
from
vllm.config
import
VllmConfig
from
vllm.distributed.parallel_state
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
from
vllm.engine.arg_utils
import
EngineArgs
...
...
@@ -177,7 +176,8 @@ def test_prepare_decode_cuda_graph(batch_size):
model_input
.
attn_metadata
,
model_input
.
attn_metadata
.
slot_mapping
)
assert
len
(
slot_mapping
)
==
len
(
input_tokens
)
expected_bs
=
VllmConfig
.
get_graph_batch_size
(
len
(
seq_group_metadata_list
))
expected_bs
=
model_runner
.
vllm_config
.
pad_for_cudagraph
(
len
(
seq_group_metadata_list
))
# Verify input metadata is correct for prompts.
device
=
model_runner
.
device
assert
attn_metadata
.
num_prefills
==
0
...
...
vllm/config.py
View file @
be39e3cd
...
...
@@ -2354,6 +2354,12 @@ class CompilationConfig(BaseModel):
# not configurable, computed after init
compile_sizes
:
List
[
int
]
=
PrivateAttr
capture_sizes
:
List
[
int
]
=
PrivateAttr
max_capture_size
:
int
=
PrivateAttr
# optimization:
# Intuitively, bs_to_padded_graph_size should be Dict[int, int].
# since we know all keys are in a range [0, max_capture_size],
# we can optimize it to List[int] for better lookup performance.
bs_to_padded_graph_size
:
List
[
int
]
=
PrivateAttr
# keep track of enabled and disabled custom ops
enabled_custom_ops
:
Counter
[
str
]
=
PrivateAttr
...
...
@@ -2365,6 +2371,19 @@ class CompilationConfig(BaseModel):
# Map from layer name to the attention cls
static_forward_context
:
Dict
[
str
,
Any
]
=
PrivateAttr
def
__repr__
(
self
)
->
str
:
exclude
=
{
"static_forward_context"
,
"enabled_custom_ops"
,
"disabled_custom_ops"
,
"compilation_time"
,
"bs_to_padded_graph_size"
,
"pass_config"
,
}
return
self
.
model_dump_json
(
exclude
=
exclude
,
exclude_unset
=
True
)
__str__
=
__repr__
@
classmethod
def
from_cli
(
cls
,
cli_value
:
str
)
->
"CompilationConfig"
:
"""Parse the CLI value for the compilation config."""
...
...
@@ -2450,18 +2469,22 @@ class CompilationConfig(BaseModel):
# sort to make sure cudagraph capture sizes are in descending order
self
.
capture_sizes
.
sort
(
reverse
=
True
)
self
.
max_capture_size
=
self
.
capture_sizes
[
0
]
if
self
.
capture_sizes
else
0
_BATCH_SIZE_ALIGNMENT
=
8
# all the token sizes that **can** be captured by cudagraph.
# they can be arbitrarily large.
# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
# the actual sizes to capture will be determined by the model,
# depending on the model's max_num_seqs.
# NOTE: get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE
=
[
1
,
2
,
4
]
+
[
_BATCH_SIZE_ALIGNMENT
*
i
for
i
in
range
(
1
,
1025
)
]
# pre-compute the mapping from batch size to padded graph size
self
.
bs_to_padded_graph_size
=
[
0
for
i
in
range
(
self
.
max_capture_size
+
1
)
]
for
end
,
start
in
zip
(
self
.
capture_sizes
,
self
.
capture_sizes
[
1
:]
+
[
0
]):
for
bs
in
range
(
start
,
end
):
if
bs
==
start
:
self
.
bs_to_padded_graph_size
[
bs
]
=
start
else
:
self
.
bs_to_padded_graph_size
[
bs
]
=
end
self
.
bs_to_padded_graph_size
[
self
.
max_capture_size
]
=
self
.
max_capture_size
@
dataclass
...
...
@@ -2491,40 +2514,12 @@ class VllmConfig:
init
=
True
)
# type: ignore
instance_id
:
str
=
""
@
staticmethod
def
get_graph_batch_size
(
batch_size
:
int
)
->
int
:
"""Returns the padded batch size given actual batch size.
Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
"""
if
batch_size
<=
2
:
return
batch_size
elif
batch_size
<=
4
:
return
4
else
:
return
((
batch_size
+
_BATCH_SIZE_ALIGNMENT
-
1
)
//
_BATCH_SIZE_ALIGNMENT
*
_BATCH_SIZE_ALIGNMENT
)
@
staticmethod
def
get_max_graph_batch_size
(
max_num_seqs
:
int
)
->
int
:
"""
max_num_seqs: Maximum number of sequences in a batch.
_BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.
pad the max_num_seqs if necessary by calling get_graph_batch_size,
which will deal with some edge cases like 1, 2, 4.
if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded
size. if not, it means the padded size is larger than the largest size
in _BATCH_SIZES_TO_CAPTURE, return the largest size in
_BATCH_SIZES_TO_CAPTURE.
"""
padded_size
=
VllmConfig
.
get_graph_batch_size
(
max_num_seqs
)
if
padded_size
in
_BATCH_SIZES_TO_CAPTURE
:
return
padded_size
assert
padded_size
>
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
return
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
def
pad_for_cudagraph
(
self
,
batch_size
:
int
)
->
int
:
# if batch_size > self.compilation_config.max_capture_size,
# it should raise an IndexError.
# the caller should make sure the batch_size is within the range,
# i.e., batch_size <= self.compilation_config.max_capture_size
return
self
.
compilation_config
.
bs_to_padded_graph_size
[
batch_size
]
@
staticmethod
def
_get_quantization_config
(
...
...
@@ -2618,27 +2613,7 @@ class VllmConfig:
self
.
compilation_config
.
pass_config
.
enable_reshape
=
False
self
.
compilation_config
.
level
=
CompilationLevel
.
PIECEWISE
if
not
envs
.
VLLM_USE_V1
:
max_batchsize_to_capture
=
0
if
self
.
scheduler_config
is
not
None
and
\
self
.
model_config
is
not
None
and
\
not
self
.
model_config
.
enforce_eager
:
max_batchsize_to_capture
=
\
self
.
get_max_graph_batch_size
(
self
.
scheduler_config
.
max_num_seqs
)
batch_size_capture_list
=
[
size
for
size
in
_BATCH_SIZES_TO_CAPTURE
if
size
<=
max_batchsize_to_capture
]
else
:
batch_size_capture_list
=
[]
if
self
.
model_config
is
not
None
and
\
not
self
.
model_config
.
enforce_eager
:
batch_size_capture_list
=
[
1
,
2
,
4
]
+
[
i
for
i
in
range
(
8
,
513
,
8
)]
self
.
compilation_config
.
init_with_cudagraph_sizes
(
batch_size_capture_list
)
self
.
_set_cudagraph_sizes
()
if
self
.
cache_config
is
not
None
and
\
self
.
cache_config
.
cpu_offload_gb
>
0
and
\
...
...
@@ -2659,6 +2634,70 @@ class VllmConfig:
if
not
self
.
instance_id
:
self
.
instance_id
=
random_uuid
()[:
5
]
def
_set_cudagraph_sizes
(
self
):
"""
cudagraph batchsize padding logic:
`[1, 2, 4] + [8 * i for i in range(1, 1025)]` is a list of all possible
batch sizes that cudagraph will capture.
Depending on the engine's configuration of `max_num_seqs`, the
candidate batch sizes to capture cudagraph will shrink to the subset
which just cover the range of `[1, max_num_seqs]`. In the common case,
`max_num_seqs` is 256, and the cudagraph batch sizes will be
`[1, 2, 4, 8, 16, 24, 32, 40, ..., 256]`.
However, if users specify the cudagraph capture sizes through
compilation config, we will use the specified sizes instead.
In the end, `vllm_config.compilation_config.capture_sizes` will be the
final sizes to capture cudagraph (in descending order).
During runtime, if batchsize is larger than
`vllm_config.compilation_config.capture_sizes`,
no cudagraph will be used.
If the batch size is no larger than
`vllm_config.compilation_config.capture_sizes`,
we can quickly find the padded graph size for a given batch size by
looking up `vllm_config.compilation_config.bs_to_padded_graph_size`.
"""
# calculate the default `batch_size_capture_list`
if
not
envs
.
VLLM_USE_V1
:
batch_size_capture_list
=
[]
max_batchsize_to_capture
=
0
if
self
.
scheduler_config
is
not
None
and
\
self
.
model_config
is
not
None
and
\
not
self
.
model_config
.
enforce_eager
:
possible_sizes
=
[
1
,
2
,
4
]
+
[
8
*
i
for
i
in
range
(
1
,
1025
)]
# find the minimum size that is larger than max_num_seqs,
# which then becomes the max_batchsize_to_capture
larger_sizes
=
[
x
for
x
in
possible_sizes
if
x
>=
self
.
scheduler_config
.
max_num_seqs
]
if
larger_sizes
:
max_batchsize_to_capture
=
larger_sizes
[
0
]
else
:
max_batchsize_to_capture
=
possible_sizes
[
-
1
]
# filter out the sizes that are
# larger than max_batchsize_to_capture
batch_size_capture_list
=
[
size
for
size
in
possible_sizes
if
size
<=
max_batchsize_to_capture
]
else
:
batch_size_capture_list
=
[]
if
self
.
model_config
is
not
None
and
\
not
self
.
model_config
.
enforce_eager
:
batch_size_capture_list
=
[
1
,
2
,
4
]
+
[
i
for
i
in
range
(
8
,
513
,
8
)]
self
.
compilation_config
.
init_with_cudagraph_sizes
(
batch_size_capture_list
)
def
__str__
(
self
):
return
(
f
"model=
{
self
.
model_config
.
model
!
r
}
,"
...
...
vllm/model_executor/models/jamba.py
View file @
be39e3cd
...
...
@@ -7,7 +7,7 @@ from transformers import JambaConfig
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.layer
import
Attention
from
vllm.config
import
_BATCH_SIZES_TO_CAPTURE
,
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
...
...
@@ -420,6 +420,17 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
if
self
.
scheduler_config
is
not
None
and
\
not
self
.
model_config
.
enforce_eager
:
if
self
.
scheduler_config
.
max_num_seqs
>
\
vllm_config
.
compilation_config
.
max_capture_size
:
self
.
max_batch_size
=
\
vllm_config
.
compilation_config
.
max_capture_size
else
:
self
.
max_batch_size
=
vllm_config
.
pad_for_cudagraph
(
self
.
scheduler_config
.
max_num_seqs
)
else
:
self
.
max_batch_size
=
8192
+
2
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
...
...
@@ -433,15 +444,12 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
if
self
.
mamba_cache
is
None
:
max_batch_size
=
(
VllmConfig
.
get_graph_batch_size
(
self
.
scheduler_config
.
max_num_seqs
)
if
self
.
scheduler_config
else
max
(
_BATCH_SIZES_TO_CAPTURE
)
+
2
)
num_mamba_layers
=
self
.
model_config
.
get_num_layers_by_block_type
(
self
.
vllm_config
.
parallel_config
,
LayerBlockType
.
mamba
)
self
.
mamba_cache
=
MambaCacheManager
(
self
.
lm_head
.
weight
.
dtype
,
num_mamba_layers
,
max_batch_size
,
*
self
.
_get_mamba_cache_shape
())
self
.
lm_head
.
weight
.
dtype
,
num_mamba_layers
,
self
.
max_batch_size
,
*
self
.
_get_mamba_cache_shape
())
(
mamba_cache_tensors
,
state_indices_tensor
,
...
...
vllm/model_executor/models/mamba.py
View file @
be39e3cd
...
...
@@ -6,7 +6,7 @@ from torch import nn
from
transformers
import
MambaConfig
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
_BATCH_SIZES_TO_CAPTURE
,
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.model_executor.layers.layernorm
import
RMSNorm
...
...
@@ -195,6 +195,17 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
backbone
.
make_empty_intermediate_tensors
)
if
self
.
scheduler_config
is
not
None
and
\
not
self
.
model_config
.
enforce_eager
:
if
self
.
scheduler_config
.
max_num_seqs
>
\
vllm_config
.
compilation_config
.
max_capture_size
:
self
.
max_batch_size
=
\
vllm_config
.
compilation_config
.
max_capture_size
else
:
self
.
max_batch_size
=
vllm_config
.
pad_for_cudagraph
(
self
.
scheduler_config
.
max_num_seqs
)
else
:
self
.
max_batch_size
=
8192
+
2
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
backbone
.
get_input_embeddings
(
input_ids
)
...
...
@@ -208,15 +219,11 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
if
self
.
mamba_cache
is
None
:
max_batch_size
=
(
VllmConfig
.
get_graph_batch_size
(
self
.
scheduler_config
.
max_num_seqs
)
if
self
.
scheduler_config
else
max
(
_BATCH_SIZES_TO_CAPTURE
)
+
2
)
num_mamba_layers
=
self
.
model_config
.
get_num_layers_by_block_type
(
self
.
vllm_config
.
parallel_config
,
LayerBlockType
.
mamba
)
self
.
mamba_cache
=
MambaCacheManager
(
self
.
lm_head
.
weight
.
dtype
,
num_mamba_layers
,
max_batch_size
,
*
self
.
_get_mamba_cache_shape
())
self
.
lm_head
.
weight
.
dtype
,
num_mamba_layers
,
self
.
max_batch_size
,
*
self
.
_get_mamba_cache_shape
())
(
mamba_cache_tensors
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
be39e3cd
import
gc
import
time
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Tuple
import
numpy
as
np
import
torch
...
...
@@ -459,7 +459,7 @@ class GPUModelRunner:
and
num_scheduled_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens
=
self
.
_get_padded_batch_size
(
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_scheduled_tokens
)
else
:
# Eager mode.
...
...
@@ -641,10 +641,3 @@ class GPUModelRunner:
torch
.
zeros
(
kv_cache_shape
,
dtype
=
self
.
kv_cache_dtype
,
device
=
self
.
device
))
def
_get_padded_batch_size
(
self
,
batch_size
:
int
)
->
Optional
[
int
]:
# TODO: Optimize this?
for
size
in
self
.
cudagraph_batch_sizes
:
if
batch_size
<=
size
:
return
size
return
None
vllm/worker/enc_dec_model_runner.py
View file @
be39e3cd
...
...
@@ -464,7 +464,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
# We will be using CUDA graph replay for this decode.
max_len_of_block_table
=
self
.
get_max_block_per_batch
()
batch_size
=
len
(
encoder_seq_lens
)
graph_batch_size
=
self
.
vllm_config
.
get_graph_batch_size
(
graph_batch_size
=
self
.
vllm_config
.
pad_for_cudagraph
(
batch_size
)
assert
graph_batch_size
>=
batch_size
cuda_graph_pad_size
=
graph_batch_size
-
batch_size
...
...
vllm/worker/model_runner.py
View file @
be39e3cd
...
...
@@ -802,7 +802,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
max_encoder_seq_len
):
return
-
1
graph_batch_size
=
VllmConfig
.
get_graph_batch_size
(
batch_size
)
graph_batch_size
=
self
.
runner
.
vllm_config
.
pad_for_cudagraph
(
batch_size
)
assert
graph_batch_size
>=
batch_size
return
graph_batch_size
-
batch_size
...
...
@@ -1014,8 +1015,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
sliding_window
=
model_config
.
get_sliding_window
()
self
.
block_size
=
cache_config
.
block_size
self
.
max_seq_len_to_capture
=
self
.
model_config
.
max_seq_len_to_capture
self
.
max_batchsize_to_capture
=
VllmConfig
.
get_max_graph_batch_size
(
self
.
scheduler_config
.
max_num_seqs
)
self
.
max_batchsize_to_capture
=
\
self
.
vllm_config
.
compilation_config
.
max_capture_size
self
.
graph_runners
:
List
[
Dict
[
int
,
CUDAGraphRunner
]]
=
[
{}
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
...
...
vllm/worker/xpu_model_runner.py
View file @
be39e3cd
...
...
@@ -37,10 +37,6 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
_PAD_SLOT_ID
=
-
1
_BATCH_SIZE_ALIGNMENT
=
8
_BATCH_SIZES_TO_CAPTURE
=
[
1
,
2
,
4
]
+
[
_BATCH_SIZE_ALIGNMENT
*
i
for
i
in
range
(
1
,
33
)
]
TModelInputForXPU
=
TypeVar
(
'TModelInputForXPU'
,
bound
=
"ModelInputForXPU"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment