Unverified Commit e1c7a4a4 authored by Yunzhou Liu's avatar Yunzhou Liu Committed by GitHub
Browse files

feat: update how we wait for model before model is ready (#3050)


Signed-off-by: default avatarElnifio <elnifio0519@gmail.com>
parent 65a1e1b4
...@@ -13,6 +13,7 @@ The node setup is done using Python job submission scripts with Jinja2 templates ...@@ -13,6 +13,7 @@ The node setup is done using Python job submission scripts with Jinja2 templates
- **`job_script_template.j2`**: Jinja2 template for generating SLURM job scripts - **`job_script_template.j2`**: Jinja2 template for generating SLURM job scripts
- **`scripts/worker_setup.py`**: Worker script that handles the setup on each node - **`scripts/worker_setup.py`**: Worker script that handles the setup on each node
- **`scripts/monitor_gpu_utilization.sh`**: Script for monitoring GPU utilization during benchmarks - **`scripts/monitor_gpu_utilization.sh`**: Script for monitoring GPU utilization during benchmarks
- **`submit_disagg.sh`**: A simple one-liner script that invokes the `submit_job_script.py`
## Logs Folder Structure ## Logs Folder Structure
......
...@@ -17,6 +17,7 @@ PREFILL_WORKERS={{ prefill_workers }} ...@@ -17,6 +17,7 @@ PREFILL_WORKERS={{ prefill_workers }}
DECODE_WORKERS={{ decode_workers }} DECODE_WORKERS={{ decode_workers }}
TOTAL_NODES=$((PREFILL_NODES + DECODE_NODES)) TOTAL_NODES=$((PREFILL_NODES + DECODE_NODES))
GPUS_PER_NODE={{ gpus_per_node }} GPUS_PER_NODE={{ gpus_per_node }}
TOTAL_GPUS=$((TOTAL_NODES * GPUS_PER_NODE))
PREFILL_NODES_PER_WORKER=$((PREFILL_NODES / PREFILL_WORKERS)) PREFILL_NODES_PER_WORKER=$((PREFILL_NODES / PREFILL_WORKERS))
DECODE_NODES_PER_WORKER=$((DECODE_NODES / DECODE_WORKERS)) DECODE_NODES_PER_WORKER=$((DECODE_NODES / DECODE_WORKERS))
LOG_DIR="${SLURM_SUBMIT_DIR}/logs/${SLURM_JOB_ID}" LOG_DIR="${SLURM_SUBMIT_DIR}/logs/${SLURM_JOB_ID}"
...@@ -297,7 +298,7 @@ PROFILER_ARGS="{{ profiler_arg }}" ...@@ -297,7 +298,7 @@ PROFILER_ARGS="{{ profiler_arg }}"
{% if do_profile %} {% if do_profile %}
{% raw %} {% raw %}
srun --nodes=1 --ntasks=1 $ENROOT_ARGS --jobid $SLURM_JOB_ID -w ${nodes[0]} --output=${LOG_DIR}/profile.out --error=${LOG_DIR}/profile.err --overlap bash /scripts/${PROFILER_TYPE}/bench.sh $PREFILL_WORKERS $DECODE_WORKERS ${PROFILER_ARGS} & srun --nodes=1 --ntasks=1 $ENROOT_ARGS --jobid $SLURM_JOB_ID -w ${nodes[0]} --output=${LOG_DIR}/profile.out --error=${LOG_DIR}/profile.err --overlap bash /scripts/${PROFILER_TYPE}/bench.sh $PREFILL_WORKERS $DECODE_WORKERS $TOTAL_GPUS ${PROFILER_ARGS} &
{% endraw %} {% endraw %}
{% endif %} {% endif %}
......
...@@ -6,22 +6,26 @@ wait_for_model() { ...@@ -6,22 +6,26 @@ wait_for_model() {
local model_host=$1 local model_host=$1
local model_port=$2 local model_port=$2
local poll=${3:-1} local n_prefill=${3:-1}
local timeout=${4:-600} local n_decode=${4:-1}
local report_every=${5:-60} local poll=${5:-1}
local timeout=${6:-600}
local report_every=${7:-60}
local health_addr="http://${model_host}:${model_port}/health" local health_addr="http://${model_host}:${model_port}/health"
echo "Polling ${health_addr} every ${poll} seconds" echo "Polling ${health_addr} every ${poll} seconds to check whether ${n_prefill} prefills and ${n_decode} decodes are alive"
local start_ts=$(date +%s) local start_ts=$(date +%s)
local report_ts=$(date +%s) local report_ts=$(date +%s)
while :; do while :; do
# Curl timeout - our primary use case here is to launch it at the first node (localhost), so no timeout is needed.
curl_result=$(curl ${health_addr} 2>/dev/null) curl_result=$(curl ${health_addr} 2>/dev/null)
health=$(grep '"status":"healthy"' <<< $curl_result) # Python path - Use of `check_server_health.py` is self-constrained outside of any packaging.
if [[ -n $health ]]; then check_result=$(python3 /scripts/check_server_health.py $n_prefill $n_decode <<< $curl_result)
echo "Model is alive. Health response: ${curl_result}; " if [[ $check_result == *"Model is ready."* ]]; then
return 0; echo $check_result
return 0
fi fi
time_now=$(date +%s) time_now=$(date +%s)
...@@ -31,7 +35,7 @@ wait_for_model() { ...@@ -31,7 +35,7 @@ wait_for_model() {
fi fi
if [[ $((time_now - report_ts)) -ge $report_every ]]; then if [[ $((time_now - report_ts)) -ge $report_every ]]; then
echo "Waiting for model to come alive. Current result: ${curl_result}" echo $check_result
report_ts=$time_now report_ts=$time_now
fi fi
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# pytest: skip-file
import json
import sys
"""
A file that parses the response of `curl <host_ip>:<host_port>/health` endpoint
to check whether the server is ready to be benchmarked.
Usage:
```bash
curl_result=$(curl "${host_ip}:${host_port}/health" 2> /dev/null)
check_result=$(python3 check_server_health.py $N_PREFILL $N_DECODE <<< $curl_result)
# ... then do subsequent processing for check_result ...
```
"""
def check_server_health(expected_n_prefill, expected_n_decode, response):
"""
Checks the health of the server's response
and ensures that the number of spinned up prefill & decode
matches our expectation.
---
Parameter:
- expected_n_prefill: string (expect integer), number of expected prefill workers.
- expected_n_decode: string (expect integer), number of expected decode workers.
- response: string, formatted `curl <url>/health` curl results,
should be JSON-parsable
Returns:
string, a pretty-printable string that tell the current status.
"""
if not (expected_n_prefill.isnumeric() and expected_n_decode.isnumeric()):
return f"Got unparsable expected prefill / decode value: {expected_n_prefill} & {expected_n_decode} should be string"
expected_n_prefill = int(expected_n_prefill)
expected_n_decode = int(expected_n_decode)
try:
decoded_response = json.loads(response)
except json.JSONDecodeError:
return f"Got invalid response from server that leads to JSON Decode error: {response}"
if "instances" not in decoded_response:
return f"Key 'instances' not found in response: {response}"
for instance in decoded_response["instances"]:
if instance.get("endpoint") == "generate":
if instance.get("component") == "prefill":
expected_n_prefill -= 1
if instance.get("component") == "backend":
expected_n_decode -= 1
if expected_n_prefill <= 0 and expected_n_decode <= 0:
return f"Model is ready. Response: {response}"
else:
return f"Model is not ready, waiting for {expected_n_prefill} prefills and {expected_n_decode} decodes to spin up. Response: {response}"
if __name__ == "__main__":
"""
Usage -
provide the expected number of prefill / decode as sys args
and then provide the `curl` response as an input.
E.g.:
```bash
curl_result=$(curl "${host_ip}:${host_port}/health" 2> /dev/null)
check_result=$(python3 check_server_health.py $N_PREFILL $N_DECODE <<< $curl_result)
# ... then do subsequent processing for check_result ...
```
"""
expected_n_prefill = sys.argv[1]
expected_n_decode = sys.argv[2]
response = sys.stdin.read()
print(
check_server_health(
expected_n_prefill=expected_n_prefill,
expected_n_decode=expected_n_decode,
response=response,
)
)
...@@ -4,10 +4,11 @@ ...@@ -4,10 +4,11 @@
prefill_workers=$1 prefill_workers=$1
decode_workers=$2 decode_workers=$2
total_gpus=$3
chosen_isl=$3 chosen_isl=$4
chosen_osl=$4 chosen_osl=$5
chosen_concurrencies=$5 chosen_concurrencies=$6
echo "Profiling for model with PrefillDP=${prefill_workers}, DecodeDP=${decode_workers}" echo "Profiling for model with PrefillDP=${prefill_workers}, DecodeDP=${decode_workers}"
...@@ -23,7 +24,7 @@ echo "Chosen random seed ${random_seed}" ...@@ -23,7 +24,7 @@ echo "Chosen random seed ${random_seed}"
source /scripts/benchmark_utils.sh source /scripts/benchmark_utils.sh
wait_for_model $head_node $head_port 5 2400 60 wait_for_model $head_node $head_port $prefill_workers $decode_workers 5 900 60
set -e set -e
warmup_model $head_node $head_port $SERVED_MODEL_NAME $MODEL_PATH "${chosen_isl}x${chosen_osl}x10000x10000x250" warmup_model $head_node $head_port $SERVED_MODEL_NAME $MODEL_PATH "${chosen_isl}x${chosen_osl}x10000x10000x250"
......
...@@ -2,12 +2,16 @@ ...@@ -2,12 +2,16 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
chosen_isl=$3 n_prefill=$1
chosen_osl=$4 n_decode=$2
total_gpus=$3
chosen_isl=$4
chosen_osl=$5
concurrency_list=$6
concurrency_list=$5
IFS='x' read -r -a chosen_concurrencies <<< "$concurrency_list" IFS='x' read -r -a chosen_concurrencies <<< "$concurrency_list"
chosen_req_rate=$6 chosen_req_rate=$7
echo "Config ${chosen_isl}; ${chosen_osl}; ${chosen_concurrencies[@]}; ${chosen_req_rate}" echo "Config ${chosen_isl}; ${chosen_osl}; ${chosen_concurrencies[@]}; ${chosen_req_rate}"
...@@ -19,12 +23,12 @@ MODEL_PATH=/model/ ...@@ -19,12 +23,12 @@ MODEL_PATH=/model/
source /scripts/benchmark_utils.sh source /scripts/benchmark_utils.sh
wait_for_model $head_node $head_port 5 2400 60 wait_for_model $head_node $head_port $n_prefill $n_decode 5 900 60
sleep 300 sleep 300
set -e set -e
warmup_model $head_node $head_port $SERVED_MODEL_NAME $MODEL_PATH "${chosen_isl}x${chosen_osl}x10000x10000x${chosen_req_rate}" warmup_model $head_node $head_port $SERVED_MODEL_NAME $MODEL_PATH "${chosen_isl}x${chosen_osl}x10000x10000x250"
set +e set +e
profile_folder="/logs/sglang_isl_${chosen_isl}_osl_${chosen_osl}" profile_folder="/logs/sglang_isl_${chosen_isl}_osl_${chosen_osl}"
......
...@@ -331,6 +331,104 @@ async def async_request_openai_completions( ...@@ -331,6 +331,104 @@ async def async_request_openai_completions(
return output return output
async def async_request_dynamo_completions(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith(
("completions", "profile")
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
async with aiohttp.ClientSession(
trust_env=True, timeout=AIOHTTP_TIMEOUT
) as session:
payload = {
"model": request_func_input.model_name
if request_func_input.model_name
else request_func_input.model,
"prompt": request_func_input.prompt,
"temperature": 0.0,
"best_of": request_func_input.best_of,
"max_tokens": request_func_input.output_len,
"logprobs": request_func_input.logprobs,
"stream": True,
"stream_options": {
"include_usage": True,
},
}
if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos
if request_func_input.extra_body:
payload.update(request_func_input.extra_body)
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
generated_text = ""
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(
url=api_url, json=payload, headers=headers
) as response:
if response.status == 200:
first_chunk_received = False
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
if chunk != "[DONE]":
data = json.loads(chunk)
# NOTE: Some completion API might have a last
# usage summary response without a token so we
# want to check a token was generated
if choices := data.get("choices"):
# Note that text could be empty here
# e.g. for special tokens
text = choices[0].get("text")
timestamp = time.perf_counter()
# First token
if not first_chunk_received:
first_chunk_received = True
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
generated_text += text or ""
if usage := data.get("usage"):
output.output_tokens = usage.get("completion_tokens")
if first_chunk_received:
output.success = True
else:
output.success = False
output.error = (
"Never received a valid chunk to calculate TTFT."
"This response will be marked as failed!"
)
output.generated_text = generated_text
output.latency = most_recent_timestamp - st
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
async def async_request_openai_chat_completions( async def async_request_openai_chat_completions(
request_func_input: RequestFuncInput, request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None, pbar: Optional[tqdm] = None,
...@@ -470,160 +568,10 @@ def get_tokenizer( ...@@ -470,160 +568,10 @@ def get_tokenizer(
) )
async def async_request_dynamo_chat_completions(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith(
"chat/completions"
), "Dynamo Chat API URL must end with 'chat/completions'."
headers = {"Content-Type": "application/json"}
user_content = request_func_input.prompt
if request_func_input.multi_modal_content:
pass
payload = {
"model": (
request_func_input.model_name
if request_func_input.model_name
else request_func_input.model
),
"messages": [{"role": "user", "content": user_content}],
"temperature": 0.0, # keep deterministic for benchmarks unless you want entropy
"max_tokens": request_func_input.output_len,
"stream": True,
}
if request_func_input.ignore_eos:
payload["ignore_eos"] = True # if unsupported, we’ll retry w/o it below
output = RequestFuncOutput(prompt_len=request_func_input.prompt_len)
async with aiohttp.ClientSession(
trust_env=True, timeout=AIOHTTP_TIMEOUT
) as session:
st = time.perf_counter()
most_recent_timestamp = st
ttft = 0.0
got_stream_chunk = False
parts: list[str] = []
async def _do_request(json_payload):
return await session.post(url=api_url, json=json_payload, headers=headers)
try:
# First try with the requested flags
response = await _do_request(payload)
# If server errors, fall back progressively by removing the usual culprits.
if response.status >= 500:
# remove ignore_eos
if "ignore_eos" in payload:
payload.pop("ignore_eos", None)
response = await _do_request(payload)
if response.status >= 500:
# disable streaming
payload["stream"] = False
response = await _do_request(payload)
# --- STREAMING (SSE) ---
if response.status == 200 and response.headers.get(
"content-type", ""
).startswith("text/event-stream"):
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
s = chunk_bytes.decode("utf-8").lstrip()
if s.startswith(":"): # ping/keepalive
continue
if s.startswith("data:"):
s = s.partition("data:")[2].strip()
if s == "[DONE]":
continue
data = json.loads(s)
timestamp = time.perf_counter()
if choices := data.get("choices"):
got_stream_chunk = True
delta = choices[0].get("delta", {})
# preserve order; keep <think> intact
rc = delta.get("reasoning_content")
if rc is not None:
parts.append(rc)
c = delta.get("content")
if c is not None:
parts.append(c)
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft
else:
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens", output.output_tokens
)
output.generated_text = "".join(parts)
output.latency = (
most_recent_timestamp if got_stream_chunk else time.perf_counter()
) - st
output.success = got_stream_chunk
if not got_stream_chunk:
# capture error body if available
try:
output.error = await response.text()
except Exception:
output.error = "No stream token chunks received."
# --- NON-STREAMING JSON ---
elif response.status == 200:
body = await response.json()
ch0 = (body.get("choices") or [{}])[0]
msg = ch0.get("message") or {}
# Prefer reasoning + content from message
if msg.get("reasoning_content") is not None:
parts.append(msg["reasoning_content"])
if msg.get("content") is not None:
parts.append(msg["content"])
# Fallbacks (some implementations put text/rc at choice-level)
if not parts:
if ch0.get("reasoning_content") is not None:
parts.append(ch0["reasoning_content"])
if ch0.get("text") is not None:
parts.append(ch0["text"])
output.generated_text = "".join(parts)
output.latency = time.perf_counter() - st
output.ttft = 0.0
if usage := body.get("usage"):
output.output_tokens = usage.get("completion_tokens", 0)
output.success = True
else:
# Better error visibility for your “Initial test run failed”
output.success = False
output.error = f"HTTP {response.status}: " + (await response.text())
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
ASYNC_REQUEST_FUNCS = { ASYNC_REQUEST_FUNCS = {
"tgi": async_request_tgi, "tgi": async_request_tgi,
"vllm": async_request_openai_completions, "vllm": async_request_openai_completions,
"dynamo": async_request_dynamo_completions,
"lmdeploy": async_request_openai_completions, "lmdeploy": async_request_openai_completions,
"deepspeed-mii": async_request_deepspeed_mii, "deepspeed-mii": async_request_deepspeed_mii,
"openai": async_request_openai_completions, "openai": async_request_openai_completions,
...@@ -631,5 +579,4 @@ ASYNC_REQUEST_FUNCS = { ...@@ -631,5 +579,4 @@ ASYNC_REQUEST_FUNCS = {
"tensorrt-llm": async_request_trt_llm, "tensorrt-llm": async_request_trt_llm,
"scalellm": async_request_openai_completions, "scalellm": async_request_openai_completions,
"sglang": async_request_openai_completions, "sglang": async_request_openai_completions,
"dynamo": async_request_dynamo_chat_completions,
} }
...@@ -9,22 +9,51 @@ model_path="/model/" ...@@ -9,22 +9,51 @@ model_path="/model/"
head_node="localhost" head_node="localhost"
head_port=8000 head_port=8000
n_prefill=$1
n_decode=$2
total_gpus=$3
source /scripts/benchmark_utils.sh source /scripts/benchmark_utils.sh
work_dir="/scripts/vllm/" work_dir="/scripts/vllm/"
cd $work_dir cd $work_dir
chosen_isl=$3 chosen_isl=$4
chosen_osl=$4 chosen_osl=$5
concurrency_list=$5 concurrency_list=$6
IFS='x' read -r -a chosen_concurrencies <<< "$concurrency_list" IFS='x' read -r -a chosen_concurrencies <<< "$concurrency_list"
chosen_req_rate=$6 chosen_req_rate=$7
echo "Config ${chosen_isl}; ${chosen_osl}; ${chosen_concurrencies[@]}; ${chosen_req_rate}" echo "Config ${chosen_isl}; ${chosen_osl}; ${chosen_concurrencies[@]}; ${chosen_req_rate}"
wait_for_model $head_node $head_port 5 2400 60 wait_for_model_timeout=1500 # 25 minutes
wait_for_model_check_interval=5 # check interval -> 5s
wait_for_model_report_interval=60 # wait_for_model report interval -> 60s
wait_for_model $head_node $head_port $n_prefill $n_decode $wait_for_model_check_interval $wait_for_model_timeout $wait_for_model_report_interval
set -e set -e
warmup_model $head_node $head_port $model_name $model_path "${chosen_isl}x${chosen_osl}x10000x10000x${chosen_req_rate}" # Warmup the model
warmup_isl=$chosen_isl
warmup_osl=$chosen_osl
warmup_prompts=10000
warmup_concurrencies=10000
warmup_req_rate=250
set -x
python3 benchmark_serving.py \
--model ${model_name} --tokenizer ${model_path} \
--host $head_node --port $head_port \
--backend "dynamo" --endpoint /v1/completions \
--disable-tqdm \
--dataset-name random \
--num-prompts "$warmup_prompts" \
--random-input-len $warmup_isl \
--random-output-len $warmup_osl \
--random-range-ratio 0.8 \
--ignore-eos \
--request-rate ${warmup_req_rate} \
--percentile-metrics ttft,tpot,itl,e2el \
--max-concurrency "$warmup_concurrencies"
set +x
set +e set +e
result_dir="/logs/vllm_isl_${chosen_isl}_osl_${chosen_osl}" result_dir="/logs/vllm_isl_${chosen_isl}_osl_${chosen_osl}"
...@@ -35,13 +64,13 @@ for concurrency in "${chosen_concurrencies[@]}" ...@@ -35,13 +64,13 @@ for concurrency in "${chosen_concurrencies[@]}"
do do
num_prompts=$((concurrency * 5)) num_prompts=$((concurrency * 5))
echo "Running benchmark with concurrency: $concurrency and num-prompts: $num_prompts, writing to file ${result_dir}" echo "Running benchmark with concurrency: $concurrency and num-prompts: $num_prompts, writing to file ${result_dir}"
result_filename="isl_${chosen_isl}_osl_${chosen_osl}_concurrency_${concurrency}_req_rate_${chosen_req_rate}.json" result_filename="isl_${chosen_isl}_osl_${chosen_osl}_concurrency_${concurrency}_req_rate_${chosen_req_rate}_gpus${total_gpus}.json"
set -x set -x
python3 benchmark_serving.py \ python3 benchmark_serving.py \
--model ${model_name} --tokenizer ${model_path} \ --model ${model_name} --tokenizer ${model_path} \
--host $head_node --port $head_port \ --host $head_node --port $head_port \
--backend "dynamo" --endpoint /v1/chat/completions \ --backend "dynamo" --endpoint /v1/completions \
--disable-tqdm \ --disable-tqdm \
--dataset-name random \ --dataset-name random \
--num-prompts "$num_prompts" \ --num-prompts "$num_prompts" \
......
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
usage() {
cat << 'USAGE'
This script aims to provide a one-liner call to the submit_job_script.py,
so that the deployment process can be further simplified.
To use this script, fill in the following script and run it under your `slurm_jobs` directory:
======== begin script area ========
export SLURM_ACCOUNT=
export SLURM_PARTITION=
export TIME_LIMIT=
# Add path to your DSR1-FP8 model directory here
export MODEL_PATH=
# This path should contain the deepep.json and optionally init expert locations.
# Please refer to the README for more detail.
export CONFIG_DIR=
# Add path to your container image here, either as a link or as a cached file
export CONTAINER_IMAGE=
bash submit_disagg.sh \
$PREFILL_NODES $PREFILL_WORKERS $DECODE_NODES $DECODE_WORKERS \
$ADDITIONAL_FRONTENDS \
$ISL $OSL $CONCURRENCIES $REQUEST_RATE
======== end script area ========
USAGE
}
check_env() {
local name="$1"
if [[ -z "${!name:-}" ]]; then
echo "Error: ${name} not specified" >&2
usage >&2
exit 1
fi
}
check_env SLURM_ACCOUNT
check_env SLURM_PARTITION
check_env TIME_LIMIT
check_env MODEL_PATH
check_env CONFIG_DIR
check_env CONTAINER_IMAGE
GPU_TYPE="gb200-fp8"
GPUS_PER_NODE=4
: "${NETWORK_INTERFACE:=enP6p9s0np0}"
# COMMAND_LINE ARGS
PREFILL_NODES=$1
PREFILL_WORKERS=$2
DECODE_NODES=$3
DECODE_WORKERS=$4
N_ADDITIONAL_FRONTENDS=$5
ISL=$6
OSL=$7
CONCURRENCIES=$8
REQUEST_RATE=$9
RETRIES=1 # defaults to retry the job 1 time to avoid transient errors
# Should not need retries
profiler_args="type=vllm; isl=${ISL}; osl=${OSL}; concurrencies=${CONCURRENCIES}; req-rate=${REQUEST_RATE}"
USE_INIT_LOCATIONS=()
if [[ $PREFILL_NODES -eq 6 ]] && [[ $PREFILL_WORKERS -eq 3 ]] && [[ $DECODE_NODES -eq 12 ]] && [[ $DECODE_WORKERS -eq 1 ]]; then
USE_INIT_LOCATIONS=(--use-init-location)
fi
command=(
python3 submit_job_script.py
--account $SLURM_ACCOUNT --partition $SLURM_PARTITION --time-limit $TIME_LIMIT
--template job_script_template.j2
--model-dir $MODEL_PATH --config-dir $CONFIG_DIR
--container-image $CONTAINER_IMAGE
--gpu-type $GPU_TYPE --gpus-per-node $GPUS_PER_NODE --network-interface $NETWORK_INTERFACE
--prefill-nodes $PREFILL_NODES --prefill-workers $PREFILL_WORKERS
--decode-nodes $DECODE_NODES --decode-workers $DECODE_WORKERS
--enable-multiple-frontends --num-additional-frontends $N_ADDITIONAL_FRONTENDS ${USE_INIT_LOCATIONS[@]}
--profiler "${profiler_args}"
--retries $RETRIES
)
"${command[@]}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment