Unverified Commit 1630f8ba authored by fsaady's avatar fsaady Committed by GitHub
Browse files

feat: automate slurm handling in sglang example. (#1730)


Signed-off-by: default avatarFadi Saady <fsaady@nvidia.com>
parent dda59e31
# Example: Deploy Multi-node SGLang with Dynamo on SLURM
This folder implements the example of [SGLang DeepSeek-R1 Disaggregated with WideEP](../dsr1-wideep.md) on a SLURM cluster.
## Overview
The scripts in this folder set up multiple cluster nodes to run the [SGLang DeepSeek-R1 Disaggregated with WideEP](../dsr1-wideep.md) example, with separate nodes handling prefill and decode.
The node setup is done using Python job submission scripts with Jinja2 templates for flexible configuration. The setup also includes GPU utilization monitoring capabilities to track performance during benchmarks.
## Scripts
- **`submit_job_script.py`**: Main script for generating and submitting SLURM job scripts from templates
- **`job_script_template.j2`**: Jinja2 template for generating SLURM job scripts
- **`scripts/worker_setup.py`**: Worker script that handles the setup on each node
- **`scripts/monitor_gpu_utilization.sh`**: Script for monitoring GPU utilization during benchmarks
## Logs Folder Structure
Each SLURM job creates a unique log directory under `logs/` using the job ID. For example, job ID `3062824` creates the directory `logs/3062824/`.
### Log File Structure
```
logs/
├── 3062824/ # Job ID directory
│ ├── log.out # Main job output (node allocation, IP addresses, launch commands)
│ ├── log.err # Main job errors
│ ├── node0197_prefill.out # Prefill node stdout (node0197)
│ ├── node0197_prefill.err # Prefill node stderr (node0197)
│ ├── node0200_prefill.out # Prefill node stdout (node0200)
│ ├── node0200_prefill.err # Prefill node stderr (node0200)
│ ├── node0201_decode.out # Decode node stdout (node0201)
│ ├── node0201_decode.err # Decode node stderr (node0201)
│ ├── node0204_decode.out # Decode node stdout (node0204)
│ ├── node0204_decode.err # Decode node stderr (node0204)
│ ├── node0197_prefill_gpu_utilization.log # GPU utilization monitoring (node0197)
│ ├── node0200_prefill_gpu_utilization.log # GPU utilization monitoring (node0200)
│ ├── node0201_decode_gpu_utilization.log # GPU utilization monitoring (node0201)
│ └── node0204_decode_gpu_utilization.log # GPU utilization monitoring (node0204)
├── 3063137/ # Another job ID directory
├── 3062689/ # Another job ID directory
└── ...
```
## Setup
For simplicity of the example, we will make some assumptions about your SLURM cluster:
1. We assume you have access to a SLURM cluster with multiple GPU nodes
available. For functional testing, most setups should be fine. For performance
testing, you should aim to allocate groups of nodes that are performantly
inter-connected, such as those in an NVL72 setup.
2. We assume this SLURM cluster has the [Pyxis](https://github.com/NVIDIA/pyxis)
SPANK plugin setup. In particular, the `job_script_template.j2` template in this
example will use `srun` arguments like `--container-image`,
`--container-mounts`, and `--container-env` that are added to `srun` by Pyxis.
If your cluster supports similar container based plugins, you may be able to
modify the template to use that instead.
3. We assume you have already built a recent Dynamo+SGLang container image as
described [here](../dsr1-wideep.md#instructions).
This is the image that can be passed to the `--container-image` argument in later steps.
## Usage
1. **Submit a benchmark job**:
```bash
python submit_job_script.py \
--template job_script_template.j2 \
--model-dir /path/to/model \
--config-dir /path/to/configs \
--container-image container-image-uri \
--account your-slurm-account
```
**Required arguments**:
- `--template`: Path to Jinja2 template file
- `--model-dir`: Model directory path
- `--config-dir`: Config directory path
- `--container-image`: Container image URI (e.g., `registry/repository:tag`)
- `--account`: SLURM account
**Optional arguments**:
- `--prefill-nodes`: Number of prefill nodes (default: `2`)
- `--decode-nodes`: Number of decode nodes (default: `2`)
- `--gpus-per-node`: Number of GPUs per node (default: `8`)
- `--network-interface`: Network interface to use (default: `eth3`)
- `--job-name`: SLURM job name (default: `dynamo_setup`)
- `--time-limit`: Time limit in HH:MM:SS format (default: `01:00:00`)
**Note**: The script automatically calculates the total number of nodes needed based on `--prefill-nodes` and `--decode-nodes` parameters.
2. **Monitor job progress**:
```bash
squeue -u $USER
```
3. **Check logs in real-time**:
```bash
tail -f logs/{JOB_ID}/log.out
```
4. **Monitor GPU utilization**:
```bash
tail -f logs/{JOB_ID}/{node}_prefill_gpu_utilization.log
```
## Outputs
Benchmark results and outputs are stored in the `outputs/` directory, which is mounted into the container.
#!/bin/bash
#SBATCH --job-name={{ job_name }}
#SBATCH --nodes={{ total_nodes }}
#SBATCH --ntasks={{ total_nodes }}
#SBATCH --ntasks-per-node=1
#SBATCH --account={{ account }}
#SBATCH --time={{ time_limit }}
#SBATCH --output=logs/%j/log.out
#SBATCH --error=logs/%j/log.err
# Constants
PREFILL_NODES={{ prefill_nodes }}
DECODE_NODES={{ decode_nodes }}
TOTAL_NODES=$((PREFILL_NODES + DECODE_NODES))
GPUS_PER_NODE={{ gpus_per_node }}
LOG_DIR="${SLURM_SUBMIT_DIR}/logs/${SLURM_JOB_ID}/"
SCRIPT_DIR="${SLURM_SUBMIT_DIR}/scripts"
OUTPUT_DIR="${SLURM_SUBMIT_DIR}/outputs"
MODEL_DIR="{{ model_dir }}"
CONFIG_DIR="{{ config_dir }}"
CONTAINER_IMAGE="{{ container_image }}"
NETWORK_INTERFACE="{{ network_interface }}"
{% raw %}
mkdir -p "${OUTPUT_DIR}" "${LOG_DIR}"
nodes=($(scontrol show hostnames $SLURM_NODELIST))
if [ ${#nodes[@]} -ne $TOTAL_NODES ]; then
echo "Error: Expected $TOTAL_NODES nodes but got ${#nodes[@]} nodes"
exit 1
fi
# Print node information
for i in "${!nodes[@]}"; do
echo "Node $i: ${nodes[$i]}"
done
PREFILL_HOST_IP=$(srun --nodes=1 --ntasks=1 --nodelist=${nodes[0]} ifconfig $NETWORK_INTERFACE | grep -oP 'inet \K[0-9.]+')
if [ -z "$PREFILL_HOST_IP" ]; then
echo "Error: Could not retrieve IP address for prefill host ${nodes[0]} on interface $NETWORK_INTERFACE"
exit 1
fi
echo "Prefill host IP address: $PREFILL_HOST_IP"
DECODE_HOST_IP=$(srun --nodes=1 --ntasks=1 --nodelist=${nodes[$PREFILL_NODES]} ifconfig $NETWORK_INTERFACE | grep -oP 'inet \K[0-9.]+')
if [ -z "$DECODE_HOST_IP" ]; then
echo "Error: Could not retrieve IP address for decode host ${nodes[$PREFILL_NODES]} on interface $NETWORK_INTERFACE"
exit 1
fi
echo "Decode host IP address: $DECODE_HOST_IP"
# Prepare enroot arguments to pass to srun commands
ENROOT_ARGS="\
--container-image=${CONTAINER_IMAGE} \
--no-container-entrypoint \
--container-mount-home \
--no-container-remap-root \
--container-mounts=${MODEL_DIR}:/model/,${CONFIG_DIR}:/configs/,${SCRIPT_DIR}:/scripts/,${OUTPUT_DIR}:/outputs/,${LOG_DIR}:/logs/ \
"
# Launch prefill tasks on the first PREFILL_NODES nodes
for i in $(seq 0 $((PREFILL_NODES - 1))); do
node=${nodes[$i]}
rank=$i
echo "Launching prefill task on node ${i} (rank ${rank}): $node"
echo "Srun args: $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$node --output=${LOG_DIR}/${node}_prefill.out --error=${LOG_DIR}/${node}_prefill.err"
echo "Command: python /scripts/worker_setup.py --prefill_host_ip ${PREFILL_HOST_IP} --decode_host_ip ${DECODE_HOST_IP} --rank ${rank} --total_nodes ${PREFILL_NODES} --worker_type prefill --gpus_per_node ${GPUS_PER_NODE} --gpu_utilization_log /logs/${node}_prefill_gpu_utilization.log &"
srun $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$node \
--output=${LOG_DIR}/${node}_prefill.out --error=${LOG_DIR}/${node}_prefill.err \
python /scripts/worker_setup.py --prefill_host_ip ${PREFILL_HOST_IP} --decode_host_ip ${DECODE_HOST_IP} --rank ${rank} --total_nodes ${PREFILL_NODES} --worker_type prefill --gpus_per_node ${GPUS_PER_NODE} --gpu_utilization_log /logs/${node}_prefill_gpu_utilization.log &
done
# Launch decode tasks on the next DECODE_NODES nodes
for i in $(seq $PREFILL_NODES $((PREFILL_NODES + DECODE_NODES - 1))); do
node=${nodes[$i]}
rank=$((i - PREFILL_NODES))
echo "Launching decode task on node ${i} (rank ${rank}): $node"
echo "Srun args: $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$node --output=${LOG_DIR}/${node}_decode.out --error=${LOG_DIR}/${node}_decode.err"
echo "Command: python /scripts/worker_setup.py --decode_host_ip ${DECODE_HOST_IP} --prefill_host_ip ${PREFILL_HOST_IP} --rank ${rank} --total_nodes ${DECODE_NODES} --worker_type decode --gpus_per_node ${GPUS_PER_NODE} --gpu_utilization_log /logs/${node}_decode_gpu_utilization.log &"
srun $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$node \
--output=${LOG_DIR}/${node}_decode.out --error=${LOG_DIR}/${node}_decode.err \
python /scripts/worker_setup.py --decode_host_ip ${DECODE_HOST_IP} --prefill_host_ip ${PREFILL_HOST_IP} --rank ${rank} --total_nodes ${DECODE_NODES} --worker_type decode --gpus_per_node ${GPUS_PER_NODE} --gpu_utilization_log /logs/${node}_decode_gpu_utilization.log &
done
echo ""
echo "To connect to the host prefill node:"
echo "srun $ENROOT_ARGS --jobid $SLURM_JOB_ID -w ${nodes[0]} --overlap --pty bash"
echo ""
echo "Make sure to cancel the job at the end:"
echo "scancel $SLURM_JOB_ID"
# Wait for all tasks to complete
wait
echo "Script finished at $(date)"
{% endraw %}
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Usage: ./monitor_gpu_utilization.sh [interval_seconds]
# Default interval is 2 seconds
INTERVAL=${1:-2}
# Check if nvidia-smi is available
if ! command -v nvidia-smi &> /dev/null; then
echo "$(date '+%Y-%m-%d %H:%M:%S') Error: nvidia-smi not found"
exit 1
fi
echo "Starting GPU utilization monitoring (checking every ${INTERVAL}s, printing only on changes)..."
PREV_UTILIZATION=""
while true; do
CURRENT_UTILIZATION=$(nvidia-smi --query-gpu=utilization.gpu --format=csv,nounits | paste -sd ' ' -)
if [ $? -ne 0 ]; then
echo "$(date '+%Y-%m-%d %H:%M:%S') Error: nvidia-smi command failed"
else
if [ "$CURRENT_UTILIZATION" != "$PREV_UTILIZATION" ]; then
echo "$(date '+%Y-%m-%d %H:%M:%S') GPU Utilization: $CURRENT_UTILIZATION"
PREV_UTILIZATION="$CURRENT_UTILIZATION"
fi
fi
sleep $INTERVAL
done
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Worker setup script for Slurm nodes.
This script will be running on the prefill and decode nodes, and will be called by the
benchmark_dynamo.sh script.
The script will:
- Setup the environment
- Update the YAML config file
- Start Dynamo graphs.disagg service
- Monitor the GPU utilization
"""
import argparse
import logging
import os
import socket
import subprocess
import time
from pathlib import Path
import requests
# Network configurations
ETCD_CLIENT_PORT = 2379
ETCD_PEER_PORT = 2380
NATS_PORT = 4222
DIST_INIT_PORT = 29500
ETCD_LISTEN_ADDR = "http://0.0.0.0"
def setup_logging(level: int = logging.INFO) -> None:
logging.basicConfig(
level=level,
format="%(asctime)s| %(name)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
def log_gpu_utilization(log_file: Path) -> None:
"""
Log GPU utilization for all GPUs in the node.
Format: utilization.gpu [%] x y z
"""
util_script = Path(__file__).parent / "monitor_gpu_utilization.sh"
util_process = run_command(
f"bash {util_script}",
background=True,
stdout=open(log_file, "w"),
stderr=subprocess.STDOUT,
)
if not util_process:
logging.warning("Failed to start GPU utilization monitoring")
else:
logging.info("Started GPU utilization monitoring in the background")
def check_etcd_health(etcd_url: str) -> bool:
"""Check if etcd is healthy"""
health_url = f"{etcd_url}/health"
try:
response = requests.get(health_url, timeout=5)
return response.status_code == 200
except requests.exceptions.RequestException:
return False
def wait_for_etcd(etcd_url: str, max_retries: int = 1000) -> bool:
"""Wait for etcd to be ready"""
logging.info(f"Waiting for etcd to be ready on {etcd_url}...")
for attempt in range(max_retries):
try:
if check_etcd_health(etcd_url):
logging.info("Etcd is ready!")
return True
except requests.exceptions.RequestException:
pass
logging.info(
f"Etcd not ready yet, retrying in 2 seconds... (attempt {attempt + 1}/{max_retries})"
)
time.sleep(2)
logging.error("Etcd failed to become ready within the timeout period")
return False
def run_command(
cmd: str, background: bool = False, shell: bool = True, stdout=None, stderr=None
):
"""
Run a command either in background or foreground.
Args:
cmd: Command to run
background: If True, run in background and return Popen object. If False, wait for
completion and return exit code.
shell: Whether to run command through shell
Returns:
If background=True: subprocess.Popen
If background=False: int (exit code)
"""
logging.info(f"Running command (background={background}, shell={shell}): {cmd}")
if background:
process = subprocess.Popen(
cmd,
shell=shell,
stdout=stdout if stdout else subprocess.PIPE,
stderr=stderr if stderr else subprocess.PIPE,
) # noqa: S603
return process
else:
result = subprocess.run(cmd, shell=shell, check=True) # noqa: S603
return result.returncode
def _parse_command_line_args(args: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Worker setup script for Dynamo distributed training",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--prefill_host_ip",
type=str,
required=True,
help="IP address of the prefill host node",
)
parser.add_argument(
"--decode_host_ip",
type=str,
required=True,
help="IP address of the decode host node",
)
parser.add_argument(
"--rank",
type=int,
required=True,
help="Rank of the current node (0 for host node)",
)
parser.add_argument(
"--total_nodes",
type=int,
required=True,
help="Total number of nodes in the cluster",
)
parser.add_argument(
"--worker_type",
choices=["decode", "prefill"],
required=True,
help="Type of worker to run",
)
parser.add_argument(
"--gpus_per_node",
type=int,
default=8,
help="Number of GPUs per node (default: 8)",
)
parser.add_argument(
"--gpu_utilization_log",
type=str,
default=None,
help="File to log GPU utilization (default: None)",
)
return parser.parse_args(args)
def _validate_args(args: argparse.Namespace) -> None:
"""Validate command line arguments"""
if args.rank < 0:
raise ValueError("Rank must be non-negative")
if args.total_nodes < 1:
raise ValueError("Total nodes must be at least 1")
if args.gpus_per_node < 1:
raise ValueError("GPUs per node must be at least 1")
def setup_prefill_node(
rank: int, prefill_host_ip: str, total_nodes: int, total_gpus: int
) -> int:
"""
Setup the prefill node.
"""
if rank == 0:
logging.info(f"Setting up host prefill node: {rank}")
logging.info(f"Starting nats server on node {rank} with IP {prefill_host_ip}")
nats_process = run_command("nats-server -js", background=True)
if not nats_process:
raise RuntimeError("Failed to start nats-server")
etcd_cmd = (
f"etcd --listen-client-urls {ETCD_LISTEN_ADDR}:{ETCD_CLIENT_PORT} "
f"--advertise-client-urls {ETCD_LISTEN_ADDR}:{ETCD_CLIENT_PORT} "
f"--listen-peer-urls {ETCD_LISTEN_ADDR}:{ETCD_PEER_PORT} "
f"--initial-cluster default=http://{prefill_host_ip}:{ETCD_PEER_PORT}"
)
etcd_process = run_command(etcd_cmd, background=True)
if not etcd_process:
raise RuntimeError("Failed to start etcd")
ingress_process = run_command("dynamo run in=http out=dyn", background=True)
if not ingress_process:
raise RuntimeError("Failed to start ingress")
else:
logging.info(f"Setting up child prefill node: {rank}")
if not wait_for_etcd(f"http://{prefill_host_ip}:{ETCD_CLIENT_PORT}"):
raise RuntimeError("Failed to connect to etcd")
# NOTE: This implements the example in examples/sglang/dsr1-wideep.md
# For other examples, the command might have to be modified.
dynamo_cmd = (
f"python3 components/worker.py "
"--model-path /model/ "
"--served-model-name deepseek-ai/DeepSeek-R1 "
"--skip-tokenizer-init "
"--disaggregation-mode prefill "
"--disaggregation-transfer-backend nixl "
"--disaggregation-bootstrap-port 30001 "
f"--dist-init-addr {prefill_host_ip}:{DIST_INIT_PORT} "
f"--nnodes {total_nodes} "
f"--node-rank {rank} "
f"--tp-size {total_gpus} "
f"--dp-size {total_gpus} "
"--enable-dp-attention "
"--decode-log-interval 1 "
"--enable-deepep-moe "
"--page-size 1 "
"--trust-remote-code "
"--moe-dense-tp-size 1 "
"--enable-dp-lm-head "
"--disable-radix-cache "
"--watchdog-timeout 1000000 "
"--enable-two-batch-overlap "
"--deepep-mode normal "
"--mem-fraction-static 0.85 "
"--deepep-config /configs/deepep.json "
"--ep-num-redundant-experts 32 "
"--ep-dispatch-algorithm dynamic "
"--eplb-algorithm deepseek "
)
return run_command(dynamo_cmd)
def setup_decode_node(
rank: int,
decode_host_ip: str,
prefill_host_ip: str,
total_nodes: int,
total_gpus: int,
) -> int:
"""
Setup the decode node.
"""
logging.info(f"Setting up child decode node: {rank}")
if not wait_for_etcd(f"http://{prefill_host_ip}:{ETCD_CLIENT_PORT}"):
raise RuntimeError("Failed to connect to etcd")
dynamo_cmd = (
"python3 components/decode_worker.py "
"--model-path /model/ "
"--served-model-name deepseek-ai/DeepSeek-R1 "
"--skip-tokenizer-init "
"--disaggregation-mode decode "
"--disaggregation-transfer-backend nixl "
"--disaggregation-bootstrap-port 30001 "
f"--dist-init-addr {decode_host_ip}:{DIST_INIT_PORT} "
f"--nnodes {total_nodes} "
f"--node-rank {rank} "
f"--tp-size {total_gpus} "
f"--dp-size {total_gpus} "
"--enable-dp-attention "
"--decode-log-interval 1 "
"--enable-deepep-moe "
"--page-size 1 "
"--trust-remote-code "
"--moe-dense-tp-size 1 "
"--enable-dp-lm-head "
"--disable-radix-cache "
"--watchdog-timeout 1000000 "
"--enable-two-batch-overlap "
"--deepep-mode low_latency "
"--mem-fraction-static 0.835 "
"--ep-num-redundant-experts 32 "
"--cuda-graph-bs 256 "
)
return run_command(dynamo_cmd)
def setup_env(prefill_host_ip: str):
nats_server = f"nats://{prefill_host_ip}:{NATS_PORT}"
etcd_endpoints = f"http://{prefill_host_ip}:{ETCD_CLIENT_PORT}"
os.environ["NATS_SERVER"] = nats_server
os.environ["ETCD_ENDPOINTS"] = etcd_endpoints
logging.info(f"set NATS_SERVER: {nats_server}")
logging.info(f"set ETCD_ENDPOINTS: {etcd_endpoints}")
def main(input_args: list[str] | None = None):
setup_logging()
args = _parse_command_line_args(input_args)
_validate_args(args)
if args.gpu_utilization_log:
log_gpu_utilization(args.gpu_utilization_log)
logging.info(f"{args.worker_type.capitalize()} node setup started")
logging.info(f"Hostname: {socket.gethostname()}")
logging.info(f"Prefill host IP: {args.prefill_host_ip}")
logging.info(f"Decode host IP: {args.decode_host_ip}")
logging.info(f"Rank: {args.rank}")
setup_env(args.prefill_host_ip)
if args.worker_type == "prefill":
setup_prefill_node(
args.rank,
args.prefill_host_ip,
args.total_nodes,
args.total_nodes * args.gpus_per_node,
)
else:
setup_decode_node(
args.rank,
args.decode_host_ip,
args.prefill_host_ip,
args.total_nodes,
args.total_nodes * args.gpus_per_node,
)
logging.info(f"{args.worker_type.capitalize()} node setup complete")
if __name__ == "__main__":
main()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Script to generate SLURM job scripts from Jinja2 templates.
"""
import argparse
import logging
import subprocess
import tempfile
from jinja2 import Template
def setup_logging(level: int = logging.INFO) -> None:
logging.basicConfig(
level=level,
format="%(asctime)s| %(name)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
def generate_job_script(template_path, output_path, **kwargs):
"""Generate a job script from template with given parameters."""
with open(template_path, "r") as f:
template = Template(f.read())
rendered_script = template.render(**kwargs)
with open(output_path, "w") as f:
f.write(rendered_script)
return output_path
def submit_job(job_script_path):
"""
Submit the job script to SLURM and extract the job ID from the output.
Returns:
The job ID of the submitted job.
"""
try:
result = subprocess.run(
["sbatch", job_script_path], capture_output=True, text=True, check=True
)
output_lines = result.stdout.strip().split("\n")
# sbatch typically outputs: "Submitted batch job JOBID"
job_id = output_lines[-1].split()[-1]
logging.info(f"Job submitted successfully with ID: {job_id}")
return job_id
except subprocess.CalledProcessError as e:
logging.error(f"Error submitting job: {e}")
logging.error(f"stderr: {e.stderr}")
raise
except (IndexError, ValueError):
logging.error(f"Error parsing job ID from sbatch output: {result.stdout}")
raise
def _parse_command_line_args(args: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Generate and submit SLURM job scripts"
)
parser.add_argument(
"--template", required=True, help="Path to Jinja2 template file"
)
# Template parameters
parser.add_argument("--job-name", default="dynamo_setup", help="SLURM job name")
parser.add_argument("--account", required=True, help="SLURM account")
parser.add_argument("--model-dir", required=True, help="Model directory path")
parser.add_argument("--config-dir", required=True, help="Config directory path")
parser.add_argument("--container-image", required=True, help="Container image")
parser.add_argument(
"--time-limit", default="01:00:00", help="Time limit (HH:MM:SS)"
)
parser.add_argument(
"--prefill-nodes", type=int, default=2, help="Number of prefill nodes"
)
parser.add_argument(
"--decode-nodes", type=int, default=2, help="Number of decode nodes"
)
parser.add_argument(
"--gpus-per-node", type=int, default=8, help="Number of GPUs per node"
)
parser.add_argument(
"--network-interface", default="eth3", help="Network interface to use"
)
return parser.parse_args(args)
def main(input_args: list[str] | None = None):
setup_logging()
args = _parse_command_line_args(input_args)
total_nodes = args.prefill_nodes + args.decode_nodes
template_vars = {
"job_name": args.job_name,
"total_nodes": total_nodes,
"account": args.account,
"time_limit": args.time_limit,
"prefill_nodes": args.prefill_nodes,
"decode_nodes": args.decode_nodes,
"model_dir": args.model_dir,
"config_dir": args.config_dir,
"container_image": args.container_image,
"gpus_per_node": args.gpus_per_node,
"network_interface": args.network_interface,
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".sh") as temp_file:
generate_job_script(args.template, temp_file.name, **template_vars)
job_id = submit_job(temp_file.name)
logging.info(f"Job logs will be available in: logs/{job_id}/")
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment