Unverified Commit efea69e1 authored by Ji Lin's avatar Ji Lin Committed by GitHub
Browse files

Merge pull request #67 from Sakits/main

Add compatibility with GQA & optimize multi-GPU memory allocation
parents dc139757 b190df35
...@@ -5,6 +5,7 @@ import argparse ...@@ -5,6 +5,7 @@ import argparse
import os import os
import json import json
from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model, load_checkpoint_in_model from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model, load_checkpoint_in_model
from accelerate.utils.modeling import get_balanced_memory
from awq.utils.parallel import auto_parallel from awq.utils.parallel import auto_parallel
from awq.quantize.pre_quant import run_awq, apply_awq from awq.quantize.pre_quant import run_awq, apply_awq
from awq.quantize.quantizer import pseudo_quantize_model_weight, real_quantize_model_weight from awq.quantize.quantizer import pseudo_quantize_model_weight, real_quantize_model_weight
...@@ -162,7 +163,7 @@ def build_model_and_enc(model_path): ...@@ -162,7 +163,7 @@ def build_model_and_enc(model_path):
raise NotImplementedError raise NotImplementedError
# Move the model to GPU (as much as possible) for LM evaluation # Move the model to GPU (as much as possible) for LM evaluation
kwargs = {"max_memory": max_memory} if len(max_memory) else {} kwargs = {"max_memory": get_balanced_memory(model, max_memory if len(max_memory) > 0 else None)}
device_map = infer_auto_device_map( device_map = infer_auto_device_map(
model, model,
# TODO: can we remove this? # TODO: can we remove this?
......
...@@ -213,11 +213,13 @@ def auto_scale_block(module, module_kwargs, ...@@ -213,11 +213,13 @@ def auto_scale_block(module, module_kwargs,
module2inspect=module.self_attn, kwargs=module_kwargs, module2inspect=module.self_attn, kwargs=module_kwargs,
)) ))
# attn out # attn out
scales_list.append(_auto_get_scale( # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
prev_op=module.self_attn.v_proj, if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers=[module.self_attn.o_proj], scales_list.append(_auto_get_scale(
inp=input_feat['self_attn.o_proj'], prev_op=module.self_attn.v_proj,
)) layers=[module.self_attn.o_proj],
inp=input_feat['self_attn.o_proj'],
))
# fc1 # fc1
scales_list.append(_auto_get_scale( scales_list.append(_auto_get_scale(
prev_op=module.post_attention_layernorm, prev_op=module.post_attention_layernorm,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment