Commit 4e668794 authored by Muyang Li's avatar Muyang Li
Browse files

add use count in the demo

parent 0ffe1607
# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py # Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py
import os
import random import random
import tempfile import tempfile
import time import time
...@@ -73,6 +74,16 @@ def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed: ...@@ -73,6 +74,16 @@ def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed:
if is_unsafe_prompt: if is_unsafe_prompt:
latency_str += " (Unsafe prompt detected)" latency_str += " (Unsafe prompt detected)"
torch.cuda.empty_cache() torch.cuda.empty_cache()
if args.count_use:
if os.path.exists("use_count.txt"):
with open("use_count.txt", "r") as f:
count = int(f.read())
else:
count = 0
count += 1
print(f"Use count: {count}")
with open("use_count.txt", "w") as f:
f.write(str(count))
return result_image, latency_str return result_image, latency_str
......
...@@ -6,13 +6,11 @@ import argparse ...@@ -6,13 +6,11 @@ import argparse
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# parser.add_argument(
# "-m", "--model", type=str, default="pretrained/converted/sketch.safetensors", help="Path to the model"
# )
parser.add_argument( parser.add_argument(
"-p", "--precision", type=str, default="int4", choices=["int4", "bf16"], help="Which precisions to use" "-p", "--precision", type=str, default="int4", choices=["int4", "bf16"], help="Which precisions to use"
) )
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder") parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker") parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
args = parser.parse_args() args = parser.parse_args()
return args return args
# Changed from https://huggingface.co/spaces/playgroundai/playground-v2.5/blob/main/app.py # Changed from https://huggingface.co/spaces/playgroundai/playground-v2.5/blob/main/app.py
import argparse import argparse
import os
import random import random
import time import time
...@@ -30,6 +31,7 @@ def get_args() -> argparse.Namespace: ...@@ -30,6 +31,7 @@ def get_args() -> argparse.Namespace:
) )
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder") parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker") parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
return parser.parse_args() return parser.parse_args()
...@@ -124,6 +126,18 @@ def generate( ...@@ -124,6 +126,18 @@ def generate(
for i in range(len(latency_strs)): for i in range(len(latency_strs)):
latency_strs[i] += " (Unsafe prompt detected)" latency_strs[i] += " (Unsafe prompt detected)"
torch.cuda.empty_cache() torch.cuda.empty_cache()
if args.count_use:
if os.path.exists("use_count.txt"):
with open("use_count.txt", "r") as f:
count = int(f.read())
else:
count = 0
count += 1
print(f"Use count: {count}")
with open("use_count.txt", "w") as f:
f.write(str(count))
return *images, *latency_strs return *images, *latency_strs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment