import argparse import torch from modeling_bitnet import BitnetForCausalLM torch.set_grad_enabled(False) parser = argparse.ArgumentParser() parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str) def profile(model, input_data): import time import numpy as np model = model.cuda() model.eval() def get_runtime(num_repeats=1): tic = time.time() for _ in range(num_repeats): _ = model(input_data) torch.cuda.synchronize() return (time.time() - tic) * 1000 / num_repeats with torch.no_grad(): st = time.time() while time.time() - st < 1.0: get_runtime() # warmup warmup_runtime = get_runtime() num_repeats = max(1, int(1000 / warmup_runtime)) times = get_runtime(num_repeats) return np.mean(times) def main(): model = BitnetForCausalLM.from_pretrained( "1bitLLM/bitnet_b1_58-3B", device_map="auto", low_cpu_mem_usage=True, use_flash_attention_2=True, torch_dtype=torch.float16, ).half() print(f"gpu memory: {torch.cuda.memory_allocated() / 1024**3} GB") with torch.no_grad(): model._post_process_weights() print(f"gpu memory BitBLAS: {torch.cuda.memory_allocated() / 1024**3} GB") if __name__ == "__main__": main()