Unverified Commit 73e9a2ef authored by Mick's avatar Mick Committed by GitHub
Browse files

fix: tiny fix cli (#12744)

parent 837b08eb
import argparse import argparse
from sglang.cli.main import get_is_diffusion_model, get_model_path from sglang.cli.utils import get_is_diffusion_model, get_model_path
from sglang.multimodal_gen.runtime.entrypoints.cli.generate import ( from sglang.multimodal_gen.runtime.entrypoints.cli.generate import (
add_multimodal_gen_generate_args, add_multimodal_gen_generate_args,
generate_cmd, generate_cmd,
...@@ -8,6 +8,13 @@ from sglang.multimodal_gen.runtime.entrypoints.cli.generate import ( ...@@ -8,6 +8,13 @@ from sglang.multimodal_gen.runtime.entrypoints.cli.generate import (
def generate(args, extra_argv): def generate(args, extra_argv):
# If help is requested, show generate subcommand help without requiring --model-path
if any(h in extra_argv for h in ("-h", "--help")):
parser = argparse.ArgumentParser(description="SGLang Multimodal Generation")
add_multimodal_gen_generate_args(parser)
parser.parse_args(extra_argv)
return
model_path = get_model_path(extra_argv) model_path = get_model_path(extra_argv)
is_diffusion_model = get_is_diffusion_model(model_path) is_diffusion_model = get_is_diffusion_model(model_path)
if is_diffusion_model: if is_diffusion_model:
......
import argparse import argparse
import hashlib
import json
import logging
import os
import tempfile
from typing import Optional
import filelock
from huggingface_hub import hf_hub_download
from sglang.cli.generate import generate from sglang.cli.generate import generate
from sglang.cli.serve import serve from sglang.cli.serve import serve
logger = logging.getLogger(__name__)
temp_dir = tempfile.gettempdir()
def _get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
lock_dir = cache_dir or temp_dir
os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
model_name = model_name_or_path.replace("/", "-")
hash_name = hashlib.sha256(model_name.encode()).hexdigest()
lock_file_name = hash_name + model_name + ".lock"
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666)
return lock
# Copied and adapted from hf_diffusers_utils.py
def _maybe_download_model(
model_name_or_path: str, local_dir: str | None = None, download: bool = True
) -> str:
"""
Resolve a model path. If it's a local directory, return it.
If it's a Hugging Face Hub ID, download only the config file
(`model_index.json` or `config.json`) and return its directory.
Args:
model_name_or_path: Local path or Hugging Face Hub model ID
local_dir: Local directory to save the downloaded file (if any)
download: Whether to download from Hugging Face Hub when needed
Returns:
Local directory path that contains the downloaded config file, or the original local directory.
"""
if os.path.exists(model_name_or_path):
logger.info("Model already exists locally")
return model_name_or_path
if not download:
return model_name_or_path
with _get_lock(model_name_or_path):
# Try `model_index.json` first (diffusers models)
try:
logger.info(
"Downloading model_index.json from HF Hub for %s...",
model_name_or_path,
)
file_path = hf_hub_download(
repo_id=model_name_or_path,
filename="model_index.json",
local_dir=local_dir,
)
logger.info("Downloaded to %s", file_path)
return os.path.dirname(file_path)
except Exception as e_index:
logger.debug("model_index.json not found or failed: %s", e_index)
# Fallback to `config.json`
try:
logger.info(
"Downloading config.json from HF Hub for %s...", model_name_or_path
)
file_path = hf_hub_download(
repo_id=model_name_or_path,
filename="config.json",
local_dir=local_dir,
)
logger.info("Downloaded to %s", file_path)
return os.path.dirname(file_path)
except Exception as e_config:
raise ValueError(
(
"Could not find model locally at %s and failed to download "
"model_index.json/config.json from HF Hub: %s"
)
% (model_name_or_path, e_config)
) from e_config
# Copied and adapted from hf_diffusers_utils.py
def is_diffusers_model_path(model_path: str) -> True:
"""
Verify if the model directory contains a valid diffusers configuration.
Args:
model_path: Path to the model directory
Returns:
The loaded model configuration as a dictionary if the model is a diffusers model
None if the model is not a diffusers model
"""
# Prefer model_index.json which indicates a diffusers pipeline
config_path = os.path.join(model_path, "model_index.json")
if not os.path.exists(config_path):
return False
# Load the config
with open(config_path) as f:
config = json.load(f)
# Verify diffusers version exists
if "_diffusers_version" not in config:
return False
return True
def get_is_diffusion_model(model_path: str):
model_path = _maybe_download_model(model_path)
is_diffusion_model = is_diffusers_model_path(model_path)
if is_diffusion_model:
logger.info("Diffusion model detected")
return is_diffusion_model
def get_model_path(extra_argv):
# Find the model_path argument
model_path = None
for i, arg in enumerate(extra_argv):
if arg == "--model-path":
if i + 1 < len(extra_argv):
model_path = extra_argv[i + 1]
break
elif arg.startswith("--model-path="):
model_path = arg.split("=", 1)[1]
break
if model_path is None:
# Fallback for --help or other cases where model-path is not provided
if any(h in extra_argv for h in ["-h", "--help"]):
raise Exception(
"Usage: sglang serve --model-path <model-name-or-path> [additional-arguments]\n\n"
"This command can launch either a standard language model server or a diffusion model server.\n"
"The server type is determined by the model path.\n"
"For specific arguments, please provide a model_path."
)
else:
raise Exception(
"Error: --model-path is required. "
"Please provide the path to the model."
)
return model_path
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
...@@ -4,16 +4,49 @@ import argparse ...@@ -4,16 +4,49 @@ import argparse
import logging import logging
import os import os
from sglang.cli.main import get_is_diffusion_model, get_model_path from sglang.cli.utils import get_is_diffusion_model, get_model_path
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def serve(args, extra_argv): def serve(args, extra_argv):
if any(h in extra_argv for h in ("-h", "--help")):
# Since the server type is determined by the model, and we don't have a model path,
# we can't show the exact help. Instead, we show a general help message and then
# the help for both possible server types.
print(
"Usage: sglang serve --model-path <model-name-or-path> [additional-arguments]\n"
)
print(
"This command can launch either a standard language model server or a diffusion model server."
)
print("The server type is determined by the model path.\n")
print("For specific arguments, please provide a model_path.")
print("\n--- Help for Standard Language Model Server ---")
from sglang.srt.server_args import prepare_server_args
try:
prepare_server_args(["--help"])
except SystemExit:
pass # argparse --help calls sys.exit
print("\n--- Help for Diffusion Model Server ---")
from sglang.multimodal_gen.runtime.entrypoints.cli.serve import (
add_multimodal_gen_serve_args,
)
parser = argparse.ArgumentParser(description="SGLang Diffusion Model Serving")
add_multimodal_gen_serve_args(parser)
parser.print_help()
return
model_path = get_model_path(extra_argv) model_path = get_model_path(extra_argv)
try: try:
is_diffusion_model = get_is_diffusion_model(model_path) is_diffusion_model = get_is_diffusion_model(model_path)
if is_diffusion_model:
logger.info("Diffusion model detected")
if is_diffusion_model: if is_diffusion_model:
# Logic for Diffusion Models # Logic for Diffusion Models
from sglang.multimodal_gen.runtime.entrypoints.cli.serve import ( from sglang.multimodal_gen.runtime.entrypoints.cli.serve import (
......
import hashlib
import json
import logging
import os
import tempfile
from typing import Optional
import filelock
from huggingface_hub import hf_hub_download
logger = logging.getLogger(__name__)
temp_dir = tempfile.gettempdir()
def _get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
lock_dir = cache_dir or temp_dir
os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
model_name = model_name_or_path.replace("/", "-")
hash_name = hashlib.sha256(model_name.encode()).hexdigest()
lock_file_name = hash_name + model_name + ".lock"
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666)
return lock
# Copied and adapted from hf_diffusers_utils.py
def _maybe_download_model(
model_name_or_path: str, local_dir: str | None = None, download: bool = True
) -> str:
"""
Resolve a model path. If it's a local directory, return it.
If it's a Hugging Face Hub ID, download only the config file
(`model_index.json` or `config.json`) and return its directory.
Args:
model_name_or_path: Local path or Hugging Face Hub model ID
local_dir: Local directory to save the downloaded file (if any)
download: Whether to download from Hugging Face Hub when needed
Returns:
Local directory path that contains the downloaded config file, or the original local directory.
"""
if os.path.exists(model_name_or_path):
logger.info("Model already exists locally")
return model_name_or_path
if not download:
return model_name_or_path
with _get_lock(model_name_or_path):
# Try `model_index.json` first (diffusers models)
try:
logger.info(
"Downloading model_index.json from HF Hub for %s...",
model_name_or_path,
)
file_path = hf_hub_download(
repo_id=model_name_or_path,
filename="model_index.json",
local_dir=local_dir,
)
logger.info("Downloaded to %s", file_path)
return os.path.dirname(file_path)
except Exception as e_index:
logger.debug("model_index.json not found or failed: %s", e_index)
# Fallback to `config.json`
try:
logger.info(
"Downloading config.json from HF Hub for %s...", model_name_or_path
)
file_path = hf_hub_download(
repo_id=model_name_or_path,
filename="config.json",
local_dir=local_dir,
)
logger.info("Downloaded to %s", file_path)
return os.path.dirname(file_path)
except Exception as e_config:
raise ValueError(
(
"Could not find model locally at %s and failed to download "
"model_index.json/config.json from HF Hub: %s"
)
% (model_name_or_path, e_config)
) from e_config
# Copied and adapted from hf_diffusers_utils.py
def is_diffusers_model_path(model_path: str) -> True:
"""
Verify if the model directory contains a valid diffusers configuration.
Args:
model_path: Path to the model directory
Returns:
The loaded model configuration as a dictionary if the model is a diffusers model
None if the model is not a diffusers model
"""
# Prefer model_index.json which indicates a diffusers pipeline
config_path = os.path.join(model_path, "model_index.json")
if not os.path.exists(config_path):
return False
# Load the config
with open(config_path) as f:
config = json.load(f)
# Verify diffusers version exists
if "_diffusers_version" not in config:
return False
return True
def get_is_diffusion_model(model_path: str):
model_path = _maybe_download_model(model_path)
is_diffusion_model = is_diffusers_model_path(model_path)
if is_diffusion_model:
logger.info("Diffusion model detected")
return is_diffusion_model
def get_model_path(extra_argv):
# Find the model_path argument
model_path = None
for i, arg in enumerate(extra_argv):
if arg == "--model-path":
if i + 1 < len(extra_argv):
model_path = extra_argv[i + 1]
break
elif arg.startswith("--model-path="):
model_path = arg.split("=", 1)[1]
break
if model_path is None:
# Fallback for --help or other cases where model-path is not provided
if any(h in extra_argv for h in ["-h", "--help"]):
raise Exception(
"Usage: sglang serve --model-path <model-name-or-path> [additional-arguments]\n\n"
"This command can launch either a standard language model server or a diffusion model server.\n"
"The server type is determined by the model path.\n"
"For specific arguments, please provide a model_path."
)
else:
raise Exception(
"Error: --model-path is required. "
"Please provide the path to the model."
)
return model_path
...@@ -143,7 +143,7 @@ SERVER_ARGS=( ...@@ -143,7 +143,7 @@ SERVER_ARGS=(
--ring-degree=2 --ring-degree=2
) )
sglang serve $SERVER_ARGS sglang serve"${SERVER_ARGS[@]}"
``` ```
- **--model-path**: Which model to load. The example uses `Wan-AI/Wan2.1-T2V-1.3B-Diffusers`. - **--model-path**: Which model to load. The example uses `Wan-AI/Wan2.1-T2V-1.3B-Diffusers`.
...@@ -265,7 +265,7 @@ SAMPLING_ARGS=( ...@@ -265,7 +265,7 @@ SAMPLING_ARGS=(
--output-file-name "A curious raccoon.mp4" --output-file-name "A curious raccoon.mp4"
) )
sglang generate $SERVER_ARGS $SAMPLING_ARGS sglang generate "${SERVER_ARGS[@]}" "${SAMPLING_ARGS[@]}"
``` ```
Once the generation task has finished, the server will shut down automatically. Once the generation task has finished, the server will shut down automatically.
......
...@@ -258,10 +258,10 @@ class DiffGenerator: ...@@ -258,10 +258,10 @@ class DiffGenerator:
data_type = ( data_type = (
DataType.IMAGE DataType.IMAGE
if self.server_args.pipeline_config.is_image_gen if self.server_args.pipeline_config.is_image_gen
or sampling_params.num_frames == 1 or pretrained_sampling_params.num_frames == 1
else DataType.VIDEO else DataType.VIDEO
) )
sampling_params.data_type = data_type pretrained_sampling_params.data_type = data_type
pretrained_sampling_params.set_output_file_name() pretrained_sampling_params.set_output_file_name()
requests: list[Req] = [] requests: list[Req] = []
......
...@@ -217,6 +217,7 @@ class CudaPlatformBase(Platform): ...@@ -217,6 +217,7 @@ class CudaPlatformBase(Platform):
elif selected_backend == AttentionBackendEnum.FA3: elif selected_backend == AttentionBackendEnum.FA3:
if is_blackwell(): if is_blackwell():
raise ValueError("The 'fa3' backend is not supported on Blackwell GPUs") raise ValueError("The 'fa3' backend is not supported on Blackwell GPUs")
target_backend = AttentionBackendEnum.FA3
elif selected_backend: elif selected_backend:
raise ValueError(f"Invalid attention backend for {cls.device_name}") raise ValueError(f"Invalid attention backend for {cls.device_name}")
else: else:
......
...@@ -777,6 +777,13 @@ class ServerArgs: ...@@ -777,6 +777,13 @@ class ServerArgs:
) )
self.sp_degree = self.ulysses_degree = self.ring_degree = 1 self.sp_degree = self.ulysses_degree = self.ring_degree = 1
if (
self.ring_degree is not None
and self.ring_degree > 1
and self.attention_backend != "fa3"
):
raise ValueError("Ring Attention is only supported for fa3 backend for now")
if self.sp_degree == -1: if self.sp_degree == -1:
# assume we leave all remaining gpus to sp # assume we leave all remaining gpus to sp
num_gpus_per_group = self.dp_size * self.tp_size num_gpus_per_group = self.dp_size * self.tp_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment