Unverified Commit e24de2ec authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

update vae path name (#405)

parent d8827789
......@@ -492,7 +492,7 @@ def run_inference(
"vae_path": find_torch_model_path(model_path, "Wan2.1_VAE.pth"),
"use_tiling_vae": use_tiling_vae,
"use_tae": use_tae,
"tae_pth": (find_torch_model_path(model_path, "taew2_1.pth") if use_tae else None),
"tae_path": (find_torch_model_path(model_path, "taew2_1.pth") if use_tae else None),
"lazy_load": lazy_load,
"do_mm_calib": False,
"parallel_attn_type": None,
......
......@@ -496,7 +496,7 @@ def run_inference(
"vae_path": find_torch_model_path(model_path, "Wan2.1_VAE.pth"),
"use_tiling_vae": use_tiling_vae,
"use_tae": use_tae,
"tae_pth": (find_torch_model_path(model_path, "taew2_1.pth") if use_tae else None),
"tae_path": (find_torch_model_path(model_path, "taew2_1.pth") if use_tae else None),
"lazy_load": lazy_load,
"do_mm_calib": False,
"parallel_attn_type": None,
......
......@@ -24,6 +24,6 @@
"clip_quant_scheme": "fp8",
"use_tiling_vae": true,
"use_tae": true,
"tae_pth": "/path/to/taew2_1.pth",
"tae_path": "/path/to/taew2_1.pth",
"lazy_load": true
}
......@@ -24,7 +24,7 @@
"clip_quant_scheme": "fp8",
"use_tiling_vae": true,
"use_tae": true,
"tae_pth": "/path/to/taew2_1.pth",
"tae_path": "/path/to/taew2_1.pth",
"lazy_load": true,
"rotary_chunk": true,
"clean_cuda_cache": true
......
......@@ -29,7 +29,7 @@ In some cases, the VAE component can be time-consuming. You can use a lightweigh
```python
{
"use_tae": true,
"tae_pth": "/path to taew2_1.pth"
"tae_path": "/path to taew2_1.pth"
}
```
The taew2_1.pth weights can be downloaded from [here](https://github.com/madebyollin/taehv/raw/refs/heads/main/taew2_1.pth)
......
......@@ -161,7 +161,7 @@ use_tiling_vae = True # Enable VAE chunked inference
```python
# VAE optimization configuration
use_tae = True # Use lightweight VAE
tae_pth = "/path to taew2_1.pth"
tae_path = "/path to taew2_1.pth"
```
You can download taew2_1.pth [here](https://github.com/madebyollin/taehv/blob/main/taew2_1.pth)
......
......@@ -29,7 +29,7 @@
```python
{
"use_tae": true,
"tae_pth": "/path to taew2_1.pth"
"tae_path": "/path to taew2_1.pth"
}
```
taew2_1.pth 权重可以从[这里](https://github.com/madebyollin/taehv/raw/refs/heads/main/taew2_1.pth)下载
......
......@@ -161,7 +161,7 @@ use_tiling_vae = True # 启用VAE分块推理
```python
# VAE优化配置
use_tae = True
tae_pth = "/path to taew2_1.pth"
tae_path = "/path to taew2_1.pth"
```
taew2_1.pth 权重可以从[这里](https://github.com/madebyollin/taehv/raw/refs/heads/main/taew2_1.pth)下载
......
......@@ -870,7 +870,7 @@ class Wan22AudioRunner(WanAudioRunner):
else:
vae_device = torch.device("cuda")
vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
"vae_path": find_torch_model_path(self.config, "vae_path", "Wan2.2_VAE.pth"),
"device": vae_device,
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
......@@ -886,7 +886,7 @@ class Wan22AudioRunner(WanAudioRunner):
else:
vae_device = torch.device("cuda")
vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
"vae_path": find_torch_model_path(self.config, "vae_path", "Wan2.2_VAE.pth"),
"device": vae_device,
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
......
......@@ -146,7 +146,7 @@ class WanRunner(DefaultRunner):
vae_device = torch.device("cuda")
vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", self.vae_name),
"vae_path": find_torch_model_path(self.config, "vae_path", self.vae_name),
"device": vae_device,
"parallel": self.config["parallel"],
"use_tiling": self.config.get("use_tiling_vae", False),
......@@ -169,7 +169,7 @@ class WanRunner(DefaultRunner):
vae_device = torch.device("cuda")
vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", self.vae_name),
"vae_path": find_torch_model_path(self.config, "vae_path", self.vae_name),
"device": vae_device,
"parallel": self.config["parallel"],
"use_tiling": self.config.get("use_tiling_vae", False),
......@@ -179,8 +179,8 @@ class WanRunner(DefaultRunner):
"load_from_rank0": self.config.get("load_from_rank0", False),
}
if self.config.get("use_tae", False):
tae_pth = find_torch_model_path(self.config, "tae_pth", self.tiny_vae_name)
vae_decoder = self.tiny_vae_cls(vae_pth=tae_pth, device=self.init_device, need_scaled=self.config.get("need_scaled", False)).to("cuda")
tae_path = find_torch_model_path(self.config, "tae_path", self.tiny_vae_name)
vae_decoder = self.tiny_vae_cls(vae_path=tae_path, device=self.init_device, need_scaled=self.config.get("need_scaled", False)).to("cuda")
else:
vae_decoder = self.vae_cls(**vae_config)
return vae_decoder
......
......@@ -63,7 +63,7 @@ if __name__ == "__main__":
dtype = dtype_map[args.dtype]
model_args = {"vae_pth": args.checkpoint, "dtype": dtype, "device": dev}
model_args = {"vae_path": args.checkpoint, "dtype": dtype, "device": dev}
if args.model_type in "vaew2_1":
model_args.update({"use_lightvae": args.use_lightvae})
......
......@@ -795,7 +795,7 @@ class WanVAE:
def __init__(
self,
z_dim=16,
vae_pth="cache/vae_step_411000.pth",
vae_path="cache/vae_step_411000.pth",
dtype=torch.float,
device="cuda",
parallel=False,
......@@ -895,7 +895,7 @@ class WanVAE:
# init model
self.model = (
_video_vae(pretrained_path=vae_pth, z_dim=z_dim, cpu_offload=cpu_offload, dtype=dtype, load_from_rank0=load_from_rank0, pruning_rate=pruning_rate)
_video_vae(pretrained_path=vae_path, z_dim=z_dim, cpu_offload=cpu_offload, dtype=dtype, load_from_rank0=load_from_rank0, pruning_rate=pruning_rate)
.eval()
.requires_grad_(False)
.to(device)
......
......@@ -866,7 +866,7 @@ class Wan2_2_VAE:
self,
z_dim=48,
c_dim=160,
vae_pth=None,
vae_path=None,
dim_mult=[1, 2, 4, 4],
temperal_downsample=[False, True, True],
dtype=torch.float,
......@@ -994,7 +994,7 @@ class Wan2_2_VAE:
# init model
self.model = (
_video_vae(
pretrained_path=vae_pth, z_dim=z_dim, dim=c_dim, dim_mult=dim_mult, temperal_downsample=temperal_downsample, cpu_offload=cpu_offload, dtype=dtype, load_from_rank0=load_from_rank0
pretrained_path=vae_path, z_dim=z_dim, dim=c_dim, dim_mult=dim_mult, temperal_downsample=temperal_downsample, cpu_offload=cpu_offload, dtype=dtype, load_from_rank0=load_from_rank0
)
.eval()
.requires_grad_(False)
......
......@@ -7,7 +7,7 @@ class WanSFVAE:
def __init__(
self,
z_dim=16,
vae_pth="cache/vae_step_411000.pth",
vae_path="cache/vae_step_411000.pth",
dtype=torch.float,
device="cuda",
parallel=False,
......@@ -29,7 +29,7 @@ class WanSFVAE:
self.std = torch.tensor(std, dtype=torch.float32)
# init model
self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim, cpu_offload=cpu_offload, dtype=dtype, load_from_rank0=load_from_rank0).eval().requires_grad_(False).to(device).to(dtype)
self.model = _video_vae(pretrained_path=vae_path, z_dim=z_dim, cpu_offload=cpu_offload, dtype=dtype, load_from_rank0=load_from_rank0).eval().requires_grad_(False).to(device).to(dtype)
self.model.clear_cache()
def to_cpu(self):
......
......@@ -11,11 +11,11 @@ class DotDict(dict):
class WanVAE_tiny(nn.Module):
def __init__(self, vae_pth="taew2_1.pth", dtype=torch.bfloat16, device="cuda", need_scaled=False):
def __init__(self, vae_path="taew2_1.pth", dtype=torch.bfloat16, device="cuda", need_scaled=False):
super().__init__()
self.dtype = dtype
self.device = torch.device("cuda")
self.taehv = TAEHV(vae_pth).to(self.dtype)
self.taehv = TAEHV(vae_path).to(self.dtype)
self.temperal_downsample = [True, True, False]
self.need_scaled = need_scaled
......@@ -83,11 +83,11 @@ class WanVAE_tiny(nn.Module):
class Wan2_2_VAE_tiny(nn.Module):
def __init__(self, vae_pth="taew2_2.pth", dtype=torch.bfloat16, device="cuda", need_scaled=False):
def __init__(self, vae_path="taew2_2.pth", dtype=torch.bfloat16, device="cuda", need_scaled=False):
super().__init__()
self.dtype = dtype
self.device = torch.device("cuda")
self.taehv = TAEHV(vae_pth, model_type="wan22").to(self.dtype)
self.taehv = TAEHV(vae_path, model_type="wan22").to(self.dtype)
self.need_scaled = need_scaled
if self.need_scaled:
self.latents_mean = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment