Unverified Commit 59a0192f authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Interface for accessing model from `VllmRunner` (#10353)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 83609791
......@@ -244,6 +244,7 @@ def video_assets() -> _VideoAssets:
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
_R = TypeVar("_R")
class HfRunner:
......@@ -930,6 +931,10 @@ class VllmRunner:
req_outputs = self.model.score(text_1, text_2)
return [req_output.outputs.score for req_output in req_outputs]
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
executor = self.model.llm_engine.model_executor
return executor.apply_model(func)
def __enter__(self):
return self
......
......@@ -51,7 +51,9 @@ def test_custom_executor(model, tmp_path):
assert not os.path.exists(".marker")
engine_args = EngineArgs(
model=model, distributed_executor_backend=CustomUniExecutor)
model=model,
distributed_executor_backend=CustomUniExecutor,
)
engine = LLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(max_tokens=1)
......
......@@ -25,13 +25,12 @@ def test_model_loading_with_params(vllm_runner):
with vllm_runner(model_name=MODEL_NAME,
revision=REVISION,
dtype="float16",
max_model_len=MAX_MODEL_LEN) as model:
output = model.encode("Write a short story about a robot that"
max_model_len=MAX_MODEL_LEN) as vllm_model:
output = vllm_model.encode("Write a short story about a robot that"
" dreams for the first time.\n")
model_config = model.model.llm_engine.model_config
model_tokenizer = model.model.llm_engine.tokenizer
model_config = vllm_model.model.llm_engine.model_config
model_tokenizer = vllm_model.model.llm_engine.tokenizer
# asserts on the bert model config file
assert model_config.encoder_config["max_seq_length"] == 512
......@@ -46,11 +45,13 @@ def test_model_loading_with_params(vllm_runner):
assert model_tokenizer.tokenizer_config["do_lower_case"]
assert model_tokenizer.tokenizer.model_max_length == 512
model = model.model.llm_engine.model_executor\
.driver_worker.model_runner.model
def check_model(model):
assert isinstance(model, BertEmbeddingModel)
assert model._pooler.pooling_type == PoolingType.CLS
assert model._pooler.normalize
vllm_model.apply_model(check_model)
# assert output
assert output
......@@ -64,13 +65,12 @@ def test_roberta_model_loading_with_params(vllm_runner):
with vllm_runner(model_name=MODEL_NAME_ROBERTA,
revision=REVISION_ROBERTA,
dtype="float16",
max_model_len=MAX_MODEL_LEN) as model:
output = model.encode("Write a short story about a robot that"
max_model_len=MAX_MODEL_LEN) as vllm_model:
output = vllm_model.encode("Write a short story about a robot that"
" dreams for the first time.\n")
model_config = model.model.llm_engine.model_config
model_tokenizer = model.model.llm_engine.tokenizer
model_config = vllm_model.model.llm_engine.model_config
model_tokenizer = vllm_model.model.llm_engine.tokenizer
# asserts on the bert model config file
assert model_config.encoder_config["max_seq_length"] == 512
......@@ -84,12 +84,13 @@ def test_roberta_model_loading_with_params(vllm_runner):
assert model_tokenizer.tokenizer_id == "intfloat/multilingual-e5-large"
assert not model_tokenizer.tokenizer_config["do_lower_case"]
model = model.model.llm_engine.model_executor\
.driver_worker.model_runner.model
def check_model(model):
assert isinstance(model, RobertaEmbeddingModel)
assert model._pooler.pooling_type == PoolingType.MEAN
assert model._pooler.normalize
vllm_model.apply_model(check_model)
# assert output
assert output
......@@ -103,17 +104,18 @@ def test_facebook_roberta_model_loading_with_params(vllm_runner):
model_name = "FacebookAI/roberta-base"
with vllm_runner(model_name=model_name,
dtype="float16",
max_model_len=MAX_MODEL_LEN) as model:
output = model.encode("Write a short story about a robot that"
max_model_len=MAX_MODEL_LEN) as vllm_model:
output = vllm_model.encode("Write a short story about a robot that"
" dreams for the first time.\n")
model_tokenizer = model.model.llm_engine.tokenizer
model_tokenizer = vllm_model.model.llm_engine.tokenizer
assert model_tokenizer.tokenizer_id == model_name
model = model.model.llm_engine.model_executor\
.driver_worker.model_runner.model
assert not hasattr(model, "lm_head")
def check_model(model):
assert isinstance(model, RobertaEmbeddingModel)
assert not hasattr(model, "lm_head")
assert isinstance(model._pooler, CLSPool)
vllm_model.apply_model(check_model)
assert output
......@@ -33,10 +33,13 @@ def test_models(
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
def print_model(model):
print(model)
vllm_model.apply_model(print_model)
for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
......
......@@ -51,10 +51,13 @@ def test_models(
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
def print_model(model):
print(model)
vllm_model.apply_model(print_model)
for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
......
......@@ -73,10 +73,13 @@ def test_models(
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
def print_model(model):
print(model)
vllm_model.apply_model(print_model)
check_logprobs_close(
outputs_0_lst=hf_outputs,
......
......@@ -5,7 +5,6 @@ import pytest
import torch
from PIL import Image
from vllm.entrypoints.llm import LLM
from vllm.multimodal.image import rescale_image_size
from vllm.multimodal.video import rescale_video_size, sample_frames_from_video
......@@ -69,7 +68,7 @@ class Qwen2VLPromptVideoEmbeddingInput(TypedDict):
def batch_make_image_embeddings(
image_batches: List[Union[Image.Image, List[Image.Image]]], processor,
llm: LLM) -> List[Qwen2VLPromptImageEmbeddingInput]:
llm: VllmRunner) -> List[Qwen2VLPromptImageEmbeddingInput]:
"""batched image embeddings for Qwen2-VL
This will infer all images' embeddings in a single batch,
......@@ -106,17 +105,19 @@ def batch_make_image_embeddings(
image_grid_thw = preprocess_result["image_grid_thw"]
# pixel values to embeddings & grid_thws
def get_image_embeds(model):
with torch.no_grad():
visual = llm.llm_engine.model_executor.driver_worker. \
model_runner.model.visual
visual = model.visual
pixel_values_on_device = pixel_values.to(visual.device,
dtype=visual.dtype)
image_grid_thw_on_device = image_grid_thw.to(visual.device,
dtype=torch.int64)
image_embeds = visual(pixel_values_on_device,
return visual(pixel_values_on_device,
grid_thw=image_grid_thw_on_device)
image_embeds = torch.concat(llm.apply_model(get_image_embeds))
# split into original batches
result: List[Qwen2VLPromptImageEmbeddingInput] = []
image_counter = 0
......@@ -150,7 +151,7 @@ def batch_make_image_embeddings(
def batch_make_video_embeddings(
video_batches: PromptVideoInput, processor,
llm: LLM) -> List[Qwen2VLPromptVideoEmbeddingInput]:
llm: VllmRunner) -> List[Qwen2VLPromptVideoEmbeddingInput]:
"""batched video embeddings for Qwen2-VL
A NDArray represents a single video's all frames.
......@@ -187,17 +188,19 @@ def batch_make_video_embeddings(
video_grid_thw = preprocess_result["video_grid_thw"]
# pixel values to embeddings & grid_thws
def get_image_embeds(model):
with torch.no_grad():
visual = llm.llm_engine.model_executor.driver_worker.\
model_runner.model.visual
visual = model.visual
pixel_values_on_device = pixel_values.to(visual.device,
dtype=visual.dtype)
video_grid_thw_on_device = video_grid_thw.to(visual.device,
dtype=torch.int64)
video_embeds = visual(pixel_values_on_device,
return visual(pixel_values_on_device,
grid_thw=video_grid_thw_on_device)
video_embeds = torch.concat(llm.apply_model(get_image_embeds))
# split into original batches
result: List[Qwen2VLPromptVideoEmbeddingInput] = []
video_counter = 0
......@@ -278,9 +281,9 @@ def run_embedding_input_test(
max_tokens,
num_logprobs=num_logprobs,
images=batch_make_image_embeddings(
images, processor, vllm_model.model) if images else None,
images, processor, vllm_model) if images else None,
videos=batch_make_video_embeddings(
videos, processor, vllm_model.model) if videos else None)
videos, processor, vllm_model) if videos else None)
for prompts, images, videos in inputs
]
......
......@@ -24,10 +24,13 @@ def test_classification_models(
) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.classify(example_prompts)
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
def print_model(model):
print(model)
vllm_model.apply_model(print_model)
with hf_runner(model,
dtype=dtype,
......
......@@ -62,10 +62,13 @@ def test_models(
max_model_len=None,
**vllm_extra_kwargs) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
def print_model(model):
print(model)
vllm_model.apply_model(print_model)
check_embeddings_close(
embeddings_0_lst=hf_outputs,
......
......@@ -30,7 +30,8 @@ from vllm.platforms import current_platform
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
model_path, strategy, quant_type, shape_0, is_symmetric = model_args
with vllm_runner(model_path, enforce_eager=True) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
......@@ -50,8 +51,10 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
assert zp_valid(gate_up_proj.input_zero_point)
assert zp_valid(down_proj.input_zero_point)
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(o_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(o_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(gate_up_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(down_proj.quant_method,
......@@ -75,6 +78,8 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
assert qkv_proj.weight_scale.dtype is torch.float32
assert qkv_proj.input_scale.dtype is torch.float32
llm.apply_model(check_model)
output = llm.generate_greedy(["Hello my name is"], max_tokens=20)
assert output
......@@ -129,17 +134,21 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):
def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
model_path, strategy = model_args
with vllm_runner(model_path, dtype=torch.float16) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
assert not qkv_proj.scheme.is_static_input_scheme
assert qkv_proj.scheme.strategy == strategy
assert qkv_proj.weight.dtype is torch.int8
llm.apply_model(check_model)
output = llm.generate_greedy(["Hello my name is"], max_tokens=20)
assert output
......@@ -152,20 +161,25 @@ def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
model, strategy, group, pack_factor = wNa16_args
with vllm_runner(model) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsWNA16)
assert qkv_proj.scheme.strategy == strategy
assert qkv_proj.scheme.group_size == (-1 if group is None else group)
assert qkv_proj.scheme.group_size == (-1
if group is None else group)
assert qkv_proj.weight_packed.dtype is torch.int32
assert qkv_proj.weight_scale.dtype is torch.float16
assert qkv_proj.scheme.pack_factor == pack_factor
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
......@@ -173,15 +187,19 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
def test_compressed_tensors_w4a16_marlin24(vllm_runner):
model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
with vllm_runner(model_path) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Sparse24)
assert qkv_proj.weight_packed.dtype is torch.int32
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
......@@ -189,12 +207,14 @@ def test_compressed_tensors_w4a16_marlin24(vllm_runner):
def test_compressed_tensors_fp8(vllm_runner):
model_path = "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
with vllm_runner(model_path) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(
qkv_proj.scheme,
(CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8))
......@@ -207,6 +227,8 @@ def test_compressed_tensors_fp8(vllm_runner):
assert qkv_proj.weight_scale.dtype is torch.float32
assert len(qkv_proj.weight_scale.shape) == 0
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
......@@ -248,13 +270,16 @@ def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy):
def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
model, weight_strategy, input_strategy = args_2of4
with vllm_runner(model) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn
_test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output
......@@ -273,13 +298,16 @@ def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
model, weight_strategy, input_strategy = args_2of4
with vllm_runner(model) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert qkv_proj.scheme.weights_dtype == torch.int8
_test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output
......@@ -293,11 +321,13 @@ def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
model = args_2of4
with vllm_runner(model) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensors24)
assert qkv_proj.scheme.weight_quant is None
......@@ -308,6 +338,8 @@ def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
assert sparsity_map.get("Linear").format == "dense"
assert sparsity_map.get("Linear").sparsity_structure == "2:4"
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output
......@@ -49,14 +49,18 @@ KV_CACHE_MODELS = [
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str):
with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
def check_model(model):
attn = model.model.layers[0].self_attn.attn
assert isinstance(attn.quant_method, Fp8KVCacheMethod)
# NOTE: it is valid for scales to be 1.0 (default value), but we know
# these checkpoints have scales < 1.0
# NOTE: it is valid for scales to be 1.0 (default value), but
# we know these checkpoints have scales < 1.0
assert 0.0 < attn._k_scale < 1.0
assert 0.0 < attn._v_scale < 1.0
llm.apply_model(check_model)
# note: this does not test accuracy, just that we can run through
# see lm-eval tests for accuracy
outputs = llm.generate_greedy(prompts=["Hello my name is"],
......@@ -77,7 +81,7 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
quantization="fp8",
kv_cache_dtype=kv_cache_dtype) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
def check_model(model):
fc1 = model.model.decoder.layers[0].fc1
assert isinstance(fc1.quant_method, Fp8LinearMethod)
if kv_cache_dtype == "fp8":
......@@ -94,6 +98,8 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
# for weight-only quantization using Marlin kernels
assert fc1.weight.dtype == torch.int32
llm.apply_model(check_model)
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.")
......
......@@ -28,20 +28,23 @@ def test_lm_head(
model_lm_head_quant: Tuple[str, bool],
) -> None:
model, lm_head_quantized = model_lm_head_quant
vllm_model = vllm_runner(model, dtype=torch.float16, max_model_len=2048)
lm_head_layer = (vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model.lm_head)
with vllm_runner(model, dtype=torch.float16,
max_model_len=2048) as vllm_model:
def check_model(model):
lm_head_layer = model.lm_head
if lm_head_quantized:
assert isinstance(
lm_head_layer.linear_method,
(GPTQLinearMethod, GPTQMarlinLinearMethod, MarlinLinearMethod))
assert isinstance(lm_head_layer.linear_method,
(GPTQLinearMethod, GPTQMarlinLinearMethod,
MarlinLinearMethod))
else:
assert isinstance(lm_head_layer.linear_method,
UnquantizedEmbeddingMethod)
vllm_model.apply_model(check_model)
print(
vllm_model.generate_greedy(prompts=["Hello my name is"],
max_tokens=10)[0][1])
del vllm_model
......@@ -12,7 +12,8 @@ from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
def test_quark_fp8(vllm_runner):
model_path = "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
with vllm_runner(model_path) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
......@@ -26,5 +27,7 @@ def test_quark_fp8(vllm_runner):
#assert qkv_proj.weight.dtype is torch.float8_e4m3fnuz
assert len(qkv_proj.weight_scale.shape) == 0
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
......@@ -3,6 +3,7 @@ import json
import os
import pathlib
import subprocess
from functools import partial
from unittest.mock import MagicMock, patch
import openai
......@@ -24,7 +25,6 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
# yapf: enable
from vllm.utils import PlaceholderModule, import_from_path
from ..conftest import VllmRunner
from ..utils import VLLM_PATH, RemoteOpenAIServer
from .conftest import retry_until_skip
......@@ -58,16 +58,6 @@ def is_curl_installed():
return False
def get_torch_model(vllm_runner: VllmRunner):
return vllm_runner \
.model \
.llm_engine \
.model_executor \
.driver_worker \
.model_runner \
.model
def write_keyfile(keyfile_path: str):
encryption_params = EncryptionParams.random()
pathlib.Path(keyfile_path).parent.mkdir(parents=True, exist_ok=True)
......@@ -121,8 +111,10 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
config_for_serializing = TensorizerConfig(tensorizer_uri=model_path,
encryption_keyfile=key_path)
serialize_vllm_model(get_torch_model(vllm_model),
config_for_serializing)
vllm_model.apply_model(
partial(serialize_vllm_model,
tensorizer_config=config_for_serializing))
config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
encryption_keyfile=key_path)
......@@ -175,8 +167,10 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
with vllm_runner(model_ref, ) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors")
serialize_vllm_model(get_torch_model(vllm_model),
TensorizerConfig(tensorizer_uri=model_path))
vllm_model.apply_model(
partial(
serialize_vllm_model,
tensorizer_config=TensorizerConfig(tensorizer_uri=model_path)))
with vllm_runner(
model_ref,
......@@ -215,8 +209,10 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
with vllm_runner(model_ref, ) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors")
serialize_vllm_model(get_torch_model(vllm_model),
TensorizerConfig(tensorizer_uri=model_path))
vllm_model.apply_model(
partial(
serialize_vllm_model,
tensorizer_config=TensorizerConfig(tensorizer_uri=model_path)))
model_loader_extra_config = {
"tensorizer_uri": str(model_path),
......@@ -337,7 +333,9 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
with vllm_runner(model_ref) as vllm_model:
outputs = vllm_model.generate(prompts, sampling_params)
serialize_vllm_model(get_torch_model(vllm_model), config)
vllm_model.apply_model(
partial(serialize_vllm_model, tensorizer_config=config))
assert is_vllm_tensorized(config)
......
......@@ -5,10 +5,10 @@ from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional)
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Type, Union, cast, overload
from typing import Set, Type, Union, cast, overload
import torch
from typing_extensions import TypeVar, deprecated
......@@ -1818,17 +1818,6 @@ class LLMEngine:
def stop_profile(self) -> None:
self.model_executor.stop_profile()
def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
"""
See LLM.collective_rpc for more details.
"""
return self.model_executor.collective_rpc(method, timeout, args,
kwargs)
def check_health(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()
......
......@@ -5,8 +5,9 @@ from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence,
Tuple, Type, Union, cast, overload)
import cloudpickle
import torch.nn as nn
from tqdm import tqdm
from typing_extensions import deprecated
from typing_extensions import TypeVar, deprecated
from vllm import envs
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
......@@ -42,6 +43,8 @@ from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
logger = init_logger(__name__)
_R = TypeVar("_R", default=Any)
class LLM:
"""An LLM for generating texts from given prompts and sampling parameters.
......@@ -464,25 +467,42 @@ class LLM:
return self.engine_class.validate_outputs(outputs, RequestOutput)
def collective_rpc(self,
method: Union[str, Callable],
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
kwargs: Optional[Dict[str, Any]] = None) -> List[_R]:
"""
Run a method on all workers, with homogeneous arguments.
The main extension point for the LLM entrypoint.
Users can provide custom worker class through `worker_cls`
argument, and implement new methods in the worker class.
Then, users can call the new methods through this API.
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
The method can also be a callable, which will be serialized
and sent to all workers to execute.
Execute an RPC call on all workers.
Args:
method: Name of the worker method to execute, or a callable that
is serialized and sent to all workers to execute.
If the method is a callable, it should accept an additional
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
timeout: Maximum time in seconds to wait for execution. Raises a
:exc:`TimeoutError` on timeout. `None` means wait indefinitely.
args: Positional arguments to pass to the worker method.
kwargs: Keyword arguments to pass to the worker method.
Returns:
A list containing the results from each worker.
Note:
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
"""
executor = self.llm_engine.model_executor
return executor.collective_rpc(method, timeout, args, kwargs)
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
"""
Run a function directly on the model inside each worker,
returning the result for each of them.
"""
return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
executor = self.llm_engine.model_executor
return executor.apply_model(func)
def beam_search(
self,
......
......@@ -3,6 +3,9 @@ from abc import ABC, abstractmethod
from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
Union)
import torch.nn as nn
from typing_extensions import TypeVar
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
......@@ -11,9 +14,12 @@ from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.utils import make_async
from vllm.worker.worker_base import WorkerBase
logger = init_logger(__name__)
_R = TypeVar("_R", default=Any)
class ExecutorBase(ABC):
"""Base class for all executors.
......@@ -44,22 +50,37 @@ class ExecutorBase(ABC):
@abstractmethod
def _init_executor(self) -> None:
pass
raise NotImplementedError
@abstractmethod
def collective_rpc(self,
method: Union[str, Callable],
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
kwargs: Optional[Dict[str, Any]] = None) -> List[_R]:
"""
The main interface of the executor to run a method on all workers,
with homogeneous arguments.
If the args are heterogeneous, then we can pack them into a list,
and unpack them in the method of every worker, because every worker
knows their own rank.
Execute an RPC call on all workers.
Args:
method: Name of the worker method to execute, or a callable that
is serialized and sent to all workers to execute.
If the method is a callable, it should accept an additional
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
timeout: Maximum time in seconds to wait for execution. Raises a
:exc:`TimeoutError` on timeout. `None` means wait indefinitely.
args: Positional arguments to pass to the worker method.
kwargs: Keyword arguments to pass to the worker method.
Returns:
A list containing the results from each worker.
Note:
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
"""
pass
raise NotImplementedError
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and
......@@ -97,6 +118,17 @@ class ExecutorBase(ABC):
self.collective_rpc("initialize_cache",
args=(num_gpu_blocks, num_cpu_blocks))
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
"""
Run a function directly on the model inside each worker,
returning the result for each of them.
"""
def rpc_func(worker: WorkerBase) -> _R:
return func(worker.get_model())
return self.collective_rpc(rpc_func)
def execute_model(
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
......
......@@ -148,7 +148,7 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
) -> List[Any]:
"""Runs the given method on all workers.
Args:
......
......@@ -459,16 +459,7 @@ def tensorize_vllm_model(engine_args: EngineArgs,
stream.write(encryption_params.key)
engine = LLMEngine.from_engine_args(engine_args)
if tensorizer_config._is_sharded:
# if the engine is a distributed engine (for tensor parallel) then each
# worker shard needs to serialize its part of the model.
engine.model_executor._run_workers(
engine.model_executor.collective_rpc(
"save_tensorized_model",
tensorizer_config=tensorizer_config,
)
else:
# with a single worker, we can get to the underlying model directly
serialize_vllm_model(
engine.model_executor.driver_worker.model_runner.model,
tensorizer_config,
kwargs=dict(tensorizer_config=tensorizer_config),
)
......@@ -2,6 +2,7 @@ import weakref
from typing import List, Optional, Set, Tuple
import torch
import torch.nn as nn
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
......@@ -10,6 +11,10 @@ from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
class _DummyModel(nn.Module):
pass
class NGramWorker(NonLLMProposerWorkerBase):
"""NGramWorker provides a light drafter without need for model.
......@@ -36,7 +41,6 @@ class NGramWorker(NonLLMProposerWorkerBase):
def init_device(self):
self.device = torch.device(f"{self.device_type}:{self.local_rank}")
self.load_model = lambda *args, **kwargs: None
# Current NGramWorker only supports Top1Proposer
self._proposer = Top1Proposer(
......@@ -45,6 +49,12 @@ class NGramWorker(NonLLMProposerWorkerBase):
vocab_size=self.vocab_size,
)
def load_model(self) -> None:
pass # Dummy
def get_model(self) -> nn.Module:
return _DummyModel()
def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment