"docs/integrations/lmcache-integration.md" did not exist on "784da90ed06c31a460454023f8b59aa42a5ea3e6"
Unverified Commit 7860861f authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

feat: Add TTFT and ITL Interpolation to Profiling Script (#1159)


Co-authored-by: default avatarroot <root@kkranen-dt.nvidia.com>
parent 3bde1e45
...@@ -31,6 +31,7 @@ protobuf==5.27.3 ...@@ -31,6 +31,7 @@ protobuf==5.27.3
pydantic==2.7.1 pydantic==2.7.1
pyright pyright
PyYAML PyYAML
scikit-learn
sentencepiece sentencepiece
tensorboard==2.19.0 tensorboard==2.19.0
tensorboardX==2.6.2.2 tensorboardX==2.6.2.2
......
...@@ -86,6 +86,8 @@ The following information will be printed out in the terminal: ...@@ -86,6 +86,8 @@ The following information will be printed out in the terminal:
2025-05-16 15:20:24 - __main__ - INFO - Suggested planner upper/lower bound for decode kv cache utilization: 0.20/0.10 2025-05-16 15:20:24 - __main__ - INFO - Suggested planner upper/lower bound for decode kv cache utilization: 0.20/0.10
``` ```
After finding the best TP size for prefill and decode, the script will then interpolate the TTFT with ISL and ITL with active KV cache and decode context length. This is to provide a more accurate estimation of the performance when ISL and OSL changes. The results will be saved to `<output_dir>/<decode/prefill>_tp<best_tp>_interploation`.
## Usage ## Usage
The planner is started automatically as part of Dynamo pipelines when running `dynamo serve`. You can configure the planner just as you would any other component in your pipeline either via YAML configuration or through CLI arguments. The planner is started automatically as part of Dynamo pipelines when running `dynamo serve`. You can configure the planner just as you would any other component in your pipeline either via YAML configuration or through CLI arguments.
......
...@@ -28,6 +28,8 @@ import matplotlib.pyplot as plt ...@@ -28,6 +28,8 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
import requests import requests
import yaml import yaml
from matplotlib import cm
from scipy.interpolate import griddata
DECODE_NUM_REQUESTS_RANGE = [ DECODE_NUM_REQUESTS_RANGE = [
1, 1,
...@@ -274,6 +276,29 @@ def get_port(config: dict) -> int: ...@@ -274,6 +276,29 @@ def get_port(config: dict) -> int:
return config["Frontend"]["port"] return config["Frontend"]["port"]
def shutdown_deployment(dynamo_process):
os.killpg(os.getpgid(dynamo_process.pid), signal.SIGINT)
dynamo_process.communicate()
try:
current_pid = os.getpid()
ps_cmd = ["ps", "-ef"]
ps_output = subprocess.check_output(ps_cmd, text=True)
for line in ps_output.splitlines():
if "python" in line.lower():
parts = line.split()
if len(parts) >= 2:
try:
pid = int(parts[1])
if pid != current_pid: # Exclude current process
os.kill(pid, signal.SIGKILL)
except ValueError:
continue
except Exception as e:
logger.error(f"Error killing Python processes: {e}")
time.sleep(5)
def wait_for_server_ready(model_name: str, port: int, timeout: int = 300): def wait_for_server_ready(model_name: str, port: int, timeout: int = 300):
logger.info("Waiting for the server to be ready...") logger.info("Waiting for the server to be ready...")
endpoint_url = f"http://localhost:{port}/v1/chat/completions" endpoint_url = f"http://localhost:{port}/v1/chat/completions"
...@@ -332,6 +357,79 @@ def get_gap_result(artifact_dir: str) -> dict: ...@@ -332,6 +357,79 @@ def get_gap_result(artifact_dir: str) -> dict:
return json.load(f) return json.load(f)
def benchmark_prefill(isl, genai_perf_artifact_dir, model_name, port):
logger.info(f"Running genai-perf with isl {isl}")
genai_perf_cmd = get_prefill_genai_perf_cmd(
isl, genai_perf_artifact_dir, model=model_name, port=port
)
gap_process = subprocess.Popen(
genai_perf_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
stdout, stderr = gap_process.communicate()
if gap_process.returncode == 0:
logger.info("Genai-perf profiling completed successfully")
logger.info(stdout)
gap_result = get_gap_result(genai_perf_artifact_dir)
return gap_result
else:
logger.error(f"Genai-perf failed with error code: {gap_process.returncode}")
logger.error(f"stderr: {stderr}")
return None
def benchmark_decode(isl, osl, num_request, genai_perf_artifact_dir, model_name, port):
logger.info(f"Profiling decode with num_request {num_request}...")
# first warm-up the engine by pre-computing all prefill tokens
# we use the same random seed to make sure the prompt is the same
seed = random.randint(0, 1000000)
genai_perf_cmd = get_decode_genai_perf_cmd(
args.isl,
args.osl,
f"{genai_perf_artifact_dir}_warmup",
num_request,
seed=seed,
model=model_name,
port=port,
)
gap_process = subprocess.Popen(
genai_perf_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
gap_process.communicate()
# then send out the real requests, hopefully, this will skip all prefill computation
genai_perf_cmd = get_decode_genai_perf_cmd(
args.isl,
args.osl,
genai_perf_artifact_dir,
num_request,
seed=seed,
model=model_name,
port=port,
)
gap_process = subprocess.Popen(
genai_perf_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
stdout, stderr = gap_process.communicate()
if gap_process.returncode == 0:
logger.info("Genai-perf profiling completed successfully")
logger.info(stdout)
gap_result = get_gap_result(genai_perf_artifact_dir)
return gap_result
else:
logger.error(f"Genai-perf failed with error code: {gap_process.returncode}")
logger.error(f"stderr: {stderr}")
return None
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
...@@ -355,6 +453,25 @@ if __name__ == "__main__": ...@@ -355,6 +453,25 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--itl", type=int, default=5, help="target Inter Token Latency in ms" "--itl", type=int, default=5, help="target Inter Token Latency in ms"
) )
# below are arguments used for interpolating TTFT and ITL under different ISL/OSL
parser.add_argument(
"--max-context-length",
type=int,
default=16384,
help="maximum context length supported by the served model",
)
parser.add_argument(
"--prefill-interpolation-granularity",
type=int,
default=16,
help="how many samples to benchmark to interpolate TTFT under different ISL",
)
parser.add_argument(
"--decode-interpolation-granularity",
type=int,
default=6,
help="how many samples to benchmark to interpolate ITL under different active kv cache size and decode context length",
)
args = parser.parse_args() args = parser.parse_args()
with open(args.config, "r") as f: with open(args.config, "r") as f:
...@@ -379,12 +496,12 @@ if __name__ == "__main__": ...@@ -379,12 +496,12 @@ if __name__ == "__main__":
prefill_config = convert_config(config, "prefill") prefill_config = convert_config(config, "prefill")
for tp_size in profile_tp_size: for tp_size in profile_tp_size:
logger.info(f"Profiling prefill with TP size {tp_size}...") logger.info(f"Profiling prefill with TP size {tp_size}...")
prefill_config = set_config_tp_size(prefill_config, tp_size)
logger.info(f"Dynamo config: {prefill_config}") logger.info(f"Dynamo config: {prefill_config}")
work_dir = f"{args.output_dir}/prefill_tp{tp_size}" work_dir = f"{args.output_dir}/prefill_tp{tp_size}"
os.makedirs(work_dir, exist_ok=True) os.makedirs(work_dir, exist_ok=True)
prefill_config = set_config_tp_size(prefill_config, tp_size)
prefill_config_fn = f"{work_dir}/config.yaml" prefill_config_fn = f"{work_dir}/config.yaml"
dynamo_log_fn = f"{work_dir}/dynamo.log" dynamo_log_fn = f"{work_dir}/dynamo.log"
with open(prefill_config_fn, "w") as f: with open(prefill_config_fn, "w") as f:
...@@ -407,30 +524,17 @@ if __name__ == "__main__": ...@@ -407,30 +524,17 @@ if __name__ == "__main__":
break break
# run genai-perf # run genai-perf
logger.info(f"Running genai-perf with isl {args.isl}")
genai_perf_artifact_dir = f"{work_dir}/gap_isl{args.isl}" genai_perf_artifact_dir = f"{work_dir}/gap_isl{args.isl}"
genai_perf_cmd = get_prefill_genai_perf_cmd( gap_result = benchmark_prefill(
args.isl, genai_perf_artifact_dir, model=model_name, port=port args.isl, genai_perf_artifact_dir, model_name, port
) )
gap_process = subprocess.Popen( if gap_result is not None:
genai_perf_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
stdout, stderr = gap_process.communicate()
if gap_process.returncode == 0:
logger.info("Genai-perf profiling completed successfully")
logger.info(stdout)
gap_result = get_gap_result(genai_perf_artifact_dir)
ttft = gap_result["time_to_first_token"]["avg"] ttft = gap_result["time_to_first_token"]["avg"]
prefill_tp_size.append(tp_size) prefill_tp_size.append(tp_size)
prefill_ttft.append(ttft) prefill_ttft.append(ttft)
prefill_thpt_per_gpu.append(args.isl / ttft / tp_size * 1000) prefill_thpt_per_gpu.append(args.isl / ttft / tp_size * 1000)
else:
logger.error(f"Genai-perf failed with error code: {gap_process.returncode}")
logger.error(f"stderr: {stderr}")
# Send SIGINT to the dynamo process to terminate it gracefully shutdown_deployment(dynamo_process)
os.killpg(os.getpgid(dynamo_process.pid), signal.SIGINT)
dynamo_process.communicate()
# Plot the results as a 2D scatter plot # Plot the results as a 2D scatter plot
if prefill_tp_size and prefill_ttft and prefill_thpt_per_gpu: if prefill_tp_size and prefill_ttft and prefill_thpt_per_gpu:
...@@ -471,12 +575,12 @@ if __name__ == "__main__": ...@@ -471,12 +575,12 @@ if __name__ == "__main__":
decode_config = convert_config(config, "decode") decode_config = convert_config(config, "decode")
for tp_size in profile_tp_size: for tp_size in profile_tp_size:
logger.info(f"Profiling decode with TP size {tp_size}...") logger.info(f"Profiling decode with TP size {tp_size}...")
decode_config = set_config_tp_size(decode_config, tp_size)
logger.info(f"Dynamo config: {decode_config}") logger.info(f"Dynamo config: {decode_config}")
work_dir = f"{args.output_dir}/decode_tp{tp_size}" work_dir = f"{args.output_dir}/decode_tp{tp_size}"
os.makedirs(work_dir, exist_ok=True) os.makedirs(work_dir, exist_ok=True)
decode_config = set_config_tp_size(decode_config, tp_size)
decode_config_fn = f"{work_dir}/config.yaml" decode_config_fn = f"{work_dir}/config.yaml"
dynamo_log_fn = f"{work_dir}/dynamo.log" dynamo_log_fn = f"{work_dir}/dynamo.log"
with open(decode_config_fn, "w") as f: with open(decode_config_fn, "w") as f:
...@@ -510,50 +614,16 @@ if __name__ == "__main__": ...@@ -510,50 +614,16 @@ if __name__ == "__main__":
engine_decode_itl = [] engine_decode_itl = []
engine_decode_thpt_per_gpu = [] engine_decode_thpt_per_gpu = []
for num_request in sweep_num_request: for num_request in sweep_num_request:
logger.info(f"Profiling decode with num_request {num_request}...")
# first warm-up the engine by pre-computing all prefill tokens
# we use the same random seed to make sure the prompt is the same
seed = random.randint(0, 1000000)
genai_perf_artifact_dir = f"{work_dir}/gap_request{num_request}_isl{args.isl}_osl{args.osl}_n{num_request}_warmup"
genai_perf_cmd = get_decode_genai_perf_cmd(
args.isl,
args.osl,
genai_perf_artifact_dir,
num_request,
seed=seed,
model=model_name,
port=port,
)
gap_process = subprocess.Popen(
genai_perf_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
gap_process.communicate()
# then send out the real requests, hopefully, this will skip all prefill computation
genai_perf_artifact_dir = f"{work_dir}/gap_request{num_request}_isl{args.isl}_osl{args.osl}_n{num_request}" genai_perf_artifact_dir = f"{work_dir}/gap_request{num_request}_isl{args.isl}_osl{args.osl}_n{num_request}"
genai_perf_cmd = get_decode_genai_perf_cmd( gap_result = benchmark_decode(
args.isl, args.isl,
args.osl, args.osl,
genai_perf_artifact_dir,
num_request, num_request,
seed=seed, genai_perf_artifact_dir,
model=model_name, model_name,
port=port, port,
)
gap_process = subprocess.Popen(
genai_perf_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
) )
stdout, stderr = gap_process.communicate() if gap_result is not None:
if gap_process.returncode == 0:
logger.info("Genai-perf profiling completed successfully")
logger.info(stdout)
gap_result = get_gap_result(genai_perf_artifact_dir)
itl = gap_result["inter_token_latency"]["avg"] itl = gap_result["inter_token_latency"]["avg"]
thpt_per_gpu = gap_result["output_token_throughput"]["avg"] / tp_size thpt_per_gpu = gap_result["output_token_throughput"]["avg"] / tp_size
engine_decode_itl.append(itl) engine_decode_itl.append(itl)
...@@ -563,15 +633,8 @@ if __name__ == "__main__": ...@@ -563,15 +633,8 @@ if __name__ == "__main__":
decode_thpt_per_gpu.append(thpt_per_gpu) decode_thpt_per_gpu.append(thpt_per_gpu)
decode_concurrency.append(num_request) decode_concurrency.append(num_request)
decode_kv_cache_size.append(max_kv_tokens) decode_kv_cache_size.append(max_kv_tokens)
else:
logger.error(
f"Genai-perf failed with error code: {gap_process.returncode}"
)
logger.error(f"stderr: {stderr}")
# Send SIGINT to the dynamo process to terminate it gracefully shutdown_deployment(dynamo_process)
os.killpg(os.getpgid(dynamo_process.pid), signal.SIGINT)
dynamo_process.communicate()
# Plot a line in the 2d plot # Plot a line in the 2d plot
plt.plot(engine_decode_itl, engine_decode_thpt_per_gpu, label=f"TP{tp_size}") plt.plot(engine_decode_itl, engine_decode_thpt_per_gpu, label=f"TP{tp_size}")
...@@ -645,3 +708,243 @@ if __name__ == "__main__": ...@@ -645,3 +708,243 @@ if __name__ == "__main__":
logger.info( logger.info(
f"Suggested planner upper/lower bound for decode kv cache utilization: {min(1, selected_decode_kv_cache_utilization + 0.2):.2f}/{max(0.1, selected_decode_kv_cache_utilization - 0.2):.2f}" f"Suggested planner upper/lower bound for decode kv cache utilization: {min(1, selected_decode_kv_cache_utilization + 0.2):.2f}/{max(0.1, selected_decode_kv_cache_utilization - 0.2):.2f}"
) )
# interpolate ISL - TTFT with best prefill TP
best_prefill_tp = prefill_tp_size[selected_prefill_idx]
prefill_isl = []
prefill_ttft = []
prefill_thpt_per_gpu = []
logger.info(
f"Profiling prefill under best TP {best_prefill_tp} with different ISL..."
)
prefill_config = convert_config(config, "prefill")
prefill_config = set_config_tp_size(prefill_config, tp_size)
logger.info(f"Dynamo config: {prefill_config}")
work_dir = f"{args.output_dir}/prefill_tp{tp_size}_interpolation"
os.makedirs(work_dir, exist_ok=True)
prefill_config_fn = f"{work_dir}/config.yaml"
dynamo_log_fn = f"{work_dir}/dynamo.log"
with open(prefill_config_fn, "w") as f:
yaml.dump(prefill_config, f)
# Start the dynamo serve process
logger.info(f"Starting dynamo serve with TP size {tp_size}...")
dynamo_serve_cmd = get_dynamo_serve_cmd(prefill_config_fn)
with open(dynamo_log_fn, "w") as dynamo_log_f:
dynamo_process = subprocess.Popen(
dynamo_serve_cmd,
stdout=dynamo_log_f,
stderr=subprocess.STDOUT,
text=True,
preexec_fn=os.setsid, # Use process group for clean termination
)
if not wait_for_server_ready(model_name, port):
logger.error(f"Server did not become ready, skip profiling tp={tp_size}")
else:
for isl in range(
100,
args.max_context_length,
(args.max_context_length - 100) // args.prefill_interpolation_granularity,
):
# run genai-perf
genai_perf_artifact_dir = f"{work_dir}/gap_isl{isl}"
gap_result = benchmark_prefill(
isl, genai_perf_artifact_dir, model_name, port
)
if gap_result is not None:
ttft = gap_result["time_to_first_token"]["avg"]
prefill_isl.append(isl)
prefill_ttft.append(ttft)
prefill_thpt_per_gpu.append(isl / ttft / best_prefill_tp * 1000)
shutdown_deployment(dynamo_process)
# Interpolate prefill_ttft vs prefill_isl with quadratic function (y=ax^2+bx+c)
if len(prefill_isl) > 2:
logger.info("Interpolating prefill TTFT and throughput vs ISL...")
# Convert to numpy arrays for easier manipulation
prefill_isl_np = np.array(prefill_isl)
prefill_ttft_np = np.array(prefill_ttft)
prefill_thpt_per_gpu_np = np.array(prefill_thpt_per_gpu)
# Fit quadratic functions
ttft_coeffs = np.polyfit(prefill_isl_np, prefill_ttft_np, 2)
thpt_coeffs = np.polyfit(prefill_isl_np, prefill_thpt_per_gpu_np, 2)
# Create interpolation functions
ttft_poly = np.poly1d(ttft_coeffs)
thpt_poly = np.poly1d(thpt_coeffs)
# Generate points for smooth curves
x_interp = np.linspace(min(prefill_isl_np), max(prefill_isl_np), 100)
ttft_interp = ttft_poly(x_interp)
thpt_interp = thpt_poly(x_interp)
# Plot TTFT vs ISL
plt.figure(figsize=(10, 6))
plt.scatter(prefill_isl_np, prefill_ttft_np, s=100, label="Measured data")
plt.plot(
x_interp,
ttft_interp,
"r-",
label=f"Quadratic fit: {ttft_coeffs[0]:.2e}x² + {ttft_coeffs[1]:.2e}x + {ttft_coeffs[2]:.2e}",
)
plt.title("Prefill TTFT vs Input Sequence Length")
plt.xlabel("Input Sequence Length (tokens)")
plt.ylabel("Time to First Token (ms)")
plt.grid(True)
plt.legend()
ttft_plot_path = f"{work_dir}/prefill_ttft_interpolation.png"
plt.savefig(ttft_plot_path, dpi=300)
logger.info(f"TTFT interpolation plot saved to {ttft_plot_path}")
plt.close()
# Plot Throughput vs ISL
plt.figure(figsize=(10, 6))
plt.scatter(
prefill_isl_np, prefill_thpt_per_gpu_np, s=100, label="Measured data"
)
plt.plot(
x_interp,
thpt_interp,
"g-",
label=f"Quadratic fit: {thpt_coeffs[0]:.2e}x² + {thpt_coeffs[1]:.2e}x + {thpt_coeffs[2]:.2e}",
)
plt.title("Prefill Throughput vs Input Sequence Length")
plt.xlabel("Input Sequence Length (tokens)")
plt.ylabel("Prefill throughput per GPU (tokens/s/GPU)")
plt.grid(True)
plt.legend()
thpt_plot_path = f"{work_dir}/prefill_throughput_interpolation.png"
plt.savefig(thpt_plot_path, dpi=300)
logger.info(
f"Prefill throughput per GPU interpolation plot saved to {thpt_plot_path}"
)
plt.close()
else:
logger.warning(
"Not enough data points to perform interpolation (need at least 3 points)"
)
# interpolate ITL - Active_KV_Cache - Decode_Context_Length with best decode TP
x_kv_usage = []
y_context_length = []
z_itl = []
z_thpt_per_gpu = []
best_decode_tp = decode_tp_size[selected_decode_idx]
logger.info(f"Profiling decode with TP size {best_decode_tp}...")
decode_config = set_config_tp_size(decode_config, best_decode_tp)
logger.info(f"Dynamo config: {decode_config}")
work_dir = f"{args.output_dir}/decode_tp{best_decode_tp}_interpolation"
os.makedirs(work_dir, exist_ok=True)
decode_config_fn = f"{work_dir}/config.yaml"
dynamo_log_fn = f"{work_dir}/dynamo.log"
with open(decode_config_fn, "w") as f:
yaml.dump(decode_config, f)
# Start the dynamo serve process
logger.info(f"Starting dynamo serve with TP size {tp_size}...")
dynamo_serve_cmd = get_dynamo_serve_cmd(decode_config_fn)
with open(dynamo_log_fn, "w") as dynamo_log_f:
dynamo_process = subprocess.Popen(
dynamo_serve_cmd,
stdout=dynamo_log_f,
stderr=subprocess.STDOUT,
text=True,
preexec_fn=os.setsid, # Use process group for clean termination
)
if not wait_for_server_ready(model_name, port):
logger.error(f"Server did not become ready, skip profiling tp={tp_size}")
else:
max_kv_tokens = get_kv_cache_size_from_dynamo_log(dynamo_log_fn)
osl = 500 # not too large to reduce ITL variance, not too small to have stable measurement
for isl in range(
100,
args.max_context_length - osl,
(args.max_context_length - osl) // args.decode_interpolation_granularity,
):
max_concurrency = max_kv_tokens // (isl + osl)
sweep_num_request = list(
range(
1,
max_concurrency,
max_concurrency // args.decode_interpolation_granularity,
)
)
for num_request in sweep_num_request:
genai_perf_artifact_dir = (
f"{work_dir}/gap_isl{isl}_osl{osl}_n{num_request}"
)
gap_result = benchmark_decode(
isl, osl, num_request, genai_perf_artifact_dir, model_name, port
)
if gap_result is not None:
itl = gap_result["inter_token_latency"]["avg"]
x_kv_usage.append((isl + osl / 2) * num_request / max_kv_tokens)
y_context_length.append(isl + osl / 2)
z_itl.append(itl)
z_thpt_per_gpu.append(
gap_result["output_token_throughput"]["avg"] / tp_size
)
shutdown_deployment(dynamo_process)
# Save the data points to a .npz file
save_path = f"{work_dir}/decode_tp{tp_size}_data.npz"
np.savez(
save_path,
x_kv_usage=np.array(x_kv_usage),
y_context_length=np.array(y_context_length),
z_itl=np.array(z_itl),
z_thpt_per_gpu=np.array(z_thpt_per_gpu),
)
logger.info(f"Saved data points to {save_path}")
xi = np.linspace(min(x_kv_usage), max(x_kv_usage), 100)
yi = np.linspace(min(y_context_length), max(y_context_length), 100)
X, Y = np.meshgrid(xi, yi)
Z = griddata((x_kv_usage, y_context_length), z_itl, (X, Y), method="cubic")
fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(111, projection="3d") # type: ignore
# Create the surface plot with customizations
surf = ax.plot_surface( # type: ignore
X,
Y,
Z,
cmap=cm.coolwarm, # type: ignore
linewidth=0.2,
antialiased=True,
alpha=0.8,
)
# Add a color bar with custom settings
cbar = fig.colorbar(surf, ax=ax, shrink=0.5, aspect=5)
cbar.set_label("Z Value", fontsize=12)
cbar.ax.tick_params(labelsize=10)
# Add labels with custom font sizes
ax.set_xlabel("Active KV Percentage", fontsize=12)
ax.set_ylabel("Decode Context Length", fontsize=12)
ax.set_zlabel("ITL", fontsize=12) # type: ignore
# Set viewing angle
ax.view_init(elev=30, azim=45) # type: ignore
ax.grid(True)
ax.tick_params(axis="both", which="major", labelsize=10)
logger.info(f"Saving ITL surface plot to {work_dir}/decode_tp{tp_size}.png")
plt.savefig(f"{work_dir}/decode_tp{tp_size}.png", dpi=300, bbox_inches="tight")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment