Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
10760da8
Unverified
Commit
10760da8
authored
May 07, 2024
by
Austin Veselka
Committed by
GitHub
May 07, 2024
Browse files
[Bugfix] Fixed error in slice_lora_b for MergedQKVParallelLinearWithLora (#4609)
parent
478aed58
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
52 additions
and
32 deletions
+52
-32
vllm/lora/fully_sharded_layers.py
vllm/lora/fully_sharded_layers.py
+30
-24
vllm/lora/layers.py
vllm/lora/layers.py
+22
-8
No files found.
vllm/lora/fully_sharded_layers.py
View file @
10760da8
# pylint: disable=unused-argument
# pylint: disable=unused-argument
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -51,10 +51,9 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
...
@@ -51,10 +51,9 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
lora_a
=
lora_a
[:,
start_idx
:
start_idx
+
shard_size
]
lora_a
=
lora_a
[:,
start_idx
:
start_idx
+
shard_size
]
return
lora_a
return
lora_a
def
apply
_weights
(
self
,
x
:
torch
.
Tensor
,
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
linear_method
.
apply_weights
(
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
,
bias
)
self
.
base_layer
,
x
,
bias
)
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
output
,
out_orig_shape
=
output
.
view
(
-
1
,
output
,
out_orig_shape
=
output
.
view
(
-
1
,
...
@@ -88,7 +87,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
...
@@ -88,7 +87,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
)
)
def
_mcp_apply
_weights
(
x
,
bias
,
layer
):
def
_mcp_apply
(
x
,
bias
,
layer
):
"""
"""
MergedColumnParallelLinearWithShardedLoRA and
MergedColumnParallelLinearWithShardedLoRA and
QKVParallelLinearWithShardedLora share the same
QKVParallelLinearWithShardedLora share the same
...
@@ -100,8 +99,7 @@ def _mcp_apply_weights(x, bias, layer):
...
@@ -100,8 +99,7 @@ def _mcp_apply_weights(x, bias, layer):
"""
"""
# expecting 2 for column parallel and 3 for qkv
# expecting 2 for column parallel and 3 for qkv
n
=
len
(
layer
.
lora_a_stacked
)
n
=
len
(
layer
.
lora_a_stacked
)
output
=
layer
.
base_layer
.
linear_method
.
apply_weights
(
output
=
layer
.
base_layer
.
quant_method
.
apply
(
layer
.
base_layer
,
x
,
bias
)
layer
.
base_layer
,
x
,
bias
)
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
output
,
out_orig_shape
=
output
.
view
(
-
1
,
output
.
shape
[
-
1
]),
output
.
shape
output
,
out_orig_shape
=
output
.
view
(
-
1
,
output
.
shape
[
-
1
]),
output
.
shape
...
@@ -136,18 +134,23 @@ class MergedColumnParallelLinearWithShardedLoRA(
...
@@ -136,18 +134,23 @@ class MergedColumnParallelLinearWithShardedLoRA(
Based on S-LoRA, slicing happens along the rank dim.
Based on S-LoRA, slicing happens along the rank dim.
"""
"""
def
slice_lora_a
(
self
,
lora_a
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
def
slice_lora_a
(
self
,
lora_a
:
List
[
Union
[
torch
.
Tensor
,
None
]]
)
->
List
[
Union
[
torch
.
Tensor
,
None
]]:
if
lora_a
[
0
]
is
None
or
lora_a
[
1
]
is
None
:
return
lora_a
output_shard_size
=
self
.
lora_a_stacked
[
0
].
shape
[
2
]
output_shard_size
=
self
.
lora_a_stacked
[
0
].
shape
[
2
]
output_start_idx
=
self
.
tp_rank
*
output_shard_size
output_start_idx
=
self
.
tp_rank
*
output_shard_size
lora_a
=
[
lora_a
=
[
lora_a
[
i
][:,
output_start_idx
:
output_start_idx
+
output_shard_size
]
lora_a
[
0
][:,
for
i
in
range
(
2
)
output_start_idx
:
output_start_idx
+
output_shard_size
],
lora_a
[
1
][:,
output_start_idx
:
output_start_idx
+
output_shard_size
]
]
]
return
lora_a
return
lora_a
def
apply
_weights
(
self
,
x
:
torch
.
Tensor
,
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
_mcp_apply
_weights
(
x
,
bias
,
self
)
return
_mcp_apply
(
x
,
bias
,
self
)
@
classmethod
@
classmethod
@
_fully_sharded_can_replace
@
_fully_sharded_can_replace
...
@@ -172,19 +175,23 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
...
@@ -172,19 +175,23 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
Based on S-LoRA, slicing happens along the rank dim.
Based on S-LoRA, slicing happens along the rank dim.
"""
"""
def
slice_lora_a
(
self
,
lora_a
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
def
slice_lora_a
(
self
,
lora_a
:
List
[
Union
[
torch
.
Tensor
,
None
]]
)
->
List
[
Union
[
torch
.
Tensor
,
None
]]:
if
lora_a
[
0
]
is
None
or
lora_a
[
1
]
is
None
or
lora_a
[
2
]
is
None
:
return
lora_a
shard_size
=
[
self
.
lora_a_stacked
[
i
].
shape
[
2
]
for
i
in
range
(
3
)]
shard_size
=
[
self
.
lora_a_stacked
[
i
].
shape
[
2
]
for
i
in
range
(
3
)]
start_idx
=
[
self
.
tp_rank
*
shard_size
[
i
]
for
i
in
range
(
3
)]
start_idx
=
[
self
.
tp_rank
*
shard_size
[
i
]
for
i
in
range
(
3
)]
lora_a
=
[
lora_a
=
[
lora_a
[
i
][:,
start_idx
[
i
]:
start_idx
[
i
]
+
lora_a
[
0
][:,
start_idx
[
0
]:
start_idx
[
0
]
+
shard_size
[
0
]],
shard_size
[
i
]]
if
lora_a
[
i
]
is
not
None
else
None
lora_a
[
1
][:,
start_idx
[
1
]:
start_idx
[
1
]
+
shard_size
[
1
]]
,
f
or
i
in
range
(
3
)
l
or
a_a
[
2
][:,
start_idx
[
2
]:
start_idx
[
2
]
+
shard_size
[
2
]]
]
]
return
lora_a
return
lora_a
def
apply
_weights
(
self
,
x
:
torch
.
Tensor
,
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
_mcp_apply
_weights
(
x
,
bias
,
self
)
return
_mcp_apply
(
x
,
bias
,
self
)
@
classmethod
@
classmethod
@
_fully_sharded_can_replace
@
_fully_sharded_can_replace
...
@@ -218,9 +225,8 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
...
@@ -218,9 +225,8 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
lora_b
=
lora_b
[:,
start_idx
:
end_idx
]
lora_b
=
lora_b
[:,
start_idx
:
end_idx
]
return
lora_b
return
lora_b
def
apply_weights
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
apply
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
linear_method
.
apply_weights
(
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
)
self
.
base_layer
,
x
)
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
output
,
out_orig_shape
=
output
.
view
(
-
1
,
output
,
out_orig_shape
=
output
.
view
(
-
1
,
...
...
vllm/lora/layers.py
View file @
10760da8
# pylint: disable=unused-argument
# pylint: disable=unused-argument
import
math
import
math
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -145,11 +145,15 @@ class LoRAMapping:
...
@@ -145,11 +145,15 @@ class LoRAMapping:
class
BaseLayerWithLoRA
(
nn
.
Module
):
class
BaseLayerWithLoRA
(
nn
.
Module
):
def
slice_lora_a
(
self
,
lora_a
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
slice_lora_a
(
self
,
lora_a
:
Union
[
torch
.
Tensor
,
List
[
Union
[
torch
.
Tensor
,
None
]]]
)
->
Union
[
torch
.
Tensor
,
List
[
Union
[
torch
.
Tensor
,
None
]]]:
"""Slice lora a if splitting for tensor parallelism."""
"""Slice lora a if splitting for tensor parallelism."""
...
...
def
slice_lora_b
(
self
,
lora_b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
slice_lora_b
(
self
,
lora_b
:
Union
[
torch
.
Tensor
,
List
[
Union
[
torch
.
Tensor
,
None
]]]
)
->
Union
[
torch
.
Tensor
,
List
[
Union
[
torch
.
Tensor
,
None
]]]:
"""Slice lora b if splitting with tensor parallelism."""
"""Slice lora b if splitting with tensor parallelism."""
...
...
...
@@ -539,10 +543,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -539,10 +543,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self
.
lora_b_stacked
[
0
][
index
]
=
0
self
.
lora_b_stacked
[
0
][
index
]
=
0
self
.
lora_b_stacked
[
1
][
index
]
=
0
self
.
lora_b_stacked
[
1
][
index
]
=
0
def
slice_lora_a
(
self
,
lora_a
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
def
slice_lora_a
(
self
,
lora_a
:
List
[
Union
[
torch
.
Tensor
,
None
]]
)
->
List
[
Union
[
torch
.
Tensor
,
None
]]:
return
lora_a
return
lora_a
def
slice_lora_b
(
self
,
lora_b
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
def
slice_lora_b
(
self
,
lora_b
:
List
[
Union
[
torch
.
Tensor
,
None
]]
)
->
List
[
Union
[
torch
.
Tensor
,
None
]]:
if
lora_b
[
0
]
is
None
or
lora_b
[
1
]
is
None
:
return
lora_b
shard_size
=
self
.
output_dim
shard_size
=
self
.
output_dim
start_idx
=
self
.
tp_rank
*
shard_size
start_idx
=
self
.
tp_rank
*
shard_size
end_idx
=
(
self
.
tp_rank
+
1
)
*
shard_size
end_idx
=
(
self
.
tp_rank
+
1
)
*
shard_size
...
@@ -767,10 +777,15 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
...
@@ -767,10 +777,15 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self
.
lora_a_stacked
[
2
][
index
]
=
0
self
.
lora_a_stacked
[
2
][
index
]
=
0
self
.
lora_b_stacked
[
2
][
index
]
=
0
self
.
lora_b_stacked
[
2
][
index
]
=
0
def
slice_lora_a
(
self
,
lora_a
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
def
slice_lora_a
(
self
,
lora_a
:
List
[
Union
[
torch
.
Tensor
,
None
]]
)
->
List
[
Union
[
torch
.
Tensor
,
None
]]:
return
lora_a
return
lora_a
def
slice_lora_b
(
self
,
lora_b
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
def
slice_lora_b
(
self
,
lora_b
:
List
[
Union
[
torch
.
Tensor
,
None
]]
)
->
List
[
Union
[
torch
.
Tensor
,
None
]]:
lora_b_q
,
lora_b_k
,
lora_b_v
=
None
,
None
,
None
if
lora_b
[
0
]
is
not
None
:
if
lora_b
[
0
]
is
not
None
:
lora_b_q
=
lora_b
[
0
][:,
self
.
q_proj_shard_size
*
lora_b_q
=
lora_b
[
0
][:,
self
.
q_proj_shard_size
*
self
.
q_shard_id
:
self
.
q_proj_shard_size
*
self
.
q_shard_id
:
self
.
q_proj_shard_size
*
...
@@ -992,7 +1007,6 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -992,7 +1007,6 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
@
property
@
property
def
weight
(
self
):
def
weight
(
self
):
return
self
.
base_layer
.
weight
if
hasattr
(
return
self
.
base_layer
.
weight
if
hasattr
(
self
.
base_layer
,
"weight"
)
else
self
.
base_layer
.
qweight
self
.
base_layer
,
"weight"
)
else
self
.
base_layer
.
qweight
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment