Commit b93699b0 authored by TorynCurtis's avatar TorynCurtis Committed by Yang Yong(雍洋)
Browse files

wan model cpu_offload (#3)

* 修改了main.py, t5的model, wan的model、三个weights文件和三个infer文件, 并且在common的conv3d算子中注册新算子

* 修改了Conv3dWeightForceBF16算子,更新了wan的pre_weights中对此算子的使用

* 修复了import中的bug

* 修复了WanPreWeights, WanTransformerWeights没有self.config的bug

* 修复了WanPreWeights, WanTransformerWeights没有self.config的bug

* 修复了config的bug,目前在使用cpu_offload的时候,vae阶段有tensor不在同一device的bug

* 修复了vae阶段迁移的bug

* 修复了scale在mean和inv_std迁移后仍需重新赋值的bug
parent 73a30e28
...@@ -56,3 +56,13 @@ class Conv3dWeight(Conv3dWeightTemplate): ...@@ -56,3 +56,13 @@ class Conv3dWeight(Conv3dWeightTemplate):
self.weight = self.weight.cuda() self.weight = self.weight.cuda()
if self.bias is not None: if self.bias is not None:
self.bias = self.bias.cuda() self.bias = self.bias.cuda()
@CONV3D_WEIGHT_REGISTER('Defaultt-Force-BF16')
class Conv3dWeightForceBF16(Conv3dWeight):
def __init__(self, weight_name, bias_name, stride=1, padding=0, dilation=1, groups=1):
super().__init__(weight_name, bias_name, stride, padding, dilation, groups)
def load(self, weight_dict):
self.weight = weight_dict[self.weight_name].to(torch.bfloat16).cuda()
self.bias = weight_dict[self.bias_name].to(torch.bfloat16).cuda() if self.bias_name is not None else None
\ No newline at end of file
...@@ -14,7 +14,7 @@ class WanPostInfer: ...@@ -14,7 +14,7 @@ class WanPostInfer:
x, (x.shape[1],), None, None, 1e-6 x, (x.shape[1],), None, None, 1e-6
).type_as(x) ).type_as(x)
out = norm_out * (1 + e[1].squeeze(0)) + e[0].squeeze(0) out = norm_out * (1 + e[1].squeeze(0)) + e[0].squeeze(0)
x = torch.addmm(weights.head_bias, out, weights.head_weight.t()) x = weights.head.apply(out)
x = self.unpatchify(x, grid_sizes) x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x] return [u.float() for u in x]
......
...@@ -30,7 +30,7 @@ class WanPreInfer: ...@@ -30,7 +30,7 @@ class WanPreInfer:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings # embeddings
x = [weights.patch_embedding(u.unsqueeze(0)) for u in x] x = [weights.patch_embedding.apply(u.unsqueeze(0)) for u in x]
grid_sizes = torch.stack( grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x] [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]
) )
...@@ -45,24 +45,12 @@ class WanPreInfer: ...@@ -45,24 +45,12 @@ class WanPreInfer:
) )
embed = sinusoidal_embedding_1d(self.freq_dim, t) embed = sinusoidal_embedding_1d(self.freq_dim, t)
embed = torch.addmm( embed = weights.time_embedding_0.apply(embed)
weights.time_embedding_0_bias,
embed,
weights.time_embedding_0_weight.t(),
)
embed = torch.nn.functional.silu(embed) embed = torch.nn.functional.silu(embed)
embed = torch.addmm( embed = weights.time_embedding_2.apply(embed)
weights.time_embedding_2_bias,
embed,
weights.time_embedding_2_weight.t(),
)
embed0 = torch.nn.functional.silu(embed) embed0 = torch.nn.functional.silu(embed)
embed0 = torch.addmm( embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim))
weights.time_projection_1_bias,
embed0,
weights.time_projection_1_weight.t(),
).unflatten(1, (6, self.dim))
# text embeddings # text embeddings
stacked = torch.stack( stacked = torch.stack(
...@@ -71,45 +59,17 @@ class WanPreInfer: ...@@ -71,45 +59,17 @@ class WanPreInfer:
for u in context for u in context
] ]
) )
out = torch.addmm( out = weights.text_embedding_0.apply(stacked.squeeze(0))
weights.text_embedding_0_bias,
stacked.squeeze(0),
weights.text_embedding_0_weight.t(),
)
out = torch.nn.functional.gelu(out, approximate="tanh") out = torch.nn.functional.gelu(out, approximate="tanh")
context = torch.addmm( context = weights.text_embedding_2.apply(out)
weights.text_embedding_2_bias,
out,
weights.text_embedding_2_weight.t(),
)
if self.task == 'i2v': if self.task == 'i2v':
context_clip = torch.nn.functional.layer_norm( context_clip = weights.proj_0.apply(clip_fea)
clip_fea, context_clip = weights.proj_1.apply(context_clip)
normalized_shape=(clip_fea.shape[1],),
weight=weights.proj_0_weight,
bias=weights.proj_0_bias,
eps=1e-5,
)
context_clip = torch.addmm(
weights.proj_1_bias,
context_clip,
weights.proj_1_weight.t(),
)
context_clip = torch.nn.functional.gelu(context_clip, approximate="none") context_clip = torch.nn.functional.gelu(context_clip, approximate="none")
context_clip = torch.addmm( context_clip = weights.proj_3.apply(context_clip)
weights.proj_3_bias, context_clip = weights.proj_4.apply(context_clip)
context_clip,
weights.proj_3_weight.t(),
)
context_clip = torch.nn.functional.layer_norm(
context_clip,
normalized_shape=(context_clip.shape[1],),
weight=weights.proj_4_weight,
bias=weights.proj_4_bias,
eps=1e-5,
)
context = torch.concat([context_clip, context], dim=0) context = torch.concat([context_clip, context], dim=0)
return ( return (
......
...@@ -107,13 +107,7 @@ class WanTransformerInfer: ...@@ -107,13 +107,7 @@ class WanTransformerInfer:
x = x + y * embed0[2].squeeze(0) x = x + y * embed0[2].squeeze(0)
norm3_out = torch.nn.functional.layer_norm( norm3_out = weights.norm3.apply(x)
x,
normalized_shape=(x.shape[1],),
weight=weights.norm3_weight,
bias=weights.norm3_bias,
eps=1e-6,
)
if self.task == 'i2v': if self.task == 'i2v':
context_img = context[:257] context_img = context[:257]
......
...@@ -32,6 +32,9 @@ class WanModel: ...@@ -32,6 +32,9 @@ class WanModel:
if config['parallel_attn']: if config['parallel_attn']:
parallelize_wan(self) parallelize_wan(self)
if self.config['cpu_offload']:
self.to_cpu()
def _init_infer_class(self): def _init_infer_class(self):
self.pre_infer_class = WanPreInfer self.pre_infer_class = WanPreInfer
self.post_infer_class = WanPostInfer self.post_infer_class = WanPostInfer
...@@ -69,7 +72,7 @@ class WanModel: ...@@ -69,7 +72,7 @@ class WanModel:
weight_dict = self._load_ckpt() weight_dict = self._load_ckpt()
# init weights # init weights
self.pre_weight = self.pre_weight_class(self.config) self.pre_weight = self.pre_weight_class(self.config)
self.post_weight = self.post_weight_class() self.post_weight = self.post_weight_class(self.config)
self.transformer_weights = self.transformer_weight_class(self.config) self.transformer_weights = self.transformer_weight_class(self.config)
# load weights # load weights
self.pre_weight.load_weights(weight_dict) self.pre_weight.load_weights(weight_dict)
...@@ -85,6 +88,16 @@ class WanModel: ...@@ -85,6 +88,16 @@ class WanModel:
self.scheduler = scheduler self.scheduler = scheduler
self.transformer_infer.set_scheduler(scheduler) self.transformer_infer.set_scheduler(scheduler)
def to_cpu(self):
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
self.transformer_weights.to_cpu()
def to_cuda(self):
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
self.transformer_weights.to_cuda()
@torch.no_grad() @torch.no_grad()
def infer(self, text_encoders_output, image_encoder_output, args): def infer(self, text_encoders_output, image_encoder_output, args):
......
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
class WanPostWeights: class WanPostWeights:
def __init__(self): def __init__(self, config):
pass self.config = config
def load_weights(self, weight_dict): def load_weights(self, weight_dict):
head_layers = {"head": ["head.weight", "head.bias", "modulation"]}
for param_name, param_keys in head_layers.items(): self.head = MM_WEIGHT_REGISTER["Default"]('head.head.weight','head.head.bias')
for key in param_keys: self.head_modulation = weight_dict['head.modulation']
weight_path = f"{param_name}.{key}"
key = key.split('.') self.weight_list = [
setattr(self, f"{param_name}_{key[-1]}", weight_dict[weight_path]) self.head,
\ No newline at end of file self.head_modulation
]
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate):
mm_weight.set_config(self.config['mm_config'])
mm_weight.load(weight_dict)
def to_cpu(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate):
mm_weight.to_cpu()
else:
mm_weight.cpu()
def to_cuda(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate):
mm_weight.to_cuda()
else:
mm_weight.cuda()
\ No newline at end of file
import torch import torch
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, CONV3D_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate
from lightx2v.common.ops.conv.conv3d import Conv3dWeightTemplate
class WanPreWeights: class WanPreWeights:
def __init__(self, config): def __init__(self, config):
self.in_dim = config["in_dim"] self.in_dim = config["in_dim"]
self.dim = config["dim"] self.dim = config["dim"]
self.patch_size = (1, 2, 2) self.patch_size = (1, 2, 2)
self.config = config
def load_weights(self, weight_dict): def load_weights(self, weight_dict):
layers = {
"text_embedding": {"0": ["weight", "bias"], "2": ["weight", "bias"]}, self.patch_embedding = CONV3D_WEIGHT_REGISTER["Defaultt-Force-BF16"]('patch_embedding.weight', 'patch_embedding.bias', stride=self.patch_size)
"time_embedding": {"0": ["weight", "bias"], "2": ["weight", "bias"]},
"time_projection": {"1": ["weight", "bias"]}, self.text_embedding_0 = MM_WEIGHT_REGISTER["Default"]('text_embedding.0.weight', 'text_embedding.0.bias')
} self.text_embedding_2 = MM_WEIGHT_REGISTER["Default"]('text_embedding.2.weight', 'text_embedding.2.bias')
self.time_embedding_0 = MM_WEIGHT_REGISTER["Default"]('time_embedding.0.weight', 'time_embedding.0.bias')
self.patch_embedding = ( self.time_embedding_2 = MM_WEIGHT_REGISTER["Default"]('time_embedding.2.weight', 'time_embedding.2.bias')
torch.nn.Conv3d( self.time_projection_1 = MM_WEIGHT_REGISTER["Default"]('time_projection.1.weight', 'time_projection.1.bias')
self.in_dim,
self.dim, self.weight_list = [
kernel_size=self.patch_size, self.patch_embedding,
stride=self.patch_size,
) self.text_embedding_0,
.to(torch.bfloat16) self.text_embedding_2,
.cuda() self.time_embedding_0,
) self.time_embedding_2,
self.patch_embedding.weight.data.copy_(weight_dict["patch_embedding.weight"]) self.time_projection_1,
self.patch_embedding.bias.data.copy_(weight_dict["patch_embedding.bias"]) ]
for module_name, sub_layers in layers.items():
for param_name, param_keys in sub_layers.items():
for key in param_keys:
weight_path = f"{module_name}.{param_name}.{key}"
setattr(
self,
f"{module_name}_{param_name}_{key}",
weight_dict[weight_path],
)
if 'img_emb.proj.0.weight' in weight_dict.keys(): if 'img_emb.proj.0.weight' in weight_dict.keys():
MLP_layers = { self.proj_0 = LN_WEIGHT_REGISTER["Default"]('img_emb.proj.0.weight', 'img_emb.proj.0.bias', eps=1e-5)
"proj_0_weight": "proj.0.weight", self.proj_1 = MM_WEIGHT_REGISTER["Default"]('img_emb.proj.1.weight', 'img_emb.proj.1.bias')
"proj_0_bias": "proj.0.bias", self.proj_3 = MM_WEIGHT_REGISTER["Default"]('img_emb.proj.3.weight', 'img_emb.proj.3.bias')
"proj_1_weight": "proj.1.weight", self.proj_4 = LN_WEIGHT_REGISTER["Default"]('img_emb.proj.4.weight', 'img_emb.proj.4.bias', eps=1e-5)
"proj_1_bias": "proj.1.bias", self.weight_list.append(self.proj_0)
"proj_3_weight": "proj.3.weight", self.weight_list.append(self.proj_1)
"proj_3_bias": "proj.3.bias", self.weight_list.append(self.proj_3)
"proj_4_weight": "proj.4.weight", self.weight_list.append(self.proj_4)
"proj_4_bias": "proj.4.bias",
} for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, LNWeightTemplate) or isinstance(mm_weight, Conv3dWeightTemplate):
for layer_name, weight_keys in MLP_layers.items(): mm_weight.set_config(self.config['mm_config'])
weight_path = f"img_emb.{weight_keys}" mm_weight.load(weight_dict)
setattr(self, layer_name, weight_dict[weight_path])
\ No newline at end of file def to_cpu(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, LNWeightTemplate) or isinstance(mm_weight, Conv3dWeightTemplate):
mm_weight.to_cpu()
def to_cuda(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, LNWeightTemplate) or isinstance(mm_weight, Conv3dWeightTemplate):
mm_weight.to_cuda()
\ No newline at end of file
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate
class WanTransformerWeights: class WanTransformerWeights:
def __init__(self, config): def __init__(self, config):
self.blocks_num = config["num_layers"] self.blocks_num = config["num_layers"]
self.task = config['task'] self.task = config['task']
self.config = config
if config['do_mm_calib']: if config['do_mm_calib']:
self.mm_type = 'Calib' self.mm_type = 'Calib'
else: else:
...@@ -13,69 +15,88 @@ class WanTransformerWeights: ...@@ -13,69 +15,88 @@ class WanTransformerWeights:
def load_weights(self, weight_dict): def load_weights(self, weight_dict):
self.blocks_weights = [ self.blocks_weights = [
WanTransformerAttentionBlock(i, self.task, self.mm_type) for i in range(self.blocks_num) WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config) for i in range(self.blocks_num)
] ]
for block in self.blocks_weights: for block in self.blocks_weights:
block.load_weights(weight_dict) block.load_weights(weight_dict)
def to_cpu(self):
for block in self.blocks_weights:
block.to_cpu()
def to_cuda(self):
for block in self.blocks_weights:
block.to_cuda()
class WanTransformerAttentionBlock: class WanTransformerAttentionBlock:
def __init__(self, block_index, task, mm_type): def __init__(self, block_index, task, mm_type, config):
self.block_index = block_index self.block_index = block_index
self.mm_type = mm_type self.mm_type = mm_type
self.task = task self.task = task
self.config = config
def load_weights(self, weight_dict): def load_weights(self, weight_dict):
if self.task == 't2v':
layers = {
"self_attn_q": ["self_attn.q.weight", "self_attn.q.bias"],
"self_attn_k": ["self_attn.k.weight", "self_attn.k.bias"],
"self_attn_v": ["self_attn.v.weight", "self_attn.v.bias"],
"self_attn_o": ["self_attn.o.weight", "self_attn.o.bias"],
"self_attn_norm_q_weight": "self_attn.norm_q.weight",
"self_attn_norm_k_weight": "self_attn.norm_k.weight",
"norm3_weight": "norm3.weight",
"norm3_bias": "norm3.bias",
"cross_attn_q": ["cross_attn.q.weight", "cross_attn.q.bias"],
"cross_attn_k": ["cross_attn.k.weight", "cross_attn.k.bias"],
"cross_attn_v": ["cross_attn.v.weight", "cross_attn.v.bias"],
"cross_attn_o": ["cross_attn.o.weight", "cross_attn.o.bias"],
"cross_attn_norm_q_weight": "cross_attn.norm_q.weight",
"cross_attn_norm_k_weight": "cross_attn.norm_k.weight",
"ffn_0": ["ffn.0.weight", "ffn.0.bias"],
"ffn_2": ["ffn.2.weight", "ffn.2.bias"],
"modulation": "modulation",
}
elif self.task == 'i2v':
layers = {
"self_attn_q": ["self_attn.q.weight", "self_attn.q.bias"],
"self_attn_k": ["self_attn.k.weight", "self_attn.k.bias"],
"self_attn_v": ["self_attn.v.weight", "self_attn.v.bias"],
"self_attn_o": ["self_attn.o.weight", "self_attn.o.bias"],
"self_attn_norm_q_weight": "self_attn.norm_q.weight",
"self_attn_norm_k_weight": "self_attn.norm_k.weight",
"norm3_weight": "norm3.weight",
"norm3_bias": "norm3.bias",
"cross_attn_q": ["cross_attn.q.weight", "cross_attn.q.bias"],
"cross_attn_k": ["cross_attn.k.weight", "cross_attn.k.bias"],
"cross_attn_v": ["cross_attn.v.weight", "cross_attn.v.bias"],
"cross_attn_o": ["cross_attn.o.weight", "cross_attn.o.bias"],
"cross_attn_norm_q_weight": "cross_attn.norm_q.weight",
"cross_attn_norm_k_weight": "cross_attn.norm_k.weight",
"cross_attn_k_img": ["cross_attn.k_img.weight", "cross_attn.k_img.bias"],
"cross_attn_v_img": ["cross_attn.v_img.weight", "cross_attn.v_img.bias"],
"cross_attn_norm_k_img_weight": "cross_attn.norm_k_img.weight",
"ffn_0": ["ffn.0.weight", "ffn.0.bias"],
"ffn_2": ["ffn.2.weight", "ffn.2.bias"],
"modulation": "modulation",
}
for layer_name, weight_keys in layers.items(): self.self_attn_q = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.self_attn.q.weight',f'blocks.{self.block_index}.self_attn.q.bias')
if isinstance(weight_keys, list): self.self_attn_k = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.self_attn.k.weight',f'blocks.{self.block_index}.self_attn.k.bias')
weight_key, bias_key = weight_keys self.self_attn_v = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.self_attn.v.weight',f'blocks.{self.block_index}.self_attn.v.bias')
weight_path = f"blocks.{self.block_index}.{weight_key}" self.self_attn_o = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.self_attn.o.weight',f'blocks.{self.block_index}.self_attn.o.bias')
bias_path = f"blocks.{self.block_index}.{bias_key}" self.self_attn_norm_q_weight = weight_dict[f'blocks.{self.block_index}.self_attn.norm_q.weight']
setattr(self, layer_name, MM_WEIGHT_REGISTER[self.mm_type](weight_path, bias_path)) self.self_attn_norm_k_weight = weight_dict[f'blocks.{self.block_index}.self_attn.norm_k.weight']
getattr(self, layer_name).load(weight_dict) self.norm3 = LN_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.norm3.weight',f'blocks.{self.block_index}.norm3.bias',eps = 1e-6)
self.cross_attn_q = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.cross_attn.q.weight',f'blocks.{self.block_index}.cross_attn.q.bias')
self.cross_attn_k = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.cross_attn.k.weight',f'blocks.{self.block_index}.cross_attn.k.bias')
self.cross_attn_v = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.cross_attn.v.weight',f'blocks.{self.block_index}.cross_attn.v.bias')
self.cross_attn_o = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.cross_attn.o.weight',f'blocks.{self.block_index}.cross_attn.o.bias')
self.cross_attn_norm_q_weight = weight_dict[f'blocks.{self.block_index}.cross_attn.norm_q.weight']
self.cross_attn_norm_k_weight = weight_dict[f'blocks.{self.block_index}.cross_attn.norm_k.weight']
self.ffn_0 = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.ffn.0.weight',f'blocks.{self.block_index}.ffn.0.bias')
self.ffn_2 = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.ffn.2.weight',f'blocks.{self.block_index}.ffn.2.bias')
self.modulation = weight_dict[f'blocks.{self.block_index}.modulation']
self.weight_list = [
self.self_attn_q,
self.self_attn_k,
self.self_attn_v,
self.self_attn_o,
self.self_attn_norm_q_weight,
self.self_attn_norm_k_weight,
self.norm3,
self.cross_attn_q,
self.cross_attn_k,
self.cross_attn_v,
self.cross_attn_o,
self.cross_attn_norm_q_weight,
self.cross_attn_norm_k_weight,
self.ffn_0,
self.ffn_2,
self.modulation,
]
if self.task == 'i2v':
self.cross_attn_k_img = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.cross_attn.k_img.weight',f'blocks.{self.block_index}.cross_attn.k_img.bias')
self.cross_attn_v_img = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.cross_attn.v_img.weight',f'blocks.{self.block_index}.cross_attn.v_img.bias')
self.cross_attn_norm_k_img_weight = weight_dict[f'blocks.{self.block_index}.cross_attn.norm_k_img.weight']
self.weight_list.append(self.cross_attn_k_img)
self.weight_list.append(self.cross_attn_v_img)
self.weight_list.append(self.cross_attn_norm_k_img_weight)
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, LNWeightTemplate):
mm_weight.set_config(self.config['mm_config'])
mm_weight.load(weight_dict)
def to_cpu(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, LNWeightTemplate):
mm_weight.to_cpu()
else:
mm_weight.cpu()
def to_cuda(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, LNWeightTemplate):
mm_weight.to_cuda()
else: else:
weight_path = f"blocks.{self.block_index}.{weight_keys}" mm_weight.cuda()
setattr(self, layer_name, weight_dict[weight_path]) \ No newline at end of file
\ No newline at end of file
...@@ -566,12 +566,25 @@ class T5EncoderModel: ...@@ -566,12 +566,25 @@ class T5EncoderModel:
name=tokenizer_path, seq_len=text_len, clean="whitespace" name=tokenizer_path, seq_len=text_len, clean="whitespace"
) )
def to_cpu(self):
self.model = self.model.to("cpu")
def to_cuda(self):
self.model = self.model.to("cuda")
def infer(self, texts, args): def infer(self, texts, args):
if args.cpu_offload:
self.to_cuda()
ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True) ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
ids = ids.cuda() ids = ids.cuda()
mask = mask.cuda() mask = mask.cuda()
seq_lens = mask.gt(0).sum(dim=1).long() seq_lens = mask.gt(0).sum(dim=1).long()
context = self.model(ids, mask) context = self.model(ids, mask)
if args.cpu_offload:
self.to_cpu()
return [u[:v] for u, v in zip(context, seq_lens)] return [u[:v] for u, v in zip(context, seq_lens)]
......
...@@ -750,8 +750,8 @@ class WanVAE: ...@@ -750,8 +750,8 @@ class WanVAE:
1.9160, 1.9160,
] ]
self.mean = torch.tensor(mean, dtype=dtype, device=device) self.mean = torch.tensor(mean, dtype=dtype, device=device)
self.std = torch.tensor(std, dtype=dtype, device=device) self.inv_std = 1.0 / torch.tensor(std, dtype=dtype, device=device)
self.scale = [self.mean, 1.0 / self.std] self.scale = [self.mean, self.inv_std]
# init model # init model
self.model = ( self.model = (
...@@ -764,6 +764,18 @@ class WanVAE: ...@@ -764,6 +764,18 @@ class WanVAE:
.to(device) .to(device)
) )
def to_cpu(self):
self.model = self.model.to("cpu")
self.mean = self.mean.cpu()
self.inv_std = self.inv_std.cpu()
self.scale = [self.mean, self.inv_std]
def to_cuda(self):
self.model = self.model.to("cuda")
self.mean = self.mean.cuda()
self.inv_std = self.inv_std.cuda()
self.scale = [self.mean, self.inv_std]
def encode(self, videos): def encode(self, videos):
""" """
videos: A list of videos each with shape [C, T, H, W]. videos: A list of videos each with shape [C, T, H, W].
...@@ -823,6 +835,9 @@ class WanVAE: ...@@ -823,6 +835,9 @@ class WanVAE:
def decode(self, zs, generator, args): def decode(self, zs, generator, args):
if args.cpu_offload:
self.to_cuda()
if self.parallel: if self.parallel:
world_size = dist.get_world_size() world_size = dist.get_world_size()
cur_rank = dist.get_rank() cur_rank = dist.get_rank()
...@@ -838,4 +853,9 @@ class WanVAE: ...@@ -838,4 +853,9 @@ class WanVAE:
images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1) images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
else: else:
images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1) images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
if args.cpu_offload:
images = images.cpu().float()
self.to_cpu()
return images return images
...@@ -48,18 +48,18 @@ def load_models(args, model_config): ...@@ -48,18 +48,18 @@ def load_models(args, model_config):
text_encoder = T5EncoderModel( text_encoder = T5EncoderModel(
text_len=model_config["text_len"], text_len=model_config["text_len"],
dtype=torch.bfloat16, dtype=torch.bfloat16,
device=torch.device("cuda"), device=init_device,
checkpoint_path=os.path.join(args.model_path, "models_t5_umt5-xxl-enc-bf16.pth"), checkpoint_path=os.path.join(args.model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
tokenizer_path=os.path.join(args.model_path, "google/umt5-xxl"), tokenizer_path=os.path.join(args.model_path, "google/umt5-xxl"),
shard_fn=None, shard_fn=None,
) )
text_encoders = [text_encoder] text_encoders = [text_encoder]
model = WanModel(args.model_path, model_config) model = WanModel(args.model_path, model_config)
vae_model = WanVAE(vae_pth=os.path.join(args.model_path, "Wan2.1_VAE.pth"), device=torch.device("cuda"), parallel=args.parallel_vae) vae_model = WanVAE(vae_pth=os.path.join(args.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=args.parallel_vae)
if args.task == 'i2v': if args.task == 'i2v':
image_encoder = CLIPModel( image_encoder = CLIPModel(
dtype=torch.float16, dtype=torch.float16,
device=torch.device("cuda"), device=init_device,
checkpoint_path=os.path.join(args.model_path, checkpoint_path=os.path.join(args.model_path,
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
tokenizer_path=os.path.join(args.model_path, "xlm-roberta-large")) tokenizer_path=os.path.join(args.model_path, "xlm-roberta-large"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment