"csrc/git@developer.sourcefind.cn:OpenDAS/torch-sparce.git" did not exist on "014c4baefa0ba920799e58ceaecdc0f22c0e006e"
Commit 71d8e68d authored by Haotian Tang's avatar Haotian Tang
Browse files

[Major] Add support for BLOOM, MPT and Falcon.

parent 06e299ba
...@@ -62,15 +62,18 @@ def build_model_and_enc(model_path): ...@@ -62,15 +62,18 @@ def build_model_and_enc(model_path):
print(f"* Building model {model_path}") print(f"* Building model {model_path}")
# all hf model # all hf model
config = AutoConfig.from_pretrained(model_path) config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
enc = AutoTokenizer.from_pretrained(model_path, use_fast=False) if "mpt" in config.__class__.__name__.lower():
enc = AutoTokenizer.from_pretrained(config.tokenizer_name)
else:
enc = AutoTokenizer.from_pretrained(model_path, use_fast=False)
if args.load_quant: # directly load quantized weights if args.load_quant: # directly load quantized weights
# no need to really load the fp16 weights... just to get the model structure # no need to really load the fp16 weights... just to get the model structure
print("Loading pre-computed quantized weights...") print("Loading pre-computed quantized weights...")
with init_empty_weights(): with init_empty_weights():
model = AutoModelForCausalLM.from_pretrained(model_path, config=config, model = AutoModelForCausalLM.from_pretrained(model_path, config=config,
torch_dtype=torch.float16) torch_dtype=torch.float16, trust_remote_code=True)
real_quantize_model_weight( real_quantize_model_weight(
model, w_bit=args.w_bit, q_config=q_config, init_only=True) model, w_bit=args.w_bit, q_config=q_config, init_only=True)
model = load_checkpoint_and_dispatch( model = load_checkpoint_and_dispatch(
...@@ -83,8 +86,7 @@ def build_model_and_enc(model_path): ...@@ -83,8 +86,7 @@ def build_model_and_enc(model_path):
kwargs = {"device_map": "balanced", "torch_dtype": torch.float16} kwargs = {"device_map": "balanced", "torch_dtype": torch.float16}
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, **kwargs) model_path, config=config, trust_remote_code=True, **kwargs)
if args.run_awq: if args.run_awq:
awq_results = run_awq( awq_results = run_awq(
model, enc, model, enc,
......
...@@ -13,7 +13,7 @@ __pack_half2(const half x, const half y) { ...@@ -13,7 +13,7 @@ __pack_half2(const half x, const half y) {
return (v1 << 16) | v0; return (v1 << 16) | v0;
} }
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
{ {
static constexpr uint32_t ZERO = 0x0; static constexpr uint32_t ZERO = 0x0;
float C_warp[32]; float C_warp[32];
...@@ -24,7 +24,6 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli ...@@ -24,7 +24,6 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
__shared__ half zeros_shared[128]; __shared__ half zeros_shared[128];
int j_factors1 = ((OC + 128 - 1) / 128); int j_factors1 = ((OC + 128 - 1) / 128);
int blockIdx_x = 0; int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1); int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1); int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
...@@ -53,6 +52,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli ...@@ -53,6 +52,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
+ (((int)threadIdx.x) / (128 / 8)) * (OC / 8) + (((int)threadIdx.x) / (128 / 8)) * (OC / 8)
+ (((int)blockIdx_y) % j_factors1) * (128 / 8) + (((int)blockIdx_y) % j_factors1) * (128 / 8)
+ (((int)threadIdx.x) % (128 / 8)) * 1; + (((int)threadIdx.x) % (128 / 8)) * 1;
// Why * 1 in the above line?
half* A_shared_ptr = A_shared half* A_shared_ptr = A_shared
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8) + ((int)threadIdx.y) * row_stride_warp * (32 + 8)
...@@ -80,7 +80,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli ...@@ -80,7 +80,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
// preload s.f. and zeros // preload s.f. and zeros
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
if ((k_bound - 1) * 32 + blockIdx_z >= IC) k_bound -= 1; if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) { for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
__syncthreads(); __syncthreads();
...@@ -95,9 +95,9 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli ...@@ -95,9 +95,9 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
} }
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / 128 * (OC / 8)); uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / 128 * (OC)); uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
/* /*
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
...@@ -107,6 +107,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli ...@@ -107,6 +107,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) { for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) {
// TODO: Shang: double check how to get 8.
// B: 32 x 136 (128+8) float16 // B: 32 x 136 (128+8) float16
// each warp: 32 x 4 // each warp: 32 x 4
...@@ -205,6 +206,204 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli ...@@ -205,6 +206,204 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
} }
} }
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
{
static constexpr uint32_t ZERO = 0x0;
float C_warp[32];
__shared__ half A_shared[16 * (32 + 8)];
__shared__ half B_shared[32 * (64 + 8)];
__shared__ half scaling_factors_shared[64];
__shared__ half zeros_shared[64];
int j_factors1 = ((OC + 64 - 1) / 64);
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
half A_shared_warp[8];
half B_shared_warp[16];
for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) {
for (int i = 0; i < 8; ++i) {
C_warp[(j_0_4_init * 8) + i] = 0.0;
}
}
static constexpr int row_stride_warp = 32 * 8 / 32;
static constexpr int row_stride = 2 * 32 * 8 / 64;
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64;
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
// bool wb_C_flag = (threadIdx.x / 4) < M;
half* A_ptr = A
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
+ (((int)threadIdx.x) % (32 / 8)) * 8;
int* B_ptr = B
+ ((int)threadIdx.y) * (OC / 8) * 4
+ (((int)threadIdx.x) / (64 / 8)) * (OC / 8)
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
+ (((int)threadIdx.x) % (64 / 8)) * 1;
// Why * 1 in the above line?
half* A_shared_ptr = A_shared
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
half* B_shared_ptr = B_shared
+ ((int)threadIdx.y) * (row_stride / 2) * (64 + 8)
+ (((int)threadIdx.x) / (64 / 8)) * (64 + 8)
+ (((int)threadIdx.x) % (64 / 8)) * 8;
int* zeros_ptr = zeros
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
+ ((int)threadIdx.x) % (64 / 8);
half* scaling_factors_ptr = scaling_factors
+ (((int)blockIdx_y) % j_factors1) * (64)
+ (((int)threadIdx.x) % (64 / 8)) * 8;
half* C_ptr = C
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim
+ (((int)blockIdx_y) % j_factors1) * 64
+ ((int)threadIdx.y) * 32
+ (((int)threadIdx.x) % 4) * 2;
// preload s.f. and zeros
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
__syncthreads();
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
if (ld_A_flag)
{
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
}
else
{
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
}
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
/*
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
}
*/
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) {
// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
// - zero and * scale
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
/*
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
}
*/
// write back
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16;
}
__syncthreads();
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1)
{
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
);
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
: "r"(addr)
);
}
for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0)
{
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8))))
);
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
: "r"(addr)
);
}
}
for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4)
{
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
}
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
}
}
}
}
// TODO: Shang: Hoist loop invariance.
for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) {
for (int local_id = 0; local_id < 8; ++local_id) {
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
if (row_offset < M)
{
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
}
}
}
}
// in_feats: M, IC [float16] // in_feats: M, IC [float16]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b] // kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// scaling_factors: IC // G, OC [float16] // scaling_factors: IC // G, OC [float16]
...@@ -232,20 +431,38 @@ torch::Tensor gemm_forward_cuda( ...@@ -232,20 +431,38 @@ torch::Tensor gemm_forward_cuda(
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>()); auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>()); auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>()); auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
int group_size = num_in_channels / _scaling_factors.size(0);
if (num_out_channels % 64 != 0)
if (num_out_channels % 128 != 0) throw std::invalid_argument("OC is not multiple of cta_N = 64");
throw std::invalid_argument("OC is not multiple of cta_N = 128");
if (num_out_channels % 8 != 0) if (num_out_channels % 8 != 0)
throw std::invalid_argument("OC is not multiple of pack_num = 8"); throw std::invalid_argument("OC is not multiple of pack_num = 8");
int j_factors1 = num_out_channels / 128 / 1; if (group_size % 32 != 0)
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); throw std::invalid_argument("Group size should be a multiple of 32");
if (num_out_channels % group_size != 0)
throw std::invalid_argument("OC is not multiple of Group size");
if (num_out_channels % 128 == 0)
{
int j_factors1 = num_out_channels / 128 / 1;
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2);
gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
}
else if (num_out_channels % 64 == 0)
{
int j_factors1 = num_out_channels / 64 / 1;
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
// threadIdx.x: 32 // threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2] // threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2); dim3 threads_per_block(32, 2);
gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>( gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
}
return _out_feats.sum(0); return _out_feats.sum(0);
} }
...@@ -22,7 +22,7 @@ def auto_clip_layer(w, input_feat, n_bit, q_config, ...@@ -22,7 +22,7 @@ def auto_clip_layer(w, input_feat, n_bit, q_config,
input_feat = input_feat[:, 0::input_feat.shape[1] // n_sample_token] input_feat = input_feat[:, 0::input_feat.shape[1] // n_sample_token]
w = w.reshape(w.shape[0], 1, -1, group_size) w = w.reshape(w.shape[0], 1, -1, group_size)
oc_batch_size = 256 # prevent OOM oc_batch_size = 256 if w.shape[0] % 256 == 0 else 64 # prevent OOM
assert w.shape[0] % oc_batch_size == 0 assert w.shape[0] % oc_batch_size == 0
w_all = w w_all = w
best_max_val_all = [] best_max_val_all = []
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu
from transformers.models.opt.modeling_opt import OPTDecoderLayer from transformers.models.opt.modeling_opt import OPTDecoderLayer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm
from ..utils.module import get_op_by_name, get_op_name from .qmodule import ScaledActivation
from ..utils.module import get_op_by_name, get_op_name, set_op_by_name
__all__ = ["auto_scale_block", "apply_scale"] __all__ = ["auto_scale_block", "apply_scale"]
...@@ -32,6 +34,13 @@ def scale_ln_fcs(ln, fcs, scales): ...@@ -32,6 +34,13 @@ def scale_ln_fcs(ln, fcs, scales):
scales = scales.to(ln.weight.device) scales = scales.to(ln.weight.device)
# debugging start even scales = 1 does not work?
"""
scales = scales * 0
scales = scales + 1
"""
# debugging end
ln.weight.div_(scales) ln.weight.div_(scales)
if hasattr(ln, 'bias') and ln.bias is not None: if hasattr(ln, 'bias') and ln.bias is not None:
ln.bias.div_(scales) ln.bias.div_(scales)
...@@ -50,11 +59,12 @@ def scale_ln_fcs(ln, fcs, scales): ...@@ -50,11 +59,12 @@ def scale_ln_fcs(ln, fcs, scales):
def scale_fc_fc(fc1, fc2, scales): def scale_fc_fc(fc1, fc2, scales):
assert isinstance(fc1, nn.Linear) assert isinstance(fc1, nn.Linear)
assert isinstance(fc2, nn.Linear) assert isinstance(fc2, nn.Linear)
assert fc1.out_features == fc2.in_features # assert fc1.out_features == fc2.in_features
scales = scales.to(fc1.weight.device) scales = scales.to(fc1.weight.device)
fc1.weight.div_(scales.view(-1, 1)) # fc1.weight.div_(scales.view(-1, 1))
fc1.weight[-scales.size(0):].div_(scales.view(-1, 1))
if fc1.bias is not None: if fc1.bias is not None:
fc1.bias.div_(scales.view(-1)) fc1.bias.div_(scales.view(-1))
...@@ -66,6 +76,17 @@ def scale_fc_fc(fc1, fc2, scales): ...@@ -66,6 +76,17 @@ def scale_fc_fc(fc1, fc2, scales):
assert torch.isnan(p).sum() == 0 assert torch.isnan(p).sum() == 0
@torch.no_grad()
def scale_gelu_fc(gelu, fc, scales):
assert isinstance(gelu, nn.GELU) or isinstance(gelu, BloomGelu)
assert isinstance(fc, nn.Linear)
fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
for p in fc.parameters():
assert torch.isnan(p).sum() == 0
@torch.no_grad() @torch.no_grad()
def auto_scale_block(module, module_kwargs, def auto_scale_block(module, module_kwargs,
w_bit, q_config, w_bit, q_config,
...@@ -112,7 +133,7 @@ def auto_scale_block(module, module_kwargs, ...@@ -112,7 +133,7 @@ def auto_scale_block(module, module_kwargs,
).clamp(min=1e-4).view(-1) ).clamp(min=1e-4).view(-1)
scales = scales / (scales.max() * scales.min()).sqrt() scales = scales / (scales.max() * scales.min()).sqrt()
for fc in linears2scale: for fc in linears2scale:
fc.weight.mul_(scales.view(1, -1)) fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
fc.weight.data = w_quantize_func( fc.weight.data = w_quantize_func(
fc.weight.data) / (scales.view(1, -1)) fc.weight.data) / (scales.view(1, -1))
out = block(x, **kwargs) out = block(x, **kwargs)
...@@ -204,7 +225,91 @@ def auto_scale_block(module, module_kwargs, ...@@ -204,7 +225,91 @@ def auto_scale_block(module, module_kwargs,
layers=[module.mlp.down_proj], layers=[module.mlp.down_proj],
inp=input_feat['mlp.down_proj'], inp=input_feat['mlp.down_proj'],
)) ))
elif isinstance(module, BloomBlock):
# attention input
scales_list.append(_auto_get_scale(
prev_op=module.input_layernorm,
layers=[module.self_attention.query_key_value],
inp=input_feat['self_attention.query_key_value'],
module2inspect=module, kwargs=module_kwargs,
))
# attn out
"""
scales_list.append(_auto_get_scale(
prev_op=module.self_attention.query_key_value,
layers=[module.self_attention.dense],
inp=input_feat['self_attn.dense'],
))
"""
# fc1
scales_list.append(_auto_get_scale(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.dense_h_to_4h],
inp=input_feat['mlp.dense_h_to_4h'],
module2inspect=module, kwargs=module_kwargs,
))
# fc2
scales_list.append(_auto_get_scale(
prev_op=module.mlp.gelu_impl,
layers=[module.mlp.dense_4h_to_h],
inp=input_feat['mlp.dense_4h_to_h'],
))
elif "mpt" in str(module.__class__).lower():
# attention input
scales_list.append(_auto_get_scale(
prev_op=module.norm_1,
layers=[module.attn.Wqkv],
inp=input_feat['attn.Wqkv'],
module2inspect=module.attn,
kwargs=module_kwargs,
))
# attn out
scales_list.append(_auto_get_scale(
prev_op=module.attn.Wqkv,
layers=[module.attn.out_proj],
inp=input_feat['attn.out_proj'],
))
# fc1
scales_list.append(_auto_get_scale(
prev_op=module.norm_2,
layers=[module.ffn.up_proj],
inp=input_feat['ffn.up_proj'],
module2inspect=module.ffn,
))
# fc2
scales_list.append(_auto_get_scale(
prev_op=module.ffn.act,
layers=[module.ffn.down_proj],
inp=input_feat['ffn.down_proj'],
))
elif "falcon" in str(module.__class__).lower():
# attn out
# Haotian: TBD: need to handle repeated scales for MQ
"""
scales_list.append(_auto_get_scale(
prev_op=module.self_attention.query_key_value,
layers=[module.self_attention.dense],
inp=input_feat['self_attention.dense'],
))
"""
# fc1, as long as it is scaled, everything is screwed up
scales_list.append(_auto_get_scale(
prev_op=module.input_layernorm,
layers=[module.mlp.dense_h_to_4h, module.self_attention.query_key_value],
inp=input_feat['self_attention.query_key_value'],
module2inspect=module,
kwargs=module_kwargs,
))
# fc2
scales_list.append(_auto_get_scale(
prev_op=module.mlp.act,
layers=[module.mlp.dense_4h_to_h],
inp=input_feat['mlp.dense_4h_to_h'],
))
else: else:
raise NotImplementedError(f"{type(module)} not supported yet!") raise NotImplementedError(f"{type(module)} not supported yet!")
...@@ -220,6 +325,10 @@ def apply_scale(module, scales_list, input_feat_dict=None): ...@@ -220,6 +325,10 @@ def apply_scale(module, scales_list, input_feat_dict=None):
scale_fc_fc(prev_op, layers[0], scales) scale_fc_fc(prev_op, layers[0], scales)
elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)): elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)):
scale_ln_fcs(prev_op, layers, scales) scale_ln_fcs(prev_op, layers, scales)
elif isinstance(prev_op, nn.GELU) or isinstance(prev_op, BloomGelu):
new_module = ScaledActivation(prev_op, scales)
set_op_by_name(module, prev_op_name, new_module)
scale_gelu_fc(prev_op, layers[0], scales)
else: else:
raise NotImplementedError( raise NotImplementedError(
f"prev_op {type(prev_op)} not supported yet!") f"prev_op {type(prev_op)} not supported yet!")
...@@ -228,4 +337,4 @@ def apply_scale(module, scales_list, input_feat_dict=None): ...@@ -228,4 +337,4 @@ def apply_scale(module, scales_list, input_feat_dict=None):
if input_feat_dict is not None: if input_feat_dict is not None:
for layer_name in layer_names: for layer_name in layer_names:
inp = input_feat_dict[layer_name] inp = input_feat_dict[layer_name]
inp.div_(scales.view(1, -1).to(inp.device)) inp.div_(scales.view(1, -1).to(inp.device))
\ No newline at end of file
...@@ -5,6 +5,7 @@ import gc ...@@ -5,6 +5,7 @@ import gc
import functools import functools
from collections import defaultdict from collections import defaultdict
from transformers.models.bloom.modeling_bloom import BloomForCausalLM
from transformers.models.opt.modeling_opt import OPTForCausalLM from transformers.models.opt.modeling_opt import OPTForCausalLM
from transformers.models.llama.modeling_llama import LlamaForCausalLM from transformers.models.llama.modeling_llama import LlamaForCausalLM
...@@ -23,6 +24,12 @@ def get_blocks(model): ...@@ -23,6 +24,12 @@ def get_blocks(model):
layers = model.model.layers layers = model.model.layers
elif isinstance(model, OPTForCausalLM): elif isinstance(model, OPTForCausalLM):
layers = model.model.decoder.layers layers = model.model.decoder.layers
elif isinstance(model, BloomForCausalLM):
layers = model.transformer.h
elif "mpt" in str(model.__class__).lower():
layers = model.transformer.blocks
elif "falcon" in str(model.__class__).lower():
layers = model.transformer.h
else: else:
raise NotImplementedError(type(model)) raise NotImplementedError(type(model))
return layers return layers
...@@ -102,7 +109,6 @@ def run_awq( ...@@ -102,7 +109,6 @@ def run_awq(
inps = layer(inps, **layer_kwargs)[0] inps = layer(inps, **layer_kwargs)[0]
for h in handles: for h in handles:
h.remove() h.remove()
# now solve for scaling and clipping # now solve for scaling and clipping
input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()} input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}
...@@ -112,7 +118,8 @@ def run_awq( ...@@ -112,7 +118,8 @@ def run_awq(
w_bit=w_bit, q_config=q_config, w_bit=w_bit, q_config=q_config,
input_feat=input_feat, input_feat=input_feat,
) )
apply_scale(layer, scales_list, input_feat_dict=input_feat) # apply_scale(layer, scales_list, input_feat_dict=input_feat)
apply_scale(layers[i], scales_list, input_feat_dict=input_feat)
# append prefix to make names global # append prefix to make names global
awq_results["scale"] += append_str_prefix(scales_list, get_op_name(model, layer) + ".") awq_results["scale"] += append_str_prefix(scales_list, get_op_name(model, layer) + ".")
...@@ -124,6 +131,7 @@ def run_awq( ...@@ -124,6 +131,7 @@ def run_awq(
# append prefix to make names global # append prefix to make names global
awq_results["clip"] += append_str_prefix(clip_list, get_op_name(model, layer) + ".") awq_results["clip"] += append_str_prefix(clip_list, get_op_name(model, layer) + ".")
# Haotian: check activation replacement
del input_feat del input_feat
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
...@@ -4,6 +4,16 @@ import torch.nn as nn ...@@ -4,6 +4,16 @@ import torch.nn as nn
import f16s4_gemm # with CUDA kernels import f16s4_gemm # with CUDA kernels
class ScaledActivation(nn.Module):
def __init__(self, module, scales):
super().__init__()
self.act = module
self.scales = nn.Parameter(scales.data)
def forward(self, x):
return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
class WQLinear(nn.Module): class WQLinear(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__() super().__init__()
......
...@@ -2,11 +2,48 @@ import torch ...@@ -2,11 +2,48 @@ import torch
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
import gc import gc
from .qmodule import ScaledActivation
from ..utils.module import set_op_by_name
from transformers.models.bloom.modeling_bloom import BloomBlock
EMBEDDING_KEYWORDS = ["embed"] EMBEDDING_KEYWORDS = ["embed"]
LM_HEAD_KEYWORDS = ["lm_head", "embed_out", "output"] LM_HEAD_KEYWORDS = ["lm_head", "embed_out", "output"]
def scale_activations(module):
param = next(module.parameters())
dtype = param.dtype
device = param.device
if isinstance(module, BloomBlock):
if isinstance(module.mlp.gelu_impl, ScaledActivation):
return
c = module.mlp.dense_h_to_4h.out_features
act = ScaledActivation(
module.mlp.gelu_impl,
torch.ones(c, dtype=dtype, device=device)
)
set_op_by_name(module, "mlp.gelu_impl", act)
elif 'mptblock' in str(module.__class__.__name__).lower():
if isinstance(module.ffn.act, ScaledActivation):
return
c = module.ffn.up_proj.out_features
act = ScaledActivation(
module.ffn.act,
torch.ones(c, dtype=dtype, device=device)
)
set_op_by_name(module, "ffn.act", act)
elif 'falcon' in str(module.__class__).lower():
if isinstance(module.mlp.act, ScaledActivation):
return
c = module.mlp.dense_h_to_4h.out_features
act = ScaledActivation(
module.mlp.act,
torch.ones(c, dtype=dtype, device=device)
)
set_op_by_name(module, "mlp.act", act)
# core quantization method (simulated quantization) # core quantization method (simulated quantization)
def pseudo_quantize_tensor(w, n_bit=8, def pseudo_quantize_tensor(w, n_bit=8,
zero_point=True, q_group_size=-1, zero_point=True, q_group_size=-1,
...@@ -77,7 +114,8 @@ def real_quantize_model_weight( ...@@ -77,7 +114,8 @@ def real_quantize_model_weight(
for i in tqdm(range(len(layers)), desc="real weight quantization..." + ("(init only)" if init_only else "")): for i in tqdm(range(len(layers)), desc="real weight quantization..." + ("(init only)" if init_only else "")):
layer = layers[i] layer = layers[i]
named_linears = get_named_linears(layer) named_linears = get_named_linears(layer)
scale_activations(layer)
for name, module in named_linears.items(): for name, module in named_linears.items():
if init_only: if init_only:
q_linear = WQLinear.from_linear( q_linear = WQLinear.from_linear(
...@@ -88,18 +126,6 @@ def real_quantize_model_weight( ...@@ -88,18 +126,6 @@ def real_quantize_model_weight(
zeros = zeros.t().contiguous() zeros = zeros.t().contiguous()
q_linear = WQLinear.from_linear( q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], False, scales, zeros) module, w_bit, q_config['q_group_size'], False, scales, zeros)
set_op_by_name(layer, name, q_linear)
levels = name.split('.')
if len(levels) > 1:
mod_ = layer
for l_idx in range(len(levels)-1):
if levels[l_idx].isdigit():
mod_ = mod_[int(levels[l_idx])]
else:
mod_ = getattr(mod_, levels[l_idx])
setattr(mod_, levels[-1], q_linear)
else:
setattr(layer, name, q_linear)
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
\ No newline at end of file
...@@ -47,6 +47,10 @@ class LMEvalAdaptor(BaseLM): ...@@ -47,6 +47,10 @@ class LMEvalAdaptor(BaseLM):
return 2048 return 2048
elif 'llama' in self.model_name: elif 'llama' in self.model_name:
return 2048 # TODO: did not check this return 2048 # TODO: did not check this
elif 'mpt' in self.model_name:
return 2048
elif 'falcon' in self.model_name:
return 2048
else: else:
print(self.model.config) print(self.model.config)
raise NotImplementedError raise NotImplementedError
......
...@@ -8,6 +8,20 @@ def get_op_by_name(module, op_name): ...@@ -8,6 +8,20 @@ def get_op_by_name(module, op_name):
raise ValueError(f"Cannot find op {op_name} in module {module}") raise ValueError(f"Cannot find op {op_name} in module {module}")
def set_op_by_name(layer, name, new_module):
levels = name.split('.')
if len(levels) > 1:
mod_ = layer
for l_idx in range(len(levels)-1):
if levels[l_idx].isdigit():
mod_ = mod_[int(levels[l_idx])]
else:
mod_ = getattr(mod_, levels[l_idx])
setattr(mod_, levels[-1], new_module)
else:
setattr(layer, name, new_module)
def get_op_name(module, op): def get_op_name(module, op):
# get the name of the op relative to the module # get the name of the op relative to the module
for name, m in module.named_modules(): for name, m in module.named_modules():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment