"vscode:/vscode.git/clone" did not exist on "97a042f3bca53417de6405a248e3d11fca568e2c"
Unverified Commit 72d14d0e authored by Sanger Steel's avatar Sanger Steel Committed by GitHub
Browse files

[Frontend] [Core] Integrate Tensorizer in to S3 loading machinery, allow...


[Frontend] [Core] Integrate Tensorizer in to S3 loading machinery, allow passing arbitrary arguments during save/load (#19619)
Signed-off-by: default avatarSanger Steel <sangersteel@gmail.com>
Co-authored-by: default avatarEta <esyra@coreweave.com>
parent e34d130c
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import argparse import argparse
import dataclasses import dataclasses
import json import json
import logging
import os import os
import uuid import uuid
...@@ -15,9 +16,13 @@ from vllm.model_executor.model_loader.tensorizer import ( ...@@ -15,9 +16,13 @@ from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, TensorizerConfig,
tensorize_lora_adapter, tensorize_lora_adapter,
tensorize_vllm_model, tensorize_vllm_model,
tensorizer_kwargs_arg,
) )
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
logger = logging.getLogger()
# yapf conflicts with isort for this docstring # yapf conflicts with isort for this docstring
# yapf: disable # yapf: disable
""" """
...@@ -119,7 +124,7 @@ vllm serve <model_path> \ ...@@ -119,7 +124,7 @@ vllm serve <model_path> \
""" """
def parse_args(): def get_parser():
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description="An example script that can be used to serialize and " description="An example script that can be used to serialize and "
"deserialize vLLM models. These models " "deserialize vLLM models. These models "
...@@ -135,13 +140,13 @@ def parse_args(): ...@@ -135,13 +140,13 @@ def parse_args():
required=False, required=False,
help="Path to a LoRA adapter to " help="Path to a LoRA adapter to "
"serialize along with model tensors. This can then be deserialized " "serialize along with model tensors. This can then be deserialized "
"along with the model by passing a tensorizer_config kwarg to " "along with the model by instantiating a TensorizerConfig object, "
"LoRARequest with type TensorizerConfig. See the docstring for this " "creating a dict from it with TensorizerConfig.to_serializable(), "
"for a usage example." "and passing it to LoRARequest's initializer with the kwarg "
"tensorizer_config_dict."
) )
subparsers = parser.add_subparsers(dest='command') subparsers = parser.add_subparsers(dest='command', required=True)
serialize_parser = subparsers.add_parser( serialize_parser = subparsers.add_parser(
'serialize', help="Serialize a model to `--serialized-directory`") 'serialize', help="Serialize a model to `--serialized-directory`")
...@@ -171,6 +176,14 @@ def parse_args(): ...@@ -171,6 +176,14 @@ def parse_args():
"where `suffix` is given by `--suffix` or a random UUID if not " "where `suffix` is given by `--suffix` or a random UUID if not "
"provided.") "provided.")
serialize_parser.add_argument(
"--serialization-kwargs",
type=tensorizer_kwargs_arg,
required=False,
help=("A JSON string containing additional keyword arguments to "
"pass to Tensorizer's TensorSerializer during "
"serialization."))
serialize_parser.add_argument( serialize_parser.add_argument(
"--keyfile", "--keyfile",
type=str, type=str,
...@@ -186,9 +199,17 @@ def parse_args(): ...@@ -186,9 +199,17 @@ def parse_args():
deserialize_parser.add_argument( deserialize_parser.add_argument(
"--path-to-tensors", "--path-to-tensors",
type=str, type=str,
required=True, required=False,
help="The local path or S3 URI to the model tensors to deserialize. ") help="The local path or S3 URI to the model tensors to deserialize. ")
deserialize_parser.add_argument(
"--serialized-directory",
type=str,
required=False,
help="Directory with model artifacts for loading. Assumes a "
"model.tensors file exists therein. Can supersede "
"--path-to-tensors.")
deserialize_parser.add_argument( deserialize_parser.add_argument(
"--keyfile", "--keyfile",
type=str, type=str,
...@@ -196,11 +217,27 @@ def parse_args(): ...@@ -196,11 +217,27 @@ def parse_args():
help=("Path to a binary key to use to decrypt the model weights," help=("Path to a binary key to use to decrypt the model weights,"
" if the model was serialized with encryption")) " if the model was serialized with encryption"))
TensorizerArgs.add_cli_args(deserialize_parser) deserialize_parser.add_argument(
"--deserialization-kwargs",
type=tensorizer_kwargs_arg,
required=False,
help=("A JSON string containing additional keyword arguments to "
"pass to Tensorizer's `TensorDeserializer` during "
"deserialization."))
return parser.parse_args() TensorizerArgs.add_cli_args(deserialize_parser)
return parser
def merge_extra_config_with_tensorizer_config(extra_cfg: dict,
cfg: TensorizerConfig):
for k, v in extra_cfg.items():
if hasattr(cfg, k):
setattr(cfg, k, v)
logger.info(
"Updating TensorizerConfig with %s from "
"--model-loader-extra-config provided", k
)
def deserialize(args, tensorizer_config): def deserialize(args, tensorizer_config):
if args.lora_path: if args.lora_path:
...@@ -230,7 +267,8 @@ def deserialize(args, tensorizer_config): ...@@ -230,7 +267,8 @@ def deserialize(args, tensorizer_config):
lora_request=LoRARequest("sql-lora", lora_request=LoRARequest("sql-lora",
1, 1,
args.lora_path, args.lora_path,
tensorizer_config = tensorizer_config) tensorizer_config_dict = tensorizer_config
.to_serializable())
) )
) )
else: else:
...@@ -243,7 +281,8 @@ def deserialize(args, tensorizer_config): ...@@ -243,7 +281,8 @@ def deserialize(args, tensorizer_config):
def main(): def main():
args = parse_args() parser = get_parser()
args = parser.parse_args()
s3_access_key_id = (getattr(args, 's3_access_key_id', None) s3_access_key_id = (getattr(args, 's3_access_key_id', None)
or os.environ.get("S3_ACCESS_KEY_ID", None)) or os.environ.get("S3_ACCESS_KEY_ID", None))
...@@ -265,13 +304,24 @@ def main(): ...@@ -265,13 +304,24 @@ def main():
else: else:
keyfile = None keyfile = None
extra_config = {}
if args.model_loader_extra_config: if args.model_loader_extra_config:
config = json.loads(args.model_loader_extra_config) extra_config = json.loads(args.model_loader_extra_config)
tensorizer_args = \
TensorizerConfig(**config)._construct_tensorizer_args()
tensorizer_args.tensorizer_uri = args.path_to_tensors tensorizer_dir = (args.serialized_directory or
else: extra_config.get("tensorizer_dir"))
tensorizer_args = None tensorizer_uri = (getattr(args, "path_to_tensors", None)
or extra_config.get("tensorizer_uri"))
if tensorizer_dir and tensorizer_uri:
parser.error("--serialized-directory and --path-to-tensors "
"cannot both be provided")
if not tensorizer_dir and not tensorizer_uri:
parser.error("Either --serialized-directory or --path-to-tensors "
"must be provided")
if args.command == "serialize": if args.command == "serialize":
eng_args_dict = {f.name: getattr(args, f.name) for f in eng_args_dict = {f.name: getattr(args, f.name) for f in
...@@ -281,7 +331,7 @@ def main(): ...@@ -281,7 +331,7 @@ def main():
argparse.Namespace(**eng_args_dict) argparse.Namespace(**eng_args_dict)
) )
input_dir = args.serialized_directory.rstrip('/') input_dir = tensorizer_dir.rstrip('/')
suffix = args.suffix if args.suffix else uuid.uuid4().hex suffix = args.suffix if args.suffix else uuid.uuid4().hex
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}" base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
if engine_args.tensor_parallel_size > 1: if engine_args.tensor_parallel_size > 1:
...@@ -292,21 +342,29 @@ def main(): ...@@ -292,21 +342,29 @@ def main():
tensorizer_config = TensorizerConfig( tensorizer_config = TensorizerConfig(
tensorizer_uri=model_path, tensorizer_uri=model_path,
encryption_keyfile=keyfile, encryption_keyfile=keyfile,
**credentials) serialization_kwargs=args.serialization_kwargs or {},
**credentials
)
if args.lora_path: if args.lora_path:
tensorizer_config.lora_dir = tensorizer_config.tensorizer_dir tensorizer_config.lora_dir = tensorizer_config.tensorizer_dir
tensorize_lora_adapter(args.lora_path, tensorizer_config) tensorize_lora_adapter(args.lora_path, tensorizer_config)
merge_extra_config_with_tensorizer_config(extra_config,
tensorizer_config)
tensorize_vllm_model(engine_args, tensorizer_config) tensorize_vllm_model(engine_args, tensorizer_config)
elif args.command == "deserialize": elif args.command == "deserialize":
if not tensorizer_args: tensorizer_config = TensorizerConfig(
tensorizer_config = TensorizerConfig( tensorizer_uri=args.path_to_tensors,
tensorizer_uri=args.path_to_tensors, tensorizer_dir=args.serialized_directory,
encryption_keyfile = keyfile, encryption_keyfile=keyfile,
**credentials deserialization_kwargs=args.deserialization_kwargs or {},
) **credentials
)
merge_extra_config_with_tensorizer_config(extra_config,
tensorizer_config)
deserialize(args, tensorizer_config) deserialize(args, tensorizer_config)
else: else:
raise ValueError("Either serialize or deserialize must be specified.") raise ValueError("Either serialize or deserialize must be specified.")
......
# testing # testing
pytest pytest
tensorizer>=2.9.0 tensorizer==2.10.1
pytest-forked pytest-forked
pytest-asyncio pytest-asyncio
pytest-rerunfailures pytest-rerunfailures
......
...@@ -11,7 +11,7 @@ datasets ...@@ -11,7 +11,7 @@ datasets
ray>=2.10.0,<2.45.0 ray>=2.10.0,<2.45.0
peft peft
pytest-asyncio pytest-asyncio
tensorizer>=2.9.0 tensorizer==2.10.1
packaging>=24.2 packaging>=24.2
setuptools>=77.0.3,<80.0.0 setuptools>=77.0.3,<80.0.0
setuptools-scm>=8 setuptools-scm>=8
......
# testing # testing
pytest pytest
tensorizer>=2.9.0 tensorizer==2.10.1
pytest-forked pytest-forked
pytest-asyncio pytest-asyncio
pytest-rerunfailures pytest-rerunfailures
......
...@@ -739,7 +739,7 @@ tenacity==9.0.0 ...@@ -739,7 +739,7 @@ tenacity==9.0.0
# via # via
# lm-eval # lm-eval
# plotly # plotly
tensorizer==2.9.0 tensorizer==2.10.1
# via -r requirements/test.in # via -r requirements/test.in
threadpoolctl==3.5.0 threadpoolctl==3.5.0
# via scikit-learn # via scikit-learn
......
...@@ -689,7 +689,7 @@ setup( ...@@ -689,7 +689,7 @@ setup(
install_requires=get_requirements(), install_requires=get_requirements(),
extras_require={ extras_require={
"bench": ["pandas", "datasets"], "bench": ["pandas", "datasets"],
"tensorizer": ["tensorizer>=2.9.0"], "tensorizer": ["tensorizer==2.10.1"],
"fastsafetensors": ["fastsafetensors >= 0.1.10"], "fastsafetensors": ["fastsafetensors >= 0.1.10"],
"runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"], "runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"],
"audio": ["librosa", "soundfile"], # Required for audio processing "audio": ["librosa", "soundfile"], # Required for audio processing
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc import gc
import json import os
import tempfile import tempfile
import openai import openai
...@@ -58,18 +58,20 @@ def tensorize_model_and_lora(tmp_dir, model_uri): ...@@ -58,18 +58,20 @@ def tensorize_model_and_lora(tmp_dir, model_uri):
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(model_uri, tensorize_model_and_lora): def server(model_uri, tensorize_model_and_lora):
model_loader_extra_config = { # In this case, model_uri is a directory with a model.tensors
"tensorizer_uri": model_uri, # file and all necessary model artifacts, particularly a
} # HF `config.json` file. In this case, Tensorizer can infer the
# `TensorizerConfig` so --model-loader-extra-config can be completely
# omitted.
## Start OpenAI API server ## Start OpenAI API server
args = [ args = [
"--load-format", "tensorizer", "--device", "cuda", "--load-format", "tensorizer", "--served-model-name", MODEL_NAME,
"--model-loader-extra-config", "--enable-lora"
json.dumps(model_loader_extra_config), "--enable-lora"
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: model_dir = os.path.dirname(model_uri)
with RemoteOpenAIServer(model_dir, args) as remote_server:
yield remote_server yield remote_server
......
...@@ -169,7 +169,8 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files, ...@@ -169,7 +169,8 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files,
f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", "--model", f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", "--model",
MODEL_PATH, "--lora-path", lora_path, "--tensor-parallel-size", MODEL_PATH, "--lora-path", lora_path, "--tensor-parallel-size",
str(tp_size), "serialize", "--serialized-directory", str(tp_size), "serialize", "--serialized-directory",
str(tmp_path), "--suffix", suffix str(tmp_path), "--suffix", suffix, "--serialization-kwargs",
'{"limit_cpu_concurrency": 4}'
], ],
check=True, check=True,
capture_output=True, capture_output=True,
...@@ -195,7 +196,7 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files, ...@@ -195,7 +196,7 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files,
tensor_parallel_size=2, tensor_parallel_size=2,
max_loras=2) max_loras=2)
tensorizer_config_dict = tensorizer_config.to_dict() tensorizer_config_dict = tensorizer_config.to_serializable()
print("lora adapter created") print("lora adapter created")
assert do_sample(loaded_vllm_model, assert do_sample(loaded_vllm_model,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable
import pytest import pytest
from vllm import LLM, EngineArgs
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.model_executor.model_loader import tensorizer as tensorizer_mod
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.executor.abstract import UniProcExecutor
from vllm.worker.worker_base import WorkerWrapperBase
MODEL_REF = "facebook/opt-125m"
@pytest.fixture()
def model_ref():
return MODEL_REF
@pytest.fixture(autouse=True)
def allow_insecure_serialization(monkeypatch):
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
...@@ -11,7 +30,73 @@ def cleanup(): ...@@ -11,7 +30,73 @@ def cleanup():
cleanup_dist_env_and_memory(shutdown_ray=True) cleanup_dist_env_and_memory(shutdown_ray=True)
@pytest.fixture()
def just_serialize_model_tensors(model_ref, monkeypatch, tmp_path):
def noop(*args, **kwargs):
return None
args = EngineArgs(model=model_ref)
tc = TensorizerConfig(tensorizer_uri=f"{tmp_path}/model.tensors")
monkeypatch.setattr(tensorizer_mod, "serialize_extra_artifacts", noop)
tensorizer_mod.tensorize_vllm_model(args, tc)
yield tmp_path
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def tensorizer_config(): def tensorizer_config():
config = TensorizerConfig(tensorizer_uri="vllm") config = TensorizerConfig(tensorizer_uri="vllm")
return config return config
@pytest.fixture()
def model_path(model_ref, tmp_path):
yield tmp_path / model_ref / "model.tensors"
def assert_from_collective_rpc(engine: LLM, closure: Callable,
closure_kwargs: dict):
res = engine.collective_rpc(method=closure, kwargs=closure_kwargs)
return all(res)
# This is an object pulled from tests/v1/engine/test_engine_core.py
# Modified to strip the `load_model` method from its `_init_executor`
# method. It's purely used as a dummy utility to run methods that test
# Tensorizer functionality
class DummyExecutor(UniProcExecutor):
def _init_executor(self) -> None:
"""Initialize the worker and load the model.
"""
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
rpc_rank=0)
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
local_rank = 0
# set local rank as the device index if specified
device_info = self.vllm_config.device_config.device.__str__().split(
":")
if len(device_info) > 1:
local_rank = int(device_info[1])
rank = 0
is_driver_worker = True
kwargs = dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
)
self.collective_rpc("init_worker", args=([kwargs], ))
self.collective_rpc("init_device")
@property
def max_concurrent_batches(self) -> int:
return 2
def shutdown(self):
if hasattr(self, 'thread_pool'):
self.thread_pool.shutdown(wait=False)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import gc import gc
import json
import os import os
import pathlib import pathlib
import subprocess import subprocess
import sys
from typing import Any
import pytest import pytest
import torch import torch
from vllm import SamplingParams import vllm.model_executor.model_loader.tensorizer
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
# yapf conflicts with isort for this docstring
# yapf: disable # yapf: disable
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig, from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
TensorSerializer, TensorSerializer,
is_vllm_tensorized, is_vllm_tensorized,
open_stream, open_stream,
tensorize_vllm_model) tensorize_vllm_model)
from vllm.model_executor.model_loader.tensorizer_loader import (
BLACKLISTED_TENSORIZER_ARGS)
# yapf: enable # yapf: enable
from vllm.utils import PlaceholderModule from vllm.utils import PlaceholderModule
from ..utils import VLLM_PATH from ..utils import VLLM_PATH, RemoteOpenAIServer
from .conftest import DummyExecutor, assert_from_collective_rpc
try: try:
import tensorizer
from tensorizer import EncryptionParams from tensorizer import EncryptionParams
except ImportError: except ImportError:
tensorizer = PlaceholderModule("tensorizer") # type: ignore[assignment] tensorizer = PlaceholderModule("tensorizer") # type: ignore[assignment]
EncryptionParams = tensorizer.placeholder_attr("EncryptionParams") EncryptionParams = tensorizer.placeholder_attr("EncryptionParams")
class TensorizerCaughtError(Exception):
pass
EXAMPLES_PATH = VLLM_PATH / "examples" EXAMPLES_PATH = VLLM_PATH / "examples"
pytest_plugins = "pytest_asyncio",
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
"The president of the United States is", "The president of the United States is",
...@@ -40,9 +55,37 @@ prompts = [ ...@@ -40,9 +55,37 @@ prompts = [
# Create a sampling params object. # Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0) sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
model_ref = "facebook/opt-125m"
tensorize_model_for_testing_script = os.path.join( def patch_init_and_catch_error(self, obj, method_name,
os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py") expected_error: type[Exception]):
original = getattr(obj, method_name, None)
if original is None:
raise ValueError("Method '{}' not found.".format(method_name))
def wrapper(*args, **kwargs):
try:
return original(*args, **kwargs)
except expected_error as err:
raise TensorizerCaughtError from err
setattr(obj, method_name, wrapper)
self.load_model()
def assert_specific_tensorizer_error_is_raised(
executor,
obj: Any,
method_name: str,
expected_error: type[Exception],
):
with pytest.raises(TensorizerCaughtError):
executor.collective_rpc(patch_init_and_catch_error,
args=(
obj,
method_name,
expected_error,
))
def is_curl_installed(): def is_curl_installed():
...@@ -81,11 +124,10 @@ def test_can_deserialize_s3(vllm_runner): ...@@ -81,11 +124,10 @@ def test_can_deserialize_s3(vllm_runner):
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") @pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_deserialized_encrypted_vllm_model_has_same_outputs( def test_deserialized_encrypted_vllm_model_has_same_outputs(
vllm_runner, tmp_path): model_ref, vllm_runner, tmp_path, model_path):
args = EngineArgs(model=model_ref) args = EngineArgs(model=model_ref)
with vllm_runner(model_ref) as vllm_model: with vllm_runner(model_ref) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors") key_path = tmp_path / model_ref / "model.key"
key_path = tmp_path / (model_ref + ".key")
write_keyfile(key_path) write_keyfile(key_path)
outputs = vllm_model.generate(prompts, sampling_params) outputs = vllm_model.generate(prompts, sampling_params)
...@@ -111,9 +153,9 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs( ...@@ -111,9 +153,9 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
tmp_path): tmp_path, model_ref,
model_path):
with hf_runner(model_ref) as hf_model: with hf_runner(model_ref) as hf_model:
model_path = tmp_path / (model_ref + ".tensors")
max_tokens = 50 max_tokens = 50
outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens) outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens)
with open_stream(model_path, "wb+") as stream: with open_stream(model_path, "wb+") as stream:
...@@ -123,7 +165,7 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, ...@@ -123,7 +165,7 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
with vllm_runner(model_ref, with vllm_runner(model_ref,
load_format="tensorizer", load_format="tensorizer",
model_loader_extra_config=TensorizerConfig( model_loader_extra_config=TensorizerConfig(
tensorizer_uri=model_path, tensorizer_uri=str(model_path),
num_readers=1, num_readers=1,
)) as loaded_hf_model: )) as loaded_hf_model:
deserialized_outputs = loaded_hf_model.generate_greedy( deserialized_outputs = loaded_hf_model.generate_greedy(
...@@ -132,7 +174,7 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, ...@@ -132,7 +174,7 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
assert outputs == deserialized_outputs assert outputs == deserialized_outputs
def test_load_without_tensorizer_load_format(vllm_runner, capfd): def test_load_without_tensorizer_load_format(vllm_runner, capfd, model_ref):
model = None model = None
try: try:
model = vllm_runner( model = vllm_runner(
...@@ -150,7 +192,8 @@ def test_load_without_tensorizer_load_format(vllm_runner, capfd): ...@@ -150,7 +192,8 @@ def test_load_without_tensorizer_load_format(vllm_runner, capfd):
torch.cuda.empty_cache() torch.cuda.empty_cache()
def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd): def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd,
model_ref):
model = None model = None
try: try:
model = vllm_runner( model = vllm_runner(
...@@ -208,7 +251,7 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs( ...@@ -208,7 +251,7 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(
outputs = base_model.generate(prompts, sampling_params) outputs = base_model.generate(prompts, sampling_params)
# load model with two shards and serialize with encryption # load model with two shards and serialize with encryption
model_path = str(tmp_path / (model_ref + "-%02d.tensors")) model_path = str(tmp_path / model_ref / "model-%02d.tensors")
key_path = tmp_path / (model_ref + ".key") key_path = tmp_path / (model_ref + ".key")
tensorizer_config = TensorizerConfig( tensorizer_config = TensorizerConfig(
...@@ -242,13 +285,12 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs( ...@@ -242,13 +285,12 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(
@pytest.mark.flaky(reruns=3) @pytest.mark.flaky(reruns=3)
def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path): def test_vllm_tensorized_model_has_same_outputs(model_ref, vllm_runner,
tmp_path, model_path):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
model_ref = "facebook/opt-125m"
model_path = tmp_path / (model_ref + ".tensors")
config = TensorizerConfig(tensorizer_uri=str(model_path)) config = TensorizerConfig(tensorizer_uri=str(model_path))
args = EngineArgs(model=model_ref, device="cuda") args = EngineArgs(model=model_ref)
with vllm_runner(model_ref) as vllm_model: with vllm_runner(model_ref) as vllm_model:
outputs = vllm_model.generate(prompts, sampling_params) outputs = vllm_model.generate(prompts, sampling_params)
...@@ -264,3 +306,243 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path): ...@@ -264,3 +306,243 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
# noqa: E501 # noqa: E501
assert outputs == deserialized_outputs assert outputs == deserialized_outputs
def test_load_with_just_model_tensors(just_serialize_model_tensors, model_ref):
# For backwards compatibility, ensure Tensorizer can be still be loaded
# for inference by passing the model reference name, not a local/S3 dir,
# and the location of the model tensors
model_dir = just_serialize_model_tensors
extra_config = {"tensorizer_uri": f"{model_dir}/model.tensors"}
## Start OpenAI API server
args = [
"--load-format",
"tensorizer",
"--model-loader-extra-config",
json.dumps(extra_config),
]
with RemoteOpenAIServer(model_ref, args):
# This test only concerns itself with being able to load the model
# and successfully initialize the server
pass
def test_assert_serialization_kwargs_passed_to_tensor_serializer(tmp_path):
serialization_params = {
"limit_cpu_concurrency": 2,
}
model_ref = "facebook/opt-125m"
model_path = tmp_path / (model_ref + ".tensors")
config = TensorizerConfig(tensorizer_uri=str(model_path),
serialization_kwargs=serialization_params)
llm = LLM(model=model_ref, )
def serialization_test(self, *args, **kwargs):
# This is performed in the ephemeral worker process, so monkey-patching
# will actually work, and cleanup is guaranteed so don't
# need to reset things
original_dict = serialization_params
to_compare = {}
original = tensorizer.serialization.TensorSerializer.__init__
def tensorizer_serializer_wrapper(self, *args, **kwargs):
nonlocal to_compare
to_compare = kwargs.copy()
return original(self, *args, **kwargs)
tensorizer.serialization.TensorSerializer.__init__ = (
tensorizer_serializer_wrapper)
tensorizer_config = TensorizerConfig(**kwargs["tensorizer_config"])
self.save_tensorized_model(tensorizer_config=tensorizer_config, )
return to_compare | original_dict == to_compare
kwargs = {"tensorizer_config": config.to_serializable()}
assert assert_from_collective_rpc(llm, serialization_test, kwargs)
def test_assert_deserialization_kwargs_passed_to_tensor_deserializer(
tmp_path, capfd):
deserialization_kwargs = {
"num_readers": "bar", # illegal value
}
serialization_params = {
"limit_cpu_concurrency": 2,
}
model_ref = "facebook/opt-125m"
model_path = tmp_path / (model_ref + ".tensors")
config = TensorizerConfig(tensorizer_uri=str(model_path),
serialization_kwargs=serialization_params)
args = EngineArgs(model=model_ref)
tensorize_vllm_model(args, config)
loader_tc = TensorizerConfig(
tensorizer_uri=str(model_path),
deserialization_kwargs=deserialization_kwargs,
)
engine_args = EngineArgs(
model="facebook/opt-125m",
load_format="tensorizer",
model_loader_extra_config=loader_tc.to_serializable(),
)
vllm_config = engine_args.create_engine_config()
executor = DummyExecutor(vllm_config)
assert_specific_tensorizer_error_is_raised(
executor,
tensorizer.serialization.TensorDeserializer,
"__init__",
TypeError,
)
def test_assert_stream_kwargs_passed_to_tensor_deserializer(tmp_path, capfd):
deserialization_kwargs = {
"num_readers": 1,
}
serialization_params = {
"limit_cpu_concurrency": 2,
}
model_ref = "facebook/opt-125m"
model_path = tmp_path / (model_ref + ".tensors")
config = TensorizerConfig(tensorizer_uri=str(model_path),
serialization_kwargs=serialization_params)
args = EngineArgs(model=model_ref)
tensorize_vllm_model(args, config)
stream_kwargs = {"mode": "foo"}
loader_tc = TensorizerConfig(
tensorizer_uri=str(model_path),
deserialization_kwargs=deserialization_kwargs,
stream_kwargs=stream_kwargs,
)
engine_args = EngineArgs(
model="facebook/opt-125m",
load_format="tensorizer",
model_loader_extra_config=loader_tc.to_serializable(),
)
vllm_config = engine_args.create_engine_config()
executor = DummyExecutor(vllm_config)
assert_specific_tensorizer_error_is_raised(
executor,
vllm.model_executor.model_loader.tensorizer,
"open_stream",
ValueError,
)
@pytest.mark.asyncio
async def test_serialize_and_serve_entrypoints(tmp_path):
model_ref = "facebook/opt-125m"
suffix = "test"
try:
result = subprocess.run([
sys.executable,
f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", "--model",
model_ref, "serialize", "--serialized-directory",
str(tmp_path), "--suffix", suffix, "--serialization-kwargs",
'{"limit_cpu_concurrency": 4}'
],
check=True,
capture_output=True,
text=True)
except subprocess.CalledProcessError as e:
print("Tensorizing failed.")
print("STDOUT:\n", e.stdout)
print("STDERR:\n", e.stderr)
raise
assert "Successfully serialized" in result.stdout
# Next, try to serve with vllm serve
model_uri = tmp_path / "vllm" / model_ref / suffix / "model.tensors"
model_loader_extra_config = {
"tensorizer_uri": str(model_uri),
"stream_kwargs": {
"force_http": False,
},
"deserialization_kwargs": {
"verify_hash": True,
"num_readers": 8,
}
}
cmd = [
"-m", "vllm.entrypoints.cli.main", "serve", "--host", "localhost",
"--load-format", "tensorizer", model_ref,
"--model-loader-extra-config",
json.dumps(model_loader_extra_config, indent=2)
]
proc = await asyncio.create_subprocess_exec(
sys.executable,
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
assert proc.stdout is not None
fut = proc.stdout.readuntil(b"Application startup complete.")
try:
await asyncio.wait_for(fut, 180)
except asyncio.TimeoutError:
pytest.fail("Server did not start successfully")
finally:
proc.terminate()
await proc.communicate()
@pytest.mark.parametrize("illegal_value", BLACKLISTED_TENSORIZER_ARGS)
def test_blacklisted_parameter_for_loading(tmp_path, vllm_runner, capfd,
illegal_value):
serialization_params = {
"limit_cpu_concurrency": 2,
}
model_ref = "facebook/opt-125m"
model_path = tmp_path / (model_ref + ".tensors")
config = TensorizerConfig(tensorizer_uri=str(model_path),
serialization_kwargs=serialization_params)
args = EngineArgs(model=model_ref)
tensorize_vllm_model(args, config)
loader_tc = {"tensorizer_uri": str(model_path), illegal_value: "foo"}
try:
vllm_runner(
model_ref,
load_format="tensorizer",
model_loader_extra_config=loader_tc,
)
except RuntimeError:
out, err = capfd.readouterr()
combined_output = out + err
assert (f"ValueError: {illegal_value} is not an allowed "
f"Tensorizer argument.") in combined_output
...@@ -686,8 +686,11 @@ class ModelConfig: ...@@ -686,8 +686,11 @@ class ModelConfig:
# If tokenizer is same as model, download to same directory # If tokenizer is same as model, download to same directory
if model == tokenizer: if model == tokenizer:
s3_model.pull_files( s3_model.pull_files(model,
model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) ignore_pattern=[
"*.pt", "*.safetensors", "*.bin",
"*.tensors"
])
self.tokenizer = s3_model.dir self.tokenizer = s3_model.dir
return return
...@@ -695,7 +698,8 @@ class ModelConfig: ...@@ -695,7 +698,8 @@ class ModelConfig:
if is_s3(tokenizer): if is_s3(tokenizer):
s3_tokenizer = S3Model() s3_tokenizer = S3Model()
s3_tokenizer.pull_files( s3_tokenizer.pull_files(
model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) model,
ignore_pattern=["*.pt", "*.safetensors", "*.bin", "*.tensors"])
self.tokenizer = s3_tokenizer.dir self.tokenizer = s3_tokenizer.dir
def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: def _init_multimodal_config(self) -> Optional["MultiModalConfig"]:
......
...@@ -58,7 +58,8 @@ def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]: ...@@ -58,7 +58,8 @@ def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
def _parse_type(val: str) -> T: def _parse_type(val: str) -> T:
try: try:
if return_type is json.loads and not re.match("^{.*}$", val): if return_type is json.loads and not re.match(
r"(?s)^\s*{.*}\s*$", val):
return cast(T, nullable_kvs(val)) return cast(T, nullable_kvs(val))
return return_type(val) return return_type(val)
except ValueError as e: except ValueError as e:
...@@ -80,7 +81,7 @@ def optional_type( ...@@ -80,7 +81,7 @@ def optional_type(
def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]: def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]:
if not re.match("^{.*}$", val): if not re.match(r"(?s)^\s*{.*}\s*$", val):
return str(val) return str(val)
return optional_type(json.loads)(val) return optional_type(json.loads)(val)
...@@ -1001,11 +1002,42 @@ class EngineArgs: ...@@ -1001,11 +1002,42 @@ class EngineArgs:
override_attention_dtype=self.override_attention_dtype, override_attention_dtype=self.override_attention_dtype,
) )
def valid_tensorizer_config_provided(self) -> bool:
"""
Checks if a parseable TensorizerConfig was passed to
self.model_loader_extra_config. It first checks if the config passed
is a dict or a TensorizerConfig object directly, and if the latter is
true (by checking that the object has TensorizerConfig's
.to_serializable() method), converts it in to a serializable dict
format
"""
if self.model_loader_extra_config:
if hasattr(self.model_loader_extra_config, "to_serializable"):
self.model_loader_extra_config = (
self.model_loader_extra_config.to_serializable())
for allowed_to_pass in ["tensorizer_uri", "tensorizer_dir"]:
try:
self.model_loader_extra_config[allowed_to_pass]
return False
except KeyError:
pass
return True
def create_load_config(self) -> LoadConfig: def create_load_config(self) -> LoadConfig:
if self.quantization == "bitsandbytes": if self.quantization == "bitsandbytes":
self.load_format = "bitsandbytes" self.load_format = "bitsandbytes"
if (self.load_format == "tensorizer"
and self.valid_tensorizer_config_provided()):
logger.info("Inferring Tensorizer args from %s", self.model)
self.model_loader_extra_config = {"tensorizer_dir": self.model}
else:
logger.info(
"Using Tensorizer args from --model-loader-extra-config. "
"Note that you can now simply pass the S3 directory in the "
"model tag instead of providing the JSON string.")
return LoadConfig( return LoadConfig(
load_format=self.load_format, load_format=self.load_format,
download_dir=self.download_dir, download_dir=self.download_dir,
......
...@@ -245,9 +245,10 @@ class LoRAModel(AdapterModel): ...@@ -245,9 +245,10 @@ class LoRAModel(AdapterModel):
lora_tensor_path = os.path.join(tensorizer_config.tensorizer_dir, lora_tensor_path = os.path.join(tensorizer_config.tensorizer_dir,
"adapter_model.tensors") "adapter_model.tensors")
tensorizer_args = tensorizer_config._construct_tensorizer_args() tensorizer_args = tensorizer_config._construct_tensorizer_args()
tensors = TensorDeserializer(lora_tensor_path, tensors = TensorDeserializer(
dtype=tensorizer_config.dtype, lora_tensor_path,
**tensorizer_args.deserializer_params) dtype=tensorizer_config.dtype,
**tensorizer_args.deserialization_kwargs)
check_unexpected_modules(tensors) check_unexpected_modules(tensors)
elif os.path.isfile(lora_tensor_path): elif os.path.isfile(lora_tensor_path):
......
...@@ -106,7 +106,7 @@ class PEFTHelper: ...@@ -106,7 +106,7 @@ class PEFTHelper:
"adapter_config.json") "adapter_config.json")
with open_stream(lora_config_path, with open_stream(lora_config_path,
mode="rb", mode="rb",
**tensorizer_args.stream_params) as f: **tensorizer_args.stream_kwargs) as f:
config = json.load(f) config = json.load(f)
logger.info("Successfully deserialized LoRA config from %s", logger.info("Successfully deserialized LoRA config from %s",
......
...@@ -5,18 +5,18 @@ import argparse ...@@ -5,18 +5,18 @@ import argparse
import contextlib import contextlib
import contextvars import contextvars
import dataclasses import dataclasses
import io
import json import json
import os import os
import tempfile
import threading import threading
import time import time
from collections.abc import Generator from collections.abc import Generator, MutableMapping
from dataclasses import dataclass from dataclasses import asdict, dataclass, field, fields
from functools import partial from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
from typing import TYPE_CHECKING, Any, BinaryIO, Optional, Union
import regex as re import regex as re
import torch import torch
from huggingface_hub import snapshot_download
from torch import nn from torch import nn
from torch.utils._python_dispatch import TorchDispatchMode from torch.utils._python_dispatch import TorchDispatchMode
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -39,10 +39,6 @@ try: ...@@ -39,10 +39,6 @@ try:
from tensorizer.utils import (convert_bytes, get_mem_usage, from tensorizer.utils import (convert_bytes, get_mem_usage,
no_init_or_tensor) no_init_or_tensor)
_read_stream, _write_stream = (partial(
open_stream,
mode=mode,
) for mode in ("rb", "wb+"))
except ImportError: except ImportError:
tensorizer = PlaceholderModule("tensorizer") tensorizer = PlaceholderModule("tensorizer")
DecryptionParams = tensorizer.placeholder_attr("DecryptionParams") DecryptionParams = tensorizer.placeholder_attr("DecryptionParams")
...@@ -54,9 +50,6 @@ except ImportError: ...@@ -54,9 +50,6 @@ except ImportError:
get_mem_usage = tensorizer.placeholder_attr("utils.get_mem_usage") get_mem_usage = tensorizer.placeholder_attr("utils.get_mem_usage")
no_init_or_tensor = tensorizer.placeholder_attr("utils.no_init_or_tensor") no_init_or_tensor = tensorizer.placeholder_attr("utils.no_init_or_tensor")
_read_stream = tensorizer.placeholder_attr("_read_stream")
_write_stream = tensorizer.placeholder_attr("_write_stream")
__all__ = [ __all__ = [
'EncryptionParams', 'DecryptionParams', 'TensorDeserializer', 'EncryptionParams', 'DecryptionParams', 'TensorDeserializer',
'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage', 'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage',
...@@ -66,6 +59,23 @@ __all__ = [ ...@@ -66,6 +59,23 @@ __all__ = [
logger = init_logger(__name__) logger = init_logger(__name__)
def is_valid_deserialization_uri(uri: Optional[str]) -> bool:
if uri:
scheme = uri.lower().split("://")[0]
return scheme in {"s3", "http", "https"} or os.path.exists(uri)
return False
def tensorizer_kwargs_arg(value):
loaded = json.loads(value)
if not isinstance(loaded, dict):
raise argparse.ArgumentTypeError(
f"Not deserializable to dict: {value}. serialization_kwargs and "
f"deserialization_kwargs must be "
f"deserializable from a JSON string to a dictionary. ")
return loaded
class MetaTensorMode(TorchDispatchMode): class MetaTensorMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None): def __torch_dispatch__(self, func, types, args=(), kwargs=None):
...@@ -137,54 +147,143 @@ class _NoInitOrTensorImpl: ...@@ -137,54 +147,143 @@ class _NoInitOrTensorImpl:
@dataclass @dataclass
class TensorizerConfig: class TensorizerConfig(MutableMapping):
tensorizer_uri: Union[str, None] = None tensorizer_uri: Optional[str] = None
vllm_tensorized: Optional[bool] = False tensorizer_dir: Optional[str] = None
verify_hash: Optional[bool] = False vllm_tensorized: Optional[bool] = None
verify_hash: Optional[bool] = None
num_readers: Optional[int] = None num_readers: Optional[int] = None
encryption_keyfile: Optional[str] = None encryption_keyfile: Optional[str] = None
s3_access_key_id: Optional[str] = None s3_access_key_id: Optional[str] = None
s3_secret_access_key: Optional[str] = None s3_secret_access_key: Optional[str] = None
s3_endpoint: Optional[str] = None s3_endpoint: Optional[str] = None
model_class: Optional[type[torch.nn.Module]] = None
hf_config: Optional[PretrainedConfig] = None
dtype: Optional[Union[str, torch.dtype]] = None
lora_dir: Optional[str] = None lora_dir: Optional[str] = None
_is_sharded: bool = False stream_kwargs: Optional[dict[str, Any]] = None
serialization_kwargs: Optional[dict[str, Any]] = None
deserialization_kwargs: Optional[dict[str, Any]] = None
_extra_serialization_attrs: Optional[dict[str, Any]] = field(init=False,
default=None)
model_class: Optional[type[torch.nn.Module]] = field(init=False,
default=None)
hf_config: Optional[PretrainedConfig] = field(init=False, default=None)
dtype: Optional[Union[str, torch.dtype]] = field(init=False, default=None)
_is_sharded: bool = field(init=False, default=False)
_fields: ClassVar[tuple[str, ...]]
_keys: ClassVar[frozenset[str]]
"""
Args for the TensorizerConfig class. These are used to configure the
behavior of model serialization and deserialization using Tensorizer.
Args:
tensorizer_uri: Path to serialized model tensors. Can be a local file
path or a S3 URI. This is a required field unless lora_dir is
provided and the config is meant to be used for the
`tensorize_lora_adapter` function. Unless a `tensorizer_dir` or
`lora_dir` is passed to this object's initializer, this is a required
argument.
tensorizer_dir: Path to a directory containing serialized model tensors,
and all other potential model artifacts to load the model, such as
configs and tokenizer files. Can be passed instead of `tensorizer_uri`
where the `model.tensors` file will be assumed to be in this
directory.
vllm_tensorized: If True, indicates that the serialized model is a
vLLM model. This is used to determine the behavior of the
TensorDeserializer when loading tensors from a serialized model.
It is far faster to deserialize a vLLM model as it utilizes
tensorizer's optimized GPU loading. Note that this is now
deprecated, as serialized vLLM models are now automatically
inferred as vLLM models.
verify_hash: If True, the hashes of each tensor will be verified against
the hashes stored in the metadata. A `HashMismatchError` will be
raised if any of the hashes do not match.
num_readers: Controls how many threads are allowed to read concurrently
from the source file. Default is `None`, which will dynamically set
the number of readers based on the number of available
resources and model size. This greatly increases performance.
encryption_keyfile: File path to a binary file containing a
binary key to use for decryption. `None` (the default) means
no decryption. See the example script in
examples/others/tensorize_vllm_model.py.
s3_access_key_id: The access key for the S3 bucket. Can also be set via
the S3_ACCESS_KEY_ID environment variable.
s3_secret_access_key: The secret access key for the S3 bucket. Can also
be set via the S3_SECRET_ACCESS_KEY environment variable.
s3_endpoint: The endpoint for the S3 bucket. Can also be set via the
S3_ENDPOINT_URL environment variable.
lora_dir: Path to a directory containing LoRA adapter artifacts for
serialization or deserialization. When serializing LoRA adapters
this is the only necessary parameter to pass to this object's
initializer.
"""
def __post_init__(self): def __post_init__(self):
# check if the configuration is for a sharded vLLM model # check if the configuration is for a sharded vLLM model
self._is_sharded = isinstance(self.tensorizer_uri, str) \ self._is_sharded = isinstance(self.tensorizer_uri, str) \
and re.search(r'%0\dd', self.tensorizer_uri) is not None and re.search(r'%0\dd', self.tensorizer_uri) is not None
if not self.tensorizer_uri and not self.lora_dir:
raise ValueError("tensorizer_uri must be provided.")
if not self.tensorizer_uri and self.lora_dir:
self.tensorizer_uri = f"{self.lora_dir}/adapter_model.tensors"
assert self.tensorizer_uri is not None, ("tensorizer_uri must be "
"provided.")
self.tensorizer_dir = os.path.dirname(self.tensorizer_uri)
self.lora_dir = self.tensorizer_dir
@classmethod
def as_dict(cls, *args, **kwargs) -> dict[str, Any]:
cfg = TensorizerConfig(*args, **kwargs)
return dataclasses.asdict(cfg)
def to_dict(self) -> dict[str, Any]: if self.tensorizer_dir and self.tensorizer_uri:
return dataclasses.asdict(self) raise ValueError(
"Either tensorizer_dir or tensorizer_uri must be provided, "
"not both.")
if self.tensorizer_dir and self.lora_dir:
raise ValueError(
"Only one of tensorizer_dir or lora_dir may be specified. "
"Use lora_dir exclusively when serializing LoRA adapters, "
"and tensorizer_dir or tensorizer_uri otherwise.")
if not self.tensorizer_uri:
if self.lora_dir:
self.tensorizer_uri = f"{self.lora_dir}/adapter_model.tensors"
elif self.tensorizer_dir:
self.tensorizer_uri = f"{self.tensorizer_dir}/model.tensors"
else:
raise ValueError("Unable to resolve tensorizer_uri. "
"A valid tensorizer_uri or tensorizer_dir "
"must be provided for deserialization, and a "
"valid tensorizer_uri, tensorizer_uri, or "
"lora_dir for serialization.")
else:
self.tensorizer_dir = os.path.dirname(self.tensorizer_uri)
if not self.serialization_kwargs:
self.serialization_kwargs = {}
if not self.deserialization_kwargs:
self.deserialization_kwargs = {}
def to_serializable(self) -> dict[str, Any]:
# Due to TensorizerConfig needing to be msgpack-serializable, it needs
# support for morphing back and forth between itself and its dict
# representation
# TensorizerConfig's representation as a dictionary is meant to be
# linked to TensorizerConfig in such a way that the following is
# technically initializable:
# TensorizerConfig(**my_tensorizer_cfg.to_serializable())
# This means the dict must not retain non-initializable parameters
# and post-init attribute states
# Also don't want to retain private and unset parameters, so only retain
# not None values and public attributes
raw_tc_dict = asdict(self)
blacklisted = []
if "tensorizer_uri" in raw_tc_dict and "tensorizer_dir" in raw_tc_dict:
blacklisted.append("tensorizer_dir")
if "tensorizer_dir" in raw_tc_dict and "lora_dir" in raw_tc_dict:
blacklisted.append("tensorizer_dir")
tc_dict = {}
for k, v in raw_tc_dict.items():
if (k not in blacklisted and k not in tc_dict
and not k.startswith("_") and v is not None):
tc_dict[k] = v
return tc_dict
def _construct_tensorizer_args(self) -> "TensorizerArgs": def _construct_tensorizer_args(self) -> "TensorizerArgs":
tensorizer_args = { return TensorizerArgs(self) # type: ignore
"tensorizer_uri": self.tensorizer_uri,
"vllm_tensorized": self.vllm_tensorized,
"verify_hash": self.verify_hash,
"num_readers": self.num_readers,
"encryption_keyfile": self.encryption_keyfile,
"s3_access_key_id": self.s3_access_key_id,
"s3_secret_access_key": self.s3_secret_access_key,
"s3_endpoint": self.s3_endpoint,
}
return TensorizerArgs(**tensorizer_args) # type: ignore
def verify_with_parallel_config( def verify_with_parallel_config(
self, self,
...@@ -209,81 +308,76 @@ class TensorizerConfig: ...@@ -209,81 +308,76 @@ class TensorizerConfig:
tensorizer_args = self._construct_tensorizer_args() tensorizer_args = self._construct_tensorizer_args()
return open_stream(self.tensorizer_uri, return open_stream(self.tensorizer_uri,
**tensorizer_args.stream_params) **tensorizer_args.stream_kwargs)
def keys(self):
return self._keys
def __len__(self):
return len(fields(self))
def __iter__(self):
return iter(self._fields)
def __getitem__(self, item: str) -> Any:
if item not in self.keys():
raise KeyError(item)
return getattr(self, item)
def __setitem__(self, key: str, value: Any) -> None:
if key not in self.keys():
# Disallow modifying invalid keys
raise KeyError(key)
setattr(self, key, value)
def __delitem__(self, key, /):
if key not in self.keys():
raise KeyError(key)
delattr(self, key)
TensorizerConfig._fields = tuple(f.name for f in fields(TensorizerConfig))
TensorizerConfig._keys = frozenset(TensorizerConfig._fields)
@dataclass @dataclass
class TensorizerArgs: class TensorizerArgs:
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str, tensorizer_uri: Optional[str] = None
bytes, os.PathLike, int] tensorizer_dir: Optional[str] = None
vllm_tensorized: Optional[bool] = False
verify_hash: Optional[bool] = False
num_readers: Optional[int] = None
encryption_keyfile: Optional[str] = None encryption_keyfile: Optional[str] = None
s3_access_key_id: Optional[str] = None
s3_secret_access_key: Optional[str] = None
s3_endpoint: Optional[str] = None
"""
Args for the TensorizerAgent class. These are used to configure the behavior
of the TensorDeserializer when loading tensors from a serialized model.
Args:
tensorizer_uri: Path to serialized model tensors. Can be a local file
path or a S3 URI. This is a required field unless lora_dir is
provided and the config is meant to be used for the
`tensorize_lora_adapter` function.
vllm_tensorized: If True, indicates that the serialized model is a
vLLM model. This is used to determine the behavior of the
TensorDeserializer when loading tensors from a serialized model.
It is far faster to deserialize a vLLM model as it utilizes
tensorizer's optimized GPU loading. Note that this is now
deprecated, as serialized vLLM models are now automatically
inferred as vLLM models.
verify_hash: If True, the hashes of each tensor will be verified against
the hashes stored in the metadata. A `HashMismatchError` will be
raised if any of the hashes do not match.
num_readers: Controls how many threads are allowed to read concurrently
from the source file. Default is `None`, which will dynamically set
the number of readers based on the number of available
resources and model size. This greatly increases performance.
encryption_keyfile: File path to a binary file containing a
binary key to use for decryption. `None` (the default) means
no decryption. See the example script in
examples/others/tensorize_vllm_model.py.
s3_access_key_id: The access key for the S3 bucket. Can also be set via
the S3_ACCESS_KEY_ID environment variable.
s3_secret_access_key: The secret access key for the S3 bucket. Can also
be set via the S3_SECRET_ACCESS_KEY environment variable.
s3_endpoint: The endpoint for the S3 bucket. Can also be set via the
S3_ENDPOINT_URL environment variable.
"""
def __post_init__(self): def __init__(self, tensorizer_config: TensorizerConfig):
self.file_obj = self.tensorizer_uri for k, v in tensorizer_config.items():
self.s3_access_key_id = self.s3_access_key_id or envs.S3_ACCESS_KEY_ID setattr(self, k, v)
self.s3_secret_access_key = (self.s3_secret_access_key self.file_obj = tensorizer_config.tensorizer_uri
self.s3_access_key_id = (tensorizer_config.s3_access_key_id
or envs.S3_ACCESS_KEY_ID)
self.s3_secret_access_key = (tensorizer_config.s3_secret_access_key
or envs.S3_SECRET_ACCESS_KEY) or envs.S3_SECRET_ACCESS_KEY)
self.s3_endpoint = self.s3_endpoint or envs.S3_ENDPOINT_URL self.s3_endpoint = tensorizer_config.s3_endpoint or envs.S3_ENDPOINT_URL
self.stream_params = {
"s3_access_key_id": self.s3_access_key_id, self.stream_kwargs = {
"s3_secret_access_key": self.s3_secret_access_key, "s3_access_key_id": tensorizer_config.s3_access_key_id,
"s3_endpoint": self.s3_endpoint, "s3_secret_access_key": tensorizer_config.s3_secret_access_key,
"s3_endpoint": tensorizer_config.s3_endpoint,
**(tensorizer_config.stream_kwargs or {})
} }
self.deserializer_params = { self.deserialization_kwargs = {
"verify_hash": self.verify_hash, "verify_hash": tensorizer_config.verify_hash,
"encryption": self.encryption_keyfile, "encryption": tensorizer_config.encryption_keyfile,
"num_readers": self.num_readers "num_readers": tensorizer_config.num_readers,
**(tensorizer_config.deserialization_kwargs or {})
} }
if self.encryption_keyfile: if self.encryption_keyfile:
with open_stream( with open_stream(
self.encryption_keyfile, tensorizer_config.encryption_keyfile,
**self.stream_params, **self.stream_kwargs,
) as stream: ) as stream:
key = stream.read() key = stream.read()
decryption_params = DecryptionParams.from_key(key) decryption_params = DecryptionParams.from_key(key)
self.deserializer_params['encryption'] = decryption_params self.deserialization_kwargs['encryption'] = decryption_params
@staticmethod @staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
...@@ -405,15 +499,22 @@ def init_tensorizer_model(tensorizer_config: TensorizerConfig, ...@@ -405,15 +499,22 @@ def init_tensorizer_model(tensorizer_config: TensorizerConfig,
def deserialize_tensorizer_model(model: nn.Module, def deserialize_tensorizer_model(model: nn.Module,
tensorizer_config: TensorizerConfig) -> None: tensorizer_config: TensorizerConfig) -> None:
tensorizer_args = tensorizer_config._construct_tensorizer_args() tensorizer_args = tensorizer_config._construct_tensorizer_args()
if not is_valid_deserialization_uri(tensorizer_config.tensorizer_uri):
raise ValueError(
f"{tensorizer_config.tensorizer_uri} is not a valid "
f"tensorizer URI. Please check that the URI is correct. "
f"It must either point to a local existing file, or have a "
f"S3, HTTP or HTTPS scheme.")
before_mem = get_mem_usage() before_mem = get_mem_usage()
start = time.perf_counter() start = time.perf_counter()
with _read_stream( with open_stream(
tensorizer_config.tensorizer_uri, tensorizer_config.tensorizer_uri,
**tensorizer_args.stream_params) as stream, TensorDeserializer( mode="rb",
**tensorizer_args.stream_kwargs) as stream, TensorDeserializer(
stream, stream,
dtype=tensorizer_config.dtype, dtype=tensorizer_config.dtype,
device=f'cuda:{torch.cuda.current_device()}', device=torch.device("cuda", torch.cuda.current_device()),
**tensorizer_args.deserializer_params) as deserializer: **tensorizer_args.deserialization_kwargs) as deserializer:
deserializer.load_into_module(model) deserializer.load_into_module(model)
end = time.perf_counter() end = time.perf_counter()
...@@ -442,9 +543,9 @@ def tensorizer_weights_iterator( ...@@ -442,9 +543,9 @@ def tensorizer_weights_iterator(
"examples/others/tensorize_vllm_model.py example script " "examples/others/tensorize_vllm_model.py example script "
"for serializing vLLM models.") "for serializing vLLM models.")
deserializer_args = tensorizer_args.deserializer_params deserializer_args = tensorizer_args.deserialization_kwargs
stream_params = tensorizer_args.stream_params stream_kwargs = tensorizer_args.stream_kwargs
stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params) stream = open_stream(tensorizer_args.tensorizer_uri, **stream_kwargs)
with TensorDeserializer(stream, **deserializer_args, with TensorDeserializer(stream, **deserializer_args,
device="cpu") as state: device="cpu") as state:
yield from state.items() yield from state.items()
...@@ -465,8 +566,8 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool: ...@@ -465,8 +566,8 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool:
""" """
tensorizer_args = tensorizer_config._construct_tensorizer_args() tensorizer_args = tensorizer_config._construct_tensorizer_args()
deserializer = TensorDeserializer(open_stream( deserializer = TensorDeserializer(open_stream(
tensorizer_args.tensorizer_uri, **tensorizer_args.stream_params), tensorizer_args.tensorizer_uri, **tensorizer_args.stream_kwargs),
**tensorizer_args.deserializer_params, **tensorizer_args.deserialization_kwargs,
lazy_load=True) lazy_load=True)
if tensorizer_config.vllm_tensorized: if tensorizer_config.vllm_tensorized:
logger.warning( logger.warning(
...@@ -477,13 +578,41 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool: ...@@ -477,13 +578,41 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool:
return ".vllm_tensorized_marker" in deserializer return ".vllm_tensorized_marker" in deserializer
def serialize_extra_artifacts(
tensorizer_args: TensorizerArgs,
served_model_name: Union[str, list[str], None]) -> None:
if not isinstance(served_model_name, str):
raise ValueError(
f"served_model_name must be a str for serialize_extra_artifacts, "
f"not {type(served_model_name)}.")
with tempfile.TemporaryDirectory() as tmpdir:
snapshot_download(served_model_name,
local_dir=tmpdir,
ignore_patterns=[
"*.pt", "*.safetensors", "*.bin", "*.cache",
"*.gitattributes", "*.md"
])
for artifact in os.scandir(tmpdir):
if not artifact.is_file():
continue
with open(artifact.path, "rb") as f, open_stream(
f"{tensorizer_args.tensorizer_dir}/{artifact.name}",
mode="wb+",
**tensorizer_args.stream_kwargs) as stream:
logger.info("Writing artifact %s", artifact.name)
stream.write(f.read())
def serialize_vllm_model( def serialize_vllm_model(
model: nn.Module, model: nn.Module,
tensorizer_config: TensorizerConfig, tensorizer_config: TensorizerConfig,
model_config: "ModelConfig",
) -> nn.Module: ) -> nn.Module:
model.register_parameter( model.register_parameter(
"vllm_tensorized_marker", "vllm_tensorized_marker",
nn.Parameter(torch.tensor((1, ), device="meta"), requires_grad=False)) nn.Parameter(torch.tensor((1, ), device="meta"), requires_grad=False))
tensorizer_args = tensorizer_config._construct_tensorizer_args() tensorizer_args = tensorizer_config._construct_tensorizer_args()
encryption_params = None encryption_params = None
...@@ -497,10 +626,16 @@ def serialize_vllm_model( ...@@ -497,10 +626,16 @@ def serialize_vllm_model(
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
output_file = output_file % get_tensor_model_parallel_rank() output_file = output_file % get_tensor_model_parallel_rank()
with _write_stream(output_file, **tensorizer_args.stream_params) as stream: with open_stream(output_file, mode="wb+",
serializer = TensorSerializer(stream, encryption=encryption_params) **tensorizer_args.stream_kwargs) as stream:
serializer = TensorSerializer(stream,
encryption=encryption_params,
**tensorizer_config.serialization_kwargs)
serializer.write_module(model) serializer.write_module(model)
serializer.close() serializer.close()
serialize_extra_artifacts(tensorizer_args, model_config.served_model_name)
logger.info("Successfully serialized model to %s", str(output_file)) logger.info("Successfully serialized model to %s", str(output_file))
return model return model
...@@ -522,8 +657,9 @@ def tensorize_vllm_model(engine_args: "EngineArgs", ...@@ -522,8 +657,9 @@ def tensorize_vllm_model(engine_args: "EngineArgs",
if generate_keyfile and (keyfile := if generate_keyfile and (keyfile :=
tensorizer_config.encryption_keyfile) is not None: tensorizer_config.encryption_keyfile) is not None:
encryption_params = EncryptionParams.random() encryption_params = EncryptionParams.random()
with _write_stream( with open_stream(
keyfile, keyfile,
mode="wb+",
s3_access_key_id=tensorizer_config.s3_access_key_id, s3_access_key_id=tensorizer_config.s3_access_key_id,
s3_secret_access_key=tensorizer_config.s3_secret_access_key, s3_secret_access_key=tensorizer_config.s3_secret_access_key,
s3_endpoint=tensorizer_config.s3_endpoint, s3_endpoint=tensorizer_config.s3_endpoint,
...@@ -537,13 +673,13 @@ def tensorize_vllm_model(engine_args: "EngineArgs", ...@@ -537,13 +673,13 @@ def tensorize_vllm_model(engine_args: "EngineArgs",
engine = LLMEngine.from_engine_args(engine_args) engine = LLMEngine.from_engine_args(engine_args)
engine.model_executor.collective_rpc( engine.model_executor.collective_rpc(
"save_tensorized_model", "save_tensorized_model",
kwargs=dict(tensorizer_config=tensorizer_config), kwargs={"tensorizer_config": tensorizer_config.to_serializable()},
) )
else: else:
engine = V1LLMEngine.from_vllm_config(engine_config) engine = V1LLMEngine.from_vllm_config(engine_config)
engine.collective_rpc( engine.collective_rpc(
"save_tensorized_model", "save_tensorized_model",
kwargs=dict(tensorizer_config=tensorizer_config), kwargs={"tensorizer_config": tensorizer_config.to_serializable()},
) )
...@@ -586,14 +722,14 @@ def tensorize_lora_adapter(lora_path: str, ...@@ -586,14 +722,14 @@ def tensorize_lora_adapter(lora_path: str,
with open_stream(f"{tensorizer_config.lora_dir}/adapter_config.json", with open_stream(f"{tensorizer_config.lora_dir}/adapter_config.json",
mode="wb+", mode="wb+",
**tensorizer_args.stream_params) as f: **tensorizer_args.stream_kwargs) as f:
f.write(json.dumps(config).encode("utf-8")) f.write(json.dumps(config).encode("utf-8"))
lora_uri = (f"{tensorizer_config.lora_dir}" lora_uri = (f"{tensorizer_config.lora_dir}"
f"/adapter_model.tensors") f"/adapter_model.tensors")
with open_stream(lora_uri, mode="wb+", with open_stream(lora_uri, mode="wb+",
**tensorizer_args.stream_params) as f: **tensorizer_args.stream_kwargs) as f:
serializer = TensorSerializer(f) serializer = TensorSerializer(f)
serializer.write_state_dict(tensors) serializer.write_state_dict(tensors)
serializer.close() serializer.close()
......
...@@ -20,6 +20,18 @@ from vllm.model_executor.model_loader.utils import (get_model_architecture, ...@@ -20,6 +20,18 @@ from vllm.model_executor.model_loader.utils import (get_model_architecture,
logger = init_logger(__name__) logger = init_logger(__name__)
BLACKLISTED_TENSORIZER_ARGS = {
"device", # vLLM decides this
"dtype", # vLLM decides this
"mode", # Not meant to be configurable by the user
}
def validate_config(config: dict):
for k, v in config.items():
if v is not None and k in BLACKLISTED_TENSORIZER_ARGS:
raise ValueError(f"{k} is not an allowed Tensorizer argument.")
class TensorizerLoader(BaseModelLoader): class TensorizerLoader(BaseModelLoader):
"""Model loader using CoreWeave's tensorizer library.""" """Model loader using CoreWeave's tensorizer library."""
...@@ -29,6 +41,7 @@ class TensorizerLoader(BaseModelLoader): ...@@ -29,6 +41,7 @@ class TensorizerLoader(BaseModelLoader):
if isinstance(load_config.model_loader_extra_config, TensorizerConfig): if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
self.tensorizer_config = load_config.model_loader_extra_config self.tensorizer_config = load_config.model_loader_extra_config
else: else:
validate_config(load_config.model_loader_extra_config)
self.tensorizer_config = TensorizerConfig( self.tensorizer_config = TensorizerConfig(
**load_config.model_loader_extra_config) **load_config.model_loader_extra_config)
...@@ -118,10 +131,12 @@ class TensorizerLoader(BaseModelLoader): ...@@ -118,10 +131,12 @@ class TensorizerLoader(BaseModelLoader):
def save_model( def save_model(
model: torch.nn.Module, model: torch.nn.Module,
tensorizer_config: Union[TensorizerConfig, dict], tensorizer_config: Union[TensorizerConfig, dict],
model_config: ModelConfig,
) -> None: ) -> None:
if isinstance(tensorizer_config, dict): if isinstance(tensorizer_config, dict):
tensorizer_config = TensorizerConfig(**tensorizer_config) tensorizer_config = TensorizerConfig(**tensorizer_config)
serialize_vllm_model( serialize_vllm_model(
model=model, model=model,
tensorizer_config=tensorizer_config, tensorizer_config=tensorizer_config,
model_config=model_config,
) )
...@@ -1820,6 +1820,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1820,6 +1820,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
TensorizerLoader.save_model( TensorizerLoader.save_model(
self.model, self.model,
tensorizer_config=tensorizer_config, tensorizer_config=tensorizer_config,
model_config=self.model_config,
) )
def _get_prompt_logprobs_dict( def _get_prompt_logprobs_dict(
......
...@@ -1246,6 +1246,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1246,6 +1246,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
TensorizerLoader.save_model( TensorizerLoader.save_model(
self.model, self.model,
tensorizer_config=tensorizer_config, tensorizer_config=tensorizer_config,
model_config=self.model_config,
) )
def get_max_block_per_batch(self) -> int: def get_max_block_per_batch(self) -> int:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment