Commit 1a798103 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Merge pull request #98 from ModelTC/dev_gradio

Update gradio
parents 8bf5c2e1 0d69e3cf
......@@ -11,7 +11,7 @@ from loguru import logger
import importlib.util
import psutil
import random
logger.add(
"inference_logs.log",
......@@ -22,6 +22,12 @@ logger.add(
diagnose=True,
)
MAX_NUMPY_SEED = 2**32 - 1
def generate_random_seed():
return random.randint(0, MAX_NUMPY_SEED)
def is_module_installed(module_name):
try:
......@@ -111,14 +117,13 @@ def is_fp8_supported_gpu():
return (major == 8 and minor == 9) or (major >= 9)
def update_precision_mode(dit_quant_scheme):
if dit_quant_scheme != "bf16":
return "bf16"
return "fp32"
global_runner = None
current_config = None
cur_dit_quant_scheme = None
cur_clip_quant_scheme = None
cur_t5_quant_scheme = None
cur_precision_mode = None
cur_enable_teacache = None
available_quant_ops = get_available_quant_ops()
quant_op_choices = []
......@@ -175,6 +180,7 @@ def run_inference(
attention_type = attention_type.split("(")[0].strip()
global global_runner, current_config, model_path
global cur_dit_quant_scheme, cur_clip_quant_scheme, cur_t5_quant_scheme, cur_precision_mode, cur_enable_teacache
if os.path.exists(os.path.join(model_path, "config.json")):
with open(os.path.join(model_path, "config.json"), "r") as f:
......@@ -279,7 +285,21 @@ def run_inference(
else:
clip_quant_ckpt = None
needs_reinit = lazy_load or global_runner is None or current_config is None or current_config.get("model_path") != model_path
needs_reinit = (
lazy_load
or global_runner is None
or current_config is None
or cur_dit_quant_scheme is None
or cur_dit_quant_scheme != dit_quant_scheme
or cur_clip_quant_scheme is None
or cur_clip_quant_scheme != clip_quant_scheme
or cur_t5_quant_scheme is None
or cur_t5_quant_scheme != t5_quant_scheme
or cur_precision_mode is None
or cur_precision_mode != precision_mode
or cur_enable_teacache is None
or cur_enable_teacache != enable_teacache
)
if torch_compile:
os.environ["ENABLE_GRAPH_MODE"] = "true"
......@@ -294,7 +314,10 @@ def run_inference(
if quant_op == "vllm":
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Vllm"
elif quant_op == "sgl":
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Sgl"
if dit_quant_scheme == "int8":
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Sgl-ActVllm"
else:
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Sgl"
elif quant_op == "q8f":
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Q8F"
......@@ -389,6 +412,11 @@ def run_inference(
runner = init_runner(config)
current_config = config
cur_dit_quant_scheme = dit_quant_scheme
cur_clip_quant_scheme = clip_quant_scheme
cur_t5_quant_scheme = t5_quant_scheme
cur_precision_mode = precision_mode
cur_enable_teacache = enable_teacache
if not lazy_load:
global_runner = runner
......@@ -597,7 +625,12 @@ def auto_configure(enable_auto_config, model_type, resolution):
(32, {"dit_quant_scheme_val": quant_type, "lazy_load_val": True}),
(
16,
{"dit_quant_scheme_val": quant_type, "t5_quant_scheme_val": quant_type, "clip_quant_scheme_val": quant_type, "lazy_load_val": True},
{
"dit_quant_scheme_val": quant_type,
"t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type,
"lazy_load_val": True,
},
),
]
......@@ -716,14 +749,20 @@ def main():
value="832x480",
label="Maximum Resolution",
)
with gr.Column():
with gr.Column(scale=9):
seed = gr.Slider(
label="Random Seed",
minimum=-10000000,
maximum=10000000,
minimum=0,
maximum=MAX_NUMPY_SEED,
step=1,
value=42,
value=generate_random_seed(),
)
with gr.Column(scale=1):
randomize_btn = gr.Button("🎲 Randomize", variant="secondary")
randomize_btn.click(fn=generate_random_seed, inputs=None, outputs=seed)
with gr.Column():
infer_steps = gr.Slider(
label="Inference Steps",
minimum=1,
......@@ -963,8 +1002,6 @@ def main():
],
)
dit_quant_scheme.change(fn=update_precision_mode, inputs=[dit_quant_scheme], outputs=[precision_mode])
infer_btn.click(
fn=run_inference,
inputs=[
......
......@@ -11,6 +11,7 @@ from loguru import logger
import importlib.util
import psutil
import random
logger.add(
......@@ -22,9 +23,14 @@ logger.add(
diagnose=True,
)
MAX_NUMPY_SEED = 2**32 - 1
def generate_random_seed():
return random.randint(0, MAX_NUMPY_SEED)
def is_module_installed(module_name):
"""检查模块是否已安装"""
try:
spec = importlib.util.find_spec(module_name)
return spec is not None
......@@ -112,14 +118,13 @@ def is_fp8_supported_gpu():
return (major == 8 and minor == 9) or (major >= 9)
def update_precision_mode(dit_quant_scheme):
if dit_quant_scheme != "bf16":
return "bf16"
return "fp32"
global_runner = None
current_config = None
cur_dit_quant_scheme = None
cur_clip_quant_scheme = None
cur_t5_quant_scheme = None
cur_precision_mode = None
cur_enable_teacache = None
available_quant_ops = get_available_quant_ops()
quant_op_choices = []
......@@ -176,6 +181,7 @@ def run_inference(
attention_type = attention_type.split("(")[0].strip()
global global_runner, current_config, model_path
global cur_dit_quant_scheme, cur_clip_quant_scheme, cur_t5_quant_scheme, cur_precision_mode, cur_enable_teacache
if os.path.exists(os.path.join(model_path, "config.json")):
with open(os.path.join(model_path, "config.json"), "r") as f:
......@@ -280,7 +286,21 @@ def run_inference(
else:
clip_quant_ckpt = None
needs_reinit = lazy_load or global_runner is None or current_config is None or current_config.get("model_path") != model_path
needs_reinit = (
lazy_load
or global_runner is None
or current_config is None
or cur_dit_quant_scheme is None
or cur_dit_quant_scheme != dit_quant_scheme
or cur_clip_quant_scheme is None
or cur_clip_quant_scheme != clip_quant_scheme
or cur_t5_quant_scheme is None
or cur_t5_quant_scheme != t5_quant_scheme
or cur_precision_mode is None
or cur_precision_mode != precision_mode
or cur_enable_teacache is None
or cur_enable_teacache != enable_teacache
)
if torch_compile:
os.environ["ENABLE_GRAPH_MODE"] = "true"
......@@ -295,7 +315,10 @@ def run_inference(
if quant_op == "vllm":
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Vllm"
elif quant_op == "sgl":
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Sgl"
if dit_quant_scheme == "int8":
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Sgl-ActVllm"
else:
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Sgl"
elif quant_op == "q8f":
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Q8F"
......@@ -389,6 +412,11 @@ def run_inference(
runner = init_runner(config)
current_config = config
cur_dit_quant_scheme = dit_quant_scheme
cur_clip_quant_scheme = clip_quant_scheme
cur_t5_quant_scheme = t5_quant_scheme
cur_precision_mode = precision_mode
cur_enable_teacache = enable_teacache
if not lazy_load:
global_runner = runner
......@@ -597,7 +625,12 @@ def auto_configure(enable_auto_config, model_type, resolution):
(32, {"dit_quant_scheme_val": quant_type, "lazy_load_val": True}),
(
16,
{"dit_quant_scheme_val": quant_type, "t5_quant_scheme_val": quant_type, "clip_quant_scheme_val": quant_type, "lazy_load_val": True},
{
"dit_quant_scheme_val": quant_type,
"t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type,
"lazy_load_val": True,
},
),
]
......@@ -716,14 +749,19 @@ def main():
value="832x480",
label="最大分辨率",
)
with gr.Column():
with gr.Column(scale=9):
seed = gr.Slider(
label="随机种子",
minimum=-10000000,
maximum=10000000,
minimum=0,
maximum=MAX_NUMPY_SEED,
step=1,
value=42,
value=generate_random_seed(),
)
with gr.Column(scale=1):
randomize_btn = gr.Button("🎲 生成随机种子", variant="secondary")
randomize_btn.click(fn=generate_random_seed, inputs=None, outputs=seed)
with gr.Column():
infer_steps = gr.Slider(
label="推理步数",
minimum=1,
......@@ -963,8 +1001,6 @@ def main():
],
)
dit_quant_scheme.change(fn=update_precision_mode, inputs=[dit_quant_scheme], outputs=[precision_mode])
infer_btn.click(
fn=run_inference,
inputs=[
......
#!/bin/bash
lightx2v_path=/mtc/gushiqiao/llmc_workspace/lightx2v_new/lightx2v
model_path=/data/nvme0/gushiqiao/models/Wan2.1-I2V-14B-480P-Lightx2v
model_path=/data/nvme0/gushiqiao/models/I2V/Wan2.1-I2V-14B-720P-Lightx2v-Step-Distill
export CUDA_VISIBLE_DEVICES=7
export CUDA_LAUNCH_BLOCKING=1
......
......@@ -572,7 +572,7 @@ def convert_weights(args):
json.dump(index, f, indent=2)
logger.info(f"Index file written to: {index_path}")
if os.path.isdir(args.source):
if os.path.isdir(args.source) and args.copy_no_weight_files:
copy_non_weight_files(args.source, args.output)
......@@ -650,6 +650,7 @@ def main():
default=[1.0],
help="Alpha for LoRA weight scaling",
)
parser.add_argument("--copy_no_weight_files", action="store_true")
args = parser.parse_args()
if args.quantized:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment