"csrc/vscode:/vscode.git/clone" did not exist on "ffb1f7bf095bcf4b8e25383820053e2c8e92f5ee"
Commit 34f1faff authored by Jiaming Tang's avatar Jiaming Tang
Browse files

[Major] Add CPU offloading support for run_awq

parent a293e16f
......@@ -83,17 +83,22 @@ def build_model_and_enc(model_path):
"OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"]
)
else: # fp16 to quantized
kwargs = {"device_map": "balanced", "torch_dtype": torch.float16}
args.run_awq &= not args.load_awq # if load_awq, no need to run awq
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, **kwargs)
if args.run_awq:
assert args.dump_awq, "Please save the awq results with --dump_awq"
if args.load_awq:
print("Loading pre-computed AWQ results from", args.load_awq)
awq_results = torch.load(args.load_awq, map_location="cpu")
apply_awq(model, awq_results)
# Init model on CPU
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_normal_ = skip
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, torch_dtype=torch.float16)
elif args.run_awq:
awq_results = run_awq(
model, enc,
w_bit=args.w_bit, q_config=q_config,
......@@ -103,6 +108,19 @@ def build_model_and_enc(model_path):
torch.save(awq_results, args.dump_awq)
print("AWQ results saved at", args.dump_awq)
exit(0)
else:
# Inference with fake quant
# Init model on GPUs:
kwargs = {"device_map": "balanced", "torch_dtype": torch.float16}
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, **kwargs)
if args.load_awq:
print("Loading pre-computed AWQ results from", args.load_awq)
awq_results = torch.load(args.load_awq, map_location="cpu")
apply_awq(model, awq_results)
# weight quantization
if args.w_bit is not None:
if args.q_backend == "fake":
......
......@@ -34,6 +34,22 @@ def get_blocks(model):
raise NotImplementedError(type(model))
return layers
def move_embed(model, device):
if isinstance(model, LlamaForCausalLM):
model.model.embed_tokens = model.model.embed_tokens.to(device)
elif isinstance(model, OPTForCausalLM):
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(device)
model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(device)
elif isinstance(model, BloomForCausalLM):
model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
model.transformer.word_embeddings_layernorm = model.transformer.word_embeddings_layernorm.to(device)
elif "mpt" in str(model.__class__).lower():
model.transformer.wte = model.transformer.wte.to(device)
model.transformer.emb_drop = model.transformer.emb_drop.to(device)
elif "falcon" in str(model.__class__).lower():
model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
else:
raise NotImplementedError(type(model))
@torch.no_grad()
def run_awq(
......@@ -57,6 +73,9 @@ def run_awq(
inps = []
layer_kwargs = {}
layers[0] = layers[0].cuda()
move_embed(model, "cuda")
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
......@@ -79,6 +98,9 @@ def run_awq(
layers[0] = layers[0].module # restore
inps = inps[0]
layers[0] = layers[0].cpu()
move_embed(model, "cpu")
gc.collect()
torch.cuda.empty_cache()
......@@ -90,6 +112,7 @@ def run_awq(
# solve layer by layer
for i in tqdm.tqdm(range(len(layers)), desc="Running AWQ..."):
layer = layers[i]
layer = layer.cuda()
named_linears = get_named_linears(layer)
# firstly, get input features of all linear layers
......@@ -131,6 +154,7 @@ def run_awq(
# append prefix to make names global
awq_results["clip"] += append_str_prefix(clip_list, get_op_name(model, layer) + ".")
layer = layer.cpu()
# Haotian: check activation replacement
del input_feat
gc.collect()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment