Commit df0c600c authored by Abhinav Kulkarni's avatar Abhinav Kulkarni
Browse files

[Minor] Added model dispatch to GPU logic

parent 4e7ada89
from lm_eval import evaluator, tasks from lm_eval import evaluator, tasks
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import torch import torch
import argparse import argparse
import os import os
import json import json
from accelerate import init_empty_weights, load_checkpoint_and_dispatch from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model, load_checkpoint_and_dispatch
from awq.utils.parallel import auto_parallel from awq.utils.parallel import auto_parallel
from awq.quantize.pre_quant import run_awq, apply_awq from awq.quantize.pre_quant import run_awq, apply_awq
from awq.quantize.quantizer import pseudo_quantize_model_weight, real_quantize_model_weight from awq.quantize.quantizer import pseudo_quantize_model_weight, real_quantize_model_weight
...@@ -23,7 +23,7 @@ parser.add_argument('--parallel', action='store_true', ...@@ -23,7 +23,7 @@ parser.add_argument('--parallel', action='store_true',
# max memory to offload larger models to CPU # max memory to offload larger models to CPU
parser.add_argument('--max_memory', type=str, nargs='*', parser.add_argument('--max_memory', type=str, nargs='*',
help="List of device_id:max_memory pairs to be parsed into a dictionary; " \ help="List of device_id:max_memory pairs to be parsed into a dictionary; " \
+ "Example: 0:10GiB 1:10GiB cpu:20GiB; " \ + "Example: 0:10GiB 1:10GiB cpu:30GiB; " \
+ "mode details here: " \ + "mode details here: " \
+ "https://huggingface.co/docs/accelerate/usage_guides/big_modeling") + "https://huggingface.co/docs/accelerate/usage_guides/big_modeling")
parser.add_argument('--auto_parallel', action='store_true', parser.add_argument('--auto_parallel', action='store_true',
...@@ -49,7 +49,7 @@ parser.add_argument('--load_awq', type=str, default=None, ...@@ -49,7 +49,7 @@ parser.add_argument('--load_awq', type=str, default=None,
help="load the awq search results") help="load the awq search results")
args = parser.parse_args() args = parser.parse_args()
max_memory = [v.split(':') for v in (args.max_memory or "")] max_memory = [v.split(':') for v in (args.max_memory or [])]
max_memory = {(int(k) if k.isdigit() else k):v for k,v in max_memory} max_memory = {(int(k) if k.isdigit() else k):v for k,v in max_memory}
if args.auto_parallel: if args.auto_parallel:
...@@ -78,22 +78,29 @@ def build_model_and_enc(model_path): ...@@ -78,22 +78,29 @@ def build_model_and_enc(model_path):
enc = AutoTokenizer.from_pretrained(model_path, use_fast=False) enc = AutoTokenizer.from_pretrained(model_path, use_fast=False)
if args.load_quant: # directly load quantized weights if args.load_quant: # directly load quantized weights
# no need to really load the fp16 weights... just to get the model structure
print("Loading pre-computed quantized weights...") print("Loading pre-computed quantized weights...")
with init_empty_weights(): with init_empty_weights():
model = AutoModelForCausalLM.from_pretrained(model_path, config=config, model = AutoModelForCausalLM.from_pretrained(model_path, config=config,
torch_dtype=torch.float16, trust_remote_code=True) torch_dtype=torch.float16, trust_remote_code=True)
real_quantize_model_weight( real_quantize_model_weight(
model, w_bit=args.w_bit, q_config=q_config, init_only=True) model, w_bit=args.w_bit, q_config=q_config, init_only=True)
# Passing empty max_memory={} causes error
kwargs = {"max_memory": max_memory} if len(max_memory) else {}
model = load_checkpoint_and_dispatch( model = load_checkpoint_and_dispatch(
model, args.load_quant, device_map="balanced", model,
checkpoint=args.load_quant,
device_map="balanced",
# TODO: can we remove this? # TODO: can we remove this?
no_split_module_classes=[ no_split_module_classes=[
"OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"] "OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"],
**kwargs
) )
else: # fp16 to quantized else: # fp16 to quantized
args.run_awq &= not args.load_awq # if load_awq, no need to run awq args.run_awq &= not args.load_awq # if load_awq, no need to run awq
kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True}
if args.run_awq: if args.run_awq:
assert args.dump_awq, "Please save the awq results with --dump_awq" assert args.dump_awq, "Please save the awq results with --dump_awq"
...@@ -106,7 +113,7 @@ def build_model_and_enc(model_path): ...@@ -106,7 +113,7 @@ def build_model_and_enc(model_path):
torch.nn.init.uniform_ = skip torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip torch.nn.init.normal_ = skip
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, torch_dtype=torch.float16) model_path, config=config, trust_remote_code=True, **kwargs)
awq_results = run_awq( awq_results = run_awq(
model, enc, model, enc,
...@@ -124,7 +131,6 @@ def build_model_and_enc(model_path): ...@@ -124,7 +131,6 @@ def build_model_and_enc(model_path):
else: else:
# Inference with fake quant # Inference with fake quant
# Init model on CPU: # Init model on CPU:
kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True}
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, **kwargs) model_path, config=config, trust_remote_code=True, **kwargs)
...@@ -156,15 +162,16 @@ def build_model_and_enc(model_path): ...@@ -156,15 +162,16 @@ def build_model_and_enc(model_path):
else: else:
raise NotImplementedError raise NotImplementedError
# Move the model to GPU (as much as possible) for LM evaluation # Move the model to GPU (as much as possible) for LM evaluation
kwargs = { kwargs = {"max_memory": max_memory} if len(max_memory) else {}
"torch_dtype": torch.float16, device_map = infer_auto_device_map(
"device_map": "auto", model,
} # TODO: can we remove this?
if len(max_memory): no_split_module_classes=[
kwargs["max_memory"] = max_memory "OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"],
model = AutoModelForCausalLM.from_pretrained( **kwargs
model_path, config=config, state_dict=model.state_dict(), trust_remote_code=True, **kwargs) )
model = dispatch_model(model, device_map=device_map)
return model, enc return model, enc
......
...@@ -133,5 +133,3 @@ def real_quantize_model_weight( ...@@ -133,5 +133,3 @@ def real_quantize_model_weight(
set_op_by_name(layer, name, q_linear) set_op_by_name(layer, name, q_linear)
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
model.tie_weights()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment