import json import requests import os import uuid import shutil import time import logging import argparse from datetime import datetime from PIL import Image from config import female_images, male_images from typing import List, Tuple, Optional, Dict from contextlib import contextmanager CURRENT_DIR = os.path.dirname(__file__) INPUT_DIR = os.path.join(CURRENT_DIR, "..", "input") OUTPUT_DIR = os.path.join(CURRENT_DIR, "..", "output") CONCAT_OUTPUT_DIR = os.path.join(CURRENT_DIR, "concat_output") LOG_DIR = os.path.join(CURRENT_DIR, "logs") # Constants DEFAULT_TIMEOUT = 2400 EXTENDED_TIMEOUT = 1800 POLL_INTERVAL = 0.1 MAX_RETRIES = 3 # Create necessary directories for directory in [LOG_DIR, CONCAT_OUTPUT_DIR]: os.makedirs(directory, exist_ok=True) # Configure logging log_filename = os.path.join(LOG_DIR, f"test_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log") logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(log_filename, encoding='utf-8'), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) class TimingRecorder: """Record and calculate timing statistics""" def __init__(self): self.times = [] def add(self, elapsed_time: float): self.times.append(elapsed_time) def get_statistics(self) -> Dict[str, any]: if not self.times: return None return { 'times': self.times, 'average': sum(self.times) / len(self.times), 'count': len(self.times), 'min': min(self.times), 'max': max(self.times) } def log_statistics(self, workflow_name: str, extra_info: str = ""): stats = self.get_statistics() if stats: logger.info(f"\n========> {workflow_name} {extra_info} Timing Statistics <========") logger.info(f"Time list (excluding warmup): {[f'{t:.1f}s' for t in stats['times']]}") logger.info(f"Average time: {stats['average']:.1f}s") logger.info(f"Min time: {stats['min']:.1f}s, Max time: {stats['max']:.1f}s") logger.info(f"Total processed: {stats['count']}") else: logger.warning("No timing data collected") def queue_prompt(prompt: dict, server_url: str, retries: int = MAX_RETRIES) -> bool: """Queue a prompt to the server with retry logic""" p = {"prompt": prompt} data = json.dumps(p).encode('utf-8') for attempt in range(retries): try: response = requests.post(f"{server_url}/prompt", data=data, timeout=30) response.raise_for_status() logger.info(f"Queue prompt response: {response.json()}") return True except requests.exceptions.RequestException as e: logger.error(f"Failed to queue prompt (attempt {attempt + 1}/{retries}): {str(e)}") if attempt < retries - 1: time.sleep(1) else: raise Exception(f"Failed to queue prompt after {retries} attempts") from e return False def resize_image_to_height(image: Image.Image, target_height: int) -> Image.Image: """Resize image to target height while maintaining aspect ratio""" width, height = image.size aspect_ratio = width / height new_width = int(target_height * aspect_ratio) return image.resize((new_width, target_height), Image.Resampling.LANCZOS) def concat_images(image_paths: List[str], output_path: str) -> bool: """Concatenate images horizontally""" try: if not image_paths: logger.error("Image path list is empty") return False # Load all images with context manager images = [] for path in image_paths: try: img = Image.open(path) images.append(img.copy()) # Copy to allow closing the file except Exception as e: logger.error(f"Failed to load image {path}: {str(e)}") return False if not images: logger.error("No valid images loaded") return False target_height = max(img.height for img in images) resized_images = [resize_image_to_height(img, target_height) for img in images] total_width = sum(img.width for img in resized_images) concatenated_image = Image.new('RGB', (total_width, target_height), (255, 255, 255)) x_offset = 0 for img in resized_images: concatenated_image.paste(img, (x_offset, 0)) x_offset += img.width concatenated_image.save(output_path) logger.info(f"Successfully saved concatenated image to {output_path}") # Close all images for img in images: img.close() return True except Exception as e: logger.error(f"Failed to concatenate images: {str(e)}") return False def check_prefix_image(dir_path: str, prefix: str, timeout: int = DEFAULT_TIMEOUT) -> Optional[str]: """Check for image with specific prefix in directory""" start_time = time.time() checked_files = set() while time.time() - start_time < timeout: try: files = os.listdir(dir_path) for file in files: file_path = os.path.join(dir_path, file) if file.startswith(prefix): try: # Verify image can be opened with Image.open(file_path) as img: img.verify() # Re-open to ensure file is complete with Image.open(file_path) as img: img.load() logger.info(f"Found valid output image: {file}") return file except Exception as e: logger.debug(f"File {file} not ready yet: {str(e)}") continue time.sleep(POLL_INTERVAL) except Exception as e: logger.error(f"Error checking directory: {str(e)}") time.sleep(POLL_INTERVAL) logger.error(f"Timeout waiting for image with prefix {prefix}") return None def clean_output_directories(*directories: str): """Clean PNG files from output directories""" for directory in directories: try: if not os.path.exists(directory): logger.warning(f"Directory does not exist: {directory}") continue png_files = [f for f in os.listdir(directory) if f.endswith(".png")] for file in png_files: try: os.remove(os.path.join(directory, file)) logger.debug(f"Removed: {file}") except Exception as e: logger.error(f"Failed to remove {file}: {str(e)}") logger.info(f"Cleaned {len(png_files)} PNG files from {directory}") except Exception as e: logger.error(f"Failed to clean directory {directory}: {str(e)}") @contextmanager def workflow_execution_timer(workflow_name: str, skip_first: bool = True): """Context manager for timing workflow execution""" recorder = TimingRecorder() start_time = time.perf_counter() try: yield recorder finally: pass def test1(server_url: str): """Test human portrait workflow with different weight types""" #workflow_name = "1-人像写真-开放平台-真实照片-v1.json" workflow_name = "1-人像写真-开放平台-真实照片-v1_torchcompile.json" workflow_file = os.path.join(CURRENT_DIR, "workflows", workflow_name) test_images = female_images with open(workflow_file) as f: api_prompt = json.load(f) image_mapping = {} for weight_dtype in ['default', 'fp8_e4m3fn', 'fp8_e4m3fn_fast'][::2]: logger.info(f'========> {workflow_name} {weight_dtype} <========') recorder = TimingRecorder() for idx, image in enumerate(test_images): api_prompt["16"]["inputs"]["image"] = image task_id = str(uuid.uuid4()).replace('-', '') api_prompt["219"]["inputs"]["filename_prefix"] = task_id api_prompt["150"]["inputs"]["weight_dtype"] = weight_dtype #api_prompt["203"]["inputs"]["bbox_detector"]="yolox_l.torchscript.pt" api_prompt["203"]["inputs"]["pose_estimator"] = "dw-ll_ucoco_384.onnx" try: queue_prompt(api_prompt, server_url) except Exception as e: response = requests.get(f"{server_url}/history") print(response.text[:2000]) # 查看最近错误 raise #requests.post(f"{server_url}/free", json={"unload_models": True}) #time.sleep(2) logger.info(f'Waiting for output... {task_id}') tic = time.perf_counter() output_image = check_prefix_image(OUTPUT_DIR, task_id, DEFAULT_TIMEOUT) toc = time.perf_counter() elapsed_time = toc - tic logger.info(f"Time taken: {elapsed_time:.1f} seconds") # Skip first image for timing statistics if idx > 0: recorder.add(elapsed_time) if output_image is None: raise Exception(f"Output image {task_id} not found") if image not in image_mapping: image_mapping[image] = [] image_mapping[image].append(output_image) # Output timing statistics recorder.log_statistics(os.path.basename(workflow_file), weight_dtype) # Create concatenated result images for k, v in image_mapping.items(): image_paths = [os.path.join(INPUT_DIR, k)] + [os.path.join(OUTPUT_DIR, img) for img in v] output_path = os.path.join(CONCAT_OUTPUT_DIR, f"1_{k.split('.')[0]}.png") concat_images(image_paths, output_path) def test2(server_url: str): """Test clothing change workflow""" #workflow_file = os.path.join(CURRENT_DIR, "workflows", "2-换装-aigc开放平台.json") workflow_file = os.path.join(CURRENT_DIR, "workflows", "2-换装-aigc开放平台_torchcompile.json") test_images = [] for image1 in male_images: for image2 in male_images: if image1 != image2: test_images.append((image1, image2)) with open(workflow_file) as f: api_prompt = json.load(f) logger.info(f'\n========> {os.path.basename(workflow_file)} <========') recorder = TimingRecorder() for idx, (image1, image2) in enumerate(test_images[:2]): api_prompt["158"]["inputs"]["image"] = image1 api_prompt["161"]["inputs"]["image"] = image2 task_id = str(uuid.uuid4()).replace('-', '') api_prompt["391"]["inputs"]["filename_prefix"] = task_id api_prompt["358"]["inputs"]["bbox_detector"]="yolox_l.onnx" api_prompt["358"]["inputs"]["pose_estimator"] = "dw-ll_ucoco_384.onnx" queue_prompt(api_prompt, server_url) logger.info(f'Waiting for output... {task_id}') tic = time.perf_counter() output_image = check_prefix_image(OUTPUT_DIR, task_id, EXTENDED_TIMEOUT) toc = time.perf_counter() elapsed_time = toc - tic logger.info(f"Time taken: {elapsed_time:.1f} seconds") # Skip first image pair for timing statistics if idx > 1: recorder.add(elapsed_time) if output_image is None: raise Exception(f"Output image {task_id} not found") image_paths = [ os.path.join(INPUT_DIR, image1), os.path.join(INPUT_DIR, image2), os.path.join(OUTPUT_DIR, output_image) ] output_path = os.path.join(CONCAT_OUTPUT_DIR, f"2_{task_id}.png") concat_images(image_paths, output_path) # Output timing statistics recorder.log_statistics(os.path.basename(workflow_file)) def test_flux1dev_t2i(workflow_file: str, server_url: str, base_seed: int = 411647920510829, test_id: int = 0): """Test FLUX.1-dev text-to-image workflow with different resolutions""" with open(workflow_file) as f: api_prompt = json.load(f) resolutions = [(512, 768), (576, 1024), (768, 1024), (720, 1280), (1080, 1920)] for resolution in resolutions: width, height = resolution logger.info(f'========> {os.path.basename(workflow_file)} {width}x{height} seed:{base_seed} <========') recorder = TimingRecorder() output_images = [] for i in range(2): api_prompt["25"]["inputs"]["noise_seed"] = base_seed + i api_prompt["27"]["inputs"]["width"] = width api_prompt["27"]["inputs"]["height"] = height api_prompt["30"]["inputs"]["width"] = width api_prompt["30"]["inputs"]["height"] = height task_id = str(uuid.uuid4()).replace('-', '') api_prompt["9"]["inputs"]["filename_prefix"] = task_id queue_prompt(api_prompt, server_url) logger.info(f'Waiting for output... {task_id}') tic = time.perf_counter() output_image = check_prefix_image(OUTPUT_DIR, task_id, DEFAULT_TIMEOUT) toc = time.perf_counter() elapsed_time = toc - tic logger.info(f"Time taken: {elapsed_time:.1f} seconds") # Skip first image for timing statistics if i > 0: recorder.add(elapsed_time) if output_image is None: raise Exception(f"Output image {task_id} not found") output_images.append(output_image) # Create concatenated result image_paths = [os.path.join(OUTPUT_DIR, img) for img in output_images] output_path = os.path.join(CONCAT_OUTPUT_DIR, f"{test_id}_{width}x{height}_{base_seed}.png") concat_images(image_paths, output_path) # Output timing statistics recorder.log_statistics(os.path.basename(workflow_file), f"{width}x{height}") def test5(server_url: str): """Test golden human portrait workflow""" workflow_name = "5-huangjin_human.json" workflow_file = os.path.join(CURRENT_DIR, "workflows", workflow_name) test_images = female_images with open(workflow_file) as f: api_prompt = json.load(f) logger.info(f'\n========> {workflow_name} <========') recorder = TimingRecorder() image_mapping = {} for idx, image in enumerate(test_images): api_prompt["11"]["inputs"]["image"] = image task_id = str(uuid.uuid4()).replace('-', '') api_prompt["61"]["inputs"]["filename_prefix"] = task_id queue_prompt(api_prompt, server_url) logger.info(f'Waiting for output... {task_id}') tic = time.perf_counter() output_image = check_prefix_image(OUTPUT_DIR, task_id, EXTENDED_TIMEOUT) toc = time.perf_counter() elapsed_time = toc - tic logger.info(f"Time taken: {elapsed_time:.1f} seconds") # Skip first image for timing statistics if idx > 0: recorder.add(elapsed_time) if output_image is None: raise Exception(f"Output image {task_id} not found") image_mapping[image] = output_image # Create concatenated result images for input_img, output_img in image_mapping.items(): image_paths = [ os.path.join(INPUT_DIR, input_img), os.path.join(OUTPUT_DIR, output_img) ] output_path = os.path.join(CONCAT_OUTPUT_DIR, f"5_{input_img.split('.')[0]}.png") concat_images(image_paths, output_path) # Output timing statistics recorder.log_statistics(workflow_name) def test6(server_url: str): """Test Qwen image edit workflow with different weight types""" workflow_name = "6-qwen-image-edit.json" workflow_file = os.path.join(CURRENT_DIR, "workflows", workflow_name) with open(workflow_file) as f: api_prompt = json.load(f) # Test image pairs for clothing change test_cases = [] for image1 in male_images: for image2 in male_images: if image1 != image2: test_cases.append((image1, image2)) # Add SaveImage node if not exists (node 246 is PreviewImage) if "256" not in api_prompt: api_prompt["256"] = { "inputs": { "images": ["241", 0], "filename_prefix": "qwen_output" }, "class_type": "SaveImage" } image_mapping = {} for weight_dtype in ['default', 'fp8_e4m3fn', 'fp8_e4m3fn_fast']: logger.info(f'\n========> {workflow_name} {weight_dtype} <========') recorder = TimingRecorder() for idx, (image1, image2) in enumerate(test_cases): api_prompt["247"]["inputs"]["image"] = image1 api_prompt["248"]["inputs"]["image"] = image2 api_prompt["236"]["inputs"]["weight_dtype"] = weight_dtype task_id = str(uuid.uuid4()).replace('-', '') api_prompt["256"]["inputs"]["filename_prefix"] = task_id queue_prompt(api_prompt, server_url) logger.info(f'Waiting for output... {task_id}') tic = time.perf_counter() output_image = check_prefix_image(OUTPUT_DIR, task_id, EXTENDED_TIMEOUT) toc = time.perf_counter() elapsed_time = toc - tic logger.info(f"Time taken: {elapsed_time:.1f} seconds") # Skip first image for timing statistics if idx > 0: recorder.add(elapsed_time) if output_image is None: raise Exception(f"Output image {task_id} not found") case_key = f"{image1}_{image2}" if case_key not in image_mapping: image_mapping[case_key] = [] image_mapping[case_key].append(output_image) # Output timing statistics recorder.log_statistics(workflow_name, weight_dtype) # Create concatenated result images for case_key, output_images in image_mapping.items(): image1, image2 = case_key.split('_', 1) image_paths = [ os.path.join(INPUT_DIR, image1), os.path.join(INPUT_DIR, image2) ] + [os.path.join(OUTPUT_DIR, img) for img in output_images] output_path = os.path.join(CONCAT_OUTPUT_DIR, f"6_{case_key.replace('.', '_')}.png") concat_images(image_paths, output_path) def test7(server_url: str): """Test Kontext+ACE old photo restoration workflow with different weight types""" workflow_name = "7-kontext+ace 老照片修复_torchcompile.json" workflow_file = os.path.join(CURRENT_DIR, "workflows", workflow_name) test_images = female_images with open(workflow_file) as f: api_prompt = json.load(f) image_mapping = {} for weight_dtype in ['default', 'fp8_e4m3fn', 'fp8_e4m3fn_fast']: #for weight_dtype in ['fp8_e4m3fn_fast']: if weight_dtype == "fp8_e4m3fn_fast": api_prompt["12"]["inputs"]["unet_name"] = "flux1-kontext-dev.safetensors" logger.info(f'\n========> {workflow_name} {weight_dtype} <========') recorder = TimingRecorder() for idx, image in enumerate(test_images): api_prompt["41"]["inputs"]["image"] = image api_prompt["12"]["inputs"]["weight_dtype"] = weight_dtype task_id = str(uuid.uuid4()).replace('-', '') api_prompt["9"]["inputs"]["filename_prefix"] = task_id queue_prompt(api_prompt, server_url) logger.info(f'Waiting for output... {task_id}') tic = time.perf_counter() output_image = check_prefix_image(OUTPUT_DIR, task_id, EXTENDED_TIMEOUT) toc = time.perf_counter() elapsed_time = toc - tic logger.info(f"Time taken: {elapsed_time:.1f} seconds") # Skip first image for timing statistics if idx > 0: recorder.add(elapsed_time) if output_image is None: raise Exception(f"Output image {task_id} not found") if image not in image_mapping: image_mapping[image] = [] image_mapping[image].append(output_image) # Output timing statistics recorder.log_statistics(workflow_name, weight_dtype) # Create concatenated result images for input_img, output_images in image_mapping.items(): image_paths = [os.path.join(INPUT_DIR, input_img)] + [os.path.join(OUTPUT_DIR, img) for img in output_images] output_path = os.path.join(CONCAT_OUTPUT_DIR, f"7_{input_img.split('.')[0]}.png") concat_images(image_paths, output_path) if __name__ == "__main__": parser = argparse.ArgumentParser(description='ComfyUI workflow test script') parser.add_argument('--ip', type=str, default='127.0.0.1', help='Server IP address (default: 127.0.0.1)') parser.add_argument('--port', type=int, default=8188, help='Server port (default: 8188)') parser.add_argument('--skip-cleanup', action='store_true', help='Skip cleaning output directories before tests') args = parser.parse_args() server_url = f"http://{args.ip}:{args.port}" logger.info(f"Log file saved at: {log_filename}") logger.info(f"Server URL: {server_url}") logger.info("Starting workflow tests...") # Clean output directories if not args.skip_cleanup: clean_output_directories(OUTPUT_DIR, CONCAT_OUTPUT_DIR) try: # Test human portrait workflow with flux.1-dev model #test1(server_url) #test1(server_url) # Test clothing change workflow test2(server_url) # Test flux.1-dev t2i workflows #os.path.join(CURRENT_DIR, "workflows", "3-flux.1-dev-t2i.json"), """ test_flux1dev_t2i( os.path.join(CURRENT_DIR, "workflows", "3-flux.1-dev-t2i_torchcompile.json"), server_url, test_id=3 ) """ # Test int4 in L20/L40 # test_flux1dev_t2i( # os.path.join(CURRENT_DIR, "workflows", "4-flux.1-dev-t2i-nunchaku.json"), # server_url, # test_id=4 # ) # Test golden human portrait workflow #test5(server_url) # Test Qwen image edit workflow #test6(server_url) # Test old photo restoration workflow #test7(server_url) #test7(server_url) #test7(server_url) logger.info("All tests completed successfully!") except Exception as e: logger.error(f"Test failed with error: {str(e)}", exc_info=True) raise