"test/vscode:/vscode.git/clone" did not exist on "0f4fb19bc8cf87f518b0273ee970d5b3eef8beb5"
Commit f085ede3 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Fix] Fix high peak memory bug (#313)

parent 6d9e6c0a
...@@ -84,19 +84,25 @@ class MMWeight(MMWeightTemplate): ...@@ -84,19 +84,25 @@ class MMWeight(MMWeightTemplate):
def load(self, weight_dict): def load(self, weight_dict):
device = weight_dict[self.weight_name].device device = weight_dict[self.weight_name].device
if device.type == "cuda":
self.weight = weight_dict[self.weight_name].t()
if self.bias_name is not None:
self.bias = weight_dict[self.bias_name]
elif device.type == "cpu":
weight_shape = weight_dict[self.weight_name].t().shape
weight_dtype = weight_dict[self.weight_name].dtype
self.weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
self.weight.copy_(weight_dict[self.weight_name].t())
weight_shape = weight_dict[self.weight_name].t().shape if self.bias_name is not None:
weight_dtype = weight_dict[self.weight_name].dtype bias_shape = weight_dict[self.bias_name].shape
self.weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype).to(device) bias_dtype = weight_dict[self.bias_name].dtype
self.weight = self.weight.copy_(weight_dict[self.weight_name].t()) self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
self.bias.copy_(weight_dict[self.bias_name])
if self.bias_name is not None: else:
bias_shape = weight_dict[self.bias_name].shape self.bias = None
bias_dtype = weight_dict[self.bias_name].dtype
self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype).to(device)
self.bias = self.bias.copy_(weight_dict[self.bias_name])
else: else:
self.bias = None raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
def _calculate_size(self): def _calculate_size(self):
if self.bias is not None: if self.bias is not None:
...@@ -149,7 +155,7 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -149,7 +155,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
# weight load functions # weight load functions
# ========================= # =========================
def load_from_disk(self): def load_from_disk(self): # Need Rewrite
if not torch._dynamo.is_compiling(): if not torch._dynamo.is_compiling():
self.weight = self.lazy_load_file.get_tensor(self.weight_name).pin_memory() self.weight = self.lazy_load_file.get_tensor(self.weight_name).pin_memory()
self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float().pin_memory() self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float().pin_memory()
...@@ -180,28 +186,25 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -180,28 +186,25 @@ class MMWeightQuantTemplate(MMWeightTemplate):
def _calculate_size(self): def _calculate_size(self):
if self.bias is not None: if self.bias is not None:
return self.weight.numel() * self.weight.element_size() + self.weight_scale.numel() * self.weight_scale.element_size() + self.bias.numel() * self.bias.element_size() return self.weight.numel() * self.weight.element_size() + self.weight_scale.numel() * self.weight_scale.element_size() + self.bias.numel() * self.bias.element_size()
return self.weight.numel() * self.weight.element_size() + self.weight_scale.numel() * self.weight_scale.element_size() return self.weight.numel() * self.weight.element_size() + self.weight_scale.numel() * self.weight_scale.element_size()
def load_quantized(self, weight_dict): def load_quantized(self, weight_dict):
device = weight_dict[self.weight_name].device device = weight_dict[self.weight_name].device
weight_shape = weight_dict[self.weight_name].shape if device.type == "cuda":
weight_dtype = weight_dict[self.weight_name].dtype self.weight = weight_dict[self.weight_name]
self.weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype).to(device) self.weight_scale = weight_dict[self.weight_scale_name].float()
self.weight = self.weight.copy_(weight_dict[self.weight_name]) elif device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape
weight_scale_shape = weight_dict[self.weight_scale_name].shape weight_dtype = weight_dict[self.weight_name].dtype
weight_scale_dtype = torch.float self.weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
self.weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype).to(device) self.weight.copy_(weight_dict[self.weight_name])
self.weight_scale = self.weight_scale.copy_(weight_dict[self.weight_scale_name])
weight_scale_shape = weight_dict[self.weight_scale_name].shape
if self.bias_name is not None: weight_scale_dtype = torch.float
bias_shape = weight_dict[self.bias_name].shape self.weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype)
bias_dtype = weight_dict[self.bias_name].dtype self.weight_scale.copy_(weight_dict[self.weight_scale_name])
self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype).to(device)
self.bias = self.bias.copy_(weight_dict[self.bias_name])
else: else:
self.bias = None raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
def load_fp8_perchannel_sym(self, weight_dict): def load_fp8_perchannel_sym(self, weight_dict):
if self.config.get("weight_auto_quant", False): if self.config.get("weight_auto_quant", False):
...@@ -215,10 +218,15 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -215,10 +218,15 @@ class MMWeightQuantTemplate(MMWeightTemplate):
if self.bias_name is not None: if self.bias_name is not None:
device = weight_dict[self.bias_name].device device = weight_dict[self.bias_name].device
bias_shape = weight_dict[self.bias_name].shape if device.type == "cuda":
bias_dtype = weight_dict[self.bias_name].dtype self.bias = weight_dict[self.bias_name]
self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype).to(device) elif device.type == "cpu":
self.bias = self.bias.copy_(weight_dict[self.bias_name]) bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype
self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
self.bias.copy_(weight_dict[self.bias_name])
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
else: else:
self.bias = None self.bias = None
...@@ -234,10 +242,15 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -234,10 +242,15 @@ class MMWeightQuantTemplate(MMWeightTemplate):
if self.bias_name is not None: if self.bias_name is not None:
device = weight_dict[self.bias_name].device device = weight_dict[self.bias_name].device
bias_shape = weight_dict[self.bias_name].shape if device.type == "cuda":
bias_dtype = weight_dict[self.bias_name].dtype self.bias = weight_dict[self.bias_name]
self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype).to(device) elif device.type == "cpu":
self.bias = self.bias.copy_(weight_dict[self.bias_name]) bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype
self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
self.bias.copy_(weight_dict[self.bias_name])
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
else: else:
self.bias = None self.bias = None
...@@ -250,10 +263,15 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -250,10 +263,15 @@ class MMWeightQuantTemplate(MMWeightTemplate):
if self.bias_name is not None: if self.bias_name is not None:
device = weight_dict[self.bias_name].device device = weight_dict[self.bias_name].device
bias_shape = weight_dict[self.bias_name].shape if device.type == "cuda":
bias_dtype = weight_dict[self.bias_name].dtype self.bias = weight_dict[self.bias_name]
self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype).to(device) elif device.type == "cpu":
self.bias = self.bias.copy_(weight_dict[self.bias_name]) bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype
self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
self.bias.copy_(weight_dict[self.bias_name])
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
else: else:
self.bias = None self.bias = None
...@@ -735,8 +753,8 @@ class MMWeightWint4group128Marlin(MMWeightQuantTemplate): ...@@ -735,8 +753,8 @@ class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
if self.bias_name is not None: if self.bias_name is not None:
bias_shape = weight_dict[self.bias_name].shape bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype bias_dtype = weight_dict[self.bias_name].dtype
self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype).to(device) self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
self.bias = self.bias.copy_(weight_dict[self.bias_name]) self.bias.copy_(weight_dict[self.bias_name])
else: else:
self.bias = None self.bias = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment