Commit 1a798103 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Merge pull request #98 from ModelTC/dev_gradio

Update gradio
parents 8bf5c2e1 0d69e3cf
...@@ -11,7 +11,7 @@ from loguru import logger ...@@ -11,7 +11,7 @@ from loguru import logger
import importlib.util import importlib.util
import psutil import psutil
import random
logger.add( logger.add(
"inference_logs.log", "inference_logs.log",
...@@ -22,6 +22,12 @@ logger.add( ...@@ -22,6 +22,12 @@ logger.add(
diagnose=True, diagnose=True,
) )
MAX_NUMPY_SEED = 2**32 - 1
def generate_random_seed():
return random.randint(0, MAX_NUMPY_SEED)
def is_module_installed(module_name): def is_module_installed(module_name):
try: try:
...@@ -111,14 +117,13 @@ def is_fp8_supported_gpu(): ...@@ -111,14 +117,13 @@ def is_fp8_supported_gpu():
return (major == 8 and minor == 9) or (major >= 9) return (major == 8 and minor == 9) or (major >= 9)
def update_precision_mode(dit_quant_scheme):
if dit_quant_scheme != "bf16":
return "bf16"
return "fp32"
global_runner = None global_runner = None
current_config = None current_config = None
cur_dit_quant_scheme = None
cur_clip_quant_scheme = None
cur_t5_quant_scheme = None
cur_precision_mode = None
cur_enable_teacache = None
available_quant_ops = get_available_quant_ops() available_quant_ops = get_available_quant_ops()
quant_op_choices = [] quant_op_choices = []
...@@ -175,6 +180,7 @@ def run_inference( ...@@ -175,6 +180,7 @@ def run_inference(
attention_type = attention_type.split("(")[0].strip() attention_type = attention_type.split("(")[0].strip()
global global_runner, current_config, model_path global global_runner, current_config, model_path
global cur_dit_quant_scheme, cur_clip_quant_scheme, cur_t5_quant_scheme, cur_precision_mode, cur_enable_teacache
if os.path.exists(os.path.join(model_path, "config.json")): if os.path.exists(os.path.join(model_path, "config.json")):
with open(os.path.join(model_path, "config.json"), "r") as f: with open(os.path.join(model_path, "config.json"), "r") as f:
...@@ -279,7 +285,21 @@ def run_inference( ...@@ -279,7 +285,21 @@ def run_inference(
else: else:
clip_quant_ckpt = None clip_quant_ckpt = None
needs_reinit = lazy_load or global_runner is None or current_config is None or current_config.get("model_path") != model_path needs_reinit = (
lazy_load
or global_runner is None
or current_config is None
or cur_dit_quant_scheme is None
or cur_dit_quant_scheme != dit_quant_scheme
or cur_clip_quant_scheme is None
or cur_clip_quant_scheme != clip_quant_scheme
or cur_t5_quant_scheme is None
or cur_t5_quant_scheme != t5_quant_scheme
or cur_precision_mode is None
or cur_precision_mode != precision_mode
or cur_enable_teacache is None
or cur_enable_teacache != enable_teacache
)
if torch_compile: if torch_compile:
os.environ["ENABLE_GRAPH_MODE"] = "true" os.environ["ENABLE_GRAPH_MODE"] = "true"
...@@ -294,7 +314,10 @@ def run_inference( ...@@ -294,7 +314,10 @@ def run_inference(
if quant_op == "vllm": if quant_op == "vllm":
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Vllm" mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Vllm"
elif quant_op == "sgl": elif quant_op == "sgl":
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Sgl" if dit_quant_scheme == "int8":
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Sgl-ActVllm"
else:
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Sgl"
elif quant_op == "q8f": elif quant_op == "q8f":
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Q8F" mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Q8F"
...@@ -389,6 +412,11 @@ def run_inference( ...@@ -389,6 +412,11 @@ def run_inference(
runner = init_runner(config) runner = init_runner(config)
current_config = config current_config = config
cur_dit_quant_scheme = dit_quant_scheme
cur_clip_quant_scheme = clip_quant_scheme
cur_t5_quant_scheme = t5_quant_scheme
cur_precision_mode = precision_mode
cur_enable_teacache = enable_teacache
if not lazy_load: if not lazy_load:
global_runner = runner global_runner = runner
...@@ -597,7 +625,12 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -597,7 +625,12 @@ def auto_configure(enable_auto_config, model_type, resolution):
(32, {"dit_quant_scheme_val": quant_type, "lazy_load_val": True}), (32, {"dit_quant_scheme_val": quant_type, "lazy_load_val": True}),
( (
16, 16,
{"dit_quant_scheme_val": quant_type, "t5_quant_scheme_val": quant_type, "clip_quant_scheme_val": quant_type, "lazy_load_val": True}, {
"dit_quant_scheme_val": quant_type,
"t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type,
"lazy_load_val": True,
},
), ),
] ]
...@@ -716,14 +749,20 @@ def main(): ...@@ -716,14 +749,20 @@ def main():
value="832x480", value="832x480",
label="Maximum Resolution", label="Maximum Resolution",
) )
with gr.Column(): with gr.Column(scale=9):
seed = gr.Slider( seed = gr.Slider(
label="Random Seed", label="Random Seed",
minimum=-10000000, minimum=0,
maximum=10000000, maximum=MAX_NUMPY_SEED,
step=1, step=1,
value=42, value=generate_random_seed(),
) )
with gr.Column(scale=1):
randomize_btn = gr.Button("🎲 Randomize", variant="secondary")
randomize_btn.click(fn=generate_random_seed, inputs=None, outputs=seed)
with gr.Column():
infer_steps = gr.Slider( infer_steps = gr.Slider(
label="Inference Steps", label="Inference Steps",
minimum=1, minimum=1,
...@@ -963,8 +1002,6 @@ def main(): ...@@ -963,8 +1002,6 @@ def main():
], ],
) )
dit_quant_scheme.change(fn=update_precision_mode, inputs=[dit_quant_scheme], outputs=[precision_mode])
infer_btn.click( infer_btn.click(
fn=run_inference, fn=run_inference,
inputs=[ inputs=[
......
...@@ -11,6 +11,7 @@ from loguru import logger ...@@ -11,6 +11,7 @@ from loguru import logger
import importlib.util import importlib.util
import psutil import psutil
import random
logger.add( logger.add(
...@@ -22,9 +23,14 @@ logger.add( ...@@ -22,9 +23,14 @@ logger.add(
diagnose=True, diagnose=True,
) )
MAX_NUMPY_SEED = 2**32 - 1
def generate_random_seed():
return random.randint(0, MAX_NUMPY_SEED)
def is_module_installed(module_name): def is_module_installed(module_name):
"""检查模块是否已安装"""
try: try:
spec = importlib.util.find_spec(module_name) spec = importlib.util.find_spec(module_name)
return spec is not None return spec is not None
...@@ -112,14 +118,13 @@ def is_fp8_supported_gpu(): ...@@ -112,14 +118,13 @@ def is_fp8_supported_gpu():
return (major == 8 and minor == 9) or (major >= 9) return (major == 8 and minor == 9) or (major >= 9)
def update_precision_mode(dit_quant_scheme):
if dit_quant_scheme != "bf16":
return "bf16"
return "fp32"
global_runner = None global_runner = None
current_config = None current_config = None
cur_dit_quant_scheme = None
cur_clip_quant_scheme = None
cur_t5_quant_scheme = None
cur_precision_mode = None
cur_enable_teacache = None
available_quant_ops = get_available_quant_ops() available_quant_ops = get_available_quant_ops()
quant_op_choices = [] quant_op_choices = []
...@@ -176,6 +181,7 @@ def run_inference( ...@@ -176,6 +181,7 @@ def run_inference(
attention_type = attention_type.split("(")[0].strip() attention_type = attention_type.split("(")[0].strip()
global global_runner, current_config, model_path global global_runner, current_config, model_path
global cur_dit_quant_scheme, cur_clip_quant_scheme, cur_t5_quant_scheme, cur_precision_mode, cur_enable_teacache
if os.path.exists(os.path.join(model_path, "config.json")): if os.path.exists(os.path.join(model_path, "config.json")):
with open(os.path.join(model_path, "config.json"), "r") as f: with open(os.path.join(model_path, "config.json"), "r") as f:
...@@ -280,7 +286,21 @@ def run_inference( ...@@ -280,7 +286,21 @@ def run_inference(
else: else:
clip_quant_ckpt = None clip_quant_ckpt = None
needs_reinit = lazy_load or global_runner is None or current_config is None or current_config.get("model_path") != model_path needs_reinit = (
lazy_load
or global_runner is None
or current_config is None
or cur_dit_quant_scheme is None
or cur_dit_quant_scheme != dit_quant_scheme
or cur_clip_quant_scheme is None
or cur_clip_quant_scheme != clip_quant_scheme
or cur_t5_quant_scheme is None
or cur_t5_quant_scheme != t5_quant_scheme
or cur_precision_mode is None
or cur_precision_mode != precision_mode
or cur_enable_teacache is None
or cur_enable_teacache != enable_teacache
)
if torch_compile: if torch_compile:
os.environ["ENABLE_GRAPH_MODE"] = "true" os.environ["ENABLE_GRAPH_MODE"] = "true"
...@@ -295,7 +315,10 @@ def run_inference( ...@@ -295,7 +315,10 @@ def run_inference(
if quant_op == "vllm": if quant_op == "vllm":
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Vllm" mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Vllm"
elif quant_op == "sgl": elif quant_op == "sgl":
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Sgl" if dit_quant_scheme == "int8":
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Sgl-ActVllm"
else:
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Sgl"
elif quant_op == "q8f": elif quant_op == "q8f":
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Q8F" mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Q8F"
...@@ -389,6 +412,11 @@ def run_inference( ...@@ -389,6 +412,11 @@ def run_inference(
runner = init_runner(config) runner = init_runner(config)
current_config = config current_config = config
cur_dit_quant_scheme = dit_quant_scheme
cur_clip_quant_scheme = clip_quant_scheme
cur_t5_quant_scheme = t5_quant_scheme
cur_precision_mode = precision_mode
cur_enable_teacache = enable_teacache
if not lazy_load: if not lazy_load:
global_runner = runner global_runner = runner
...@@ -597,7 +625,12 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -597,7 +625,12 @@ def auto_configure(enable_auto_config, model_type, resolution):
(32, {"dit_quant_scheme_val": quant_type, "lazy_load_val": True}), (32, {"dit_quant_scheme_val": quant_type, "lazy_load_val": True}),
( (
16, 16,
{"dit_quant_scheme_val": quant_type, "t5_quant_scheme_val": quant_type, "clip_quant_scheme_val": quant_type, "lazy_load_val": True}, {
"dit_quant_scheme_val": quant_type,
"t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type,
"lazy_load_val": True,
},
), ),
] ]
...@@ -716,14 +749,19 @@ def main(): ...@@ -716,14 +749,19 @@ def main():
value="832x480", value="832x480",
label="最大分辨率", label="最大分辨率",
) )
with gr.Column(): with gr.Column(scale=9):
seed = gr.Slider( seed = gr.Slider(
label="随机种子", label="随机种子",
minimum=-10000000, minimum=0,
maximum=10000000, maximum=MAX_NUMPY_SEED,
step=1, step=1,
value=42, value=generate_random_seed(),
) )
with gr.Column(scale=1):
randomize_btn = gr.Button("🎲 生成随机种子", variant="secondary")
randomize_btn.click(fn=generate_random_seed, inputs=None, outputs=seed)
with gr.Column():
infer_steps = gr.Slider( infer_steps = gr.Slider(
label="推理步数", label="推理步数",
minimum=1, minimum=1,
...@@ -963,8 +1001,6 @@ def main(): ...@@ -963,8 +1001,6 @@ def main():
], ],
) )
dit_quant_scheme.change(fn=update_precision_mode, inputs=[dit_quant_scheme], outputs=[precision_mode])
infer_btn.click( infer_btn.click(
fn=run_inference, fn=run_inference,
inputs=[ inputs=[
......
#!/bin/bash #!/bin/bash
lightx2v_path=/mtc/gushiqiao/llmc_workspace/lightx2v_new/lightx2v lightx2v_path=/mtc/gushiqiao/llmc_workspace/lightx2v_new/lightx2v
model_path=/data/nvme0/gushiqiao/models/Wan2.1-I2V-14B-480P-Lightx2v model_path=/data/nvme0/gushiqiao/models/I2V/Wan2.1-I2V-14B-720P-Lightx2v-Step-Distill
export CUDA_VISIBLE_DEVICES=7 export CUDA_VISIBLE_DEVICES=7
export CUDA_LAUNCH_BLOCKING=1 export CUDA_LAUNCH_BLOCKING=1
......
...@@ -572,7 +572,7 @@ def convert_weights(args): ...@@ -572,7 +572,7 @@ def convert_weights(args):
json.dump(index, f, indent=2) json.dump(index, f, indent=2)
logger.info(f"Index file written to: {index_path}") logger.info(f"Index file written to: {index_path}")
if os.path.isdir(args.source): if os.path.isdir(args.source) and args.copy_no_weight_files:
copy_non_weight_files(args.source, args.output) copy_non_weight_files(args.source, args.output)
...@@ -650,6 +650,7 @@ def main(): ...@@ -650,6 +650,7 @@ def main():
default=[1.0], default=[1.0],
help="Alpha for LoRA weight scaling", help="Alpha for LoRA weight scaling",
) )
parser.add_argument("--copy_no_weight_files", action="store_true")
args = parser.parse_args() args = parser.parse_args()
if args.quantized: if args.quantized:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment