Unverified Commit 92bbbc39 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

fix: Fix vllm/sglang engine model name if using HF repo (#986)


Signed-off-by: default avatarGraham King <graham@gkgk.org>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent 0a894cc3
...@@ -190,6 +190,8 @@ is equivalent to ...@@ -190,6 +190,8 @@ is equivalent to
dynamo-run in=text out=mistralrs Qwen/Qwen2.5-3B-Instruct dynamo-run in=text out=mistralrs Qwen/Qwen2.5-3B-Instruct
``` ```
If you have multiple GPUs, mistral.rs does automatic tensor parallelism. You do not need to pass any extra flags to dynamo-run to enable it.
### llamacpp ### llamacpp
Currently [llama.cpp](https://github.com/ggml-org/llama.cpp) is not included by default. Build it like this: Currently [llama.cpp](https://github.com/ggml-org/llama.cpp) is not included by default. Build it like this:
...@@ -202,6 +204,8 @@ cargo build --features llamacpp[,cuda|metal|vulkan] -p dynamo-run ...@@ -202,6 +204,8 @@ cargo build --features llamacpp[,cuda|metal|vulkan] -p dynamo-run
dynamo-run out=llamacpp ~/llms/Llama-3.2-3B-Instruct-Q6_K.gguf dynamo-run out=llamacpp ~/llms/Llama-3.2-3B-Instruct-Q6_K.gguf
``` ```
llamacpp is best for single-GPU inference with a quantized GGUF model file.
### sglang ### sglang
The [SGLang](https://docs.sglang.ai/index.html) engine requires [etcd](https://etcd.io/) and [nats](https://nats.io/) with jetstream (`nats-server -js`) to be running. The [SGLang](https://docs.sglang.ai/index.html) engine requires [etcd](https://etcd.io/) and [nats](https://nats.io/) with jetstream (`nats-server -js`) to be running.
...@@ -416,7 +420,8 @@ async def worker(runtime: DistributedRuntime): ...@@ -416,7 +420,8 @@ async def worker(runtime: DistributedRuntime):
model_path = "Qwen/Qwen2.5-0.5B-Instruct" # or "/data/models/Qwen2.5-0.5B-Instruct" model_path = "Qwen/Qwen2.5-0.5B-Instruct" # or "/data/models/Qwen2.5-0.5B-Instruct"
model_type = ModelType.Backend model_type = ModelType.Backend
endpoint = component.endpoint("endpoint") endpoint = component.endpoint("endpoint")
await register_llm(endpoint, model_path, model_type) # Optional last param to register_llm is model_name. If not present derives it from model_path
await register_llm(model_type, endpoint, model_path)
# Initialize your engine here # Initialize your engine here
# engine = ... # engine = ...
......
...@@ -72,7 +72,7 @@ pub async fn run( ...@@ -72,7 +72,7 @@ pub async fn run(
LocalModel::prepare( LocalModel::prepare(
model_path.to_str().context("Invalid UTF-8 in model path")?, model_path.to_str().context("Invalid UTF-8 in model path")?,
flags.model_config.as_deref(), flags.model_config.as_deref(),
maybe_model_name.as_deref(), maybe_model_name,
) )
.await? .await?
} }
...@@ -136,7 +136,7 @@ pub async fn run( ...@@ -136,7 +136,7 @@ pub async fn run(
}; };
let (py_script, child) = match subprocess::start( let (py_script, child) = match subprocess::start(
subprocess::sglang::PY, subprocess::sglang::PY,
local_model.path(), &local_model,
flags.tensor_parallel_size, flags.tensor_parallel_size,
if flags.base_gpu_id == 0 { if flags.base_gpu_id == 0 {
None None
...@@ -172,7 +172,7 @@ pub async fn run( ...@@ -172,7 +172,7 @@ pub async fn run(
} }
let (py_script, child) = match subprocess::start( let (py_script, child) = match subprocess::start(
subprocess::vllm::PY, subprocess::vllm::PY,
local_model.path(), &local_model,
flags.tensor_parallel_size, flags.tensor_parallel_size,
None, // base_gpu_id. vllm uses CUDA_VISIBLE_DEVICES instead None, // base_gpu_id. vllm uses CUDA_VISIBLE_DEVICES instead
None, // multi-node config. vllm uses `ray`, see guide None, // multi-node config. vllm uses `ray`, see guide
......
...@@ -12,6 +12,7 @@ use regex::Regex; ...@@ -12,6 +12,7 @@ use regex::Regex;
use tokio::io::AsyncBufReadExt; use tokio::io::AsyncBufReadExt;
use dynamo_llm::engines::MultiNodeConfig; use dynamo_llm::engines::MultiNodeConfig;
use dynamo_llm::LocalModel;
pub mod sglang; pub mod sglang;
pub mod vllm; pub mod vllm;
...@@ -22,8 +23,8 @@ pub const ENDPOINT: &str = "dyn://dynamo.internal.worker"; ...@@ -22,8 +23,8 @@ pub const ENDPOINT: &str = "dyn://dynamo.internal.worker";
pub async fn start( pub async fn start(
// The Python code to run // The Python code to run
py_script: &'static str, py_script: &'static str,
// Path to folder or file with model weights // Model info
model_path: &Path, local_model: &LocalModel,
// How many GPUs to use // How many GPUs to use
tensor_parallel_size: u32, tensor_parallel_size: u32,
// sglang which GPU to start from, on a multi-GPU system // sglang which GPU to start from, on a multi-GPU system
...@@ -43,8 +44,10 @@ pub async fn start( ...@@ -43,8 +44,10 @@ pub async fn start(
script_path.to_string_lossy().to_string(), script_path.to_string_lossy().to_string(),
"--endpoint".to_string(), "--endpoint".to_string(),
ENDPOINT.to_string(), ENDPOINT.to_string(),
"--model".to_string(), "--model-path".to_string(),
model_path.to_string_lossy().to_string(), local_model.path().to_string_lossy().to_string(),
"--model-name".to_string(),
local_model.display_name().to_string(),
"--tensor-parallel-size".to_string(), "--tensor-parallel-size".to_string(),
tensor_parallel_size.to_string(), tensor_parallel_size.to_string(),
]; ];
......
...@@ -17,7 +17,9 @@ ...@@ -17,7 +17,9 @@
import argparse import argparse
import asyncio import asyncio
import logging
import sys import sys
from typing import Optional
import sglang import sglang
import uvloop import uvloop
...@@ -29,6 +31,8 @@ from dynamo.runtime import DistributedRuntime, dynamo_worker ...@@ -29,6 +31,8 @@ from dynamo.runtime import DistributedRuntime, dynamo_worker
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate" DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
DEFAULT_MODEL = "Qwen/Qwen2.5-0.5B-Instruct" DEFAULT_MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
logging.basicConfig(level=logging.DEBUG)
class Config: class Config:
"""Command line parameters or defaults""" """Command line parameters or defaults"""
...@@ -36,7 +40,8 @@ class Config: ...@@ -36,7 +40,8 @@ class Config:
namespace: str namespace: str
component: str component: str
endpoint: str endpoint: str
model: str model_path: str
model_name: Optional[str]
base_gpu_id: int base_gpu_id: int
tensor_parallel_size: int tensor_parallel_size: int
nnodes: int nnodes: int
...@@ -54,12 +59,17 @@ class RequestHandler: ...@@ -54,12 +59,17 @@ class RequestHandler:
self.engine_client = engine self.engine_client = engine
async def generate(self, request): async def generate(self, request):
# print(f"Received request: {request}") sampling_params = {}
sampling_params = { for key, value in request["sampling_options"].items():
"temperature": request["sampling_options"]["temperature"], if value:
# sglang defaults this to 128 # TODO: Do these always match? Maybe allow-list the fields that do match
"max_new_tokens": request["stop_conditions"]["max_tokens"], sampling_params[key] = value
}
# sglang defaults this to 128
max_new_tokens = request["stop_conditions"]["max_tokens"]
if max_new_tokens:
sampling_params["max_new_tokens"] = max_new_tokens
num_output_tokens_so_far = 0 num_output_tokens_so_far = 0
gen = await self.engine_client.async_generate( gen = await self.engine_client.async_generate(
input_ids=request["token_ids"], sampling_params=sampling_params, stream=True input_ids=request["token_ids"], sampling_params=sampling_params, stream=True
...@@ -91,12 +101,12 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -91,12 +101,12 @@ async def init(runtime: DistributedRuntime, config: Config):
await component.create_service() await component.create_service()
endpoint = component.endpoint(config.endpoint) endpoint = component.endpoint(config.endpoint)
print("Started server instance") await register_llm(
ModelType.Backend, endpoint, config.model_path, config.model_name
await register_llm(endpoint, config.model, ModelType.Backend) )
arg_map = { arg_map = {
"model_path": config.model, "model_path": config.model_path,
"skip_tokenizer_init": True, "skip_tokenizer_init": True,
"tp_size": config.tensor_parallel_size, "tp_size": config.tensor_parallel_size,
"base_gpu_id": config.base_gpu_id, "base_gpu_id": config.base_gpu_id,
...@@ -121,6 +131,8 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -121,6 +131,8 @@ async def init(runtime: DistributedRuntime, config: Config):
logging.debug(f"Adding extra engine arguments: {json_map}") logging.debug(f"Adding extra engine arguments: {json_map}")
arg_map = {**arg_map, **json_map} # json_map gets precedence arg_map = {**arg_map, **json_map} # json_map gets precedence
# TODO fetch default SamplingParams from generation_config.json
engine_args = ServerArgs(**arg_map) engine_args = ServerArgs(**arg_map)
engine_client = sglang.Engine(server_args=engine_args) engine_client = sglang.Engine(server_args=engine_args)
...@@ -140,11 +152,17 @@ def cmd_line_args(): ...@@ -140,11 +152,17 @@ def cmd_line_args():
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT}", help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT}",
) )
parser.add_argument( parser.add_argument(
"--model", "--model-path",
type=str, type=str,
default=DEFAULT_MODEL, default=DEFAULT_MODEL,
help=f"Path to disk model or HuggingFace model identifier to load. Default: {DEFAULT_MODEL}", help=f"Path to disk model or HuggingFace model identifier to load. Default: {DEFAULT_MODEL}",
) )
parser.add_argument(
"--model-name",
type=str,
default="",
help="Name to serve the model under. Defaults to deriving it from model path.",
)
parser.add_argument( parser.add_argument(
"--base-gpu-id", "--base-gpu-id",
type=int, type=int,
...@@ -178,12 +196,17 @@ def cmd_line_args(): ...@@ -178,12 +196,17 @@ def cmd_line_args():
args = parser.parse_args() args = parser.parse_args()
config = Config() config = Config()
config.model = args.model config.model_path = args.model_path
if args.model_name:
config.model_name = args.model_name
else:
# This becomes an `Option` on the Rust side
config.model_name = None
endpoint_str = args.endpoint.replace("dyn://", "", 1) endpoint_str = args.endpoint.replace("dyn://", "", 1)
endpoint_parts = endpoint_str.split(".") endpoint_parts = endpoint_str.split(".")
if len(endpoint_parts) != 3: if len(endpoint_parts) != 3:
print( logging.error(
f"Invalid endpoint format: '{args.endpoint}'. Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'." f"Invalid endpoint format: '{args.endpoint}'. Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
) )
sys.exit(1) sys.exit(1)
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# A very basic example of vllm worker handling pre-processed requests. # A very basic example of vllm worker handling pre-processed requests.
# #
...@@ -24,13 +11,16 @@ ...@@ -24,13 +11,16 @@
# Start nats and etcd: # Start nats and etcd:
# - nats-server -js # - nats-server -js
# #
# Window 1: `python server_vllm.py`. Wait for log "Starting endpoint". # Window 1: `python vllm_inc.py`. Wait for log "Starting endpoint".
# Window 2: `dynamo-run out=dyn://dynamo.backend.generate` # Window 2: `dynamo-run out=dyn://dynamo.backend.generate`
import argparse import argparse
import asyncio import asyncio
import logging import logging
import os
import sys import sys
import uuid
from typing import Optional
import uvloop import uvloop
from vllm import SamplingParams from vllm import SamplingParams
...@@ -46,8 +36,7 @@ from dynamo.runtime import DistributedRuntime, dynamo_worker ...@@ -46,8 +36,7 @@ from dynamo.runtime import DistributedRuntime, dynamo_worker
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate" DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
DEFAULT_MODEL = "Qwen/Qwen2.5-0.5B-Instruct" DEFAULT_MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
# TODO this should match DYN_LOG level logging.basicConfig(level=logging.DEBUG)
logging.basicConfig(level=logging.INFO)
class Config: class Config:
...@@ -56,7 +45,8 @@ class Config: ...@@ -56,7 +45,8 @@ class Config:
namespace: str namespace: str
component: str component: str
endpoint: str endpoint: str
model: str model_path: str
model_name: Optional[str]
tensor_parallel_size: int tensor_parallel_size: int
extra_engine_args: str extra_engine_args: str
...@@ -71,14 +61,21 @@ class RequestHandler: ...@@ -71,14 +61,21 @@ class RequestHandler:
self.default_sampling_params = default_sampling_params self.default_sampling_params = default_sampling_params
async def generate(self, request): async def generate(self, request):
request_id = "1" # hello_world example only # logging.debug(f"Received request: {request}")
request_id = str(uuid.uuid4().hex)
logging.debug(f"Received request: {request}")
prompt = TokensPrompt(prompt_token_ids=request["token_ids"]) prompt = TokensPrompt(prompt_token_ids=request["token_ids"])
sampling_params = SamplingParams(**self.default_sampling_params) sampling_params = SamplingParams(**self.default_sampling_params)
sampling_params.temperature = request["sampling_options"]["temperature"] for key, value in request["sampling_options"].items():
sampling_params.max_tokens = request["stop_conditions"]["max_tokens"] if not value:
continue
if hasattr(sampling_params, key):
setattr(sampling_params, key, value)
max_tokens = request["stop_conditions"]["max_tokens"]
if max_tokens:
sampling_params.max_tokens = max_tokens
num_output_tokens_so_far = 0 num_output_tokens_so_far = 0
gen = self.engine_client.generate(prompt, sampling_params, request_id) gen = self.engine_client.generate(prompt, sampling_params, request_id)
...@@ -119,15 +116,18 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -119,15 +116,18 @@ async def init(runtime: DistributedRuntime, config: Config):
await component.create_service() await component.create_service()
endpoint = component.endpoint(config.endpoint) endpoint = component.endpoint(config.endpoint)
logging.info("Started server instance") await register_llm(
ModelType.Backend, endpoint, config.model_path, config.model_name
await register_llm(endpoint, config.model, ModelType.Backend) )
arg_map = { arg_map = {
"model": config.model, "model": config.model_path,
"task": "generate", "task": "generate",
"tensor_parallel_size": config.tensor_parallel_size, "tensor_parallel_size": config.tensor_parallel_size,
"skip_tokenizer_init": True, "skip_tokenizer_init": True,
"disable_log_requests": True,
# KV routing relies on logging KV metrics
"disable_log_stats": False,
} }
if config.extra_engine_args != "": if config.extra_engine_args != "":
json_map = {} json_map = {}
...@@ -142,6 +142,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -142,6 +142,7 @@ async def init(runtime: DistributedRuntime, config: Config):
logging.debug(f"Adding extra engine arguments: {json_map}") logging.debug(f"Adding extra engine arguments: {json_map}")
arg_map = {**arg_map, **json_map} # json_map gets precedence arg_map = {**arg_map, **json_map} # json_map gets precedence
os.environ["VLLM_NO_USAGE_STATS"] = "1" # Avoid internal HTTP requests
engine_args = AsyncEngineArgs(**arg_map) engine_args = AsyncEngineArgs(**arg_map)
model_config = engine_args.create_model_config() model_config = engine_args.create_model_config()
# Load default sampling params from `generation_config.json` # Load default sampling params from `generation_config.json`
...@@ -168,11 +169,17 @@ def cmd_line_args(): ...@@ -168,11 +169,17 @@ def cmd_line_args():
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT}", help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT}",
) )
parser.add_argument( parser.add_argument(
"--model", "--model-path",
type=str, type=str,
default=DEFAULT_MODEL, default=DEFAULT_MODEL,
help=f"Path to disk model or HuggingFace model identifier to load. Default: {DEFAULT_MODEL}", help=f"Path to disk model or HuggingFace model identifier to load. Default: {DEFAULT_MODEL}",
) )
parser.add_argument(
"--model-name",
type=str,
default="",
help="Name to serve the model under. Defaults to deriving it from model path.",
)
parser.add_argument( parser.add_argument(
"--tensor-parallel-size", type=int, default=1, help="Number of GPUs to use." "--tensor-parallel-size", type=int, default=1, help="Number of GPUs to use."
) )
...@@ -185,7 +192,12 @@ def cmd_line_args(): ...@@ -185,7 +192,12 @@ def cmd_line_args():
args = parser.parse_args() args = parser.parse_args()
config = Config() config = Config()
config.model = args.model config.model_path = args.model_path
if args.model_name:
config.model_name = args.model_name
else:
# This becomes an `Option` on the Rust side
config.model_name = None
endpoint_str = args.endpoint.replace("dyn://", "", 1) endpoint_str = args.endpoint.replace("dyn://", "", 1)
endpoint_parts = endpoint_str.split(".") endpoint_parts = endpoint_str.split(".")
......
...@@ -97,9 +97,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -97,9 +97,7 @@ async def init(runtime: DistributedRuntime, config: Config):
await component.create_service() await component.create_service()
endpoint = component.endpoint(config.endpoint) endpoint = component.endpoint(config.endpoint)
print("Started server instance") await register_llm(ModelType.Backend, endpoint, config.model)
await register_llm(endpoint, config.model, ModelType.Backend)
engine_args = ServerArgs( engine_args = ServerArgs(
model_path=config.model, model_path=config.model,
......
...@@ -112,9 +112,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -112,9 +112,7 @@ async def init(runtime: DistributedRuntime, config: Config):
await component.create_service() await component.create_service()
endpoint = component.endpoint(config.endpoint) endpoint = component.endpoint(config.endpoint)
print("Started server instance") await register_llm(ModelType.Backend, endpoint, config.model)
await register_llm(endpoint, config.model, ModelType.Backend)
engine_args = AsyncEngineArgs( engine_args = AsyncEngineArgs(
model=config.model, model=config.model,
......
...@@ -101,12 +101,13 @@ fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32) ...@@ -101,12 +101,13 @@ fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32)
} }
#[pyfunction] #[pyfunction]
#[pyo3(text_signature = "(endpoint, path, model_type)")] #[pyo3(signature = (model_type, endpoint, model_path, model_name=None))]
fn register_llm<'p>( fn register_llm<'p>(
py: Python<'p>, py: Python<'p>,
endpoint: Endpoint,
path: &str,
model_type: ModelType, model_type: ModelType,
endpoint: Endpoint,
model_path: &str,
model_name: Option<&str>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
let model_type_obj = match model_type { let model_type_obj = match model_type {
ModelType::Chat => llm_rs::model_type::ModelType::Chat, ModelType::Chat => llm_rs::model_type::ModelType::Chat,
...@@ -114,10 +115,11 @@ fn register_llm<'p>( ...@@ -114,10 +115,11 @@ fn register_llm<'p>(
ModelType::Backend => llm_rs::model_type::ModelType::Backend, ModelType::Backend => llm_rs::model_type::ModelType::Backend,
}; };
let inner_path = path.to_string(); let inner_path = model_path.to_string();
let model_name = model_name.map(|n| n.to_string());
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
// Download from HF, load the ModelDeploymentCard // Download from HF, load the ModelDeploymentCard
let mut local_model = llm_rs::LocalModel::prepare(&inner_path, None, None) let mut local_model = llm_rs::LocalModel::prepare(&inner_path, None, model_name)
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
......
...@@ -610,7 +610,7 @@ class ModelType: ...@@ -610,7 +610,7 @@ class ModelType:
"""What type of request this model needs: Chat, Component or Backend (pre-processed)""" """What type of request this model needs: Chat, Component or Backend (pre-processed)"""
... ...
async def register_llm(endpoint: Endpoint, path: str, model_type: ModelType) -> None: async def register_llm(model_type: ModelType, endpoint: Endpoint, model_path: str, model_name: Optional[str]) -> None:
"""Attach the model at path to the given endpoint, and advertise it as model_type""" """Attach the model at path to the given endpoint, and advertise it as model_type"""
... ...
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::fs; use std::fs;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::Arc; use std::sync::Arc;
use anyhow::Context;
use dynamo_runtime::component::Endpoint; use dynamo_runtime::component::Endpoint;
use dynamo_runtime::traits::DistributedRuntimeProvider; use dynamo_runtime::traits::DistributedRuntimeProvider;
...@@ -57,6 +44,10 @@ impl LocalModel { ...@@ -57,6 +44,10 @@ impl LocalModel {
&self.full_path &self.full_path
} }
pub fn display_name(&self) -> &str {
&self.card.display_name
}
pub fn service_name(&self) -> &str { pub fn service_name(&self) -> &str {
&self.card.service_name &self.card.service_name
} }
...@@ -74,7 +65,7 @@ impl LocalModel { ...@@ -74,7 +65,7 @@ impl LocalModel {
pub async fn prepare( pub async fn prepare(
model_path: &str, model_path: &str,
override_config: Option<&Path>, override_config: Option<&Path>,
override_name: Option<&str>, override_name: Option<String>,
) -> anyhow::Result<LocalModel> { ) -> anyhow::Result<LocalModel> {
// Name it // Name it
...@@ -91,23 +82,21 @@ impl LocalModel { ...@@ -91,23 +82,21 @@ impl LocalModel {
fs::canonicalize(relative_path)? fs::canonicalize(relative_path)?
}; };
let model_name = match override_name.map(|s| s.to_string()) { let model_name = override_name.unwrap_or_else(|| {
Some(name) => name, if is_hf_repo {
None => { // HF repos use their full name ("org/name") not the folder name
if is_hf_repo { relative_path.to_string()
// HF repos use their full name ("org/name") not the folder name } else {
relative_path.to_string() full_path
} else { .iter()
full_path .next_back()
.iter() .map(|n| n.to_string_lossy().into_owned())
.next_back() .unwrap_or_else(|| {
.map(|n| n.to_string_lossy().into_owned()) // Panic because we can't do anything without a model
.with_context(|| { panic!("Invalid model path, too short: '{}'", full_path.display())
format!("Invalid model path, too short: {}", full_path.display()) })
})?
}
} }
}; });
// Load the ModelDeploymentCard // Load the ModelDeploymentCard
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment