Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
460c1884
"vscode:/vscode.git/clone" did not exist on "d094512296ad18968efffd925c372533e9dd12e3"
Unverified
Commit
460c1884
authored
Jul 31, 2024
by
Michael Goin
Committed by
GitHub
Jul 31, 2024
Browse files
[Bugfix] Support cpu offloading with fp8 quantization (#6960)
parent
bd700134
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
116 additions
and
33 deletions
+116
-33
tests/basic_correctness/test_cpu_offload.py
tests/basic_correctness/test_cpu_offload.py
+37
-6
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+53
-3
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+26
-24
No files found.
tests/basic_correctness/test_cpu_offload.py
View file @
460c1884
from
vllm.utils
import
is_hip
import
pytest
from
tests.quantization.utils
import
is_quant_method_supported
from
..utils
import
compare_two_settings
from
..utils
import
compare_two_settings
...
@@ -6,8 +8,37 @@ from ..utils import compare_two_settings
...
@@ -6,8 +8,37 @@ from ..utils import compare_two_settings
def
test_cpu_offload
():
def
test_cpu_offload
():
compare_two_settings
(
"meta-llama/Llama-2-7b-hf"
,
[],
compare_two_settings
(
"meta-llama/Llama-2-7b-hf"
,
[],
[
"--cpu-offload-gb"
,
"4"
])
[
"--cpu-offload-gb"
,
"4"
])
if
not
is_hip
():
# compressed-tensors quantization is currently not supported in ROCm.
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
reason
=
"fp8 is not supported on this GPU type."
)
def
test_cpu_offload_fp8
():
# Test quantization of an unquantized checkpoint
compare_two_settings
(
"meta-llama/Meta-Llama-3-8B-Instruct"
,
[
"--quantization"
,
"fp8"
],
[
"--quantization"
,
"fp8"
,
"--cpu-offload-gb"
,
"2"
])
# Test loading a quantized checkpoint
compare_two_settings
(
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8"
,
[],
[
"--cpu-offload-gb"
,
"2"
])
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"awq"
),
reason
=
"awq is not supported on this GPU type."
)
def
test_cpu_offload_awq
():
compare_two_settings
(
"casperhansen/llama-3-8b-instruct-awq"
,
[],
[
"--cpu-offload-gb"
,
"2"
])
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"gptq_marlin"
),
reason
=
"gptq_marlin is not supported on this GPU type."
)
def
test_cpu_offload_compressed_tensors
():
# Test wNa16
compare_two_settings
(
"nm-testing/tinyllama-oneshot-w4a16-channel-v2"
,
[],
[
"--cpu-offload-gb"
,
"1"
])
# Test w4a16_marlin24
compare_two_settings
(
"nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
,
[],
[
"--cpu-offload-gb"
,
"1"
])
# Test w8a8
compare_two_settings
(
compare_two_settings
(
"nm-testing/llama
7b
-one
-
shot-
2_4-w4a16-marlin24-t
"
,
[],
"nm-testing/
tiny
llama-oneshot-
w8w8-test-static-shape-change
"
,
[],
[
"--cpu-offload-gb"
,
"1"
])
[
"--cpu-offload-gb"
,
"1"
])
vllm/model_executor/model_loader/loader.py
View file @
460c1884
...
@@ -7,6 +7,7 @@ import json
...
@@ -7,6 +7,7 @@ import json
import
math
import
math
import
os
import
os
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
,
Type
import
huggingface_hub
import
huggingface_hub
...
@@ -37,7 +38,49 @@ from vllm.model_executor.models.interfaces import (has_inner_state,
...
@@ -37,7 +38,49 @@ from vllm.model_executor.models.interfaces import (has_inner_state,
supports_vision
)
supports_vision
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_tpu
from
vllm.utils
import
is_pin_memory_available
,
is_tpu
@
contextmanager
def
device_loading_context
(
module
:
torch
.
nn
.
Module
,
target_device
:
torch
.
device
):
if
target_device
.
type
==
"cpu"
:
# If target is CPU, no need to move anything
yield
module
return
original_device_states
:
Dict
[
str
,
torch
.
device
]
=
{}
# Store original device states and move parameters to GPU if they're on CPU
for
name
,
p
in
module
.
named_parameters
():
if
p
.
device
.
type
==
"cpu"
:
original_device_states
[
name
]
=
p
.
device
p
.
data
=
p
.
data
.
to
(
target_device
)
# Parameters already on target device are not touched
try
:
yield
module
finally
:
# Restore parameters to their original devices, ignoring new parameters
pin_memory
=
is_pin_memory_available
()
for
name
,
p
in
module
.
named_parameters
():
if
name
in
original_device_states
:
original_device
:
torch
.
device
=
original_device_states
[
name
]
if
original_device
.
type
==
"cpu"
:
# `torch.empty_like` does not support `pin_memory` argument
cpu_data
=
torch
.
empty_strided
(
size
=
p
.
data
.
size
(),
stride
=
p
.
data
.
stride
(),
dtype
=
p
.
data
.
dtype
,
layout
=
p
.
data
.
layout
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
cpu_data
.
copy_
(
p
.
data
)
p
.
data
=
cpu_data
else
:
p
.
data
=
p
.
data
.
to
(
original_device
)
# New parameters or parameters already on target device are untouched
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -275,8 +318,9 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -275,8 +318,9 @@ class DefaultModelLoader(BaseModelLoader):
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
cache_config
:
CacheConfig
)
->
nn
.
Module
:
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
t
orch
.
device
(
device_config
.
device
)
:
with
t
arget_
device
:
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
lora_config
,
multimodal_config
,
lora_config
,
multimodal_config
,
cache_config
,
scheduler_config
)
cache_config
,
scheduler_config
)
...
@@ -291,6 +335,12 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -291,6 +335,12 @@ class DefaultModelLoader(BaseModelLoader):
for
_
,
module
in
model
.
named_modules
():
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
if
quant_method
is
not
None
:
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with
device_loading_context
(
module
,
target_device
):
quant_method
.
process_weights_after_loading
(
module
)
quant_method
.
process_weights_after_loading
(
module
)
return
model
.
eval
()
return
model
.
eval
()
...
...
vllm/model_executor/models/utils.py
View file @
460c1884
...
@@ -87,6 +87,7 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
...
@@ -87,6 +87,7 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
# offload parameters to CPU
# offload parameters to CPU
# use pin_memory if possible, which helps cudagraph capture speed
# use pin_memory if possible, which helps cudagraph capture speed
offloaded_parameters
=
False
for
p
in
module
.
parameters
():
for
p
in
module
.
parameters
():
if
_CPU_OFFLOAD_BYTES
>=
_CPU_OFFLOAD_MAX_BYTES
:
if
_CPU_OFFLOAD_BYTES
>=
_CPU_OFFLOAD_MAX_BYTES
:
# we use per-parameter offloading
# we use per-parameter offloading
...
@@ -94,7 +95,8 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
...
@@ -94,7 +95,8 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
break
break
# `torch.empty_like` does not support `pin_memory` argument
# `torch.empty_like` does not support `pin_memory` argument
cpu_data
=
torch
.
empty
(
size
=
p
.
data
.
size
(),
cpu_data
=
torch
.
empty_strided
(
size
=
p
.
data
.
size
(),
stride
=
p
.
data
.
stride
(),
dtype
=
p
.
data
.
dtype
,
dtype
=
p
.
data
.
dtype
,
layout
=
p
.
data
.
layout
,
layout
=
p
.
data
.
layout
,
device
=
'cpu'
,
device
=
'cpu'
,
...
@@ -102,9 +104,9 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
...
@@ -102,9 +104,9 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
cpu_data
.
copy_
(
p
.
data
)
cpu_data
.
copy_
(
p
.
data
)
p
.
data
=
cpu_data
p
.
data
=
cpu_data
_CPU_OFFLOAD_BYTES
+=
p
.
data
.
numel
()
*
p
.
data
.
element_size
()
_CPU_OFFLOAD_BYTES
+=
p
.
data
.
numel
()
*
p
.
data
.
element_size
()
offloaded_parameters
=
True
state_dict
:
Dict
[
str
,
torch
.
Tensor
]
=
module
.
state_dict
()
if
offloaded_parameters
:
original_forward
=
module
.
forward
original_forward
=
module
.
forward
def
forward
(
*
args
,
**
kwargs
):
def
forward
(
*
args
,
**
kwargs
):
...
@@ -113,7 +115,7 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
...
@@ -113,7 +115,7 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
# here we blindly call `to(device)`
# here we blindly call `to(device)`
# if the parameter is already on the device, it will be a no-op
# if the parameter is already on the device, it will be a no-op
k
:
v
.
to
(
device
,
non_blocking
=
True
)
k
:
v
.
to
(
device
,
non_blocking
=
True
)
for
k
,
v
in
state_dict
.
items
()
for
k
,
v
in
module
.
state_dict
()
.
items
()
}
}
output
=
functional_call
(
module
,
output
=
functional_call
(
module
,
device_state
,
device_state
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment