"docs/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "b076abd17bb93a61ef9cc7d4265dab8b1e154dd8"
Commit 95cd9c2d authored by Abhinav Kulkarni's avatar Abhinav Kulkarni
Browse files

[Major] Add CPU offloading support for apply_scale, apply_clip,...

[Major] Add CPU offloading support for apply_scale, apply_clip, pseudo_quantize_model_weight, real_quantize_model_weight
parent 8e7e9ccc
...@@ -114,8 +114,8 @@ def build_model_and_enc(model_path): ...@@ -114,8 +114,8 @@ def build_model_and_enc(model_path):
exit(0) exit(0)
else: else:
# Inference with fake quant # Inference with fake quant
# Init model on GPUs: # Init model on CPU:
kwargs = {"device_map": "balanced", "torch_dtype": torch.float16} kwargs = {"torch_dtype": torch.float16}
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, **kwargs) model_path, config=config, trust_remote_code=True, **kwargs)
...@@ -146,6 +146,15 @@ def build_model_and_enc(model_path): ...@@ -146,6 +146,15 @@ def build_model_and_enc(model_path):
exit(0) exit(0)
else: else:
raise NotImplementedError raise NotImplementedError
# Move the model to GPU (as much as possible) for LM evaluation
kwargs = {
"torch_dtype": torch.float16,
"device_map": "auto",
"max_memory": {0: "8GiB", "cpu": "99GiB"}
}
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, state_dict=model.state_dict(), trust_remote_code=True, **kwargs)
return model, enc return model, enc
...@@ -163,11 +172,10 @@ def main(): ...@@ -163,11 +172,10 @@ def main():
# a hack here to auto set model group # a hack here to auto set model group
model, enc = build_model_and_enc(args.model_path) model, enc = build_model_and_enc(args.model_path)
lm_eval_model = LMEvalAdaptor(args.model_path, model, enc, args.batch_size)
if args.tasks is not None: if args.tasks is not None:
task_names = args.tasks.split(",") task_names = args.tasks.split(",")
lm_eval_model = LMEvalAdaptor(args.model_path, model, enc, args.batch_size)
results = evaluator.simple_evaluate( results = evaluator.simple_evaluate(
model=lm_eval_model, model=lm_eval_model,
tasks=task_names, tasks=task_names,
......
...@@ -86,8 +86,10 @@ def apply_clip(module, clip_list): ...@@ -86,8 +86,10 @@ def apply_clip(module, clip_list):
from ..utils.module import get_op_by_name from ..utils.module import get_op_by_name
for name, max_val in clip_list: for name, max_val in clip_list:
layer = get_op_by_name(module, name) layer = get_op_by_name(module, name)
layer.cuda()
max_val = max_val.to(layer.weight.device) max_val = max_val.to(layer.weight.device)
org_shape = layer.weight.shape org_shape = layer.weight.shape
layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1) layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1)
layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val) layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val)
layer.weight.data = layer.weight.data.reshape(org_shape) layer.weight.data = layer.weight.data.reshape(org_shape)
layer.cpu()
...@@ -320,6 +320,10 @@ def apply_scale(module, scales_list, input_feat_dict=None): ...@@ -320,6 +320,10 @@ def apply_scale(module, scales_list, input_feat_dict=None):
for prev_op_name, layer_names, scales in scales_list: for prev_op_name, layer_names, scales in scales_list:
prev_op = get_op_by_name(module, prev_op_name) prev_op = get_op_by_name(module, prev_op_name)
layers = [get_op_by_name(module, name) for name in layer_names] layers = [get_op_by_name(module, name) for name in layer_names]
prev_op.cuda()
for layer in layers:
layer.cuda()
if isinstance(prev_op, nn.Linear): if isinstance(prev_op, nn.Linear):
assert len(layers) == 1 assert len(layers) == 1
...@@ -338,4 +342,8 @@ def apply_scale(module, scales_list, input_feat_dict=None): ...@@ -338,4 +342,8 @@ def apply_scale(module, scales_list, input_feat_dict=None):
if input_feat_dict is not None: if input_feat_dict is not None:
for layer_name in layer_names: for layer_name in layer_names:
inp = input_feat_dict[layer_name] inp = input_feat_dict[layer_name]
inp.div_(scales.view(1, -1).to(inp.device)) inp.div_(scales.view(1, -1).to(inp.device))
\ No newline at end of file
prev_op.cpu()
for layer in layers:
layer.cpu()
...@@ -98,7 +98,9 @@ def pseudo_quantize_model_weight( ...@@ -98,7 +98,9 @@ def pseudo_quantize_model_weight(
for i in tqdm(range(len(layers)), desc="pseudo weight quantization..."): for i in tqdm(range(len(layers)), desc="pseudo weight quantization..."):
named_linears = get_named_linears(layers[i]) named_linears = get_named_linears(layers[i])
for n, m in named_linears.items(): for n, m in named_linears.items():
m.cuda()
m.weight.data = pseudo_quantize_tensor(m.weight.data, n_bit=w_bit, **q_config) m.weight.data = pseudo_quantize_tensor(m.weight.data, n_bit=w_bit, **q_config)
m.cpu()
@torch.no_grad() @torch.no_grad()
...@@ -121,11 +123,15 @@ def real_quantize_model_weight( ...@@ -121,11 +123,15 @@ def real_quantize_model_weight(
q_linear = WQLinear.from_linear( q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], True) module, w_bit, q_config['q_group_size'], True)
else: else:
module.cuda()
module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config) module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config)
scales = scales.t().contiguous() scales = scales.t().contiguous()
zeros = zeros.t().contiguous() zeros = zeros.t().contiguous()
q_linear = WQLinear.from_linear( q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], False, scales, zeros) module, w_bit, q_config['q_group_size'], False, scales, zeros)
module.cpu()
set_op_by_name(layer, name, q_linear) set_op_by_name(layer, name, q_linear)
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
model.tie_weights()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment