Commit dbfa688b authored by helloyongyang's avatar helloyongyang
Browse files

Fix weight loading without quant

parent 420fec7f
...@@ -64,7 +64,7 @@ class MMWeight(MMWeightTemplate): ...@@ -64,7 +64,7 @@ class MMWeight(MMWeightTemplate):
super().__init__(weight_name, bias_name) super().__init__(weight_name, bias_name)
def load(self, weight_dict): def load(self, weight_dict):
if GET_RUNNING_FLAG() == "save_naive_quant" or self.config.get("weight_auto_quant", False): if GET_RUNNING_FLAG() == "save_naive_quant" or self.config.get("weight_auto_quant", False) or self.config.get("mm_type", "Default") == "Default":
self.weight = weight_dict[self.weight_name].t().cuda() self.weight = weight_dict[self.weight_name].t().cuda()
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
else: else:
......
...@@ -73,7 +73,7 @@ class HunyuanModel: ...@@ -73,7 +73,7 @@ class HunyuanModel:
return weight_dict return weight_dict
def _init_weights(self): def _init_weights(self):
if GET_RUNNING_FLAG() == "save_naive_quant" or self.config["mm_config"].get("weight_auto_quant", False): if GET_RUNNING_FLAG() == "save_naive_quant" or self.config["mm_config"].get("weight_auto_quant", False) or self.config["mm_config"].get("mm_type", "Default") == "Default":
weight_dict = self._load_ckpt() weight_dict = self._load_ckpt()
else: else:
weight_dict = self._load_ckpt_quant_model() weight_dict = self._load_ckpt_quant_model()
......
...@@ -90,7 +90,7 @@ class WanModel: ...@@ -90,7 +90,7 @@ class WanModel:
def _init_weights(self, weight_dict=None): def _init_weights(self, weight_dict=None):
if weight_dict is None: if weight_dict is None:
if GET_RUNNING_FLAG() == "save_naive_quant" or self.config["mm_config"].get("weight_auto_quant", False): if GET_RUNNING_FLAG() == "save_naive_quant" or self.config["mm_config"].get("weight_auto_quant", False) or self.config["mm_config"].get("mm_type", "Default") == "Default":
self.original_weight_dict = self._load_ckpt() self.original_weight_dict = self._load_ckpt()
else: else:
self.original_weight_dict = self._load_ckpt_quant_model() self.original_weight_dict = self._load_ckpt_quant_model()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment