Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
d7afab6d
Unverified
Commit
d7afab6d
authored
Feb 14, 2024
by
Woosuk Kwon
Committed by
GitHub
Feb 14, 2024
Browse files
[BugFix] Fix GC bug for `LLM` class (#2882)
parent
31348dff
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
182 additions
and
170 deletions
+182
-170
tests/test_regression.py
tests/test_regression.py
+18
-0
vllm/lora/punica.py
vllm/lora/punica.py
+164
-170
No files found.
tests/test_regression.py
View file @
d7afab6d
...
@@ -4,6 +4,10 @@ It should include tests that are reported by users and making sure they
...
@@ -4,6 +4,10 @@ It should include tests that are reported by users and making sure they
will never happen again.
will never happen again.
"""
"""
import
gc
import
torch
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
...
@@ -35,6 +39,20 @@ def test_max_tokens_none():
...
@@ -35,6 +39,20 @@ def test_max_tokens_none():
assert
len
(
prompts
)
==
len
(
outputs
)
assert
len
(
prompts
)
==
len
(
outputs
)
def
test_gc
():
llm
=
LLM
(
"facebook/opt-125m"
,
enforce_eager
=
True
)
del
llm
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
# The memory allocated for model and KV cache should be released.
# The memory allocated for PyTorch and others should be less than 50MB.
# Usually, it's around 10MB.
allocated
=
torch
.
cuda
.
memory_allocated
()
assert
allocated
<
50
*
1024
*
1024
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
import
pytest
import
pytest
pytest
.
main
([
__file__
])
pytest
.
main
([
__file__
])
vllm/lora/punica.py
View file @
d7afab6d
...
@@ -4,173 +4,167 @@ from typing import Optional
...
@@ -4,173 +4,167 @@ from typing import Optional
import
torch
import
torch
import_exc
=
None
def
_raise_import_error
(
e
):
try
:
if
torch
.
cuda
.
get_device_capability
()
<
(
8
,
0
):
import
vllm._punica_C
as
punica_kernels
raise
ImportError
(
except
ImportError
as
e
:
"punica LoRA kernels require compute capability >= 8.0"
)
from
e
import_exc
=
e
else
:
raise
ImportError
(
if
import_exc
is
None
:
"punica LoRA kernels could not be imported. If you built vLLM "
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
def
bgmv
(
"was set."
)
from
e
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
def
bgmv
(
indicies
:
torch
.
LongTensor
,
y
:
torch
.
Tensor
,
layer_idx
:
int
,
x
:
torch
.
Tensor
,
scale
:
float
,
w_t_all
:
torch
.
Tensor
,
):
indicies
:
torch
.
LongTensor
,
"""
layer_idx
:
int
,
Semantics:
scale
:
float
,
y[i] += (
):
x[i].unsqueeze(0)
"""
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
Semantics:
* scale
y[i] += (
).squeeze(0)
x[i].unsqueeze(0)
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
Args:
* scale
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
).squeeze(0)
x: Shape: `[B, H1]`. Input vectors.
w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight
Args:
matrices.
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
indicies: Shape: `[B]`. Indices of the weight matrices.
x: Shape: `[B, H1]`. Input vectors.
layer_idx: Layer index of the weight matrices.
w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight
scale: Scaling factor.
matrices.
"""
indicies: Shape: `[B]`. Indices of the weight matrices.
punica_kernels
.
dispatch_bgmv
(
y
,
x
,
w_t_all
,
indicies
,
layer_idx
,
scale
)
layer_idx: Layer index of the weight matrices.
scale: Scaling factor.
def
add_lora
(
y
:
torch
.
Tensor
,
"""
x
:
torch
.
Tensor
,
try
:
wa_t_all
:
torch
.
Tensor
,
import
vllm._punica_C
as
punica_kernels
wb_t_all
:
torch
.
Tensor
,
except
ImportError
as
e
:
indicies
:
torch
.
LongTensor
,
_raise_import_error
(
e
)
layer_idx
:
int
,
scale
:
float
,
punica_kernels
.
dispatch_bgmv
(
y
,
x
,
w_t_all
,
indicies
,
layer_idx
,
scale
)
*
,
buffer
:
Optional
[
torch
.
Tensor
]
=
None
):
"""
def
add_lora
(
y
:
torch
.
Tensor
,
Semantics:
x
:
torch
.
Tensor
,
y[i] += (
wa_t_all
:
torch
.
Tensor
,
x[i].unsqueeze(0)
wb_t_all
:
torch
.
Tensor
,
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
indicies
:
torch
.
LongTensor
,
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
layer_idx
:
int
,
* scale
scale
:
float
,
).squeeze(0)
*
,
buffer
:
Optional
[
torch
.
Tensor
]
=
None
):
Args:
"""
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
Semantics:
x: Shape: `[B, H1]`. Input vectors.
y[i] += (
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
x[i].unsqueeze(0)
LoRA A matrices.
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
LoRA B matrices.
* scale
indicies: Shape: `[B]`. Indices of the LoRA weights.
).squeeze(0)
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
Args:
buffer: Optional. Shape: `[B, R]`. Temporary buffer.
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
"""
x: Shape: `[B, H1]`. Input vectors.
r
=
wb_t_all
.
size
(
-
1
)
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
if
buffer
is
None
:
LoRA A matrices.
# We set the buffer to be float32 by default to avoid
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
# numerical innacuracies that would otherwise happen
LoRA B matrices.
# due to downcasting.
indicies: Shape: `[B]`. Indices of the LoRA weights.
buffer
=
torch
.
zeros
((
x
.
size
(
0
),
r
),
layer_idx: Layer index of LoRA weights.
dtype
=
torch
.
float32
,
scale: Scaling factor.
device
=
x
.
device
)
buffer: Optional. Shape: `[B, R]`. Temporary buffer.
punica_kernels
.
dispatch_bgmv
(
buffer
,
x
,
wa_t_all
,
indicies
,
layer_idx
,
"""
1.0
)
try
:
punica_kernels
.
dispatch_bgmv
(
y
,
buffer
,
wb_t_all
,
indicies
,
layer_idx
,
import
vllm._punica_C
as
punica_kernels
scale
)
except
ImportError
as
e
:
_raise_import_error
(
e
)
def
add_lora_slice
(
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
r
=
wb_t_all
.
size
(
-
1
)
wa_t_all
:
torch
.
Tensor
,
if
buffer
is
None
:
wb_t_all
:
torch
.
Tensor
,
# We set the buffer to be float32 by default to avoid
indicies
:
torch
.
LongTensor
,
# numerical innacuracies that would otherwise happen
layer_idx
:
int
,
# due to downcasting.
scale
:
float
,
buffer
=
torch
.
zeros
((
x
.
size
(
0
),
r
),
y_offset
:
int
,
dtype
=
torch
.
float32
,
y_slice_size
:
int
,
device
=
x
.
device
)
*
,
punica_kernels
.
dispatch_bgmv
(
buffer
,
x
,
wa_t_all
,
indicies
,
layer_idx
,
1.0
)
buffer
:
Optional
[
torch
.
Tensor
]
=
None
):
punica_kernels
.
dispatch_bgmv
(
y
,
buffer
,
wb_t_all
,
indicies
,
layer_idx
,
"""
scale
)
Same as `add_lora` but you can operate on slices of y.
Pass whole y, define y_offset and y_slice_size.
def
add_lora_slice
(
y
:
torch
.
Tensor
,
Semantics:
x
:
torch
.
Tensor
,
y[i] += (
wa_t_all
:
torch
.
Tensor
,
x[i].unsqueeze(0)
wb_t_all
:
torch
.
Tensor
,
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
indicies
:
torch
.
LongTensor
,
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
layer_idx
:
int
,
* scale
scale
:
float
,
).squeeze(0)
y_offset
:
int
,
y_slice_size
:
int
,
Args:
*
,
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
buffer
:
Optional
[
torch
.
Tensor
]
=
None
):
x: Shape: `[B, H1]`. Input vectors.
"""
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
Same as `add_lora` but you can operate on slices of y.
LoRA A matrices.
Pass whole y, define y_offset and y_slice_size.
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
LoRA B matrices.
Semantics:
indicies: Shape: `[B]`. Indices of the LoRA weights.
y[i] += (
layer_idx: Layer index of LoRA weights.
x[i].unsqueeze(0)
scale: Scaling factor.
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
y_offset: Offset to apply to the starting column of y.
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
y_slice_size: Size of the y column slice.
* scale
"""
).squeeze(0)
r
=
wb_t_all
.
size
(
-
1
)
if
buffer
is
None
:
Args:
# We set the buffer to be float32 by default to avoid
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
# numerical inaccuracies that would otherwise happen
x: Shape: `[B, H1]`. Input vectors.
# due to downcasting.
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
buffer
=
torch
.
zeros
((
x
.
size
(
0
),
r
),
LoRA A matrices.
dtype
=
torch
.
float32
,
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
device
=
x
.
device
)
LoRA B matrices.
punica_kernels
.
dispatch_bgmv_low_level
(
indicies: Shape: `[B]`. Indices of the LoRA weights.
buffer
,
layer_idx: Layer index of LoRA weights.
x
,
scale: Scaling factor.
wa_t_all
,
y_offset: Offset to apply to the starting column of y.
indicies
,
y_slice_size: Size of the y column slice.
layer_idx
,
"""
1.0
,
try
:
x
.
size
(
1
),
import
vllm._punica_C
as
punica_kernels
buffer
.
size
(
1
),
except
ImportError
as
e
:
0
,
_raise_import_error
(
e
)
)
punica_kernels
.
dispatch_bgmv_low_level
(
r
=
wb_t_all
.
size
(
-
1
)
y
,
if
buffer
is
None
:
buffer
,
# We set the buffer to be float32 by default to avoid
wb_t_all
,
# numerical inaccuracies that would otherwise happen
indicies
,
# due to downcasting.
layer_idx
,
buffer
=
torch
.
zeros
((
x
.
size
(
0
),
r
),
scale
,
dtype
=
torch
.
float32
,
buffer
.
size
(
1
),
device
=
x
.
device
)
y_slice_size
,
punica_kernels
.
dispatch_bgmv_low_level
(
y_offset
,
buffer
,
)
x
,
wa_t_all
,
else
:
indicies
,
layer_idx
,
def
_raise_exc
(
1.0
,
*
args
,
# pylint: disable=unused-argument
x
.
size
(
1
),
**
kwargs
# pylint: disable=unused-argument
buffer
.
size
(
1
),
):
0
,
if
torch
.
cuda
.
get_device_capability
()
<
(
8
,
0
):
)
raise
ImportError
(
"punica LoRA kernels require compute "
punica_kernels
.
dispatch_bgmv_low_level
(
"capability>=8.0"
)
from
import_exc
y
,
else
:
buffer
,
raise
ImportError
(
wb_t_all
,
"punica LoRA kernels could not be imported. If you built vLLM "
indicies
,
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
layer_idx
,
"was set."
)
from
import_exc
scale
,
buffer
.
size
(
1
),
bgmv
=
_raise_exc
y_slice_size
,
add_lora
=
_raise_exc
y_offset
,
add_lora_slice
=
_raise_exc
)
__all__
=
[
"bgmv"
,
"add_lora"
,
"add_lora_slice"
,
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment