Commit a92ea6e8 authored by gushiqiao's avatar gushiqiao
Browse files

Fix

parent 8ba6e3b4
...@@ -38,9 +38,15 @@ class Conv3dWeight(Conv3dWeightTemplate): ...@@ -38,9 +38,15 @@ class Conv3dWeight(Conv3dWeightTemplate):
self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) if self.bias_name is not None else None self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) if self.bias_name is not None else None
def apply(self, input_tensor): def apply(self, input_tensor):
# if input_tensor.dtype == torch.float: input_tensor = torch.nn.functional.conv3d(
# input_tensor = input_tensor.to(torch.bfloat16) input_tensor,
input_tensor = torch.nn.functional.conv3d(input_tensor, weight=self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) weight=self.weight,
bias=self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
)
return input_tensor return input_tensor
def to_cpu(self, non_blocking=False): def to_cpu(self, non_blocking=False):
......
...@@ -79,7 +79,6 @@ class MMWeight(MMWeightTemplate): ...@@ -79,7 +79,6 @@ class MMWeight(MMWeightTemplate):
self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) if self.bias is not None else None self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) if self.bias is not None else None
def apply(self, input_tensor): def apply(self, input_tensor):
# if input_tensor.dtype != torch.float
shape = (input_tensor.shape[0], self.weight.shape[1]) shape = (input_tensor.shape[0], self.weight.shape[1])
dtype = input_tensor.dtype dtype = input_tensor.dtype
device = input_tensor.device device = input_tensor.device
...@@ -143,6 +142,7 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -143,6 +142,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.load_func(weight_dict) self.load_func(weight_dict)
if self.weight_need_transpose: if self.weight_need_transpose:
self.weight = self.weight.t() self.weight = self.weight.t()
self.pinned_weight = self.pinned_weight.t()
def clear(self): def clear(self):
attrs = ["weight", "weight_scale", "bias"] attrs = ["weight", "weight_scale", "bias"]
......
...@@ -72,10 +72,6 @@ def apply_rotary_emb(x, freqs_i): ...@@ -72,10 +72,6 @@ def apply_rotary_emb(x, freqs_i):
# Apply rotary embedding # Apply rotary embedding
x_i = torch.view_as_real(x_i * freqs_i).flatten(2) x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
x_i = torch.cat([x_i, x[seq_len:]]) x_i = torch.cat([x_i, x[seq_len:]])
# if GET_DTYPE() == "BF16":
# x_i = x_i.to(torch.bfloat16)
# else:
# x_i = x_i.float()
return x_i.to(torch.bfloat16) return x_i.to(torch.bfloat16)
......
...@@ -37,7 +37,6 @@ class WanModel: ...@@ -37,7 +37,6 @@ class WanModel:
self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False) self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False)
if self.dit_quantized: if self.dit_quantized:
assert self.weight_auto_quant or self.dit_quantized_ckpt is not None assert self.weight_auto_quant or self.dit_quantized_ckpt is not None
assert GET_DTYPE() == "BF16"
self.device = device self.device = device
self._init_infer_class() self._init_infer_class()
...@@ -63,13 +62,11 @@ class WanModel: ...@@ -63,13 +62,11 @@ class WanModel:
else: else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}") raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
def _load_safetensor_to_dict(self, file_path): def _load_safetensor_to_dict(self, file_path, use_bf16, skip_bf16):
with safe_open(file_path, framework="pt") as f: with safe_open(file_path, framework="pt") as f:
use_bf16 = GET_DTYPE() == "BF16"
skip_bf16 = {"norm", "embedding", "modulation", "time"}
return {key: (f.get_tensor(key).to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()} return {key: (f.get_tensor(key).to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()}
def _load_ckpt(self): def _load_ckpt(self, use_bf16, skip_bf16):
safetensors_pattern = os.path.join(self.model_path, "*.safetensors") safetensors_pattern = os.path.join(self.model_path, "*.safetensors")
safetensors_files = glob.glob(safetensors_pattern) safetensors_files = glob.glob(safetensors_pattern)
...@@ -77,21 +74,17 @@ class WanModel: ...@@ -77,21 +74,17 @@ class WanModel:
raise FileNotFoundError(f"No .safetensors files found in directory: {self.model_path}") raise FileNotFoundError(f"No .safetensors files found in directory: {self.model_path}")
weight_dict = {} weight_dict = {}
for file_path in safetensors_files: for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path) file_weights = self._load_safetensor_to_dict(file_path, use_bf16, skip_bf16)
weight_dict.update(file_weights) weight_dict.update(file_weights)
return weight_dict return weight_dict
def _load_quant_ckpt(self): def _load_quant_ckpt(self, use_bf16, skip_bf16):
ckpt_path = self.config.dit_quantized_ckpt ckpt_path = self.config.dit_quantized_ckpt
logger.info(f"Loading quant dit model from {ckpt_path}") logger.info(f"Loading quant dit model from {ckpt_path}")
if ckpt_path.endswith(".pth"):
logger.info(f"Loading {ckpt_path} as PyTorch model.")
weight_dict = torch.load(ckpt_path, map_location=self.device, weights_only=True)
else:
index_files = [f for f in os.listdir(ckpt_path) if f.endswith(".index.json")] index_files = [f for f in os.listdir(ckpt_path) if f.endswith(".index.json")]
if not index_files: if not index_files:
raise FileNotFoundError(f"No .pth file or *.index.json found in {ckpt_path}") raise FileNotFoundError(f"No *.index.json found in {ckpt_path}")
index_path = os.path.join(ckpt_path, index_files[0]) index_path = os.path.join(ckpt_path, index_files[0])
logger.info(f" Using safetensors index: {index_path}") logger.info(f" Using safetensors index: {index_path}")
...@@ -102,16 +95,20 @@ class WanModel: ...@@ -102,16 +95,20 @@ class WanModel:
weight_dict = {} weight_dict = {}
for filename in set(index_data["weight_map"].values()): for filename in set(index_data["weight_map"].values()):
safetensor_path = os.path.join(ckpt_path, filename) safetensor_path = os.path.join(ckpt_path, filename)
with safe_open(safetensor_path, framework="pt", device=str(self.device)) as f: with safe_open(safetensor_path, framework="pt") as f:
logger.info(f"Loading weights from {safetensor_path}") logger.info(f"Loading weights from {safetensor_path}")
for k in f.keys(): for k in f.keys():
weight_dict[k] = f.get_tensor(k).pin_memory() if f.get_tensor(k).dtype == torch.float:
if weight_dict[k].dtype == torch.float: if use_bf16 or all(s not in k for s in skip_bf16):
weight_dict[k] = weight_dict[k].pin_memory().to(torch.bfloat16) weight_dict[k] = f.get_tensor(k).pin_memory().to(torch.bfloat16).to(self.device)
else:
weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
else:
weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
return weight_dict return weight_dict
def _load_quant_split_ckpt(self): def _load_quant_split_ckpt(self, use_bf16, skip_bf16):
lazy_load_model_path = self.config.dit_quantized_ckpt lazy_load_model_path = self.config.dit_quantized_ckpt
logger.info(f"Loading splited quant model from {lazy_load_model_path}") logger.info(f"Loading splited quant model from {lazy_load_model_path}")
pre_post_weight_dict, transformer_weight_dict = {}, {} pre_post_weight_dict, transformer_weight_dict = {}, {}
...@@ -119,9 +116,13 @@ class WanModel: ...@@ -119,9 +116,13 @@ class WanModel:
safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors") safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors")
with safe_open(safetensor_path, framework="pt", device="cpu") as f: with safe_open(safetensor_path, framework="pt", device="cpu") as f:
for k in f.keys(): for k in f.keys():
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory() if f.get_tensor(k).dtype == torch.float:
if pre_post_weight_dict[k].dtype == torch.float: if use_bf16 or all(s not in k for s in skip_bf16):
pre_post_weight_dict[k] = pre_post_weight_dict[k].pin_memory().to(torch.bfloat16) pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(torch.bfloat16).to(self.device)
else:
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
else:
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
safetensors_pattern = os.path.join(lazy_load_model_path, "block_*.safetensors") safetensors_pattern = os.path.join(lazy_load_model_path, "block_*.safetensors")
safetensors_files = glob.glob(safetensors_pattern) safetensors_files = glob.glob(safetensors_pattern)
...@@ -132,24 +133,30 @@ class WanModel: ...@@ -132,24 +133,30 @@ class WanModel:
with safe_open(file_path, framework="pt") as f: with safe_open(file_path, framework="pt") as f:
for k in f.keys(): for k in f.keys():
if "modulation" in k: if "modulation" in k:
transformer_weight_dict[k] = f.get_tensor(k).pin_memory() if f.get_tensor(k).dtype == torch.float:
if transformer_weight_dict[k].dtype == torch.float: if use_bf16 or all(s not in k for s in skip_bf16):
transformer_weight_dict[k] = transformer_weight_dict[k].pin_memory().to(torch.bfloat16) transformer_weight_dict[k] = f.get_tensor(k).pin_memory().to(torch.bfloat16).to(self.device)
else:
transformer_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
return pre_post_weight_dict, transformer_weight_dict return pre_post_weight_dict, transformer_weight_dict
def _init_weights(self, weight_dict=None): def _init_weights(self, weight_dict=None):
use_bf16 = GET_DTYPE() == "BF16"
# Some layers run with float32 to achieve high accuracy
skip_bf16 = {"norm", "embedding", "modulation", "time"}
if weight_dict is None: if weight_dict is None:
if not self.dit_quantized or self.weight_auto_quant: if not self.dit_quantized or self.weight_auto_quant:
self.original_weight_dict = self._load_ckpt() self.original_weight_dict = self._load_ckpt(use_bf16, skip_bf16)
else: else:
if not self.config.get("lazy_load", False): if not self.config.get("lazy_load", False):
self.original_weight_dict = self._load_quant_ckpt() self.original_weight_dict = self._load_quant_ckpt(use_bf16, skip_bf16)
else: else:
( (
self.original_weight_dict, self.original_weight_dict,
self.transformer_weight_dict, self.transformer_weight_dict,
) = self._load_quant_split_ckpt() ) = self._load_quant_split_ckpt(use_bf16, skip_bf16)
else: else:
self.original_weight_dict = weight_dict self.original_weight_dict = weight_dict
# init weights # init weights
......
...@@ -519,7 +519,7 @@ def convert_weights(args): ...@@ -519,7 +519,7 @@ def convert_weights(args):
def copy_non_weight_files(source_dir, target_dir): def copy_non_weight_files(source_dir, target_dir):
ignore_extensions = [".pth", ".pt", ".safetensors"] ignore_extensions = [".pth", ".pt", ".safetensors", ".index.json"]
logger.info(f"Start copying non-weighted files and subdirectories...") logger.info(f"Start copying non-weighted files and subdirectories...")
...@@ -586,6 +586,7 @@ def main(): ...@@ -586,6 +586,7 @@ def main():
) )
args = parser.parse_args() args = parser.parse_args()
if args.quantized:
if args.dtype == "torch.int8": if args.dtype == "torch.int8":
args.dtype = torch.int8 args.dtype = torch.int8
elif args.dtype == "torch.float8_e4m3fn": elif args.dtype == "torch.float8_e4m3fn":
......
...@@ -33,7 +33,7 @@ python converter.py \ ...@@ -33,7 +33,7 @@ python converter.py \
--quantized \ --quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \ --source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext .safetensors \
--output_name wan_int8 \ --output_name wan_int8 \
--dtype torch.int8 \ --dtype torch.int8 \
--model_type wan_dit --model_type wan_dit
...@@ -44,7 +44,7 @@ python converter.py \ ...@@ -44,7 +44,7 @@ python converter.py \
--quantized \ --quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \ --source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext .safetensors \
--output_name wan_fp8 \ --output_name wan_fp8 \
--dtype torch.float8_e4m3fn \ --dtype torch.float8_e4m3fn \
--model_type wan_dit --model_type wan_dit
...@@ -57,7 +57,7 @@ python converter.py \ ...@@ -57,7 +57,7 @@ python converter.py \
--quantized \ --quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \ --source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext .safetensors \
--output_name hunyuan_int8 \ --output_name hunyuan_int8 \
--dtype torch.int8 \ --dtype torch.int8 \
--model_type hunyuan_dit --model_type hunyuan_dit
...@@ -68,7 +68,7 @@ python converter.py \ ...@@ -68,7 +68,7 @@ python converter.py \
--quantized \ --quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \ --source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext .safetensors \
--output_name hunyuan_fp8 \ --output_name hunyuan_fp8 \
--dtype torch.float8_e4m3fn \ --dtype torch.float8_e4m3fn \
--model_type hunyuan_dit --model_type hunyuan_dit
......
...@@ -33,7 +33,7 @@ python converter.py \ ...@@ -33,7 +33,7 @@ python converter.py \
--quantized \ --quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \ --source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext .safetensors \
--output_name wan_int8 \ --output_name wan_int8 \
--dtype torch.int8 \ --dtype torch.int8 \
--model_type wan_dit --model_type wan_dit
...@@ -44,7 +44,7 @@ python converter.py \ ...@@ -44,7 +44,7 @@ python converter.py \
--quantized \ --quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \ --source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext .safetensors \
--output_name wan_fp8 \ --output_name wan_fp8 \
--dtype torch.float8_e4m3fn \ --dtype torch.float8_e4m3fn \
--model_type wan_dit --model_type wan_dit
...@@ -57,7 +57,7 @@ python converter.py \ ...@@ -57,7 +57,7 @@ python converter.py \
--quantized \ --quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \ --source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext ..safetensors \
--output_name hunyuan_int8 \ --output_name hunyuan_int8 \
--dtype torch.int8 \ --dtype torch.int8 \
--model_type hunyuan_dit --model_type hunyuan_dit
...@@ -68,7 +68,7 @@ python converter.py \ ...@@ -68,7 +68,7 @@ python converter.py \
--quantized \ --quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \ --source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext .safetensors \
--output_name hunyuan_fp8 \ --output_name hunyuan_fp8 \
--dtype torch.float8_e4m3fn \ --dtype torch.float8_e4m3fn \
--model_type hunyuan_dit --model_type hunyuan_dit
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment