Unverified Commit e24de2ec authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

update vae path name (#405)

parent d8827789
...@@ -492,7 +492,7 @@ def run_inference( ...@@ -492,7 +492,7 @@ def run_inference(
"vae_path": find_torch_model_path(model_path, "Wan2.1_VAE.pth"), "vae_path": find_torch_model_path(model_path, "Wan2.1_VAE.pth"),
"use_tiling_vae": use_tiling_vae, "use_tiling_vae": use_tiling_vae,
"use_tae": use_tae, "use_tae": use_tae,
"tae_pth": (find_torch_model_path(model_path, "taew2_1.pth") if use_tae else None), "tae_path": (find_torch_model_path(model_path, "taew2_1.pth") if use_tae else None),
"lazy_load": lazy_load, "lazy_load": lazy_load,
"do_mm_calib": False, "do_mm_calib": False,
"parallel_attn_type": None, "parallel_attn_type": None,
......
...@@ -496,7 +496,7 @@ def run_inference( ...@@ -496,7 +496,7 @@ def run_inference(
"vae_path": find_torch_model_path(model_path, "Wan2.1_VAE.pth"), "vae_path": find_torch_model_path(model_path, "Wan2.1_VAE.pth"),
"use_tiling_vae": use_tiling_vae, "use_tiling_vae": use_tiling_vae,
"use_tae": use_tae, "use_tae": use_tae,
"tae_pth": (find_torch_model_path(model_path, "taew2_1.pth") if use_tae else None), "tae_path": (find_torch_model_path(model_path, "taew2_1.pth") if use_tae else None),
"lazy_load": lazy_load, "lazy_load": lazy_load,
"do_mm_calib": False, "do_mm_calib": False,
"parallel_attn_type": None, "parallel_attn_type": None,
......
...@@ -24,6 +24,6 @@ ...@@ -24,6 +24,6 @@
"clip_quant_scheme": "fp8", "clip_quant_scheme": "fp8",
"use_tiling_vae": true, "use_tiling_vae": true,
"use_tae": true, "use_tae": true,
"tae_pth": "/path/to/taew2_1.pth", "tae_path": "/path/to/taew2_1.pth",
"lazy_load": true "lazy_load": true
} }
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
"clip_quant_scheme": "fp8", "clip_quant_scheme": "fp8",
"use_tiling_vae": true, "use_tiling_vae": true,
"use_tae": true, "use_tae": true,
"tae_pth": "/path/to/taew2_1.pth", "tae_path": "/path/to/taew2_1.pth",
"lazy_load": true, "lazy_load": true,
"rotary_chunk": true, "rotary_chunk": true,
"clean_cuda_cache": true "clean_cuda_cache": true
......
...@@ -29,7 +29,7 @@ In some cases, the VAE component can be time-consuming. You can use a lightweigh ...@@ -29,7 +29,7 @@ In some cases, the VAE component can be time-consuming. You can use a lightweigh
```python ```python
{ {
"use_tae": true, "use_tae": true,
"tae_pth": "/path to taew2_1.pth" "tae_path": "/path to taew2_1.pth"
} }
``` ```
The taew2_1.pth weights can be downloaded from [here](https://github.com/madebyollin/taehv/raw/refs/heads/main/taew2_1.pth) The taew2_1.pth weights can be downloaded from [here](https://github.com/madebyollin/taehv/raw/refs/heads/main/taew2_1.pth)
......
...@@ -161,7 +161,7 @@ use_tiling_vae = True # Enable VAE chunked inference ...@@ -161,7 +161,7 @@ use_tiling_vae = True # Enable VAE chunked inference
```python ```python
# VAE optimization configuration # VAE optimization configuration
use_tae = True # Use lightweight VAE use_tae = True # Use lightweight VAE
tae_pth = "/path to taew2_1.pth" tae_path = "/path to taew2_1.pth"
``` ```
You can download taew2_1.pth [here](https://github.com/madebyollin/taehv/blob/main/taew2_1.pth) You can download taew2_1.pth [here](https://github.com/madebyollin/taehv/blob/main/taew2_1.pth)
......
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
```python ```python
{ {
"use_tae": true, "use_tae": true,
"tae_pth": "/path to taew2_1.pth" "tae_path": "/path to taew2_1.pth"
} }
``` ```
taew2_1.pth 权重可以从[这里](https://github.com/madebyollin/taehv/raw/refs/heads/main/taew2_1.pth)下载 taew2_1.pth 权重可以从[这里](https://github.com/madebyollin/taehv/raw/refs/heads/main/taew2_1.pth)下载
......
...@@ -161,7 +161,7 @@ use_tiling_vae = True # 启用VAE分块推理 ...@@ -161,7 +161,7 @@ use_tiling_vae = True # 启用VAE分块推理
```python ```python
# VAE优化配置 # VAE优化配置
use_tae = True use_tae = True
tae_pth = "/path to taew2_1.pth" tae_path = "/path to taew2_1.pth"
``` ```
taew2_1.pth 权重可以从[这里](https://github.com/madebyollin/taehv/raw/refs/heads/main/taew2_1.pth)下载 taew2_1.pth 权重可以从[这里](https://github.com/madebyollin/taehv/raw/refs/heads/main/taew2_1.pth)下载
......
...@@ -870,7 +870,7 @@ class Wan22AudioRunner(WanAudioRunner): ...@@ -870,7 +870,7 @@ class Wan22AudioRunner(WanAudioRunner):
else: else:
vae_device = torch.device("cuda") vae_device = torch.device("cuda")
vae_config = { vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"), "vae_path": find_torch_model_path(self.config, "vae_path", "Wan2.2_VAE.pth"),
"device": vae_device, "device": vae_device,
"cpu_offload": vae_offload, "cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False), "offload_cache": self.config.get("vae_offload_cache", False),
...@@ -886,7 +886,7 @@ class Wan22AudioRunner(WanAudioRunner): ...@@ -886,7 +886,7 @@ class Wan22AudioRunner(WanAudioRunner):
else: else:
vae_device = torch.device("cuda") vae_device = torch.device("cuda")
vae_config = { vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"), "vae_path": find_torch_model_path(self.config, "vae_path", "Wan2.2_VAE.pth"),
"device": vae_device, "device": vae_device,
"cpu_offload": vae_offload, "cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False), "offload_cache": self.config.get("vae_offload_cache", False),
......
...@@ -146,7 +146,7 @@ class WanRunner(DefaultRunner): ...@@ -146,7 +146,7 @@ class WanRunner(DefaultRunner):
vae_device = torch.device("cuda") vae_device = torch.device("cuda")
vae_config = { vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", self.vae_name), "vae_path": find_torch_model_path(self.config, "vae_path", self.vae_name),
"device": vae_device, "device": vae_device,
"parallel": self.config["parallel"], "parallel": self.config["parallel"],
"use_tiling": self.config.get("use_tiling_vae", False), "use_tiling": self.config.get("use_tiling_vae", False),
...@@ -169,7 +169,7 @@ class WanRunner(DefaultRunner): ...@@ -169,7 +169,7 @@ class WanRunner(DefaultRunner):
vae_device = torch.device("cuda") vae_device = torch.device("cuda")
vae_config = { vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", self.vae_name), "vae_path": find_torch_model_path(self.config, "vae_path", self.vae_name),
"device": vae_device, "device": vae_device,
"parallel": self.config["parallel"], "parallel": self.config["parallel"],
"use_tiling": self.config.get("use_tiling_vae", False), "use_tiling": self.config.get("use_tiling_vae", False),
...@@ -179,8 +179,8 @@ class WanRunner(DefaultRunner): ...@@ -179,8 +179,8 @@ class WanRunner(DefaultRunner):
"load_from_rank0": self.config.get("load_from_rank0", False), "load_from_rank0": self.config.get("load_from_rank0", False),
} }
if self.config.get("use_tae", False): if self.config.get("use_tae", False):
tae_pth = find_torch_model_path(self.config, "tae_pth", self.tiny_vae_name) tae_path = find_torch_model_path(self.config, "tae_path", self.tiny_vae_name)
vae_decoder = self.tiny_vae_cls(vae_pth=tae_pth, device=self.init_device, need_scaled=self.config.get("need_scaled", False)).to("cuda") vae_decoder = self.tiny_vae_cls(vae_path=tae_path, device=self.init_device, need_scaled=self.config.get("need_scaled", False)).to("cuda")
else: else:
vae_decoder = self.vae_cls(**vae_config) vae_decoder = self.vae_cls(**vae_config)
return vae_decoder return vae_decoder
......
...@@ -63,7 +63,7 @@ if __name__ == "__main__": ...@@ -63,7 +63,7 @@ if __name__ == "__main__":
dtype = dtype_map[args.dtype] dtype = dtype_map[args.dtype]
model_args = {"vae_pth": args.checkpoint, "dtype": dtype, "device": dev} model_args = {"vae_path": args.checkpoint, "dtype": dtype, "device": dev}
if args.model_type in "vaew2_1": if args.model_type in "vaew2_1":
model_args.update({"use_lightvae": args.use_lightvae}) model_args.update({"use_lightvae": args.use_lightvae})
......
...@@ -795,7 +795,7 @@ class WanVAE: ...@@ -795,7 +795,7 @@ class WanVAE:
def __init__( def __init__(
self, self,
z_dim=16, z_dim=16,
vae_pth="cache/vae_step_411000.pth", vae_path="cache/vae_step_411000.pth",
dtype=torch.float, dtype=torch.float,
device="cuda", device="cuda",
parallel=False, parallel=False,
...@@ -895,7 +895,7 @@ class WanVAE: ...@@ -895,7 +895,7 @@ class WanVAE:
# init model # init model
self.model = ( self.model = (
_video_vae(pretrained_path=vae_pth, z_dim=z_dim, cpu_offload=cpu_offload, dtype=dtype, load_from_rank0=load_from_rank0, pruning_rate=pruning_rate) _video_vae(pretrained_path=vae_path, z_dim=z_dim, cpu_offload=cpu_offload, dtype=dtype, load_from_rank0=load_from_rank0, pruning_rate=pruning_rate)
.eval() .eval()
.requires_grad_(False) .requires_grad_(False)
.to(device) .to(device)
......
...@@ -866,7 +866,7 @@ class Wan2_2_VAE: ...@@ -866,7 +866,7 @@ class Wan2_2_VAE:
self, self,
z_dim=48, z_dim=48,
c_dim=160, c_dim=160,
vae_pth=None, vae_path=None,
dim_mult=[1, 2, 4, 4], dim_mult=[1, 2, 4, 4],
temperal_downsample=[False, True, True], temperal_downsample=[False, True, True],
dtype=torch.float, dtype=torch.float,
...@@ -994,7 +994,7 @@ class Wan2_2_VAE: ...@@ -994,7 +994,7 @@ class Wan2_2_VAE:
# init model # init model
self.model = ( self.model = (
_video_vae( _video_vae(
pretrained_path=vae_pth, z_dim=z_dim, dim=c_dim, dim_mult=dim_mult, temperal_downsample=temperal_downsample, cpu_offload=cpu_offload, dtype=dtype, load_from_rank0=load_from_rank0 pretrained_path=vae_path, z_dim=z_dim, dim=c_dim, dim_mult=dim_mult, temperal_downsample=temperal_downsample, cpu_offload=cpu_offload, dtype=dtype, load_from_rank0=load_from_rank0
) )
.eval() .eval()
.requires_grad_(False) .requires_grad_(False)
......
...@@ -7,7 +7,7 @@ class WanSFVAE: ...@@ -7,7 +7,7 @@ class WanSFVAE:
def __init__( def __init__(
self, self,
z_dim=16, z_dim=16,
vae_pth="cache/vae_step_411000.pth", vae_path="cache/vae_step_411000.pth",
dtype=torch.float, dtype=torch.float,
device="cuda", device="cuda",
parallel=False, parallel=False,
...@@ -29,7 +29,7 @@ class WanSFVAE: ...@@ -29,7 +29,7 @@ class WanSFVAE:
self.std = torch.tensor(std, dtype=torch.float32) self.std = torch.tensor(std, dtype=torch.float32)
# init model # init model
self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim, cpu_offload=cpu_offload, dtype=dtype, load_from_rank0=load_from_rank0).eval().requires_grad_(False).to(device).to(dtype) self.model = _video_vae(pretrained_path=vae_path, z_dim=z_dim, cpu_offload=cpu_offload, dtype=dtype, load_from_rank0=load_from_rank0).eval().requires_grad_(False).to(device).to(dtype)
self.model.clear_cache() self.model.clear_cache()
def to_cpu(self): def to_cpu(self):
......
...@@ -11,11 +11,11 @@ class DotDict(dict): ...@@ -11,11 +11,11 @@ class DotDict(dict):
class WanVAE_tiny(nn.Module): class WanVAE_tiny(nn.Module):
def __init__(self, vae_pth="taew2_1.pth", dtype=torch.bfloat16, device="cuda", need_scaled=False): def __init__(self, vae_path="taew2_1.pth", dtype=torch.bfloat16, device="cuda", need_scaled=False):
super().__init__() super().__init__()
self.dtype = dtype self.dtype = dtype
self.device = torch.device("cuda") self.device = torch.device("cuda")
self.taehv = TAEHV(vae_pth).to(self.dtype) self.taehv = TAEHV(vae_path).to(self.dtype)
self.temperal_downsample = [True, True, False] self.temperal_downsample = [True, True, False]
self.need_scaled = need_scaled self.need_scaled = need_scaled
...@@ -83,11 +83,11 @@ class WanVAE_tiny(nn.Module): ...@@ -83,11 +83,11 @@ class WanVAE_tiny(nn.Module):
class Wan2_2_VAE_tiny(nn.Module): class Wan2_2_VAE_tiny(nn.Module):
def __init__(self, vae_pth="taew2_2.pth", dtype=torch.bfloat16, device="cuda", need_scaled=False): def __init__(self, vae_path="taew2_2.pth", dtype=torch.bfloat16, device="cuda", need_scaled=False):
super().__init__() super().__init__()
self.dtype = dtype self.dtype = dtype
self.device = torch.device("cuda") self.device = torch.device("cuda")
self.taehv = TAEHV(vae_pth, model_type="wan22").to(self.dtype) self.taehv = TAEHV(vae_path, model_type="wan22").to(self.dtype)
self.need_scaled = need_scaled self.need_scaled = need_scaled
if self.need_scaled: if self.need_scaled:
self.latents_mean = [ self.latents_mean = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment