Unverified Commit b0a0fecf authored by Jiaming Tang's avatar Jiaming Tang Committed by GitHub
Browse files

Merge pull request #41 from mit-han-lab/dev/more_models

parents 25e92c4c ce4a6bb1
from lm_eval import evaluator, tasks from lm_eval import evaluator, tasks
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import torch import torch
import argparse import argparse
import os import os
import json import json
from accelerate import init_empty_weights, load_checkpoint_and_dispatch from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model, load_checkpoint_in_model
from awq.utils.parallel import auto_parallel from awq.utils.parallel import auto_parallel
from awq.quantize.pre_quant import run_awq, apply_awq from awq.quantize.pre_quant import run_awq, apply_awq
from awq.quantize.quantizer import pseudo_quantize_model_weight, real_quantize_model_weight from awq.quantize.quantizer import pseudo_quantize_model_weight, real_quantize_model_weight
from awq.utils.lm_eval_adaptor import LMEvalAdaptor from awq.utils.lm_eval_adaptor import LMEvalAdaptor
from awq.utils.utils import simple_dispatch_model
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -20,6 +21,12 @@ parser.add_argument('--num_fewshot', type=int, default=0) ...@@ -20,6 +21,12 @@ parser.add_argument('--num_fewshot', type=int, default=0)
# model config # model config
parser.add_argument('--parallel', action='store_true', parser.add_argument('--parallel', action='store_true',
help="enable model parallelism") help="enable model parallelism")
# max memory to offload larger models to CPU
parser.add_argument('--max_memory', type=str, nargs='*',
help="List of device_id:max_memory pairs to be parsed into a dictionary; " \
+ "Example: 0:10GiB 1:10GiB cpu:30GiB; " \
+ "mode details here: " \
+ "https://huggingface.co/docs/accelerate/usage_guides/big_modeling")
parser.add_argument('--auto_parallel', action='store_true', parser.add_argument('--auto_parallel', action='store_true',
help="automatically set parallel and batch_size") help="automatically set parallel and batch_size")
# quantization config # quantization config
...@@ -43,6 +50,9 @@ parser.add_argument('--load_awq', type=str, default=None, ...@@ -43,6 +50,9 @@ parser.add_argument('--load_awq', type=str, default=None,
help="load the awq search results") help="load the awq search results")
args = parser.parse_args() args = parser.parse_args()
max_memory = [v.split(':') for v in (args.max_memory or [])]
max_memory = {(int(k) if k.isdigit() else k):v for k,v in max_memory}
if args.auto_parallel: if args.auto_parallel:
gpu_list = auto_parallel(args) gpu_list = auto_parallel(args)
...@@ -62,39 +72,67 @@ def build_model_and_enc(model_path): ...@@ -62,39 +72,67 @@ def build_model_and_enc(model_path):
print(f"* Building model {model_path}") print(f"* Building model {model_path}")
# all hf model # all hf model
config = AutoConfig.from_pretrained(model_path) config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
enc = AutoTokenizer.from_pretrained(model_path, use_fast=False) if "mpt" in config.__class__.__name__.lower():
enc = AutoTokenizer.from_pretrained(config.tokenizer_name, trust_remote_code=True)
else:
enc = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
if args.load_quant: # directly load quantized weights if args.load_quant: # directly load quantized weights
# no need to really load the fp16 weights... just to get the model structure
print("Loading pre-computed quantized weights...") print("Loading pre-computed quantized weights...")
with init_empty_weights(): with init_empty_weights():
model = AutoModelForCausalLM.from_pretrained(model_path, config=config, model = AutoModelForCausalLM.from_config(config=config,
torch_dtype=torch.float16) torch_dtype=torch.float16, trust_remote_code=True)
real_quantize_model_weight( real_quantize_model_weight(
model, w_bit=args.w_bit, q_config=q_config, init_only=True) model, w_bit=args.w_bit, q_config=q_config, init_only=True)
model = load_checkpoint_and_dispatch(
model, args.load_quant, device_map="balanced", model.tie_weights()
# TODO: can we remove this?
# Infer device map
kwargs = {"max_memory": max_memory} if len(max_memory) else {}
device_map = infer_auto_device_map(
model,
no_split_module_classes=[ no_split_module_classes=[
"OPTDecoderLayer", "LlamaDecoderLayer"] "OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"],
**kwargs
) )
else: # fp16 to quantized # Load checkpoint in the model
kwargs = {"device_map": "balanced", "torch_dtype": torch.float16} load_checkpoint_in_model(
model,
checkpoint=args.load_quant,
device_map=device_map,
offload_state_dict=True,
)
# Dispatch model
model = simple_dispatch_model(model, device_map=device_map)
model.eval()
else: # fp16 to quantized
args.run_awq &= not args.load_awq # if load_awq, no need to run awq
# Init model on CPU:
kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True}
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, **kwargs) model_path, config=config, trust_remote_code=True, **kwargs)
model.eval()
if args.run_awq: if args.run_awq:
assert args.dump_awq, "Please save the awq results with --dump_awq"
awq_results = run_awq( awq_results = run_awq(
model, enc, model, enc,
w_bit=args.w_bit, q_config=q_config, w_bit=args.w_bit, q_config=q_config,
n_samples=128, seqlen=512, n_samples=128, seqlen=512,
) )
if args.dump_awq: if args.dump_awq:
dirpath = os.path.dirname(args.dump_awq)
os.makedirs(dirpath, exist_ok=True)
torch.save(awq_results, args.dump_awq) torch.save(awq_results, args.dump_awq)
print("AWQ results saved at", args.dump_awq) print("AWQ results saved at", args.dump_awq)
exit(0)
if args.load_awq: if args.load_awq:
print("Loading pre-computed AWQ results from", args.load_awq) print("Loading pre-computed AWQ results from", args.load_awq)
awq_results = torch.load(args.load_awq, map_location="cpu") awq_results = torch.load(args.load_awq, map_location="cpu")
...@@ -113,12 +151,26 @@ def build_model_and_enc(model_path): ...@@ -113,12 +151,26 @@ def build_model_and_enc(model_path):
model, w_bit=args.w_bit, q_config=q_config model, w_bit=args.w_bit, q_config=q_config
) )
if args.dump_quant: if args.dump_quant:
dirpath = os.path.dirname(args.dump_quant)
os.makedirs(dirpath, exist_ok=True)
print( print(
f"Saving the quantized model at {args.dump_quant}...") f"Saving the quantized model at {args.dump_quant}...")
torch.save(model.cpu().state_dict(), args.dump_quant) torch.save(model.cpu().state_dict(), args.dump_quant)
exit(0) exit(0)
else: else:
raise NotImplementedError raise NotImplementedError
# Move the model to GPU (as much as possible) for LM evaluation
kwargs = {"max_memory": max_memory} if len(max_memory) else {}
device_map = infer_auto_device_map(
model,
# TODO: can we remove this?
no_split_module_classes=[
"OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"],
**kwargs
)
model = dispatch_model(model, device_map=device_map)
return model, enc return model, enc
...@@ -136,11 +188,10 @@ def main(): ...@@ -136,11 +188,10 @@ def main():
# a hack here to auto set model group # a hack here to auto set model group
model, enc = build_model_and_enc(args.model_path) model, enc = build_model_and_enc(args.model_path)
lm_eval_model = LMEvalAdaptor(args.model_path, model, enc, args.batch_size)
if args.tasks is not None: if args.tasks is not None:
task_names = args.tasks.split(",") task_names = args.tasks.split(",")
lm_eval_model = LMEvalAdaptor(args.model_path, model, enc, args.batch_size)
results = evaluator.simple_evaluate( results = evaluator.simple_evaluate(
model=lm_eval_model, model=lm_eval_model,
tasks=task_names, tasks=task_names,
......
...@@ -13,7 +13,7 @@ __pack_half2(const half x, const half y) { ...@@ -13,7 +13,7 @@ __pack_half2(const half x, const half y) {
return (v1 << 16) | v0; return (v1 << 16) | v0;
} }
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
{ {
static constexpr uint32_t ZERO = 0x0; static constexpr uint32_t ZERO = 0x0;
float C_warp[32]; float C_warp[32];
...@@ -24,7 +24,6 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli ...@@ -24,7 +24,6 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
__shared__ half zeros_shared[128]; __shared__ half zeros_shared[128];
int j_factors1 = ((OC + 128 - 1) / 128); int j_factors1 = ((OC + 128 - 1) / 128);
int blockIdx_x = 0; int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1); int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1); int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
...@@ -53,6 +52,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli ...@@ -53,6 +52,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
+ (((int)threadIdx.x) / (128 / 8)) * (OC / 8) + (((int)threadIdx.x) / (128 / 8)) * (OC / 8)
+ (((int)blockIdx_y) % j_factors1) * (128 / 8) + (((int)blockIdx_y) % j_factors1) * (128 / 8)
+ (((int)threadIdx.x) % (128 / 8)) * 1; + (((int)threadIdx.x) % (128 / 8)) * 1;
// Why * 1 in the above line?
half* A_shared_ptr = A_shared half* A_shared_ptr = A_shared
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8) + ((int)threadIdx.y) * row_stride_warp * (32 + 8)
...@@ -80,7 +80,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli ...@@ -80,7 +80,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
// preload s.f. and zeros // preload s.f. and zeros
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
if ((k_bound - 1) * 32 + blockIdx_z >= IC) k_bound -= 1; if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) { for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
__syncthreads(); __syncthreads();
...@@ -95,9 +95,9 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli ...@@ -95,9 +95,9 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
} }
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / 128 * (OC / 8)); uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / 128 * (OC)); uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
/* /*
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
...@@ -107,6 +107,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli ...@@ -107,6 +107,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) { for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) {
// TODO: Shang: double check how to get 8.
// B: 32 x 136 (128+8) float16 // B: 32 x 136 (128+8) float16
// each warp: 32 x 4 // each warp: 32 x 4
...@@ -205,6 +206,204 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli ...@@ -205,6 +206,204 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
} }
} }
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
{
static constexpr uint32_t ZERO = 0x0;
float C_warp[32];
__shared__ half A_shared[16 * (32 + 8)];
__shared__ half B_shared[32 * (64 + 8)];
__shared__ half scaling_factors_shared[64];
__shared__ half zeros_shared[64];
int j_factors1 = ((OC + 64 - 1) / 64);
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
half A_shared_warp[8];
half B_shared_warp[16];
for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) {
for (int i = 0; i < 8; ++i) {
C_warp[(j_0_4_init * 8) + i] = 0.0;
}
}
static constexpr int row_stride_warp = 32 * 8 / 32;
static constexpr int row_stride = 2 * 32 * 8 / 64;
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64;
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
// bool wb_C_flag = (threadIdx.x / 4) < M;
half* A_ptr = A
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
+ (((int)threadIdx.x) % (32 / 8)) * 8;
int* B_ptr = B
+ ((int)threadIdx.y) * (OC / 8) * 4
+ (((int)threadIdx.x) / (64 / 8)) * (OC / 8)
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
+ (((int)threadIdx.x) % (64 / 8)) * 1;
// Why * 1 in the above line?
half* A_shared_ptr = A_shared
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
half* B_shared_ptr = B_shared
+ ((int)threadIdx.y) * (row_stride / 2) * (64 + 8)
+ (((int)threadIdx.x) / (64 / 8)) * (64 + 8)
+ (((int)threadIdx.x) % (64 / 8)) * 8;
int* zeros_ptr = zeros
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
+ ((int)threadIdx.x) % (64 / 8);
half* scaling_factors_ptr = scaling_factors
+ (((int)blockIdx_y) % j_factors1) * (64)
+ (((int)threadIdx.x) % (64 / 8)) * 8;
half* C_ptr = C
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim
+ (((int)blockIdx_y) % j_factors1) * 64
+ ((int)threadIdx.y) * 32
+ (((int)threadIdx.x) % 4) * 2;
// preload s.f. and zeros
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
__syncthreads();
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
if (ld_A_flag)
{
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
}
else
{
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
}
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
/*
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
}
*/
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) {
// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
// - zero and * scale
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
/*
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
}
*/
// write back
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16;
}
__syncthreads();
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1)
{
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
);
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
: "r"(addr)
);
}
for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0)
{
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8))))
);
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
: "r"(addr)
);
}
}
for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4)
{
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
}
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
}
}
}
}
// TODO: Shang: Hoist loop invariance.
for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) {
for (int local_id = 0; local_id < 8; ++local_id) {
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
if (row_offset < M)
{
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
}
}
}
}
// in_feats: M, IC [float16] // in_feats: M, IC [float16]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b] // kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// scaling_factors: IC // G, OC [float16] // scaling_factors: IC // G, OC [float16]
...@@ -232,20 +431,38 @@ torch::Tensor gemm_forward_cuda( ...@@ -232,20 +431,38 @@ torch::Tensor gemm_forward_cuda(
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>()); auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>()); auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>()); auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
int group_size = num_in_channels / _scaling_factors.size(0);
if (num_out_channels % 64 != 0)
if (num_out_channels % 128 != 0) throw std::invalid_argument("OC is not multiple of cta_N = 64");
throw std::invalid_argument("OC is not multiple of cta_N = 128");
if (num_out_channels % 8 != 0) if (num_out_channels % 8 != 0)
throw std::invalid_argument("OC is not multiple of pack_num = 8"); throw std::invalid_argument("OC is not multiple of pack_num = 8");
int j_factors1 = num_out_channels / 128 / 1; if (group_size % 32 != 0)
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); throw std::invalid_argument("Group size should be a multiple of 32");
if (num_out_channels % group_size != 0)
throw std::invalid_argument("OC is not multiple of Group size");
if (num_out_channels % 128 == 0)
{
int j_factors1 = num_out_channels / 128 / 1;
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2);
gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
}
else if (num_out_channels % 64 == 0)
{
int j_factors1 = num_out_channels / 64 / 1;
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
// threadIdx.x: 32 // threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2] // threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2); dim3 threads_per_block(32, 2);
gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>( gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
}
return _out_feats.sum(0); return _out_feats.sum(0);
} }
...@@ -3,7 +3,7 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtensio ...@@ -3,7 +3,7 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtensio
extra_compile_args = { extra_compile_args = {
"cxx": ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17"], "cxx": ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17"],
"nvcc": ["-O3", "-std=c++17", "-keep"], "nvcc": ["-O3", "-std=c++17"],
} }
setup( setup(
...@@ -18,4 +18,4 @@ setup( ...@@ -18,4 +18,4 @@ setup(
], ],
cmdclass={"build_ext": BuildExtension}, cmdclass={"build_ext": BuildExtension},
install_requires=["torch"], install_requires=["torch"],
) )
\ No newline at end of file
...@@ -22,7 +22,7 @@ def auto_clip_layer(w, input_feat, n_bit, q_config, ...@@ -22,7 +22,7 @@ def auto_clip_layer(w, input_feat, n_bit, q_config,
input_feat = input_feat[:, 0::input_feat.shape[1] // n_sample_token] input_feat = input_feat[:, 0::input_feat.shape[1] // n_sample_token]
w = w.reshape(w.shape[0], 1, -1, group_size) w = w.reshape(w.shape[0], 1, -1, group_size)
oc_batch_size = 256 # prevent OOM oc_batch_size = 256 if w.shape[0] % 256 == 0 else 64 # prevent OOM
assert w.shape[0] % oc_batch_size == 0 assert w.shape[0] % oc_batch_size == 0
w_all = w w_all = w
best_max_val_all = [] best_max_val_all = []
...@@ -73,11 +73,13 @@ def auto_clip_block(module, ...@@ -73,11 +73,13 @@ def auto_clip_block(module,
clip_list = [] clip_list = []
for name in named_linears: for name in named_linears:
# due to qk bmm, it is hard to clip precisely # due to qk bmm, it is hard to clip precisely
if any([_ in name for _ in ["q_", "k_"]]): if any([_ in name for _ in ["q_", "k_", "query", "key", "Wqkv"]]):
continue continue
named_linears[name].cuda()
max_val = auto_clip_layer( max_val = auto_clip_layer(
named_linears[name].weight, input_feat[name], n_bit=w_bit, q_config=q_config) named_linears[name].weight, input_feat[name], n_bit=w_bit, q_config=q_config)
clip_list.append((name, max_val)) clip_list.append((name, max_val))
named_linears[name].cpu()
return clip_list return clip_list
...@@ -86,8 +88,10 @@ def apply_clip(module, clip_list): ...@@ -86,8 +88,10 @@ def apply_clip(module, clip_list):
from ..utils.module import get_op_by_name from ..utils.module import get_op_by_name
for name, max_val in clip_list: for name, max_val in clip_list:
layer = get_op_by_name(module, name) layer = get_op_by_name(module, name)
layer.cuda()
max_val = max_val.to(layer.weight.device) max_val = max_val.to(layer.weight.device)
org_shape = layer.weight.shape org_shape = layer.weight.shape
layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1) layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1)
layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val) layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val)
layer.weight.data = layer.weight.data.reshape(org_shape) layer.weight.data = layer.weight.data.reshape(org_shape)
layer.cpu()
import gc
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu
from transformers.models.opt.modeling_opt import OPTDecoderLayer from transformers.models.opt.modeling_opt import OPTDecoderLayer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm
from ..utils.module import get_op_by_name, get_op_name from .qmodule import ScaledActivation
from ..utils.module import get_op_by_name, get_op_name, set_op_by_name
__all__ = ["auto_scale_block", "apply_scale"] __all__ = ["auto_scale_block", "apply_scale"]
...@@ -32,6 +35,13 @@ def scale_ln_fcs(ln, fcs, scales): ...@@ -32,6 +35,13 @@ def scale_ln_fcs(ln, fcs, scales):
scales = scales.to(ln.weight.device) scales = scales.to(ln.weight.device)
# debugging start even scales = 1 does not work?
"""
scales = scales * 0
scales = scales + 1
"""
# debugging end
ln.weight.div_(scales) ln.weight.div_(scales)
if hasattr(ln, 'bias') and ln.bias is not None: if hasattr(ln, 'bias') and ln.bias is not None:
ln.bias.div_(scales) ln.bias.div_(scales)
...@@ -50,11 +60,12 @@ def scale_ln_fcs(ln, fcs, scales): ...@@ -50,11 +60,12 @@ def scale_ln_fcs(ln, fcs, scales):
def scale_fc_fc(fc1, fc2, scales): def scale_fc_fc(fc1, fc2, scales):
assert isinstance(fc1, nn.Linear) assert isinstance(fc1, nn.Linear)
assert isinstance(fc2, nn.Linear) assert isinstance(fc2, nn.Linear)
assert fc1.out_features == fc2.in_features # assert fc1.out_features == fc2.in_features
scales = scales.to(fc1.weight.device) scales = scales.to(fc1.weight.device)
fc1.weight.div_(scales.view(-1, 1)) # fc1.weight.div_(scales.view(-1, 1))
fc1.weight[-scales.size(0):].div_(scales.view(-1, 1))
if fc1.bias is not None: if fc1.bias is not None:
fc1.bias.div_(scales.view(-1)) fc1.bias.div_(scales.view(-1))
...@@ -66,6 +77,17 @@ def scale_fc_fc(fc1, fc2, scales): ...@@ -66,6 +77,17 @@ def scale_fc_fc(fc1, fc2, scales):
assert torch.isnan(p).sum() == 0 assert torch.isnan(p).sum() == 0
@torch.no_grad()
def scale_gelu_fc(gelu, fc, scales):
assert isinstance(gelu, nn.GELU) or isinstance(gelu, BloomGelu)
assert isinstance(fc, nn.Linear)
fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
for p in fc.parameters():
assert torch.isnan(p).sum() == 0
@torch.no_grad() @torch.no_grad()
def auto_scale_block(module, module_kwargs, def auto_scale_block(module, module_kwargs,
w_bit, q_config, w_bit, q_config,
...@@ -86,11 +108,15 @@ def auto_scale_block(module, module_kwargs, ...@@ -86,11 +108,15 @@ def auto_scale_block(module, module_kwargs,
def _search_module_scale(block, linears2scale: list, x, kwargs={}): def _search_module_scale(block, linears2scale: list, x, kwargs={}):
# w: co, ci # w: co, ci
# x: n, ci # x: n, ci
x = x.to(next(block.parameters()).device)
weight = torch.cat([_m.weight for _m in linears2scale], dim=0) weight = torch.cat([_m.weight for _m in linears2scale], dim=0)
w_max = get_weight_scale( w_max = get_weight_scale(
weight, q_group_size=q_config.get("q_group_size", -1)) weight, q_group_size=q_config.get("q_group_size", -1))
# Clear GPU memory
del weight
gc.collect()
torch.cuda.empty_cache()
x = x.to(next(block.parameters()).device)
with torch.no_grad(): with torch.no_grad():
org_out = block(x, **kwargs) org_out = block(x, **kwargs)
if isinstance(org_out, tuple): if isinstance(org_out, tuple):
...@@ -112,7 +138,7 @@ def auto_scale_block(module, module_kwargs, ...@@ -112,7 +138,7 @@ def auto_scale_block(module, module_kwargs,
).clamp(min=1e-4).view(-1) ).clamp(min=1e-4).view(-1)
scales = scales / (scales.max() * scales.min()).sqrt() scales = scales / (scales.max() * scales.min()).sqrt()
for fc in linears2scale: for fc in linears2scale:
fc.weight.mul_(scales.view(1, -1)) fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
fc.weight.data = w_quantize_func( fc.weight.data = w_quantize_func(
fc.weight.data) / (scales.view(1, -1)) fc.weight.data) / (scales.view(1, -1))
out = block(x, **kwargs) out = block(x, **kwargs)
...@@ -143,6 +169,7 @@ def auto_scale_block(module, module_kwargs, ...@@ -143,6 +169,7 @@ def auto_scale_block(module, module_kwargs,
module2inspect = layers[0] module2inspect = layers[0]
scales = _search_module_scale(module2inspect, layers, inp, kwargs) scales = _search_module_scale(module2inspect, layers, inp, kwargs)
scales = scales.detach().cpu()
# prev_op_name, [layer_name], scale # prev_op_name, [layer_name], scale
return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), scales) return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), scales)
...@@ -204,7 +231,110 @@ def auto_scale_block(module, module_kwargs, ...@@ -204,7 +231,110 @@ def auto_scale_block(module, module_kwargs,
layers=[module.mlp.down_proj], layers=[module.mlp.down_proj],
inp=input_feat['mlp.down_proj'], inp=input_feat['mlp.down_proj'],
)) ))
elif isinstance(module, BloomBlock):
# attention input
scales_list.append(_auto_get_scale(
prev_op=module.input_layernorm,
layers=[module.self_attention.query_key_value],
inp=input_feat['self_attention.query_key_value'],
module2inspect=module, kwargs=module_kwargs,
))
# attn out
# Please refer to https://github.com/mit-han-lab/llm-awq/issues/2#issuecomment-1606297469
"""
scales_list.append(_auto_get_scale(
prev_op=module.self_attention.query_key_value,
layers=[module.self_attention.dense],
inp=input_feat['self_attention.dense'],
))
"""
# fc1
scales_list.append(_auto_get_scale(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.dense_h_to_4h],
inp=input_feat['mlp.dense_h_to_4h'],
module2inspect=module, kwargs=module_kwargs,
))
# fc2
scales_list.append(_auto_get_scale(
prev_op=module.mlp.gelu_impl,
layers=[module.mlp.dense_4h_to_h],
inp=input_feat['mlp.dense_4h_to_h'],
))
elif "mpt" in str(module.__class__).lower():
# attention input
scales_list.append(_auto_get_scale(
prev_op=module.norm_1,
layers=[module.attn.Wqkv],
inp=input_feat['attn.Wqkv'],
module2inspect=module.attn,
kwargs=module_kwargs,
))
# attn out
scales_list.append(_auto_get_scale(
prev_op=module.attn.Wqkv,
layers=[module.attn.out_proj],
inp=input_feat['attn.out_proj'],
))
# fc1
scales_list.append(_auto_get_scale(
prev_op=module.norm_2,
layers=[module.ffn.up_proj],
inp=input_feat['ffn.up_proj'],
module2inspect=module.ffn,
))
# fc2
scales_list.append(_auto_get_scale(
prev_op=module.ffn.act,
layers=[module.ffn.down_proj],
inp=input_feat['ffn.down_proj'],
))
elif "falcon" in str(module.__class__).lower():
# attn out
# Haotian: TBD: need to handle repeated scales for MQ
"""
scales_list.append(_auto_get_scale(
prev_op=module.self_attention.query_key_value,
layers=[module.self_attention.dense],
inp=input_feat['self_attention.dense'],
))
"""
# fc1, as long as it is scaled, everything is screwed up
if "falcon-7b" in str(module.__class__).lower():
scales_list.append(_auto_get_scale(
prev_op=module.input_layernorm,
layers=[module.mlp.dense_h_to_4h, module.self_attention.query_key_value],
inp=input_feat['self_attention.query_key_value'],
module2inspect=module,
kwargs=module_kwargs,
))
elif "falcon-40b" in str(module.__class__).lower():
scales_list.append(_auto_get_scale(
prev_op=module.ln_attn,
layers=[module.self_attention.query_key_value],
inp=input_feat['self_attention.query_key_value'],
module2inspect=module,
kwargs=module_kwargs,
))
scales_list.append(_auto_get_scale(
prev_op=module.ln_mlp,
layers=[module.mlp.dense_h_to_4h],
inp=input_feat['mlp.dense_h_to_4h'],
module2inspect=module,
kwargs=module_kwargs,
))
else:
raise NotImplementedError("Unknown Falcon architecture, currently only falcon-7b and falcon-40b are supported")
# fc2
scales_list.append(_auto_get_scale(
prev_op=module.mlp.act,
layers=[module.mlp.dense_4h_to_h],
inp=input_feat['mlp.dense_4h_to_h'],
))
else: else:
raise NotImplementedError(f"{type(module)} not supported yet!") raise NotImplementedError(f"{type(module)} not supported yet!")
...@@ -214,12 +344,21 @@ def apply_scale(module, scales_list, input_feat_dict=None): ...@@ -214,12 +344,21 @@ def apply_scale(module, scales_list, input_feat_dict=None):
for prev_op_name, layer_names, scales in scales_list: for prev_op_name, layer_names, scales in scales_list:
prev_op = get_op_by_name(module, prev_op_name) prev_op = get_op_by_name(module, prev_op_name)
layers = [get_op_by_name(module, name) for name in layer_names] layers = [get_op_by_name(module, name) for name in layer_names]
prev_op.cuda()
for layer in layers:
layer.cuda()
scales.cuda()
if isinstance(prev_op, nn.Linear): if isinstance(prev_op, nn.Linear):
assert len(layers) == 1 assert len(layers) == 1
scale_fc_fc(prev_op, layers[0], scales) scale_fc_fc(prev_op, layers[0], scales)
elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)): elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)):
scale_ln_fcs(prev_op, layers, scales) scale_ln_fcs(prev_op, layers, scales)
elif isinstance(prev_op, nn.GELU) or isinstance(prev_op, BloomGelu):
new_module = ScaledActivation(prev_op, scales)
set_op_by_name(module, prev_op_name, new_module)
scale_gelu_fc(prev_op, layers[0], scales)
else: else:
raise NotImplementedError( raise NotImplementedError(
f"prev_op {type(prev_op)} not supported yet!") f"prev_op {type(prev_op)} not supported yet!")
...@@ -229,3 +368,8 @@ def apply_scale(module, scales_list, input_feat_dict=None): ...@@ -229,3 +368,8 @@ def apply_scale(module, scales_list, input_feat_dict=None):
for layer_name in layer_names: for layer_name in layer_names:
inp = input_feat_dict[layer_name] inp = input_feat_dict[layer_name]
inp.div_(scales.view(1, -1).to(inp.device)) inp.div_(scales.view(1, -1).to(inp.device))
prev_op.cpu()
for layer in layers:
layer.cpu()
scales.cpu()
...@@ -5,6 +5,7 @@ import gc ...@@ -5,6 +5,7 @@ import gc
import functools import functools
from collections import defaultdict from collections import defaultdict
from transformers.models.bloom.modeling_bloom import BloomForCausalLM
from transformers.models.opt.modeling_opt import OPTForCausalLM from transformers.models.opt.modeling_opt import OPTForCausalLM
from transformers.models.llama.modeling_llama import LlamaForCausalLM from transformers.models.llama.modeling_llama import LlamaForCausalLM
...@@ -23,10 +24,32 @@ def get_blocks(model): ...@@ -23,10 +24,32 @@ def get_blocks(model):
layers = model.model.layers layers = model.model.layers
elif isinstance(model, OPTForCausalLM): elif isinstance(model, OPTForCausalLM):
layers = model.model.decoder.layers layers = model.model.decoder.layers
elif isinstance(model, BloomForCausalLM):
layers = model.transformer.h
elif "mpt" in str(model.__class__).lower():
layers = model.transformer.blocks
elif "falcon" in str(model.__class__).lower():
layers = model.transformer.h
else: else:
raise NotImplementedError(type(model)) raise NotImplementedError(type(model))
return layers return layers
def move_embed(model, device):
if isinstance(model, LlamaForCausalLM):
model.model.embed_tokens = model.model.embed_tokens.to(device)
elif isinstance(model, OPTForCausalLM):
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(device)
model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(device)
elif isinstance(model, BloomForCausalLM):
model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
model.transformer.word_embeddings_layernorm = model.transformer.word_embeddings_layernorm.to(device)
elif "mpt" in str(model.__class__).lower():
model.transformer.wte = model.transformer.wte.to(device)
model.transformer.emb_drop = model.transformer.emb_drop.to(device)
elif "falcon" in str(model.__class__).lower():
model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
else:
raise NotImplementedError(type(model))
@torch.no_grad() @torch.no_grad()
def run_awq( def run_awq(
...@@ -50,6 +73,9 @@ def run_awq( ...@@ -50,6 +73,9 @@ def run_awq(
inps = [] inps = []
layer_kwargs = {} layer_kwargs = {}
layers[0] = layers[0].cuda()
move_embed(model, "cuda")
# get input and kwargs to layer 0 # get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0 # with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now # use this Catcher hack for now
...@@ -69,9 +95,13 @@ def run_awq( ...@@ -69,9 +95,13 @@ def run_awq(
model(samples.to(next(model.parameters()).device)) model(samples.to(next(model.parameters()).device))
except ValueError: # work with early exit except ValueError: # work with early exit
pass pass
del samples
layers[0] = layers[0].module # restore layers[0] = layers[0].module # restore
inps = inps[0] inps = inps[0]
layers[0] = layers[0].cpu()
move_embed(model, "cpu")
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -83,6 +113,7 @@ def run_awq( ...@@ -83,6 +113,7 @@ def run_awq(
# solve layer by layer # solve layer by layer
for i in tqdm.tqdm(range(len(layers)), desc="Running AWQ..."): for i in tqdm.tqdm(range(len(layers)), desc="Running AWQ..."):
layer = layers[i] layer = layers[i]
layer = layer.cuda()
named_linears = get_named_linears(layer) named_linears = get_named_linears(layer)
# firstly, get input features of all linear layers # firstly, get input features of all linear layers
...@@ -102,19 +133,25 @@ def run_awq( ...@@ -102,19 +133,25 @@ def run_awq(
inps = layer(inps, **layer_kwargs)[0] inps = layer(inps, **layer_kwargs)[0]
for h in handles: for h in handles:
h.remove() h.remove()
# now solve for scaling and clipping # now solve for scaling and clipping
input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()} input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}
# Clear GPU memory
torch.cuda.empty_cache()
if auto_scale: # if it applies, we should also modify the input_feat with scales if auto_scale: # if it applies, we should also modify the input_feat with scales
scales_list = auto_scale_block( scales_list = auto_scale_block(
layer, layer_kwargs, layer, layer_kwargs,
w_bit=w_bit, q_config=q_config, w_bit=w_bit, q_config=q_config,
input_feat=input_feat, input_feat=input_feat,
) )
apply_scale(layer, scales_list, input_feat_dict=input_feat) # apply_scale(layer, scales_list, input_feat_dict=input_feat)
apply_scale(layers[i], scales_list, input_feat_dict=input_feat)
# append prefix to make names global # append prefix to make names global
awq_results["scale"] += append_str_prefix(scales_list, get_op_name(model, layer) + ".") awq_results["scale"] += append_str_prefix(scales_list, get_op_name(model, layer) + ".")
# Clear GPU memory
torch.cuda.empty_cache()
if mse_range: if mse_range:
clip_list = auto_clip_block(layer, clip_list = auto_clip_block(layer,
...@@ -124,6 +161,8 @@ def run_awq( ...@@ -124,6 +161,8 @@ def run_awq(
# append prefix to make names global # append prefix to make names global
awq_results["clip"] += append_str_prefix(clip_list, get_op_name(model, layer) + ".") awq_results["clip"] += append_str_prefix(clip_list, get_op_name(model, layer) + ".")
layer = layer.cpu()
# Haotian: check activation replacement
del input_feat del input_feat
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
...@@ -4,6 +4,16 @@ import torch.nn as nn ...@@ -4,6 +4,16 @@ import torch.nn as nn
import f16s4_gemm # with CUDA kernels import f16s4_gemm # with CUDA kernels
class ScaledActivation(nn.Module):
def __init__(self, module, scales):
super().__init__()
self.act = module
self.scales = nn.Parameter(scales.data)
def forward(self, x):
return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
class WQLinear(nn.Module): class WQLinear(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__() super().__init__()
...@@ -83,3 +93,7 @@ class WQLinear(nn.Module): ...@@ -83,3 +93,7 @@ class WQLinear(nn.Module):
out = out + self.bias if self.bias is not None else out out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape) return out.reshape(out_shape)
def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'.format(
self.in_features, self.out_features, self.bias is not None, self.w_bit, self.group_size
)
...@@ -2,11 +2,48 @@ import torch ...@@ -2,11 +2,48 @@ import torch
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
import gc import gc
from .qmodule import ScaledActivation
from ..utils.module import set_op_by_name
from transformers.models.bloom.modeling_bloom import BloomBlock
EMBEDDING_KEYWORDS = ["embed"] EMBEDDING_KEYWORDS = ["embed"]
LM_HEAD_KEYWORDS = ["lm_head", "embed_out", "output"] LM_HEAD_KEYWORDS = ["lm_head", "embed_out", "output"]
def scale_activations(module):
param = next(module.parameters())
dtype = param.dtype
device = param.device
if isinstance(module, BloomBlock):
if isinstance(module.mlp.gelu_impl, ScaledActivation):
return
c = module.mlp.dense_h_to_4h.out_features
act = ScaledActivation(
module.mlp.gelu_impl,
torch.ones(c, dtype=dtype, device=device)
)
set_op_by_name(module, "mlp.gelu_impl", act)
elif 'mptblock' in str(module.__class__.__name__).lower():
if isinstance(module.ffn.act, ScaledActivation):
return
c = module.ffn.up_proj.out_features
act = ScaledActivation(
module.ffn.act,
torch.ones(c, dtype=dtype, device=device)
)
set_op_by_name(module, "ffn.act", act)
elif 'falcon' in str(module.__class__).lower():
if isinstance(module.mlp.act, ScaledActivation):
return
c = module.mlp.dense_h_to_4h.out_features
act = ScaledActivation(
module.mlp.act,
torch.ones(c, dtype=dtype, device=device)
)
set_op_by_name(module, "mlp.act", act)
# core quantization method (simulated quantization) # core quantization method (simulated quantization)
def pseudo_quantize_tensor(w, n_bit=8, def pseudo_quantize_tensor(w, n_bit=8,
zero_point=True, q_group_size=-1, zero_point=True, q_group_size=-1,
...@@ -61,7 +98,9 @@ def pseudo_quantize_model_weight( ...@@ -61,7 +98,9 @@ def pseudo_quantize_model_weight(
for i in tqdm(range(len(layers)), desc="pseudo weight quantization..."): for i in tqdm(range(len(layers)), desc="pseudo weight quantization..."):
named_linears = get_named_linears(layers[i]) named_linears = get_named_linears(layers[i])
for n, m in named_linears.items(): for n, m in named_linears.items():
m.cuda()
m.weight.data = pseudo_quantize_tensor(m.weight.data, n_bit=w_bit, **q_config) m.weight.data = pseudo_quantize_tensor(m.weight.data, n_bit=w_bit, **q_config)
m.cpu()
@torch.no_grad() @torch.no_grad()
...@@ -77,29 +116,21 @@ def real_quantize_model_weight( ...@@ -77,29 +116,21 @@ def real_quantize_model_weight(
for i in tqdm(range(len(layers)), desc="real weight quantization..." + ("(init only)" if init_only else "")): for i in tqdm(range(len(layers)), desc="real weight quantization..." + ("(init only)" if init_only else "")):
layer = layers[i] layer = layers[i]
named_linears = get_named_linears(layer) named_linears = get_named_linears(layer)
scale_activations(layer)
for name, module in named_linears.items(): for name, module in named_linears.items():
if init_only: if init_only:
q_linear = WQLinear.from_linear( q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], True) module, w_bit, q_config['q_group_size'], True)
else: else:
module.cuda()
module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config) module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config)
scales = scales.t().contiguous() scales = scales.t().contiguous()
zeros = zeros.t().contiguous() zeros = zeros.t().contiguous()
q_linear = WQLinear.from_linear( q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], False, scales, zeros) module, w_bit, q_config['q_group_size'], False, scales, zeros)
module.cpu()
levels = name.split('.') q_linear.to(next(layer.parameters()).device)
if len(levels) > 1: set_op_by_name(layer, name, q_linear)
mod_ = layer
for l_idx in range(len(levels)-1):
if levels[l_idx].isdigit():
mod_ = mod_[int(levels[l_idx])]
else:
mod_ = getattr(mod_, levels[l_idx])
setattr(mod_, levels[-1], q_linear)
else:
setattr(layer, name, q_linear)
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
\ No newline at end of file
...@@ -47,6 +47,10 @@ class LMEvalAdaptor(BaseLM): ...@@ -47,6 +47,10 @@ class LMEvalAdaptor(BaseLM):
return 2048 return 2048
elif 'llama' in self.model_name: elif 'llama' in self.model_name:
return 2048 # TODO: did not check this return 2048 # TODO: did not check this
elif 'mpt' in self.model_name:
return 2048
elif 'falcon' in self.model_name:
return 2048
else: else:
print(self.model.config) print(self.model.config)
raise NotImplementedError raise NotImplementedError
......
...@@ -8,6 +8,20 @@ def get_op_by_name(module, op_name): ...@@ -8,6 +8,20 @@ def get_op_by_name(module, op_name):
raise ValueError(f"Cannot find op {op_name} in module {module}") raise ValueError(f"Cannot find op {op_name} in module {module}")
def set_op_by_name(layer, name, new_module):
levels = name.split('.')
if len(levels) > 1:
mod_ = layer
for l_idx in range(len(levels)-1):
if levels[l_idx].isdigit():
mod_ = mod_[int(levels[l_idx])]
else:
mod_ = getattr(mod_, levels[l_idx])
setattr(mod_, levels[-1], new_module)
else:
setattr(layer, name, new_module)
def get_op_name(module, op): def get_op_name(module, op):
# get the name of the op relative to the module # get the name of the op relative to the module
for name, m in module.named_modules(): for name, m in module.named_modules():
......
import torch
import accelerate
def get_module_by_name_suffix(model, module_name: str):
for name, module in model.named_modules():
if name.endswith(module_name):
return module
def simple_dispatch_model(model, device_map):
from accelerate.hooks import add_hook_to_module, AlignDevicesHook
if "" in device_map:
d = device_map[""]
model = model.to(torch.device(d))
model.hf_device_map = device_map
return model
tied_params = accelerate.utils.modeling.find_tied_parameters(model)
if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}:
main_device = "cpu"
else:
main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0]
cpu_offload_group = [(n, d) for n, d in device_map.items() if d == "cpu"]
prev_hook = None
for idx, (n, d) in enumerate(cpu_offload_group):
m = get_module_by_name_suffix(model, n)
_, prev_hook = accelerate.cpu_offload_with_hook(m, execution_device=main_device, prev_module_hook=prev_hook)
# set first cpu offload module's prev_module_hook to the last cpu offload module's hook
if len(cpu_offload_group) > 1:
get_module_by_name_suffix(model, cpu_offload_group[0][0])._hf_hook.prev_module_hook = prev_hook
for n, d in device_map.items():
m = get_module_by_name_suffix(model, n)
if d != "cpu":
d = torch.device(d)
hook = AlignDevicesHook(d, io_same_device=True, place_submodules=True)
add_hook_to_module(m, hook)
accelerate.utils.modeling.retie_parameters(model, tied_params)
model.hf_device_map = device_map
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment