Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
ab0da582
Unverified
Commit
ab0da582
authored
Oct 15, 2025
by
Graham King
Committed by
GitHub
Oct 15, 2025
Browse files
feat: Python binding to download a model. (#3593)
Signed-off-by:
Graham King
<
grahamk@nvidia.com
>
parent
6a1391eb
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
213 additions
and
121 deletions
+213
-121
components/src/dynamo/sglang/args.py
components/src/dynamo/sglang/args.py
+11
-1
components/src/dynamo/sglang/main.py
components/src/dynamo/sglang/main.py
+2
-1
components/src/dynamo/vllm/main.py
components/src/dynamo/vllm/main.py
+22
-9
launch/dynamo-run/src/lib.rs
launch/dynamo-run/src/lib.rs
+25
-6
lib/bindings/python/rust/lib.rs
lib/bindings/python/rust/lib.rs
+32
-3
lib/bindings/python/rust/llm/entrypoint.rs
lib/bindings/python/rust/llm/entrypoint.rs
+18
-2
lib/bindings/python/src/dynamo/_core.pyi
lib/bindings/python/src/dynamo/_core.pyi
+7
-0
lib/bindings/python/src/dynamo/llm/__init__.py
lib/bindings/python/src/dynamo/llm/__init__.py
+1
-0
lib/llm/src/local_model.rs
lib/llm/src/local_model.rs
+30
-40
lib/llm/src/model_card.rs
lib/llm/src/model_card.rs
+56
-53
tests/unit/test_sglang_unit.py
tests/unit/test_sglang_unit.py
+9
-6
No files found.
components/src/dynamo/sglang/args.py
View file @
ab0da582
...
...
@@ -16,6 +16,7 @@ from sglang.srt.server_args import ServerArgs
from
dynamo._core
import
get_reasoning_parser_names
,
get_tool_parser_names
from
dynamo.common.config_dump
import
register_encoder
from
dynamo.llm
import
fetch_llm
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.sglang
import
__version__
...
...
@@ -203,8 +204,9 @@ def _set_parser(
return
dynamo_str
def
parse_args
(
args
:
list
[
str
])
->
Config
:
async
def
parse_args
(
args
:
list
[
str
])
->
Config
:
"""Parse CLI arguments and return combined configuration.
Download the model if necessary.
Args:
args: Command-line argument strings.
...
...
@@ -339,6 +341,14 @@ def parse_args(args: list[str]) -> Config:
)
logging
.
debug
(
f
"Dynamo args:
{
dynamo_args
}
"
)
# TODO: sglang downloads the model in `from_cli_args`, so we need to do it here.
# That's unfortunate because `parse_args` isn't the right place for this. Fix.
model_path
=
parsed_args
.
model_path
if
not
parsed_args
.
served_model_name
:
parsed_args
.
served_model_name
=
model_path
if
not
os
.
path
.
exists
(
model_path
):
parsed_args
.
model_path
=
await
fetch_llm
(
model_path
)
server_args
=
ServerArgs
.
from_cli_args
(
parsed_args
)
if
parsed_args
.
use_sglang_tokenizer
:
...
...
components/src/dynamo/sglang/main.py
View file @
ab0da582
...
...
@@ -45,8 +45,9 @@ async def worker(runtime: DistributedRuntime):
logging
.
info
(
"Signal handlers will trigger a graceful shutdown of the runtime"
)
config
=
parse_args
(
sys
.
argv
[
1
:])
config
=
await
parse_args
(
sys
.
argv
[
1
:])
dump_config
(
config
.
dynamo_args
.
dump_config_to
,
config
)
if
config
.
dynamo_args
.
embedding_worker
:
await
init_embedding
(
runtime
,
config
)
elif
config
.
dynamo_args
.
multimodal_processor
:
...
...
components/src/dynamo/vllm/main.py
View file @
ab0da582
...
...
@@ -21,6 +21,7 @@ from dynamo.llm import (
ModelType
,
ZmqKvEventPublisher
,
ZmqKvEventPublisherConfig
,
fetch_llm
,
register_llm
,
)
from
dynamo.runtime
import
DistributedRuntime
,
dynamo_worker
...
...
@@ -82,6 +83,15 @@ async def worker(runtime: DistributedRuntime):
logging
.
debug
(
"Signal handlers set up for graceful shutdown"
)
dump_config
(
config
.
dump_config_to
,
config
)
# Download the model if necessary.
# register_llm would do this for us, but we want it on disk before we start vllm.
# Ensure the original HF name (e.g. "Qwen/Qwen3-0.6B") is used as the served_model_name.
if
not
config
.
served_model_name
:
config
.
served_model_name
=
config
.
engine_args
.
served_model_name
=
config
.
model
if
not
os
.
path
.
exists
(
config
.
model
):
config
.
model
=
config
.
engine_args
.
model
=
await
fetch_llm
(
config
.
model
)
if
config
.
is_prefill_worker
:
await
init_prefill
(
runtime
,
config
)
logger
.
debug
(
"init_prefill completed"
)
...
...
@@ -165,9 +175,11 @@ def setup_vllm_engine(config, stat_logger=None):
disable_log_stats
=
engine_args
.
disable_log_stats
,
)
if
ENABLE_LMCACHE
:
logger
.
info
(
f
"VllmWorker for
{
config
.
model
}
has been initialized with LMCache"
)
logger
.
info
(
f
"VllmWorker for
{
config
.
served_model_name
}
has been initialized with LMCache"
)
else
:
logger
.
info
(
f
"VllmWorker for
{
config
.
model
}
has been initialized"
)
logger
.
info
(
f
"VllmWorker for
{
config
.
served_model_name
}
has been initialized"
)
return
engine_client
,
vllm_config
,
default_sampling_params
...
...
@@ -207,11 +219,13 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
generate_endpoint
.
serve_endpoint
(
handler
.
generate
,
graceful_shutdown
=
True
,
metrics_labels
=
[(
"model"
,
config
.
model
)],
# In practice config.served_model_name is always set, but mypy needs the "or" here.
metrics_labels
=
[(
"model"
,
config
.
served_model_name
or
config
.
model
)],
health_check_payload
=
health_check_payload
,
),
clear_endpoint
.
serve_endpoint
(
handler
.
clear_kv_blocks
,
metrics_labels
=
[(
"model"
,
config
.
model
)]
handler
.
clear_kv_blocks
,
metrics_labels
=
[(
"model"
,
config
.
served_model_name
)],
),
)
logger
.
debug
(
"serve_endpoint completed for prefill worker"
)
...
...
@@ -251,7 +265,7 @@ async def init(runtime: DistributedRuntime, config: Config):
factory
=
StatLoggerFactory
(
component
,
config
.
engine_args
.
data_parallel_rank
or
0
,
metrics_labels
=
[(
"model"
,
config
.
model
)],
metrics_labels
=
[(
"model"
,
config
.
served_model_name
or
config
.
model
)],
)
engine_client
,
vllm_config
,
default_sampling_params
=
setup_vllm_engine
(
config
,
factory
...
...
@@ -262,8 +276,6 @@ async def init(runtime: DistributedRuntime, config: Config):
factory
.
set_request_total_slots_all
(
vllm_config
.
scheduler_config
.
max_num_seqs
)
factory
.
init_publish
()
logger
.
info
(
f
"VllmWorker for
{
config
.
model
}
has been initialized"
)
handler
=
DecodeWorkerHandler
(
runtime
,
component
,
...
...
@@ -321,11 +333,12 @@ async def init(runtime: DistributedRuntime, config: Config):
generate_endpoint
.
serve_endpoint
(
handler
.
generate
,
graceful_shutdown
=
config
.
migration_limit
<=
0
,
metrics_labels
=
[(
"model"
,
config
.
model
)],
metrics_labels
=
[(
"model"
,
config
.
served_model_name
or
config
.
model
)],
health_check_payload
=
health_check_payload
,
),
clear_endpoint
.
serve_endpoint
(
handler
.
clear_kv_blocks
,
metrics_labels
=
[(
"model"
,
config
.
model
)]
handler
.
clear_kv_blocks
,
metrics_labels
=
[(
"model"
,
config
.
served_model_name
or
config
.
model
)],
),
)
logger
.
debug
(
"serve_endpoint completed for decode worker"
)
...
...
launch/dynamo-run/src/lib.rs
View file @
ab0da582
...
...
@@ -21,18 +21,32 @@ pub async fn run(
out_opt
:
Option
<
Output
>
,
mut
flags
:
Flags
,
)
->
anyhow
::
Result
<
()
>
{
//
// Download
//
let
maybe_remote_repo
=
flags
.model_path_pos
.clone
()
.or_else
(||
flags
.model_path_flag
.clone
());
let
model_path
=
match
maybe_remote_repo
{
None
=>
None
,
Some
(
p
)
if
p
.exists
()
=>
{
// Already a local path
Some
(
p
)
}
Some
(
p
)
=>
{
// model_path might be an HF repo, not a local path. Resolve it by downloading.
Some
(
LocalModel
::
fetch
(
&
p
.display
()
.to_string
(),
false
)
.await
?
)
}
};
//
// Configure
//
let
mut
builder
=
LocalModelBuilder
::
default
();
builder
.model_path
(
flags
.model_path_pos
.clone
()
.or
(
flags
.model_path_flag
.clone
()),
)
.model_name
(
flags
.model_name
.clone
())
.kv_cache_block_size
(
flags
.kv_cache_block_size
)
// Only set if user provides. Usually loaded from tokenizer_config.json
...
...
@@ -45,6 +59,11 @@ pub async fn run(
.migration_limit
(
flags
.migration_limit
)
.is_mocker
(
matches!
(
out_opt
,
Some
(
Output
::
Mocker
)));
// Only the worker has a model path
if
let
Some
(
model_path
)
=
model_path
{
builder
.model_path
(
model_path
);
}
// TODO: old, address this later:
// If `in=dyn` we want the trtllm/sglang/vllm subprocess to listen on that endpoint.
// If not, then the endpoint isn't exposed so we let LocalModel invent one.
...
...
lib/bindings/python/rust/lib.rs
View file @
ab0da582
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
dynamo_llm
::
local_model
::
LocalModel
;
use
futures
::
StreamExt
;
use
once_cell
::
sync
::
OnceCell
;
use
pyo3
::
IntoPyObjectExt
;
...
...
@@ -9,6 +10,7 @@ use pyo3::types::{PyDict, PyString};
use
pyo3
::{
exceptions
::
PyException
,
prelude
::
*
};
use
rand
::
seq
::
IteratorRandom
as
_
;
use
rs
::
pipeline
::
network
::
Ingress
;
use
std
::
fs
;
use
std
::
net
::{
IpAddr
,
Ipv4Addr
,
SocketAddr
,
SocketAddrV4
};
use
std
::
path
::
PathBuf
;
use
std
::
time
::
Duration
;
...
...
@@ -96,6 +98,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m
.add_function
(
wrap_pyfunction!
(
llm
::
kv
::
compute_block_hash_for_seq_py
,
m
)
?
)
?
;
m
.add_function
(
wrap_pyfunction!
(
log_message
,
m
)
?
)
?
;
m
.add_function
(
wrap_pyfunction!
(
register_llm
,
m
)
?
)
?
;
m
.add_function
(
wrap_pyfunction!
(
fetch_llm
,
m
)
?
)
?
;
m
.add_function
(
wrap_pyfunction!
(
llm
::
entrypoint
::
make_engine
,
m
)
?
)
?
;
m
.add_function
(
wrap_pyfunction!
(
llm
::
entrypoint
::
run_input
,
m
)
?
)
?
;
...
...
@@ -174,6 +177,8 @@ fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32)
logging
::
log_message
(
level
,
message
,
module
,
file
,
line
);
}
/// Create an engine and attach it to an endpoint to make it visible to the frontend.
/// This is the main way you create a Dynamo worker / backend.
#[pyfunction]
#[pyo3(signature
=
(model_input,
model_type,
endpoint,
model_path,
model_name=None,
context_length=None,
kv_cache_block_size=None,
router_mode=None,
migration_limit=
0
,
runtime_config=None,
user_data=None,
custom_template_path=None))]
#[allow(clippy::too_many_arguments)]
...
...
@@ -201,7 +206,7 @@ fn register_llm<'p>(
let
model_type_obj
=
model_type
.inner
;
let
inner_path
=
model_path
.to_string
();
let
model_name
=
model_name
.map
(|
n
|
n
.to_string
());
let
mut
model_name
=
model_name
.map
(|
n
|
n
.to_string
());
let
router_mode
=
router_mode
.unwrap_or
(
RouterMode
::
RoundRobin
);
let
router_config
=
RouterConfig
::
new
(
router_mode
.into
(),
KvRouterConfig
::
default
());
...
...
@@ -226,9 +231,22 @@ fn register_llm<'p>(
})
?
;
pyo3_async_runtimes
::
tokio
::
future_into_py
(
py
,
async
move
{
let
model_path
=
if
fs
::
exists
(
&
inner_path
)
?
{
PathBuf
::
from
(
inner_path
)
}
else
{
// Preserve the model name
if
model_name
.is_none
()
{
model_name
=
Some
(
inner_path
.clone
());
}
// Likely it's a Hugging Face repo, download it
LocalModel
::
fetch
(
&
inner_path
,
false
)
.await
.map_err
(
to_pyerr
)
?
};
let
mut
builder
=
dynamo_llm
::
local_model
::
LocalModelBuilder
::
default
();
builder
.model_path
(
Some
(
PathBuf
::
from
(
inner
_path
)
))
.model_path
(
model
_path
)
.model_name
(
model_name
)
.context_length
(
context_length
)
.kv_cache_block_size
(
kv_cache_block_size
)
...
...
@@ -237,7 +255,7 @@ fn register_llm<'p>(
.runtime_config
(
runtime_config
.unwrap_or_default
()
.inner
)
.user_data
(
user_data_json
)
.custom_template_path
(
custom_template_path_owned
);
//
Download from HF, l
oad the ModelDeploymentCard
//
L
oad the ModelDeploymentCard
let
mut
local_model
=
builder
.build
()
.await
.map_err
(
to_pyerr
)
?
;
// Advertise ourself on etcd so ingress can find us
local_model
...
...
@@ -249,6 +267,17 @@ fn register_llm<'p>(
})
}
/// Download a model from Hugging Face, returning it's local path
/// Example: `model_path = await fetch_llm("Qwen/Qwen3-0.6B")`
#[pyfunction]
#[pyo3(signature
=
(remote_name))]
fn
fetch_llm
<
'p
>
(
py
:
Python
<
'p
>
,
remote_name
:
&
str
)
->
PyResult
<
Bound
<
'p
,
PyAny
>>
{
let
repo
=
remote_name
.to_string
();
pyo3_async_runtimes
::
tokio
::
future_into_py
(
py
,
async
move
{
LocalModel
::
fetch
(
&
repo
,
false
)
.await
.map_err
(
to_pyerr
)
})
}
#[pyclass]
#[derive(Clone)]
pub
struct
DistributedRuntime
{
...
...
lib/bindings/python/rust/llm/entrypoint.rs
View file @
ab0da582
...
...
@@ -180,6 +180,8 @@ pub(crate) struct EngineConfig {
inner
:
RsEngineConfig
,
}
/// Create the backend engine wrapper to run the model.
/// Download the model if necessary.
#[pyfunction]
#[pyo3(signature
=
(distributed_runtime,
args))]
pub
fn
make_engine
<
'p
>
(
...
...
@@ -189,8 +191,11 @@ pub fn make_engine<'p>(
)
->
PyResult
<
Bound
<
'p
,
PyAny
>>
{
let
mut
builder
=
LocalModelBuilder
::
default
();
builder
.model_path
(
args
.model_path
.clone
())
.model_name
(
args
.model_name
.clone
())
.model_name
(
args
.model_name
.clone
()
.or_else
(||
args
.model_path
.clone
()
.map
(|
p
|
p
.display
()
.to_string
())),
)
.endpoint_id
(
args
.endpoint_id
.clone
())
.context_length
(
args
.context_length
)
.request_template
(
args
.template_file
.clone
())
...
...
@@ -206,6 +211,17 @@ pub fn make_engine<'p>(
.custom_backend_metrics_endpoint
(
args
.custom_backend_metrics_endpoint
.clone
())
.custom_backend_metrics_polling_interval
(
args
.custom_backend_metrics_polling_interval
);
pyo3_async_runtimes
::
tokio
::
future_into_py
(
py
,
async
move
{
if
let
Some
(
model_path
)
=
args
.model_path
.clone
()
{
let
local_path
=
if
model_path
.exists
()
{
model_path
}
else
{
LocalModel
::
fetch
(
&
model_path
.display
()
.to_string
(),
false
)
.await
.map_err
(
to_pyerr
)
?
};
builder
.model_path
(
local_path
);
}
let
local_model
=
builder
.build
()
.await
.map_err
(
to_pyerr
)
?
;
let
inner
=
select_engine
(
distributed_runtime
,
args
,
local_model
)
.await
...
...
lib/bindings/python/src/dynamo/_core.pyi
View file @
ab0da582
...
...
@@ -892,6 +892,13 @@ async def register_llm(
"""Attach the model at path to the given endpoint, and advertise it as model_type"""
...
async def fetch_llm(remote_name: str) -> str:
"""
Download a model from Hugging Face, returning it's local path.
Example: `model_path = await fetch_llm("Qwen/Qwen3-0.6B")`
"""
...
class EngineConfig:
"""Holds internal configuration for a Dynamo engine."""
...
...
...
lib/bindings/python/src/dynamo/llm/__init__.py
View file @
ab0da582
...
...
@@ -42,6 +42,7 @@ from dynamo._core import ZmqKvEventListener as ZmqKvEventListener
from
dynamo._core
import
ZmqKvEventPublisher
as
ZmqKvEventPublisher
from
dynamo._core
import
ZmqKvEventPublisherConfig
as
ZmqKvEventPublisherConfig
from
dynamo._core
import
compute_block_hash_for_seq_py
as
compute_block_hash_for_seq_py
from
dynamo._core
import
fetch_llm
as
fetch_llm
from
dynamo._core
import
make_engine
from
dynamo._core
import
register_llm
as
register_llm
from
dynamo._core
import
run_input
...
...
lib/llm/src/local_model.rs
View file @
ab0da582
...
...
@@ -5,7 +5,6 @@ use std::fs;
use
std
::
path
::{
Path
,
PathBuf
};
use
std
::
sync
::
Arc
;
use
anyhow
::
Context
as
_
;
use
dynamo_runtime
::
protocols
::
EndpointId
;
use
dynamo_runtime
::
slug
::
Slug
;
use
dynamo_runtime
::
storage
::
key_value_store
::
Key
;
...
...
@@ -25,9 +24,6 @@ pub mod runtime_config;
use
runtime_config
::
ModelRuntimeConfig
;
/// Prefix for Hugging Face model repository
const
HF_SCHEME
:
&
str
=
"hf://"
;
/// What we call a model if the user didn't provide a name. Usually this means the name
/// is invisible, for example in a text chat.
const
DEFAULT_NAME
:
&
str
=
"dynamo"
;
...
...
@@ -90,8 +86,9 @@ impl Default for LocalModelBuilder {
}
impl
LocalModelBuilder
{
pub
fn
model_path
(
&
mut
self
,
model_path
:
Option
<
PathBuf
>
)
->
&
mut
Self
{
self
.model_path
=
model_path
;
/// The path must exist
pub
fn
model_path
(
&
mut
self
,
model_path
:
PathBuf
)
->
&
mut
Self
{
self
.model_path
=
Some
(
model_path
);
self
}
...
...
@@ -214,7 +211,7 @@ impl LocalModelBuilder {
.map
(
RequestTemplate
::
load
)
.transpose
()
?
;
// echo engine do
es
n't need a path.
It's an edge case, move it out of the way.
//
frontend and
echo engine don't need a path.
if
self
.model_path
.is_none
()
{
let
mut
card
=
ModelDeploymentCard
::
with_name_only
(
self
.model_name
.as_deref
()
.unwrap_or
(
DEFAULT_NAME
),
...
...
@@ -243,40 +240,24 @@ impl LocalModelBuilder {
// Main logic. We are running a model.
let
model_path
=
self
.model_path
.take
()
.unwrap
();
let
model_path
=
model_path
.to_str
()
.context
(
"Invalid UTF-8 in model path"
)
?
;
// Check for hf:// prefix first, in case we really want an HF repo but it conflicts
// with a relative path.
let
is_hf_repo
=
model_path
.starts_with
(
HF_SCHEME
)
||
!
fs
::
exists
(
model_path
)
.unwrap_or
(
false
);
let
relative_path
=
model_path
.trim_start_matches
(
HF_SCHEME
);
let
full_path
=
if
is_hf_repo
{
// HF download if necessary
super
::
hub
::
from_hf
(
relative_path
,
self
.is_mocker
)
.await
?
}
else
{
fs
::
canonicalize
(
relative_path
)
?
};
if
!
model_path
.exists
()
{
anyhow
::
bail!
(
"Path does not exist: '{}'. Use LocalModel::fetch to download it."
,
model_path
.display
(),
);
}
let
model_path
=
fs
::
canonicalize
(
model_path
)
?
;
let
mut
card
=
ModelDeploymentCard
::
load_from_disk
(
&
full_path
,
self
.custom_template_path
.as_deref
())
?
;
// Usually we infer from the path, self.model_name is user override
let
model_name
=
self
.model_name
.take
()
.unwrap_or_else
(||
{
if
is_hf_repo
{
// HF repos use their full name ("org/name") not the folder name
relative_path
.to_string
()
}
else
{
full_path
.iter
()
.next_back
()
.map
(|
n
|
n
.to_string_lossy
()
.into_owned
())
.unwrap_or_else
(||
{
// Panic because we can't do anything without a model
panic!
(
"Invalid model path, too short: '{}'"
,
full_path
.display
())
})
}
});
card
.set_name
(
&
model_name
);
ModelDeploymentCard
::
load_from_disk
(
&
model_path
,
self
.custom_template_path
.as_deref
())
?
;
// The served model name defaults to the full model path.
// This matches what vllm and sglang do.
card
.set_name
(
&
self
.model_name
.clone
()
.unwrap_or_else
(||
model_path
.display
()
.to_string
()),
);
card
.kv_cache_block_size
=
self
.kv_cache_block_size
;
...
...
@@ -303,7 +284,7 @@ impl LocalModelBuilder {
Ok
(
LocalModel
{
card
,
full_path
,
full_path
:
model_path
,
endpoint_id
,
template
,
http_host
:
self
.http_host
.take
(),
...
...
@@ -337,6 +318,15 @@ pub struct LocalModel {
}
impl
LocalModel
{
/// Ensure a model is accessible locally, returning it's path.
/// Downloads the model from Hugging Face if necessary.
/// If ignore_weights is true, model weight files will be skipped and only the model config
/// will be downloaded.
/// Returns the path to the model files
pub
async
fn
fetch
(
remote_name
:
&
str
,
ignore_weights
:
bool
)
->
anyhow
::
Result
<
PathBuf
>
{
super
::
hub
::
from_hf
(
remote_name
,
ignore_weights
)
.await
}
pub
fn
card
(
&
self
)
->
&
ModelDeploymentCard
{
&
self
.card
}
...
...
lib/llm/src/model_card.rs
View file @
ab0da582
...
...
@@ -13,7 +13,7 @@
//! - Prompt formatter settings (PromptFormatterArtifact)
use
std
::
fmt
;
use
std
::
path
::
{
Path
,
PathBuf
}
;
use
std
::
path
::
Path
;
use
std
::
sync
::{
Arc
,
OnceLock
};
use
crate
::
common
::
checked_file
::
CheckedFile
;
...
...
@@ -485,38 +485,26 @@ impl ModelDeploymentCard {
/// - The path contains invalid Unicode characters
/// - Required model files are missing or invalid
fn
from_local_path
(
local_
root_dir
:
impl
AsRef
<
Path
>
,
local_
path
:
impl
AsRef
<
Path
>
,
custom_template_path
:
Option
<&
Path
>
,
)
->
anyhow
::
Result
<
Self
>
{
let
local_root_dir
=
local_root_dir
.as_ref
();
check_valid_local_repo_path
(
local_root_dir
)
?
;
let
repo_id
=
local_root_dir
.canonicalize
()
?
.to_str
()
.ok_or_else
(||
anyhow
::
anyhow!
(
"Path contains invalid Unicode"
))
?
.to_string
();
let
model_name
=
local_root_dir
.file_name
()
.and_then
(|
n
|
n
.to_str
())
.ok_or_else
(||
anyhow
::
anyhow!
(
"Invalid model directory name"
))
?
;
Self
::
from_repo
(
&
repo_id
,
model_name
,
custom_template_path
)
}
fn
from_repo
(
repo_id
:
&
str
,
model_name
:
&
str
,
check_valid_local_repo_path
(
&
local_path
)
?
;
Self
::
from_repo_checkout
(
&
local_path
,
custom_template_path
)
}
fn
from_repo_checkout
(
local_path
:
impl
AsRef
<
Path
>
,
custom_template_path
:
Option
<&
Path
>
,
)
->
anyhow
::
Result
<
Self
>
{
let
local_path
=
local_path
.as_ref
();
// This is usually the right choice
let
context_length
=
crate
::
file_json_field
(
&
PathBuf
::
from
(
repo_id
)
.join
(
"config.json"
),
"max_position_embeddings"
,
)
let
context_length
=
crate
::
file_json_field
(
&
local_path
.join
(
"config.json"
),
"max_position_embeddings"
)
// But sometimes this is
.or_else
(|
_
|
{
crate
::
file_json_field
(
&
PathBuf
::
from
(
repo_id
)
.join
(
"tokenizer_config.json"
),
&
local_path
.join
(
"tokenizer_config.json"
),
"model_max_length"
,
)
})
...
...
@@ -544,16 +532,17 @@ impl ModelDeploymentCard {
CheckedFile
::
from_disk
(
template_path
)
?
,
))
}
else
{
PromptFormatterArtifact
::
chat_template_from_
repo
(
repo_id
)
?
PromptFormatterArtifact
::
chat_template_from_
disk
(
local_path
)
?
};
let
display_name
=
local_path
.display
()
.to_string
();
Ok
(
Self
{
display_name
:
model_name
.to_string
(
),
slug
:
Slug
::
from_string
(
model
_name
)
,
model_info
:
Some
(
ModelInfoType
::
from_
repo
(
repo_id
)
?
),
tokenizer
:
Some
(
TokenizerKind
::
from_
repo
(
repo_id
)
?
),
gen_config
:
GenerationConfig
::
from_
repo
(
repo_id
)
.ok
(),
// optional
prompt_formatter
:
PromptFormatterArtifact
::
from_
repo
(
repo_id
)
?
,
slug
:
Slug
::
from_string
(
&
display_name
),
display
_name
,
model_info
:
Some
(
ModelInfoType
::
from_
disk
(
local_path
)
?
),
tokenizer
:
Some
(
TokenizerKind
::
from_
disk
(
local_path
)
?
),
gen_config
:
GenerationConfig
::
from_
disk
(
local_path
)
.ok
(),
// optional
prompt_formatter
:
PromptFormatterArtifact
::
from_
disk
(
local_path
)
?
,
chat_template_file
,
prompt_context
:
None
,
// TODO - auto-detect prompt context
context_length
,
...
...
@@ -778,33 +767,43 @@ impl ModelInfo for HFConfig {
}
impl
ModelInfoType
{
pub
fn
from_repo
(
repo_id
:
&
str
)
->
Result
<
Self
>
{
let
f
=
CheckedFile
::
from_disk
(
PathBuf
::
from
(
repo_id
)
.join
(
"config.json"
))
.with_context
(||
format!
(
"unable to extract config.json from repo {repo_id}"
))
?
;
pub
fn
from_disk
(
directory
:
&
Path
)
->
Result
<
Self
>
{
let
f
=
CheckedFile
::
from_disk
(
directory
.join
(
"config.json"
))
.with_context
(||
{
format!
(
"unable to extract config.json from directory {}"
,
directory
.display
()
)
})
?
;
Ok
(
Self
::
HfConfigJson
(
f
))
}
}
impl
GenerationConfig
{
pub
fn
from_repo
(
repo_id
:
&
str
)
->
Result
<
Self
>
{
let
f
=
CheckedFile
::
from_disk
(
PathBuf
::
from
(
repo_id
)
.join
(
"generation_config.json"
))
.with_context
(||
format!
(
"unable to extract generation_config from repo {repo_id}"
))
?
;
pub
fn
from_disk
(
directory
:
&
Path
)
->
Result
<
Self
>
{
let
f
=
CheckedFile
::
from_disk
(
directory
.join
(
"generation_config.json"
))
.with_context
(
||
{
format!
(
"unable to extract generation_config from directory {}"
,
directory
.display
()
)
},
)
?
;
Ok
(
Self
::
HfGenerationConfigJson
(
f
))
}
}
impl
PromptFormatterArtifact
{
pub
fn
from_
repo
(
repo_id
:
&
str
)
->
Result
<
Option
<
Self
>>
{
pub
fn
from_
disk
(
directory
:
&
Path
)
->
Result
<
Option
<
Self
>>
{
// we should only error if we expect a prompt formatter and it's not found
// right now, we don't know when to expect it, so we just return Ok(Some/None)
match
CheckedFile
::
from_disk
(
PathBuf
::
from
(
repo_id
)
.join
(
"tokenizer_config.json"
))
{
match
CheckedFile
::
from_disk
(
directory
.join
(
"tokenizer_config.json"
))
{
Ok
(
f
)
=>
Ok
(
Some
(
Self
::
HfTokenizerConfigJson
(
f
))),
Err
(
_
)
=>
Ok
(
None
),
}
}
pub
fn
chat_template_from_
repo
(
repo_id
:
&
str
)
->
Result
<
Option
<
Self
>>
{
match
CheckedFile
::
from_disk
(
PathBuf
::
from
(
repo_id
)
.join
(
"chat_template.jinja"
))
{
pub
fn
chat_template_from_
disk
(
directory
:
&
Path
)
->
Result
<
Option
<
Self
>>
{
match
CheckedFile
::
from_disk
(
directory
.join
(
"chat_template.jinja"
))
{
Ok
(
f
)
=>
Ok
(
Some
(
Self
::
HfChatTemplate
(
f
))),
Err
(
_
)
=>
Ok
(
None
),
}
...
...
@@ -812,9 +811,13 @@ impl PromptFormatterArtifact {
}
impl
TokenizerKind
{
pub
fn
from_repo
(
repo_id
:
&
str
)
->
Result
<
Self
>
{
let
f
=
CheckedFile
::
from_disk
(
PathBuf
::
from
(
repo_id
)
.join
(
"tokenizer.json"
))
.with_context
(||
format!
(
"unable to extract tokenizer kind from repo {repo_id}"
))
?
;
pub
fn
from_disk
(
directory
:
&
Path
)
->
Result
<
Self
>
{
let
f
=
CheckedFile
::
from_disk
(
directory
.join
(
"tokenizer.json"
))
.with_context
(||
{
format!
(
"unable to extract tokenizer kind from directory {}"
,
directory
.display
()
)
})
?
;
Ok
(
Self
::
HfTokenizerJson
(
f
))
}
}
...
...
tests/unit/test_sglang_unit.py
View file @
ab0da582
...
...
@@ -29,7 +29,8 @@ pytestmark = [
mock_sglang_cli
=
make_cli_args_fixture
(
"dynamo.sglang"
)
def
test_custom_jinja_template_invalid_path
(
mock_sglang_cli
):
@
pytest
.
mark
.
asyncio
async
def
test_custom_jinja_template_invalid_path
(
mock_sglang_cli
):
"""Test that invalid file path raises FileNotFoundError."""
invalid_path
=
"/nonexistent/path/to/template.jinja"
mock_sglang_cli
(
...
...
@@ -40,14 +41,15 @@ def test_custom_jinja_template_invalid_path(mock_sglang_cli):
FileNotFoundError
,
match
=
re
.
escape
(
f
"Custom Jinja template file not found:
{
invalid_path
}
"
),
):
parse_args
(
sys
.
argv
[
1
:])
await
parse_args
(
sys
.
argv
[
1
:])
def
test_custom_jinja_template_valid_path
(
mock_sglang_cli
):
@
pytest
.
mark
.
asyncio
async
def
test_custom_jinja_template_valid_path
(
mock_sglang_cli
):
"""Test that valid absolute path is stored correctly."""
mock_sglang_cli
(
model
=
"Qwen/Qwen3-0.6B"
,
custom_jinja_template
=
JINJA_TEMPLATE_PATH
)
config
=
parse_args
(
sys
.
argv
[
1
:])
config
=
await
parse_args
(
sys
.
argv
[
1
:])
assert
config
.
dynamo_args
.
custom_jinja_template
==
JINJA_TEMPLATE_PATH
,
(
f
"Expected custom_jinja_template value to be
{
JINJA_TEMPLATE_PATH
}
, "
...
...
@@ -55,7 +57,8 @@ def test_custom_jinja_template_valid_path(mock_sglang_cli):
)
def
test_custom_jinja_template_env_var_expansion
(
monkeypatch
,
mock_sglang_cli
):
@
pytest
.
mark
.
asyncio
async
def
test_custom_jinja_template_env_var_expansion
(
monkeypatch
,
mock_sglang_cli
):
"""Test that environment variables in paths are expanded by Python code."""
jinja_dir
=
str
(
TEST_DIR
/
"serve"
/
"fixtures"
)
monkeypatch
.
setenv
(
"JINJA_DIR"
,
jinja_dir
)
...
...
@@ -63,7 +66,7 @@ def test_custom_jinja_template_env_var_expansion(monkeypatch, mock_sglang_cli):
cli_path
=
"$JINJA_DIR/custom_template.jinja"
mock_sglang_cli
(
model
=
"Qwen/Qwen3-0.6B"
,
custom_jinja_template
=
cli_path
)
config
=
parse_args
(
sys
.
argv
[
1
:])
config
=
await
parse_args
(
sys
.
argv
[
1
:])
assert
"$JINJA_DIR"
not
in
config
.
dynamo_args
.
custom_jinja_template
assert
config
.
dynamo_args
.
custom_jinja_template
==
JINJA_TEMPLATE_PATH
,
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment