"vscode:/vscode.git/clone" did not exist on "487678d046fe56560ff5dc6c91c3f3c31af7de6f"
Unverified Commit d28d86e8 authored by Kyle Sayers's avatar Kyle Sayers Committed by GitHub
Browse files

[QeRL] Fix online quantized reloading (#38442)


Signed-off-by: default avatarKyle Sayers <kylesayrs@gmail.com>
parent 995dea13
...@@ -812,7 +812,7 @@ steps: ...@@ -812,7 +812,7 @@ steps:
commands: commands:
- apt-get update && apt-get install -y curl libsodium23 - apt-get update && apt-get install -y curl libsodium23
- export VLLM_WORKER_MULTIPROC_METHOD=spawn - export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s model_executor - pytest -v -s model_executor -m '(not slow_test)'
- pytest -v -s entrypoints/openai/completion/test_tensorizer_entrypoint.py - pytest -v -s entrypoints/openai/completion/test_tensorizer_entrypoint.py
...@@ -1242,7 +1242,7 @@ steps: ...@@ -1242,7 +1242,7 @@ steps:
- vllm/platforms/rocm.py - vllm/platforms/rocm.py
commands: commands:
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)' - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py -m '(not slow_test)'
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)' - pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/language -v -s -m 'distributed(num_gpus=2)' - pytest models/language -v -s -m 'distributed(num_gpus=2)'
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py - pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py
...@@ -2501,7 +2501,7 @@ steps: ...@@ -2501,7 +2501,7 @@ steps:
- tests/models/ - tests/models/
commands: commands:
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)' - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py -m '(not slow_test)'
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)' - pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/language -v -s -m 'distributed(num_gpus=2)' - pytest models/language -v -s -m 'distributed(num_gpus=2)'
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py - pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py
......
...@@ -13,5 +13,5 @@ steps: ...@@ -13,5 +13,5 @@ steps:
commands: commands:
- apt-get update && apt-get install -y curl libsodium23 - apt-get update && apt-get install -y curl libsodium23
- export VLLM_WORKER_MULTIPROC_METHOD=spawn - export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s model_executor - pytest -v -s model_executor -m '(not slow_test)'
- pytest -v -s entrypoints/openai/completion/test_tensorizer_entrypoint.py - pytest -v -s entrypoints/openai/completion/test_tensorizer_entrypoint.py
...@@ -14,7 +14,7 @@ steps: ...@@ -14,7 +14,7 @@ steps:
- tests/models/ - tests/models/
commands: commands:
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)' - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py -m '(not slow_test)'
# Avoid importing model tests that cause CUDA reinitialization error # Avoid importing model tests that cause CUDA reinitialization error
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)' - pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/language -v -s -m 'distributed(num_gpus=2)' - pytest models/language -v -s -m 'distributed(num_gpus=2)'
......
...@@ -38,7 +38,10 @@ def test_move_metatensors(): ...@@ -38,7 +38,10 @@ def test_move_metatensors():
def test_reload_lifecycle(): def test_reload_lifecycle():
layer = torch.nn.Linear(2, 3) layer = torch.nn.Linear(2, 3)
info = LayerReloadingInfo(restore_metadata=capture_layer_to_meta(layer)) info = LayerReloadingInfo(
restore_metadata=capture_layer_to_meta(layer),
restore_device=torch.device("cpu"),
)
restore_layer_on_meta(layer, info) restore_layer_on_meta(layer, info)
for name, tensor in get_layer_tensors(layer).items(): for name, tensor in get_layer_tensors(layer).items():
...@@ -48,7 +51,7 @@ def test_reload_lifecycle(): ...@@ -48,7 +51,7 @@ def test_reload_lifecycle():
assert tensor.__class__ == meta_tensor.__class__ assert tensor.__class__ == meta_tensor.__class__
assert tensor.__dict__ == meta_tensor.__dict__ assert tensor.__dict__ == meta_tensor.__dict__
materialize_layer(layer) materialize_layer(layer, info)
for name, tensor in get_layer_tensors(layer).items(): for name, tensor in get_layer_tensors(layer).items():
materialized_tensor = getattr(layer, name) materialized_tensor = getattr(layer, name)
assert tensor.dtype == materialized_tensor.dtype assert tensor.dtype == materialized_tensor.dtype
...@@ -60,7 +63,10 @@ def test_reload_lifecycle(): ...@@ -60,7 +63,10 @@ def test_reload_lifecycle():
def test_model_cleanup(dist_init, default_vllm_config): def test_model_cleanup(dist_init, default_vllm_config):
layer = QKVParallelLinear(2, 3, 4) layer = QKVParallelLinear(2, 3, 4)
assert layer.weight.weight_loader.__self__ is layer assert layer.weight.weight_loader.__self__ is layer
info = LayerReloadingInfo(restore_metadata=capture_layer_to_meta(layer)) info = LayerReloadingInfo(
restore_metadata=capture_layer_to_meta(layer),
restore_device=torch.device("cpu"),
)
mock_info_dict: WeakKeyDictionary[torch.nn.Module, LayerReloadingInfo] = ( mock_info_dict: WeakKeyDictionary[torch.nn.Module, LayerReloadingInfo] = (
WeakKeyDictionary() WeakKeyDictionary()
...@@ -90,39 +96,46 @@ def test_get_numel_loaded(): ...@@ -90,39 +96,46 @@ def test_get_numel_loaded():
assert ret == "value" assert ret == "value"
@pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize(
"tp_size", [pytest.param(1), pytest.param(2, marks=[pytest.mark.slow_test])]
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"base_model,mul_model,add_model", "base_model,mul_model,add_model",
[ [
( pytest.param(
"Qwen/Qwen3-0.6B", "Qwen/Qwen3-0.6B",
"inference-optimization/Qwen3-0.6B-debug-multiply", "inference-optimization/Qwen3-0.6B-debug-multiply",
"inference-optimization/Qwen3-0.6B-debug-add", "inference-optimization/Qwen3-0.6B-debug-add",
marks=[pytest.mark.slow_test],
), ),
( pytest.param(
"inference-optimization/Qwen3-0.6B-FP8_BLOCK", "inference-optimization/Qwen3-0.6B-FP8_BLOCK",
"inference-optimization/Qwen3-0.6B-debug-multiply-FP8_BLOCK", "inference-optimization/Qwen3-0.6B-debug-multiply-FP8_BLOCK",
"inference-optimization/Qwen3-0.6B-debug-add-FP8_BLOCK", "inference-optimization/Qwen3-0.6B-debug-add-FP8_BLOCK",
marks=[pytest.mark.slow_test],
), ),
( pytest.param(
"inference-optimization/Qwen3-0.6B-W4A16-G128", "inference-optimization/Qwen3-0.6B-W4A16-G128",
"inference-optimization/Qwen3-0.6B-debug-multiply-W4A16-G128", "inference-optimization/Qwen3-0.6B-debug-multiply-W4A16-G128",
"inference-optimization/Qwen3-0.6B-debug-add-W4A16-G128", "inference-optimization/Qwen3-0.6B-debug-add-W4A16-G128",
marks=[pytest.mark.slow_test],
), ),
( pytest.param(
"inference-optimization/DeepSeek-V3-debug-empty", "inference-optimization/DeepSeek-V3-debug-empty",
"inference-optimization/DeepSeek-V3-debug-multiply", "inference-optimization/DeepSeek-V3-debug-multiply",
"inference-optimization/DeepSeek-V3-debug-add", "inference-optimization/DeepSeek-V3-debug-add",
marks=[pytest.mark.slow_test],
), ),
( pytest.param(
"inference-optimization/DeepSeek-V3-debug-empty-FP8_DYNAMIC", "inference-optimization/DeepSeek-V3-debug-empty-FP8_DYNAMIC",
"inference-optimization/DeepSeek-V3-debug-multiply-FP8_DYNAMIC", "inference-optimization/DeepSeek-V3-debug-multiply-FP8_DYNAMIC",
"inference-optimization/DeepSeek-V3-debug-add-FP8_DYNAMIC", "inference-optimization/DeepSeek-V3-debug-add-FP8_DYNAMIC",
), ),
( pytest.param(
"inference-optimization/DeepSeek-V3-debug-empty-NVFP4A16", "inference-optimization/DeepSeek-V3-debug-empty-NVFP4A16",
"inference-optimization/DeepSeek-V3-debug-multiply-NVFP4A16", "inference-optimization/DeepSeek-V3-debug-multiply-NVFP4A16",
"inference-optimization/DeepSeek-V3-debug-add-NVFP4A16", "inference-optimization/DeepSeek-V3-debug-add-NVFP4A16",
marks=[pytest.mark.slow_test],
), ),
], ],
) )
...@@ -138,6 +151,8 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner): ...@@ -138,6 +151,8 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
tensor_parallel_size=tp_size, tensor_parallel_size=tp_size,
enable_expert_parallel=(tp_size > 1 and "DeepSeek" in base_model), enable_expert_parallel=(tp_size > 1 and "DeepSeek" in base_model),
enable_prefix_caching=False, enable_prefix_caching=False,
max_model_len=16,
max_num_seqs=1,
) as llm: ) as llm:
llm.collective_rpc("reload_weights", kwargs={"weights_path": mul_model}) llm.collective_rpc("reload_weights", kwargs={"weights_path": mul_model})
mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0] mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0]
...@@ -150,34 +165,42 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner): ...@@ -150,34 +165,42 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
assert add_perp < mul_perp assert add_perp < mul_perp
@pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize(
"tp_size", [pytest.param(1), pytest.param(2, marks=[pytest.mark.slow_test])]
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"base_model,mul_model,add_model,quantization", "base_model,mul_model,add_model,quantization",
[ [
( pytest.param(
"Qwen/Qwen3-0.6B", "Qwen/Qwen3-0.6B",
"inference-optimization/Qwen3-0.6B-debug-multiply", "inference-optimization/Qwen3-0.6B-debug-multiply",
"inference-optimization/Qwen3-0.6B-debug-add", "inference-optimization/Qwen3-0.6B-debug-add",
"fp8", "fp8",
), ),
( pytest.param(
"inference-optimization/DeepSeek-V3-debug-empty", "inference-optimization/DeepSeek-V3-debug-empty",
"inference-optimization/DeepSeek-V3-debug-multiply", "inference-optimization/DeepSeek-V3-debug-multiply",
"inference-optimization/DeepSeek-V3-debug-add", "inference-optimization/DeepSeek-V3-debug-add",
"fp8", "fp8",
marks=[pytest.mark.slow_test],
), ),
( pytest.param(
"Qwen/Qwen3-0.6B", "Qwen/Qwen3-0.6B",
"inference-optimization/Qwen3-0.6B-debug-multiply", "inference-optimization/Qwen3-0.6B-debug-multiply",
"inference-optimization/Qwen3-0.6B-debug-add", "inference-optimization/Qwen3-0.6B-debug-add",
"mxfp8", "mxfp8",
marks=[pytest.mark.slow_test],
),
pytest.param(
"inference-optimization/DeepSeek-V3-debug-empty",
"inference-optimization/DeepSeek-V3-debug-multiply",
"inference-optimization/DeepSeek-V3-debug-add",
"mxfp8",
marks=[
pytest.mark.slow_test,
pytest.mark.xfail(reason="mxfp4 & mla is not supported yet"),
],
), ),
# ( TODO: support mxfp4 & mla
# "inference-optimization/DeepSeek-V3-debug-empty",
# "inference-optimization/DeepSeek-V3-debug-multiply",
# "inference-optimization/DeepSeek-V3-debug-add",
# "mxfp8",
# ),
], ],
) )
def test_online_quantize_reload( def test_online_quantize_reload(
...@@ -195,6 +218,8 @@ def test_online_quantize_reload( ...@@ -195,6 +218,8 @@ def test_online_quantize_reload(
tensor_parallel_size=tp_size, tensor_parallel_size=tp_size,
enable_expert_parallel=(tp_size > 1 and "DeepSeek" in base_model), enable_expert_parallel=(tp_size > 1 and "DeepSeek" in base_model),
enable_prefix_caching=False, enable_prefix_caching=False,
max_model_len=16,
max_num_seqs=1,
) as llm: ) as llm:
llm.collective_rpc("reload_weights", kwargs={"weights_path": mul_model}) llm.collective_rpc("reload_weights", kwargs={"weights_path": mul_model})
mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0] mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0]
......
...@@ -1006,14 +1006,17 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod): ...@@ -1006,14 +1006,17 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
initialize_online_processing(layer) initialize_online_processing(layer)
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
# TODO(@ksayers): inplace fp8 quant kernel, initialize scales with ones
if getattr(layer, "_already_called_process_weights_after_loading", False): if getattr(layer, "_already_called_process_weights_after_loading", False):
return return
fp8_dtype = current_platform.fp8_dtype() fp8_dtype = current_platform.fp8_dtype()
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype) w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype) w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
w13_scale = torch.ones(layer.num_experts, dtype=torch.float32) w13_scale = torch.ones(
w2_scale = torch.ones(layer.num_experts, dtype=torch.float32) layer.num_experts, device=w13.device, dtype=torch.float32
)
w2_scale = torch.ones(layer.num_experts, device=w2.device, dtype=torch.float32)
layer.w13_input_scale = None layer.w13_input_scale = None
layer.w2_input_scale = None layer.w2_input_scale = None
......
...@@ -49,7 +49,10 @@ def get_layerwise_info(layer: torch.nn.Module) -> LayerReloadingInfo: ...@@ -49,7 +49,10 @@ def get_layerwise_info(layer: torch.nn.Module) -> LayerReloadingInfo:
information existed, a new entry is constructed information existed, a new entry is constructed
""" """
if layer not in LAYERWISE_INFO: if layer not in LAYERWISE_INFO:
LAYERWISE_INFO[layer] = LayerReloadingInfo() LAYERWISE_INFO[layer] = LayerReloadingInfo(
restore_metadata=({}, {}),
restore_device=torch.get_default_device(),
)
return LAYERWISE_INFO[layer] return LAYERWISE_INFO[layer]
...@@ -64,6 +67,7 @@ def record_metadata_for_reloading(model: torch.nn.Module): ...@@ -64,6 +67,7 @@ def record_metadata_for_reloading(model: torch.nn.Module):
for layer in model.modules(): for layer in model.modules():
info = get_layerwise_info(layer) info = get_layerwise_info(layer)
info.restore_metadata = capture_layer_to_meta(layer) info.restore_metadata = capture_layer_to_meta(layer)
info.restore_device = torch.get_default_device()
@torch.no_grad() @torch.no_grad()
...@@ -99,10 +103,18 @@ def initialize_layerwise_reload(model: torch.nn.Module): ...@@ -99,10 +103,18 @@ def initialize_layerwise_reload(model: torch.nn.Module):
# Restore layer parameters/buffers onto meta device # Restore layer parameters/buffers onto meta device
restore_layer_on_meta(layer, info) restore_layer_on_meta(layer, info)
# Wrap weight loaders to buffer loading
initialize_online_processing(layer) initialize_online_processing(layer)
def initialize_online_processing(layer: torch.nn.Module): def initialize_online_processing(layer: torch.nn.Module):
"""
Wrap a layer's weight loaders with online processing loaders.
Called by either `initialize_layerwise_reload` or an online quantization scheme,
prevents double wrapping in the case of online quantization + reloading
:param layer: layer whose parameter weight loaders will be wrapped
"""
info = get_layerwise_info(layer) info = get_layerwise_info(layer)
# Track loading progress to determine when to process/copy # Track loading progress to determine when to process/copy
...@@ -211,7 +223,7 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon ...@@ -211,7 +223,7 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon
elif info.load_numel <= 0: elif info.load_numel <= 0:
# first load but received no weights. This happens on dummy load # first load but received no weights. This happens on dummy load
if info.kernel_tensors is None: if info.kernel_tensors is None:
materialize_layer(layer) materialize_layer(layer, info)
# reloading: place kernel tensors back as a fallback # reloading: place kernel tensors back as a fallback
else: else:
...@@ -244,7 +256,7 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo): ...@@ -244,7 +256,7 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
4. Copies processed values back to original tensor storage 4. Copies processed values back to original tensor storage
""" """
# Materialize layer tensors onto device # Materialize layer tensors onto device
materialize_layer(layer) materialize_layer(layer, info)
# Reset online quantization flag so process_weights_after_loading # Reset online quantization flag so process_weights_after_loading
# will run again during reload # will run again during reload
......
...@@ -94,11 +94,12 @@ def restore_layer_on_meta(layer: torch.nn.Module, info: LayerReloadingInfo): ...@@ -94,11 +94,12 @@ def restore_layer_on_meta(layer: torch.nn.Module, info: LayerReloadingInfo):
layer.register_buffer(name, buffer) layer.register_buffer(name, buffer)
def materialize_layer(layer: torch.nn.Module) -> None: def materialize_layer(layer: torch.nn.Module, info: LayerReloadingInfo):
"""Materialize all meta tensors in a layer to actual tensors.""" """Materialize all meta tensors in a layer to actual tensors."""
if layer.__class__.__name__ in SKIP_MODULES: if layer.__class__.__name__ in SKIP_MODULES:
return return
with info.restore_device:
for name, tensor in get_layer_tensors(layer).items(): for name, tensor in get_layer_tensors(layer).items():
if name not in SKIP_TENSORS: if name not in SKIP_TENSORS:
setattr(layer, name, materialize_meta_tensor(tensor)) setattr(layer, name, materialize_meta_tensor(tensor))
......
...@@ -13,21 +13,26 @@ LayerTensors = tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]] ...@@ -13,21 +13,26 @@ LayerTensors = tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]
@dataclass @dataclass
class LayerReloadingInfo: class LayerReloadingInfo:
# model format (meta), populated by `record_metadata_for_reloading` # model format metadata, recorded by `record_metadata_for_reloading`
restore_metadata: LayerTensors = field(default_factory=lambda: ({}, {})) restore_metadata: LayerTensors
# kernel format (device), used to copy into when reloading only # device to materialize layers with, recorded by `record_metadata_for_reloading`
kernel_tensors: LayerTensors | None = None restore_device: torch.device
# track how many restored elements are ready for loading # track how many elements are ready for loading, used by `online_process_loader`
load_numel: int = 0 load_numel: int = 0
load_numel_total: int | None = None load_numel_total: int | None = None
# stores arguments and tensors ready for loading # used by `online_process_loader` to buffer args and tensors until ready to load
loaded_weights: list[tuple[str, BoundArguments]] = field(default_factory=list) loaded_weights: list[tuple[str, BoundArguments]] = field(default_factory=list)
# kernel formatted tensors, copied into by `_layerwise_process` when reloading
kernel_tensors: LayerTensors | None = None
def reset(self): def reset(self):
self.__init__(restore_metadata=self.restore_metadata) # type: ignore[misc] self.__init__( # type: ignore[misc]
restore_metadata=self.restore_metadata, restore_device=self.restore_device
)
def can_load(self) -> bool: def can_load(self) -> bool:
return self.load_numel_total is not None return self.load_numel_total is not None
...@@ -4943,10 +4943,6 @@ class GPUModelRunner( ...@@ -4943,10 +4943,6 @@ class GPUModelRunner(
# begin loading weights # begin loading weights
logger.info_once("Reloading weights inplace...", scope="local") logger.info_once("Reloading weights inplace...", scope="local")
load_device = (
self.vllm_config.load_config.device or self.vllm_config.device_config.device
)
with torch.device(load_device):
if is_checkpoint_format: if is_checkpoint_format:
# load weights from checkpoint/ original model format # load weights from checkpoint/ original model format
initialize_layerwise_reload(model) initialize_layerwise_reload(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment