Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
d7afab6d
Unverified
Commit
d7afab6d
authored
Feb 14, 2024
by
Woosuk Kwon
Committed by
GitHub
Feb 14, 2024
Browse files
[BugFix] Fix GC bug for `LLM` class (#2882)
parent
31348dff
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
182 additions
and
170 deletions
+182
-170
tests/test_regression.py
tests/test_regression.py
+18
-0
vllm/lora/punica.py
vllm/lora/punica.py
+164
-170
No files found.
tests/test_regression.py
View file @
d7afab6d
...
@@ -4,6 +4,10 @@ It should include tests that are reported by users and making sure they
...
@@ -4,6 +4,10 @@ It should include tests that are reported by users and making sure they
will never happen again.
will never happen again.
"""
"""
import
gc
import
torch
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
...
@@ -35,6 +39,20 @@ def test_max_tokens_none():
...
@@ -35,6 +39,20 @@ def test_max_tokens_none():
assert
len
(
prompts
)
==
len
(
outputs
)
assert
len
(
prompts
)
==
len
(
outputs
)
def
test_gc
():
llm
=
LLM
(
"facebook/opt-125m"
,
enforce_eager
=
True
)
del
llm
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
# The memory allocated for model and KV cache should be released.
# The memory allocated for PyTorch and others should be less than 50MB.
# Usually, it's around 10MB.
allocated
=
torch
.
cuda
.
memory_allocated
()
assert
allocated
<
50
*
1024
*
1024
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
import
pytest
import
pytest
pytest
.
main
([
__file__
])
pytest
.
main
([
__file__
])
vllm/lora/punica.py
View file @
d7afab6d
...
@@ -4,23 +4,26 @@ from typing import Optional
...
@@ -4,23 +4,26 @@ from typing import Optional
import
torch
import
torch
import_exc
=
None
try
:
def
_raise_import_error
(
e
):
import
vllm._punica_C
as
punica_kernels
if
torch
.
cuda
.
get_device_capability
()
<
(
8
,
0
):
except
ImportError
as
e
:
raise
ImportError
(
import_exc
=
e
"punica LoRA kernels require compute capability >= 8.0"
)
from
e
else
:
raise
ImportError
(
"punica LoRA kernels could not be imported. If you built vLLM "
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
"was set."
)
from
e
if
import_exc
is
None
:
def
bgmv
(
def
bgmv
(
y
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
indicies
:
torch
.
LongTensor
,
indicies
:
torch
.
LongTensor
,
layer_idx
:
int
,
layer_idx
:
int
,
scale
:
float
,
scale
:
float
,
):
):
"""
"""
Semantics:
Semantics:
y[i] += (
y[i] += (
...
@@ -38,9 +41,15 @@ if import_exc is None:
...
@@ -38,9 +41,15 @@ if import_exc is None:
layer_idx: Layer index of the weight matrices.
layer_idx: Layer index of the weight matrices.
scale: Scaling factor.
scale: Scaling factor.
"""
"""
try
:
import
vllm._punica_C
as
punica_kernels
except
ImportError
as
e
:
_raise_import_error
(
e
)
punica_kernels
.
dispatch_bgmv
(
y
,
x
,
w_t_all
,
indicies
,
layer_idx
,
scale
)
punica_kernels
.
dispatch_bgmv
(
y
,
x
,
w_t_all
,
indicies
,
layer_idx
,
scale
)
def
add_lora
(
y
:
torch
.
Tensor
,
def
add_lora
(
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
wa_t_all
:
torch
.
Tensor
,
wa_t_all
:
torch
.
Tensor
,
wb_t_all
:
torch
.
Tensor
,
wb_t_all
:
torch
.
Tensor
,
...
@@ -70,6 +79,11 @@ if import_exc is None:
...
@@ -70,6 +79,11 @@ if import_exc is None:
scale: Scaling factor.
scale: Scaling factor.
buffer: Optional. Shape: `[B, R]`. Temporary buffer.
buffer: Optional. Shape: `[B, R]`. Temporary buffer.
"""
"""
try
:
import
vllm._punica_C
as
punica_kernels
except
ImportError
as
e
:
_raise_import_error
(
e
)
r
=
wb_t_all
.
size
(
-
1
)
r
=
wb_t_all
.
size
(
-
1
)
if
buffer
is
None
:
if
buffer
is
None
:
# We set the buffer to be float32 by default to avoid
# We set the buffer to be float32 by default to avoid
...
@@ -78,12 +92,12 @@ if import_exc is None:
...
@@ -78,12 +92,12 @@ if import_exc is None:
buffer
=
torch
.
zeros
((
x
.
size
(
0
),
r
),
buffer
=
torch
.
zeros
((
x
.
size
(
0
),
r
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
x
.
device
)
device
=
x
.
device
)
punica_kernels
.
dispatch_bgmv
(
buffer
,
x
,
wa_t_all
,
indicies
,
layer_idx
,
punica_kernels
.
dispatch_bgmv
(
buffer
,
x
,
wa_t_all
,
indicies
,
layer_idx
,
1.0
)
1.0
)
punica_kernels
.
dispatch_bgmv
(
y
,
buffer
,
wb_t_all
,
indicies
,
layer_idx
,
punica_kernels
.
dispatch_bgmv
(
y
,
buffer
,
wb_t_all
,
indicies
,
layer_idx
,
scale
)
scale
)
def
add_lora_slice
(
y
:
torch
.
Tensor
,
def
add_lora_slice
(
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
wa_t_all
:
torch
.
Tensor
,
wa_t_all
:
torch
.
Tensor
,
wb_t_all
:
torch
.
Tensor
,
wb_t_all
:
torch
.
Tensor
,
...
@@ -119,6 +133,11 @@ if import_exc is None:
...
@@ -119,6 +133,11 @@ if import_exc is None:
y_offset: Offset to apply to the starting column of y.
y_offset: Offset to apply to the starting column of y.
y_slice_size: Size of the y column slice.
y_slice_size: Size of the y column slice.
"""
"""
try
:
import
vllm._punica_C
as
punica_kernels
except
ImportError
as
e
:
_raise_import_error
(
e
)
r
=
wb_t_all
.
size
(
-
1
)
r
=
wb_t_all
.
size
(
-
1
)
if
buffer
is
None
:
if
buffer
is
None
:
# We set the buffer to be float32 by default to avoid
# We set the buffer to be float32 by default to avoid
...
@@ -149,28 +168,3 @@ if import_exc is None:
...
@@ -149,28 +168,3 @@ if import_exc is None:
y_slice_size
,
y_slice_size
,
y_offset
,
y_offset
,
)
)
else
:
def
_raise_exc
(
*
args
,
# pylint: disable=unused-argument
**
kwargs
# pylint: disable=unused-argument
):
if
torch
.
cuda
.
get_device_capability
()
<
(
8
,
0
):
raise
ImportError
(
"punica LoRA kernels require compute "
"capability>=8.0"
)
from
import_exc
else
:
raise
ImportError
(
"punica LoRA kernels could not be imported. If you built vLLM "
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
"was set."
)
from
import_exc
bgmv
=
_raise_exc
add_lora
=
_raise_exc
add_lora_slice
=
_raise_exc
__all__
=
[
"bgmv"
,
"add_lora"
,
"add_lora_slice"
,
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment