"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "f5a1343f97bf315b94a4f8a00663bd0f106155cf"
Commit 683aaa3a authored by gushiqiao's avatar gushiqiao Committed by Yang Yong(雍洋)
Browse files

Support sync cpu offload. (#10)


Co-authored-by: default avatargushiqiao <gushiqiao@sensetime.com>
parent d76fc3db
File mode changed from 100644 to 100755
...@@ -54,7 +54,7 @@ def load_models(args, model_config): ...@@ -54,7 +54,7 @@ def load_models(args, model_config):
shard_fn=None, shard_fn=None,
) )
text_encoders = [text_encoder] text_encoders = [text_encoder]
model = WanModel(args.model_path, model_config) model = WanModel(args.model_path, model_config, device=init_device)
vae_model = WanVAE(vae_pth=os.path.join(args.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=args.parallel_vae) vae_model = WanVAE(vae_pth=os.path.join(args.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=args.parallel_vae)
if args.task == "i2v": if args.task == "i2v":
image_encoder = CLIPModel( image_encoder = CLIPModel(
...@@ -97,8 +97,7 @@ def run_image_encoder(args, image_encoder, vae_model): ...@@ -97,8 +97,7 @@ def run_image_encoder(args, image_encoder, vae_model):
elif args.model_cls == "wan2.1": elif args.model_cls == "wan2.1":
img = Image.open(args.image_path).convert("RGB") img = Image.open(args.image_path).convert("RGB")
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda() img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
clip_encoder_out = image_encoder.visual([img[:, None, :, :]]).squeeze(0).to(torch.bfloat16) clip_encoder_out = image_encoder.visual([img[:, None, :, :]], args).squeeze(0).to(torch.bfloat16)
h, w = img.shape[1:] h, w = img.shape[1:]
aspect_ratio = h / w aspect_ratio = h / w
max_area = args.target_height * args.target_width max_area = args.target_height * args.target_width
...@@ -115,13 +114,14 @@ def run_image_encoder(args, image_encoder, vae_model): ...@@ -115,13 +114,14 @@ def run_image_encoder(args, image_encoder, vae_model):
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0] msk = msk.transpose(1, 2)[0]
vae_encode_out = vae_model.encode(
vae_encode_out = vae_model.encode([torch.concat([torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1), torch.zeros(3, 80, h, w)], dim=1).cuda()])[0] [torch.concat([torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1), torch.zeros(3, 80, h, w)], dim=1).cuda()], args
)[0]
vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16) vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16)
return {"clip_encoder_out": clip_encoder_out, "vae_encode_out": vae_encode_out} return {"clip_encoder_out": clip_encoder_out, "vae_encode_out": vae_encode_out}
else: else:
raise NotImplementedError(f"Unsupported model class: {model_cls}") raise NotImplementedError(f"Unsupported model class: {args.model_cls}")
def run_text_encoder(args, text, text_encoders, model_config): def run_text_encoder(args, text, text_encoders, model_config):
...@@ -279,15 +279,10 @@ if __name__ == "__main__": ...@@ -279,15 +279,10 @@ if __name__ == "__main__":
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
if args.cpu_offload:
model.to_cuda()
latents, generator = run_main_inference(args, model, text_encoder_output, image_encoder_output) latents, generator = run_main_inference(args, model, text_encoder_output, image_encoder_output)
if args.cpu_offload: gc.collect()
model.to_cpu() torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
images = run_vae(latents, generator, args) images = run_vae(latents, generator, args)
......
import torch
class WeightStreamManager(object):
def __init__(self):
self.active_weights = [None for _ in range(2)]
self.compute_stream = torch.cuda.Stream(priority=-1)
self.load_stream = torch.cuda.Stream(priority=0)
def prefetch_weights(self, block_idx, blocks_weights):
with torch.cuda.stream(self.load_stream):
if self.active_weights[1] is not None:
self.active_weights[1].to_cpu_sync()
new_weights = blocks_weights[block_idx]
new_weights.to_cuda_sync()
self.active_weights[1] = new_weights
def swap_weights(self):
self.compute_stream.synchronize()
self.load_stream.synchronize()
self.active_weights[0], self.active_weights[1] = (
self.active_weights[1],
self.active_weights[0],
)
import torch
import torch.nn as nn
class MemoryEfficientBlocks(nn.Module):
def __init__(self, block_class, num_blocks, **block_params):
super().__init__()
self.block_class = block_class
self.num_blocks = num_blocks
self.block_params = block_params
# 初始化两个block
self.active_blocks = nn.ModuleList([block_class(**block_params) for _ in range(2)])
# 为权重加载创建独立的CUDA流,并设置优先级
self.compute_stream = torch.cuda.Stream(priority=-1) # 高优先级
self.load_stream = torch.cuda.Stream(priority=0) # 普通优先级
# 预分配固定内存用于异步传输
self.pinned_memory = torch.cuda.empty_cache()
torch.cuda.memory.set_per_process_memory_fraction(0.8) # 限制GPU内存使用
# 用于存储预加载的权重
# self.next_weights = None
self.weight_buffer = []
# self.current_block_idx = 0
def initialize_weights(self, checkpoint, key):
"""加载所有权重到CPU内存"""
# checkpoint = torch.load(checkpoint_path, map_location='cpu')
for i in range(self.num_blocks):
block_weights = {k.replace(f"{key}.{i}.", ""): v for k, v in checkpoint.items() if f"{key}.{i}." in k}
self.weight_buffer.append(block_weights)
def prefetch_weights(self, block_idx):
"""在独立CUDA流中预加载下一个block的权重"""
with torch.cuda.stream(self.load_stream):
next_weights = self.weight_buffer[block_idx]
next_weights = {k: v.cuda(non_blocking=True) for k, v in next_weights.items()}
self.active_blocks[1].load_state_dict(next_weights)
def swap_blocks(self):
"""交换两个block并更新权重"""
# 等待计算完成
self.compute_stream.synchronize()
# 等待加载完成
self.load_stream.synchronize()
# 交换blocks
self.active_blocks[0], self.active_blocks[1] = self.active_blocks[1], self.active_blocks[0]
def forward(self, *args, **kwargs):
"""前向传播,同时进行计算和权重加载"""
# import pdb; pdb.set_trace()
for i in range(self.num_blocks):
if i == 0:
self.active_blocks[0].load_state_dict(self.weight_buffer[0])
# 在主计算流中进行当前block的计算
with torch.cuda.stream(self.compute_stream):
current_block = self.active_blocks[0]
outputs = current_block(*args, **kwargs) # 解包参数传入
# import pdb; pdb.set_trace()
# 在独立流中预加载下一个block的权重
if i < self.num_blocks - 1:
self.prefetch_weights(i + 1)
# 交换blocks并更新权重
self.swap_blocks()
# 更新args中的输入为当前输出
args = list(args)
if len(outputs) == 1:
args[0] = outputs
else:
for i in range(len(outputs)):
args[i] = outputs[i]
args = tuple(args)
return outputs
...@@ -28,6 +28,16 @@ class MMWeightTemplate(metaclass=ABCMeta): ...@@ -28,6 +28,16 @@ class MMWeightTemplate(metaclass=ABCMeta):
if config is not None: if config is not None:
self.config = config self.config = config
def to_cpu(self, non_blocking=False):
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
if self.bias is not None:
self.bias = self.bias.to("cpu", non_blocking=non_blocking)
def to_cuda(self, non_blocking=False):
self.weight = self.weight.cuda(non_blocking=non_blocking)
if self.bias is not None:
self.bias = self.bias.cuda(non_blocking=non_blocking)
@MM_WEIGHT_REGISTER("Default") @MM_WEIGHT_REGISTER("Default")
class MMWeight(MMWeightTemplate): class MMWeight(MMWeightTemplate):
...@@ -47,16 +57,6 @@ class MMWeight(MMWeightTemplate): ...@@ -47,16 +57,6 @@ class MMWeight(MMWeightTemplate):
return torch.mm(input_tensor, self.weight, out=output_tensor) return torch.mm(input_tensor, self.weight, out=output_tensor)
return torch.addmm(self.bias, input_tensor, self.weight, out=output_tensor) return torch.addmm(self.bias, input_tensor, self.weight, out=output_tensor)
def to_cpu(self):
self.weight = self.weight.cpu()
if self.bias is not None:
self.bias = self.bias.cpu()
def to_cuda(self):
self.weight = self.weight.cuda()
if self.bias is not None:
self.bias = self.bias.cuda()
@MM_WEIGHT_REGISTER("Default-Force-FP32") @MM_WEIGHT_REGISTER("Default-Force-FP32")
class MMWeightForceFP32(MMWeight): class MMWeightForceFP32(MMWeight):
...@@ -87,7 +87,7 @@ class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightTemplate): ...@@ -87,7 +87,7 @@ class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightTemplate):
def load(self, weight_dict): def load(self, weight_dict):
if self.config.get("weight_auto_quant", True): if self.config.get("weight_auto_quant", True):
self.weight = weight_dict[self.weight_name].to(torch.float32).cuda() self.weight = weight_dict[self.weight_name].to(torch.float32).cuda()
w_quantizer = FloatQuantizer("e4m3", True, "channel") w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight) self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.float8_e4m3fn).t().cuda() self.weight = self.weight.to(torch.float8_e4m3fn).t().cuda()
self.weight_scale = self.weight_scale.to(torch.float32).cuda() self.weight_scale = self.weight_scale.to(torch.float32).cuda()
...@@ -105,18 +105,6 @@ class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightTemplate): ...@@ -105,18 +105,6 @@ class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightTemplate):
torch.ops._C.cutlass_scaled_mm(output_tensor, qinput, self.weight, x_scale, self.weight_scale, self.bias) torch.ops._C.cutlass_scaled_mm(output_tensor, qinput, self.weight, x_scale, self.weight_scale, self.bias)
return output_tensor return output_tensor
def to_cpu(self):
self.weight = self.weight.cpu()
self.weight_scale = self.weight_scale.cpu()
if self.bias is not None:
self.bias = self.bias.cpu()
def to_cuda(self):
self.weight = self.weight.cuda()
self.weight_scale = self.weight_scale.cuda()
if self.bias is not None:
self.bias = self.bias.cuda()
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm") @MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm")
class MMWeightWint8channelAint8channeldynamicVllm(MMWeightTemplate): class MMWeightWint8channelAint8channeldynamicVllm(MMWeightTemplate):
...@@ -135,7 +123,7 @@ class MMWeightWint8channelAint8channeldynamicVllm(MMWeightTemplate): ...@@ -135,7 +123,7 @@ class MMWeightWint8channelAint8channeldynamicVllm(MMWeightTemplate):
def load(self, weight_dict): def load(self, weight_dict):
if self.config.get("weight_auto_quant", True): if self.config.get("weight_auto_quant", True):
self.weight = weight_dict[self.weight_name].to(torch.float32).cuda() self.weight = weight_dict[self.weight_name].to(torch.float32).cuda()
w_quantizer = IntegerQuantizer(8, True, "channel") w_quantizer = IntegerQuantizer(8, True, "per_channel")
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight) self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.int8).t().cuda() self.weight = self.weight.to(torch.int8).t().cuda()
self.weight_scale = self.weight_scale.to(torch.float32).cuda() self.weight_scale = self.weight_scale.to(torch.float32).cuda()
...@@ -153,18 +141,6 @@ class MMWeightWint8channelAint8channeldynamicVllm(MMWeightTemplate): ...@@ -153,18 +141,6 @@ class MMWeightWint8channelAint8channeldynamicVllm(MMWeightTemplate):
torch.ops._C.cutlass_scaled_mm(output_tensor, qinput, self.weight, x_scale, self.weight_scale, self.bias) torch.ops._C.cutlass_scaled_mm(output_tensor, qinput, self.weight, x_scale, self.weight_scale, self.bias)
return output_tensor return output_tensor
def to_cpu(self):
self.weight = self.weight.cpu()
self.weight_scale = self.weight_scale.cpu()
if self.bias is not None:
self.bias = self.bias.cpu()
def to_cuda(self):
self.weight = self.weight.cuda()
self.weight_scale = self.weight_scale.cuda()
if self.bias is not None:
self.bias = self.bias.cuda()
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F") @MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F")
class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightTemplate): class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightTemplate):
...@@ -183,7 +159,7 @@ class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightTemplate): ...@@ -183,7 +159,7 @@ class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightTemplate):
def load(self, weight_dict): def load(self, weight_dict):
if self.config.get("weight_auto_quant", True): if self.config.get("weight_auto_quant", True):
self.weight = weight_dict[self.weight_name].cuda() self.weight = weight_dict[self.weight_name].cuda()
w_quantizer = IntegerQuantizer(8, True, "channel") w_quantizer = IntegerQuantizer(8, True, "per_channel")
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight) self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.int8) self.weight = self.weight.to(torch.int8)
self.weight_scale = self.weight_scale.to(torch.float32) self.weight_scale = self.weight_scale.to(torch.float32)
...@@ -197,18 +173,6 @@ class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightTemplate): ...@@ -197,18 +173,6 @@ class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightTemplate):
output_tensor = Q8F.linear.q8_linear(qinput, self.weight, self.bias, x_scale, self.weight_scale, fuse_gelu=False, out_dtype=torch.bfloat16) output_tensor = Q8F.linear.q8_linear(qinput, self.weight, self.bias, x_scale, self.weight_scale, fuse_gelu=False, out_dtype=torch.bfloat16)
return output_tensor.squeeze(0) return output_tensor.squeeze(0)
def to_cpu(self):
self.weight = self.weight.cpu()
self.weight_scale = self.weight_scale.cpu()
if self.bias is not None:
self.bias = self.bias.cpu()
def to_cuda(self):
self.weight = self.weight.cuda()
self.weight_scale = self.weight_scale.cuda()
if self.bias is not None:
self.bias = self.bias.cuda()
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F") @MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F")
class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightTemplate): class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightTemplate):
...@@ -227,7 +191,7 @@ class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightTemplate): ...@@ -227,7 +191,7 @@ class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightTemplate):
def load(self, weight_dict): def load(self, weight_dict):
if self.config.get("weight_auto_quant", True): if self.config.get("weight_auto_quant", True):
self.weight = weight_dict[self.weight_name].cuda() self.weight = weight_dict[self.weight_name].cuda()
w_quantizer = FloatQuantizer("e4m3", True, "channel") w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight) self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.float8_e4m3fn) self.weight = self.weight.to(torch.float8_e4m3fn)
self.weight_scale = self.weight_scale.to(torch.float32) self.weight_scale = self.weight_scale.to(torch.float32)
...@@ -241,18 +205,6 @@ class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightTemplate): ...@@ -241,18 +205,6 @@ class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightTemplate):
output_tensor = Q8F.linear.fp8_linear(qinput, self.weight, self.bias, x_scale, self.weight_scale, out_dtype=torch.bfloat16) output_tensor = Q8F.linear.fp8_linear(qinput, self.weight, self.bias, x_scale, self.weight_scale, out_dtype=torch.bfloat16)
return output_tensor.squeeze(0) return output_tensor.squeeze(0)
def to_cpu(self):
self.weight = self.weight.cpu()
self.weight_scale = self.weight_scale.cpu()
if self.bias is not None:
self.bias = self.bias.cpu()
def to_cuda(self):
self.weight = self.weight.cuda()
self.weight_scale = self.weight_scale.cuda()
if self.bias is not None:
self.bias = self.bias.cuda()
if __name__ == "__main__": if __name__ == "__main__":
weight_dict = { weight_dict = {
......
...@@ -25,8 +25,8 @@ class MMWeightCalib(MMWeight): ...@@ -25,8 +25,8 @@ class MMWeightCalib(MMWeight):
def get_quantizer(self): def get_quantizer(self):
if self.config["mm_type"] == "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm": if self.config["mm_type"] == "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm":
self.w_setting = {"bit": "e4m3", "symmetric": True, "granularity": "channel"} self.w_setting = {"bit": "e4m3", "symmetric": True, "granularity": "per_channel"}
self.a_setting = {"bit": "e4m3", "symmetric": True, "granularity": "channel"} self.a_setting = {"bit": "e4m3", "symmetric": True, "granularity": "per_channel"}
self.w_quantizer = FloatQuantizer(**self.w_setting) self.w_quantizer = FloatQuantizer(**self.w_setting)
self.a_quantizer = FloatQuantizer(**self.a_setting) self.a_quantizer = FloatQuantizer(**self.a_setting)
self.act_dynamic_quant = True self.act_dynamic_quant = True
......
...@@ -442,7 +442,9 @@ class CLIPModel: ...@@ -442,7 +442,9 @@ class CLIPModel:
# init tokenizer # init tokenizer
self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=self.model.max_text_len - 2, clean="whitespace") self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=self.model.max_text_len - 2, clean="whitespace")
def visual(self, videos): def visual(self, videos, args):
if args.cpu_offload:
self.to_cuda()
# preprocess # preprocess
size = (self.model.image_size,) * 2 size = (self.model.image_size,) * 2
videos = torch.cat([F.interpolate(u.transpose(0, 1), size=size, mode="bicubic", align_corners=False) for u in videos]) videos = torch.cat([F.interpolate(u.transpose(0, 1), size=size, mode="bicubic", align_corners=False) for u in videos])
...@@ -451,4 +453,13 @@ class CLIPModel: ...@@ -451,4 +453,13 @@ class CLIPModel:
# forward # forward
with torch.amp.autocast("cuda", dtype=self.dtype): with torch.amp.autocast("cuda", dtype=self.dtype):
out = self.model.visual(videos, use_31_block=True) out = self.model.visual(videos, use_31_block=True)
return out
if args.cpu_offload:
self.to_cpu()
return out
def to_cuda(self):
self.model = self.model.cuda()
def to_cpu(self):
self.model = self.model.cpu()
import torch import torch
from .utils import compute_freqs, compute_freqs_dist, apply_rotary_emb from .utils import compute_freqs, compute_freqs_dist, apply_rotary_emb
from lightx2v.attentions import attention from lightx2v.attentions import attention
from lightx2v.common.offload.manager import WeightStreamManager
class WanTransformerInfer: class WanTransformerInfer:
...@@ -13,6 +14,11 @@ class WanTransformerInfer: ...@@ -13,6 +14,11 @@ class WanTransformerInfer:
self.head_dim = config["dim"] // config["num_heads"] self.head_dim = config["dim"] // config["num_heads"]
self.window_size = config.get("window_size", (-1, -1)) self.window_size = config.get("window_size", (-1, -1))
self.parallel_attention = None self.parallel_attention = None
if self.config["cpu_offload"]:
self.weights_stream_mgr = WeightStreamManager()
self.infer_func = self._infer_with_offload
else:
self.infer_func = self._infer_without_offload
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
...@@ -29,9 +35,36 @@ class WanTransformerInfer: ...@@ -29,9 +35,36 @@ class WanTransformerInfer:
return cu_seqlens_q, cu_seqlens_k, lq, lk return cu_seqlens_q, cu_seqlens_k, lq, lk
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
for i in range(self.blocks_num): return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
for block_idx in range(self.blocks_num):
if block_idx == 0:
self.weights_stream_mgr.active_weights[0] = weights.blocks_weights[0]
self.weights_stream_mgr.active_weights[0].to_cuda()
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
x = self.infer_block(
self.weights_stream_mgr.active_weights[0],
grid_sizes,
embed,
x,
embed0,
seq_lens,
freqs,
context,
)
if block_idx < self.blocks_num - 1:
self.weights_stream_mgr.prefetch_weights(block_idx + 1, weights.blocks_weights)
self.weights_stream_mgr.swap_weights()
return x
def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
for block_idx in range(self.blocks_num):
x = self.infer_block( x = self.infer_block(
weights.blocks_weights[i], weights.blocks_weights[block_idx],
grid_sizes, grid_sizes,
embed, embed,
x, x,
...@@ -40,7 +73,6 @@ class WanTransformerInfer: ...@@ -40,7 +73,6 @@ class WanTransformerInfer:
freqs, freqs,
context, context,
) )
return x return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
...@@ -101,11 +133,28 @@ class WanTransformerInfer: ...@@ -101,11 +133,28 @@ class WanTransformerInfer:
k = weights.cross_attn_norm_k.apply(weights.cross_attn_k.apply(context)).view(-1, n, d) k = weights.cross_attn_norm_k.apply(weights.cross_attn_k.apply(context)).view(-1, n, d)
v = weights.cross_attn_v.apply(context).view(-1, n, d) v = weights.cross_attn_v.apply(context).view(-1, n, d)
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k, k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device))
attn_out = attention(
attention_type=self.attention_type,
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=lq,
max_seqlen_kv=lk,
)
if self.task == "i2v": if self.task == "i2v":
k_img = weights.cross_attn_norm_k_img.apply(weights.cross_attn_k_img.apply(context_img)).view(-1, n, d) k_img = weights.cross_attn_norm_k_img.apply(weights.cross_attn_k_img.apply(context_img)).view(-1, n, d)
v_img = weights.cross_attn_v_img.apply(context_img).view(-1, n, d) v_img = weights.cross_attn_v_img.apply(context_img).view(-1, n, d)
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k_img, k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device)) cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(
q,
k_img,
k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device),
)
img_attn_out = attention( img_attn_out = attention(
attention_type=self.attention_type, attention_type=self.attention_type,
...@@ -118,18 +167,8 @@ class WanTransformerInfer: ...@@ -118,18 +167,8 @@ class WanTransformerInfer:
max_seqlen_kv=lk, max_seqlen_kv=lk,
) )
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k, k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device)) attn_out = attn_out + img_attn_out
attn_out = attention(
attention_type=self.attention_type,
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=lq,
max_seqlen_kv=lk,
)
attn_out = weights.cross_attn_o.apply(attn_out) attn_out = weights.cross_attn_o.apply(attn_out)
x = x + attn_out x = x + attn_out
......
...@@ -23,9 +23,10 @@ class WanModel: ...@@ -23,9 +23,10 @@ class WanModel:
post_weight_class = WanPostWeights post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config): def __init__(self, model_path, config, device):
self.model_path = model_path self.model_path = model_path
self.config = config self.config = config
self.device = device
self._init_infer_class() self._init_infer_class()
self._init_weights() self._init_weights()
self._init_infer() self._init_infer()
...@@ -53,7 +54,7 @@ class WanModel: ...@@ -53,7 +54,7 @@ class WanModel:
def _load_safetensor_to_dict(self, file_path): def _load_safetensor_to_dict(self, file_path):
with safe_open(file_path, framework="pt") as f: with safe_open(file_path, framework="pt") as f:
tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16).cuda() for key in f.keys()} tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16).to(self.device) for key in f.keys()}
return tensor_dict return tensor_dict
def _load_ckpt(self): def _load_ckpt(self):
...@@ -102,6 +103,10 @@ class WanModel: ...@@ -102,6 +103,10 @@ class WanModel:
def infer(self, text_encoders_output, image_encoder_output, args): def infer(self, text_encoders_output, image_encoder_output, args):
timestep = torch.stack([self.scheduler.timesteps[self.scheduler.step_index]]) timestep = torch.stack([self.scheduler.timesteps[self.scheduler.step_index]])
if self.config["cpu_offload"]:
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
embed, grid_sizes, pre_infer_out = self.pre_infer.infer( embed, grid_sizes, pre_infer_out = self.pre_infer.infer(
self.pre_weight, self.pre_weight,
[self.scheduler.latents], [self.scheduler.latents],
...@@ -112,6 +117,7 @@ class WanModel: ...@@ -112,6 +117,7 @@ class WanModel:
[image_encoder_output["vae_encode_out"]], [image_encoder_output["vae_encode_out"]],
) )
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out) x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0] noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
if self.config["feature_caching"] == "Tea": if self.config["feature_caching"] == "Tea":
...@@ -128,6 +134,7 @@ class WanModel: ...@@ -128,6 +134,7 @@ class WanModel:
image_encoder_output["clip_encoder_out"], image_encoder_output["clip_encoder_out"],
[image_encoder_output["vae_encode_out"]], [image_encoder_output["vae_encode_out"]],
) )
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out) x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0] noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
...@@ -137,3 +144,7 @@ class WanModel: ...@@ -137,3 +144,7 @@ class WanModel:
self.scheduler.cnt = 0 self.scheduler.cnt = 0
self.scheduler.noise_pred = noise_pred_uncond + args.sample_guide_scale * (noise_pred_cond - noise_pred_uncond) self.scheduler.noise_pred = noise_pred_uncond + args.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
if self.config["cpu_offload"]:
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
...@@ -10,23 +10,24 @@ class WanPostWeights: ...@@ -10,23 +10,24 @@ class WanPostWeights:
self.head = MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias") self.head = MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias")
self.head_modulation = weight_dict["head.modulation"] self.head_modulation = weight_dict["head.modulation"]
self.weight_list = [self.head, self.head_modulation] self.weight_list = [self.head]
for mm_weight in self.weight_list: for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate): if isinstance(mm_weight, MMWeightTemplate):
mm_weight.set_config(self.config["mm_config"]) mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict) mm_weight.load(weight_dict)
if self.config["cpu_offload"]:
mm_weight.to_cpu()
self.head_modulation = self.head_modulation.cpu()
def to_cpu(self): def to_cpu(self):
for mm_weight in self.weight_list: for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate): if isinstance(mm_weight, MMWeightTemplate):
mm_weight.to_cpu() mm_weight.to_cpu()
else: self.head_modulation = self.head_modulation.cpu()
mm_weight.cpu()
def to_cuda(self): def to_cuda(self):
for mm_weight in self.weight_list: for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate): if isinstance(mm_weight, MMWeightTemplate):
mm_weight.to_cuda() mm_weight.to_cuda()
else: self.head_modulation = self.head_modulation.cuda()
mm_weight.cuda()
...@@ -44,6 +44,8 @@ class WanPreWeights: ...@@ -44,6 +44,8 @@ class WanPreWeights:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, Conv3dWeightTemplate)): if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, Conv3dWeightTemplate)):
mm_weight.set_config(self.config["mm_config"]) mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict) mm_weight.load(weight_dict)
if self.config["cpu_offload"]:
mm_weight.to_cpu()
def to_cpu(self): def to_cpu(self):
for mm_weight in self.weight_list: for mm_weight in self.weight_list:
......
...@@ -71,7 +71,7 @@ class WanTransformerAttentionBlock: ...@@ -71,7 +71,7 @@ class WanTransformerAttentionBlock:
self.cross_attn_norm_k, self.cross_attn_norm_k,
self.ffn_0, self.ffn_0,
self.ffn_2, self.ffn_2,
self.modulation, # self.modulation,
] ]
if self.task == "i2v": if self.task == "i2v":
...@@ -87,17 +87,30 @@ class WanTransformerAttentionBlock: ...@@ -87,17 +87,30 @@ class WanTransformerAttentionBlock:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)): if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
mm_weight.set_config(self.config["mm_config"]) mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict) mm_weight.load(weight_dict)
if self.config["cpu_offload"]:
mm_weight.to_cpu()
self.modulation = self.modulation.cpu()
def to_cpu(self): def to_cpu(self):
for mm_weight in self.weight_list: for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)): if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cpu() mm_weight.to_cpu()
else: self.modulation = self.modulation.cpu()
mm_weight.cpu()
def to_cuda(self): def to_cuda(self):
for mm_weight in self.weight_list: for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)): if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cuda() mm_weight.to_cuda()
else: self.modulation = self.modulation.cuda()
mm_weight.cuda()
def to_cpu_sync(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cpu(non_blocking=True)
self.modulation = self.modulation.to("cpu", non_blocking=True)
def to_cuda_sync(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cuda(non_blocking=True)
self.modulation = self.modulation.cuda(non_blocking=True)
...@@ -725,11 +725,18 @@ class WanVAE: ...@@ -725,11 +725,18 @@ class WanVAE:
self.inv_std = self.inv_std.cuda() self.inv_std = self.inv_std.cuda()
self.scale = [self.mean, self.inv_std] self.scale = [self.mean, self.inv_std]
def encode(self, videos): def encode(self, videos, args):
""" """
videos: A list of videos each with shape [C, T, H, W]. videos: A list of videos each with shape [C, T, H, W].
""" """
return [self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos] if args.cpu_offload:
self.to_cuda()
out = [self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos]
if args.cpu_offload:
self.to_cpu()
return out
def decode_dist(self, zs, world_size, cur_rank, split_dim): def decode_dist(self, zs, world_size, cur_rank, split_dim):
splited_total_len = zs.shape[split_dim] splited_total_len = zs.shape[split_dim]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment