Unverified Commit 765e4610 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Fix Nemotron Parse loading (#37407)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 6a9cceb2
...@@ -24,12 +24,8 @@ class ModelRequestData(NamedTuple): ...@@ -24,12 +24,8 @@ class ModelRequestData(NamedTuple):
sampling_params: SamplingParams | None = None sampling_params: SamplingParams | None = None
@pytest.mark.core_model
@pytest.mark.parametrize("question", [QUESTION]) @pytest.mark.parametrize("question", [QUESTION])
def test_keye_vl( def test_keye_vl(image_assets, question: str):
image_assets,
question: str,
):
images = [asset.pil_image for asset in image_assets] images = [asset.pil_image for asset in image_assets]
image_urls = [encode_image_url(image) for image in images] image_urls = [encode_image_url(image) for image in images]
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Sequence
from collections.abc import Sequence
import pytest import pytest
import regex as re
from transformers import AutoModel from transformers import AutoModel
from tests.models.utils import check_logprobs_close from tests.models.utils import check_logprobs_close
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.logprobs import Logprob, SampleLogprobs
from vllm.tokenizers import TokenizerLike
from ....conftest import HfRunner, PromptImageInput, VllmRunner from ....conftest import HfRunner, PromptImageInput, VllmRunner
from ....utils import create_new_process_for_each_test
IMAGE = ImageAsset("paper-11").pil_image_ext(ext="png").convert("RGB") IMAGE = ImageAsset("paper-11").pil_image_ext(ext="png").convert("RGB")
PROMPT = "</s><s><predict_bbox><predict_classes><output_markdown>" PROMPT = "</s><s><predict_bbox><predict_classes><output_markdown>"
class DummyLogprobs(dict[int, Logprob]):
def __init__(self, vocab_ids: Iterable[int]):
super().__init__(dict.fromkeys(vocab_ids, Logprob(0.0)))
def __repr__(self):
return "DummyLogprobs()"
def mask_bbox_tokens(
output: tuple[list[int], str, SampleLogprobs],
tokenizer: TokenizerLike,
) -> tuple[list[int], str, SampleLogprobs]:
"""
Always pass check_logprobs_close check for bounding box tokens
because it is reasonable for them to differ slightly.
"""
ignore_pattern = r"<[xy]_[\d.]+>"
vocab = tokenizer.get_vocab()
output_ids, output_str, out_logprobs = output
masked_logprobs = list[dict[int, Logprob]]()
for token, logprobs in zip(output_ids, out_logprobs):
if re.match(ignore_pattern, tokenizer.decode(token)):
masked_logprobs.append(DummyLogprobs(vocab.values()))
else:
masked_logprobs.append(logprobs)
return output_ids, output_str, masked_logprobs
def run_test( def run_test(
hf_runner: type[HfRunner], hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner], vllm_runner: type[VllmRunner],
...@@ -44,6 +76,8 @@ def run_test( ...@@ -44,6 +76,8 @@ def run_test(
for prompts, images in inputs for prompts, images in inputs
] ]
tokenizer = vllm_model.llm.get_tokenizer()
with hf_runner(model, dtype=dtype, auto_cls=AutoModel) as hf_model: with hf_runner(model, dtype=dtype, auto_cls=AutoModel) as hf_model:
hf_outputs_per_case = [ hf_outputs_per_case = [
hf_model.generate_greedy_logprobs_limit( hf_model.generate_greedy_logprobs_limit(
...@@ -58,18 +92,20 @@ def run_test( ...@@ -58,18 +92,20 @@ def run_test(
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, vllm_outputs_per_case): for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, vllm_outputs_per_case):
check_logprobs_close( check_logprobs_close(
outputs_0_lst=hf_outputs, outputs_0_lst=[
outputs_1_lst=vllm_outputs, mask_bbox_tokens(output, tokenizer) for output in hf_outputs
],
outputs_1_lst=[
mask_bbox_tokens(output, tokenizer) for output in vllm_outputs
],
name_0="hf", name_0="hf",
name_1="vllm", name_1="vllm",
) )
@pytest.mark.core_model
@pytest.mark.parametrize("model", ["nvidia/NVIDIA-Nemotron-Parse-v1.1"]) @pytest.mark.parametrize("model", ["nvidia/NVIDIA-Nemotron-Parse-v1.1"])
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
@create_new_process_for_each_test("spawn")
def test_models( def test_models(
hf_runner, vllm_runner, model: str, dtype: str, num_logprobs: int hf_runner, vllm_runner, model: str, dtype: str, num_logprobs: int
) -> None: ) -> None:
...@@ -77,10 +113,7 @@ def test_models( ...@@ -77,10 +113,7 @@ def test_models(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
inputs=[ inputs=[
( ([PROMPT] * 10, [IMAGE] * 10),
[PROMPT] * 10,
[IMAGE] * 10,
),
], ],
model=model, model=model,
dtype=dtype, dtype=dtype,
......
...@@ -8,7 +8,7 @@ from tests.conftest import VllmRunner ...@@ -8,7 +8,7 @@ from tests.conftest import VllmRunner
from tests.utils import create_new_process_for_each_test from tests.utils import create_new_process_for_each_test
@create_new_process_for_each_test() # Memory is not cleaned up properly otherwise @create_new_process_for_each_test() # Hangs otherwise
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model",
[ [
......
...@@ -319,8 +319,9 @@ class MBartDecoderNoPos(nn.Module): ...@@ -319,8 +319,9 @@ class MBartDecoderNoPos(nn.Module):
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"), (".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"), (".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"), (".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
(".encoder_attn.kv_proj", ".encoder_attn.k_proj", "k"), # MergedColumnParallelLinear uses integer indices (0, 1)
(".encoder_attn.kv_proj", ".encoder_attn.v_proj", "v"), (".encoder_attn.kv_proj", ".encoder_attn.k_proj", 0),
(".encoder_attn.kv_proj", ".encoder_attn.v_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: set[str] = set() loaded_params: set[str] = set()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment