"composable_kernel/include/utility/config.hpp" did not exist on "3406a1148adf283f31a345549b63de633a4ff61e"
Commit 7fde7063 authored by helloyongyang's avatar helloyongyang
Browse files

support lru_cache

parent 7fc021e2
...@@ -347,7 +347,7 @@ if __name__ == "__main__": ...@@ -347,7 +347,7 @@ if __name__ == "__main__":
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
if ENABLE_GRAPH_MODE: if CHECK_ENABLE_GRAPH_MODE():
default_runner = DefaultRunner(model, inputs) default_runner = DefaultRunner(model, inputs)
runner = GraphRunner(default_runner) runner = GraphRunner(default_runner)
else: else:
......
...@@ -26,7 +26,7 @@ class HunyuanTransformerInfer: ...@@ -26,7 +26,7 @@ class HunyuanTransformerInfer:
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
@torch.compile(disable=not ENABLE_GRAPH_MODE) @torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None): def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None):
return self.infer_func(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num) return self.infer_func(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
......
...@@ -35,7 +35,7 @@ class WanTransformerInfer: ...@@ -35,7 +35,7 @@ class WanTransformerInfer:
cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32) cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32)
return cu_seqlens_q, cu_seqlens_k, lq, lk return cu_seqlens_q, cu_seqlens_k, lq, lk
@torch.compile(disable=not ENABLE_GRAPH_MODE) @torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
......
import os import os
from functools import lru_cache
global ENABLE_PROFILING_DEBUG @lru_cache(maxsize=None)
ENABLE_PROFILING_DEBUG = os.getenv("ENABLE_PROFILING_DEBUG", "false").lower() == "true" def CHECK_ENABLE_PROFILING_DEBUG():
ENABLE_PROFILING_DEBUG = os.getenv("ENABLE_PROFILING_DEBUG", "false").lower() == "true"
return ENABLE_PROFILING_DEBUG
global ENABLE_GRAPH_MODE
ENABLE_GRAPH_MODE = os.getenv("ENABLE_GRAPH_MODE", "false").lower() == "true" @lru_cache(maxsize=None)
def CHECK_ENABLE_GRAPH_MODE():
ENABLE_GRAPH_MODE = os.getenv("ENABLE_GRAPH_MODE", "false").lower() == "true"
return ENABLE_GRAPH_MODE
...@@ -33,4 +33,4 @@ class _NullContext(ContextDecorator): ...@@ -33,4 +33,4 @@ class _NullContext(ContextDecorator):
ProfilingContext = _ProfilingContext ProfilingContext = _ProfilingContext
ProfilingContext4Debug = _ProfilingContext if ENABLE_PROFILING_DEBUG else _NullContext ProfilingContext4Debug = _ProfilingContext if CHECK_ENABLE_PROFILING_DEBUG() else _NullContext
...@@ -24,6 +24,7 @@ fi ...@@ -24,6 +24,7 @@ fi
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python ${lightx2v_path}/lightx2v/__main__.py \ python ${lightx2v_path}/lightx2v/__main__.py \
--model_cls hunyuan \ --model_cls hunyuan \
......
...@@ -24,6 +24,7 @@ fi ...@@ -24,6 +24,7 @@ fi
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python ${lightx2v_path}/lightx2v/__main__.py \ python ${lightx2v_path}/lightx2v/__main__.py \
--model_cls hunyuan \ --model_cls hunyuan \
......
...@@ -30,6 +30,7 @@ fi ...@@ -30,6 +30,7 @@ fi
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python ${lightx2v_path}/lightx2v/__main__.py \ python ${lightx2v_path}/lightx2v/__main__.py \
--model_cls wan2.1 \ --model_cls wan2.1 \
......
...@@ -30,6 +30,7 @@ fi ...@@ -30,6 +30,7 @@ fi
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python ${lightx2v_path}/lightx2v/__main__.py \ python ${lightx2v_path}/lightx2v/__main__.py \
--model_cls wan2.1 \ --model_cls wan2.1 \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment