"vscode:/vscode.git/clone" did not exist on "5313c2cb8b3bcf7f71c0e6024c59d120efe94d88"
Unverified Commit 0130223b authored by Vasiliy Kuznetsov's avatar Vasiliy Kuznetsov Committed by GitHub
Browse files

fix memory for online fp8 quantization with streaming weight load (#31914)


Signed-off-by: default avatarvasiliy <vasiliy@fb.com>
parent 5d1aef30
...@@ -5,7 +5,10 @@ ...@@ -5,7 +5,10 @@
Run `pytest tests/quantization/test_fp8.py --forked`. Run `pytest tests/quantization/test_fp8.py --forked`.
""" """
import logging
import pytest import pytest
import regex as re
import torch import torch
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
...@@ -195,6 +198,99 @@ def test_online_quantization( ...@@ -195,6 +198,99 @@ def test_online_quantization(
print(outputs[0][1]) print(outputs[0][1])
@pytest.mark.skipif(
not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.",
)
def test_online_quant_peak_mem(
vllm_runner,
caplog_mp_spawn,
monkeypatch,
) -> None:
# Note: `allenai/OLMoE-1B-7B-0125-Instruct` was selected because:
# 1. it covers both Linear and MoE paths
# 2. it is already used by other tests in CI, so adding it here
# does not increase disk space for CI runners
# I really wanted to use `ibm-granite/granite-3.0-1b-a400m-base`
# which I think is the smallest MoE model in vLLM (2.5 GiB bf16,
# 1.3 GiB fp8), but could not as adding one more model makes CI
# run out of disk space.
model_name = "allenai/OLMoE-1B-7B-0125-Instruct"
# Force spawn to ensure caplog_mp_spawn works consistently
# (it relies on VLLM_LOGGING_CONFIG_PATH which spawn reads but fork ignores)
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
with (
caplog_mp_spawn(logging.DEBUG) as log_holder,
vllm_runner(
model_name,
quantization="fp8",
enforce_eager=True,
) as llm,
):
outputs = llm.generate_greedy(["The future of AI is"], max_tokens=4)
print(outputs[0][1])
log_text = log_holder.text
# Parse memory usage from captured logs
model_memory_gib = None
peak_memory_gib = None
for line in log_text.splitlines():
if model_memory_gib is None:
match = re.search(r"Model loading took ([\d.]+) GiB memory", line)
if match:
model_memory_gib = float(match.group(1))
if peak_memory_gib is None:
match = re.search(
r"Peak GPU memory after loading weights: ([\d.]+) GiB", line
)
if match:
peak_memory_gib = float(match.group(1))
assert model_memory_gib is not None, "Could not find model loading memory log"
assert peak_memory_gib is not None, "Could not find peak memory log"
print(f"GPU memory used after loading weights: {model_memory_gib} GiB")
print(f"Peak GPU memory usage while loading weights: {peak_memory_gib} GiB")
# model specific, allenai/OLMoE-1B-7B-0125-Instruct fp8 online quant
# uses 6.65 GiB for weight loading (bf16 checkpoint is ~12.89 GiB)
expected_model_memory_gib = 6.7
# for allenai/OLMoE-1B-7B-0125-Instruct the number we see today is 9.06
# GiB, which is 1.36x above model_memory_gib. A slightly higher number is
# expected as when we load and quantize weights in a streaming fashion we
# need to have individual weights in bf16 + fp8 alive at the same time.
expected_peak_memory_gib = expected_model_memory_gib * 1.4
assert model_memory_gib < expected_model_memory_gib, (
f"{model_memory_gib=} higher than {expected_model_memory_gib}"
)
assert peak_memory_gib < expected_peak_memory_gib, (
f"{peak_memory_gib=} higher than {expected_peak_memory_gib}"
)
@pytest.mark.skipif(
not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.",
)
def test_online_quant_load_format_dummy(
vllm_runner,
monkeypatch,
caplog,
) -> None:
with vllm_runner(
"ibm-granite/granite-3.0-1b-a400m-base",
quantization="fp8",
enforce_eager=True,
load_format="dummy",
) as llm:
outputs = llm.generate_greedy(["The future of AI is"], max_tokens=4)
print(outputs[0][1])
@pytest.mark.skipif( @pytest.mark.skipif(
not is_quant_method_supported("fp8"), not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.", reason="FP8 is not supported on this GPU type.",
......
...@@ -86,6 +86,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ...@@ -86,6 +86,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_fp8_supported, cutlass_fp8_supported,
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
) )
from vllm.model_executor.model_loader.weight_utils import initialize_single_dummy_weight
from vllm.model_executor.parameter import ( from vllm.model_executor.parameter import (
BlockQuantScaleParameter, BlockQuantScaleParameter,
ModelWeightParameter, ModelWeightParameter,
...@@ -293,6 +294,16 @@ class CopyNumelCounter(TorchDispatchMode): ...@@ -293,6 +294,16 @@ class CopyNumelCounter(TorchDispatchMode):
return out return out
def _copy_missing_attrs(old: torch.Tensor, new: torch.Tensor) -> None:
"""Copies any attrs present in `old` but not in `new` to `new`"""
new_attrs = set(dir(new))
attrs_to_set = {}
for attr in dir(old):
if attr not in new_attrs:
attrs_to_set[attr] = getattr(old, attr)
set_weight_attrs(new, attrs_to_set)
class Fp8LinearMethod(LinearMethodBase): class Fp8LinearMethod(LinearMethodBase):
"""Linear method for FP8. """Linear method for FP8.
Supports loading FP8 checkpoints with static weight scale and Supports loading FP8 checkpoints with static weight scale and
...@@ -578,6 +589,22 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod): ...@@ -578,6 +589,22 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
if not hasattr(layer, "_loaded_numel"): if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0 layer._loaded_numel = 0
# when the first `loaded_weight` is about to be
# loaded to `param`, materialize `param` just-in-time
weight = ModelWeightParameter(
data=torch.empty_like(layer.weight, device=layer._load_device),
input_dim=1,
output_dim=0,
weight_loader=patched_weight_loader,
)
_copy_missing_attrs(layer.weight, weight)
layer.register_parameter("weight", weight)
del layer._load_device
# refresh the reference to `param` to reflect just-in-time
# materialization
param = layer.weight
# load the current weight chunk # load the current weight chunk
copy_numel_counter = CopyNumelCounter() copy_numel_counter = CopyNumelCounter()
with copy_numel_counter: with copy_numel_counter:
...@@ -590,30 +617,50 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod): ...@@ -590,30 +617,50 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
if layer._loaded_numel == target_loaded_numel: if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer) self.process_weights_after_loading(layer)
# Delete the bookkeeping
del layer._loaded_numel
# Prevent the usual `process_weights_after_loading` call from doing # Prevent the usual `process_weights_after_loading` call from doing
# anything # anything
layer._already_called_process_weights_after_loading = True layer._already_called_process_weights_after_loading = True
# Note that we keep `layer._loaded_numel` around just in case
# there is logic added to vllm in the future which calls a
# weight loader twice - we do not want to re-initialize in
# that case.
return res return res
weight = ModelWeightParameter( weight = ModelWeightParameter(
data=torch.empty( data=torch.empty(
output_size_per_partition, output_size_per_partition,
input_size_per_partition, input_size_per_partition,
# materialized just-in-time in `patched_weight_loader`
device="meta",
dtype=params_dtype, dtype=params_dtype,
), ),
input_dim=1, input_dim=1,
output_dim=0, output_dim=0,
weight_loader=patched_weight_loader, weight_loader=patched_weight_loader,
) )
# stash the correct device for `patched_weight_loader`
layer._load_device = torch.get_default_device()
layer.register_parameter("weight", weight) layer.register_parameter("weight", weight)
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False): if getattr(layer, "_already_called_process_weights_after_loading", False):
return return
# deferred initialization of randomly initialized weights for the
# `--load_format dummy` feature
if layer.weight.device == torch.device("meta"):
weight = ModelWeightParameter(
data=torch.empty_like(layer.weight, device=layer._load_device),
input_dim=1,
output_dim=0,
weight_loader=layer.weight.weight_loader,
)
_copy_missing_attrs(layer.weight, weight)
layer.register_parameter("weight", weight)
initialize_single_dummy_weight(layer.weight)
# TODO(future): support block_quant in online quant path # TODO(future): support block_quant in online quant path
assert not self.block_quant assert not self.block_quant
...@@ -1069,6 +1116,39 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod): ...@@ -1069,6 +1116,39 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
if not hasattr(layer, "_loaded_numel"): if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0 layer._loaded_numel = 0
# save the ids of original w13 and w2 so that we can
# distinguish which one `param` should map to further
# down in this file
layer._w13_weight_orig_id = id(layer.w13_weight)
layer._w2_weight_orig_id = id(layer.w2_weight)
# when the first `loaded_weight` is about to be
# loaded to `param`, materialize `param` just-in-time
w13_weight = torch.nn.Parameter(
torch.empty_like(layer.w13_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(w13_weight, extra_weight_attrs)
_copy_missing_attrs(layer.w13_weight, w13_weight)
layer.register_parameter("w13_weight", w13_weight)
w2_weight = torch.nn.Parameter(
torch.empty_like(layer.w2_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(w2_weight, extra_weight_attrs)
_copy_missing_attrs(layer.w2_weight, w2_weight)
layer.register_parameter("w2_weight", w2_weight)
del layer._load_device
# refresh the reference to `param` to reflect just-in-time
# materialization
if id(param) == layer._w13_weight_orig_id:
param = layer.w13_weight
elif id(param) == layer._w2_weight_orig_id:
param = layer.w2_weight
# load the current weight chunk # load the current weight chunk
copy_numel_counter = CopyNumelCounter() copy_numel_counter = CopyNumelCounter()
with copy_numel_counter: with copy_numel_counter:
...@@ -1081,12 +1161,16 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod): ...@@ -1081,12 +1161,16 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
if layer._loaded_numel == target_loaded_numel: if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer) self.process_weights_after_loading(layer)
# Delete the bookkeeping
del layer._loaded_numel
# Prevent the usual `process_weights_after_loading` call # Prevent the usual `process_weights_after_loading` call
# from doing anything # from doing anything
layer._already_called_process_weights_after_loading = True layer._already_called_process_weights_after_loading = True
# Note that we keep `layer._loaded_numel`,
# `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id`
# around because if EP is on, weight loaders for non-local
# experts will run but not actually copy any elements, and we
# need to not re-initialize in that case.
return res return res
new_extra_weight_attrs["weight_loader"] = patched_weight_loader new_extra_weight_attrs["weight_loader"] = patched_weight_loader
...@@ -1098,6 +1182,8 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod): ...@@ -1098,6 +1182,8 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
num_experts, num_experts,
2 * intermediate_size_per_partition, 2 * intermediate_size_per_partition,
hidden_size, hidden_size,
# materialized just-in-time in `patched_weight_loader`
device="meta",
dtype=params_dtype, dtype=params_dtype,
), ),
requires_grad=False, requires_grad=False,
...@@ -1110,12 +1196,16 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod): ...@@ -1110,12 +1196,16 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
num_experts, num_experts,
hidden_size, hidden_size,
intermediate_size_per_partition, intermediate_size_per_partition,
# materialized just-in-time in `patched_weight_loader`
device="meta",
dtype=params_dtype, dtype=params_dtype,
), ),
requires_grad=False, requires_grad=False,
) )
layer.register_parameter("w2_weight", w2_weight) layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
# stash the correct device for `patched_weight_loader`
layer._load_device = torch.get_default_device()
# WEIGHT_SCALES # WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively. # Allocate 2 scales for w1 and w3 respectively.
...@@ -1138,6 +1228,31 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod): ...@@ -1138,6 +1228,31 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
if getattr(layer, "_already_called_process_weights_after_loading", False): if getattr(layer, "_already_called_process_weights_after_loading", False):
return return
# deferred initialization of randomly initialized weights for the
# `--load_format dummy` feature
if layer.w13_weight.device == torch.device("meta"):
w13_weight = torch.nn.Parameter(
torch.empty_like(layer.w13_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(
w13_weight, {"weight_loader": layer.w13_weight.weight_loader}
)
_copy_missing_attrs(layer.w13_weight, w13_weight)
layer.register_parameter("w13_weight", w13_weight)
initialize_single_dummy_weight(layer.w13_weight)
if layer.w2_weight.device == torch.device("meta"):
w2_weight = torch.nn.Parameter(
torch.empty_like(layer.w2_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(
w2_weight, {"weight_loader": layer.w2_weight.weight_loader}
)
_copy_missing_attrs(layer.w2_weight, w2_weight)
layer.register_parameter("w2_weight", w2_weight)
initialize_single_dummy_weight(layer.w2_weight)
# If checkpoint is fp16, quantize in place. # If checkpoint is fp16, quantize in place.
fp8_dtype = current_platform.fp8_dtype() fp8_dtype = current_platform.fp8_dtype()
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype) w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
......
...@@ -13,6 +13,8 @@ from vllm.model_executor.model_loader.utils import ( ...@@ -13,6 +13,8 @@ from vllm.model_executor.model_loader.utils import (
initialize_model, initialize_model,
process_weights_after_loading, process_weights_after_loading,
) )
from vllm.platforms import current_platform
from vllm.utils.mem_utils import format_gib
from vllm.utils.torch_utils import set_default_torch_dtype from vllm.utils.torch_utils import set_default_torch_dtype
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -56,6 +58,17 @@ class BaseModelLoader(ABC): ...@@ -56,6 +58,17 @@ class BaseModelLoader(ABC):
logger.debug("Loading weights on %s ...", load_device) logger.debug("Loading weights on %s ...", load_device)
# Quantization does not happen in `load_weights` but after it # Quantization does not happen in `load_weights` but after it
self.load_weights(model, model_config) self.load_weights(model, model_config)
# Log peak GPU memory after loading weights. This is needed
# to have test coverage on peak memory for online quantization.
if current_platform.is_cuda():
peak_memory = torch.cuda.max_memory_allocated()
logger.debug_once(
"Peak GPU memory after loading weights: %s GiB",
format_gib(peak_memory),
scope="local",
)
process_weights_after_loading(model, model_config, target_device) process_weights_after_loading(model, model_config, target_device)
return model.eval() return model.eval()
......
...@@ -25,4 +25,4 @@ class DummyModelLoader(BaseModelLoader): ...@@ -25,4 +25,4 @@ class DummyModelLoader(BaseModelLoader):
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
# NOTE(woosuk): For accurate performance evaluation, we assign # NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights. # random values to the weights.
initialize_dummy_weights(model) initialize_dummy_weights(model, model_config)
...@@ -1059,6 +1059,7 @@ def composed_weight_loader( ...@@ -1059,6 +1059,7 @@ def composed_weight_loader(
def initialize_dummy_weights( def initialize_dummy_weights(
model: torch.nn.Module, model: torch.nn.Module,
model_config: ModelConfig,
low: float = -1e-3, low: float = -1e-3,
high: float = 1e-3, high: float = 1e-3,
seed: int = 1234, seed: int = 1234,
...@@ -1075,7 +1076,27 @@ def initialize_dummy_weights( ...@@ -1075,7 +1076,27 @@ def initialize_dummy_weights(
is fixed, the random values generated by this function only depends on is fixed, the random values generated by this function only depends on
the parameter's number of elements and its data type. the parameter's number of elements and its data type.
""" """
# TODO(future PR): make the check below more generic as more online
# quant backends are added
is_fp8_py_quant = model_config.quantization == "fp8"
for param in model.state_dict().values(): for param in model.state_dict().values():
if is_fp8_py_quant and param.device == torch.device("meta"):
# for fp8.py's online quantization, dummy weight init will happen
# in `process_weights_after_loading`.
# TODO(future PR): consider refactoring dummy model init to compose
# better with online quantization
continue
initialize_single_dummy_weight(param, low, high, seed)
def initialize_single_dummy_weight(
param: torch.Tensor,
low: float = -1e-3,
high: float = 1e-3,
seed: int = 1234,
) -> None:
if torch.is_floating_point(param): if torch.is_floating_point(param):
if current_platform.is_tpu(): if current_platform.is_tpu():
generator = torch.Generator(device="cpu") generator = torch.Generator(device="cpu")
...@@ -1098,7 +1119,7 @@ def initialize_dummy_weights( ...@@ -1098,7 +1119,7 @@ def initialize_dummy_weights(
+ low + low
) )
torch._sync(param) torch._sync(param)
continue return
generator = torch.Generator(device=param.data.device) generator = torch.Generator(device=param.data.device)
generator.manual_seed(seed) generator.manual_seed(seed)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment