Unverified Commit b0a0fecf authored by Jiaming Tang's avatar Jiaming Tang Committed by GitHub
Browse files

Merge pull request #41 from mit-han-lab/dev/more_models

parents 25e92c4c ce4a6bb1
from lm_eval import evaluator, tasks
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import torch
import argparse
import os
import json
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model, load_checkpoint_in_model
from awq.utils.parallel import auto_parallel
from awq.quantize.pre_quant import run_awq, apply_awq
from awq.quantize.quantizer import pseudo_quantize_model_weight, real_quantize_model_weight
from awq.utils.lm_eval_adaptor import LMEvalAdaptor
from awq.utils.utils import simple_dispatch_model
parser = argparse.ArgumentParser()
......@@ -20,6 +21,12 @@ parser.add_argument('--num_fewshot', type=int, default=0)
# model config
parser.add_argument('--parallel', action='store_true',
help="enable model parallelism")
# max memory to offload larger models to CPU
parser.add_argument('--max_memory', type=str, nargs='*',
help="List of device_id:max_memory pairs to be parsed into a dictionary; " \
+ "Example: 0:10GiB 1:10GiB cpu:30GiB; " \
+ "mode details here: " \
+ "https://huggingface.co/docs/accelerate/usage_guides/big_modeling")
parser.add_argument('--auto_parallel', action='store_true',
help="automatically set parallel and batch_size")
# quantization config
......@@ -43,6 +50,9 @@ parser.add_argument('--load_awq', type=str, default=None,
help="load the awq search results")
args = parser.parse_args()
max_memory = [v.split(':') for v in (args.max_memory or [])]
max_memory = {(int(k) if k.isdigit() else k):v for k,v in max_memory}
if args.auto_parallel:
gpu_list = auto_parallel(args)
......@@ -62,39 +72,67 @@ def build_model_and_enc(model_path):
print(f"* Building model {model_path}")
# all hf model
config = AutoConfig.from_pretrained(model_path)
enc = AutoTokenizer.from_pretrained(model_path, use_fast=False)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
if "mpt" in config.__class__.__name__.lower():
enc = AutoTokenizer.from_pretrained(config.tokenizer_name, trust_remote_code=True)
else:
enc = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
if args.load_quant: # directly load quantized weights
# no need to really load the fp16 weights... just to get the model structure
print("Loading pre-computed quantized weights...")
with init_empty_weights():
model = AutoModelForCausalLM.from_pretrained(model_path, config=config,
torch_dtype=torch.float16)
model = AutoModelForCausalLM.from_config(config=config,
torch_dtype=torch.float16, trust_remote_code=True)
real_quantize_model_weight(
model, w_bit=args.w_bit, q_config=q_config, init_only=True)
model = load_checkpoint_and_dispatch(
model, args.load_quant, device_map="balanced",
# TODO: can we remove this?
model.tie_weights()
# Infer device map
kwargs = {"max_memory": max_memory} if len(max_memory) else {}
device_map = infer_auto_device_map(
model,
no_split_module_classes=[
"OPTDecoderLayer", "LlamaDecoderLayer"]
"OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"],
**kwargs
)
else: # fp16 to quantized
kwargs = {"device_map": "balanced", "torch_dtype": torch.float16}
# Load checkpoint in the model
load_checkpoint_in_model(
model,
checkpoint=args.load_quant,
device_map=device_map,
offload_state_dict=True,
)
# Dispatch model
model = simple_dispatch_model(model, device_map=device_map)
model.eval()
else: # fp16 to quantized
args.run_awq &= not args.load_awq # if load_awq, no need to run awq
# Init model on CPU:
kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True}
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, **kwargs)
model_path, config=config, trust_remote_code=True, **kwargs)
model.eval()
if args.run_awq:
assert args.dump_awq, "Please save the awq results with --dump_awq"
awq_results = run_awq(
model, enc,
w_bit=args.w_bit, q_config=q_config,
n_samples=128, seqlen=512,
)
if args.dump_awq:
dirpath = os.path.dirname(args.dump_awq)
os.makedirs(dirpath, exist_ok=True)
torch.save(awq_results, args.dump_awq)
print("AWQ results saved at", args.dump_awq)
exit(0)
if args.load_awq:
print("Loading pre-computed AWQ results from", args.load_awq)
awq_results = torch.load(args.load_awq, map_location="cpu")
......@@ -113,12 +151,26 @@ def build_model_and_enc(model_path):
model, w_bit=args.w_bit, q_config=q_config
)
if args.dump_quant:
dirpath = os.path.dirname(args.dump_quant)
os.makedirs(dirpath, exist_ok=True)
print(
f"Saving the quantized model at {args.dump_quant}...")
torch.save(model.cpu().state_dict(), args.dump_quant)
exit(0)
else:
raise NotImplementedError
# Move the model to GPU (as much as possible) for LM evaluation
kwargs = {"max_memory": max_memory} if len(max_memory) else {}
device_map = infer_auto_device_map(
model,
# TODO: can we remove this?
no_split_module_classes=[
"OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"],
**kwargs
)
model = dispatch_model(model, device_map=device_map)
return model, enc
......@@ -136,11 +188,10 @@ def main():
# a hack here to auto set model group
model, enc = build_model_and_enc(args.model_path)
lm_eval_model = LMEvalAdaptor(args.model_path, model, enc, args.batch_size)
if args.tasks is not None:
task_names = args.tasks.split(",")
lm_eval_model = LMEvalAdaptor(args.model_path, model, enc, args.batch_size)
results = evaluator.simple_evaluate(
model=lm_eval_model,
tasks=task_names,
......
......@@ -13,7 +13,7 @@ __pack_half2(const half x, const half y) {
return (v1 << 16) | v0;
}
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
{
static constexpr uint32_t ZERO = 0x0;
float C_warp[32];
......@@ -24,7 +24,6 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
__shared__ half zeros_shared[128];
int j_factors1 = ((OC + 128 - 1) / 128);
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
......@@ -53,6 +52,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
+ (((int)threadIdx.x) / (128 / 8)) * (OC / 8)
+ (((int)blockIdx_y) % j_factors1) * (128 / 8)
+ (((int)threadIdx.x) % (128 / 8)) * 1;
// Why * 1 in the above line?
half* A_shared_ptr = A_shared
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
......@@ -80,7 +80,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
// preload s.f. and zeros
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
if ((k_bound - 1) * 32 + blockIdx_z >= IC) k_bound -= 1;
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
__syncthreads();
......@@ -95,9 +95,9 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
}
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / 128 * (OC / 8));
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / 128 * (OC));
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
/*
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
......@@ -107,6 +107,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) {
// TODO: Shang: double check how to get 8.
// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
......@@ -205,6 +206,204 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
}
}
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
{
static constexpr uint32_t ZERO = 0x0;
float C_warp[32];
__shared__ half A_shared[16 * (32 + 8)];
__shared__ half B_shared[32 * (64 + 8)];
__shared__ half scaling_factors_shared[64];
__shared__ half zeros_shared[64];
int j_factors1 = ((OC + 64 - 1) / 64);
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
half A_shared_warp[8];
half B_shared_warp[16];
for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) {
for (int i = 0; i < 8; ++i) {
C_warp[(j_0_4_init * 8) + i] = 0.0;
}
}
static constexpr int row_stride_warp = 32 * 8 / 32;
static constexpr int row_stride = 2 * 32 * 8 / 64;
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64;
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
// bool wb_C_flag = (threadIdx.x / 4) < M;
half* A_ptr = A
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
+ (((int)threadIdx.x) % (32 / 8)) * 8;
int* B_ptr = B
+ ((int)threadIdx.y) * (OC / 8) * 4
+ (((int)threadIdx.x) / (64 / 8)) * (OC / 8)
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
+ (((int)threadIdx.x) % (64 / 8)) * 1;
// Why * 1 in the above line?
half* A_shared_ptr = A_shared
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
half* B_shared_ptr = B_shared
+ ((int)threadIdx.y) * (row_stride / 2) * (64 + 8)
+ (((int)threadIdx.x) / (64 / 8)) * (64 + 8)
+ (((int)threadIdx.x) % (64 / 8)) * 8;
int* zeros_ptr = zeros
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
+ ((int)threadIdx.x) % (64 / 8);
half* scaling_factors_ptr = scaling_factors
+ (((int)blockIdx_y) % j_factors1) * (64)
+ (((int)threadIdx.x) % (64 / 8)) * 8;
half* C_ptr = C
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim
+ (((int)blockIdx_y) % j_factors1) * 64
+ ((int)threadIdx.y) * 32
+ (((int)threadIdx.x) % 4) * 2;
// preload s.f. and zeros
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
__syncthreads();
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
if (ld_A_flag)
{
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
}
else
{
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
}
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
/*
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
}
*/
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) {
// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
// - zero and * scale
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
/*
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
}
*/
// write back
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16;
}
__syncthreads();
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1)
{
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
);
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
: "r"(addr)
);
}
for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0)
{
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8))))
);
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
: "r"(addr)
);
}
}
for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4)
{
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
}
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
}
}
}
}
// TODO: Shang: Hoist loop invariance.
for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) {
for (int local_id = 0; local_id < 8; ++local_id) {
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
if (row_offset < M)
{
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
}
}
}
}
// in_feats: M, IC [float16]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// scaling_factors: IC // G, OC [float16]
......@@ -232,20 +431,38 @@ torch::Tensor gemm_forward_cuda(
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
int group_size = num_in_channels / _scaling_factors.size(0);
if (num_out_channels % 128 != 0)
throw std::invalid_argument("OC is not multiple of cta_N = 128");
if (num_out_channels % 64 != 0)
throw std::invalid_argument("OC is not multiple of cta_N = 64");
if (num_out_channels % 8 != 0)
throw std::invalid_argument("OC is not multiple of pack_num = 8");
int j_factors1 = num_out_channels / 128 / 1;
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
if (group_size % 32 != 0)
throw std::invalid_argument("Group size should be a multiple of 32");
if (num_out_channels % group_size != 0)
throw std::invalid_argument("OC is not multiple of Group size");
if (num_out_channels % 128 == 0)
{
int j_factors1 = num_out_channels / 128 / 1;
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2);
gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
}
else if (num_out_channels % 64 == 0)
{
int j_factors1 = num_out_channels / 64 / 1;
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2);
gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2);
gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
}
return _out_feats.sum(0);
}
......@@ -3,7 +3,7 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtensio
extra_compile_args = {
"cxx": ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17"],
"nvcc": ["-O3", "-std=c++17", "-keep"],
"nvcc": ["-O3", "-std=c++17"],
}
setup(
......@@ -18,4 +18,4 @@ setup(
],
cmdclass={"build_ext": BuildExtension},
install_requires=["torch"],
)
\ No newline at end of file
)
......@@ -22,7 +22,7 @@ def auto_clip_layer(w, input_feat, n_bit, q_config,
input_feat = input_feat[:, 0::input_feat.shape[1] // n_sample_token]
w = w.reshape(w.shape[0], 1, -1, group_size)
oc_batch_size = 256 # prevent OOM
oc_batch_size = 256 if w.shape[0] % 256 == 0 else 64 # prevent OOM
assert w.shape[0] % oc_batch_size == 0
w_all = w
best_max_val_all = []
......@@ -73,11 +73,13 @@ def auto_clip_block(module,
clip_list = []
for name in named_linears:
# due to qk bmm, it is hard to clip precisely
if any([_ in name for _ in ["q_", "k_"]]):
if any([_ in name for _ in ["q_", "k_", "query", "key", "Wqkv"]]):
continue
named_linears[name].cuda()
max_val = auto_clip_layer(
named_linears[name].weight, input_feat[name], n_bit=w_bit, q_config=q_config)
clip_list.append((name, max_val))
named_linears[name].cpu()
return clip_list
......@@ -86,8 +88,10 @@ def apply_clip(module, clip_list):
from ..utils.module import get_op_by_name
for name, max_val in clip_list:
layer = get_op_by_name(module, name)
layer.cuda()
max_val = max_val.to(layer.weight.device)
org_shape = layer.weight.shape
layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1)
layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val)
layer.weight.data = layer.weight.data.reshape(org_shape)
layer.cpu()
import gc
import torch
import torch.nn as nn
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu
from transformers.models.opt.modeling_opt import OPTDecoderLayer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm
from ..utils.module import get_op_by_name, get_op_name
from .qmodule import ScaledActivation
from ..utils.module import get_op_by_name, get_op_name, set_op_by_name
__all__ = ["auto_scale_block", "apply_scale"]
......@@ -32,6 +35,13 @@ def scale_ln_fcs(ln, fcs, scales):
scales = scales.to(ln.weight.device)
# debugging start even scales = 1 does not work?
"""
scales = scales * 0
scales = scales + 1
"""
# debugging end
ln.weight.div_(scales)
if hasattr(ln, 'bias') and ln.bias is not None:
ln.bias.div_(scales)
......@@ -50,11 +60,12 @@ def scale_ln_fcs(ln, fcs, scales):
def scale_fc_fc(fc1, fc2, scales):
assert isinstance(fc1, nn.Linear)
assert isinstance(fc2, nn.Linear)
assert fc1.out_features == fc2.in_features
# assert fc1.out_features == fc2.in_features
scales = scales.to(fc1.weight.device)
fc1.weight.div_(scales.view(-1, 1))
# fc1.weight.div_(scales.view(-1, 1))
fc1.weight[-scales.size(0):].div_(scales.view(-1, 1))
if fc1.bias is not None:
fc1.bias.div_(scales.view(-1))
......@@ -66,6 +77,17 @@ def scale_fc_fc(fc1, fc2, scales):
assert torch.isnan(p).sum() == 0
@torch.no_grad()
def scale_gelu_fc(gelu, fc, scales):
assert isinstance(gelu, nn.GELU) or isinstance(gelu, BloomGelu)
assert isinstance(fc, nn.Linear)
fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
for p in fc.parameters():
assert torch.isnan(p).sum() == 0
@torch.no_grad()
def auto_scale_block(module, module_kwargs,
w_bit, q_config,
......@@ -86,11 +108,15 @@ def auto_scale_block(module, module_kwargs,
def _search_module_scale(block, linears2scale: list, x, kwargs={}):
# w: co, ci
# x: n, ci
x = x.to(next(block.parameters()).device)
weight = torch.cat([_m.weight for _m in linears2scale], dim=0)
w_max = get_weight_scale(
weight, q_group_size=q_config.get("q_group_size", -1))
# Clear GPU memory
del weight
gc.collect()
torch.cuda.empty_cache()
x = x.to(next(block.parameters()).device)
with torch.no_grad():
org_out = block(x, **kwargs)
if isinstance(org_out, tuple):
......@@ -112,7 +138,7 @@ def auto_scale_block(module, module_kwargs,
).clamp(min=1e-4).view(-1)
scales = scales / (scales.max() * scales.min()).sqrt()
for fc in linears2scale:
fc.weight.mul_(scales.view(1, -1))
fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
fc.weight.data = w_quantize_func(
fc.weight.data) / (scales.view(1, -1))
out = block(x, **kwargs)
......@@ -143,6 +169,7 @@ def auto_scale_block(module, module_kwargs,
module2inspect = layers[0]
scales = _search_module_scale(module2inspect, layers, inp, kwargs)
scales = scales.detach().cpu()
# prev_op_name, [layer_name], scale
return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), scales)
......@@ -204,7 +231,110 @@ def auto_scale_block(module, module_kwargs,
layers=[module.mlp.down_proj],
inp=input_feat['mlp.down_proj'],
))
elif isinstance(module, BloomBlock):
# attention input
scales_list.append(_auto_get_scale(
prev_op=module.input_layernorm,
layers=[module.self_attention.query_key_value],
inp=input_feat['self_attention.query_key_value'],
module2inspect=module, kwargs=module_kwargs,
))
# attn out
# Please refer to https://github.com/mit-han-lab/llm-awq/issues/2#issuecomment-1606297469
"""
scales_list.append(_auto_get_scale(
prev_op=module.self_attention.query_key_value,
layers=[module.self_attention.dense],
inp=input_feat['self_attention.dense'],
))
"""
# fc1
scales_list.append(_auto_get_scale(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.dense_h_to_4h],
inp=input_feat['mlp.dense_h_to_4h'],
module2inspect=module, kwargs=module_kwargs,
))
# fc2
scales_list.append(_auto_get_scale(
prev_op=module.mlp.gelu_impl,
layers=[module.mlp.dense_4h_to_h],
inp=input_feat['mlp.dense_4h_to_h'],
))
elif "mpt" in str(module.__class__).lower():
# attention input
scales_list.append(_auto_get_scale(
prev_op=module.norm_1,
layers=[module.attn.Wqkv],
inp=input_feat['attn.Wqkv'],
module2inspect=module.attn,
kwargs=module_kwargs,
))
# attn out
scales_list.append(_auto_get_scale(
prev_op=module.attn.Wqkv,
layers=[module.attn.out_proj],
inp=input_feat['attn.out_proj'],
))
# fc1
scales_list.append(_auto_get_scale(
prev_op=module.norm_2,
layers=[module.ffn.up_proj],
inp=input_feat['ffn.up_proj'],
module2inspect=module.ffn,
))
# fc2
scales_list.append(_auto_get_scale(
prev_op=module.ffn.act,
layers=[module.ffn.down_proj],
inp=input_feat['ffn.down_proj'],
))
elif "falcon" in str(module.__class__).lower():
# attn out
# Haotian: TBD: need to handle repeated scales for MQ
"""
scales_list.append(_auto_get_scale(
prev_op=module.self_attention.query_key_value,
layers=[module.self_attention.dense],
inp=input_feat['self_attention.dense'],
))
"""
# fc1, as long as it is scaled, everything is screwed up
if "falcon-7b" in str(module.__class__).lower():
scales_list.append(_auto_get_scale(
prev_op=module.input_layernorm,
layers=[module.mlp.dense_h_to_4h, module.self_attention.query_key_value],
inp=input_feat['self_attention.query_key_value'],
module2inspect=module,
kwargs=module_kwargs,
))
elif "falcon-40b" in str(module.__class__).lower():
scales_list.append(_auto_get_scale(
prev_op=module.ln_attn,
layers=[module.self_attention.query_key_value],
inp=input_feat['self_attention.query_key_value'],
module2inspect=module,
kwargs=module_kwargs,
))
scales_list.append(_auto_get_scale(
prev_op=module.ln_mlp,
layers=[module.mlp.dense_h_to_4h],
inp=input_feat['mlp.dense_h_to_4h'],
module2inspect=module,
kwargs=module_kwargs,
))
else:
raise NotImplementedError("Unknown Falcon architecture, currently only falcon-7b and falcon-40b are supported")
# fc2
scales_list.append(_auto_get_scale(
prev_op=module.mlp.act,
layers=[module.mlp.dense_4h_to_h],
inp=input_feat['mlp.dense_4h_to_h'],
))
else:
raise NotImplementedError(f"{type(module)} not supported yet!")
......@@ -214,12 +344,21 @@ def apply_scale(module, scales_list, input_feat_dict=None):
for prev_op_name, layer_names, scales in scales_list:
prev_op = get_op_by_name(module, prev_op_name)
layers = [get_op_by_name(module, name) for name in layer_names]
prev_op.cuda()
for layer in layers:
layer.cuda()
scales.cuda()
if isinstance(prev_op, nn.Linear):
assert len(layers) == 1
scale_fc_fc(prev_op, layers[0], scales)
elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)):
scale_ln_fcs(prev_op, layers, scales)
elif isinstance(prev_op, nn.GELU) or isinstance(prev_op, BloomGelu):
new_module = ScaledActivation(prev_op, scales)
set_op_by_name(module, prev_op_name, new_module)
scale_gelu_fc(prev_op, layers[0], scales)
else:
raise NotImplementedError(
f"prev_op {type(prev_op)} not supported yet!")
......@@ -229,3 +368,8 @@ def apply_scale(module, scales_list, input_feat_dict=None):
for layer_name in layer_names:
inp = input_feat_dict[layer_name]
inp.div_(scales.view(1, -1).to(inp.device))
prev_op.cpu()
for layer in layers:
layer.cpu()
scales.cpu()
......@@ -5,6 +5,7 @@ import gc
import functools
from collections import defaultdict
from transformers.models.bloom.modeling_bloom import BloomForCausalLM
from transformers.models.opt.modeling_opt import OPTForCausalLM
from transformers.models.llama.modeling_llama import LlamaForCausalLM
......@@ -23,10 +24,32 @@ def get_blocks(model):
layers = model.model.layers
elif isinstance(model, OPTForCausalLM):
layers = model.model.decoder.layers
elif isinstance(model, BloomForCausalLM):
layers = model.transformer.h
elif "mpt" in str(model.__class__).lower():
layers = model.transformer.blocks
elif "falcon" in str(model.__class__).lower():
layers = model.transformer.h
else:
raise NotImplementedError(type(model))
return layers
def move_embed(model, device):
if isinstance(model, LlamaForCausalLM):
model.model.embed_tokens = model.model.embed_tokens.to(device)
elif isinstance(model, OPTForCausalLM):
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(device)
model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(device)
elif isinstance(model, BloomForCausalLM):
model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
model.transformer.word_embeddings_layernorm = model.transformer.word_embeddings_layernorm.to(device)
elif "mpt" in str(model.__class__).lower():
model.transformer.wte = model.transformer.wte.to(device)
model.transformer.emb_drop = model.transformer.emb_drop.to(device)
elif "falcon" in str(model.__class__).lower():
model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
else:
raise NotImplementedError(type(model))
@torch.no_grad()
def run_awq(
......@@ -50,6 +73,9 @@ def run_awq(
inps = []
layer_kwargs = {}
layers[0] = layers[0].cuda()
move_embed(model, "cuda")
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
......@@ -69,9 +95,13 @@ def run_awq(
model(samples.to(next(model.parameters()).device))
except ValueError: # work with early exit
pass
del samples
layers[0] = layers[0].module # restore
inps = inps[0]
layers[0] = layers[0].cpu()
move_embed(model, "cpu")
gc.collect()
torch.cuda.empty_cache()
......@@ -83,6 +113,7 @@ def run_awq(
# solve layer by layer
for i in tqdm.tqdm(range(len(layers)), desc="Running AWQ..."):
layer = layers[i]
layer = layer.cuda()
named_linears = get_named_linears(layer)
# firstly, get input features of all linear layers
......@@ -102,19 +133,25 @@ def run_awq(
inps = layer(inps, **layer_kwargs)[0]
for h in handles:
h.remove()
# now solve for scaling and clipping
input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}
# Clear GPU memory
torch.cuda.empty_cache()
if auto_scale: # if it applies, we should also modify the input_feat with scales
scales_list = auto_scale_block(
layer, layer_kwargs,
w_bit=w_bit, q_config=q_config,
input_feat=input_feat,
)
apply_scale(layer, scales_list, input_feat_dict=input_feat)
# apply_scale(layer, scales_list, input_feat_dict=input_feat)
apply_scale(layers[i], scales_list, input_feat_dict=input_feat)
# append prefix to make names global
awq_results["scale"] += append_str_prefix(scales_list, get_op_name(model, layer) + ".")
# Clear GPU memory
torch.cuda.empty_cache()
if mse_range:
clip_list = auto_clip_block(layer,
......@@ -124,6 +161,8 @@ def run_awq(
# append prefix to make names global
awq_results["clip"] += append_str_prefix(clip_list, get_op_name(model, layer) + ".")
layer = layer.cpu()
# Haotian: check activation replacement
del input_feat
gc.collect()
torch.cuda.empty_cache()
......
......@@ -4,6 +4,16 @@ import torch.nn as nn
import f16s4_gemm # with CUDA kernels
class ScaledActivation(nn.Module):
def __init__(self, module, scales):
super().__init__()
self.act = module
self.scales = nn.Parameter(scales.data)
def forward(self, x):
return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
class WQLinear(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__()
......@@ -83,3 +93,7 @@ class WQLinear(nn.Module):
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)
def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'.format(
self.in_features, self.out_features, self.bias is not None, self.w_bit, self.group_size
)
......@@ -2,11 +2,48 @@ import torch
import torch.nn as nn
from tqdm import tqdm
import gc
from .qmodule import ScaledActivation
from ..utils.module import set_op_by_name
from transformers.models.bloom.modeling_bloom import BloomBlock
EMBEDDING_KEYWORDS = ["embed"]
LM_HEAD_KEYWORDS = ["lm_head", "embed_out", "output"]
def scale_activations(module):
param = next(module.parameters())
dtype = param.dtype
device = param.device
if isinstance(module, BloomBlock):
if isinstance(module.mlp.gelu_impl, ScaledActivation):
return
c = module.mlp.dense_h_to_4h.out_features
act = ScaledActivation(
module.mlp.gelu_impl,
torch.ones(c, dtype=dtype, device=device)
)
set_op_by_name(module, "mlp.gelu_impl", act)
elif 'mptblock' in str(module.__class__.__name__).lower():
if isinstance(module.ffn.act, ScaledActivation):
return
c = module.ffn.up_proj.out_features
act = ScaledActivation(
module.ffn.act,
torch.ones(c, dtype=dtype, device=device)
)
set_op_by_name(module, "ffn.act", act)
elif 'falcon' in str(module.__class__).lower():
if isinstance(module.mlp.act, ScaledActivation):
return
c = module.mlp.dense_h_to_4h.out_features
act = ScaledActivation(
module.mlp.act,
torch.ones(c, dtype=dtype, device=device)
)
set_op_by_name(module, "mlp.act", act)
# core quantization method (simulated quantization)
def pseudo_quantize_tensor(w, n_bit=8,
zero_point=True, q_group_size=-1,
......@@ -61,7 +98,9 @@ def pseudo_quantize_model_weight(
for i in tqdm(range(len(layers)), desc="pseudo weight quantization..."):
named_linears = get_named_linears(layers[i])
for n, m in named_linears.items():
m.cuda()
m.weight.data = pseudo_quantize_tensor(m.weight.data, n_bit=w_bit, **q_config)
m.cpu()
@torch.no_grad()
......@@ -77,29 +116,21 @@ def real_quantize_model_weight(
for i in tqdm(range(len(layers)), desc="real weight quantization..." + ("(init only)" if init_only else "")):
layer = layers[i]
named_linears = get_named_linears(layer)
scale_activations(layer)
for name, module in named_linears.items():
if init_only:
q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], True)
else:
module.cuda()
module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config)
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], False, scales, zeros)
levels = name.split('.')
if len(levels) > 1:
mod_ = layer
for l_idx in range(len(levels)-1):
if levels[l_idx].isdigit():
mod_ = mod_[int(levels[l_idx])]
else:
mod_ = getattr(mod_, levels[l_idx])
setattr(mod_, levels[-1], q_linear)
else:
setattr(layer, name, q_linear)
module.cpu()
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
torch.cuda.empty_cache()
gc.collect()
\ No newline at end of file
gc.collect()
......@@ -47,6 +47,10 @@ class LMEvalAdaptor(BaseLM):
return 2048
elif 'llama' in self.model_name:
return 2048 # TODO: did not check this
elif 'mpt' in self.model_name:
return 2048
elif 'falcon' in self.model_name:
return 2048
else:
print(self.model.config)
raise NotImplementedError
......
......@@ -8,6 +8,20 @@ def get_op_by_name(module, op_name):
raise ValueError(f"Cannot find op {op_name} in module {module}")
def set_op_by_name(layer, name, new_module):
levels = name.split('.')
if len(levels) > 1:
mod_ = layer
for l_idx in range(len(levels)-1):
if levels[l_idx].isdigit():
mod_ = mod_[int(levels[l_idx])]
else:
mod_ = getattr(mod_, levels[l_idx])
setattr(mod_, levels[-1], new_module)
else:
setattr(layer, name, new_module)
def get_op_name(module, op):
# get the name of the op relative to the module
for name, m in module.named_modules():
......
import torch
import accelerate
def get_module_by_name_suffix(model, module_name: str):
for name, module in model.named_modules():
if name.endswith(module_name):
return module
def simple_dispatch_model(model, device_map):
from accelerate.hooks import add_hook_to_module, AlignDevicesHook
if "" in device_map:
d = device_map[""]
model = model.to(torch.device(d))
model.hf_device_map = device_map
return model
tied_params = accelerate.utils.modeling.find_tied_parameters(model)
if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}:
main_device = "cpu"
else:
main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0]
cpu_offload_group = [(n, d) for n, d in device_map.items() if d == "cpu"]
prev_hook = None
for idx, (n, d) in enumerate(cpu_offload_group):
m = get_module_by_name_suffix(model, n)
_, prev_hook = accelerate.cpu_offload_with_hook(m, execution_device=main_device, prev_module_hook=prev_hook)
# set first cpu offload module's prev_module_hook to the last cpu offload module's hook
if len(cpu_offload_group) > 1:
get_module_by_name_suffix(model, cpu_offload_group[0][0])._hf_hook.prev_module_hook = prev_hook
for n, d in device_map.items():
m = get_module_by_name_suffix(model, n)
if d != "cpu":
d = torch.device(d)
hook = AlignDevicesHook(d, io_same_device=True, place_submodules=True)
add_hook_to_module(m, hook)
accelerate.utils.modeling.retie_parameters(model, tied_params)
model.hf_device_map = device_map
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment