Commit 57463d8d authored by suily's avatar suily
Browse files

init

parents
Pipeline #1918 canceled with stages
icon.png

70.5 KB

from glob import glob
import shutil
import torch
from time import strftime
import os, sys, time
from argparse import ArgumentParser
from src.utils.preprocess import CropAndExtract
from src.test_audio2coeff import Audio2Coeff
from src.facerender.animate import AnimateFromCoeff
from src.generate_batch import get_data
from src.generate_facerender_batch import get_facerender_data
from src.utils.init_path import init_path
import time
def main(args):
#torch.backends.cudnn.enabled = False
pic_path = args.source_image
audio_path = args.driven_audio
save_dir = os.path.join(args.result_dir, strftime("%Y_%m_%d_%H.%M.%S"))
os.makedirs(save_dir, exist_ok=True)
pose_style = args.pose_style
device = args.device
batch_size = args.batch_size
input_yaw_list = args.input_yaw
input_pitch_list = args.input_pitch
input_roll_list = args.input_roll
ref_eyeblink = args.ref_eyeblink
ref_pose = args.ref_pose
current_root_path = os.path.split(sys.argv[0])[0]
sadtalker_paths = init_path(args.checkpoint_dir, os.path.join(current_root_path, 'src/config'), args.size, args.old_version, args.preprocess)
#init model
preprocess_model = CropAndExtract(sadtalker_paths, device)
audio_to_coeff = Audio2Coeff(sadtalker_paths, device)
animate_from_coeff = AnimateFromCoeff(sadtalker_paths, device)
#crop image and extract 3dmm from image
first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
os.makedirs(first_frame_dir, exist_ok=True)
print('3DMM Extraction for source image')
first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(pic_path, first_frame_dir, args.preprocess,\
source_image_flag=True, pic_size=args.size)
if first_coeff_path is None:
print("Can't get the coeffs of the input")
return
if ref_eyeblink is not None:
ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[0]
ref_eyeblink_frame_dir = os.path.join(save_dir, ref_eyeblink_videoname)
os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)
print('3DMM Extraction for the reference video providing eye blinking')
ref_eyeblink_coeff_path, _, _ = preprocess_model.generate(ref_eyeblink, ref_eyeblink_frame_dir, args.preprocess, source_image_flag=False)
else:
ref_eyeblink_coeff_path=None
if ref_pose is not None:
if ref_pose == ref_eyeblink:
ref_pose_coeff_path = ref_eyeblink_coeff_path
else:
ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0]
ref_pose_frame_dir = os.path.join(save_dir, ref_pose_videoname)
os.makedirs(ref_pose_frame_dir, exist_ok=True)
print('3DMM Extraction for the reference video providing pose')
ref_pose_coeff_path, _, _ = preprocess_model.generate(ref_pose, ref_pose_frame_dir, args.preprocess, source_image_flag=False)
else:
ref_pose_coeff_path=None
#audio2ceoff
batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=args.still)
coeff_path = audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
# 3dface render
if args.face3dvis:
from src.face3d.visualize import gen_composed_video
gen_composed_video(args, device, first_coeff_path, coeff_path, audio_path, os.path.join(save_dir, '3dface.mp4'))
#coeff2video
data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path,
batch_size, input_yaw_list, input_pitch_list, input_roll_list,
expression_scale=args.expression_scale, still_mode=args.still, preprocess=args.preprocess, size=args.size)
result = animate_from_coeff.generate(data, save_dir, pic_path, crop_info, \
enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess, img_size=args.size)
shutil.move(result, save_dir+'.mp4')
print('The generated video is named:', save_dir+'.mp4')
if not args.verbose:
shutil.rmtree(save_dir)
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--driven_audio", default='./examples/driven_audio/bus_chinese.wav', help="path to driven audio")
parser.add_argument("--source_image", default='./examples/source_image/full_body_1.png', help="path to source image")
parser.add_argument("--ref_eyeblink", default=None, help="path to reference video providing eye blinking")
parser.add_argument("--ref_pose", default=None, help="path to reference video providing pose")
parser.add_argument("--checkpoint_dir", default='./checkpoints', help="path to output")
parser.add_argument("--result_dir", default='./results', help="path to output")
parser.add_argument("--pose_style", type=int, default=0, help="input pose style from [0, 46)")
parser.add_argument("--batch_size", type=int, default=2, help="the batch size of facerender")
parser.add_argument("--size", type=int, default=256, help="the image size of the facerender")
parser.add_argument("--expression_scale", type=float, default=1., help="the batch size of facerender")
parser.add_argument('--input_yaw', nargs='+', type=int, default=None, help="the input yaw degree of the user ")
parser.add_argument('--input_pitch', nargs='+', type=int, default=None, help="the input pitch degree of the user")
parser.add_argument('--input_roll', nargs='+', type=int, default=None, help="the input roll degree of the user")
parser.add_argument('--enhancer', type=str, default=None, help="Face enhancer, [gfpgan, RestoreFormer]")
parser.add_argument('--background_enhancer', type=str, default=None, help="background enhancer, [realesrgan]")
parser.add_argument("--cpu", dest="cpu", action="store_true")
parser.add_argument("--face3dvis", action="store_true", help="generate 3d face and 3d landmarks")
parser.add_argument("--still", action="store_true", help="can crop back to the original videos for the full body aniamtion")
parser.add_argument("--preprocess", default='crop', choices=['crop', 'extcrop', 'resize', 'full', 'extfull'], help="how to preprocess the images" )
parser.add_argument("--verbose",action="store_true", help="saving the intermedia output or not" )
parser.add_argument("--old_version",action="store_true", help="use the pth other than safetensor version" )
# net structure and parameters
parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='useless')
parser.add_argument('--init_path', type=str, default=None, help='Useless')
parser.add_argument('--use_last_fc',default=False, help='zero initialize the last fc')
parser.add_argument('--bfm_folder', type=str, default='./checkpoints/BFM_Fitting/')
parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model')
# default renderer parameters
parser.add_argument('--focal', type=float, default=1015.)
parser.add_argument('--center', type=float, default=112.)
parser.add_argument('--camera_d', type=float, default=10.)
parser.add_argument('--z_near', type=float, default=5.)
parser.add_argument('--z_far', type=float, default=15.)
args = parser.parse_args()
if torch.cuda.is_available() and not args.cpu:
args.device = "cuda"
else:
args.device = "cpu"
start=time.time()
main(args)
end=time.time()
print(end-start,1.0/(end-start))
# this scripts installs necessary requirements and launches main program in webui.py
# borrow from : https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/launch.py
import subprocess
import os
import sys
import importlib.util
import shlex
import platform
import json
python = sys.executable
git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "")
stored_commit_hash = None
skip_install = False
dir_repos = "repositories"
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
def check_python_version():
is_windows = platform.system() == "Windows"
major = sys.version_info.major
minor = sys.version_info.minor
micro = sys.version_info.micro
if is_windows:
supported_minors = [10]
else:
supported_minors = [7, 8, 9, 10, 11]
if not (major == 3 and minor in supported_minors):
raise (f"""
INCOMPATIBLE PYTHON VERSION
This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
If you encounter an error with "RuntimeError: Couldn't install torch." message,
or any other error regarding unsuccessful package (library) installation,
please downgrade (or upgrade) to the latest version of 3.10 Python
and delete current Python and "venv" folder in WebUI's directory.
You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/
{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
Use --skip-python-version-check to suppress this warning.
""")
def commit_hash():
global stored_commit_hash
if stored_commit_hash is not None:
return stored_commit_hash
try:
stored_commit_hash = run(f"{git} rev-parse HEAD").strip()
except Exception:
stored_commit_hash = "<none>"
return stored_commit_hash
def run(command, desc=None, errdesc=None, custom_env=None, live=False):
if desc is not None:
print(desc)
if live:
result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
if result.returncode != 0:
raise RuntimeError(f"""{errdesc or 'Error running command'}.
Command: {command}
Error code: {result.returncode}""")
return ""
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
if result.returncode != 0:
message = f"""{errdesc or 'Error running command'}.
Command: {command}
Error code: {result.returncode}
stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'}
stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'}
"""
raise RuntimeError(message)
return result.stdout.decode(encoding="utf8", errors="ignore")
def check_run(command):
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
return result.returncode == 0
def is_installed(package):
try:
spec = importlib.util.find_spec(package)
except ModuleNotFoundError:
return False
return spec is not None
def repo_dir(name):
return os.path.join(script_path, dir_repos, name)
def run_python(code, desc=None, errdesc=None):
return run(f'"{python}" -c "{code}"', desc, errdesc)
def run_pip(args, desc=None):
if skip_install:
return
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
def check_run_python(code):
return check_run(f'"{python}" -c "{code}"')
def git_clone(url, dir, name, commithash=None):
# TODO clone into temporary dir and move if successful
if os.path.exists(dir):
if commithash is None:
return
current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
if current_hash == commithash:
return
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
return
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
if commithash is not None:
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
def git_pull_recursive(dir):
for subdir, _, _ in os.walk(dir):
if os.path.exists(os.path.join(subdir, '.git')):
try:
output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash'])
print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n")
except subprocess.CalledProcessError as e:
print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n")
def run_extension_installer(extension_dir):
path_installer = os.path.join(extension_dir, "install.py")
if not os.path.isfile(path_installer):
return
try:
env = os.environ.copy()
env['PYTHONPATH'] = os.path.abspath(".")
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
except Exception as e:
print(e, file=sys.stderr)
def prepare_environment():
global skip_install
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113")
## check windows
if sys.platform != 'win32':
requirements_file = os.environ.get('REQS_FILE', "req.txt")
else:
requirements_file = os.environ.get('REQS_FILE', "requirements.txt")
commit = commit_hash()
print(f"Python {sys.version}")
print(f"Commit hash: {commit}")
if not is_installed("torch") or not is_installed("torchvision"):
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
run_pip(f"install -r \"{requirements_file}\"", "requirements for SadTalker WebUI (may take longer time in first time)")
if sys.platform != 'win32' and not is_installed('tts'):
run_pip(f"install TTS", "install TTS individually in SadTalker, which might not work on windows.")
def start():
print(f"Launching SadTalker Web UI")
from app_sadtalker import sadtalker_demo
demo = sadtalker_demo()
demo.queue()
demo.launch()
if __name__ == "__main__":
prepare_environment()
start()
\ No newline at end of file
# 模型唯一标识
modelCode = 1095
# 模型名称
modelName=sadtalker_pytorch
# 模型描述
modelDescription=SadTalker是学习3D运动系数(头部姿势、表情)的视频生成模型,使用一张图片和一段语音来生成口型和头、面部视频。
# 应用场景
appScenario=推理,视频生成,家具,电商,医疗,广媒,教育
# 框架类型
frameType=pytorch
\ No newline at end of file
"""run bash scripts/download_models.sh first to prepare the weights file"""
import os
import shutil
from argparse import Namespace
from src.utils.preprocess import CropAndExtract
from src.test_audio2coeff import Audio2Coeff
from src.facerender.animate import AnimateFromCoeff
from src.generate_batch import get_data
from src.generate_facerender_batch import get_facerender_data
from src.utils.init_path import init_path
from cog import BasePredictor, Input, Path
checkpoints = "checkpoints"
class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
device = "cuda"
sadtalker_paths = init_path(checkpoints,os.path.join("src","config"))
# init model
self.preprocess_model = CropAndExtract(sadtalker_paths, device
)
self.audio_to_coeff = Audio2Coeff(
sadtalker_paths,
device,
)
self.animate_from_coeff = {
"full": AnimateFromCoeff(
sadtalker_paths,
device,
),
"others": AnimateFromCoeff(
sadtalker_paths,
device,
),
}
def predict(
self,
source_image: Path = Input(
description="Upload the source image, it can be video.mp4 or picture.png",
),
driven_audio: Path = Input(
description="Upload the driven audio, accepts .wav and .mp4 file",
),
enhancer: str = Input(
description="Choose a face enhancer",
choices=["gfpgan", "RestoreFormer"],
default="gfpgan",
),
preprocess: str = Input(
description="how to preprocess the images",
choices=["crop", "resize", "full"],
default="full",
),
ref_eyeblink: Path = Input(
description="path to reference video providing eye blinking",
default=None,
),
ref_pose: Path = Input(
description="path to reference video providing pose",
default=None,
),
still: bool = Input(
description="can crop back to the original videos for the full body aniamtion when preprocess is full",
default=True,
),
) -> Path:
"""Run a single prediction on the model"""
animate_from_coeff = (
self.animate_from_coeff["full"]
if preprocess == "full"
else self.animate_from_coeff["others"]
)
args = load_default()
args.pic_path = str(source_image)
args.audio_path = str(driven_audio)
device = "cuda"
args.still = still
args.ref_eyeblink = None if ref_eyeblink is None else str(ref_eyeblink)
args.ref_pose = None if ref_pose is None else str(ref_pose)
# crop image and extract 3dmm from image
results_dir = "results"
if os.path.exists(results_dir):
shutil.rmtree(results_dir)
os.makedirs(results_dir)
first_frame_dir = os.path.join(results_dir, "first_frame_dir")
os.makedirs(first_frame_dir)
print("3DMM Extraction for source image")
first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(
args.pic_path, first_frame_dir, preprocess, source_image_flag=True
)
if first_coeff_path is None:
print("Can't get the coeffs of the input")
return
if ref_eyeblink is not None:
ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[
0
]
ref_eyeblink_frame_dir = os.path.join(results_dir, ref_eyeblink_videoname)
os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)
print("3DMM Extraction for the reference video providing eye blinking")
ref_eyeblink_coeff_path, _, _ = self.preprocess_model.generate(
ref_eyeblink, ref_eyeblink_frame_dir
)
else:
ref_eyeblink_coeff_path = None
if ref_pose is not None:
if ref_pose == ref_eyeblink:
ref_pose_coeff_path = ref_eyeblink_coeff_path
else:
ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0]
ref_pose_frame_dir = os.path.join(results_dir, ref_pose_videoname)
os.makedirs(ref_pose_frame_dir, exist_ok=True)
print("3DMM Extraction for the reference video providing pose")
ref_pose_coeff_path, _, _ = self.preprocess_model.generate(
ref_pose, ref_pose_frame_dir
)
else:
ref_pose_coeff_path = None
# audio2ceoff
batch = get_data(
first_coeff_path,
args.audio_path,
device,
ref_eyeblink_coeff_path,
still=still,
)
coeff_path = self.audio_to_coeff.generate(
batch, results_dir, args.pose_style, ref_pose_coeff_path
)
# coeff2video
print("coeff2video")
data = get_facerender_data(
coeff_path,
crop_pic_path,
first_coeff_path,
args.audio_path,
args.batch_size,
args.input_yaw,
args.input_pitch,
args.input_roll,
expression_scale=args.expression_scale,
still_mode=still,
preprocess=preprocess,
)
animate_from_coeff.generate(
data, results_dir, args.pic_path, crop_info,
enhancer=enhancer, background_enhancer=args.background_enhancer,
preprocess=preprocess)
output = "/tmp/out.mp4"
mp4_path = os.path.join(results_dir, [f for f in os.listdir(results_dir) if "enhanced.mp4" in f][0])
shutil.copy(mp4_path, output)
return Path(output)
def load_default():
return Namespace(
pose_style=0,
batch_size=2,
expression_scale=1.0,
input_yaw=None,
input_pitch=None,
input_roll=None,
background_enhancer=None,
face3dvis=False,
net_recon="resnet50",
init_path=None,
use_last_fc=False,
bfm_folder="./src/config/",
bfm_model="BFM_model_front.mat",
focal=1015.0,
center=112.0,
camera_d=10.0,
z_near=5.0,
z_far=15.0,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment