import os import unittest import torch import torch.utils.benchmark as benchmark from transformers import AutoTokenizer from auto_gptq import AutoGPTQForCausalLM MODEL_ID = "TheBloke/Llama-7B-GPTQ" DATASET_ID = "timdettmers/openassistant-guanaco" LEARNING_RATE = 3e-5 MAX_SEQ_LEN = 10 BATCH_SIZE = 5 NUM_TRAIN_STEPS = 10 os.environ["TOKENIZERS_PARALLELISM"] = "false" def benchmark_forward( fn, *inputs, repeats="auto", desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs, ): if verbose: print(desc, "- Forward pass") def amp_wrapper(*inputs, **kwinputs): with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): fn(*inputs, **kwinputs) t = benchmark.Timer( stmt="fn_amp(*inputs, **kwinputs)", globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs}, num_threads=torch.get_num_threads(), ) if repeats == "auto": m = t.blocked_autorange() else: m = t.timeit(repeats) if verbose: print(m) return t, m def get_model_and_tokenizer( model_id=MODEL_ID, inject_fused_attention=False, inject_fused_mlp=False, **model_kwargs, ): tokenizer = AutoTokenizer.from_pretrained( MODEL_ID, use_fast=True, ) if not tokenizer.pad_token_id: tokenizer.pad_token_id = tokenizer.eos_token_id model = AutoGPTQForCausalLM.from_quantized( model_id, trainable=True, inject_fused_attention=inject_fused_attention, inject_fused_mlp=inject_fused_mlp, disable_exllamav2=True, disable_exllama=True, **model_kwargs, ) model.warmup_triton() return model, tokenizer class TestTriton(unittest.TestCase): def test_triton_qlinear(self): ref_model, _ = get_model_and_tokenizer( model_id=MODEL_ID, use_triton=True, inject_fused_attention=False, inject_fused_mlp=False, ) test_model, _ = get_model_and_tokenizer( model_id=MODEL_ID, use_tritonv2=True, inject_fused_attention=False, inject_fused_mlp=False, ) hidden_size = ref_model.model.model.embed_tokens.weight.shape[1] test_data = torch.randn((1, 2048, hidden_size), dtype=torch.float16).cuda() qlinear_ref = ref_model.model.model.layers[0].self_attn.q_proj qlinear_test = test_model.model.model.layers[0].self_attn.q_proj test_out = qlinear_test(test_data) ref_out = qlinear_ref(test_data) self.assertTrue(torch.allclose(test_out, ref_out)) _, measure_triton = benchmark_forward(qlinear_ref, test_data, desc="Triton", verbose=True) _, measure_tritonv2 = benchmark_forward(qlinear_test, test_data, desc="Triton-v2", verbose=True) self.assertTrue(measure_tritonv2.mean < measure_triton.mean)