Commit 58c6b912 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Example] Add bitnet-1.58b examples (#399)

* Update copyright notice in example_mha_bwd_wgmma_pipelined.py to reflect Tile-AI Corporation ownership.

* lint fix
parent ca730c0a
MODEL_DIR=$1
REMOTE_DIR=$2
if [ ! -d "$MODEL_DIR" ]; then
echo "Model directory does not exist!"
exit 1
fi
cd $MODEL_DIR
if [ ! -d ".git" ]; then
rm -rf .git
fi
git init
git checkout -b main
git lfs install
git lfs track *.bin
git lfs track *.safetensors
git add .
git commit -m "Initial commit"
git remote add origin $REMOTE_DIR
huggingface-cli lfs-enable-largefiles .
git fetch origin
git push -f --set-upstream origin main
This diff is collapsed.
nvidia-smi --query-gpu=memory.used --format=csv -lms 500
lm_eval==0.3.0
flash_attn
transformers==4.40
\ No newline at end of file
This diff is collapsed.
# pylint: disable=missing-docstring, invalid-name
"""This is modified from https://huggingface.co/1bitLLM/bitnet_b1_58-3B/blob/main/utils_quant.py to work with BitBLAS."""
import torch
from torch import nn
from bitblas.cache import global_operator_cache, get_database_path
from bitblas import Matmul, MatmulConfig
from bitblas import auto_detect_nvidia_target
from logging import getLogger
logger = getLogger(__name__)
BITBLAS_TARGET = auto_detect_nvidia_target()
BITBLAS_DATABASE_PATH = get_database_path()
def weight_quant(weight, num_bits=1):
dtype = weight.dtype
weight = weight.float()
s = 1 / weight.abs().mean().clamp(min=1e-5)
result = (weight * s).round().clamp(-1, 1) / s
return result.type(dtype)
def activation_quant(x, num_bits=8):
dtype = x.dtype
x = x.float()
Qn = -(2**(num_bits - 1))
Qp = 2**(num_bits - 1) - 1
s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
result = (x * s).round().clamp(Qn, Qp) / s
return result.type(dtype)
class BitLinearBitBLAS(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
weight_bits=1,
input_bits=8,
**kwargs,
):
super().__init__()
"""
RMSNorm is placed outside BitLinear
"""
self.in_features = in_features
self.out_features = out_features
self.weight_bits = weight_bits
self.input_bits = input_bits
matmul_config = MatmulConfig(
N=self.out_features, # N dimension
K=self.in_features, # K dimension
A_dtype="int8", # activation A dtype
W_dtype="int2", # weight W dtype
accum_dtype="int32", # accumulation dtype
out_dtype="float32", # output dtype
layout="nt", # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
with_bias=False, # bias
# configs for weight only quantization
group_size=None, # setting for grouped quantization
with_scaling=False, # setting for scaling factor
with_zeros=False, # setting for zeros
zeros_mode=None, # setting for how to calculating zeros
)
ENABLE_TUNING = True
self.bitblas_matmul = self._get_or_create_bitblas_operator(matmul_config, ENABLE_TUNING)
self.format = "bitnet"
self.Qp = 2**(self.input_bits - 1) - 1
def _get_or_create_bitblas_operator(self, config, enable_tuning):
if global_operator_cache.size() == 0:
global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
logger.info(f"Loaded {global_operator_cache.size()} operators from database.")
bitblas_matmul = global_operator_cache.get(config)
if bitblas_matmul is None:
# should disable tuning for the first time because we may require loading bitblas operator from database.
bitblas_matmul = Matmul(config, target=BITBLAS_TARGET, enable_tuning=False)
if enable_tuning:
bitblas_matmul.hardware_aware_finetune(topk=20)
global_operator_cache.add(config, bitblas_matmul)
global_operator_cache.save_into_database(BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
print("BitBLAS Tuning done, appended operator to global_operator_cache.")
else:
print("BitBLAS Operator created.")
else:
print("BitBLAS Operator found in global_operator_cache.")
return bitblas_matmul
def replace_weight_param_with_qweight(self):
if hasattr(self, "weight"):
del self.weight
quant_weight = torch.empty(self.bitblas_matmul.retrieve_weight_shape())
self.qweight = nn.Parameter(quant_weight, requires_grad=False)
self.format = "bitblas"
@classmethod
def from_bit_linear(cls, bitlinear, weight_group=1):
bitblas_linear = cls(
bitlinear.in_features, bitlinear.out_features, weight_bits=1, input_bits=8)
sw, qweight = bitblas_linear.create_bitblas_weights(bitlinear.weight, weight_group)
bitblas_linear.register_buffer("qweight", qweight)
bitblas_linear.register_buffer("sw", sw)
if bitlinear.bias is not None:
bitblas_linear.register_buffer("bias", bitlinear.bias)
else:
bitblas_linear.bias = None
return bitblas_linear
def create_bitblas_weights(self, weight, weight_group=1):
if weight_group:
hidden_size = weight.size(0)
group_size = hidden_size // weight_group
sw_list = []
qweight_list = []
for i in range(weight_group):
start_idx = i * group_size
end_idx = (i + 1) * group_size
sw = 1 / weight[start_idx:end_idx].abs().mean().clamp(min=1e-5)
sw_list.append(sw.repeat(group_size))
qweight = self.weight_quant(weight[start_idx:end_idx]).detach()
qweight_list.append(qweight)
sw = torch.cat(sw_list, dim=0)
qweight = torch.cat(qweight_list, dim=0)
else:
sw = 1 / weight.abs().mean().clamp(min=1e-5)
qweight = self.weight_quant(weight).detach()
qweight = self.bitblas_matmul.transform_weight(qweight)
qweight = nn.Parameter(qweight, requires_grad=False)
return sw, qweight
def post_process_weights(self):
sw = 1 / self.weight.abs().mean().clamp(min=1e-5)
self.sw = sw
quant_weight = self.weight_quant(self.weight).detach()
quant_weight = self.bitblas_matmul.transform_weight(quant_weight)
# remove self.weight and replace it with quant_weight
if hasattr(self, "weight"):
del self.weight
self.qweight = nn.Parameter(quant_weight, requires_grad=False)
self.format = "bitblas"
@staticmethod
def weight_quant(weight):
weight = weight.float()
s = 1 / weight.abs().mean().clamp(min=1e-5)
result = (weight * s).round().clamp(-1, 1)
return result.type(torch.int8)
@torch.compile
def activation_quant(self, x, num_bits=8):
x = x.float()
Qn = -(2**(num_bits - 1))
Qp = 2**(num_bits - 1) - 1
s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
result = (x * s).round().clamp(Qn, Qp)
return result.type(torch.int8), s
@torch.compile
def post_quant_process(self, input, si, sw):
out = input / si
out = out / sw
out = out.half()
return out
# for the correctness evaluation.
def native_forward(self, input):
quant_input = (input + (activation_quant(input, self.input_bits) - input).detach())
quant_weight = (
self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach())
out = nn.functional.linear(quant_input, quant_weight)
if self.bias is not None:
out += self.bias.view(1, -1).expand_as(out)
return out
def forward_fp32_simulated(self, input):
quant_input, si = self.activation_quant(input, self.input_bits).detach()
quant_weight = self.weight_quant(self.weight).detach()
fp32_simulated_input = quant_input.float()
fp32_simulated_weight = quant_weight.float()
fp32_simulated_out = nn.functional.linear(fp32_simulated_input, fp32_simulated_weight)
sw = 1 / self.weight.abs().mean().clamp(min=1e-5)
# if / (si * sw) it will inf in some cases
out = fp32_simulated_out / si
out = out / sw
out = out.half()
if self.bias is not None:
out += self.bias.view(1, -1).expand_as(out)
return out
def forward(self, input):
# return self.forward_fp32_simulated(input)
quant_input, si = self.activation_quant(input, self.input_bits)
fp32_out = self.bitblas_matmul(quant_input, self.qweight)
sw = self.sw
# if / (si * sw) it will inf in some cases
out = self.post_quant_process(fp32_out, si, sw)
if self.bias is not None:
out += self.bias.view(1, -1).expand_as(out)
return out
# Naive BitLinear from HuggingFace
class BitLinear(nn.Linear):
def __init__(self, *kargs, weight_bits=1, input_bits=8, **kwargs):
super(BitLinear, self).__init__(*kargs, **kwargs)
"""
RMSNorm is placed outside BitLinear
"""
self.weight_bits = weight_bits
self.input_bits = input_bits
def forward(self, input):
quant_input = input + (activation_quant(input, self.input_bits) - input).detach()
quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) -
self.weight).detach()
out = nn.functional.linear(quant_input, quant_weight)
if self.bias is not None:
out += self.bias.view(1, -1).expand_as(out)
return out
This diff is collapsed.
"""Compare the outputs of a GPTQ model to a Marlin model.
Note: GPTQ and Marlin do not have bitwise correctness.
As a result, in this test, we just confirm that the top selected tokens of the
Marlin/GPTQ models are in the top 3 selections of each other.
Note: Marlin internally uses locks to synchronize the threads. This can
result in very slight nondeterminism for Marlin. As a result, we re-run the test
up to 3 times to see if we pass.
Run `pytest tests/models/test_marlin.py`.
"""
from conftest import VllmRunner
import os
import argparse
# get the path of the current file
current_file_path = os.path.realpath(__file__)
current_dir = os.path.dirname(current_file_path)
ckpt_path = os.path.join(current_dir, "../models/ckpt_bitnet_b1_58-3B_bitblas")
parser = argparse.ArgumentParser(description="Inference with BitNet")
parser.add_argument(
"--ckpt_path",
type=str,
default=ckpt_path,
help="Path to the checkpoint",
)
args = parser.parse_args()
ckpt_path = args.ckpt_path
with VllmRunner(
ckpt_path,
dtype="half",
quantization="bitblas",
# set enforce_eager = False to enable cuda graph
# set enforce_eager = True to disable cuda graph
enforce_eager=False,
) as bitnet_model:
bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"],
max_tokens=1024)
print("bitnet inference:")
print(bitbnet_outputs[0][0])
print(bitbnet_outputs[0][1])
"""Compare the outputs of a GPTQ model to a Marlin model.
Note: GPTQ and Marlin do not have bitwise correctness.
As a result, in this test, we just confirm that the top selected tokens of the
Marlin/GPTQ models are in the top 3 selections of each other.
Note: Marlin internally uses locks to synchronize the threads. This can
result in very slight nondeterminism for Marlin. As a result, we re-run the test
up to 3 times to see if we pass.
Run `pytest tests/models/test_marlin.py`.
"""
from conftest import VllmRunner
import os
import argparse
# get the path of the current file
current_file_path = os.path.realpath(__file__)
current_dir = os.path.dirname(current_file_path)
ckpt_path = os.path.join(current_dir, "../models/ckpt_bitnet_b1_58-3B")
parser = argparse.ArgumentParser(description="Inference with BitNet")
parser.add_argument(
"--ckpt_path",
type=str,
default=ckpt_path,
help="Path to the checkpoint",
)
args = parser.parse_args()
ckpt_path = args.ckpt_path
with VllmRunner(
ckpt_path,
dtype="half",
quantization="bitnet_bitblas",
gpu_memory_utilization=0.5,
# set enforce_eager = False to enable cuda graph
# set enforce_eager = True to disable cuda graph
enforce_eager=False,
) as bitnet_model:
bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"], max_tokens=128)
print("bitnet inference output:")
print(bitbnet_outputs[0][0])
print(bitbnet_outputs[0][1])
from typing import Dict, List, Tuple
TokensText = Tuple[List[int], str]
def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[TokensText],
name_0: str, name_1: str):
"""
Compare the two sequences generated by different models,
which should be equal.
"""
assert len(outputs_0_lst) == len(outputs_1_lst)
for prompt_idx, (outputs_0, outputs_1) in enumerate(zip(outputs_0_lst, outputs_1_lst)):
output_ids_0, output_str_0 = outputs_0
output_ids_1, output_str_1 = outputs_1
assert output_str_0 == output_str_1, (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
assert output_ids_0 == output_ids_1, (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]]
def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs],
outputs_1_lst: List[TokensTextLogprobs], name_0: str, name_1: str):
"""
Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal.
"""
assert len(outputs_0_lst) == len(outputs_1_lst)
# Loop through responses to each prompt.
for prompt_idx, (outputs_0, outputs_1) in enumerate(zip(outputs_0_lst, outputs_1_lst)):
output_ids_0, output_str_0, logprobs_0 = outputs_0
output_ids_1, output_str_1, logprobs_1 = outputs_1
# Loop through generated tokens.
for idx, (output_id_0, output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
# If generated tokens don't match, then
if output_id_0 != output_id_1:
# Each predicted token must be in top N logprobs of the other
assert output_id_0 in logprobs_1[idx], (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
assert output_id_1 in logprobs_0[idx], (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
# Break out since sequences will now diverge.
break
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tilelang import tilelang
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment