Commit f22c2a35 authored by Jiaming Tang's avatar Jiaming Tang
Browse files

[Minor] accelerate loading quantized model

parent 5f377eff
......@@ -122,6 +122,8 @@ def real_quantize_model_weight(
if init_only:
q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], True)
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
else:
module.cuda()
module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config)
......@@ -130,7 +132,10 @@ def real_quantize_model_weight(
q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], False, scales, zeros)
module.cpu()
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
torch.cuda.empty_cache()
gc.collect()
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
gc.collect()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment