Unverified Commit b32260ab authored by liangel-02's avatar liangel-02 Committed by GitHub
Browse files

[torchao] safetensors integration (#25969)


Signed-off-by: default avatarAngel Li <liangel@meta.com>
parent f80e7866
...@@ -216,5 +216,22 @@ def test_reload_weights(): ...@@ -216,5 +216,22 @@ def test_reload_weights():
# print("-" * 60) # print("-" * 60)
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
@pytest.mark.skip(
reason="since torchao nightly is only compatible with torch nightly"
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
"torchao tests that requires newer versions (0.14.0.dev+) for now"
)
def test_opt_125m_float8_weight_only_safetensors_model_loading_with_params(vllm_runner):
torch._dynamo.reset()
model_name = (
"torchao-testing/opt-125m-Float8WeightOnlyConfig-v2-0.14.0.dev-safetensors"
)
with vllm_runner(model_name=model_name, dtype="bfloat16") as llm:
output = llm.generate_greedy(["The capital of France is"], max_tokens=32)
assert output
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
...@@ -59,6 +59,10 @@ class LoadConfig: ...@@ -59,6 +59,10 @@ class LoadConfig:
This is recommended for models on network filesystems (e.g., Lustre, NFS) This is recommended for models on network filesystems (e.g., Lustre, NFS)
as it avoids inefficient random reads, significantly speeding up model as it avoids inefficient random reads, significantly speeding up model
initialization. However, it uses more CPU RAM. initialization. However, it uses more CPU RAM.
- "torchao": Weights are loaded in upfront and then reconstructed
into torchao tensor subclasses. This is used when the checkpoint
was quantized using torchao and saved using safetensors.
Needs torchao >= 0.14.0
""" """
model_loader_extra_config: Union[dict, TensorizerConfig] = field( model_loader_extra_config: Union[dict, TensorizerConfig] = field(
default_factory=dict default_factory=dict
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
import json import json
from importlib.util import find_spec
from typing import Any, Optional from typing import Any, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from packaging import version
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -23,6 +26,18 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -23,6 +26,18 @@ from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__) logger = init_logger(__name__)
def torchao_version_at_least(torchao_version: str) -> bool:
if find_spec("torchao"):
try:
if version.parse(importlib.metadata.version("torchao")) >= version.parse(
torchao_version
):
return True
except (ImportError, version.InvalidVersion):
return False
return False
def should_skip(prefix: str, skip_modules: list[str]) -> bool: def should_skip(prefix: str, skip_modules: list[str]) -> bool:
""" """
Robust skipping logic: Robust skipping logic:
......
...@@ -14,6 +14,7 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME ...@@ -14,6 +14,7 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.config.load import LoadConfig from vllm.config.load import LoadConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least
from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_safetensors_index_file_from_hf,
...@@ -272,6 +273,10 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -272,6 +273,10 @@ class DefaultModelLoader(BaseModelLoader):
) )
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
if model_config.quantization == "torchao" and torchao_version_at_least(
"0.14.0"
):
self.load_config.safetensors_load_strategy = "torchao"
weights_to_load = {name for name, _ in model.named_parameters()} weights_to_load = {name for name, _ in model.named_parameters()}
# if we don't have `model.weight_metadata_and_attr_saved` defined and # if we don't have `model.weight_metadata_and_attr_saved` defined and
......
...@@ -54,6 +54,8 @@ except ImportError: ...@@ -54,6 +54,8 @@ except ImportError:
SafeTensorsFileLoader = fastsafetensors.placeholder_attr("SafeTensorsFileLoader") SafeTensorsFileLoader = fastsafetensors.placeholder_attr("SafeTensorsFileLoader")
SingleGroup = fastsafetensors.placeholder_attr("SingleGroup") SingleGroup = fastsafetensors.placeholder_attr("SingleGroup")
from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least
logger = init_logger(__name__) logger = init_logger(__name__)
# use system-level temp directory for file locks, so that multiple users # use system-level temp directory for file locks, so that multiple users
...@@ -602,6 +604,23 @@ def safetensors_weights_iterator( ...@@ -602,6 +604,23 @@ def safetensors_weights_iterator(
with open(st_file, "rb") as f: with open(st_file, "rb") as f:
state_dict = load(f.read()) state_dict = load(f.read())
yield from state_dict.items() yield from state_dict.items()
elif safetensors_load_strategy == "torchao":
if not torchao_version_at_least("0.14.0"):
raise ValueError(
"Please use torchao version >= 0.14.0 \
to load torchao safetensors checkpoint"
)
from torchao.prototype.safetensors.safetensors_support import (
unflatten_tensor_state_dict,
)
with safe_open(st_file, framework="pt") as f:
state_dict = {}
for name in f.keys(): # noqa: SIM118
state_dict[name] = f.get_tensor(name)
metadata = f.metadata()
updated_state_dict = unflatten_tensor_state_dict(state_dict, metadata)
yield from updated_state_dict.items()
else: else:
with safe_open(st_file, framework="pt") as f: with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118 for name in f.keys(): # noqa: SIM118
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment