Unverified Commit 611722d1 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

test: several fixes for e2e vllm tests (#1633)

parent 2becce56
...@@ -32,14 +32,13 @@ logger = logging.getLogger(__name__) ...@@ -32,14 +32,13 @@ logger = logging.getLogger(__name__)
class LocalConnector(PlannerConnector): class LocalConnector(PlannerConnector):
def __init__(self, namespace: str, runtime: DistributedRuntime, backend: str): def __init__(self, namespace: str, runtime: DistributedRuntime):
""" """
Initialize LocalConnector and connect to CircusController. Initialize LocalConnector and connect to CircusController.
Args: Args:
namespace: The Dynamo namespace namespace: The Dynamo namespace
runtime: Optional DistributedRuntime instance runtime: Optional DistributedRuntime instance
backend: The backend to use ("vllm_v0", "vllm_v1")
""" """
self.namespace = namespace self.namespace = namespace
self.runtime = runtime self.runtime = runtime
......
...@@ -12,3 +12,74 @@ ...@@ -12,3 +12,74 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import os
import pytest
# List of models used in the serve tests
SERVE_TEST_MODELS = [
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"llava-hf/llava-1.5-7b-hf",
]
logger = logging.getLogger(__name__)
@pytest.fixture(scope="session")
def predownload_models():
# Check for HF_TOKEN in environment
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
logger.info("HF_TOKEN found in environment")
else:
logger.warning(
"HF_TOKEN not found in environment. "
"Some models may fail to download or you may encounter rate limits. "
"Get a token from https://huggingface.co/settings/tokens"
)
try:
from huggingface_hub import snapshot_download
for model_id in SERVE_TEST_MODELS:
logger.info(f"Pre-downloading model: {model_id}")
try:
# Download the full model snapshot (includes all files)
# HuggingFace will handle caching automatically
snapshot_download(
repo_id=model_id,
token=hf_token,
)
logger.info(f"Successfully pre-downloaded: {model_id}")
except Exception as e:
logger.error(f"Failed to pre-download {model_id}: {e}")
# Don't fail the fixture - let individual tests handle missing models
except ImportError:
logger.warning(
"huggingface_hub not installed. "
"Models will be downloaded during test execution."
)
yield
# Automatically use the predownload fixture for all serve tests
def pytest_collection_modifyitems(config, items):
for item in items:
# Skip items that don't have fixturenames (like MypyFileItem)
if not hasattr(item, "fixturenames"):
continue
# Only apply to tests in the serve directory
if "serve" in str(item.path):
# Check if the test already uses the fixture
if "predownload_models" not in item.fixturenames:
# Don't add if test explicitly marks to skip model download
if not item.get_closest_marker("skip_model_download"):
item.fixturenames = list(item.fixturenames)
item.fixturenames.append("predownload_models")
...@@ -159,7 +159,7 @@ deployment_graphs = { ...@@ -159,7 +159,7 @@ deployment_graphs = {
"multimodal_agg": ( "multimodal_agg": (
DeploymentGraph( DeploymentGraph(
module="graphs.agg:Frontend", module="graphs.agg:Frontend",
config="configs/agg.yaml", config="configs/agg-llava.yaml",
directory="/workspace/examples/multimodal", directory="/workspace/examples/multimodal",
endpoints=["v1/chat/completions"], endpoints=["v1/chat/completions"],
response_handlers=[ response_handlers=[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment