"...resnet50_tensorflow.git" did not exist on "6c2a1d6bc0ec1d345b247a1f89f0ff975628b8a8"
Commit a92ea6e8 authored by gushiqiao's avatar gushiqiao
Browse files

Fix

parent 8ba6e3b4
...@@ -38,9 +38,15 @@ class Conv3dWeight(Conv3dWeightTemplate): ...@@ -38,9 +38,15 @@ class Conv3dWeight(Conv3dWeightTemplate):
self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) if self.bias_name is not None else None self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) if self.bias_name is not None else None
def apply(self, input_tensor): def apply(self, input_tensor):
# if input_tensor.dtype == torch.float: input_tensor = torch.nn.functional.conv3d(
# input_tensor = input_tensor.to(torch.bfloat16) input_tensor,
input_tensor = torch.nn.functional.conv3d(input_tensor, weight=self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) weight=self.weight,
bias=self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
)
return input_tensor return input_tensor
def to_cpu(self, non_blocking=False): def to_cpu(self, non_blocking=False):
......
...@@ -79,7 +79,6 @@ class MMWeight(MMWeightTemplate): ...@@ -79,7 +79,6 @@ class MMWeight(MMWeightTemplate):
self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) if self.bias is not None else None self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) if self.bias is not None else None
def apply(self, input_tensor): def apply(self, input_tensor):
# if input_tensor.dtype != torch.float
shape = (input_tensor.shape[0], self.weight.shape[1]) shape = (input_tensor.shape[0], self.weight.shape[1])
dtype = input_tensor.dtype dtype = input_tensor.dtype
device = input_tensor.device device = input_tensor.device
...@@ -143,6 +142,7 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -143,6 +142,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.load_func(weight_dict) self.load_func(weight_dict)
if self.weight_need_transpose: if self.weight_need_transpose:
self.weight = self.weight.t() self.weight = self.weight.t()
self.pinned_weight = self.pinned_weight.t()
def clear(self): def clear(self):
attrs = ["weight", "weight_scale", "bias"] attrs = ["weight", "weight_scale", "bias"]
......
...@@ -72,10 +72,6 @@ def apply_rotary_emb(x, freqs_i): ...@@ -72,10 +72,6 @@ def apply_rotary_emb(x, freqs_i):
# Apply rotary embedding # Apply rotary embedding
x_i = torch.view_as_real(x_i * freqs_i).flatten(2) x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
x_i = torch.cat([x_i, x[seq_len:]]) x_i = torch.cat([x_i, x[seq_len:]])
# if GET_DTYPE() == "BF16":
# x_i = x_i.to(torch.bfloat16)
# else:
# x_i = x_i.float()
return x_i.to(torch.bfloat16) return x_i.to(torch.bfloat16)
......
...@@ -37,7 +37,6 @@ class WanModel: ...@@ -37,7 +37,6 @@ class WanModel:
self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False) self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False)
if self.dit_quantized: if self.dit_quantized:
assert self.weight_auto_quant or self.dit_quantized_ckpt is not None assert self.weight_auto_quant or self.dit_quantized_ckpt is not None
assert GET_DTYPE() == "BF16"
self.device = device self.device = device
self._init_infer_class() self._init_infer_class()
...@@ -63,13 +62,11 @@ class WanModel: ...@@ -63,13 +62,11 @@ class WanModel:
else: else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}") raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
def _load_safetensor_to_dict(self, file_path): def _load_safetensor_to_dict(self, file_path, use_bf16, skip_bf16):
with safe_open(file_path, framework="pt") as f: with safe_open(file_path, framework="pt") as f:
use_bf16 = GET_DTYPE() == "BF16"
skip_bf16 = {"norm", "embedding", "modulation", "time"}
return {key: (f.get_tensor(key).to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()} return {key: (f.get_tensor(key).to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()}
def _load_ckpt(self): def _load_ckpt(self, use_bf16, skip_bf16):
safetensors_pattern = os.path.join(self.model_path, "*.safetensors") safetensors_pattern = os.path.join(self.model_path, "*.safetensors")
safetensors_files = glob.glob(safetensors_pattern) safetensors_files = glob.glob(safetensors_pattern)
...@@ -77,41 +74,41 @@ class WanModel: ...@@ -77,41 +74,41 @@ class WanModel:
raise FileNotFoundError(f"No .safetensors files found in directory: {self.model_path}") raise FileNotFoundError(f"No .safetensors files found in directory: {self.model_path}")
weight_dict = {} weight_dict = {}
for file_path in safetensors_files: for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path) file_weights = self._load_safetensor_to_dict(file_path, use_bf16, skip_bf16)
weight_dict.update(file_weights) weight_dict.update(file_weights)
return weight_dict return weight_dict
def _load_quant_ckpt(self): def _load_quant_ckpt(self, use_bf16, skip_bf16):
ckpt_path = self.config.dit_quantized_ckpt ckpt_path = self.config.dit_quantized_ckpt
logger.info(f"Loading quant dit model from {ckpt_path}") logger.info(f"Loading quant dit model from {ckpt_path}")
if ckpt_path.endswith(".pth"): index_files = [f for f in os.listdir(ckpt_path) if f.endswith(".index.json")]
logger.info(f"Loading {ckpt_path} as PyTorch model.") if not index_files:
weight_dict = torch.load(ckpt_path, map_location=self.device, weights_only=True) raise FileNotFoundError(f"No *.index.json found in {ckpt_path}")
else:
index_files = [f for f in os.listdir(ckpt_path) if f.endswith(".index.json")] index_path = os.path.join(ckpt_path, index_files[0])
if not index_files: logger.info(f" Using safetensors index: {index_path}")
raise FileNotFoundError(f"No .pth file or *.index.json found in {ckpt_path}")
with open(index_path, "r") as f:
index_path = os.path.join(ckpt_path, index_files[0]) index_data = json.load(f)
logger.info(f" Using safetensors index: {index_path}")
weight_dict = {}
with open(index_path, "r") as f: for filename in set(index_data["weight_map"].values()):
index_data = json.load(f) safetensor_path = os.path.join(ckpt_path, filename)
with safe_open(safetensor_path, framework="pt") as f:
weight_dict = {} logger.info(f"Loading weights from {safetensor_path}")
for filename in set(index_data["weight_map"].values()): for k in f.keys():
safetensor_path = os.path.join(ckpt_path, filename) if f.get_tensor(k).dtype == torch.float:
with safe_open(safetensor_path, framework="pt", device=str(self.device)) as f: if use_bf16 or all(s not in k for s in skip_bf16):
logger.info(f"Loading weights from {safetensor_path}") weight_dict[k] = f.get_tensor(k).pin_memory().to(torch.bfloat16).to(self.device)
for k in f.keys(): else:
weight_dict[k] = f.get_tensor(k).pin_memory() weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
if weight_dict[k].dtype == torch.float: else:
weight_dict[k] = weight_dict[k].pin_memory().to(torch.bfloat16) weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
return weight_dict return weight_dict
def _load_quant_split_ckpt(self): def _load_quant_split_ckpt(self, use_bf16, skip_bf16):
lazy_load_model_path = self.config.dit_quantized_ckpt lazy_load_model_path = self.config.dit_quantized_ckpt
logger.info(f"Loading splited quant model from {lazy_load_model_path}") logger.info(f"Loading splited quant model from {lazy_load_model_path}")
pre_post_weight_dict, transformer_weight_dict = {}, {} pre_post_weight_dict, transformer_weight_dict = {}, {}
...@@ -119,9 +116,13 @@ class WanModel: ...@@ -119,9 +116,13 @@ class WanModel:
safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors") safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors")
with safe_open(safetensor_path, framework="pt", device="cpu") as f: with safe_open(safetensor_path, framework="pt", device="cpu") as f:
for k in f.keys(): for k in f.keys():
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory() if f.get_tensor(k).dtype == torch.float:
if pre_post_weight_dict[k].dtype == torch.float: if use_bf16 or all(s not in k for s in skip_bf16):
pre_post_weight_dict[k] = pre_post_weight_dict[k].pin_memory().to(torch.bfloat16) pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(torch.bfloat16).to(self.device)
else:
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
else:
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
safetensors_pattern = os.path.join(lazy_load_model_path, "block_*.safetensors") safetensors_pattern = os.path.join(lazy_load_model_path, "block_*.safetensors")
safetensors_files = glob.glob(safetensors_pattern) safetensors_files = glob.glob(safetensors_pattern)
...@@ -132,24 +133,30 @@ class WanModel: ...@@ -132,24 +133,30 @@ class WanModel:
with safe_open(file_path, framework="pt") as f: with safe_open(file_path, framework="pt") as f:
for k in f.keys(): for k in f.keys():
if "modulation" in k: if "modulation" in k:
transformer_weight_dict[k] = f.get_tensor(k).pin_memory() if f.get_tensor(k).dtype == torch.float:
if transformer_weight_dict[k].dtype == torch.float: if use_bf16 or all(s not in k for s in skip_bf16):
transformer_weight_dict[k] = transformer_weight_dict[k].pin_memory().to(torch.bfloat16) transformer_weight_dict[k] = f.get_tensor(k).pin_memory().to(torch.bfloat16).to(self.device)
else:
transformer_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
return pre_post_weight_dict, transformer_weight_dict return pre_post_weight_dict, transformer_weight_dict
def _init_weights(self, weight_dict=None): def _init_weights(self, weight_dict=None):
use_bf16 = GET_DTYPE() == "BF16"
# Some layers run with float32 to achieve high accuracy
skip_bf16 = {"norm", "embedding", "modulation", "time"}
if weight_dict is None: if weight_dict is None:
if not self.dit_quantized or self.weight_auto_quant: if not self.dit_quantized or self.weight_auto_quant:
self.original_weight_dict = self._load_ckpt() self.original_weight_dict = self._load_ckpt(use_bf16, skip_bf16)
else: else:
if not self.config.get("lazy_load", False): if not self.config.get("lazy_load", False):
self.original_weight_dict = self._load_quant_ckpt() self.original_weight_dict = self._load_quant_ckpt(use_bf16, skip_bf16)
else: else:
( (
self.original_weight_dict, self.original_weight_dict,
self.transformer_weight_dict, self.transformer_weight_dict,
) = self._load_quant_split_ckpt() ) = self._load_quant_split_ckpt(use_bf16, skip_bf16)
else: else:
self.original_weight_dict = weight_dict self.original_weight_dict = weight_dict
# init weights # init weights
......
...@@ -519,7 +519,7 @@ def convert_weights(args): ...@@ -519,7 +519,7 @@ def convert_weights(args):
def copy_non_weight_files(source_dir, target_dir): def copy_non_weight_files(source_dir, target_dir):
ignore_extensions = [".pth", ".pt", ".safetensors"] ignore_extensions = [".pth", ".pt", ".safetensors", ".index.json"]
logger.info(f"Start copying non-weighted files and subdirectories...") logger.info(f"Start copying non-weighted files and subdirectories...")
...@@ -586,47 +586,48 @@ def main(): ...@@ -586,47 +586,48 @@ def main():
) )
args = parser.parse_args() args = parser.parse_args()
if args.dtype == "torch.int8": if args.quantized:
args.dtype = torch.int8 if args.dtype == "torch.int8":
elif args.dtype == "torch.float8_e4m3fn": args.dtype = torch.int8
args.dtype = torch.float8_e4m3fn elif args.dtype == "torch.float8_e4m3fn":
else: args.dtype = torch.float8_e4m3fn
raise ValueError(f"Not support dtype :{args.dtype}") else:
raise ValueError(f"Not support dtype :{args.dtype}")
model_type_keys_map = {
"wan_dit": { model_type_keys_map = {
"key_idx": 2, "wan_dit": {
"target_keys": ["self_attn", "cross_attn", "ffn"], "key_idx": 2,
"ignore_key": None, "target_keys": ["self_attn", "cross_attn", "ffn"],
}, "ignore_key": None,
"hunyuan_dit": { },
"key_idx": 2, "hunyuan_dit": {
"target_keys": [ "key_idx": 2,
"img_mod", "target_keys": [
"img_attn_qkv", "img_mod",
"img_attn_proj", "img_attn_qkv",
"img_mlp", "img_attn_proj",
"txt_mod", "img_mlp",
"txt_attn_qkv", "txt_mod",
"txt_attn_proj", "txt_attn_qkv",
"txt_mlp", "txt_attn_proj",
"linear1", "txt_mlp",
"linear2", "linear1",
"modulation", "linear2",
], "modulation",
"ignore_key": None, ],
}, "ignore_key": None,
"wan_t5": {"key_idx": 2, "target_keys": ["attn", "ffn"], "ignore_key": None}, },
"wan_clip": { "wan_t5": {"key_idx": 2, "target_keys": ["attn", "ffn"], "ignore_key": None},
"key_idx": 3, "wan_clip": {
"target_keys": ["attn", "mlp"], "key_idx": 3,
"ignore_key": "textual", "target_keys": ["attn", "mlp"],
}, "ignore_key": "textual",
} },
}
args.target_keys = model_type_keys_map[args.model_type]["target_keys"]
args.key_idx = model_type_keys_map[args.model_type]["key_idx"] args.target_keys = model_type_keys_map[args.model_type]["target_keys"]
args.ignore_key = model_type_keys_map[args.model_type]["ignore_key"] args.key_idx = model_type_keys_map[args.model_type]["key_idx"]
args.ignore_key = model_type_keys_map[args.model_type]["ignore_key"]
if os.path.isfile(args.output): if os.path.isfile(args.output):
raise ValueError("Output path must be a directory, not a file") raise ValueError("Output path must be a directory, not a file")
......
...@@ -33,7 +33,7 @@ python converter.py \ ...@@ -33,7 +33,7 @@ python converter.py \
--quantized \ --quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \ --source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext .safetensors \
--output_name wan_int8 \ --output_name wan_int8 \
--dtype torch.int8 \ --dtype torch.int8 \
--model_type wan_dit --model_type wan_dit
...@@ -44,7 +44,7 @@ python converter.py \ ...@@ -44,7 +44,7 @@ python converter.py \
--quantized \ --quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \ --source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext .safetensors \
--output_name wan_fp8 \ --output_name wan_fp8 \
--dtype torch.float8_e4m3fn \ --dtype torch.float8_e4m3fn \
--model_type wan_dit --model_type wan_dit
...@@ -57,7 +57,7 @@ python converter.py \ ...@@ -57,7 +57,7 @@ python converter.py \
--quantized \ --quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \ --source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext .safetensors \
--output_name hunyuan_int8 \ --output_name hunyuan_int8 \
--dtype torch.int8 \ --dtype torch.int8 \
--model_type hunyuan_dit --model_type hunyuan_dit
...@@ -68,7 +68,7 @@ python converter.py \ ...@@ -68,7 +68,7 @@ python converter.py \
--quantized \ --quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \ --source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext .safetensors \
--output_name hunyuan_fp8 \ --output_name hunyuan_fp8 \
--dtype torch.float8_e4m3fn \ --dtype torch.float8_e4m3fn \
--model_type hunyuan_dit --model_type hunyuan_dit
......
...@@ -33,7 +33,7 @@ python converter.py \ ...@@ -33,7 +33,7 @@ python converter.py \
--quantized \ --quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \ --source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext .safetensors \
--output_name wan_int8 \ --output_name wan_int8 \
--dtype torch.int8 \ --dtype torch.int8 \
--model_type wan_dit --model_type wan_dit
...@@ -44,7 +44,7 @@ python converter.py \ ...@@ -44,7 +44,7 @@ python converter.py \
--quantized \ --quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \ --source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext .safetensors \
--output_name wan_fp8 \ --output_name wan_fp8 \
--dtype torch.float8_e4m3fn \ --dtype torch.float8_e4m3fn \
--model_type wan_dit --model_type wan_dit
...@@ -57,7 +57,7 @@ python converter.py \ ...@@ -57,7 +57,7 @@ python converter.py \
--quantized \ --quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \ --source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext ..safetensors \
--output_name hunyuan_int8 \ --output_name hunyuan_int8 \
--dtype torch.int8 \ --dtype torch.int8 \
--model_type hunyuan_dit --model_type hunyuan_dit
...@@ -68,7 +68,7 @@ python converter.py \ ...@@ -68,7 +68,7 @@ python converter.py \
--quantized \ --quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \ --source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext .safetensors \
--output_name hunyuan_fp8 \ --output_name hunyuan_fp8 \
--dtype torch.float8_e4m3fn \ --dtype torch.float8_e4m3fn \
--model_type hunyuan_dit --model_type hunyuan_dit
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment