Commit d2a10bd9 authored by Abhinav Kulkarni's avatar Abhinav Kulkarni
Browse files

Added torch.cuda.empty_cache()

parent ab536fb1
...@@ -73,9 +73,9 @@ def build_model_and_enc(model_path): ...@@ -73,9 +73,9 @@ def build_model_and_enc(model_path):
# all hf model # all hf model
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
if "mpt" in config.__class__.__name__.lower(): if "mpt" in config.__class__.__name__.lower():
enc = AutoTokenizer.from_pretrained(config.tokenizer_name) enc = AutoTokenizer.from_pretrained(config.tokenizer_name, trust_remote_code=True)
else: else:
enc = AutoTokenizer.from_pretrained(model_path, use_fast=False) enc = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
if args.load_quant: # directly load quantized weights if args.load_quant: # directly load quantized weights
print("Loading pre-computed quantized weights...") print("Loading pre-computed quantized weights...")
......
...@@ -107,11 +107,14 @@ def auto_scale_block(module, module_kwargs, ...@@ -107,11 +107,14 @@ def auto_scale_block(module, module_kwargs,
def _search_module_scale(block, linears2scale: list, x, kwargs={}): def _search_module_scale(block, linears2scale: list, x, kwargs={}):
# w: co, ci # w: co, ci
# x: n, ci # x: n, ci
x = x.to(next(block.parameters()).device)
weight = torch.cat([_m.weight for _m in linears2scale], dim=0) weight = torch.cat([_m.weight for _m in linears2scale], dim=0)
w_max = get_weight_scale( w_max = get_weight_scale(
weight, q_group_size=q_config.get("q_group_size", -1)) weight, q_group_size=q_config.get("q_group_size", -1))
# Clear GPU memory
del weight
torch.cuda.empty_cache()
x = x.to(next(block.parameters()).device)
with torch.no_grad(): with torch.no_grad():
org_out = block(x, **kwargs) org_out = block(x, **kwargs)
if isinstance(org_out, tuple): if isinstance(org_out, tuple):
...@@ -126,6 +129,8 @@ def auto_scale_block(module, module_kwargs, ...@@ -126,6 +129,8 @@ def auto_scale_block(module, module_kwargs,
n_grid = 20 n_grid = 20
history = [] history = []
# Clear GPU memory
torch.cuda.empty_cache()
org_sd = {k: v.cpu() for k, v in block.state_dict().items()} org_sd = {k: v.cpu() for k, v in block.state_dict().items()}
for ratio in range(n_grid): for ratio in range(n_grid):
ratio = ratio * 1 / n_grid ratio = ratio * 1 / n_grid
......
...@@ -135,6 +135,9 @@ def run_awq( ...@@ -135,6 +135,9 @@ def run_awq(
# now solve for scaling and clipping # now solve for scaling and clipping
input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()} input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}
# Clear GPU memory
torch.cuda.empty_cache()
if auto_scale: # if it applies, we should also modify the input_feat with scales if auto_scale: # if it applies, we should also modify the input_feat with scales
scales_list = auto_scale_block( scales_list = auto_scale_block(
layer, layer_kwargs, layer, layer_kwargs,
...@@ -145,6 +148,9 @@ def run_awq( ...@@ -145,6 +148,9 @@ def run_awq(
apply_scale(layers[i], scales_list, input_feat_dict=input_feat) apply_scale(layers[i], scales_list, input_feat_dict=input_feat)
# append prefix to make names global # append prefix to make names global
awq_results["scale"] += append_str_prefix(scales_list, get_op_name(model, layer) + ".") awq_results["scale"] += append_str_prefix(scales_list, get_op_name(model, layer) + ".")
# Clear GPU memory
torch.cuda.empty_cache()
if mse_range: if mse_range:
clip_list = auto_clip_block(layer, clip_list = auto_clip_block(layer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment