Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3cb57698
Unverified
Commit
3cb57698
authored
Dec 15, 2024
by
Jee Jee Li
Committed by
GitHub
Dec 14, 2024
Browse files
[Misc] Minor improvements to the readability of PunicaWrapperBase (#11200)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
ea7bd68d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
25 deletions
+27
-25
vllm/lora/punica_wrapper/punica_base.py
vllm/lora/punica_wrapper/punica_base.py
+8
-6
vllm/lora/punica_wrapper/punica_gpu.py
vllm/lora/punica_wrapper/punica_gpu.py
+17
-17
vllm/lora/punica_wrapper/punica_hpu.py
vllm/lora/punica_wrapper/punica_hpu.py
+2
-2
No files found.
vllm/lora/punica_wrapper/punica_base.py
View file @
3cb57698
...
@@ -63,7 +63,7 @@ class PunicaWrapperABC(ABC):
...
@@ -63,7 +63,7 @@ class PunicaWrapperABC(ABC):
lora_bias_stacked
:
Optional
[
Tuple
[
torch
.
Tensor
,
...]],
lora_bias_stacked
:
Optional
[
Tuple
[
torch
.
Tensor
,
...]],
output_slices
:
Tuple
[
int
,
...],
output_slices
:
Tuple
[
int
,
...],
offset_start
:
int
=
0
,
offset_start
:
int
=
0
,
add_input
=
True
,
add_input
s
=
True
,
**
kwargs
,
**
kwargs
,
)
->
None
:
)
->
None
:
"""
"""
...
@@ -77,7 +77,7 @@ class PunicaWrapperABC(ABC):
...
@@ -77,7 +77,7 @@ class PunicaWrapperABC(ABC):
y
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
lora_b_stacked
:
torch
.
Tensor
,
lora_b_stacked
:
torch
.
Tensor
,
add_input
:
bool
=
True
,
add_input
s
:
bool
=
True
,
**
kwargs
,
**
kwargs
,
)
->
None
:
)
->
None
:
"""
"""
...
@@ -367,12 +367,13 @@ class PunicaWrapperBase(PunicaWrapperABC):
...
@@ -367,12 +367,13 @@ class PunicaWrapperBase(PunicaWrapperABC):
lora_bias_stacked
:
Optional
[
Tuple
[
torch
.
Tensor
,
...]],
lora_bias_stacked
:
Optional
[
Tuple
[
torch
.
Tensor
,
...]],
output_slices
:
Tuple
[
int
,
...],
output_slices
:
Tuple
[
int
,
...],
offset_start
:
int
=
0
,
offset_start
:
int
=
0
,
add_input
=
True
,
add_input
s
=
True
,
**
kwargs
)
->
None
:
**
kwargs
)
->
None
:
"""
"""
Performs GEMM and bias addition for multiple slices of lora_b.
Performs GEMM and bias addition for multiple slices of lora_b.
Semantics:
Semantics:
offset = offset_start
for i in range(len(lora_b_stacked)):
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
...
@@ -386,7 +387,8 @@ class PunicaWrapperBase(PunicaWrapperABC):
...
@@ -386,7 +387,8 @@ class PunicaWrapperBase(PunicaWrapperABC):
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
bias's weight
bias's weight
output_slices (Tuple[int, ...]): Every slice's size
output_slices (Tuple[int, ...]): Every slice's size
add_input (bool): Defaults to True.
offset_start (int): The starting position of y, defaults to 0
add_inputs (bool): Defaults to True.
"""
"""
# TODO: implement it based on torch ops
# TODO: implement it based on torch ops
...
@@ -397,7 +399,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
...
@@ -397,7 +399,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
y
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
lora_b_stacked
:
torch
.
Tensor
,
lora_b_stacked
:
torch
.
Tensor
,
add_input
:
bool
=
True
,
add_input
s
:
bool
=
True
,
**
kwargs
)
->
None
:
**
kwargs
)
->
None
:
"""
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
...
@@ -409,7 +411,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
...
@@ -409,7 +411,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
y (torch.Tensor): Output tensor.
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_input (bool): Default to True.
add_input
s
(bool): Default to True.
"""
"""
# TODO: implement it based on torch ops
# TODO: implement it based on torch ops
raise
NotImplementedError
raise
NotImplementedError
...
...
vllm/lora/punica_wrapper/punica_gpu.py
View file @
3cb57698
...
@@ -67,7 +67,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
...
@@ -67,7 +67,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
y
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
add_input
:
bool
,
add_input
s
:
bool
,
):
):
#No LoRA request, so return directly
#No LoRA request, so return directly
if
self
.
no_lora
:
if
self
.
no_lora
:
...
@@ -77,7 +77,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
...
@@ -77,7 +77,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
w_t_all
,
w_t_all
,
y
,
y
,
*
self
.
prefill_metadata
,
*
self
.
prefill_metadata
,
add_input
,
add_input
s
,
)
)
def
_expand_decode
(
def
_expand_decode
(
...
@@ -85,9 +85,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
...
@@ -85,9 +85,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
y
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
add_input
:
bool
,
add_input
s
:
bool
,
):
):
bgmv_expand
(
x
,
w_t_all
,
y
,
self
.
token_lora_indices
,
add_input
)
bgmv_expand
(
x
,
w_t_all
,
y
,
self
.
token_lora_indices
,
add_input
s
)
def
_expand_slice_prefill
(
def
_expand_slice_prefill
(
self
,
self
,
...
@@ -96,7 +96,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
...
@@ -96,7 +96,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
w_t_all
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
y_offset
:
Optional
[
int
],
y_offset
:
Optional
[
int
],
y_slice_size
:
Optional
[
int
],
y_slice_size
:
Optional
[
int
],
add_input
:
bool
,
add_input
s
:
bool
,
):
):
#No LoRA request, so return directly
#No LoRA request, so return directly
if
self
.
no_lora
:
if
self
.
no_lora
:
...
@@ -108,7 +108,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
...
@@ -108,7 +108,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
*
self
.
prefill_metadata
,
*
self
.
prefill_metadata
,
y_offset
,
y_offset
,
y_slice_size
,
y_slice_size
,
add_input
,
add_input
s
,
)
)
def
_expand_slice_decode
(
def
_expand_slice_decode
(
...
@@ -118,10 +118,10 @@ class PunicaWrapperGPU(PunicaWrapperBase):
...
@@ -118,10 +118,10 @@ class PunicaWrapperGPU(PunicaWrapperBase):
w_t_all
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
y_offset
:
Optional
[
int
],
y_offset
:
Optional
[
int
],
y_slice_size
:
Optional
[
int
],
y_slice_size
:
Optional
[
int
],
add_input
:
bool
,
add_input
s
:
bool
,
):
):
bgmv_expand_slice
(
x
,
w_t_all
,
y
,
self
.
token_lora_indices
,
y_offset
,
bgmv_expand_slice
(
x
,
w_t_all
,
y
,
self
.
token_lora_indices
,
y_offset
,
y_slice_size
,
add_input
)
y_slice_size
,
add_input
s
)
def
_apply_expand
(
def
_apply_expand
(
self
,
self
,
...
@@ -130,7 +130,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
...
@@ -130,7 +130,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
w_t_all
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
y_offset
:
Optional
[
int
],
y_offset
:
Optional
[
int
],
y_slice_size
:
Optional
[
int
],
y_slice_size
:
Optional
[
int
],
add_input
:
bool
=
True
,
add_input
s
:
bool
=
True
,
):
):
"""
"""
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
...
@@ -141,7 +141,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
...
@@ -141,7 +141,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
expand_slice_fun
:
Callable
=
(
self
.
_expand_slice_prefill
expand_slice_fun
:
Callable
=
(
self
.
_expand_slice_prefill
if
self
.
is_prefill
else
if
self
.
is_prefill
else
self
.
_expand_slice_decode
)
self
.
_expand_slice_decode
)
expand_slice_fun
(
y
,
x
,
w_t_all
,
y_offset
,
y_slice_size
,
add_input
)
expand_slice_fun
(
y
,
x
,
w_t_all
,
y_offset
,
y_slice_size
,
add_input
s
)
def
_apply_shrink
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
def
_apply_shrink
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
scale
:
float
):
w_t_all
:
torch
.
Tensor
,
scale
:
float
):
...
@@ -194,7 +194,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
...
@@ -194,7 +194,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
lora_bias_stacked
:
Optional
[
Tuple
[
torch
.
Tensor
,
...]],
lora_bias_stacked
:
Optional
[
Tuple
[
torch
.
Tensor
,
...]],
output_slices
:
Tuple
[
int
,
...],
output_slices
:
Tuple
[
int
,
...],
offset_start
:
int
=
0
,
offset_start
:
int
=
0
,
add_input
=
True
,
add_input
s
=
True
,
**
kwargs
)
->
None
:
**
kwargs
)
->
None
:
"""
"""
Performs GEMM and bias addition for multiple slices of lora_b.
Performs GEMM and bias addition for multiple slices of lora_b.
...
@@ -213,7 +213,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
...
@@ -213,7 +213,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
bias's weight
bias's weight
output_slices (Tuple[int, ...]): Every slice's size
output_slices (Tuple[int, ...]): Every slice's size
add_input (bool): Defaults to True.
add_input
s
(bool): Defaults to True.
"""
"""
y_org
=
y
y_org
=
y
y
=
y
.
view
(
-
1
,
y
.
shape
[
-
1
])
y
=
y
.
view
(
-
1
,
y
.
shape
[
-
1
])
...
@@ -228,7 +228,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
...
@@ -228,7 +228,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
lora_b_stacked
[
slice_idx
],
lora_b_stacked
[
slice_idx
],
offset_left
,
offset_left
,
output_slices
[
slice_idx
],
output_slices
[
slice_idx
],
add_input
=
add_input
,
add_input
s
=
add_input
s
,
)
)
offset_left
+=
output_slices
[
slice_idx
]
offset_left
+=
output_slices
[
slice_idx
]
y
=
y
.
view_as
(
y_org
)
y
=
y
.
view_as
(
y_org
)
...
@@ -237,7 +237,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
...
@@ -237,7 +237,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
y
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
lora_b_stacked
:
torch
.
Tensor
,
lora_b_stacked
:
torch
.
Tensor
,
add_input
:
bool
=
True
,
add_input
s
:
bool
=
True
,
**
kwargs
)
->
None
:
**
kwargs
)
->
None
:
"""
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
...
@@ -249,13 +249,13 @@ class PunicaWrapperGPU(PunicaWrapperBase):
...
@@ -249,13 +249,13 @@ class PunicaWrapperGPU(PunicaWrapperBase):
y (torch.Tensor): Output tensor.
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_input (bool): Default to True.
add_input
s
(bool): Default to True.
"""
"""
# Embedding layer only need expand op
# Embedding layer only need expand op
expand_fun
:
Callable
=
(
self
.
_expand_prefill
expand_fun
:
Callable
=
(
self
.
_expand_prefill
if
self
.
is_prefill
else
self
.
_expand_decode
)
if
self
.
is_prefill
else
self
.
_expand_decode
)
expand_fun
(
y
,
x
,
lora_b_stacked
,
add_input
)
expand_fun
(
y
,
x
,
lora_b_stacked
,
add_input
s
)
def
add_lora_linear
(
self
,
def
add_lora_linear
(
self
,
y
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
...
@@ -311,7 +311,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
...
@@ -311,7 +311,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
lora_b_stacked
,
lora_b_stacked
,
None
,
None
,
output_slices
,
output_slices
,
add_input
=
True
,
add_input
s
=
True
,
**
kwargs
)
**
kwargs
)
def
add_lora_logits
(
self
,
def
add_lora_logits
(
self
,
...
...
vllm/lora/punica_wrapper/punica_hpu.py
View file @
3cb57698
...
@@ -21,7 +21,7 @@ class PunicaWrapperHPU(PunicaWrapperBase):
...
@@ -21,7 +21,7 @@ class PunicaWrapperHPU(PunicaWrapperBase):
y
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
lora_b_stacked
:
torch
.
Tensor
,
lora_b_stacked
:
torch
.
Tensor
,
add_input
:
bool
=
True
,
add_input
s
:
bool
=
True
,
**
kwargs
)
->
None
:
**
kwargs
)
->
None
:
dispatch_bgmv_embedding
(
y
,
x
,
lora_b_stacked
,
0
)
dispatch_bgmv_embedding
(
y
,
x
,
lora_b_stacked
,
0
)
...
@@ -81,7 +81,7 @@ class PunicaWrapperHPU(PunicaWrapperBase):
...
@@ -81,7 +81,7 @@ class PunicaWrapperHPU(PunicaWrapperBase):
lora_bias_stacked
:
Optional
[
Tuple
[
torch
.
Tensor
,
...]],
lora_bias_stacked
:
Optional
[
Tuple
[
torch
.
Tensor
,
...]],
output_slices
:
Tuple
[
int
,
...],
output_slices
:
Tuple
[
int
,
...],
offset_start
:
int
=
0
,
offset_start
:
int
=
0
,
add_input
=
True
,
add_input
s
=
True
,
**
kwargs
,
**
kwargs
,
)
->
None
:
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment