Commit 7fde7063 authored by helloyongyang's avatar helloyongyang
Browse files

support lru_cache

parent 7fc021e2
......@@ -347,7 +347,7 @@ if __name__ == "__main__":
gc.collect()
torch.cuda.empty_cache()
if ENABLE_GRAPH_MODE:
if CHECK_ENABLE_GRAPH_MODE():
default_runner = DefaultRunner(model, inputs)
runner = GraphRunner(default_runner)
else:
......
......@@ -26,7 +26,7 @@ class HunyuanTransformerInfer:
def set_scheduler(self, scheduler):
self.scheduler = scheduler
@torch.compile(disable=not ENABLE_GRAPH_MODE)
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None):
return self.infer_func(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
......
......@@ -35,7 +35,7 @@ class WanTransformerInfer:
cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32)
return cu_seqlens_q, cu_seqlens_k, lq, lk
@torch.compile(disable=not ENABLE_GRAPH_MODE)
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
......
import os
from functools import lru_cache
global ENABLE_PROFILING_DEBUG
ENABLE_PROFILING_DEBUG = os.getenv("ENABLE_PROFILING_DEBUG", "false").lower() == "true"
@lru_cache(maxsize=None)
def CHECK_ENABLE_PROFILING_DEBUG():
ENABLE_PROFILING_DEBUG = os.getenv("ENABLE_PROFILING_DEBUG", "false").lower() == "true"
return ENABLE_PROFILING_DEBUG
global ENABLE_GRAPH_MODE
ENABLE_GRAPH_MODE = os.getenv("ENABLE_GRAPH_MODE", "false").lower() == "true"
@lru_cache(maxsize=None)
def CHECK_ENABLE_GRAPH_MODE():
ENABLE_GRAPH_MODE = os.getenv("ENABLE_GRAPH_MODE", "false").lower() == "true"
return ENABLE_GRAPH_MODE
......@@ -33,4 +33,4 @@ class _NullContext(ContextDecorator):
ProfilingContext = _ProfilingContext
ProfilingContext4Debug = _ProfilingContext if ENABLE_PROFILING_DEBUG else _NullContext
ProfilingContext4Debug = _ProfilingContext if CHECK_ENABLE_PROFILING_DEBUG() else _NullContext
......@@ -24,6 +24,7 @@ fi
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python ${lightx2v_path}/lightx2v/__main__.py \
--model_cls hunyuan \
......
......@@ -24,6 +24,7 @@ fi
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python ${lightx2v_path}/lightx2v/__main__.py \
--model_cls hunyuan \
......
......@@ -30,6 +30,7 @@ fi
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python ${lightx2v_path}/lightx2v/__main__.py \
--model_cls wan2.1 \
......
......@@ -30,6 +30,7 @@ fi
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python ${lightx2v_path}/lightx2v/__main__.py \
--model_cls wan2.1 \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment