Commit 34f1faff authored by Jiaming Tang's avatar Jiaming Tang
Browse files

[Major] Add CPU offloading support for run_awq

parent a293e16f
...@@ -83,17 +83,22 @@ def build_model_and_enc(model_path): ...@@ -83,17 +83,22 @@ def build_model_and_enc(model_path):
"OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"] "OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"]
) )
else: # fp16 to quantized else: # fp16 to quantized
kwargs = {"device_map": "balanced", "torch_dtype": torch.float16} args.run_awq &= not args.load_awq # if load_awq, no need to run awq
model = AutoModelForCausalLM.from_pretrained( if args.run_awq:
model_path, config=config, trust_remote_code=True, **kwargs) assert args.dump_awq, "Please save the awq results with --dump_awq"
if args.load_awq: # Init model on CPU
print("Loading pre-computed AWQ results from", args.load_awq) def skip(*args, **kwargs):
awq_results = torch.load(args.load_awq, map_location="cpu") pass
apply_awq(model, awq_results)
torch.nn.init.kaiming_normal_ = skip
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, torch_dtype=torch.float16)
elif args.run_awq:
awq_results = run_awq( awq_results = run_awq(
model, enc, model, enc,
w_bit=args.w_bit, q_config=q_config, w_bit=args.w_bit, q_config=q_config,
...@@ -102,6 +107,19 @@ def build_model_and_enc(model_path): ...@@ -102,6 +107,19 @@ def build_model_and_enc(model_path):
if args.dump_awq: if args.dump_awq:
torch.save(awq_results, args.dump_awq) torch.save(awq_results, args.dump_awq)
print("AWQ results saved at", args.dump_awq) print("AWQ results saved at", args.dump_awq)
exit(0)
else:
# Inference with fake quant
# Init model on GPUs:
kwargs = {"device_map": "balanced", "torch_dtype": torch.float16}
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, **kwargs)
if args.load_awq:
print("Loading pre-computed AWQ results from", args.load_awq)
awq_results = torch.load(args.load_awq, map_location="cpu")
apply_awq(model, awq_results)
# weight quantization # weight quantization
if args.w_bit is not None: if args.w_bit is not None:
......
...@@ -34,6 +34,22 @@ def get_blocks(model): ...@@ -34,6 +34,22 @@ def get_blocks(model):
raise NotImplementedError(type(model)) raise NotImplementedError(type(model))
return layers return layers
def move_embed(model, device):
if isinstance(model, LlamaForCausalLM):
model.model.embed_tokens = model.model.embed_tokens.to(device)
elif isinstance(model, OPTForCausalLM):
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(device)
model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(device)
elif isinstance(model, BloomForCausalLM):
model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
model.transformer.word_embeddings_layernorm = model.transformer.word_embeddings_layernorm.to(device)
elif "mpt" in str(model.__class__).lower():
model.transformer.wte = model.transformer.wte.to(device)
model.transformer.emb_drop = model.transformer.emb_drop.to(device)
elif "falcon" in str(model.__class__).lower():
model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
else:
raise NotImplementedError(type(model))
@torch.no_grad() @torch.no_grad()
def run_awq( def run_awq(
...@@ -57,6 +73,9 @@ def run_awq( ...@@ -57,6 +73,9 @@ def run_awq(
inps = [] inps = []
layer_kwargs = {} layer_kwargs = {}
layers[0] = layers[0].cuda()
move_embed(model, "cuda")
# get input and kwargs to layer 0 # get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0 # with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now # use this Catcher hack for now
...@@ -79,6 +98,9 @@ def run_awq( ...@@ -79,6 +98,9 @@ def run_awq(
layers[0] = layers[0].module # restore layers[0] = layers[0].module # restore
inps = inps[0] inps = inps[0]
layers[0] = layers[0].cpu()
move_embed(model, "cpu")
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -90,6 +112,7 @@ def run_awq( ...@@ -90,6 +112,7 @@ def run_awq(
# solve layer by layer # solve layer by layer
for i in tqdm.tqdm(range(len(layers)), desc="Running AWQ..."): for i in tqdm.tqdm(range(len(layers)), desc="Running AWQ..."):
layer = layers[i] layer = layers[i]
layer = layer.cuda()
named_linears = get_named_linears(layer) named_linears = get_named_linears(layer)
# firstly, get input features of all linear layers # firstly, get input features of all linear layers
...@@ -131,6 +154,7 @@ def run_awq( ...@@ -131,6 +154,7 @@ def run_awq(
# append prefix to make names global # append prefix to make names global
awq_results["clip"] += append_str_prefix(clip_list, get_op_name(model, layer) + ".") awq_results["clip"] += append_str_prefix(clip_list, get_op_name(model, layer) + ".")
layer = layer.cpu()
# Haotian: check activation replacement # Haotian: check activation replacement
del input_feat del input_feat
gc.collect() gc.collect()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment