Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a7d59688
Unverified
Commit
a7d59688
authored
Jan 13, 2025
by
Shanshan Shen
Committed by
GitHub
Jan 13, 2025
Browse files
[Platform] Move get_punica_wrapper() function to Platform (#11516)
Signed-off-by:
Shanshan Shen
<
467638484@qq.com
>
parent
458e63a2
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
32 additions
and
17 deletions
+32
-17
vllm/lora/punica_wrapper/punica_selector.py
vllm/lora/punica_wrapper/punica_selector.py
+9
-17
vllm/platforms/cpu.py
vllm/platforms/cpu.py
+4
-0
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+4
-0
vllm/platforms/hpu.py
vllm/platforms/hpu.py
+4
-0
vllm/platforms/interface.py
vllm/platforms/interface.py
+7
-0
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+4
-0
No files found.
vllm/lora/punica_wrapper/punica_selector.py
View file @
a7d59688
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
resolve_obj_by_qualname
from
.punica_base
import
PunicaWrapperBase
...
...
@@ -7,20 +8,11 @@ logger = init_logger(__name__)
def
get_punica_wrapper
(
*
args
,
**
kwargs
)
->
PunicaWrapperBase
:
if
current_platform
.
is_cuda_alike
():
# Lazy import to avoid ImportError
from
vllm.lora.punica_wrapper.punica_gpu
import
PunicaWrapperGPU
logger
.
info_once
(
"Using PunicaWrapperGPU."
)
return
PunicaWrapperGPU
(
*
args
,
**
kwargs
)
elif
current_platform
.
is_cpu
():
# Lazy import to avoid ImportError
from
vllm.lora.punica_wrapper.punica_cpu
import
PunicaWrapperCPU
logger
.
info_once
(
"Using PunicaWrapperCPU."
)
return
PunicaWrapperCPU
(
*
args
,
**
kwargs
)
elif
current_platform
.
is_hpu
():
# Lazy import to avoid ImportError
from
vllm.lora.punica_wrapper.punica_hpu
import
PunicaWrapperHPU
logger
.
info_once
(
"Using PunicaWrapperHPU."
)
return
PunicaWrapperHPU
(
*
args
,
**
kwargs
)
else
:
raise
NotImplementedError
punica_wrapper_qualname
=
current_platform
.
get_punica_wrapper
()
punica_wrapper_cls
=
resolve_obj_by_qualname
(
punica_wrapper_qualname
)
punica_wrapper
=
punica_wrapper_cls
(
*
args
,
**
kwargs
)
assert
punica_wrapper
is
not
None
,
\
"the punica_wrapper_qualname("
+
punica_wrapper_qualname
+
") is wrong."
logger
.
info_once
(
"Using "
+
punica_wrapper_qualname
.
rsplit
(
"."
,
1
)[
1
]
+
"."
)
return
punica_wrapper
vllm/platforms/cpu.py
View file @
a7d59688
...
...
@@ -109,3 +109,7 @@ class CpuPlatform(Platform):
def
is_pin_memory_available
(
cls
)
->
bool
:
logger
.
warning
(
"Pin memory is not supported on CPU."
)
return
False
@
classmethod
def
get_punica_wrapper
(
cls
)
->
str
:
return
"vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU"
vllm/platforms/cuda.py
View file @
a7d59688
...
...
@@ -218,6 +218,10 @@ class CudaPlatformBase(Platform):
logger
.
info
(
"Using Flash Attention backend."
)
return
"vllm.attention.backends.flash_attn.FlashAttentionBackend"
@
classmethod
def
get_punica_wrapper
(
cls
)
->
str
:
return
"vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
...
...
vllm/platforms/hpu.py
View file @
a7d59688
...
...
@@ -63,3 +63,7 @@ class HpuPlatform(Platform):
def
is_pin_memory_available
(
cls
):
logger
.
warning
(
"Pin memory is not supported on HPU."
)
return
False
@
classmethod
def
get_punica_wrapper
(
cls
)
->
str
:
return
"vllm.lora.punica_wrapper.punica_hpu.PunicaWrapperHPU"
vllm/platforms/interface.py
View file @
a7d59688
...
...
@@ -276,6 +276,13 @@ class Platform:
return
False
return
True
@
classmethod
def
get_punica_wrapper
(
cls
)
->
str
:
"""
Return the punica wrapper for current platform.
"""
raise
NotImplementedError
class
UnspecifiedPlatform
(
Platform
):
_enum
=
PlatformEnum
.
UNSPECIFIED
...
...
vllm/platforms/rocm.py
View file @
a7d59688
...
...
@@ -153,3 +153,7 @@ class RocmPlatform(Platform):
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
" is not set, enabling VLLM_USE_TRITON_AWQ."
)
envs
.
VLLM_USE_TRITON_AWQ
=
True
@
classmethod
def
get_punica_wrapper
(
cls
)
->
str
:
return
"vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment