This is a Tilelang Implementation for the reproduced 1.58bit model from [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B). We replaced the original simulated Int8x3bit Quantized Inference Kernel with INT8xINT2 Kernel. We also evaluated the model's correctness and performance through `eval_correctness.py` and `benchmark_inference_latency.py`.
## Make Checkpoints for vLLM
We provide two scripts to make the checkpoints for vLLM. The first script is `generate_bitnet_model_native_format.sh`, which is used to make a checkpoint with fp16 uncompressed metaadta, the main difference with the original checkpoint is the `quant_config.json`, which allow vLLM to load the model and execute with a quant extension.
```bash
# move to the integration directory
cd /root/to/BitBLAS/integration/BitNet
# make the checkpoint
./maint/generate_bitnet_model_native_format.sh
# the output ckpy will be saved in the `./models/ckpt_bitnet_b1_58-3B` directory
```
The second script is `generate_bitnet_model_bitblas_format.sh`, which is used to make a checkpoint with BitBLAS compressed metadata, which can avoid the online dequantize sage for the profiling of vLLM, which lead to more efficient memory utilization.
**Note:** To reproduce the results of BitBLAS, Please checkout the `benchmark_inference_latency.py`. To reproduce the results of the original model, Please checkout the [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B) repo.
| Model | Device | batchsize | in_seq | model | bitnet-1.58b-3b-huggingface | bitnet-1.58b-3b-bitblas |
We measured the GPU memory footprint through the `nvidia-smi` command. Please checkout `nvidia_measure_memory.sh` to get the real-time GPU memory usage. And then start a `benchmark_model_10k_loops.py` workload to measure the overall GPU memory usage.
The differences between the reported numbers and the reproduced results are possibly variances from the training data processing, seeds, or other random factors.
## Citations
```bibtex
@article{ma2024era,
title={The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits},
author={Ma, Shuming and Wang, Hongyu and Ma, Lingxiao and Wang, Lei and Wang, Wenhui and Huang, Shaohan and Dong, Li and Wang, Ruiping and Xue, Jilong and Wei, Furu},
Create a TVM GPU prim_func implementing a block-tiled matrix multiply that multiplies dense A by compressed/interleaved low‑precision B (2-bit packed into int8 storage), decoding B to int8 on-chip and accumulating into C.
The returned prim_func expects:
- A: shape (M, K) with dtype `in_dtype` ("float16" or "int8").
- B: compressed storage with shape (N, K/4) and int8 storage layout (packing 4 2-bit elements per byte).
- C: output buffer shape (M, N) with dtype `out_dtype` ("float16", "float32", or "int32").
Details:
- Builds a tiled, pipelined kernel using shared memory and warp-level MMA intrinsics (INT4TensorCoreIntrinEmitter). B is loaded from compressed storage, decoded to int8 in threads (via decode_i2u_to_i8s / decode_i2s_to_i8s), and dequantized into a shared buffer used by the MMA emitter.
- Tiling parameters:
- block_row_warps, block_col_warps: number of warps per block in row/col.
- warp_row_tiles, warp_col_tiles: tiles per warp.
- chunk: K-sized chunk per block (block_K).
- micro sizes are fixed (16x16x16, except micro_k=32 when accum_dtype == "int32").
- Uses 2-stage pipelining by default to overlap loads and compute and applies a swizzle layout to improve L2 behavior.
- Assertions: raises AssertionError if in_dtype or out_dtype are not among supported values.
Parameters:
M, N, K (int): Global matrix dimensions.
in_dtype (str): Input and decoded B element dtype; "float16" or "int8".
out_dtype (str): Output C dtype; one of "float16", "float32", "int32".
accum_dtype (str): Accumulator dtype used by MMA (e.g., "int32").
fast_decoding (bool): If True, enable the fast decoding path (affects which device decode is used).
block_row_warps (int): Warps in block row dimension.
block_col_warps (int): Warps in block column dimension.
warp_row_tiles (int): Tiles per warp in row dimension.
warp_col_tiles (int): Tiles per warp in column dimension.
chunk (int): K-length per block (block_K).
Returns:
T.prim_func: A TVM prim_func implementing the described GPU kernel suitable for compilation and execution.
"""
assertin_dtypein[
"float16",
"int8",
],"Currently only float16 and int8 are supported"
assertout_dtypein[
"float16",
"float32",
"int32",
],"Currently only float16, float32 and int32 are supported"
GPU kernel entry that performs a blocked, pipelined matrix multiplication A @ B.T writing into C.
This kernel:
- Loads tiles of A and a compressed/interleaved representation of B from global memory into shared memory.
- Decodes B's packed low-precision format (storage_dtype, e.g., 2-bit packed) into element values of `in_dtype` in shared memory via an external decode routine.
- Uses Warp/MMA tiled fragments and an INT4/INT2-capable MMA emitter to compute accumulation across K in a pipelined fashion with configurable stages.
- Writes accumulated tile results from shared memory back to global C with the expected block/micro-tile indexing.
Parameters:
A: Input matrix buffer of shape A_shape and element type `in_dtype`. Represents the MxK activations.
B: Compressed/interleaved weight buffer of shape B_shape and storage type `storage_dtype`. Must contain B in the packed low-precision layout expected by the decode routine used by this kernel.
C: Output buffer of shape (M, N) and type `out_dtype`; receives the resulting matrix (accumulated values are produced in `accum_dtype` and stored into C).
Side effects:
Writes results into C. Calls external device decode functions to expand B from its packed representation into shared memory before computation.
This is a BitBLAS Implementation for the reproduced 1.58bit model from [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B). We replaced the original simulated Int8x3bit Quantized Inference Kernel with BitBLAS INT8xINT2 Kernel. We also evaluated the model's correctness and performance through `eval_correctness.py` and `benchmark_inference_latency.py`.
## Latest News
- 08/09/2024 ✨: We provide a more efficient implementation for bitnet with vLLM, which should use special model checkpoints, to make the ckpt and study how to deploy, please checkout [Make Checkpoints for vLLM](#make-checkpoints-for-vllm).
## Make Checkpoints for vLLM
We provide two scripts to make the checkpoints for vLLM. The first script is `generate_bitnet_model_native_format.sh`, which is used to make a checkpoint with fp16 uncompressed metaadta, the main difference with the original checkpoint is the `quant_config.json`, which allow vLLM to load the model and execute with a quant extension.
```bash
# move to the integration directory
cd /root/to/BitBLAS/integration/BitNet
# make the checkpoint
./maint/generate_bitnet_model_native_format.sh
# the output ckpy will be saved in the `./models/ckpt_bitnet_b1_58-3B` directory
```
The second script is `generate_bitnet_model_bitblas_format.sh`, which is used to make a checkpoint with BitBLAS compressed metadata, which can avoid the online dequantize sage for the profiling of vLLM, which lead to more efficient memory utilization.
# the output ckpy will be saved in the `./models/ckpt_bitnet_b1_58-3B_bitblas` directory
```
Finnaly, you can use the ckpt in vLLM with:
```bash
cd vllm_workspace
# inference with the ckpt with fp16 uncompressed metadata
python3 inference_with_native_format.py
# inference with the ckpt with BitBLAS compressed metadata
python3 inference_with_bitblas_format.py
```
## BitBLAS Results
### Performance
**Note:** To reproduce the results of BitBLAS, Please checkout the `benchmark_inference_latency.py`. To reproduce the results of the original model, Please checkout the [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B) repo.
| Model | Device | batchsize | in_seq | model | bitnet-1.58b-3b-huggingface | bitnet-1.58b-3b-bitblas |
We measured the GPU memory footprint through the `nvidia-smi` command. Please checkout `nvidia_measure_memory.sh` to get the real-time GPU memory usage. And then start a `benchmark_model_10k_loops.py` workload to measure the overall GPU memory usage.
The differences between the reported numbers and the reproduced results are possibly variances from the training data processing, seeds, or other random factors.
## Citations
```bibtex
@article{ma2024era,
title={The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits},
author={Ma, Shuming and Wang, Hongyu and Ma, Lingxiao and Wang, Lei and Wang, Wenhui and Huang, Shaohan and Dong, Li and Wang, Ruiping and Xue, Jilong and Wei, Furu},