Commit 53c0d05c authored by helloyongyang's avatar helloyongyang
Browse files

update save/load of mm weights

parent 78640ad0
...@@ -49,14 +49,6 @@ class MMWeightTemplate(metaclass=ABCMeta): ...@@ -49,14 +49,6 @@ class MMWeightTemplate(metaclass=ABCMeta):
if self.bias is not None: if self.bias is not None:
self.bias = self.bias.cuda(non_blocking=non_blocking) self.bias = self.bias.cuda(non_blocking=non_blocking)
def state_dict(self, destination=None):
if destination is None:
destination = {}
destination[self.weight_name] = self.weight.cpu().detach().clone()
if self.bias is not None:
destination[self.bias_name] = self.bias.cpu().detach().clone()
return destination
@MM_WEIGHT_REGISTER("Default") @MM_WEIGHT_REGISTER("Default")
class MMWeight(MMWeightTemplate): class MMWeight(MMWeightTemplate):
...@@ -64,12 +56,8 @@ class MMWeight(MMWeightTemplate): ...@@ -64,12 +56,8 @@ class MMWeight(MMWeightTemplate):
super().__init__(weight_name, bias_name) super().__init__(weight_name, bias_name)
def load(self, weight_dict): def load(self, weight_dict):
if GET_RUNNING_FLAG() == "save_naive_quant" or self.config.get("weight_auto_quant", False) or self.config.get("mm_type", "Default") == "Default": self.weight = weight_dict[self.weight_name].t().cuda()
self.weight = weight_dict[self.weight_name].t().cuda() self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
else:
self.weight = weight_dict[self.weight_name].cuda()
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
def apply(self, input_tensor): def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[1]) shape = (input_tensor.shape[0], self.weight.shape[1])
...@@ -80,6 +68,14 @@ class MMWeight(MMWeightTemplate): ...@@ -80,6 +68,14 @@ class MMWeight(MMWeightTemplate):
return torch.mm(input_tensor, self.weight, out=output_tensor) return torch.mm(input_tensor, self.weight, out=output_tensor)
return torch.addmm(self.bias, input_tensor, self.weight, out=output_tensor) return torch.addmm(self.bias, input_tensor, self.weight, out=output_tensor)
def state_dict(self, destination=None):
if destination is None:
destination = {}
destination[self.weight_name] = self.weight.cpu().detach().clone().t().contiguous()
if self.bias is not None:
destination[self.bias_name] = self.bias.cpu().detach().clone()
return destination
@MM_WEIGHT_REGISTER("Default-Force-FP32") @MM_WEIGHT_REGISTER("Default-Force-FP32")
class MMWeightForceFP32(MMWeight): class MMWeightForceFP32(MMWeight):
...@@ -106,6 +102,8 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -106,6 +102,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
def load(self, weight_dict): def load(self, weight_dict):
self.load_func(weight_dict) self.load_func(weight_dict)
if self.weight_need_transpose:
self.weight = self.weight.t()
def load_quantized(self, weight_dict): def load_quantized(self, weight_dict):
self.weight = weight_dict[self.weight_name].cuda() self.weight = weight_dict[self.weight_name].cuda()
...@@ -118,8 +116,6 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -118,8 +116,6 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight) self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.float8_e4m3fn) self.weight = self.weight.to(torch.float8_e4m3fn)
self.weight_scale = self.weight_scale.to(torch.float32) self.weight_scale = self.weight_scale.to(torch.float32)
if self.weight_need_transpose:
self.weight = self.weight.t()
else: else:
self.load_quantized(weight_dict) self.load_quantized(weight_dict)
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
...@@ -131,8 +127,6 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -131,8 +127,6 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight) self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.int8) self.weight = self.weight.to(torch.int8)
self.weight_scale = self.weight_scale.to(torch.float32) self.weight_scale = self.weight_scale.to(torch.float32)
if self.weight_need_transpose:
self.weight = self.weight.t()
else: else:
self.load_quantized(weight_dict) self.load_quantized(weight_dict)
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
...@@ -141,8 +135,6 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -141,8 +135,6 @@ class MMWeightQuantTemplate(MMWeightTemplate):
if GET_RUNNING_FLAG() == "save_naive_quant" or self.config.get("weight_auto_quant", False): if GET_RUNNING_FLAG() == "save_naive_quant" or self.config.get("weight_auto_quant", False):
self.weight = weight_dict[self.weight_name].cuda() self.weight = weight_dict[self.weight_name].cuda()
self.weight, self.weight_scale = self.per_block_cast_to_fp8(self.weight) self.weight, self.weight_scale = self.per_block_cast_to_fp8(self.weight)
if self.weight_need_transpose:
self.weight = self.weight.t()
else: else:
self.load_quantized(weight_dict) self.load_quantized(weight_dict)
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
...@@ -193,7 +185,10 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -193,7 +185,10 @@ class MMWeightQuantTemplate(MMWeightTemplate):
def state_dict(self, destination=None): def state_dict(self, destination=None):
if destination is None: if destination is None:
destination = {} destination = {}
destination[self.weight_name] = self.weight.cpu().detach().clone() if self.weight_need_transpose:
destination[self.weight_name] = self.weight.cpu().detach().clone().t().contiguous()
else:
destination[self.weight_name] = self.weight.cpu().detach().clone().contiguous()
if self.bias is not None: if self.bias is not None:
destination[self.bias_name] = self.bias.cpu().detach().clone() destination[self.bias_name] = self.bias.cpu().detach().clone()
if hasattr(self, "weight_scale"): if hasattr(self, "weight_scale"):
...@@ -478,7 +473,7 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate): ...@@ -478,7 +473,7 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
if __name__ == "__main__": if __name__ == "__main__":
weight_dict = { weight_dict = {
"xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn).t(), "xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn),
"xx.bias": torch.randn(8192).to(torch.bfloat16), "xx.bias": torch.randn(8192).to(torch.bfloat16),
"xx.weight_scale": torch.randn(8192, 1).to(torch.float32), "xx.weight_scale": torch.randn(8192, 1).to(torch.float32),
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment