"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "73acebb8cfbd1d2954cabe1af4185f9994e61917"
Unverified Commit 9f662501 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Move torch.compile configs into cuda_graph_runner.py (#993)

parent ab787594
...@@ -71,6 +71,18 @@ def patch_model( ...@@ -71,6 +71,18 @@ def patch_model(
tp_group.ca_comm = backup_ca_comm tp_group.ca_comm = backup_ca_comm
def set_torch_compile_config():
import torch._dynamo.config
import torch._inductor.config
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
# FIXME: tmp workaround
torch._dynamo.config.accumulated_cache_size_limit = 1024
class CudaGraphRunner: class CudaGraphRunner:
def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile): def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile):
self.model_runner = model_runner self.model_runner = model_runner
...@@ -112,6 +124,9 @@ class CudaGraphRunner: ...@@ -112,6 +124,9 @@ class CudaGraphRunner:
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else [] self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
if use_torch_compile:
set_torch_compile_config()
def can_run(self, batch_size): def can_run(self, batch_size):
return batch_size < self.max_bs return batch_size < self.max_bs
......
...@@ -74,7 +74,6 @@ from sglang.srt.utils import ( ...@@ -74,7 +74,6 @@ from sglang.srt.utils import (
enable_show_time_cost, enable_show_time_cost,
kill_child_process, kill_child_process,
maybe_set_triton_cache_manager, maybe_set_triton_cache_manager,
set_torch_compile_config,
set_ulimit, set_ulimit,
) )
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
...@@ -347,10 +346,6 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -347,10 +346,6 @@ def _set_envs_and_config(server_args: ServerArgs):
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
maybe_set_triton_cache_manager() maybe_set_triton_cache_manager()
# Set torch compile config
if server_args.enable_torch_compile:
set_torch_compile_config()
# Set global chat template # Set global chat template
if server_args.chat_template: if server_args.chat_template:
# TODO: replace this with huggingface transformers template # TODO: replace this with huggingface transformers template
......
...@@ -622,19 +622,6 @@ def receive_addrs(model_port_args, server_args): ...@@ -622,19 +622,6 @@ def receive_addrs(model_port_args, server_args):
dist.destroy_process_group() dist.destroy_process_group()
def set_torch_compile_config():
# The following configurations are for torch compile optimizations
import torch._dynamo.config
import torch._inductor.config
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
# FIXME: tmp workaround
torch._dynamo.config.accumulated_cache_size_limit = 256
def set_ulimit(target_soft_limit=65535): def set_ulimit(target_soft_limit=65535):
resource_type = resource.RLIMIT_NOFILE resource_type = resource.RLIMIT_NOFILE
current_soft, current_hard = resource.getrlimit(resource_type) current_soft, current_hard = resource.getrlimit(resource_type)
......
...@@ -6,7 +6,6 @@ import json ...@@ -6,7 +6,6 @@ import json
import logging import logging
import signal import signal
import sys import sys
import threading
import traceback import traceback
import urllib.request import urllib.request
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment