"tests/vscode:/vscode.git/clone" did not exist on "203e3231db0398339de2a3409124fc4a7ed51853"
Commit a92ea6e8 authored by gushiqiao's avatar gushiqiao
Browse files

Fix

parent 8ba6e3b4
......@@ -38,9 +38,15 @@ class Conv3dWeight(Conv3dWeightTemplate):
self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) if self.bias_name is not None else None
def apply(self, input_tensor):
# if input_tensor.dtype == torch.float:
# input_tensor = input_tensor.to(torch.bfloat16)
input_tensor = torch.nn.functional.conv3d(input_tensor, weight=self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
input_tensor = torch.nn.functional.conv3d(
input_tensor,
weight=self.weight,
bias=self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
)
return input_tensor
def to_cpu(self, non_blocking=False):
......
......@@ -79,7 +79,6 @@ class MMWeight(MMWeightTemplate):
self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) if self.bias is not None else None
def apply(self, input_tensor):
# if input_tensor.dtype != torch.float
shape = (input_tensor.shape[0], self.weight.shape[1])
dtype = input_tensor.dtype
device = input_tensor.device
......@@ -143,6 +142,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.load_func(weight_dict)
if self.weight_need_transpose:
self.weight = self.weight.t()
self.pinned_weight = self.pinned_weight.t()
def clear(self):
attrs = ["weight", "weight_scale", "bias"]
......
......@@ -72,10 +72,6 @@ def apply_rotary_emb(x, freqs_i):
# Apply rotary embedding
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
x_i = torch.cat([x_i, x[seq_len:]])
# if GET_DTYPE() == "BF16":
# x_i = x_i.to(torch.bfloat16)
# else:
# x_i = x_i.float()
return x_i.to(torch.bfloat16)
......
......@@ -37,7 +37,6 @@ class WanModel:
self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False)
if self.dit_quantized:
assert self.weight_auto_quant or self.dit_quantized_ckpt is not None
assert GET_DTYPE() == "BF16"
self.device = device
self._init_infer_class()
......@@ -63,13 +62,11 @@ class WanModel:
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
def _load_safetensor_to_dict(self, file_path):
def _load_safetensor_to_dict(self, file_path, use_bf16, skip_bf16):
with safe_open(file_path, framework="pt") as f:
use_bf16 = GET_DTYPE() == "BF16"
skip_bf16 = {"norm", "embedding", "modulation", "time"}
return {key: (f.get_tensor(key).to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()}
def _load_ckpt(self):
def _load_ckpt(self, use_bf16, skip_bf16):
safetensors_pattern = os.path.join(self.model_path, "*.safetensors")
safetensors_files = glob.glob(safetensors_pattern)
......@@ -77,41 +74,41 @@ class WanModel:
raise FileNotFoundError(f"No .safetensors files found in directory: {self.model_path}")
weight_dict = {}
for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path)
file_weights = self._load_safetensor_to_dict(file_path, use_bf16, skip_bf16)
weight_dict.update(file_weights)
return weight_dict
def _load_quant_ckpt(self):
def _load_quant_ckpt(self, use_bf16, skip_bf16):
ckpt_path = self.config.dit_quantized_ckpt
logger.info(f"Loading quant dit model from {ckpt_path}")
if ckpt_path.endswith(".pth"):
logger.info(f"Loading {ckpt_path} as PyTorch model.")
weight_dict = torch.load(ckpt_path, map_location=self.device, weights_only=True)
else:
index_files = [f for f in os.listdir(ckpt_path) if f.endswith(".index.json")]
if not index_files:
raise FileNotFoundError(f"No .pth file or *.index.json found in {ckpt_path}")
index_path = os.path.join(ckpt_path, index_files[0])
logger.info(f" Using safetensors index: {index_path}")
with open(index_path, "r") as f:
index_data = json.load(f)
weight_dict = {}
for filename in set(index_data["weight_map"].values()):
safetensor_path = os.path.join(ckpt_path, filename)
with safe_open(safetensor_path, framework="pt", device=str(self.device)) as f:
logger.info(f"Loading weights from {safetensor_path}")
for k in f.keys():
weight_dict[k] = f.get_tensor(k).pin_memory()
if weight_dict[k].dtype == torch.float:
weight_dict[k] = weight_dict[k].pin_memory().to(torch.bfloat16)
index_files = [f for f in os.listdir(ckpt_path) if f.endswith(".index.json")]
if not index_files:
raise FileNotFoundError(f"No *.index.json found in {ckpt_path}")
index_path = os.path.join(ckpt_path, index_files[0])
logger.info(f" Using safetensors index: {index_path}")
with open(index_path, "r") as f:
index_data = json.load(f)
weight_dict = {}
for filename in set(index_data["weight_map"].values()):
safetensor_path = os.path.join(ckpt_path, filename)
with safe_open(safetensor_path, framework="pt") as f:
logger.info(f"Loading weights from {safetensor_path}")
for k in f.keys():
if f.get_tensor(k).dtype == torch.float:
if use_bf16 or all(s not in k for s in skip_bf16):
weight_dict[k] = f.get_tensor(k).pin_memory().to(torch.bfloat16).to(self.device)
else:
weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
else:
weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
return weight_dict
def _load_quant_split_ckpt(self):
def _load_quant_split_ckpt(self, use_bf16, skip_bf16):
lazy_load_model_path = self.config.dit_quantized_ckpt
logger.info(f"Loading splited quant model from {lazy_load_model_path}")
pre_post_weight_dict, transformer_weight_dict = {}, {}
......@@ -119,9 +116,13 @@ class WanModel:
safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors")
with safe_open(safetensor_path, framework="pt", device="cpu") as f:
for k in f.keys():
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory()
if pre_post_weight_dict[k].dtype == torch.float:
pre_post_weight_dict[k] = pre_post_weight_dict[k].pin_memory().to(torch.bfloat16)
if f.get_tensor(k).dtype == torch.float:
if use_bf16 or all(s not in k for s in skip_bf16):
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(torch.bfloat16).to(self.device)
else:
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
else:
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
safetensors_pattern = os.path.join(lazy_load_model_path, "block_*.safetensors")
safetensors_files = glob.glob(safetensors_pattern)
......@@ -132,24 +133,30 @@ class WanModel:
with safe_open(file_path, framework="pt") as f:
for k in f.keys():
if "modulation" in k:
transformer_weight_dict[k] = f.get_tensor(k).pin_memory()
if transformer_weight_dict[k].dtype == torch.float:
transformer_weight_dict[k] = transformer_weight_dict[k].pin_memory().to(torch.bfloat16)
if f.get_tensor(k).dtype == torch.float:
if use_bf16 or all(s not in k for s in skip_bf16):
transformer_weight_dict[k] = f.get_tensor(k).pin_memory().to(torch.bfloat16).to(self.device)
else:
transformer_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
return pre_post_weight_dict, transformer_weight_dict
def _init_weights(self, weight_dict=None):
use_bf16 = GET_DTYPE() == "BF16"
# Some layers run with float32 to achieve high accuracy
skip_bf16 = {"norm", "embedding", "modulation", "time"}
if weight_dict is None:
if not self.dit_quantized or self.weight_auto_quant:
self.original_weight_dict = self._load_ckpt()
self.original_weight_dict = self._load_ckpt(use_bf16, skip_bf16)
else:
if not self.config.get("lazy_load", False):
self.original_weight_dict = self._load_quant_ckpt()
self.original_weight_dict = self._load_quant_ckpt(use_bf16, skip_bf16)
else:
(
self.original_weight_dict,
self.transformer_weight_dict,
) = self._load_quant_split_ckpt()
) = self._load_quant_split_ckpt(use_bf16, skip_bf16)
else:
self.original_weight_dict = weight_dict
# init weights
......
......@@ -519,7 +519,7 @@ def convert_weights(args):
def copy_non_weight_files(source_dir, target_dir):
ignore_extensions = [".pth", ".pt", ".safetensors"]
ignore_extensions = [".pth", ".pt", ".safetensors", ".index.json"]
logger.info(f"Start copying non-weighted files and subdirectories...")
......@@ -586,47 +586,48 @@ def main():
)
args = parser.parse_args()
if args.dtype == "torch.int8":
args.dtype = torch.int8
elif args.dtype == "torch.float8_e4m3fn":
args.dtype = torch.float8_e4m3fn
else:
raise ValueError(f"Not support dtype :{args.dtype}")
model_type_keys_map = {
"wan_dit": {
"key_idx": 2,
"target_keys": ["self_attn", "cross_attn", "ffn"],
"ignore_key": None,
},
"hunyuan_dit": {
"key_idx": 2,
"target_keys": [
"img_mod",
"img_attn_qkv",
"img_attn_proj",
"img_mlp",
"txt_mod",
"txt_attn_qkv",
"txt_attn_proj",
"txt_mlp",
"linear1",
"linear2",
"modulation",
],
"ignore_key": None,
},
"wan_t5": {"key_idx": 2, "target_keys": ["attn", "ffn"], "ignore_key": None},
"wan_clip": {
"key_idx": 3,
"target_keys": ["attn", "mlp"],
"ignore_key": "textual",
},
}
args.target_keys = model_type_keys_map[args.model_type]["target_keys"]
args.key_idx = model_type_keys_map[args.model_type]["key_idx"]
args.ignore_key = model_type_keys_map[args.model_type]["ignore_key"]
if args.quantized:
if args.dtype == "torch.int8":
args.dtype = torch.int8
elif args.dtype == "torch.float8_e4m3fn":
args.dtype = torch.float8_e4m3fn
else:
raise ValueError(f"Not support dtype :{args.dtype}")
model_type_keys_map = {
"wan_dit": {
"key_idx": 2,
"target_keys": ["self_attn", "cross_attn", "ffn"],
"ignore_key": None,
},
"hunyuan_dit": {
"key_idx": 2,
"target_keys": [
"img_mod",
"img_attn_qkv",
"img_attn_proj",
"img_mlp",
"txt_mod",
"txt_attn_qkv",
"txt_attn_proj",
"txt_mlp",
"linear1",
"linear2",
"modulation",
],
"ignore_key": None,
},
"wan_t5": {"key_idx": 2, "target_keys": ["attn", "ffn"], "ignore_key": None},
"wan_clip": {
"key_idx": 3,
"target_keys": ["attn", "mlp"],
"ignore_key": "textual",
},
}
args.target_keys = model_type_keys_map[args.model_type]["target_keys"]
args.key_idx = model_type_keys_map[args.model_type]["key_idx"]
args.ignore_key = model_type_keys_map[args.model_type]["ignore_key"]
if os.path.isfile(args.output):
raise ValueError("Output path must be a directory, not a file")
......
......@@ -33,7 +33,7 @@ python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \
--output_ext .pth\
--output_ext .safetensors \
--output_name wan_int8 \
--dtype torch.int8 \
--model_type wan_dit
......@@ -44,7 +44,7 @@ python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \
--output_ext .pth\
--output_ext .safetensors \
--output_name wan_fp8 \
--dtype torch.float8_e4m3fn \
--model_type wan_dit
......@@ -57,7 +57,7 @@ python converter.py \
--quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \
--output_ext .pth\
--output_ext .safetensors \
--output_name hunyuan_int8 \
--dtype torch.int8 \
--model_type hunyuan_dit
......@@ -68,7 +68,7 @@ python converter.py \
--quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \
--output_ext .pth\
--output_ext .safetensors \
--output_name hunyuan_fp8 \
--dtype torch.float8_e4m3fn \
--model_type hunyuan_dit
......
......@@ -33,7 +33,7 @@ python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \
--output_ext .pth\
--output_ext .safetensors \
--output_name wan_int8 \
--dtype torch.int8 \
--model_type wan_dit
......@@ -44,7 +44,7 @@ python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \
--output_ext .pth\
--output_ext .safetensors \
--output_name wan_fp8 \
--dtype torch.float8_e4m3fn \
--model_type wan_dit
......@@ -57,7 +57,7 @@ python converter.py \
--quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \
--output_ext .pth\
--output_ext ..safetensors \
--output_name hunyuan_int8 \
--dtype torch.int8 \
--model_type hunyuan_dit
......@@ -68,7 +68,7 @@ python converter.py \
--quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \
--output_ext .pth\
--output_ext .safetensors \
--output_name hunyuan_fp8 \
--dtype torch.float8_e4m3fn \
--model_type hunyuan_dit
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment