Commit bc2d5632 authored by root's avatar root
Browse files

init

parents
Pipeline #3222 failed with stages
in 0 seconds
# Modified from tilelang/examples/flash_attention/example_mha_fwd_bhsd.py
import torch
import tilelang
from tilelang.autotuner import autotune
from tilelang.profiler import do_bench
import tilelang.language as T
from tilelang.layout import make_swizzled_layout
import itertools
import argparse
from typing import Optional
def get_configs():
iter_params = dict(block_M=[128], block_N=[128], num_stages=[0, 1, 2], threads=[128, 256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=500, rep=100)
@tilelang.jit(
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
window_size=None, # None for full attention
sm_scale=None,
block_M=64,
block_N=64,
num_stages=1,
threads=128,
dtype: str = "float16"):
if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N"
if sm_scale is None:
sm_scale = (1.0 / dim)**0.5
scale = sm_scale * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
accum_dtype = "float"
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
if window_size is not None:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0,
-T.infinity(acc_s.dtype))
else:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# NOTE(wt): check_inf is necessary for sliding window attention.
for i in T.Parallel(block_M):
if window_size is not None:
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0,
scores_max[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
Sinks: T.Tensor([heads], dtype),
):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
sinks = T.alloc_fragment([block_M], dtype)
T.annotate_layout({
Q_shared: make_swizzled_layout(Q_shared),
K_shared: make_swizzled_layout(K_shared),
V_shared: make_swizzled_layout(V_shared),
O_shared: make_swizzled_layout(O_shared),
})
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Parallel(block_M):
sinks[i] = Sinks[by]
end = T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.alloc_local([1], 'int32')
if window_size is not None:
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
else:
start[0] = 0
for k in T.Pipelined(start[0], end, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i in T.Parallel(block_M):
logsum[i] += T.exp2(sinks[i] * 1.44269504 -
scores_max[i] * scale) # The only change for attention sink
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
return main
# Modified from https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def ref_program(query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
sinks: torch.Tensor,
sliding_window: Optional[int] = None,
dtype: torch.dtype = torch.float16) -> torch.Tensor:
query = query.transpose(1, 2).contiguous().unsqueeze(
3) # align with the original function's interface
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape
batch_size, num_keys, num_key_value_heads, head_dim = key.shape
start_q = num_keys - num_queries
sm_scale: float = 1.0 / head_dim**0.5
sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1).float()
key = key.unsqueeze(3)
value = value.unsqueeze(3)
pos_keys = torch.arange(num_keys, device=query.device)
pos_queries = torch.arange(num_queries, device=query.device) + start_q
mask = pos_keys[None, :] > pos_queries[:, None]
mask = mask.float().masked_fill(mask, float("-inf"))
if sliding_window:
too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1)
mask.masked_fill_(too_old, float("-inf"))
logits = torch.einsum("bqhmd,bkhmd->bhmqk", query.float(), key.float()) * sm_scale
logits = logits + mask[None, None, None, :, :]
logits_max = torch.max(logits, dim=-1, keepdim=True).values
logits_or_sinks_max = torch.maximum(sinks, logits_max)
sinks = torch.exp(sinks - logits_or_sinks_max)
unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
scores = unnormalized_scores / normalizer
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups,
head_dim).to(dtype)
return output.transpose(1, 2).contiguous()
def gen_inputs(
B,
H,
Sq,
Skv,
D,
dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
query = torch.randn([B, H, Sq, D], dtype=dtype, device='cuda')
key = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda')
value = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda')
sinks = torch.randn([H], dtype=dtype, device='cuda')
return query, key, value, sinks
def main(batch: int = 1,
heads: int = 1,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
window_size: int | None = None,
dtype: str = "float16",
tune: bool = False):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None:
print('Using sliding window attention.')
assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min(
window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else:
print('Using full attention.')
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul
if tune:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, window_size, dtype=dtype)
print(f"Best latency: {kernel.latency}")
print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}")
print(f"Best config: {kernel.config}")
else:
block_M = 128
block_N = 128
num_stages = 2
threads = 256
print(f"{block_M=}, {block_N=}, {num_stages=}, {threads=}")
kernel = flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
window_size,
block_M=block_M,
block_N=block_N,
num_stages=num_stages,
threads=threads,
dtype=dtype)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype)
torch.testing.assert_close(
kernel(Q, K, V, sinks),
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
rtol=1e-2,
atol=1e-2)
print("All checks passed.✅")
latency = do_bench(
lambda: ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
print("Tilelang: {:.2f} ms".format(latency))
print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query')
parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument(
'--window_size',
type=int,
default=None,
help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument('--tune', action='store_true', help='tune')
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype,
args.tune)
# Modified from tilelang/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py
# Optimized for Hopper architecture, with a benchmark to compare with official Triton impl
import torch
import tilelang
from tilelang.autotuner import autotune
from tilelang.profiler import do_bench
import tilelang.language as T
from tilelang.layout import make_swizzled_layout
import itertools
import argparse
from typing import Optional
def get_configs():
iter_params = dict(block_M=[128], block_N=[128], num_stages=[0, 1, 2], threads=[128, 256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=500, rep=100)
@tilelang.jit(
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
window_size=None, # None for full attention
sm_scale=None,
block_M=128,
block_N=128,
num_stages=2,
threads=256,
dtype: str = "float16"):
if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N"
if sm_scale is None:
sm_scale = (1.0 / dim)**0.5
scale = sm_scale * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
accum_dtype = "float"
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
if window_size is not None:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0,
-T.infinity(acc_s.dtype))
else:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# NOTE(wt): check_inf is necessary for sliding window attention.
for i in T.Parallel(block_M):
if window_size is not None:
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0,
scores_max[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
Sinks: T.Tensor([heads], dtype),
):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
sinks = T.alloc_fragment([block_M], dtype)
T.annotate_layout({
Q_shared: make_swizzled_layout(Q_shared),
K_shared: make_swizzled_layout(K_shared),
V_shared: make_swizzled_layout(V_shared),
O_shared: make_swizzled_layout(O_shared),
})
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Parallel(block_M):
sinks[i] = Sinks[by]
end = T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.alloc_local([1], 'int32')
if window_size is not None:
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
else:
start[0] = 0
for k in T.Pipelined(
start[0],
end,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i in T.Parallel(block_M):
logsum[i] += T.exp2(sinks[i] * 1.44269504 -
scores_max[i] * scale) # The only change for attention sink
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
return main
# Following functions are adapted and optimized from
# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def ref_program(query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
sinks: torch.Tensor,
sliding_window: Optional[int] = None,
dtype: torch.dtype = torch.float16) -> torch.Tensor:
query = query.transpose(1, 2).contiguous().unsqueeze(
3) # align with the original function'sinterface
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape
batch_size, num_keys, num_key_value_heads, head_dim = key.shape
start_q = num_keys - num_queries
sm_scale: float = 1.0 / head_dim**0.5
sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1).float()
key = key.unsqueeze(3)
value = value.unsqueeze(3)
pos_keys = torch.arange(num_keys, device=query.device)
pos_queries = torch.arange(num_queries, device=query.device) + start_q
mask = pos_keys[None, :] > pos_queries[:, None]
mask = mask.float().masked_fill(mask, float("-inf"))
if sliding_window:
too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1)
mask.masked_fill_(too_old, float("-inf"))
logits = torch.einsum("bqhmd,bkhmd->bhmqk", query.float(), key.float()) * sm_scale
logits = logits + mask[None, None, None, :, :]
logits_max = torch.max(logits, dim=-1, keepdim=True).values
logits_or_sinks_max = torch.maximum(sinks, logits_max)
sinks = torch.exp(sinks - logits_or_sinks_max)
unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
scores = unnormalized_scores / normalizer
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups,
head_dim).to(dtype)
return output.transpose(1, 2).contiguous()
def gen_inputs(
B,
H,
Sq,
Skv,
D,
dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
query = torch.randn([B, H, Sq, D], dtype=dtype, device='cuda')
key = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda')
value = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda')
sinks = torch.randn([H], dtype=dtype, device='cuda')
return query, key, value, sinks
def main(batch: int = 1,
heads: int = 32,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
window_size: int | None = None,
dtype: str = "float16",
tune: bool = False):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None:
print('Using sliding window attention.')
assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min(
window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else:
print('Using full attention.')
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul
if tune:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, window_size, dtype=dtype)
print(f"Best latency: {kernel.latency}")
print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}")
print(f"Best config: {kernel.config}")
else:
block_M = 128
block_N = 128
num_stages = 2
threads = 256
print(f"{block_M=}, {block_N=}, {num_stages=}, {threads=}")
kernel = flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
window_size,
block_M=block_M,
block_N=block_N,
num_stages=num_stages,
threads=threads,
dtype=dtype)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype)
torch.testing.assert_close(
kernel(Q, K, V, sinks),
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
rtol=1e-2,
atol=1e-2)
print("All checks passed.✅")
latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
print("Tilelang: {:.2f} ms".format(latency))
print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query')
parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument(
'--window_size',
type=int,
default=None,
help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument('--tune', action='store_true', help='tune')
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype,
args.tune)
import tilelang.testing
import example_mha_sink_fwd_bhsd
import example_mha_sink_fwd_bhsd_wgmma_pipelined
import example_gqa_sink_fwd_bhsd_wgmma_pipelined
import example_mha_sink_bwd_bhsd
import example_gqa_sink_bwd_bhsd
@tilelang.testing.requires_cuda
def test_example_mha_sink_fwd_bhsd_full_attn():
example_mha_sink_fwd_bhsd.main()
@tilelang.testing.requires_cuda
def test_example_mha_sink_fwd_bhsd_sliding_window():
example_mha_sink_fwd_bhsd.main(window_size=128)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_mha_sink_fwd_bhsd_wgmma_pipelined_full_attn():
example_mha_sink_fwd_bhsd_wgmma_pipelined.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_mha_sink_fwd_bhsd_wgmma_pipelined_sliding_window():
example_mha_sink_fwd_bhsd_wgmma_pipelined.main(window_size=128)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_gqa_sink_fwd_bhsd_wgmma_pipelined_full_attn():
example_gqa_sink_fwd_bhsd_wgmma_pipelined.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_gqa_sink_fwd_bhsd_wgmma_pipelined_sliding_window():
example_gqa_sink_fwd_bhsd_wgmma_pipelined.main(window_size=128)
@tilelang.testing.requires_cuda
def test_example_mha_sink_bwd_bhsd():
example_mha_sink_bwd_bhsd.main()
@tilelang.testing.requires_cuda
def test_example_mha_sink_bwd_bhsd_sliding_window():
example_mha_sink_bwd_bhsd.main(window_size=128)
@tilelang.testing.requires_cuda
def test_example_gqa_sink_bwd_bhsd():
example_gqa_sink_bwd_bhsd.main()
@tilelang.testing.requires_cuda
def test_example_gqa_sink_bwd_bhsd_sliding_window():
example_gqa_sink_bwd_bhsd.main(window_size=128)
if __name__ == "__main__":
tilelang.testing.main()
models/
\ No newline at end of file
---
license: mit
---
This is a Tilelang Implementation for the reproduced 1.58bit model from [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B). We replaced the original simulated Int8x3bit Quantized Inference Kernel with INT8xINT2 Kernel. We also evaluated the model's correctness and performance through `eval_correctness.py` and `benchmark_inference_latency.py`.
## Make Checkpoints for vLLM
We provide two scripts to make the checkpoints for vLLM. The first script is `generate_bitnet_model_native_format.sh`, which is used to make a checkpoint with fp16 uncompressed metaadta, the main difference with the original checkpoint is the `quant_config.json`, which allow vLLM to load the model and execute with a quant extension.
```bash
# move to the integration directory
cd /root/to/BitBLAS/integration/BitNet
# make the checkpoint
./maint/generate_bitnet_model_native_format.sh
# the output ckpy will be saved in the `./models/ckpt_bitnet_b1_58-3B` directory
```
The second script is `generate_bitnet_model_bitblas_format.sh`, which is used to make a checkpoint with BitBLAS compressed metadata, which can avoid the online dequantize sage for the profiling of vLLM, which lead to more efficient memory utilization.
```bash
./maint/generate_bitnet_model_bitblas_format.sh ./models/ckpt_bitnet_b1_58-3B ./models/ckpt_bitnet_b1_58-3B_bitblas
# the output ckpy will be saved in the `./models/ckpt_bitnet_b1_58-3B_bitblas` directory
```
Finnaly, you can use the ckpt in vLLM with:
```bash
cd vllm_workspace
# inference with the ckpt with fp16 uncompressed metadata
python3 inference_with_native_format.py
# inference with the ckpt with BitBLAS compressed metadata
python3 inference_with_bitblas_format.py
```
**Benchmark results of vLLM**
| Model | Framework | BS16IN32OUT128 | BS1IN512OUT1024 | BS32IN32OUT128 |
|------------------------|--------------------------|----------------|-----------------|----------------|
| bitnet-3b-1.58bits | pytorch | 106.83 | 49.34 | 209.03 |
| bitnet-3b-1.58bits | pytorch-tilelang | 240.33 | 103.09 | 493.31 |
| bitnet-3b-1.58bits | vllm-tilelang | 379.25 | 117.43 | 752.55 |
| bitnet-3b-1.58bits | vllm-tilelang-cuda-graph | 2543.58 | 1621.08 | 2731.79 |
## BitBLAS Results
### Performance
**Note:** To reproduce the results of BitBLAS, Please checkout the `benchmark_inference_latency.py`. To reproduce the results of the original model, Please checkout the [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B) repo.
| Model | Device | batchsize | in_seq | model | bitnet-1.58b-3b-huggingface | bitnet-1.58b-3b-bitblas |
|:---------------:|:------:|:---------:|:------:|:--------:|:---------------------------:|:-----------------------:|
| bitnet_b1_58-3B | A100 | 1 | 1 | LLAMA-3B | 177.6729107 | 64.17962909 |
| bitnet_b1_58-3B | A100 | 128 | 1 | LLAMA-3B | 188.6145592 | 63.48158518 |
| bitnet_b1_58-3B | A100 | 1 | 2048 | LLAMA-3B | 348.7066031 | 202.6877999 |
### On-the-Fly GPU Memory Footprint
We measured the GPU memory footprint through the `nvidia-smi` command. Please checkout `nvidia_measure_memory.sh` to get the real-time GPU memory usage. And then start a `benchmark_model_10k_loops.py` workload to measure the overall GPU memory usage.
| **Model** | **Device** | **batchsize** | **in_seq** | **bitnet-1.58b-3b-huggingface** | **bitnet-1.58b-3b-bitblas** |
|:---------------:|:----------:|:-------------:|:----------:|:-------------------------------:|:---------------------------:|
| bitnet_b1_58-3B | A100 | 1 | 1 | 7595 MB | 1729 MB |
| bitnet_b1_58-3B | A100 | 128 | 1 | 7677 MB | 1789 MB |
| bitnet_b1_58-3B | A100 | 1 | 2048 | 8731 MB | 3163 MB |
## PPL and Zero-shot Accuracy
The number is Reported from the [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B), Please checkout the `eval_ppl.py`.
PPL and zero-shot accuracy:
| Models | PPL| ARCe| ARCc| HS | BQ | OQ | PQ | WGe | Avg
|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|
| FP16 700M (reported) | 12.33 | 54.7 | 23.0 | 37.0 | 60.0 | 20.2 | 68.9 | 54.8 | 45.5 |
| BitNet b1.58 700M (reported) | 12.87 | 51.8 | 21.4 | 35.1 | 58.2 | 20.0 | 68.1 | 55.2 | 44.3 |
| BitNet b1.58 700M (reproduced) | 12.78 | 51.4 | 21.8 | 35.0 | 59.6 | 20.6 | 67.5 | 55.4 | 44.5 |
| FP16 1.3B (reported) | 11.25 | 56.9 | 23.5 | 38.5 | 59.1 | 21.6 | 70.0 | 53.9 | 46.2
| BitNet b1.58 1.3B (reported) | 11.29 | 54.9 | 24.2 | 37.7 | 56.7 | 19.6 | 68.8 | 55.8 | 45.4 |
| BitNet b1.58 1.3B (reproduced) | 11.19 | 55.8 | 23.7 | 37.6 | 59.0 | 20.2 | 69.2 | 56.0 | 45.9
| FP16 3B (reported) | 10.04 | 62.1 | 25.6 | 43.3 | 61.8 | 24.6 | 72.1 | 58.2 | 49.7
| BitNet b1.58 3B (reported) | 9.91 | 61.4 | 28.3 | 42.9 | 61.5 | 26.6 | 71.5 | 59.3 | 50.2
| BitNet b1.58 3B (reproduced) | 9.88 | 60.9 | 28.0 | 42.3 | 58.3 | 26.0 | 71.4 | 60.3 | 49.6 |
The differences between the reported numbers and the reproduced results are possibly variances from the training data processing, seeds, or other random factors.
## Citations
```bibtex
@article{ma2024era,
title={The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits},
author={Ma, Shuming and Wang, Hongyu and Ma, Lingxiao and Wang, Lei and Wang, Wenhui and Huang, Shaohan and Dong, Li and Wang, Ruiping and Xue, Jilong and Wei, Furu},
journal={arXiv preprint arXiv:2402.17764},
year={2024}
}
```
\ No newline at end of file
python benchmark_generate.py --bs 16 --in_seq_len 32 --out_seq_len 128 | tee b16_i32_o128.log
python benchmark_generate.py --bs 1 --in_seq_len 512 --out_seq_len 64 | tee b1_i512_o64.log
python benchmark_generate.py --bs 32 --in_seq_len 32 --out_seq_len 128 | tee b32_i32_o128.log
python benchmark_generate.py --bs 16 --in_seq_len 32 --out_seq_len 128 --bitblas | tee b16_i32_o128_bitblas.log
python benchmark_generate.py --bs 1 --in_seq_len 512 --out_seq_len 64 --bitblas | tee b1_i512_o64_bitblas.log
python benchmark_generate.py --bs 32 --in_seq_len 32 --out_seq_len 128 --bitblas | tee b32_i32_o128_bitblas.log
import torch
import bitblas
from modeling_bitnet import BitnetForCausalLM
from tokenization_bitnet import BitnetTokenizer
from transformers import GenerationConfig
import time
import argparse
torch.set_grad_enabled(False)
bitblas.set_log_level("INFO")
def generate_text_batch(model, tokenizer, prompts, max_length=100):
# Encode the input prompts as a batch
input_ids = tokenizer(
prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device)
# Generate cos and sin values (commented out as not used in generation)
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
# position_embeddings = model.embed_positions(position_ids)
# cos = position_embeddings[:, :, 0::2].cos()
# sin = position_embeddings[:, :, 1::2].sin()
generation_config = GenerationConfig(
max_length=max_length,
do_sample=True,
top_k=50,
top_p=0.95,
num_return_sequences=1,
)
start_time = time.time()
output_ids = model.generate(input_ids, generation_config=generation_config)
# output_ids = model.generate(input_ids, generation_config=generation_config, cos=cos, sin=sin)
end_time = time.time()
# Decode the output ids to text
generated_texts = [
tokenizer.decode(output_id, skip_special_tokens=True) for output_id in output_ids
]
generation_time = end_time - start_time
num_tokens = sum(len(output_id) for output_id in output_ids)
tokens_per_second = num_tokens / generation_time
print(f"Generated {num_tokens} tokens in {generation_time:.2f} seconds")
print(f"Tokens per second: {tokens_per_second:.2f}")
return generated_texts
def profile(model, input_data):
import numpy as np
model = model.cuda()
model.eval()
def get_runtime(num_repeats=1):
tic = time.time()
for _ in range(num_repeats):
_ = model(input_data)
torch.cuda.synchronize()
return (time.time() - tic) * 1000 / num_repeats
with torch.no_grad():
st = time.time()
while time.time() - st < 1.0:
get_runtime() # warmup
warmup_runtime = get_runtime()
num_repeats = max(1, int(1000 / warmup_runtime))
times = get_runtime(num_repeats)
return np.mean(times)
model_path = '1bitLLM/bitnet_b1_58-3B'
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--bs', default=16, type=int)
parser.add_argument('--in_seq_len', default=32, type=int)
parser.add_argument('--out_seq_len', default=128, type=int)
parser.add_argument('--bitblas', action='store_true')
args = parser.parse_args()
bs = args.bs
in_seq_len = args.in_seq_len
out_seq_len = args.out_seq_len
is_bitblas = args.bitblas
model = BitnetForCausalLM.from_pretrained(
model_path,
use_flash_attention_2=True,
torch_dtype=torch.float16,
).cuda().half()
if is_bitblas:
with torch.no_grad():
model.quantize()
tokenizer = BitnetTokenizer.from_pretrained(model_path)
prompt = ""
for _ in range(in_seq_len):
prompt += "Hello "
prompts = []
for _ in range(bs):
prompts.append(prompt)
max_length = out_seq_len + in_seq_len
print(generate_text_batch(model, tokenizer, prompts, max_length=max_length))
if __name__ == '__main__':
main()
import argparse
import torch
from modeling_bitnet import BitnetForCausalLM
torch.set_grad_enabled(False)
parser = argparse.ArgumentParser()
parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str)
def profile(model, input_data):
import time
import numpy as np
model = model.cuda()
model.eval()
def get_runtime(num_repeats=1):
tic = time.time()
for _ in range(num_repeats):
_ = model(input_data)
torch.cuda.synchronize()
return (time.time() - tic) * 1000 / num_repeats
with torch.no_grad():
st = time.time()
while time.time() - st < 1.0:
get_runtime() # warmup
warmup_runtime = get_runtime()
num_repeats = max(1, int(1000 / warmup_runtime))
times = get_runtime(num_repeats)
return np.mean(times)
def main():
model = BitnetForCausalLM.from_pretrained(
'1bitLLM/bitnet_b1_58-3B',
device_map='auto',
low_cpu_mem_usage=True,
use_flash_attention_2=True,
torch_dtype=torch.float16,
).half()
with torch.no_grad():
model.quantize()
model = torch.compile(model)
benchmark_sets = [(1024, 1), (1, 2048)]
for batch_size, seq_len in benchmark_sets:
input_id = torch.ones(batch_size, seq_len).long().cuda()
latency = profile(model, input_id)
print(f"Batch size: {batch_size}, Seq len: {seq_len}, Latency: {latency}")
if __name__ == '__main__':
main()
import argparse
import torch
from modeling_bitnet import BitnetForCausalLM
torch.set_grad_enabled(False)
parser = argparse.ArgumentParser()
parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str)
parser.add_argument("--batch_size", default=1, type=int)
parser.add_argument("--seq_len", default=1, type=int)
args = parser.parse_args()
seq_len = args.seq_len
batch_size = args.batch_size
def profile(model, input_data):
import time
import numpy as np
model = model.cuda()
model.eval()
def get_runtime(num_repeats=1):
tic = time.time()
for _ in range(num_repeats):
_ = model(input_data)
torch.cuda.synchronize()
return (time.time() - tic) * 1000 / num_repeats
with torch.no_grad():
st = time.time()
while time.time() - st < 1.0:
get_runtime() # warmup
warmup_runtime = get_runtime()
num_repeats = max(1, int(1000 / warmup_runtime))
times = get_runtime(num_repeats)
return np.mean(times)
def main():
model = BitnetForCausalLM.from_pretrained(
"1bitLLM/bitnet_b1_58-3B",
device_map="auto",
low_cpu_mem_usage=True,
use_flash_attention_2=True,
torch_dtype=torch.float16,
).half()
with torch.no_grad():
model._post_process_weights()
torch.cuda.empty_cache()
input_id = torch.ones(batch_size, seq_len).long().cuda()
for _ in range(10000):
_ = model(input_id)
if __name__ == "__main__":
main()
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" LLaMA model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
class BitnetConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`BitnetModel`]. It is used to instantiate an LLaMA
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the LLaMA-7B.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`BitnetModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. Bitnet 1 supports up to 2048 tokens,
Bitnet 2 up to 4096, CodeBitnet up to 16384.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2):
End of stream token id.
pretraining_tp (`int`, *optional*, defaults to 1):
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to understand more about it. This value is
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
issue](https://github.com/pytorch/pytorch/issues/76232).
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import BitnetModel, BitnetConfig
>>> # Initializing a LLaMA llama-7b style configuration
>>> configuration = BitnetConfig()
>>> # Initializing a model from the llama-7b style configuration
>>> model = BitnetModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "llama"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
weight_bits=1,
input_bits=8,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.weight_bits = weight_bits
self.input_bits = input_bits
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
f"got {self.rope_scaling}")
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor,
float) or rope_scaling_factor <= 1.0:
raise ValueError(
f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
import torch
import bitblas
from modeling_bitnet import BitnetForCausalLM
from tokenization_bitnet import BitnetTokenizer
from transformers import GenerationConfig
import time
import transformers
print(f"transformers version is {transformers.__version__}")
# version must be lower than or equal to 4.40
assert transformers.__version__ <= "4.40.0"
torch.set_grad_enabled(False)
bitblas.set_log_level("INFO")
def generate_text(model, tokenizer, prompt, max_length=100):
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.lm_head.weight.device)
# Generate cos and sin values
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
generation_config = GenerationConfig(
max_length=max_length,
do_sample=True,
top_k=50,
top_p=0.95,
num_return_sequences=1,
)
start_time = time.time()
output_ids = model.generate(input_ids, generation_config=generation_config)
end_time = time.time()
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
generation_time = end_time - start_time
num_tokens = len(output_ids[0])
tokens_per_second = num_tokens / generation_time
print(f"Generated {num_tokens} tokens in {generation_time:.2f} seconds")
print(f"Tokens per second: {tokens_per_second:.2f}")
return generated_text
def profile(model, input_data):
import numpy as np
model = model.cuda()
model.eval()
def get_runtime(num_repeats=1):
tic = time.time()
for _ in range(num_repeats):
_ = model(input_data)
torch.cuda.synchronize()
return (time.time() - tic) * 1000 / num_repeats
with torch.no_grad():
st = time.time()
while time.time() - st < 1.0:
get_runtime() # warmup
warmup_runtime = get_runtime()
num_repeats = max(1, int(1000 / warmup_runtime))
times = get_runtime(num_repeats)
return np.mean(times)
model_path = '1bitLLM/bitnet_b1_58-3B'
def main():
model = BitnetForCausalLM.from_pretrained(
model_path,
use_flash_attention_2=False,
torch_dtype=torch.float16,
).cuda().half()
tokenizer = BitnetTokenizer.from_pretrained(model_path, use_fast=False)
input_id = tokenizer("Hello")['input_ids']
input_id = torch.tensor(input_id).unsqueeze(0).cuda()
print("original model generated text:")
print(generate_text(model, tokenizer, "Hello", max_length=100))
model.quantize()
print("quantized model generated text:")
print(generate_text(model, tokenizer, "Hello", max_length=100))
if __name__ == '__main__':
main()
import argparse
import torch
from modeling_bitnet import BitnetForCausalLM
torch.set_grad_enabled(False)
parser = argparse.ArgumentParser()
parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str)
def profile(model, input_data):
import time
import numpy as np
model = model.cuda()
model.eval()
def get_runtime(num_repeats=1):
tic = time.time()
for _ in range(num_repeats):
_ = model(input_data)
torch.cuda.synchronize()
return (time.time() - tic) * 1000 / num_repeats
with torch.no_grad():
st = time.time()
while time.time() - st < 1.0:
get_runtime() # warmup
warmup_runtime = get_runtime()
num_repeats = max(1, int(1000 / warmup_runtime))
times = get_runtime(num_repeats)
return np.mean(times)
def main():
model = BitnetForCausalLM.from_pretrained(
'1bitLLM/bitnet_b1_58-3B',
device_map='auto',
low_cpu_mem_usage=True,
use_flash_attention_2=True,
torch_dtype=torch.float16,
).half()
print(f"gpu memory: {torch.cuda.memory_allocated() / 1024 ** 3} GB")
with torch.no_grad():
model._post_process_weights()
print(f"gpu memory BitBLAS: {torch.cuda.memory_allocated() / 1024 ** 3} GB")
if __name__ == '__main__':
main()
# pylint: disable=missing-docstring, invalid-name
"""This is modified from https://huggingface.co/1bitLLM/bitnet_b1_58-3B/blob/main/utils_quant.py."""
import math
import argparse
import torch
import random
from eval_utils import get_test_dataset
from modeling_bitnet import BitnetForCausalLM
from tokenization_bitnet import BitnetTokenizer
from tqdm import tqdm
torch.set_grad_enabled(False)
parser = argparse.ArgumentParser()
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str)
parser.add_argument('--seqlen', default=2048, type=int)
def calulate_loss(model, input, loss_fct):
output = model(input, use_cache=False, output_hidden_states=False, output_attentions=False)[0]
shift_logits = output[:, :-1, :].contiguous()
shift_labels = input[:, 1:]
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return loss
def main(args):
datasets = ['c4', 'wikitext2']
model = BitnetForCausalLM.from_pretrained(
args.hf_path,
use_flash_attention_2=True,
torch_dtype=torch.float16,
).cuda().half()
with torch.no_grad():
model._post_process_weights()
tokenizer = BitnetTokenizer.from_pretrained(args.hf_path, use_fast=False)
loss_fct = torch.nn.CrossEntropyLoss(reduction="sum").cuda()
ppl = []
for dataset in datasets:
testdata = get_test_dataset(dataset, tokenizer, seqlen=args.seqlen)
acc_loss, count = 0.0, 0
progress = tqdm(range(len(testdata)))
for ii in progress:
input = torch.Tensor(testdata[ii]).long().cuda().view(1, -1)
loss = calulate_loss(model, input, loss_fct)
count += (input.size(-1) - 1)
acc_loss += loss.item()
progress.set_description(f"avg_loss = {acc_loss/ count / math.log(2)}")
avg_loss = acc_loss / count / math.log(2)
ppl.append(2**avg_loss)
print("{} PPL: {}".format(dataset, ppl[-1]))
print(ppl)
print("Avg PPL:", sum(ppl) / len(ppl))
if __name__ == '__main__':
torch.set_grad_enabled(False)
args = parser.parse_args()
random.seed(args.seed)
torch.random.manual_seed(args.seed)
main(args)
# ruff: noqa
import torch
import numpy as np
import torch.nn.functional as F
from lm_eval.base import BaseLM
from datasets import load_dataset
def set_seed(seed):
np.random.seed(seed)
torch.random.manual_seed(seed)
def get_test_dataset(dataset_name, tokenizer, seqlen=2048):
if dataset_name == "wikitext2":
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
testdata = "".join(testdata['text']).split('\n')
elif dataset_name == "c4":
testdata = load_dataset(
'allenai/c4',
data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'},
split='validation')['text']
else:
raise NotImplementedError
testdata = [item for item in testdata if item != ""]
tokenized_text = [
tokenizer(item, add_special_tokens=False)['input_ids'] + [tokenizer.eos_token_id]
for item in testdata
]
data, doc = [], [tokenizer.bos_token_id]
for sen in tokenized_text:
if len(sen) > seqlen:
continue
if len(doc) + len(sen) > seqlen:
data.append(doc)
doc = [tokenizer.bos_token_id]
doc.extend(sen)
if len(doc) > 1 and len(doc) <= seqlen:
data.append(doc)
return data
class LMEvalAdaptor(BaseLM):
def __init__(self, model_name, model, tokenizer, batch_size=1, max_length=-1):
super().__init__()
assert isinstance(batch_size, int)
self.model_name = model_name
self.model = model
self.model.eval()
self.tokenizer = tokenizer
self.vocab_size = self.tokenizer.vocab_size
self._batch_size = batch_size
self._max_length = max_length
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
@property
def max_length(self):
if self._max_length != -1:
return self._max_length
if hasattr(self.model.config, "n_ctx"):
return self.model.config.n_ctx
elif hasattr(self.model.config, "max_position_embeddings"):
return self.model.config.max_position_embeddings
elif hasattr(self.model.config, "n_positions"):
return self.model.config.n_positions
elif "bloom" in self.model_name:
return 2048
elif "llama" in self.model_name:
return 2048 # TODO: did not check this
elif "mpt" in self.model_name:
return 2048
elif "falcon" in self.model_name:
return 2048
else:
print(self.model.config)
raise NotImplementedError
@property
def max_gen_toks(self):
return 256
@property
def batch_size(self):
return self._batch_size
@property
def device(self):
return "cuda"
def tok_encode(self, string: str, add_special_tokens=True):
return self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
def loglikelihood(self, requests):
new_reqs = []
for context, continuation in requests:
context, continuation = context.strip(), continuation.strip()
if context == "":
# end of text as context
context_enc = [self.eot_token_id]
else:
context_enc = self.tok_encode(context, add_special_tokens=True)
continuation_enc = self.tok_encode(continuation, add_special_tokens=False)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def _model_call(self, inps):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
with torch.no_grad():
out = self.model(inps)[0]
return out
def _model_generate(self, context, max_length, eos_token_id):
return self.model.generate(
context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.backends
import tilelang
import tilelang.language as T
from tilelang import tvm as tvm
from tvm import DataType
import numpy as np
from tilelang.transform import simplify_prim_func
torch.manual_seed(42)
decode_i2s_to_i8s = """template <typename T1, typename T2>
__device__ void decode_i2s_to_i8s(T1 *_i2b, T2 *_i8s, const int N = 16)
{
// convert 8 int2b_t to 8 int8b_t -> 2 int32
uint *i8s = reinterpret_cast<uint *>(_i8s);
// i2b = {e7,e6,e5,e4,e3,e2,e1,e0}
// also require interleave {e7,e3,e6,e2,e5,e1,e4,e0}
uint const i2b = *reinterpret_cast<uint *>(_i2b);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010
static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3
static constexpr uint I8s_MAGIC_NUM = 0x00000000; // 1024
static constexpr uint MEDIAN_NUM = 0x02020202;
#pragma unroll
for (int i = 0; i < (N / 4); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(i8s[i])
: "r"(i2b >> (2 * i)), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut));
i8s[i] = __vsub4(i8s[i], MEDIAN_NUM);
}
}
template <typename T1, typename T2>
__device__ void decode_i2u_to_i8s(T1 *_i2b, T2 *_i8s, const int N = 16)
{
// convert 8 int2b_t to 8 int8b_t -> 2 int32
uint *i8s = reinterpret_cast<uint *>(_i8s);
// i2b = {e7,e6,e5,e4,e3,e2,e1,e0}
// also require interleave {e7,e3,e6,e2,e5,e1,e4,e0}
uint const i2b = *reinterpret_cast<uint *>(_i2b);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010
static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3
static constexpr uint I8s_MAGIC_NUM = 0x00000000; // 1024
#pragma unroll
for (int i = 0; i < (N / 4); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(i8s[i])
: "r"(i2b >> (2 * i)), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut));
}
}
"""
@simplify_prim_func
def bitnet_158_int8xint2_decode(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
fast_decoding=True,
n_partition=4,
reduce_thread=32,
):
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
storage_nbit = 8
num_bits = 2
A_shape = (M, K)
B_shape = (N, K // storage_nbit * num_bits)
C_shape = (M, N)
num_elems_per_byte = 4
MAX_TRANSACTION_SIZE_IN_BITS = 128
micro_size_k = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits
micro_size_k_compressed = micro_size_k // num_elems_per_byte
storage_dtype = "int8"
block_K = reduce_thread * micro_size_k
use_dp4a = True
dp4a_size = 4
@T.prim_func
def kernel(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer(C_shape, out_dtype),
):
with T.Kernel(
T.ceildiv(N, n_partition),
M,
threads=(reduce_thread, n_partition),
) as (
bx,
by,
):
A_local = T.alloc_local((micro_size_k,), in_dtype)
B_quant_local = T.alloc_local([micro_size_k_compressed], storage_dtype)
B_dequantize_local = T.alloc_local([micro_size_k], in_dtype)
accum_res = T.alloc_local((1,), accum_dtype)
reduced_accum_res = T.alloc_local((1,), accum_dtype)
kr = T.thread_binding(0, reduce_thread, thread="threadIdx.x")
ni = T.thread_binding(0, n_partition, thread="threadIdx.y")
T.import_source(decode_i2s_to_i8s)
T.clear(accum_res)
for ko in T.serial(T.ceildiv(K, block_K)):
for v in T.vectorized(micro_size_k):
A_local[v] = A[by, ko * block_K + kr * micro_size_k + v]
for v in T.vectorized(micro_size_k_compressed):
B_quant_local[v] = B[
bx * n_partition + ni,
ko * (reduce_thread * micro_size_k_compressed) +
kr * micro_size_k_compressed + v,
]
T.call_extern(
"handle",
"decode_i2u_to_i8s",
T.address_of(B_quant_local[0]),
T.address_of(B_dequantize_local[0]),
)
if use_dp4a:
for ki in T.serial(micro_size_k // dp4a_size):
T.dp4a(
A_local[ki * dp4a_size],
B_dequantize_local[ki * dp4a_size],
accum_res[0],
)
else:
for ki in T.serial(micro_size_k):
accum_res[0] += A_local[ki] * B_dequantize_local[ki]
with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
):
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
accum_res[0],
True,
reduced_accum_res[0],
kr,
dtype="handle",
))
if kr == 0:
C[by, bx * n_partition + ni] = reduced_accum_res[0]
return kernel
def general_compress(lowprecision_weight, source_bits=4, storage_dtype=np.int8):
elems_per_byte = 8 // source_bits
if lowprecision_weight.dtype == np.float16:
lowprecision_weight = lowprecision_weight.astype(dtype=np.int8)
int8_weight = np.zeros(
(
*lowprecision_weight.shape[:-1],
lowprecision_weight.shape[-1] // elems_per_byte,
),
dtype=np.int8,
)
for j in range(lowprecision_weight.shape[-1] // elems_per_byte):
for k in range(elems_per_byte):
int8_weight[:, j] |= lowprecision_weight[:, j * elems_per_byte + k] << (source_bits * k)
return int8_weight.view(storage_dtype)
# interleave weight numpy implementation
def interleave_weight(qweight, nbits=4, target_dtype="float16"):
assert target_dtype in ["float16", "int8"]
# reinterpret the data type of qweight to int32
qweight = qweight.view(np.int32)
new_qweight = np.zeros_like(qweight)
bits_stride = 8 if target_dtype == "int8" else 16
mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f
num_groups = 32 // bits_stride
elems_per_group = bits_stride // nbits
for i in range(num_groups):
for j in range(elems_per_group):
offset = i * elems_per_group + j
shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits
new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift
if nbits == 1 and target_dtype == "int8":
# special handling for 1b interleave
n16_weight = new_qweight & np.int32(0xF0F00F0F)
n16_weight |= ((new_qweight & np.int32(0x000000F0)) >> 4) << 16
n16_weight |= ((new_qweight & np.int32(0x0000F000)) >> 12) << 24
n16_weight |= ((new_qweight & np.int32(0x000F0000)) >> 16) << 4
n16_weight |= ((new_qweight & np.int32(0x0F000000)) >> 24) << 12
return n16_weight.view(np.int8)
elif nbits == 2 and target_dtype == "float16":
n8_weight = new_qweight & np.int32(0xFF0000FF)
n8_weight |= ((new_qweight & np.int32(0x0000FF00)) >> 8) << 16
n8_weight |= ((new_qweight & np.int32(0x00FF0000)) >> 16) << 8
return n8_weight.view(np.int8)
elif nbits == 1 and target_dtype == "float16":
n8_weight = new_qweight & 0xF000000F
n8_weight |= ((new_qweight & 0x000000F0) >> 4) << 8
n8_weight |= ((new_qweight & 0x00000F00) >> 8) << 16
n8_weight |= ((new_qweight & 0x0000F000) >> 12) << 24
n8_weight |= ((new_qweight & 0x000F0000) >> 16) << 4
n8_weight |= ((new_qweight & 0x00F00000) >> 20) << 12
n8_weight |= ((new_qweight & 0x0F000000) >> 24) << 20
return new_qweight.view(np.int8)
def assert_bitnet_158_int8xint2_decode_correctness(M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
fast_decoding=True):
program = bitnet_158_int8xint2_decode(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding)
print(program)
kernel = tilelang.compile(program)
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source
assert src_code is not None
print(src_code)
A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype))
B = torch.randint(0, 2, (N, K), device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
qw = general_compress(B.cpu().numpy(), source_bits=2, storage_dtype=np.int8)
qw = interleave_weight(qw, 2, target_dtype=in_dtype)
qw = torch.from_numpy(qw).to(device="cuda")
kernel(A, qw, C)
# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype))
print(ref_c)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
if __name__ == "__main__":
assert_bitnet_158_int8xint2_decode_correctness(1, 256, 256, "int8", "int32", "int32")
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.backends
import tilelang
import tilelang.language as T
from tilelang import tvm as tvm
from tvm import DataType
from tilelang.intrinsics.mma_layout import (
make_mma_swizzle_layout as make_swizzle_layout,)
import numpy as np
from tilelang.intrinsics.mma_macro_generator import (
INT4TensorCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func
torch.manual_seed(42)
decode_i2s_to_i8s = """template <typename T1, typename T2>
__device__ void decode_i2s_to_i8s(T1 *_i2b, T2 *_i8s, const int N = 16)
{
// convert 8 int2b_t to 8 int8b_t -> 2 int32
uint *i8s = reinterpret_cast<uint *>(_i8s);
// i2b = {e7,e6,e5,e4,e3,e2,e1,e0}
// also require interleave {e7,e3,e6,e2,e5,e1,e4,e0}
uint const i2b = *reinterpret_cast<uint *>(_i2b);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010
static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3
static constexpr uint I8s_MAGIC_NUM = 0x00000000; // 1024
static constexpr uint MEDIAN_NUM = 0x02020202;
#pragma unroll
for (int i = 0; i < (N / 4); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(i8s[i])
: "r"(i2b >> (2 * i)), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut));
i8s[i] = __vsub4(i8s[i], MEDIAN_NUM);
}
}
template <typename T1, typename T2>
__device__ void decode_i2u_to_i8s(T1 *_i2b, T2 *_i8s, const int N = 16)
{
// convert 8 int2b_t to 8 int8b_t -> 2 int32
uint *i8s = reinterpret_cast<uint *>(_i8s);
// i2b = {e7,e6,e5,e4,e3,e2,e1,e0}
// also require interleave {e7,e3,e6,e2,e5,e1,e4,e0}
uint const i2b = *reinterpret_cast<uint *>(_i2b);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010
static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3
static constexpr uint I8s_MAGIC_NUM = 0x00000000; // 1024
#pragma unroll
for (int i = 0; i < (N / 4); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(i8s[i])
: "r"(i2b >> (2 * i)), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut));
}
}
"""
@simplify_prim_func
def bitnet_158_int8xint2_prefill(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
fast_decoding=True,
block_row_warps=2,
block_col_warps=2,
warp_row_tiles=32,
warp_col_tiles=32,
chunk=64,
):
"""
Create a TVM GPU prim_func implementing a block-tiled matrix multiply that multiplies dense A by compressed/interleaved low‑precision B (2-bit packed into int8 storage), decoding B to int8 on-chip and accumulating into C.
The returned prim_func expects:
- A: shape (M, K) with dtype `in_dtype` ("float16" or "int8").
- B: compressed storage with shape (N, K/4) and int8 storage layout (packing 4 2-bit elements per byte).
- C: output buffer shape (M, N) with dtype `out_dtype` ("float16", "float32", or "int32").
Details:
- Builds a tiled, pipelined kernel using shared memory and warp-level MMA intrinsics (INT4TensorCoreIntrinEmitter). B is loaded from compressed storage, decoded to int8 in threads (via decode_i2u_to_i8s / decode_i2s_to_i8s), and dequantized into a shared buffer used by the MMA emitter.
- Tiling parameters:
- block_row_warps, block_col_warps: number of warps per block in row/col.
- warp_row_tiles, warp_col_tiles: tiles per warp.
- chunk: K-sized chunk per block (block_K).
- micro sizes are fixed (16x16x16, except micro_k=32 when accum_dtype == "int32").
- Uses 2-stage pipelining by default to overlap loads and compute and applies a swizzle layout to improve L2 behavior.
- Assertions: raises AssertionError if in_dtype or out_dtype are not among supported values.
Parameters:
M, N, K (int): Global matrix dimensions.
in_dtype (str): Input and decoded B element dtype; "float16" or "int8".
out_dtype (str): Output C dtype; one of "float16", "float32", "int32".
accum_dtype (str): Accumulator dtype used by MMA (e.g., "int32").
fast_decoding (bool): If True, enable the fast decoding path (affects which device decode is used).
block_row_warps (int): Warps in block row dimension.
block_col_warps (int): Warps in block column dimension.
warp_row_tiles (int): Tiles per warp in row dimension.
warp_col_tiles (int): Tiles per warp in column dimension.
chunk (int): K-length per block (block_K).
Returns:
T.prim_func: A TVM prim_func implementing the described GPU kernel suitable for compilation and execution.
"""
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
if accum_dtype == "int32":
micro_size_k = 32
num_elems_per_byte = 4
MAX_TRANSACTION_SIZE_IN_BITS = 128
local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits
local_size_compressed = local_size // num_elems_per_byte
shared_scope = "shared.dyn"
storage_dtype = "int8"
# Pipeline Stage
stage = 2
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk
A_shape = (M, K) # int8 storage represents int4*2
B_shape = (N, K // num_elems_per_byte) # int8 storage represents int4*2
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)
warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
fragement_size_a = (micro_size_x * micro_size_k) // warp_size
fragement_size_b = (micro_size_y * micro_size_k) // warp_size
fragement_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = INT4TensorCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer((M, N), out_dtype),
):
"""
GPU kernel entry that performs a blocked, pipelined matrix multiplication A @ B.T writing into C.
This kernel:
- Loads tiles of A and a compressed/interleaved representation of B from global memory into shared memory.
- Decodes B's packed low-precision format (storage_dtype, e.g., 2-bit packed) into element values of `in_dtype` in shared memory via an external decode routine.
- Uses Warp/MMA tiled fragments and an INT4/INT2-capable MMA emitter to compute accumulation across K in a pipelined fashion with configurable stages.
- Writes accumulated tile results from shared memory back to global C with the expected block/micro-tile indexing.
Parameters:
A: Input matrix buffer of shape A_shape and element type `in_dtype`. Represents the MxK activations.
B: Compressed/interleaved weight buffer of shape B_shape and storage type `storage_dtype`. Must contain B in the packed low-precision layout expected by the decode routine used by this kernel.
C: Output buffer of shape (M, N) and type `out_dtype`; receives the resulting matrix (accumulated values are produced in `accum_dtype` and stored into C).
Side effects:
Writes results into C. Calls external device decode functions to expand B from its packed representation into shared memory before computation.
"""
with T.Kernel(
T.ceildiv(N, block_N),
T.ceildiv(M, block_M),
threads=threads,
prelude=decode_i2s_to_i8s,
) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope)
B_dequantize_shared = T.alloc_shared(
B_dequantize_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_frag = T.alloc_local((warp_rows * fragement_size_a), in_dtype)
B_frag = T.alloc_local((warp_cols * fragement_size_b), in_dtype)
C_frag = T.alloc_local((warp_rows * warp_cols * fragement_size_c), accum_dtype)
B_local = T.alloc_local([local_size_compressed], storage_dtype)
B_dequantize_local = T.alloc_local([local_size], in_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_dequantize_shared: make_swizzle_layout(B_dequantize_shared),
})
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_frag)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
# Load B into shared memory
for j, k in T.Parallel(block_N, block_K // num_elems_per_byte):
B_shared[j, k] = B[bx * block_N + j, ko * (block_K // num_elems_per_byte) + k]
for i in T.serial(block_N * block_K // num_elems_per_byte //
(threads * local_size_compressed)):
for v in T.vectorized(0, local_size_compressed):
index = (
i * threads * local_size_compressed +
thread_bindings * local_size_compressed + v)
vi, vj = T.index_to_coordinates(index, B_shared_shape)
B_local[v] = B_shared[vi, vj]
T.call_extern(
"handle",
"decode_i2u_to_i8s",
T.address_of(B_local[0]),
T.address_of(B_dequantize_local[0]),
)
for v in T.vectorized(0, local_size):
index = (i * threads * local_size + thread_bindings * local_size + v)
vi, vj = T.index_to_coordinates(index, B_dequantize_shared_shape)
B_dequantize_shared[vi, vj] = B_dequantize_local[v]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_frag,
A_shared,
ki,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_frag,
B_dequantize_shared,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_frag, B_frag, C_frag)
# Perform STMatrix
mma_emitter.stmatrix(
C_frag,
C_shared,
)
# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]
return main
def general_compress(lowprecision_weight, source_bits=4, storage_dtype=np.int8):
elems_per_byte = 8 // source_bits
if lowprecision_weight.dtype == np.float16:
lowprecision_weight = lowprecision_weight.astype(dtype=np.int8)
int8_weight = np.zeros(
(
*lowprecision_weight.shape[:-1],
lowprecision_weight.shape[-1] // elems_per_byte,
),
dtype=np.int8,
)
for j in range(lowprecision_weight.shape[-1] // elems_per_byte):
for k in range(elems_per_byte):
int8_weight[:, j] |= lowprecision_weight[:, j * elems_per_byte + k] << (source_bits * k)
return int8_weight.view(storage_dtype)
# interleave weight numpy implementation
def interleave_weight(qweight, nbits=4, target_dtype="float16"):
assert target_dtype in ["float16", "int8"]
# reinterpret the data type of qweight to int32
qweight = qweight.view(np.int32)
new_qweight = np.zeros_like(qweight)
bits_stride = 8 if target_dtype == "int8" else 16
mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f
num_groups = 32 // bits_stride
elems_per_group = bits_stride // nbits
for i in range(num_groups):
for j in range(elems_per_group):
offset = i * elems_per_group + j
shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits
new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift
if nbits == 1 and target_dtype == "int8":
# special handling for 1b interleave
n16_weight = new_qweight & np.int32(0xF0F00F0F)
n16_weight |= ((new_qweight & np.int32(0x000000F0)) >> 4) << 16
n16_weight |= ((new_qweight & np.int32(0x0000F000)) >> 12) << 24
n16_weight |= ((new_qweight & np.int32(0x000F0000)) >> 16) << 4
n16_weight |= ((new_qweight & np.int32(0x0F000000)) >> 24) << 12
return n16_weight.view(np.int8)
elif nbits == 2 and target_dtype == "float16":
n8_weight = new_qweight & np.int32(0xFF0000FF)
n8_weight |= ((new_qweight & np.int32(0x0000FF00)) >> 8) << 16
n8_weight |= ((new_qweight & np.int32(0x00FF0000)) >> 16) << 8
return n8_weight.view(np.int8)
elif nbits == 1 and target_dtype == "float16":
n8_weight = new_qweight & 0xF000000F
n8_weight |= ((new_qweight & 0x000000F0) >> 4) << 8
n8_weight |= ((new_qweight & 0x00000F00) >> 8) << 16
n8_weight |= ((new_qweight & 0x0000F000) >> 12) << 24
n8_weight |= ((new_qweight & 0x000F0000) >> 16) << 4
n8_weight |= ((new_qweight & 0x00F00000) >> 20) << 12
n8_weight |= ((new_qweight & 0x0F000000) >> 24) << 20
return new_qweight.view(np.int8)
def assert_bitnet_158_int8xint2_prefill_correctness(M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
fast_decoding=True):
program = bitnet_158_int8xint2_prefill(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding)
print(program)
kernel = tilelang.compile(program)
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source
assert src_code is not None
print(src_code)
A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype))
B = torch.randint(0, 2, (N, K), device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
qw = general_compress(B.cpu().numpy(), source_bits=2, storage_dtype=np.int8)
qw = interleave_weight(qw, 2, target_dtype=in_dtype)
qw = torch.from_numpy(qw).to(device="cuda")
kernel(A, qw, C)
# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype))
print(ref_c)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
if __name__ == "__main__":
assert_bitnet_158_int8xint2_prefill_correctness(256, 256, 256, "int8", "int32", "int32")
import torch
import torch.backends
from bitblas import tvm as tvm
from tvm import DataType
from tvm import tl as TL
import tvm.tl.language as T
from bitblas.tl.utils import get_swizzle_layout
from bitblas.tl.mma_macro_generator import (
TensorCoreIntrinEmitter,)
from bitblas.base import simplify_prim_func
torch.manual_seed(0)
def make_swizzle_layout(shared_buf):
dtype = shared_buf.dtype
shape = shared_buf.shape
can_swizzle = shape[-1] * DataType(dtype).bits == 512
if not can_swizzle:
return T.Layout(shape, lambda *args: args)
def transform_func(i, j):
new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype)
return [new_warp_i, new_warp_j]
return T.Layout(shape, transform_func)
@simplify_prim_func
def tl_matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
):
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32":
micro_size_k = 32
# This is a debug config
block_row_warps = 2
block_col_warps = 2
warp_row_tiles = 64
warp_col_tiles = 64
chunk = 32 if in_dtype == "float16" else 64
shared_scope = "shared.dyn"
# Pipeline Stage
stage = 2
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk
A_shape = (M, K)
B_shape = (N, K)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)
warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
# Load B into shared memory
for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)
# Perform STMatrix
mma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)
# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]
return main
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(matmul)
src_code = mod.imported_modules[0].get_source()
# src_code is the generated cuda source
assert src_code is not None
print(src_code)
if in_dtype == "int8":
A = torch.randint(-7, 7, (M, K), device="cuda", dtype=torch.int8)
B = torch.randint(-7, 7, (N, K), device="cuda", dtype=torch.int8)
else:
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(A, B, C)
latency = mod.do_bench(mod.func, warmup=25)
print(f"Latency: {latency}")
# Ensure that the latency is not None
assert latency is not None
# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype))
print(C)
print(ref_c)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16")
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32")
if __name__ == "__main__":
# bitblas.testing.main()
# assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16")
# assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32")
assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32")
import torch
import bitblas
from modeling_bitnet import BitnetForCausalLM
from tokenization_bitnet import BitnetTokenizer
import os
from transformers import GenerationConfig
import time
filepath = os.path.abspath(__file__)
dirpath = os.path.dirname(filepath)
torch.set_grad_enabled(False)
bitblas.set_log_level("INFO")
model_name_or_path = "BitBLASModel/open_llama_3b_1.58bits"
saved_model_path = os.path.join(dirpath, "models", f"{model_name_or_path}_bitblas")
def generate_text(model, tokenizer, prompt, max_length=100):
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.lm_head.weight.device)
# Generate cos and sin values
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
generation_config = GenerationConfig(
max_length=max_length,
do_sample=True,
top_k=50,
top_p=0.95,
num_return_sequences=1,
)
start_time = time.time()
output_ids = model.generate(input_ids, generation_config=generation_config)
end_time = time.time()
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
generation_time = end_time - start_time
num_tokens = len(output_ids[0])
tokens_per_second = num_tokens / generation_time
print(f"Generated {num_tokens} tokens in {generation_time:.2f} seconds")
print(f"Tokens per second: {tokens_per_second:.2f}")
return generated_text
def main():
# load quantized model
qmodel = BitnetForCausalLM.from_quantized(saved_model_path,).cuda().half()
tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False)
# print("original model generated text:")
# print(generate_text(model, tokenizer, "Hi, ", max_length=100))
input_ids = torch.ones((1, 1), dtype=torch.long).cuda()
# naive model inference
output = qmodel(input_ids)
print("original model output:", output)
print("quantized model generated text:")
print(generate_text(qmodel, tokenizer, "Hi, ", max_length=100))
if __name__ == "__main__":
main()
---
license: mit
---
This is a BitBLAS Implementation for the reproduced 1.58bit model from [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B). We replaced the original simulated Int8x3bit Quantized Inference Kernel with BitBLAS INT8xINT2 Kernel. We also evaluated the model's correctness and performance through `eval_correctness.py` and `benchmark_inference_latency.py`.
## Latest News
- 08/09/2024 ✨: We provide a more efficient implementation for bitnet with vLLM, which should use special model checkpoints, to make the ckpt and study how to deploy, please checkout [Make Checkpoints for vLLM](#make-checkpoints-for-vllm).
## Make Checkpoints for vLLM
We provide two scripts to make the checkpoints for vLLM. The first script is `generate_bitnet_model_native_format.sh`, which is used to make a checkpoint with fp16 uncompressed metaadta, the main difference with the original checkpoint is the `quant_config.json`, which allow vLLM to load the model and execute with a quant extension.
```bash
# move to the integration directory
cd /root/to/BitBLAS/integration/BitNet
# make the checkpoint
./maint/generate_bitnet_model_native_format.sh
# the output ckpy will be saved in the `./models/ckpt_bitnet_b1_58-3B` directory
```
The second script is `generate_bitnet_model_bitblas_format.sh`, which is used to make a checkpoint with BitBLAS compressed metadata, which can avoid the online dequantize sage for the profiling of vLLM, which lead to more efficient memory utilization.
```bash
./maint/generate_bitnet_model_bitblas_format.sh ./models/ckpt_bitnet_b1_58-3B ./models/ckpt_bitnet_b1_58-3B_bitblas
# the output ckpy will be saved in the `./models/ckpt_bitnet_b1_58-3B_bitblas` directory
```
Finnaly, you can use the ckpt in vLLM with:
```bash
cd vllm_workspace
# inference with the ckpt with fp16 uncompressed metadata
python3 inference_with_native_format.py
# inference with the ckpt with BitBLAS compressed metadata
python3 inference_with_bitblas_format.py
```
## BitBLAS Results
### Performance
**Note:** To reproduce the results of BitBLAS, Please checkout the `benchmark_inference_latency.py`. To reproduce the results of the original model, Please checkout the [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B) repo.
| Model | Device | batchsize | in_seq | model | bitnet-1.58b-3b-huggingface | bitnet-1.58b-3b-bitblas |
|:---------------:|:------:|:---------:|:------:|:--------:|:---------------------------:|:-----------------------:|
| bitnet_b1_58-3B | A100 | 1 | 1 | LLAMA-3B | 177.6729107 | 64.17962909 |
| bitnet_b1_58-3B | A100 | 128 | 1 | LLAMA-3B | 188.6145592 | 63.48158518 |
| bitnet_b1_58-3B | A100 | 1 | 2048 | LLAMA-3B | 348.7066031 | 202.6877999 |
### On-the-Fly GPU Memory Footprint
We measured the GPU memory footprint through the `nvidia-smi` command. Please checkout `nvidia_measure_memory.sh` to get the real-time GPU memory usage. And then start a `benchmark_model_10k_loops.py` workload to measure the overall GPU memory usage.
| **Model** | **Device** | **batchsize** | **in_seq** | **bitnet-1.58b-3b-huggingface** | **bitnet-1.58b-3b-bitblas** |
|:---------------:|:----------:|:-------------:|:----------:|:-------------------------------:|:---------------------------:|
| bitnet_b1_58-3B | A100 | 1 | 1 | 7595 MB | 1729 MB |
| bitnet_b1_58-3B | A100 | 128 | 1 | 7677 MB | 1789 MB |
| bitnet_b1_58-3B | A100 | 1 | 2048 | 8731 MB | 3163 MB |
## PPL and Zero-shot Accuracy
The number is Reported from the [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B), Please checkout the `eval_ppl.py`.
PPL and zero-shot accuracy:
| Models | PPL| ARCe| ARCc| HS | BQ | OQ | PQ | WGe | Avg
|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|
| FP16 700M (reported) | 12.33 | 54.7 | 23.0 | 37.0 | 60.0 | 20.2 | 68.9 | 54.8 | 45.5 |
| BitNet b1.58 700M (reported) | 12.87 | 51.8 | 21.4 | 35.1 | 58.2 | 20.0 | 68.1 | 55.2 | 44.3 |
| BitNet b1.58 700M (reproduced) | 12.78 | 51.4 | 21.8 | 35.0 | 59.6 | 20.6 | 67.5 | 55.4 | 44.5 |
| FP16 1.3B (reported) | 11.25 | 56.9 | 23.5 | 38.5 | 59.1 | 21.6 | 70.0 | 53.9 | 46.2
| BitNet b1.58 1.3B (reported) | 11.29 | 54.9 | 24.2 | 37.7 | 56.7 | 19.6 | 68.8 | 55.8 | 45.4 |
| BitNet b1.58 1.3B (reproduced) | 11.19 | 55.8 | 23.7 | 37.6 | 59.0 | 20.2 | 69.2 | 56.0 | 45.9
| FP16 3B (reported) | 10.04 | 62.1 | 25.6 | 43.3 | 61.8 | 24.6 | 72.1 | 58.2 | 49.7
| BitNet b1.58 3B (reported) | 9.91 | 61.4 | 28.3 | 42.9 | 61.5 | 26.6 | 71.5 | 59.3 | 50.2
| BitNet b1.58 3B (reproduced) | 9.88 | 60.9 | 28.0 | 42.3 | 58.3 | 26.0 | 71.4 | 60.3 | 49.6 |
The differences between the reported numbers and the reproduced results are possibly variances from the training data processing, seeds, or other random factors.
## Citations
```bibtex
@article{ma2024era,
title={The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits},
author={Ma, Shuming and Wang, Hongyu and Ma, Lingxiao and Wang, Lei and Wang, Wenhui and Huang, Shaohan and Dong, Li and Wang, Ruiping and Xue, Jilong and Wei, Furu},
journal={arXiv preprint arXiv:2402.17764},
year={2024}
}
```
\ No newline at end of file
import argparse
import torch
import bitblas
from transformers.utils.hub import cached_file
import os
from transformers import GenerationConfig
import time
import json
import sys
sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + "/../")
from modeling_bitnet import BitnetForCausalLM
from tokenization_bitnet import BitnetTokenizer
filepath = os.path.abspath(__file__)
dirpath = os.path.dirname(filepath)
torch.set_grad_enabled(False)
bitblas.set_log_level("INFO")
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", type=str, default="1bitLLM/bitnet_b1_58-3B")
parser.add_argument("--saved_model_path", type=str, default=None)
args = parser.parse_args()
model_name_or_path = args.model_name_or_path
saved_model_path = os.path.join(
dirpath, "models",
f"{model_name_or_path}_bitblas") if args.saved_model_path is None else args.saved_model_path
def generate_text(model, tokenizer, prompt, max_length=100):
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.lm_head.weight.device)
# Generate cos and sin values
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
generation_config = GenerationConfig(
max_length=max_length,
do_sample=True,
top_k=50,
top_p=0.95,
num_return_sequences=1,
)
start_time = time.time()
output_ids = model.generate(input_ids, generation_config=generation_config)
end_time = time.time()
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
generation_time = end_time - start_time
num_tokens = len(output_ids[0])
tokens_per_second = num_tokens / generation_time
print(f"Generated {num_tokens} tokens in {generation_time:.2f} seconds")
print(f"Tokens per second: {tokens_per_second:.2f}")
return generated_text
def main():
model = (
BitnetForCausalLM.from_pretrained(
model_name_or_path,
use_flash_attention_2=False,
torch_dtype=torch.float16,
).cuda().half())
tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False)
# print("original model generated text:")
# print(generate_text(model, tokenizer, "Hi, ", max_length=100))
input_ids = torch.ones((1, 1), dtype=torch.long).cuda()
# naive model inference
output = model(input_ids)
print("original model output:", output)
model.quantize(fuse_qkv=True, fuse_gateup=True)
print("original model generated text:")
print(generate_text(model, tokenizer, "Hi, ", max_length=100))
model.save_pretrained(saved_model_path)
# load quant config
quant_config_path = cached_file(model_name_or_path, "quantize_config.json")
with open(quant_config_path, "r") as f:
quant_config = json.load(f)
print("quant config:")
print(quant_config)
quant_config["checkpoint_format"] = "bitblas"
quant_config["fuse_qkv"] = True
quant_config["fuse_gateup"] = True
# save quant config
quant_config_path = os.path.join(saved_model_path, "quantize_config.json")
with open(quant_config_path, "w") as f:
json.dump(quant_config, f)
print("quant config saved to:", quant_config_path)
# copy benchmark filed into saved model path
file_list = [
"configuration_bitnet.py",
"eval_utils.py",
"modeling_bitnet.py",
"tokenization_bitnet.py",
"utils_quant.py",
"README.md",
]
for file in file_list:
file_path = cached_file(model_name_or_path, file)
os.system(f"cp {file_path} {saved_model_path}")
# load quantized model
qmodel = BitnetForCausalLM.from_quantized(saved_model_path,).cuda().half()
print("quantized model generated text:")
print(generate_text(qmodel, tokenizer, "Hi, ", max_length=100))
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment