Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0bd1ff43
Unverified
Commit
0bd1ff43
authored
Jan 09, 2025
by
Cyrus Leung
Committed by
GitHub
Jan 09, 2025
Browse files
[Bugfix] Override dunder methods of placeholder modules (#11882)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
310aca88
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
220 additions
and
16 deletions
+220
-16
tests/test_utils.py
tests/test_utils.py
+44
-3
vllm/utils.py
vllm/utils.py
+176
-13
No files found.
tests/test_utils.py
View file @
0bd1ff43
...
@@ -7,9 +7,9 @@ import pytest
...
@@ -7,9 +7,9 @@ import pytest
import
torch
import
torch
from
vllm_test_utils
import
monitor
from
vllm_test_utils
import
monitor
from
vllm.utils
import
(
FlexibleArgumentParser
,
StoreBoolean
,
deprecate_kwargs
,
from
vllm.utils
import
(
FlexibleArgumentParser
,
PlaceholderModule
,
get_open_port
,
memory_profiling
,
merge_async_iterators
,
StoreBoolean
,
deprecate_kwargs
,
get_open_port
,
supports_kw
)
memory_profiling
,
merge_async_iterators
,
supports_kw
)
from
.utils
import
error_on_warning
,
fork_new_process_for_each_test
from
.utils
import
error_on_warning
,
fork_new_process_for_each_test
...
@@ -323,3 +323,44 @@ def test_memory_profiling():
...
@@ -323,3 +323,44 @@ def test_memory_profiling():
del
weights
del
weights
lib
.
cudaFree
(
handle1
)
lib
.
cudaFree
(
handle1
)
lib
.
cudaFree
(
handle2
)
lib
.
cudaFree
(
handle2
)
def
test_placeholder_module_error_handling
():
placeholder
=
PlaceholderModule
(
"placeholder_1234"
)
def
build_ctx
():
return
pytest
.
raises
(
ModuleNotFoundError
,
match
=
"No module named"
)
with
build_ctx
():
int
(
placeholder
)
with
build_ctx
():
placeholder
()
with
build_ctx
():
_
=
placeholder
.
some_attr
with
build_ctx
():
# Test conflict with internal __name attribute
_
=
placeholder
.
name
# OK to print the placeholder or use it in a f-string
_
=
repr
(
placeholder
)
_
=
str
(
placeholder
)
# No error yet; only error when it is used downstream
placeholder_attr
=
placeholder
.
placeholder_attr
(
"attr"
)
with
build_ctx
():
int
(
placeholder_attr
)
with
build_ctx
():
placeholder_attr
()
with
build_ctx
():
_
=
placeholder_attr
.
some_attr
with
build_ctx
():
# Test conflict with internal __module attribute
_
=
placeholder_attr
.
module
vllm/utils.py
View file @
0bd1ff43
...
@@ -46,7 +46,7 @@ import zmq
...
@@ -46,7 +46,7 @@ import zmq
import
zmq.asyncio
import
zmq.asyncio
from
packaging.version
import
Version
from
packaging.version
import
Version
from
torch.library
import
Library
from
torch.library
import
Library
from
typing_extensions
import
ParamSpec
,
TypeIs
,
assert_never
from
typing_extensions
import
Never
,
ParamSpec
,
TypeIs
,
assert_never
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.logger
import
enable_trace_function_call
,
init_logger
from
vllm.logger
import
enable_trace_function_call
,
init_logger
...
@@ -1627,24 +1627,183 @@ def get_vllm_optional_dependencies():
...
@@ -1627,24 +1627,183 @@ def get_vllm_optional_dependencies():
}
}
@
dataclass
(
frozen
=
True
)
class
_PlaceholderBase
:
class
PlaceholderModule
:
"""
Disallows downstream usage of placeholder modules.
We need to explicitly override each dunder method because
:meth:`__getattr__` is not called when they are accessed.
See also:
[Special method lookup](https://docs.python.org/3/reference/datamodel.html#special-lookup)
"""
def
__getattr__
(
self
,
key
:
str
)
->
Never
:
"""
The main class should implement this to throw an error
for attribute accesses representing downstream usage.
"""
raise
NotImplementedError
# [Basic customization]
def
__lt__
(
self
,
other
:
object
):
return
self
.
__getattr__
(
"__lt__"
)
def
__le__
(
self
,
other
:
object
):
return
self
.
__getattr__
(
"__le__"
)
def
__eq__
(
self
,
other
:
object
):
return
self
.
__getattr__
(
"__eq__"
)
def
__ne__
(
self
,
other
:
object
):
return
self
.
__getattr__
(
"__ne__"
)
def
__gt__
(
self
,
other
:
object
):
return
self
.
__getattr__
(
"__gt__"
)
def
__ge__
(
self
,
other
:
object
):
return
self
.
__getattr__
(
"__ge__"
)
def
__hash__
(
self
):
return
self
.
__getattr__
(
"__hash__"
)
def
__bool__
(
self
):
return
self
.
__getattr__
(
"__bool__"
)
# [Callable objects]
def
__call__
(
self
,
*
args
:
object
,
**
kwargs
:
object
):
return
self
.
__getattr__
(
"__call__"
)
# [Container types]
def
__len__
(
self
):
return
self
.
__getattr__
(
"__len__"
)
def
__getitem__
(
self
,
key
:
object
):
return
self
.
__getattr__
(
"__getitem__"
)
def
__setitem__
(
self
,
key
:
object
,
value
:
object
):
return
self
.
__getattr__
(
"__setitem__"
)
def
__delitem__
(
self
,
key
:
object
):
return
self
.
__getattr__
(
"__delitem__"
)
# __missing__ is optional according to __getitem__ specification,
# so it is skipped
# __iter__ and __reversed__ have a default implementation
# based on __len__ and __getitem__, so they are skipped.
# [Numeric Types]
def
__add__
(
self
,
other
:
object
):
return
self
.
__getattr__
(
"__add__"
)
def
__sub__
(
self
,
other
:
object
):
return
self
.
__getattr__
(
"__sub__"
)
def
__mul__
(
self
,
other
:
object
):
return
self
.
__getattr__
(
"__mul__"
)
def
__matmul__
(
self
,
other
:
object
):
return
self
.
__getattr__
(
"__matmul__"
)
def
__truediv__
(
self
,
other
:
object
):
return
self
.
__getattr__
(
"__truediv__"
)
def
__floordiv__
(
self
,
other
:
object
):
return
self
.
__getattr__
(
"__floordiv__"
)
def
__mod__
(
self
,
other
:
object
):
return
self
.
__getattr__
(
"__mod__"
)
def
__divmod__
(
self
,
other
:
object
):
return
self
.
__getattr__
(
"__divmod__"
)
def
__pow__
(
self
,
other
:
object
,
modulo
:
object
=
...):
return
self
.
__getattr__
(
"__pow__"
)
def
__lshift__
(
self
,
other
:
object
):
return
self
.
__getattr__
(
"__lshift__"
)
def
__rshift__
(
self
,
other
:
object
):
return
self
.
__getattr__
(
"__rshift__"
)
def
__and__
(
self
,
other
:
object
):
return
self
.
__getattr__
(
"__and__"
)
def
__xor__
(
self
,
other
:
object
):
return
self
.
__getattr__
(
"__xor__"
)
def
__or__
(
self
,
other
:
object
):
return
self
.
__getattr__
(
"__or__"
)
# r* and i* methods have lower priority than
# the methods for left operand so they are skipped
def
__neg__
(
self
):
return
self
.
__getattr__
(
"__neg__"
)
def
__pos__
(
self
):
return
self
.
__getattr__
(
"__pos__"
)
def
__abs__
(
self
):
return
self
.
__getattr__
(
"__abs__"
)
def
__invert__
(
self
):
return
self
.
__getattr__
(
"__invert__"
)
# __complex__, __int__ and __float__ have a default implementation
# based on __index__, so they are skipped.
def
__index__
(
self
):
return
self
.
__getattr__
(
"__index__"
)
def
__round__
(
self
,
ndigits
:
object
=
...):
return
self
.
__getattr__
(
"__round__"
)
def
__trunc__
(
self
):
return
self
.
__getattr__
(
"__trunc__"
)
def
__floor__
(
self
):
return
self
.
__getattr__
(
"__floor__"
)
def
__ceil__
(
self
):
return
self
.
__getattr__
(
"__ceil__"
)
# [Context managers]
def
__enter__
(
self
):
return
self
.
__getattr__
(
"__enter__"
)
def
__exit__
(
self
,
*
args
:
object
,
**
kwargs
:
object
):
return
self
.
__getattr__
(
"__exit__"
)
class
PlaceholderModule
(
_PlaceholderBase
):
"""
"""
A placeholder object to use when a module does not exist.
A placeholder object to use when a module does not exist.
This enables more informative errors when trying to access attributes
This enables more informative errors when trying to access attributes
of a module that does not exists.
of a module that does not exists.
"""
"""
name
:
str
def
__init__
(
self
,
name
:
str
)
->
None
:
super
().
__init__
()
# Apply name mangling to avoid conflicting with module attributes
self
.
__name
=
name
def
placeholder_attr
(
self
,
attr_path
:
str
):
def
placeholder_attr
(
self
,
attr_path
:
str
):
return
_PlaceholderModuleAttr
(
self
,
attr_path
)
return
_PlaceholderModuleAttr
(
self
,
attr_path
)
def
__getattr__
(
self
,
key
:
str
):
def
__getattr__
(
self
,
key
:
str
):
name
=
self
.
name
name
=
self
.
__
name
try
:
try
:
importlib
.
import_module
(
self
.
name
)
importlib
.
import_module
(
name
)
except
ImportError
as
exc
:
except
ImportError
as
exc
:
for
extra
,
names
in
get_vllm_optional_dependencies
().
items
():
for
extra
,
names
in
get_vllm_optional_dependencies
().
items
():
if
name
in
names
:
if
name
in
names
:
...
@@ -1657,17 +1816,21 @@ class PlaceholderModule:
...
@@ -1657,17 +1816,21 @@ class PlaceholderModule:
"when the original module can be imported"
)
"when the original module can be imported"
)
@
dataclass
(
frozen
=
True
)
class
_PlaceholderModuleAttr
(
_PlaceholderBase
):
class
_PlaceholderModuleAttr
:
module
:
PlaceholderModule
def
__init__
(
self
,
module
:
PlaceholderModule
,
attr_path
:
str
)
->
None
:
attr_path
:
str
super
().
__init__
()
# Apply name mangling to avoid conflicting with module attributes
self
.
__module
=
module
self
.
__attr_path
=
attr_path
def
placeholder_attr
(
self
,
attr_path
:
str
):
def
placeholder_attr
(
self
,
attr_path
:
str
):
return
_PlaceholderModuleAttr
(
self
.
module
,
return
_PlaceholderModuleAttr
(
self
.
__
module
,
f
"
{
self
.
attr_path
}
.
{
attr_path
}
"
)
f
"
{
self
.
__
attr_path
}
.
{
attr_path
}
"
)
def
__getattr__
(
self
,
key
:
str
):
def
__getattr__
(
self
,
key
:
str
):
getattr
(
self
.
module
,
f
"
{
self
.
attr_path
}
.
{
key
}
"
)
getattr
(
self
.
__
module
,
f
"
{
self
.
__
attr_path
}
.
{
key
}
"
)
raise
AssertionError
(
"PlaceholderModule should not be used "
raise
AssertionError
(
"PlaceholderModule should not be used "
"when the original module can be imported"
)
"when the original module can be imported"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment