import torch try: from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV except ModuleNotFoundError as e: raise ModuleNotFoundError( f"AutoAWQ package (https://github.com/casper-hansen/AutoAWQ) is required to run this benchmark. {e}" ) import numpy as np from auto_gptq.modeling._utils import autogptq_post_init from auto_gptq.nn_modules.qlinear.qlinear_exllamav2 import QuantLinear from auto_gptq.utils.import_utils import dynamically_import_QuantLinear group_size = 128 bits = 4 # Yi 34B down_proj k = 20480 n = 7168 device = torch.device("cuda:0") linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=4) linear_gptq = linear_class( bits=bits, group_size=group_size, infeatures=k, outfeatures=n, bias=False, ) assert isinstance(linear_gptq, QuantLinear) linear_gptq = linear_gptq.eval() linear_gptq = linear_gptq.to(device) linear_gptq = autogptq_post_init(linear_gptq, use_act_order=False) num_runs = 60 lines = [] seqlens = [ 1, 2, 3, 4, 5, 6, 7, 8, 12, 16, 24, 32, 48, 64, 80, 120, 250, 512, 1024, 2048, 4000, 8000, ] print(f"in_features={k}, out_features={n}") for query_length in seqlens: # batch_size, query_length, hidden_size inp = torch.rand(1, query_length, k, dtype=torch.float16).to(device) torch.cuda.empty_cache() # Warmup Exllama v2 with torch.no_grad(): res = linear_gptq(inp) latencies = [] torch.cuda.synchronize() for _ in range(num_runs): start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) torch.cuda.synchronize() start_event.record() res = linear_gptq(inp) end_event.record() torch.cuda.synchronize() latency_ms = start_event.elapsed_time(end_event) latencies.append(latency_ms) # print("-------") # print(f"Latency GPTQ Exllama v2 (query_length={query_length}): {np.mean(latencies):.3f} ms, p10={np.percentile(latencies, 10):.3f}, p90={np.percentile(latencies, 90):.3f}") exllamav2_mean_latency = np.mean(latencies) exllamav2_p10 = np.percentile(latencies, 10) exllamav2_p90 = np.percentile(latencies, 90) torch.cuda.empty_cache() total_seqlen = inp.shape[:-1].numel() if total_seqlen <= 8: awq_kernel = "GEMV" linear_awq = WQLinear_GEMV( w_bit=bits, group_size=group_size, in_features=k, out_features=n, bias=False, dev=device, ) else: awq_kernel = "GEMM" linear_awq = WQLinear_GEMM( w_bit=bits, group_size=group_size, in_features=k, out_features=n, bias=False, dev=device, ) # Warmup AWQ with torch.no_grad(): res = linear_awq(inp) latencies = [] torch.cuda.synchronize() for _ in range(num_runs): start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) torch.cuda.synchronize() start_event.record() res = linear_awq(inp) end_event.record() torch.cuda.synchronize() latency_ms = start_event.elapsed_time(end_event) latencies.append(latency_ms) awq_mean_latency = np.mean(latencies) awq_p10 = np.percentile(latencies, 10) awq_p90 = np.percentile(latencies, 90) exllama_speedup = awq_mean_latency / exllamav2_mean_latency # print(f"Latency AWQ (query_length={query_length}, kernel={awq_kernel}): {np.mean(latencies):.3f} ms, p10={np.percentile(latencies, 10):.3f}, p90={np.percentile(latencies, 90):.3f}") line = "{},{},{},{},{},{},{},{},{},{},{}".format( bits, group_size, total_seqlen, awq_kernel, f"{awq_mean_latency:.3f}", f"{exllamav2_mean_latency:.3f}", f"{awq_p10:.3f}", f"{awq_p90:.3f}", f"{exllamav2_p10:.3f}", f"{exllamav2_p90:.3f}", f"{exllama_speedup:.3f}", ) lines.append(line) header = "bits, group_size, total_seqlen, awq_kernel, awq_mean_latency (ms), exllamav2_mean_latency (ms), awq_p10, awq_p90, exllamav2_p10, exllamav2_p90, exllama_speedup" print(header) for line in lines: print(line)