Unverified Commit 19fa90ed authored by Asaf Gardin's avatar Asaf Gardin Committed by GitHub
Browse files

[Quantization] - Layerwise reloading of Attention/KV quantized models (#38995)


Signed-off-by: default avatarJosephasafg <ajgard7@gmail.com>
parent 03f8d3a5
......@@ -164,6 +164,34 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
assert add_perp < mul_perp
def test_kv_scale_reload(vllm_runner):
"""Test reloading a checkpoint that contains k_scale/v_scale weights."""
if not current_platform.supports_fp8():
pytest.skip(reason="Requires FP8 support")
model = "nm-testing/Llama-3.2-1B-Instruct-FP8-KV"
# Load dummy weights, then reload real checkpoint
with vllm_runner(
model_name=model,
load_format="dummy",
enable_prefix_caching=False,
max_model_len=16,
max_num_seqs=1,
) as llm:
llm.collective_rpc(
"update_config",
kwargs={"overrides": {"load_config": {"load_format": "auto"}}},
)
llm.collective_rpc("reload_weights", kwargs={"weights_path": model})
reloaded_perp = llm.generate_prompt_perplexity(
["The capital of France is the city of Paris"],
mask=["The capital of France is"],
)[0]
assert reloaded_perp < 10
@pytest.mark.parametrize(
"tp_size", [pytest.param(1), pytest.param(2, marks=[pytest.mark.slow_test])]
)
......
......@@ -8,10 +8,9 @@ which is useful for weight updates without full model reconstruction.
Limitations:
1. Composition with CPU offloading has not been implemented
2. Reloading Attention/MLA weights (q_scale, k_scale, v_scale) has not been implemented
3. Tied parameters will only reflect processing from one of the parent layers (for
2. Tied parameters will only reflect processing from one of the parent layers (for
example, only processing from embed_tokens will have an effect)
4. This design assumes that the number of weights loaded from disk is the same as the
3. This design assumes that the number of weights loaded from disk is the same as the
number of weights created at model init time. This is not true for quant methods
which (1) pad weights or (2) load qkv weights into the same parameter. Both of these
cases are non-issues for today's quant methods, but future quantizations may cause
......
......@@ -200,6 +200,8 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon
if hasattr(model, "_original_do_torchao_reload"):
model._do_torchao_reload = model._original_do_torchao_reload
deferred_attn: list[tuple[torch.nn.Module, LayerReloadingInfo]] = []
for layer in model.modules():
info = get_layerwise_info(layer)
if not info.can_load():
......@@ -208,22 +210,11 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon
# Attention/MLA layers are processed after all other layers
if isinstance(layer, (Attention, MLAAttention)):
if info.load_numel > 0:
raise NotImplementedError(
"Layerwise reloading of Q/K/V scale weights is not implemented yet"
)
elif info.kernel_tensors is None:
raise NotImplementedError(
"Layerwise loading of Q/K/V scale weights is not implemented yet"
)
else:
_place_kernel_tensors(layer, info)
layer.process_weights_after_loading(model_config.dtype)
deferred_attn.append((layer, info))
continue
# No weights were loaded
elif info.load_numel <= 0:
if info.load_numel <= 0:
# first load: checkpoint did not contain weights for this layer
if info.kernel_tensors is None:
_layerwise_process(layer, info)
......@@ -244,11 +235,58 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon
info.reset()
# Process attention layers after all other layers are done
for layer, info in deferred_attn:
_finalize_attention_layer(layer, info, model_config)
info.reset()
def finalize_layerwise_reload(*args, **kwargs):
finalize_layerwise_processing(*args, **kwargs)
def _finalize_attention_layer(
layer: torch.nn.Module, info: LayerReloadingInfo, model_config: ModelConfig
) -> None:
if info.load_numel > 0 and info.kernel_tensors is not None:
# Reload with new scale weights from checkpoint
_place_kernel_tensors(layer, info)
_reload_attention_scales(layer, info)
elif info.load_numel > 0 or info.kernel_tensors is None:
raise ValueError(
"Layerwise loading of attention layers is not supported. "
"Attention must always process after linears."
)
else:
_place_kernel_tensors(layer, info)
layer.process_weights_after_loading(model_config.dtype)
def _reload_attention_scales(layer: torch.nn.Module, info: LayerReloadingInfo) -> None:
"""Load and process attention scale weights (k_scale, v_scale, etc.)
during reload.
Assumes dtype/shapes of attention tensors do not change during
processing, since we use .data.copy_() to preserve kernel tensor
references."""
quant_method = getattr(layer, "quant_method", None)
if quant_method is None:
return
# Re-create scale Parameters with sentinel values so unloaded scales
# are correctly detected by process_weights_after_loading
quant_method.create_weights(layer)
for name, args in info.loaded_weights:
param = getattr(layer, name)
args.arguments["param"] = param
_get_weight_loader(param)(*args.args, **args.kwargs)
quant_method.process_weights_after_loading(layer)
_copy_and_restore_kernel_tensors(layer, info)
def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
"""
Finalize layer loading after all weights have been buffered.
......@@ -278,7 +316,6 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
param.weight_loader(*args.args, **args.kwargs)
# Process weights (quantization, repacking, etc.)
# Attention/MLA are processed in `finalize_layerwise_reload`
quant_method = getattr(layer, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
quant_method.process_weights_after_loading(layer)
......@@ -286,13 +323,7 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
# Copy processed values into original tensor storage (preserves cudagraph refs)
# this code is a no-op if not reloading (because kernel tensors is empty)
if info.kernel_tensors is not None:
parameters, buffers = info.kernel_tensors
for name, param in parameters.items():
param.data.copy_(getattr(layer, name))
for name, buffer in buffers.items():
buffer.data.copy_(getattr(layer, name))
_place_kernel_tensors(layer, info)
_copy_and_restore_kernel_tensors(layer, info)
info.reset()
logger.debug("%s: Processed", layer.__class__.__name__)
......@@ -311,6 +342,19 @@ def _get_weight_loader(tensor: torch.Tensor):
return getattr(tensor, "weight_loader", default_weight_loader)
def _copy_and_restore_kernel_tensors(layer: torch.nn.Module, info: LayerReloadingInfo):
"""Copy processed values into original kernel tensor storage and restore
kernel tensor references on the layer. Preserves cudagraph references."""
assert info.kernel_tensors is not None
parameters, buffers = info.kernel_tensors
for name, param in parameters.items():
param.data.copy_(getattr(layer, name))
for name, buffer in buffers.items():
buffer.data.copy_(getattr(layer, name))
_place_kernel_tensors(layer, info)
def _place_kernel_tensors(layer: torch.nn.Module, info: LayerReloadingInfo):
for name in get_layer_tensors(layer):
delattr(layer, name)
......
......@@ -1364,8 +1364,8 @@ def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> N
if param.numel() == 1 and loaded_weight.numel() == 1:
# Sometimes scalar values aren't considered tensors with shapes
# so if both param and loaded_weight are a scalar,
# "broadcast" instead of copy
param.data.fill_(loaded_weight.item())
# reshape to match before copying
param.data.copy_(loaded_weight.view(param.shape))
else:
assert param.size() == loaded_weight.size(), (
f"Attempted to load weight ({loaded_weight.size()}) "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment