Commit 2bb1b0f0 authored by Xinchi Huang's avatar Xinchi Huang Committed by GitHub
Browse files

Xinchi/fix offload (#57)



* fix offload extra latency in the first step by pre-allocating pinned memory

* pre-commit

---------
Co-authored-by: default avatar“de1star” <“843414674@qq.com”>
parent af248eb2
...@@ -57,7 +57,9 @@ class MMWeight(MMWeightTemplate): ...@@ -57,7 +57,9 @@ class MMWeight(MMWeightTemplate):
def load(self, weight_dict): def load(self, weight_dict):
self.weight = weight_dict[self.weight_name].t() self.weight = weight_dict[self.weight_name].t()
self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None
self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) if self.bias is not None else None
def apply(self, input_tensor): def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[1]) shape = (input_tensor.shape[0], self.weight.shape[1])
...@@ -76,6 +78,15 @@ class MMWeight(MMWeightTemplate): ...@@ -76,6 +78,15 @@ class MMWeight(MMWeightTemplate):
destination[self.bias_name] = self.bias.cpu().detach().clone() destination[self.bias_name] = self.bias.cpu().detach().clone()
return destination return destination
def to_cpu(self, non_blocking=False):
# self.weight = self.weight.to("cpu", non_blocking=non_blocking)
self.weight = self.pinned_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
if hasattr(self, "weight_scale"):
self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking)
if self.bias is not None:
# self.bias = self.bias.to("cpu", non_blocking=non_blocking)
self.bias = self.pinned_bias.copy_(self.bias, non_blocking=non_blocking).cpu()
@MM_WEIGHT_REGISTER("Default-Force-FP32") @MM_WEIGHT_REGISTER("Default-Force-FP32")
class MMWeightForceFP32(MMWeight): class MMWeightForceFP32(MMWeight):
......
import torch
from lightx2v.utils.registry_factory import TENSOR_REGISTER from lightx2v.utils.registry_factory import TENSOR_REGISTER
...@@ -8,9 +9,11 @@ class DefaultTensor: ...@@ -8,9 +9,11 @@ class DefaultTensor:
def load(self, weight_dict): def load(self, weight_dict):
self.tensor = weight_dict[self.tensor_name] self.tensor = weight_dict[self.tensor_name]
self.pinned_tensor = torch.empty(self.tensor.shape, pin_memory=True, dtype=self.tensor.dtype)
def to_cpu(self, non_blocking=False): def to_cpu(self, non_blocking=False):
self.tensor = self.tensor.to("cpu", non_blocking=non_blocking) # self.tensor = self.tensor.to("cpu", non_blocking=non_blocking)
self.tensor = self.pinned_tensor.copy_(self.tensor, non_blocking=non_blocking).cpu()
def to_cuda(self, non_blocking=False): def to_cuda(self, non_blocking=False):
self.tensor = self.tensor.cuda(non_blocking=non_blocking) self.tensor = self.tensor.cuda(non_blocking=non_blocking)
......
...@@ -64,9 +64,9 @@ class WanModel: ...@@ -64,9 +64,9 @@ class WanModel:
use_bfloat16 = self.config.get("use_bfloat16", True) use_bfloat16 = self.config.get("use_bfloat16", True)
with safe_open(file_path, framework="pt") as f: with safe_open(file_path, framework="pt") as f:
if use_bfloat16: if use_bfloat16:
tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16).to(self.device) for key in f.keys()} tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16).pin_memory().to(self.device) for key in f.keys()}
else: else:
tensor_dict = {key: f.get_tensor(key).to(self.device) for key in f.keys()} tensor_dict = {key: f.get_tensor(key).pin_memory().to(self.device) for key in f.keys()}
return tensor_dict return tensor_dict
def _load_ckpt(self): def _load_ckpt(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment