Commit 9b427ebc authored by Jiaming Tang's avatar Jiaming Tang
Browse files

Add compatibility with GQA & optimize multi-GPU memory allocation

parent dc139757
......@@ -5,6 +5,7 @@ import argparse
import os
import json
from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model, load_checkpoint_in_model
from accelerate.utils.modeling import get_balanced_memory
from awq.utils.parallel import auto_parallel
from awq.quantize.pre_quant import run_awq, apply_awq
from awq.quantize.quantizer import pseudo_quantize_model_weight, real_quantize_model_weight
......@@ -162,7 +163,7 @@ def build_model_and_enc(model_path):
raise NotImplementedError
# Move the model to GPU (as much as possible) for LM evaluation
kwargs = {"max_memory": max_memory} if len(max_memory) else {}
kwargs = {"max_memory": get_balanced_memory(model, max_memory if len(max_memory) > 0 else None)}
device_map = infer_auto_device_map(
model,
# TODO: can we remove this?
......
......@@ -213,11 +213,12 @@ def auto_scale_block(module, module_kwargs,
module2inspect=module.self_attn, kwargs=module_kwargs,
))
# attn out
scales_list.append(_auto_get_scale(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat['self_attn.o_proj'],
))
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
scales_list.append(_auto_get_scale(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat['self_attn.o_proj'],
))
# fc1
scales_list.append(_auto_get_scale(
prev_op=module.post_attention_layernorm,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment