Unverified Commit 9f662501 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Move torch.compile configs into cuda_graph_runner.py (#993)

parent ab787594
......@@ -71,6 +71,18 @@ def patch_model(
tp_group.ca_comm = backup_ca_comm
def set_torch_compile_config():
import torch._dynamo.config
import torch._inductor.config
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
# FIXME: tmp workaround
torch._dynamo.config.accumulated_cache_size_limit = 1024
class CudaGraphRunner:
def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile):
self.model_runner = model_runner
......@@ -112,6 +124,9 @@ class CudaGraphRunner:
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
if use_torch_compile:
set_torch_compile_config()
def can_run(self, batch_size):
return batch_size < self.max_bs
......
......@@ -74,7 +74,6 @@ from sglang.srt.utils import (
enable_show_time_cost,
kill_child_process,
maybe_set_triton_cache_manager,
set_torch_compile_config,
set_ulimit,
)
from sglang.utils import get_exception_traceback
......@@ -347,10 +346,6 @@ def _set_envs_and_config(server_args: ServerArgs):
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
maybe_set_triton_cache_manager()
# Set torch compile config
if server_args.enable_torch_compile:
set_torch_compile_config()
# Set global chat template
if server_args.chat_template:
# TODO: replace this with huggingface transformers template
......
......@@ -622,19 +622,6 @@ def receive_addrs(model_port_args, server_args):
dist.destroy_process_group()
def set_torch_compile_config():
# The following configurations are for torch compile optimizations
import torch._dynamo.config
import torch._inductor.config
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
# FIXME: tmp workaround
torch._dynamo.config.accumulated_cache_size_limit = 256
def set_ulimit(target_soft_limit=65535):
resource_type = resource.RLIMIT_NOFILE
current_soft, current_hard = resource.getrlimit(resource_type)
......
......@@ -6,7 +6,6 @@ import json
import logging
import signal
import sys
import threading
import traceback
import urllib.request
from concurrent.futures import ThreadPoolExecutor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment