Commit 82e40fb7 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.15.0rc1' into v0.15.0rc1-ori

parents 30a1922e 58996f35
......@@ -566,6 +566,42 @@ def run_glm4_5v_fp8(questions: list[str], modality: str) -> ModelRequestData:
)
# GLM-OCR
def run_glm_ocr(questions: list[str], modality: str) -> ModelRequestData:
model_name = "zai-org/GLM-OCR"
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=2,
mm_processor_kwargs={
"size": {"shortest_edge": 12544, "longest_edge": 47040000},
"fps": 1,
},
limit_mm_per_prompt={modality: 1},
enforce_eager=True,
)
if modality == "image":
placeholder = "<|begin_of_image|><|image|><|end_of_image|>"
elif modality == "video":
placeholder = "<|begin_of_video|><|video|><|end_of_video|>"
prompts = [
(
"[gMASK]<sop><|system|>\nYou are a helpful assistant.<|user|>\n"
f"{placeholder}"
f"{question}<|assistant|>assistant\n"
)
for question in questions
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# H2OVL-Mississippi
def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
......@@ -1889,6 +1925,32 @@ def run_step3(questions: list[str], modality: str) -> ModelRequestData:
)
# StepVL10B
def run_step_vl(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "stepfun-ai/Step3-VL-10B"
engine_args = EngineArgs(
model=model_name,
max_num_batched_tokens=4096,
tensor_parallel_size=1,
trust_remote_code=True,
limit_mm_per_prompt={modality: 1},
reasoning_parser="deepseek_r1",
)
prompts = [
"<|begin▁of▁sentence|> You are a helpful assistant.<|BOT|>user\n "
f"<im_patch>{question} <|EOT|><|BOT|>assistant\n<think>\n"
for question in questions
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# omni-research/Tarsier-7b
def run_tarsier(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
......@@ -1962,6 +2024,7 @@ model_example_map = {
"glm4_1v": run_glm4_1v,
"glm4_5v": run_glm4_5v,
"glm4_5v_fp8": run_glm4_5v_fp8,
"glm_ocr": run_glm_ocr,
"h2ovl_chat": run_h2ovl,
"hunyuan_vl": run_hunyuan_vl,
"hyperclovax_seed_vision": run_hyperclovax_seed_vision,
......@@ -2006,6 +2069,7 @@ model_example_map = {
"skywork_chat": run_skyworkr1v,
"smolvlm": run_smolvlm,
"step3": run_step3,
"stepvl": run_step_vl,
"tarsier": run_tarsier,
"tarsier2": run_tarsier2,
}
......@@ -2013,6 +2077,7 @@ model_example_map = {
MODELS_NEED_VIDEO_METADATA = [
"glm4_1v",
"glm_ocr",
"glm4_5v",
"glm4_5v_fp8",
"molmo2",
......
......@@ -1182,6 +1182,32 @@ def load_step3(question: str, image_urls: list[str]) -> ModelRequestData:
)
def load_step_vl(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "stepfun-ai/Step3-VL-10B"
engine_args = EngineArgs(
model=model_name,
max_num_batched_tokens=4096,
limit_mm_per_prompt={"image": len(image_urls)},
hf_overrides={"vision_config": {"enable_patch": False}},
trust_remote_code=True,
reasoning_parser="deepseek_r1",
)
prompt = (
"<|begin▁of▁sentence|> You are a helpful assistant.<|BOT|>user\n "
f"{'<im_patch>' * len(image_urls)}{question}<|EOT|><|BOT|>"
"assistant\n<think>\n"
)
image_data = [fetch_image(url) for url in image_urls]
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=image_data,
)
def load_tarsier(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "omni-research/Tarsier-7b"
......@@ -1374,6 +1400,7 @@ model_example_map = {
"rvl": load_r_vl,
"smolvlm": load_smolvlm,
"step3": load_step3,
"stepvl": load_step_vl,
"tarsier": load_tarsier,
"tarsier2": load_tarsier2,
"glm4_5v": load_glm4_5v,
......
......@@ -157,6 +157,37 @@ VLLM_CONFIGURE_LOGGING=0 \
vllm serve mistralai/Mistral-7B-v0.1 --max-model-len 2048
```
### Example 4: Disable access logs for health check endpoints
In production environments, health check endpoints like `/health`, `/metrics`,
and `/ping` are frequently called by load balancers and monitoring systems,
generating a large volume of repetitive access logs. To reduce log noise while
keeping logs for other endpoints, use the `--disable-access-log-for-endpoints`
option.
**Disable access logs for health and metrics endpoints:**
```bash
vllm serve mistralai/Mistral-7B-v0.1 --max-model-len 2048 \
--disable-access-log-for-endpoints /health,/metrics,/ping
```
**Common endpoints to consider filtering:**
| Endpoint | Description | Typical Caller |
| ---------- | ---------------------- | ---------------------------------------------------- |
| `/health` | Health check | Kubernetes liveness/readiness probes, load balancers |
| `/metrics` | Prometheus metrics | Prometheus scraper (every 15-60s) |
| `/ping` | SageMaker health check | SageMaker infrastructure |
| `/load` | Server load metrics | Custom monitoring |
**Notes:**
- This option only affects uvicorn access logs, not vLLM application logs
- Specify multiple endpoints by separating them with commas (no spaces)
- The filter uses exact path matching, query parameters are ignored (e.g., `/health?verbose=true` matches `/health`)
- If you need to completely disable all access logs, use `--disable-uvicorn-access-log` instead
## Additional resources
- [`logging.config` Dictionary Schema Details](https://docs.python.org/3/library/logging.config.html#dictionary-schema-details)
......@@ -44,6 +44,7 @@ vllm = "vllm.entrypoints.cli.main:main"
[project.entry-points."vllm.general_plugins"]
lora_filesystem_resolver = "vllm.plugins.lora_resolvers.filesystem_resolver:register_filesystem_resolver"
lora_hf_hub_resolver = "vllm.plugins.lora_resolvers.hf_hub_resolver:register_hf_hub_resolver"
[tool.setuptools_scm]
# no extra settings needed, presence enables setuptools-scm
......
......@@ -992,7 +992,7 @@ async def test_mcp_tool_multi_turn(client: OpenAI, model_name: str, server):
# First turn - make a calculation
response1 = await client.responses.create(
model=model_name,
input="Calculate 123 * 456 using python and print the result.",
input="Calculate 1234 * 4567 using python tool and print the result.",
tools=tools,
temperature=0.0,
instructions=(
......
......@@ -42,6 +42,7 @@ class MockModelConfig:
tokenizer_revision = None
multimodal_config = MultiModalConfig()
hf_config = MockHFConfig()
hf_text_config = MockHFConfig()
logits_processor_pattern = None
logits_processors: list[str] | None = None
diff_sampling_param: dict | None = None
......
......@@ -518,6 +518,7 @@ class MockModelConfig:
tokenizer_revision = None
multimodal_config = MultiModalConfig()
hf_config = MockHFConfig()
hf_text_config = MockHFConfig()
logits_processors: list[str] | None = None
logits_processor_pattern = None
diff_sampling_param: dict | None = None
......
......@@ -22,6 +22,9 @@ from vllm.distributed import (
)
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
......@@ -40,7 +43,6 @@ from .mk_objects import (
TestMoEQuantConfig,
expert_info,
make_fused_experts,
make_prepare_finalize,
prepare_finalize_info,
)
from .parallel_utils import ProcessGroupInfo
......@@ -603,10 +605,12 @@ def make_modular_kernel(
routing_method=RoutingMethodType.DeepSeekV3,
)
# make modular kernel
prepare_finalize = make_prepare_finalize(
config.prepare_finalize_type, config.all2all_backend(), moe, quant_config
prepare_finalize = maybe_make_prepare_finalize(
moe=moe,
quant_config=quant_config,
allow_new_interface=True,
)
assert prepare_finalize is not None
fused_experts = make_fused_experts(
config.fused_experts_type,
......
......@@ -7,9 +7,6 @@ import torch
# Fused experts and PrepareFinalize imports
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe import TritonExperts
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
)
......@@ -255,13 +252,12 @@ if has_pplx():
)
if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100):
from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize,
create_flashinfer_prepare_finalize,
)
register_prepare_and_finalize(
FlashInferCutlassMoEPrepareAndFinalize,
......@@ -429,24 +425,6 @@ if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe():
]
def make_prepare_finalize(
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
backend: str | None,
moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
) -> mk.FusedMoEPrepareAndFinalize:
if backend != "naive" and backend is not None:
prepare_finalize = maybe_make_prepare_finalize(moe, quant_config)
assert prepare_finalize is not None
return prepare_finalize
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
return create_flashinfer_prepare_finalize(
use_dp=moe.moe_parallel_config.dp_size > 1
)
else:
return MoEPrepareAndFinalizeNoEP()
def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor:
s = rank * num_local_experts
e = s + num_local_experts
......
......@@ -294,12 +294,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
)
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(
defer_input_quant=FlashInferExperts.expects_unquantized_inputs(
moe_config=moe_config,
quant_config=quant_config,
)
),
MoEPrepareAndFinalizeNoEP(),
FlashInferExperts(
moe_config=moe_config,
quant_config=quant_config,
......
......@@ -106,12 +106,7 @@ def test_flashinfer_fp4_moe_no_graph(
)
flashinfer_experts = FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(
defer_input_quant=FlashInferExperts.expects_unquantized_inputs(
moe_config=moe_config,
quant_config=quant_config,
)
),
MoEPrepareAndFinalizeNoEP(),
FlashInferExperts(moe_config=moe_config, quant_config=quant_config),
)
......
......@@ -90,7 +90,7 @@ def test_cutlass_fp4_moe_no_graph(
)
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp4(
moe_config=make_dummy_moe_config(),
quant_config=quant_config,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import numpy as np
import pytest
import torch
from transformers import AutoModelForTokenClassification
......@@ -8,6 +11,20 @@ from tests.models.utils import softmax
from vllm.platforms import current_platform
@pytest.fixture(autouse=True)
def seed_everything():
"""Seed all random number generators for reproducibility."""
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
yield
@pytest.mark.parametrize("model", ["boltuix/NeuroBERT-NER"])
# The float32 is required for this tiny model to pass the test.
@pytest.mark.parametrize("dtype", ["float"])
......@@ -51,6 +68,7 @@ def test_bert_models(
@pytest.mark.parametrize("model", ["disham993/electrical-ner-ModernBERT-base"])
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.flaky(reruns=3)
@torch.inference_mode
def test_modernbert_models(
hf_runner,
......@@ -59,6 +77,15 @@ def test_modernbert_models(
model: str,
dtype: str,
) -> None:
# NOTE: https://github.com/vllm-project/vllm/pull/32403
# `disham993/electrical-ner-ModernBERT-base` is a randomly initialized
# model, which can cause numerical precision variance and edge cases.
# We use @flaky(reruns=3) to mitigate intermittent failures.
print(
f"\n[NOTE] Testing {model} (randomly initialized weights) - "
"flaky tolerance enabled due to numerical precision variance."
)
with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.token_classify(example_prompts)
......
......@@ -458,6 +458,20 @@ VLM_TEST_SETTINGS = {
],
marks=[large_gpu_mark(min_gb=32)],
),
"glm_ocr": VLMTestInfo(
models=["zai-org/GLM-OCR"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"[gMASK]<|user|>\n{img_prompt}<|assistant|>\n", # noqa: E501
img_idx_to_prompt=lambda idx: "<|begin_of_image|><|image|><|end_of_image|>",
video_idx_to_prompt=lambda idx: "<|begin_of_video|><|video|><|end_of_video|>",
max_model_len=2048,
max_num_seqs=2,
get_stop_token_ids=lambda tok: [151329, 151336, 151338],
num_logprobs=10,
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
auto_cls=AutoModelForImageTextToText,
marks=[large_gpu_mark(min_gb=32)],
),
"h2ovl": VLMTestInfo(
models=[
"h2oai/h2ovl-mississippi-800m",
......
......@@ -91,6 +91,19 @@ MODEL_CONFIGS: dict[str, dict[str, Any]] = {
"use_processor": True,
"question": "What is the content of each image?",
},
"glm_ocr": {
"model_name": "zai-org/GLM-OCR",
"interface": "llm_generate",
"max_model_len": 131072,
"max_num_seqs": 2,
"sampling_params": {
"temperature": 0.0,
"max_tokens": 256,
"stop_token_ids": None,
},
"use_processor": True,
"question": "Text Recognition:",
},
"keye_vl": {
"model_name": "Kwai-Keye/Keye-VL-8B-Preview",
"interface": "llm_generate",
......
......@@ -122,6 +122,7 @@ MM_DATA_PATCHES = {
"ernie4_5_moe_vl": qwen3_vl_patch_mm_data,
"glm4v": glm4_1v_patch_mm_data,
"glm4v_moe": glm4_1v_patch_mm_data,
"glm_ocr": glm4_1v_patch_mm_data,
"glmasr": glmasr_patch_mm_data,
"molmo2": qwen3_vl_patch_mm_data,
"qwen3_vl": qwen3_vl_patch_mm_data,
......
......@@ -256,7 +256,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
),
"Exaone4ForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-4.0-32B"),
"ExaoneMoEForCausalLM": _HfExamplesInfo(
"LGAI-EXAONE/K-EXAONE-236B-A23B", min_transformers_version="5.0.0"
"LGAI-EXAONE/K-EXAONE-236B-A23B", min_transformers_version="5.1.0"
),
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"),
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
......@@ -273,8 +273,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Glm4MoeForCausalLM": _HfExamplesInfo("zai-org/GLM-4.5"),
"Glm4MoeLiteForCausalLM": _HfExamplesInfo(
"zai-org/GLM-4.7-Flash",
min_transformers_version="5.0.0.dev",
is_available_online=False,
min_transformers_version="5.0.0",
),
"GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2", {"alias": "gpt2"}),
"GPTBigCodeForCausalLM": _HfExamplesInfo(
......@@ -651,7 +650,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
# [Decoder-only]
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
"AudioFlamingo3ForConditionalGeneration": _HfExamplesInfo(
"nvidia/audio-flamingo-3-hf", min_transformers_version="5.0.0.dev"
"nvidia/audio-flamingo-3-hf", min_transformers_version="5.0.0"
),
"AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereLabs/aya-vision-8b"),
"BagelForConditionalGeneration": _HfExamplesInfo("ByteDance-Seed/BAGEL-7B-MoT"),
......@@ -694,7 +693,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"GlmAsrForConditionalGeneration": _HfExamplesInfo(
"zai-org/GLM-ASR-Nano-2512",
trust_remote_code=True,
min_transformers_version="5.0",
min_transformers_version="5.0.0",
),
"GraniteVision": _HfExamplesInfo("ibm-granite/granite-vision-3.3-2b"),
"GraniteSpeechForConditionalGeneration": _HfExamplesInfo(
......@@ -707,6 +706,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
),
"Glm4vForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.1V-9B-Thinking"),
"Glm4vMoeForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.5V"),
"GlmOcrForConditionalGeneration": _HfExamplesInfo(
"zai-org/GLM-OCR",
is_available_online=False,
min_transformers_version="5.1.0",
),
"H2OVLChatModel": _HfExamplesInfo(
"h2oai/h2ovl-mississippi-800m",
trust_remote_code=True,
......@@ -771,6 +775,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
)
},
),
"KimiK25ForConditionalGeneration": _HfExamplesInfo(
"moonshotai/Kimi-K2.5",
trust_remote_code=True,
is_available_online=False,
),
"LightOnOCRForConditionalGeneration": _HfExamplesInfo(
"lightonai/LightOnOCR-1B-1025"
),
......@@ -1044,7 +1053,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"ExaoneMoeMTP": _HfExamplesInfo(
"LGAI-EXAONE/K-EXAONE-236B-A23B",
speculative_model="LGAI-EXAONE/K-EXAONE-236B-A23B",
min_transformers_version="5.0.0",
min_transformers_version="5.1.0",
),
"Glm4MoeMTPModel": _HfExamplesInfo(
"zai-org/GLM-4.5",
......@@ -1053,7 +1062,13 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"Glm4MoeLiteMTPModel": _HfExamplesInfo(
"zai-org/GLM-4.7-Flash",
speculative_model="zai-org/GLM-4.7-Flash",
min_transformers_version="5.0.0",
),
"GlmOcrMTPModel": _HfExamplesInfo(
"zai-org/GLM-OCR",
speculative_model="zai-org/GLM-OCR",
is_available_online=False,
min_transformers_version="5.1.0",
),
"LongCatFlashMTPModel": _HfExamplesInfo(
"meituan-longcat/LongCat-Flash-Chat",
......@@ -1080,27 +1095,27 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
_TRANSFORMERS_BACKEND_MODELS = {
"TransformersEmbeddingModel": _HfExamplesInfo(
"BAAI/bge-base-en-v1.5", min_transformers_version="5.0.0.dev"
"BAAI/bge-base-en-v1.5", min_transformers_version="5.0.0"
),
"TransformersForSequenceClassification": _HfExamplesInfo(
"papluca/xlm-roberta-base-language-detection",
min_transformers_version="5.0.0.dev",
min_transformers_version="5.0.0",
),
"TransformersForCausalLM": _HfExamplesInfo(
"hmellor/Ilama-3.2-1B", trust_remote_code=True
),
"TransformersMultiModalForCausalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
"TransformersMoEForCausalLM": _HfExamplesInfo(
"allenai/OLMoE-1B-7B-0924", min_transformers_version="5.0.0.dev"
"allenai/OLMoE-1B-7B-0924", min_transformers_version="5.0.0"
),
"TransformersMultiModalMoEForCausalLM": _HfExamplesInfo(
"Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="5.0.0.dev"
"Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="5.0.0"
),
"TransformersMoEEmbeddingModel": _HfExamplesInfo(
"Qwen/Qwen3-30B-A3B", min_transformers_version="5.0.0.dev"
"Qwen/Qwen3-30B-A3B", min_transformers_version="5.0.0"
),
"TransformersMoEForSequenceClassification": _HfExamplesInfo(
"Qwen/Qwen3-30B-A3B", min_transformers_version="5.0.0.dev"
"Qwen/Qwen3-30B-A3B", min_transformers_version="5.0.0"
),
"TransformersMultiModalEmbeddingModel": _HfExamplesInfo("google/gemma-3-4b-it"),
"TransformersMultiModalForSequenceClassification": _HfExamplesInfo(
......
......@@ -78,7 +78,7 @@ def test_models(
from packaging.version import Version
installed = Version(transformers.__version__)
required = Version("5.0.0.dev")
required = Version("5.0.0")
if model == "allenai/OLMoE-1B-7B-0924" and installed < required:
pytest.skip(
"MoE models with the Transformers modeling backend require "
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import pytest
from huggingface_hub.constants import HF_HUB_CACHE
from vllm.plugins.lora_resolvers.hf_hub_resolver import HfHubResolver
LORA_LIB_MODEL_NAME = "ibm-granite/granite-3.3-8b-instruct"
# Repo with multiple LoRAs contained in it
LORA_LIB = "ibm-granite/granite-3.3-8b-rag-agent-lib"
LORA_NAME = "ibm-granite/granite-3.3-8b-rag-agent-lib/answerability_prediction_lora" # noqa: E501
NON_LORA_SUBPATH = "ibm-granite/granite-3.3-8b-rag-agent-lib/README.md"
LIB_DOWNLOAD_DIR = os.path.join(
HF_HUB_CACHE, "models--ibm-granite--granite-3.3-8b-rag-agent-lib"
)
INVALID_REPO_NAME = "thisrepodoesnotexist"
# Repo with only one LoRA in the root dir
LORA_REPO_MODEL_NAME = "meta-llama/Llama-2-7b-hf"
LORA_REPO = "yard1/llama-2-7b-sql-lora-test"
REPO_DOWNLOAD_DIR = os.path.join(
HF_HUB_CACHE, "models--yard1--llama-2-7b-sql-lora-test"
)
@pytest.mark.asyncio
async def test_hf_resolver_with_direct_path():
hf_resolver = HfHubResolver([LORA_REPO])
assert hf_resolver is not None
lora_request = await hf_resolver.resolve_lora(LORA_REPO_MODEL_NAME, LORA_REPO)
assert lora_request.lora_name == LORA_REPO
assert REPO_DOWNLOAD_DIR in lora_request.lora_path
assert "adapter_config.json" in os.listdir(lora_request.lora_path)
@pytest.mark.asyncio
async def test_hf_resolver_with_nested_paths():
hf_resolver = HfHubResolver([LORA_LIB])
assert hf_resolver is not None
lora_request = await hf_resolver.resolve_lora(LORA_LIB_MODEL_NAME, LORA_NAME)
assert lora_request is not None
assert lora_request.lora_name == LORA_NAME
assert LIB_DOWNLOAD_DIR in lora_request.lora_path
assert "adapter_config.json" in os.listdir(lora_request.lora_path)
@pytest.mark.asyncio
async def test_hf_resolver_with_multiple_repos():
hf_resolver = HfHubResolver([LORA_LIB, LORA_REPO])
assert hf_resolver is not None
lora_request = await hf_resolver.resolve_lora(LORA_LIB_MODEL_NAME, LORA_NAME)
assert lora_request is not None
assert lora_request.lora_name == LORA_NAME
assert LIB_DOWNLOAD_DIR in lora_request.lora_path
assert "adapter_config.json" in os.listdir(lora_request.lora_path)
@pytest.mark.asyncio
async def test_missing_adapter():
hf_resolver = HfHubResolver([LORA_LIB])
assert hf_resolver is not None
missing_lora_request = await hf_resolver.resolve_lora(LORA_LIB_MODEL_NAME, "foobar")
assert missing_lora_request is None
@pytest.mark.asyncio
async def test_nonlora_adapter():
hf_resolver = HfHubResolver([LORA_LIB])
assert hf_resolver is not None
readme_request = await hf_resolver.resolve_lora(
LORA_LIB_MODEL_NAME, NON_LORA_SUBPATH
)
assert readme_request is None
@pytest.mark.asyncio
async def test_invalid_repo():
hf_resolver = HfHubResolver([LORA_LIB])
assert hf_resolver is not None
invalid_repo_req = await hf_resolver.resolve_lora(
INVALID_REPO_NAME,
f"{INVALID_REPO_NAME}/foo",
)
assert invalid_repo_req is None
@pytest.mark.asyncio
async def test_trailing_slash():
hf_resolver = HfHubResolver([LORA_LIB])
assert hf_resolver is not None
lora_request = await hf_resolver.resolve_lora(
LORA_LIB_MODEL_NAME,
f"{LORA_NAME}/",
)
assert lora_request is not None
assert lora_request.lora_name == f"{LORA_NAME}/"
assert LIB_DOWNLOAD_DIR in lora_request.lora_path
assert "adapter_config.json" in os.listdir(lora_request.lora_path)
......@@ -36,7 +36,7 @@ class MyGemma2Embedding(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment