Commit f22c2a35 authored by Jiaming Tang's avatar Jiaming Tang
Browse files

[Minor] accelerate loading quantized model

parent 5f377eff
...@@ -122,6 +122,8 @@ def real_quantize_model_weight( ...@@ -122,6 +122,8 @@ def real_quantize_model_weight(
if init_only: if init_only:
q_linear = WQLinear.from_linear( q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], True) module, w_bit, q_config['q_group_size'], True)
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
else: else:
module.cuda() module.cuda()
module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config) module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config)
...@@ -130,7 +132,10 @@ def real_quantize_model_weight( ...@@ -130,7 +132,10 @@ def real_quantize_model_weight(
q_linear = WQLinear.from_linear( q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], False, scales, zeros) module, w_bit, q_config['q_group_size'], False, scales, zeros)
module.cpu() module.cpu()
q_linear.to(next(layer.parameters()).device) q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear) set_op_by_name(layer, name, q_linear)
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
torch.cuda.empty_cache()
gc.collect()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment