Unverified Commit d1f7392c authored by Micah Williamson's avatar Micah Williamson Committed by GitHub
Browse files

[ROCm][CI] Fix v1/logits_processors failure on ROCm (#29927)


Signed-off-by: default avatarMicah Williamson <micah.williamson@amd.com>
parent 9ae3c55b
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random import random
import sys
from typing import Any from typing import Any
import pytest import pytest
...@@ -10,7 +9,6 @@ from tests.utils import create_new_process_for_each_test ...@@ -10,7 +9,6 @@ from tests.utils import create_new_process_for_each_test
from tests.v1.logits_processors.utils import ( from tests.v1.logits_processors.utils import (
DUMMY_LOGITPROC_ARG, DUMMY_LOGITPROC_ARG,
DUMMY_LOGITPROC_FQCN, DUMMY_LOGITPROC_FQCN,
DUMMY_LOGITPROC_MODULE,
MAX_TOKENS, MAX_TOKENS,
MODEL_NAME, MODEL_NAME,
POOLING_MODEL_NAME, POOLING_MODEL_NAME,
...@@ -18,7 +16,6 @@ from tests.v1.logits_processors.utils import ( ...@@ -18,7 +16,6 @@ from tests.v1.logits_processors.utils import (
CustomLogitprocSource, CustomLogitprocSource,
DummyLogitsProcessor, DummyLogitsProcessor,
WrappedPerReqLogitsProcessor, WrappedPerReqLogitsProcessor,
dummy_module,
prompts, prompts,
) )
from tests.v1.logits_processors.utils import entry_points as fake_entry_points from tests.v1.logits_processors.utils import entry_points as fake_entry_points
...@@ -162,8 +159,6 @@ def test_custom_logitsprocs(monkeypatch, logitproc_source: CustomLogitprocSource ...@@ -162,8 +159,6 @@ def test_custom_logitsprocs(monkeypatch, logitproc_source: CustomLogitprocSource
kwargs: dict[str, list[str | type[LogitsProcessor]]] = {} kwargs: dict[str, list[str | type[LogitsProcessor]]] = {}
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN: if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN:
# Scenario: load logitproc based on fully-qualified class name (FQCN) # Scenario: load logitproc based on fully-qualified class name (FQCN)
# Inject dummy module which defines logitproc
sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module
kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN] kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN]
elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS: elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS:
# Scenario: load logitproc from provided class object # Scenario: load logitproc from provided class object
......
...@@ -14,11 +14,9 @@ from tests.utils import RemoteOpenAIServerCustom, create_new_process_for_each_te ...@@ -14,11 +14,9 @@ from tests.utils import RemoteOpenAIServerCustom, create_new_process_for_each_te
from tests.v1.logits_processors.utils import ( from tests.v1.logits_processors.utils import (
DUMMY_LOGITPROC_ARG, DUMMY_LOGITPROC_ARG,
DUMMY_LOGITPROC_FQCN, DUMMY_LOGITPROC_FQCN,
DUMMY_LOGITPROC_MODULE,
MAX_TOKENS, MAX_TOKENS,
MODEL_NAME, MODEL_NAME,
TEMP_GREEDY, TEMP_GREEDY,
dummy_module,
prompts, prompts,
) )
from tests.v1.logits_processors.utils import entry_points as fake_entry_points from tests.v1.logits_processors.utils import entry_points as fake_entry_points
...@@ -47,20 +45,14 @@ def _server_with_logitproc_entrypoint( ...@@ -47,20 +45,14 @@ def _server_with_logitproc_entrypoint(
main.main() main.main()
def _server_with_logitproc_module( def _server_with_logitproc_fqcn(
env_dict: dict[str, str] | None, env_dict: dict[str, str] | None,
model: str, model: str,
vllm_serve_args: list[str], vllm_serve_args: list[str],
) -> None: ) -> None:
"""Start vLLM server, inject module with dummy logitproc""" """Start vLLM server, inject module with dummy logitproc"""
# Patch `modules` to inject dummy logitproc module
from vllm.entrypoints.cli import main from vllm.entrypoints.cli import main
sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module
# fork is required for workers to see entrypoint patch
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "fork"
if env_dict is not None: if env_dict is not None:
os.environ.update(env_dict) os.environ.update(env_dict)
...@@ -99,7 +91,7 @@ def server(default_server_args, request, monkeypatch): ...@@ -99,7 +91,7 @@ def server(default_server_args, request, monkeypatch):
if request.param: if request.param:
# Launch server, append FQCN argument, inject dummy logitproc module # Launch server, append FQCN argument, inject dummy logitproc module
args = default_server_args + request.param args = default_server_args + request.param
_server_fxn = _server_with_logitproc_module _server_fxn = _server_with_logitproc_fqcn
else: else:
# Launch server, inject dummy logitproc entrypoint # Launch server, inject dummy logitproc entrypoint
args = default_server_args args = default_server_args
......
...@@ -27,7 +27,7 @@ DUMMY_LOGITPROC_ARG = "target_token" ...@@ -27,7 +27,7 @@ DUMMY_LOGITPROC_ARG = "target_token"
TEMP_GREEDY = 0.0 TEMP_GREEDY = 0.0
MAX_TOKENS = 20 MAX_TOKENS = 20
DUMMY_LOGITPROC_ENTRYPOINT = "dummy_logitproc" DUMMY_LOGITPROC_ENTRYPOINT = "dummy_logitproc"
DUMMY_LOGITPROC_MODULE = "DummyModule" DUMMY_LOGITPROC_MODULE = "tests.v1.logits_processors.utils"
DUMMY_LOGITPROC_FQCN = f"{DUMMY_LOGITPROC_MODULE}:DummyLogitsProcessor" DUMMY_LOGITPROC_FQCN = f"{DUMMY_LOGITPROC_MODULE}:DummyLogitsProcessor"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment