Commit df0c600c authored by Abhinav Kulkarni's avatar Abhinav Kulkarni
Browse files

[Minor] Added model dispatch to GPU logic

parent 4e7ada89
from lm_eval import evaluator, tasks
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import torch
import argparse
import os
import json
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model, load_checkpoint_and_dispatch
from awq.utils.parallel import auto_parallel
from awq.quantize.pre_quant import run_awq, apply_awq
from awq.quantize.quantizer import pseudo_quantize_model_weight, real_quantize_model_weight
......@@ -23,7 +23,7 @@ parser.add_argument('--parallel', action='store_true',
# max memory to offload larger models to CPU
parser.add_argument('--max_memory', type=str, nargs='*',
help="List of device_id:max_memory pairs to be parsed into a dictionary; " \
+ "Example: 0:10GiB 1:10GiB cpu:20GiB; " \
+ "Example: 0:10GiB 1:10GiB cpu:30GiB; " \
+ "mode details here: " \
+ "https://huggingface.co/docs/accelerate/usage_guides/big_modeling")
parser.add_argument('--auto_parallel', action='store_true',
......@@ -49,7 +49,7 @@ parser.add_argument('--load_awq', type=str, default=None,
help="load the awq search results")
args = parser.parse_args()
max_memory = [v.split(':') for v in (args.max_memory or "")]
max_memory = [v.split(':') for v in (args.max_memory or [])]
max_memory = {(int(k) if k.isdigit() else k):v for k,v in max_memory}
if args.auto_parallel:
......@@ -78,21 +78,28 @@ def build_model_and_enc(model_path):
enc = AutoTokenizer.from_pretrained(model_path, use_fast=False)
if args.load_quant: # directly load quantized weights
# no need to really load the fp16 weights... just to get the model structure
print("Loading pre-computed quantized weights...")
with init_empty_weights():
model = AutoModelForCausalLM.from_pretrained(model_path, config=config,
torch_dtype=torch.float16, trust_remote_code=True)
real_quantize_model_weight(
model, w_bit=args.w_bit, q_config=q_config, init_only=True)
# Passing empty max_memory={} causes error
kwargs = {"max_memory": max_memory} if len(max_memory) else {}
model = load_checkpoint_and_dispatch(
model, args.load_quant, device_map="balanced",
model,
checkpoint=args.load_quant,
device_map="balanced",
# TODO: can we remove this?
no_split_module_classes=[
"OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"]
"OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"],
**kwargs
)
else: # fp16 to quantized
args.run_awq &= not args.load_awq # if load_awq, no need to run awq
kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True}
if args.run_awq:
assert args.dump_awq, "Please save the awq results with --dump_awq"
......@@ -106,7 +113,7 @@ def build_model_and_enc(model_path):
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, torch_dtype=torch.float16)
model_path, config=config, trust_remote_code=True, **kwargs)
awq_results = run_awq(
model, enc,
......@@ -124,7 +131,6 @@ def build_model_and_enc(model_path):
else:
# Inference with fake quant
# Init model on CPU:
kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True}
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, **kwargs)
......@@ -157,14 +163,15 @@ def build_model_and_enc(model_path):
raise NotImplementedError
# Move the model to GPU (as much as possible) for LM evaluation
kwargs = {
"torch_dtype": torch.float16,
"device_map": "auto",
}
if len(max_memory):
kwargs["max_memory"] = max_memory
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, state_dict=model.state_dict(), trust_remote_code=True, **kwargs)
kwargs = {"max_memory": max_memory} if len(max_memory) else {}
device_map = infer_auto_device_map(
model,
# TODO: can we remove this?
no_split_module_classes=[
"OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"],
**kwargs
)
model = dispatch_model(model, device_map=device_map)
return model, enc
......
......@@ -133,5 +133,3 @@ def real_quantize_model_weight(
set_op_by_name(layer, name, q_linear)
torch.cuda.empty_cache()
gc.collect()
model.tie_weights()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment