Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8c6de96e
Unverified
Commit
8c6de96e
authored
Oct 07, 2024
by
Cyrus Leung
Committed by
GitHub
Oct 07, 2024
Browse files
[Model] Explicit interface for vLLM models and support OOT embedding models (#9108)
parent
18b296fd
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
342 additions
and
37 deletions
+342
-37
tests/conftest.py
tests/conftest.py
+20
-0
tests/models/test_oot_registration.py
tests/models/test_oot_registration.py
+15
-3
tests/models/test_registry.py
tests/models/test_registry.py
+22
-2
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py
...ins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py
+6
-0
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
...dd_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
+34
-0
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+7
-0
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+7
-21
vllm/model_executor/models/interfaces_base.py
vllm/model_executor/models/interfaces_base.py
+191
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+31
-11
vllm/utils.py
vllm/utils.py
+9
-0
No files found.
tests/conftest.py
View file @
8c6de96e
...
...
@@ -871,6 +871,7 @@ def num_gpus_available():
temp_dir
=
tempfile
.
gettempdir
()
_dummy_opt_path
=
os
.
path
.
join
(
temp_dir
,
"dummy_opt"
)
_dummy_llava_path
=
os
.
path
.
join
(
temp_dir
,
"dummy_llava"
)
_dummy_gemma2_embedding_path
=
os
.
path
.
join
(
temp_dir
,
"dummy_gemma2_embedding"
)
@
pytest
.
fixture
...
...
@@ -909,3 +910,22 @@ def dummy_llava_path():
with
open
(
json_path
,
"w"
)
as
f
:
json
.
dump
(
config
,
f
)
return
_dummy_llava_path
@
pytest
.
fixture
def
dummy_gemma2_embedding_path
():
json_path
=
os
.
path
.
join
(
_dummy_gemma2_embedding_path
,
"config.json"
)
if
not
os
.
path
.
exists
(
_dummy_gemma2_embedding_path
):
snapshot_download
(
repo_id
=
"BAAI/bge-multilingual-gemma2"
,
local_dir
=
_dummy_gemma2_embedding_path
,
ignore_patterns
=
[
"*.bin"
,
"*.bin.index.json"
,
"*.pt"
,
"*.h5"
,
"*.msgpack"
])
assert
os
.
path
.
exists
(
json_path
)
with
open
(
json_path
,
"r"
)
as
f
:
config
=
json
.
load
(
f
)
config
[
"architectures"
]
=
[
"MyGemma2Embedding"
]
with
open
(
json_path
,
"w"
)
as
f
:
json
.
dump
(
config
,
f
)
return
_dummy_gemma2_embedding_path
tests/models/test_oot_registration.py
View file @
8c6de96e
...
...
@@ -2,7 +2,7 @@ import os
import
pytest
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
PoolingParams
,
SamplingParams
from
vllm.assets.image
import
ImageAsset
from
..utils
import
fork_new_process_for_each_test
...
...
@@ -17,7 +17,7 @@ def test_plugin(dummy_opt_path):
@
fork_new_process_for_each_test
def
test_oot_registration
(
dummy_opt_path
):
def
test_oot_registration
_text_generation
(
dummy_opt_path
):
os
.
environ
[
"VLLM_PLUGINS"
]
=
"register_dummy_model"
prompts
=
[
"Hello, my name is"
,
"The text does not matter"
]
sampling_params
=
SamplingParams
(
temperature
=
0
)
...
...
@@ -32,11 +32,23 @@ def test_oot_registration(dummy_opt_path):
assert
rest
==
""
@
fork_new_process_for_each_test
def
test_oot_registration_embedding
(
dummy_gemma2_embedding_path
):
os
.
environ
[
"VLLM_PLUGINS"
]
=
"register_dummy_model"
prompts
=
[
"Hello, my name is"
,
"The text does not matter"
]
sampling_params
=
PoolingParams
()
llm
=
LLM
(
model
=
dummy_gemma2_embedding_path
,
load_format
=
"dummy"
)
outputs
=
llm
.
encode
(
prompts
,
sampling_params
)
for
output
in
outputs
:
assert
all
(
v
==
0
for
v
in
output
.
outputs
.
embedding
)
image
=
ImageAsset
(
"cherry_blossom"
).
pil_image
.
convert
(
"RGB"
)
@
fork_new_process_for_each_test
def
test_oot_
multimodal_
registration
(
dummy_llava_path
):
def
test_oot_registration
_multimodal
(
dummy_llava_path
):
os
.
environ
[
"VLLM_PLUGINS"
]
=
"register_dummy_model"
prompts
=
[{
"prompt"
:
"What's in the image?<image>"
,
...
...
tests/models/test_registry.py
View file @
8c6de96e
...
...
@@ -3,7 +3,14 @@ import warnings
import
pytest
import
torch.cuda
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
(
is_embedding_model
,
is_text_generation_model
,
supports_multimodal
)
from
vllm.model_executor.models.registry
import
(
_EMBEDDING_MODELS
,
_MULTIMODAL_MODELS
,
_SPECULATIVE_DECODING_MODELS
,
_TEXT_GENERATION_MODELS
,
ModelRegistry
)
from
vllm.platforms
import
current_platform
from
..utils
import
fork_new_process_for_each_test
...
...
@@ -12,7 +19,20 @@ from ..utils import fork_new_process_for_each_test
@
pytest
.
mark
.
parametrize
(
"model_arch"
,
ModelRegistry
.
get_supported_archs
())
def
test_registry_imports
(
model_arch
):
# Ensure all model classes can be imported successfully
ModelRegistry
.
resolve_model_cls
(
model_arch
)
model_cls
,
_
=
ModelRegistry
.
resolve_model_cls
(
model_arch
)
if
model_arch
in
_SPECULATIVE_DECODING_MODELS
:
pass
# Ignore these models which do not have a unified format
else
:
assert
is_text_generation_model
(
model_cls
)
is
(
model_arch
in
_TEXT_GENERATION_MODELS
or
model_arch
in
_MULTIMODAL_MODELS
)
assert
is_embedding_model
(
model_cls
)
is
(
model_arch
in
_EMBEDDING_MODELS
)
assert
supports_multimodal
(
model_cls
)
is
(
model_arch
in
_MULTIMODAL_MODELS
)
@
fork_new_process_for_each_test
...
...
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py
View file @
8c6de96e
...
...
@@ -9,6 +9,12 @@ def register():
ModelRegistry
.
register_model
(
"MyOPTForCausalLM"
,
MyOPTForCausalLM
)
# Test passing lazy model
if
"MyGemma2Embedding"
not
in
ModelRegistry
.
get_supported_archs
():
ModelRegistry
.
register_model
(
"MyGemma2Embedding"
,
"vllm_add_dummy_model.my_gemma_embedding:MyGemma2Embedding"
,
)
if
"MyLlava"
not
in
ModelRegistry
.
get_supported_archs
():
ModelRegistry
.
register_model
(
"MyLlava"
,
"vllm_add_dummy_model.my_llava:MyLlava"
)
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
0 → 100644
View file @
8c6de96e
from
typing
import
List
,
Optional
,
Union
import
torch
from
vllm.attention
import
AttentionMetadata
from
vllm.model_executor.models.gemma2_embedding
import
Gemma2EmbeddingModel
from
vllm.sequence
import
IntermediateTensors
class
MyGemma2Embedding
(
Gemma2EmbeddingModel
):
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
super
().
forward
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
)
if
isinstance
(
hidden_states
,
IntermediateTensors
):
return
hidden_states
# Return all-zero embeddings
return
torch
.
zeros_like
(
hidden_states
)
vllm/model_executor/models/__init__.py
View file @
8c6de96e
from
.interfaces
import
(
HasInnerState
,
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
,
has_inner_state
,
supports_lora
,
supports_multimodal
,
supports_pp
)
from
.interfaces_base
import
(
VllmModelForEmbedding
,
VllmModelForTextGeneration
,
is_embedding_model
,
is_text_generation_model
)
from
.registry
import
ModelRegistry
__all__
=
[
"ModelRegistry"
,
"VllmModelForEmbedding"
,
"is_embedding_model"
,
"VllmModelForTextGeneration"
,
"is_text_generation_model"
,
"HasInnerState"
,
"has_inner_state"
,
"SupportsLoRA"
,
...
...
vllm/model_executor/models/interfaces.py
View file @
8c6de96e
import
inspect
from
typing
import
(
TYPE_CHECKING
,
ClassVar
,
Dict
,
List
,
Literal
,
Optional
,
Protocol
,
Type
,
Union
,
overload
,
runtime_checkable
)
...
...
@@ -6,9 +5,9 @@ import torch
from
typing_extensions
import
TypeIs
from
vllm.logger
import
init_logger
from
vllm.utils
import
supports_kw
if
TYPE_CHECKING
:
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
LoRAConfig
,
MultiModalConfig
,
SchedulerConfig
from
vllm.sequence
import
IntermediateTensors
...
...
@@ -142,9 +141,7 @@ def supports_lora(
return
result
def
_supports_lora
(
model
:
Union
[
Type
[
object
],
object
],
)
->
Union
[
TypeIs
[
Type
[
SupportsLoRA
]],
TypeIs
[
SupportsLoRA
]]:
def
_supports_lora
(
model
:
Union
[
Type
[
object
],
object
])
->
bool
:
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
_SupportsLoRAType
)
...
...
@@ -175,10 +172,7 @@ class SupportsPP(Protocol):
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
"AttentionMetadata"
,
*
,
intermediate_tensors
:
Optional
[
"IntermediateTensors"
],
)
->
Union
[
torch
.
Tensor
,
"IntermediateTensors"
]:
"""
...
...
@@ -205,10 +199,7 @@ class _SupportsPPType(Protocol):
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
"AttentionMetadata"
,
*
,
intermediate_tensors
:
Optional
[
"IntermediateTensors"
],
)
->
Union
[
torch
.
Tensor
,
"IntermediateTensors"
]:
...
...
...
@@ -257,24 +248,19 @@ def supports_pp(
return
supports_attributes
and
supports_inspect
def
_supports_pp_attributes
(
model
:
Union
[
Type
[
object
],
object
],
)
->
Union
[
bool
,
TypeIs
[
Type
[
SupportsPP
]],
TypeIs
[
SupportsPP
]]:
def
_supports_pp_attributes
(
model
:
Union
[
Type
[
object
],
object
])
->
bool
:
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
_SupportsPPType
)
return
isinstance
(
model
,
SupportsPP
)
def
_supports_pp_inspect
(
model
:
Union
[
Type
[
object
],
object
],
)
->
Union
[
bool
,
TypeIs
[
Type
[
SupportsPP
]],
TypeIs
[
SupportsPP
]]:
def
_supports_pp_inspect
(
model
:
Union
[
Type
[
object
],
object
])
->
bool
:
model_forward
=
getattr
(
model
,
"forward"
,
None
)
if
not
callable
(
model_forward
):
return
False
forward_params
=
inspect
.
signature
(
model_forward
).
parameters
return
"intermediate_tensors"
in
forward_params
return
supports_kw
(
model_forward
,
"intermediate_tensors"
)
@
runtime_checkable
...
...
vllm/model_executor/models/interfaces_base.py
0 → 100644
View file @
8c6de96e
from
typing
import
(
TYPE_CHECKING
,
List
,
Optional
,
Protocol
,
Type
,
Union
,
overload
,
runtime_checkable
)
import
torch
import
torch.nn
as
nn
from
transformers
import
PretrainedConfig
from
typing_extensions
import
TypeIs
,
TypeVar
from
vllm.logger
import
init_logger
from
vllm.utils
import
supports_kw
if
TYPE_CHECKING
:
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.model_executor.layers.pooler
import
PoolerOutput
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
logger
=
init_logger
(
__name__
)
# The type of HF config
C_co
=
TypeVar
(
"C_co"
,
bound
=
PretrainedConfig
,
covariant
=
True
)
# The type of hidden states
# Currently, T = torch.Tensor for all models except for Medusa
# which has T = List[torch.Tensor]
T
=
TypeVar
(
"T"
,
default
=
torch
.
Tensor
)
T_co
=
TypeVar
(
"T_co"
,
default
=
torch
.
Tensor
,
covariant
=
True
)
# NOTE: Unlike those in `interfaces.py`, we don't define `ClassVar` tags
# for the base interfaces to avoid breaking OOT registration for existing models
# that don't inherit from the base interface classes
@
runtime_checkable
class
VllmModel
(
Protocol
[
C_co
,
T_co
]):
def
__init__
(
self
,
config
:
C_co
,
*
,
cache_config
:
Optional
[
"CacheConfig"
],
quant_config
:
Optional
[
"QuantizationConfig"
],
)
->
None
:
...
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
"AttentionMetadata"
,
)
->
T_co
:
...
def
_check_vllm_model_init
(
model
:
Union
[
Type
[
object
],
object
])
->
bool
:
model_init
=
model
.
__init__
vllm_kws
=
(
"cache_config"
,
"quant_config"
)
missing_kws
=
tuple
(
kw
for
kw
in
vllm_kws
if
not
supports_kw
(
model_init
,
kw
))
if
missing_kws
and
(
isinstance
(
model
,
type
)
and
issubclass
(
model
,
nn
.
Module
)):
logger
.
warning
(
"The model (%s) is missing "
"vLLM-specific keywords from its initializer: %s"
,
model
,
missing_kws
,
)
return
len
(
missing_kws
)
==
0
def
_check_vllm_model_forward
(
model
:
Union
[
Type
[
object
],
object
])
->
bool
:
model_forward
=
getattr
(
model
,
"forward"
,
None
)
if
not
callable
(
model_forward
):
return
False
vllm_kws
=
(
"input_ids"
,
"positions"
,
"kv_caches"
,
"attn_metadata"
)
missing_kws
=
tuple
(
kw
for
kw
in
vllm_kws
if
not
supports_kw
(
model_forward
,
kw
))
if
missing_kws
and
(
isinstance
(
model
,
type
)
and
issubclass
(
model
,
nn
.
Module
)):
logger
.
warning
(
"The model (%s) is missing "
"vLLM-specific keywords from its initializer: %s"
,
model
,
missing_kws
,
)
return
len
(
missing_kws
)
==
0
@
overload
def
is_vllm_model
(
model
:
Type
[
object
])
->
TypeIs
[
Type
[
VllmModel
]]:
...
@
overload
def
is_vllm_model
(
model
:
object
)
->
TypeIs
[
VllmModel
]:
...
def
is_vllm_model
(
model
:
Union
[
Type
[
object
],
object
],
)
->
Union
[
TypeIs
[
Type
[
VllmModel
]],
TypeIs
[
VllmModel
]]:
return
_check_vllm_model_init
(
model
)
and
_check_vllm_model_forward
(
model
)
@
runtime_checkable
class
VllmModelForTextGeneration
(
VllmModel
[
C_co
,
T
],
Protocol
[
C_co
,
T
]):
def
compute_logits
(
self
,
hidden_states
:
T
,
sampling_metadata
:
"SamplingMetadata"
,
)
->
Optional
[
T
]:
"""Return `None` if TP rank > 0."""
...
def
sample
(
self
,
logits
:
T
,
sampling_metadata
:
"SamplingMetadata"
,
)
->
"SamplerOutput"
:
"""Only called on TP rank 0."""
...
@
overload
def
is_text_generation_model
(
model
:
Type
[
object
])
->
TypeIs
[
Type
[
VllmModelForTextGeneration
]]:
...
@
overload
def
is_text_generation_model
(
model
:
object
)
->
TypeIs
[
VllmModelForTextGeneration
]:
...
def
is_text_generation_model
(
model
:
Union
[
Type
[
object
],
object
],
)
->
Union
[
TypeIs
[
Type
[
VllmModelForTextGeneration
]],
TypeIs
[
VllmModelForTextGeneration
]]:
if
not
is_vllm_model
(
model
):
return
False
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
VllmModelForTextGeneration
)
return
isinstance
(
model
,
VllmModelForTextGeneration
)
@
runtime_checkable
class
VllmModelForEmbedding
(
VllmModel
[
C_co
,
T
],
Protocol
[
C_co
,
T
]):
def
pooler
(
self
,
hidden_states
:
T
,
pooling_metadata
:
"PoolingMetadata"
,
)
->
"PoolerOutput"
:
"""Only called on TP rank 0."""
...
@
overload
def
is_embedding_model
(
model
:
Type
[
object
])
->
TypeIs
[
Type
[
VllmModelForEmbedding
]]:
...
@
overload
def
is_embedding_model
(
model
:
object
)
->
TypeIs
[
VllmModelForEmbedding
]:
...
def
is_embedding_model
(
model
:
Union
[
Type
[
object
],
object
],
)
->
Union
[
TypeIs
[
Type
[
VllmModelForEmbedding
]],
TypeIs
[
VllmModelForEmbedding
]]:
if
not
is_vllm_model
(
model
):
return
False
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
VllmModelForEmbedding
)
return
isinstance
(
model
,
VllmModelForEmbedding
)
vllm/model_executor/models/registry.py
View file @
8c6de96e
...
...
@@ -12,10 +12,12 @@ from vllm.logger import init_logger
from
vllm.utils
import
is_hip
from
.interfaces
import
supports_multimodal
,
supports_pp
from
.interfaces_base
import
is_embedding_model
,
is_text_generation_model
logger
=
init_logger
(
__name__
)
_GENERATION_MODELS
=
{
_TEXT_GENERATION_MODELS
=
{
# [Decoder-only]
"AquilaModel"
:
(
"llama"
,
"LlamaForCausalLM"
),
"AquilaForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
# AquilaChat2
"ArcticForCausalLM"
:
(
"arctic"
,
"ArcticForCausalLM"
),
...
...
@@ -74,10 +76,9 @@ _GENERATION_MODELS = {
"Starcoder2ForCausalLM"
:
(
"starcoder2"
,
"Starcoder2ForCausalLM"
),
"SolarForCausalLM"
:
(
"solar"
,
"SolarForCausalLM"
),
"XverseForCausalLM"
:
(
"xverse"
,
"XverseForCausalLM"
),
# NOTE: The below models are for speculative decoding only
"MedusaModel"
:
(
"medusa"
,
"Medusa"
),
"EAGLEModel"
:
(
"eagle"
,
"EAGLE"
),
"MLPSpeculatorPreTrainedModel"
:
(
"mlp_speculator"
,
"MLPSpeculator"
),
# [Encoder-decoder]
"BartModel"
:
(
"bart"
,
"BartForConditionalGeneration"
),
"BartForConditionalGeneration"
:
(
"bart"
,
"BartForConditionalGeneration"
),
}
_EMBEDDING_MODELS
=
{
...
...
@@ -114,16 +115,18 @@ _MULTIMODAL_MODELS = {
"MllamaForConditionalGeneration"
:
(
"mllama"
,
"MllamaForConditionalGeneration"
),
}
_CONDITIONAL_GENERATION_MODELS
=
{
"BartModel"
:
(
"bart"
,
"BartForConditionalGeneration"
),
"BartForConditionalGeneration"
:
(
"bart"
,
"BartForConditionalGeneration"
),
_SPECULATIVE_DECODING_MODELS
=
{
"EAGLEModel"
:
(
"eagle"
,
"EAGLE"
),
"MedusaModel"
:
(
"medusa"
,
"Medusa"
),
"MLPSpeculatorPreTrainedModel"
:
(
"mlp_speculator"
,
"MLPSpeculator"
),
}
_MODELS
=
{
**
_GENERATION_MODELS
,
**
_
TEXT_
GENERATION_MODELS
,
**
_EMBEDDING_MODELS
,
**
_MULTIMODAL_MODELS
,
**
_
CONDITIONAL_GENERATION
_MODELS
,
**
_
SPECULATIVE_DECODING
_MODELS
,
}
# Architecture -> type or (module, class).
...
...
@@ -317,6 +320,19 @@ class ModelRegistry:
return
result
.
returncode
==
0
@
staticmethod
def
is_text_generation_model
(
architectures
:
Union
[
str
,
List
[
str
]])
->
bool
:
if
isinstance
(
architectures
,
str
):
architectures
=
[
architectures
]
if
not
architectures
:
logger
.
warning
(
"No model architectures are specified"
)
is_txt_gen
=
partial
(
ModelRegistry
.
_check_stateless
,
is_text_generation_model
,
default
=
False
)
return
any
(
is_txt_gen
(
arch
)
for
arch
in
architectures
)
@
staticmethod
def
is_embedding_model
(
architectures
:
Union
[
str
,
List
[
str
]])
->
bool
:
if
isinstance
(
architectures
,
str
):
...
...
@@ -324,7 +340,11 @@ class ModelRegistry:
if
not
architectures
:
logger
.
warning
(
"No model architectures are specified"
)
return
any
(
arch
in
_EMBEDDING_MODELS
for
arch
in
architectures
)
is_emb
=
partial
(
ModelRegistry
.
_check_stateless
,
is_embedding_model
,
default
=
False
)
return
any
(
is_emb
(
arch
)
for
arch
in
architectures
)
@
staticmethod
def
is_multimodal_model
(
architectures
:
Union
[
str
,
List
[
str
]])
->
bool
:
...
...
vllm/utils.py
View file @
8c6de96e
...
...
@@ -1277,6 +1277,15 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
return
await
task
(
*
args
,
**
kwargs
)
def
supports_kw
(
callable
:
Callable
[...,
object
],
kw_name
:
str
)
->
bool
:
params
=
inspect
.
signature
(
callable
).
parameters
if
kw_name
in
params
:
return
True
return
any
(
param
.
kind
==
inspect
.
Parameter
.
VAR_KEYWORD
for
param
in
params
.
values
())
def
get_allowed_kwarg_only_overrides
(
callable
:
Callable
[...,
object
],
overrides
:
Optional
[
Dict
[
str
,
Any
]],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment