{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Stage-I ckpt path \n", "stage_1_path = \"./checkpoints/stage1.pt\"\n", "stage_2_path = \"./checkpoints/stage2.pt\"\n", "save_dir=\"vis_270p_1080p\"\n", "# 2 ~ 3\n", "shift_t = 2.5\n", "# 4 ~ 6\n", "sample_step = 5\n", "# 10 ~ 13\n", "cfg_second = 13\n", " # 650 ~ 750\n", "deg_latent_strength=675\n", "# stage_1_hw \n", "\n", "#TODO Stage I CFG here\n", "cfg_first = 8\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "current_directory = os.getcwd()\n", "os.chdir(os.path.dirname(current_directory))\n", "new_directory = os.getcwd()\n", "print(f\"working directory: {new_directory}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "import os\n", "import argparse\n", "import torch\n", "import numpy as np\n", "import copy\n", "\n", "from sat.model.base_model import get_model\n", "from arguments import get_args\n", "from torchvision.io.video import write_video\n", "\n", "from flow_video import FlowEngine\n", "from diffusion_video import SATVideoDiffusionEngine\n", "\n", "import os\n", "from utils import disable_all_init, decode, prepare_input, save_memory_encode_first_stage, save_mem_decode, seed_everything\n", "disable_all_init()\n", "\n", "\n", "def init_model(model, second_model, args, second_args):\n", " share_cache = dict()\n", " second_share_cache = dict()\n", " if hasattr(args, 'share_cache_config'):\n", " for k, v in args.share_cache_config.items():\n", " share_cache[k] = v\n", " if hasattr(second_args, 'share_cache_config'):\n", " for k, v in second_args.share_cache_config.items():\n", " second_share_cache[k] = v\n", "\n", " for n, m in model.named_modules():\n", " m.share_cache = share_cache \n", " if hasattr(m, \"register_new_modules\"):\n", " m.register_new_modules()\n", " for n, m in second_model.named_modules():\n", " m.share_cache = second_share_cache \n", " if hasattr(m, \"register_new_modules\"):\n", " m.register_new_modules() \n", "\n", " weight_path = args.inf_ckpt\n", " weight = torch.load(weight_path, map_location=\"cpu\")\n", " if \"model.diffusion_model.mixins.pos_embed.freqs_sin\" in weight[\"module\"]:\n", " del weight[\"module\"][\"model.diffusion_model.mixins.pos_embed.freqs_sin\"]\n", " del weight[\"module\"][\"model.diffusion_model.mixins.pos_embed.freqs_cos\"]\n", " msg = model.load_state_dict(weight[\"module\"], strict=False)\n", " print(msg)\n", " second_weight_path = args.inf_ckpt2\n", " second_weight = torch.load(second_weight_path, map_location=\"cpu\")\n", " if \"model.diffusion_model.mixins.pos_embed.freqs_sin\" in second_weight[\"module\"]:\n", " del second_weight[\"module\"][\"model.diffusion_model.mixins.pos_embed.freqs_sin\"]\n", " del second_weight[\"module\"][\"model.diffusion_model.mixins.pos_embed.freqs_cos\"]\n", " second_msg = second_model.load_state_dict(second_weight[\"module\"], strict=False)\n", " print(second_msg)\n", "\n", "def get_first_results(model, text, num_frames, H, W, neg_prompt=None):\n", " \"\"\"Get first Stage results.\n", "\n", " Args:\n", " model (nn.Module): first stage model.\n", " text (str): text prompt\n", " num_frames (int): number of frames\n", " H (int): height of the first stage results\n", " W (int): width of the first stage results\n", " neg_prompt (str): negative prompt\n", "\n", " Returns:\n", " Tensor: first stage video.\n", " \"\"\"\n", " device = 'cuda'\n", " T = 1 + (num_frames - 1) // 4\n", " F = 8\n", " motion_text_prefix = [\n", " 'very low motion,',\n", " 'low motion,',\n", " 'medium motion,',\n", " 'high motion,',\n", " 'very high motion,',\n", " ]\n", " pos_prompt = \"\"\n", " if neg_prompt is None:\n", " neg_prompt = \"\"\n", " with torch.no_grad():\n", " model.to('cuda')\n", " input_negative_prompt = motion_text_prefix[\n", " 0] + ', ' + motion_text_prefix[1] + neg_prompt\n", " c, uc = prepare_input(text,\n", " model,\n", " T,\n", " negative_prompt=input_negative_prompt,\n", " pos_prompt=pos_prompt)\n", " with torch.no_grad(), torch.amp.autocast(enabled=True,\n", " device_type='cuda',\n", " dtype=torch.bfloat16):\n", " samples_z = model.sample(\n", " c,\n", " uc=uc,\n", " batch_size=1,\n", " shape=(T, 16, H // F, W // F),\n", " num_steps=model.share_cache.get('first_sample_step', None),\n", " )\n", " samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()\n", "\n", " model.to('cpu')\n", " torch.cuda.empty_cache()\n", " first_stage_model = model.first_stage_model\n", " first_stage_model = first_stage_model.to(device)\n", "\n", " latent = 1.0 / model.scale_factor * samples_z\n", "\n", " samples = decode(first_stage_model, latent)\n", " model.to('cpu')\n", " return samples\n", "def get_second_results(model, text, first_stage_samples, num_frames):\n", " \"\"\"Get second Stage results.\n", "\n", " Args:\n", " model (nn.Module): second stage model.\n", " text (str): text prompt\n", " first_stage_samples (Tensor): first stage results\n", " num_frames (int): number of frames\n", " Returns:\n", " Tensor: second stage results.\n", " \"\"\"\n", "\n", " t, h, w, c = first_stage_samples.shape\n", " first_stage_samples = first_stage_samples[:num_frames]\n", " first_stage_samples = (first_stage_samples / 255.)\n", " first_stage_samples = (first_stage_samples - 0.5) / 0.5\n", "\n", " target_size = model.share_cache.get('target_size', None)\n", " if target_size is None:\n", " upscale_factor = model.share_cache.get('upscale_factor', 8)\n", " H = int(h * upscale_factor) // 16 * 16\n", " W = int(w * upscale_factor) // 16 * 16\n", " else:\n", " H, W = target_size\n", " H = H // 16 * 16\n", " W = W // 16 * 16\n", "\n", " first_stage_samples = first_stage_samples.permute(0, 3, 1, 2).to('cuda')\n", "\n", " ref_x = torch.nn.functional.interpolate(first_stage_samples,\n", " size=(H, W),\n", " mode='bilinear',\n", " align_corners=False,\n", " antialias=True)\n", " ref_x = ref_x[:num_frames][None]\n", "\n", " ref_x = ref_x.permute(0, 2, 1, 3, 4).contiguous()\n", "\n", " first_stage_model = model.first_stage_model\n", " print(f'start encoding first stage results to high resolution')\n", " with torch.no_grad():\n", " first_stage_dtype = next(model.first_stage_model.parameters()).dtype\n", " model.first_stage_model.cuda()\n", " ref_x = save_memory_encode_first_stage(\n", " ref_x.contiguous().to(first_stage_dtype).cuda(), model)\n", "\n", " ref_x = ref_x.permute(0, 2, 1, 3, 4).contiguous()\n", " ref_x = ref_x.to(model.dtype)\n", " print(f'finish encoding first stage results, and starting stage II')\n", "\n", " device = 'cuda'\n", "\n", " model.to(device)\n", "\n", " pos_prompt = ''\n", " input_negative_prompt = \"\"\n", "\n", " c, uc = prepare_input(text,\n", " model,\n", " num_frames,\n", " negative_prompt=input_negative_prompt,\n", " pos_prompt=pos_prompt)\n", "\n", " T = 1 + (num_frames - 1) // 4\n", " F = 8\n", " with torch.no_grad(), torch.amp.autocast(enabled=True,\n", " device_type='cuda',\n", " dtype=torch.bfloat16):\n", " samples_z = model.sample(\n", " ref_x,\n", " c,\n", " uc=uc,\n", " batch_size=1,\n", " shape=(T, 16, H // F, W // F),\n", " num_steps=model.share_cache.get('sample_step', 5),\n", " method='euler',\n", " cfg=model.share_cache.get('cfg', 7.5),\n", " )\n", " samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()\n", "\n", " model.to('cpu')\n", " torch.cuda.empty_cache()\n", " first_stage_model = model.first_stage_model\n", " first_stage_model = first_stage_model.to(device)\n", "\n", " latent = 1.0 / model.scale_factor * samples_z\n", " print(f'start spatiotemporal slice decoding')\n", " samples = save_mem_decode(first_stage_model, latent)\n", " print(f'finish spatiotemporal slice decoding')\n", " model.to('cpu')\n", " return samples\n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "\n", "\n", "\n", "\n", "os.environ[\"LOCAL_RANK\"] = \"0\"\n", "os.environ[\"WORLD_SIZE\"] = \"1\"\n", "os.environ[\"RANK\"] = \"0\"\n", "os.environ[\"MASTER_ADDR\"] = \"0.0.0.0\"\n", "os.environ[\"MASTER_PORT\"] = \"12345\"\n", "\n", "py_parser = argparse.ArgumentParser(add_help=False)\n", "args_list = [\n", " \"--base\", \"flashvideo/configs/stage1.yaml\",\n", " \"--second\", \"flashvideo/configs/stage2.yaml\",\n", " \"--inf-ckpt\", stage_1_path,\n", " \"--inf-ckpt2\", stage_2_path,\n", "]\n", "known, args_list = py_parser.parse_known_args(args=args_list)\n", "second_args_list = copy.deepcopy(args_list)\n", "\n", "\n", "args = get_args(args_list)\n", "args = argparse.Namespace(**vars(args), **vars(known))\n", "del args.deepspeed_config\n", "args.model_config.first_stage_config.params.cp_size = 1\n", "args.model_config.network_config.params.transformer_args.model_parallel_size = 1\n", "args.model_config.network_config.params.transformer_args.checkpoint_activations = False\n", "args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False\n", "\n", "second_args_list[1] = args.second[0]\n", "second_args = get_args(second_args_list)\n", "second_args = argparse.Namespace(**vars(second_args), **vars(known))\n", "del second_args.deepspeed_config\n", "second_args.model_config.first_stage_config.params.cp_size = 1\n", "second_args.model_config.network_config.params.transformer_args.model_parallel_size = 1\n", "second_args.model_config.network_config.params.transformer_args.checkpoint_activations = False\n", "second_args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_cls=SATVideoDiffusionEngine\n", "second_model_cls=FlowEngine\n", "local_rank = int(os.environ.get(\"LOCAL_RANK\", 0))\n", "torch.cuda.set_device(local_rank)\n", "\n", "second_model = get_model(second_args, second_model_cls)\n", "\n", "model = get_model(args, model_cls)\n", " \n", "init_model(model, second_model, args, second_args )\n", " \n", "model.eval()\n", "second_model.eval()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for n, m in model.named_modules():\n", " if hasattr(m, \"merge_lora\"):\n", " m.merge_lora()\n", " print(f\"merge lora of {n}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "num_frames = 49\n", "second_num_frames = 49 \n", "\n", "stage_1_hw = (270, 480) \n", "stage_2_hw = (1080, 1920) \n", "\n", "# make sure all can be divided by 16\n", "stage_1_hw = (stage_1_hw[0] // 16 * 16, stage_1_hw[1] // 16 * 16)\n", "stage_2_hw = (stage_2_hw[0] // 16 * 16, stage_2_hw[1] // 16 * 16)\n", "\n", "sample_func = model.sample\n", "T, H, W, C, F = num_frames, stage_1_hw[0], stage_1_hw[1], args.latent_channels, 8\n", "S_T, S_H, S_W, S_C, S_F = second_num_frames, stage_2_hw[0], stage_2_hw[1], args.latent_channels, 8\n", "\n", "\n", " \n", "seed_everything(0)\n", "\n", "text = \" Sunny day, The camera smoothly pushes in through an ornate garden archway, delicately adorned with climbing ivy. \\\n", " Beyond the archway, a secret, tranquil garden is revealed, brimming with a vibrant array of blooming flowers \\\n", " in a myriad of colors. A beautiful young woman with long wavy brown hair, she is smile to the camera , \\\n", " wearing a red hat sits holding a dog , the red hat has rich fabric texture \\\n", " wearing black pleated skirt and yellow sweater \"\n", "\n", "\n", "neg_text = \"\"\n", "\n", "if os.path.exists(save_dir) is False:\n", " os.makedirs(save_dir)\n", "enu_index = \"1\"\n", "model.share_cache[\"cfg\"] = cfg_first\n", "\n", "first_stage_samples = get_first_results(model, text, num_frames, H, W, neg_text)\n", "\n", "print(f\"save to {save_dir}/{enu_index}_num_frame_{num_frames}.mp4\")\n", "write_video(filename=f'./{save_dir}/{enu_index}_num_frame_{num_frames}.mp4', \n", " fps=8, \n", " video_array= first_stage_samples, \n", " options = { 'crf': '14' })\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "\n", "second_num_frames = 49\n", "second_model.share_cache[\"ref_noise_step\"] = deg_latent_strength\n", "second_model.share_cache[\"sample_ref_noise_step\"] = deg_latent_strength\n", "second_model.share_cache.pop(\"ref_noise_step_range\", None)\n", "second_model.share_cache[\"target_size\"] = stage_2_hw\n", "second_model.share_cache[\"shift_t\"] = shift_t\n", "second_model.share_cache[\"sample_step\"] = sample_step\n", "second_model.share_cache[\"cfg\"] = cfg_second\n", "post_fix = f'''noise_{second_model.share_cache[\"ref_noise_step\"]}_step_{second_model.share_cache[\"sample_step\"]}_cfg_{second_model.share_cache[\"cfg\"]}_shift_{second_model.share_cache[\"shift_t\"]}_size_{stage_2_hw[0]}x{stage_2_hw[1]}'''\n", "second_model.share_cache[\"time_size_embedding\"] = True\n", "second_stage_samples = get_second_results(second_model, \n", " text, \n", " first_stage_samples, \n", " second_num_frames)\n", "\n", "print(f\"save to {save_dir}/{enu_index}_num_frame_{num_frames}_{post_fix}.mp4\")\n", "write_video(filename=f'./{save_dir}/{enu_index}_num_frame_{num_frames}_{post_fix}_second.mp4', \n", " fps=8, \n", " video_array= second_stage_samples.cpu(), \n", " options = { 'crf': '14' })\n", "\n", "\n", "# save joint video \n", "part_first_stage = first_stage_samples[:second_num_frames]\n", "\n", "target_h, target_w = second_stage_samples.shape[1], second_stage_samples.shape[2]\n", "part_first_stage = torch.nn.functional.interpolate(part_first_stage.permute(0, 3, 1, 2).contiguous(),\n", " size=(target_h, target_w),\n", " mode=\"bilinear\",\n", " align_corners=False, \n", " antialias=True)\n", "part_first_stage = part_first_stage.permute(0, 2, 3, 1).contiguous()\n", "\n", "\n", "joint_video = torch.cat([part_first_stage.cpu(), second_stage_samples.cpu()], dim=-2)\n", "print(f'./{save_dir}/{enu_index}_num_frame_{num_frames}_{post_fix}_joint.mp4')\n", "write_video(filename=f'./{save_dir}/{enu_index}_num_frame_{num_frames}_{post_fix}_joint.mp4',\n", " fps=8,\n", " video_array=joint_video.cpu(),\n", " options={'crf': '15'}) \n" ] } ], "metadata": { "fileId": "c6eed2be-3101-492e-a984-783ecbc70a34", "filePath": "/mnt/bn/foundation-ads/shilong/conda/code/cogvideo-5b/sat/demo_ab.ipynb", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 2 }