Unverified Commit 29ee66d9 authored by Casper's avatar Casper Committed by GitHub
Browse files

PEFT compatible GEMM (#324)

parent ebe8fc3f
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.autograd import Function
from awq.utils.utils import get_best_device from awq.utils.utils import get_best_device
from awq.utils.packing_utils import dequantize_gemm from awq.utils.packing_utils import dequantize_gemm
...@@ -10,9 +11,94 @@ try: ...@@ -10,9 +11,94 @@ try:
except: except:
AWQ_INSTALLED = False AWQ_INSTALLED = False
# Adapted from https://github.com/compressa-ai/AutoAWQ/tree/dev
class WQLinearMMFunction(Function):
@staticmethod
# ctx is the first argument to forward
def forward(
ctx,
x,
qweight,
qzeros,
scales,
w_bit=4,
group_size=128,
bias=None,
out_features=0
):
# The forward pass can use ctx.
ctx.save_for_backward(x, qweight, qzeros, scales, bias)
ctx.out_features = out_features
out_shape = x.shape[:-1] + (out_features, )
x = x.to(torch.float16)
if AWQ_INSTALLED:
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0]*x.shape[1] >= 1024
if FP16_MATMUL_HEURISTIC_CONDITION:
out = awq_ext.dequantize_weights_cuda(
qweight,
scales,
qzeros,
0,
0,
0,
False
)
out = torch.matmul(x, out)
else:
out = awq_ext.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]),
qweight,
scales,
qzeros,
8
)
else:
out = dequantize_gemm(
qweight,
qzeros,
scales,
w_bit,
group_size
)
out = torch.matmul(x, out)
out = out + bias if bias is not None else out
out = out.reshape(out_shape)
# always want 3D tensor if tensor is 2D
if len(out.shape) == 2:
out = out.unsqueeze(0)
return out
@staticmethod
def backward(ctx, grad_output):
input, qweight, qzeros, scales, bias = ctx.saved_tensors
weights = awq_ext.dequantize_weights_cuda(
qweight,
scales,
qzeros,
1,
0,
0,
False
)
if ctx.needs_input_grad[0]:
# 2D matrix multiplication, unsqueeze to 3D
grad_input = grad_output.squeeze(0).mm(
weights.transpose(0, 1)
).unsqueeze(0)
return grad_input, None, None, None, None, None, None, None
class WQLinear_GEMM(nn.Module): class WQLinear_GEMM(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): def __init__(self, w_bit, group_size, in_features, out_features, bias, dev, training=False):
super().__init__() super().__init__()
if w_bit not in [4]: if w_bit not in [4]:
...@@ -22,6 +108,7 @@ class WQLinear_GEMM(nn.Module): ...@@ -22,6 +108,7 @@ class WQLinear_GEMM(nn.Module):
self.out_features = out_features self.out_features = out_features
self.w_bit = w_bit self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else in_features self.group_size = group_size if group_size != -1 else in_features
self.training = training
# quick sanity check (make sure aligment) # quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0 assert self.in_features % self.group_size == 0
...@@ -145,7 +232,6 @@ class WQLinear_GEMM(nn.Module): ...@@ -145,7 +232,6 @@ class WQLinear_GEMM(nn.Module):
return awq_linear return awq_linear
@torch.no_grad()
def forward(self, x): def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features,) out_shape = x.shape[:-1] + (self.out_features,)
...@@ -153,37 +239,29 @@ class WQLinear_GEMM(nn.Module): ...@@ -153,37 +239,29 @@ class WQLinear_GEMM(nn.Module):
if input_dtype != torch.float16: if input_dtype != torch.float16:
x = x.half() x = x.half()
if AWQ_INSTALLED: if self.training:
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 out = WQLinearMMFunction.apply(
x,
if FP16_MATMUL_HEURISTIC_CONDITION:
out = awq_ext.dequantize_weights_cuda(
self.qweight,
self.scales,
self.qzeros,
0,
0,
0,
False,
)
out = torch.matmul(x, out)
else:
out = awq_ext.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]),
self.qweight,
self.scales,
self.qzeros,
8,
)
else:
out = dequantize_gemm(
self.qweight, self.qweight,
self.qzeros, self.qzeros,
self.scales, self.scales,
self.w_bit, self.w_bit,
self.group_size, self.group_size,
self.bias,
self.out_features,
) )
out = torch.matmul(x, out) else:
with torch.no_grad():
out = WQLinearMMFunction.apply(
x,
self.qweight,
self.qzeros,
self.scales,
self.w_bit,
self.group_size,
self.bias,
self.out_features,
)
if input_dtype != torch.float16: if input_dtype != torch.float16:
out = out.to(dtype=input_dtype) out = out.to(dtype=input_dtype)
......
import datasets
from awq import AutoAWQForCausalLM
from transformers import (
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from peft import get_peft_model, LoraConfig, TaskType
def prepare_split(tokenizer):
data = datasets.load_dataset("mhenrichsen/alpaca_2k_test", split="train")
prompt_template = "<s>[INST] {system} {prompt} [/INST] {output}</s>"
def format_prompt(x):
return prompt_template.format(
system="",
prompt=x["instruction"],
output=x["output"]
)
data = data.map(
lambda x: {"text": format_prompt(x)},
).select_columns(["text"])
data = data.map(lambda x: tokenizer(x["text"]), batched=True)
return data
model_path = "ybelkada/opt-125m-awq"
# Load model
model = AutoAWQForCausalLM.from_quantized(model_path, fuse_layers=False)
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token
# Prepare data
data_train = prepare_split(tokenizer)
# Config Lora
lora_config = LoraConfig(
r=4,
lora_alpha=8,
lora_dropout=0.5,
bias="none",
task_type=TaskType.CAUSAL_LM,
inference_mode=False
)
model = get_peft_model(model.model, lora_config)
model.print_trainable_parameters()
training_arguments = TrainingArguments(
output_dir="./output",
per_device_train_batch_size=1,
optim="adamw_torch",
num_train_epochs=1,
learning_rate=1e-4,
# fp16=True,
evaluation_strategy="no",
save_strategy="epoch",
save_steps=100,
logging_steps=50,
eval_steps=None,
load_best_model_at_end=False
)
trainer = Trainer(
model=model,
train_dataset=data_train,
args=training_arguments,
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
trainer.train()
trainer.save_model("output")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment