Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
......@@ -18,27 +18,30 @@ def get_configs():
@autotune(configs=get_configs(), warmup=500, rep=100)
@tilelang.jit(
out_idx=[3], pass_configs={
out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
},
)
def flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
window_size=None, # None for full attention
sm_scale=None,
block_M=64,
block_N=64,
num_stages=1,
threads=128,
dtype: str = "float16"):
batch,
heads,
seq_q,
seq_kv,
dim,
window_size=None, # None for full attention
sm_scale=None,
block_M=64,
block_N=64,
num_stages=1,
threads=128,
dtype: str = "float16",
):
if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N"
if sm_scale is None:
sm_scale = (1.0 / dim)**0.5
sm_scale = (1.0 / dim) ** 0.5
scale = sm_scale * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
......@@ -58,13 +61,12 @@ def flashattn(
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
if window_size is not None:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0,
-T.infinity(acc_s.dtype))
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype))
else:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
......@@ -79,18 +81,18 @@ def flashattn(
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
......@@ -102,8 +104,7 @@ def flashattn(
# NOTE(wt): check_inf is necessary for sliding window attention.
for i in T.Parallel(block_M):
if window_size is not None:
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0,
scores_max[i])
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
......@@ -118,19 +119,19 @@ def flashattn(
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
Sinks: T.Tensor([heads], dtype),
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
Sinks: T.Tensor([heads], dtype),
):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
......@@ -147,53 +148,51 @@ def flashattn(
logsum = T.alloc_fragment([block_M], accum_dtype)
sinks = T.alloc_fragment([block_M], dtype)
T.annotate_layout({
Q_shared: make_swizzled_layout(Q_shared),
K_shared: make_swizzled_layout(K_shared),
V_shared: make_swizzled_layout(V_shared),
O_shared: make_swizzled_layout(O_shared),
})
T.annotate_layout(
{
Q_shared: make_swizzled_layout(Q_shared),
K_shared: make_swizzled_layout(K_shared),
V_shared: make_swizzled_layout(V_shared),
O_shared: make_swizzled_layout(O_shared),
}
)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Parallel(block_M):
sinks[i] = Sinks[by]
end = T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
end = T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.max(0, (bx * block_M + past_len - window_size) //
block_N) if window_size is not None else 0
start = T.max(0, (bx * block_M + past_len - window_size) // block_N) if window_size is not None else 0
for k in T.Pipelined(start, end, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i in T.Parallel(block_M):
logsum[i] += T.exp2(sinks[i] * 1.44269504 -
scores_max[i] * scale) # The only change for attention sink
logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return main
# Modified from https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def ref_program(query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
sinks: torch.Tensor,
sliding_window: Optional[int] = None,
dtype: torch.dtype = torch.float16) -> torch.Tensor:
query = query.transpose(1, 2).contiguous().unsqueeze(
3) # align with the original function's interface
def ref_program(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
sinks: torch.Tensor,
sliding_window: Optional[int] = None,
dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
query = query.transpose(1, 2).contiguous().unsqueeze(3) # align with the original function's interface
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
......@@ -228,41 +227,35 @@ def ref_program(query: torch.Tensor,
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups,
head_dim).to(dtype)
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype)
return output.transpose(1, 2).contiguous()
def gen_inputs(
B,
H,
Sq,
Skv,
D,
dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
query = torch.randn([B, H, Sq, D], dtype=dtype, device='cuda')
key = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda')
value = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda')
sinks = torch.randn([H], dtype=dtype, device='cuda')
def gen_inputs(B, H, Sq, Skv, D, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
query = torch.randn([B, H, Sq, D], dtype=dtype, device="cuda")
key = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda")
value = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda")
sinks = torch.randn([H], dtype=dtype, device="cuda")
return query, key, value, sinks
def main(batch: int = 1,
heads: int = 1,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False):
def main(
batch: int = 1,
heads: int = 1,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False,
):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None:
print('Using sliding window attention.')
print("Using sliding window attention.")
assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min(
window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else:
print('Using full attention.')
print("Using full attention.")
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul
......@@ -289,19 +282,17 @@ def main(batch: int = 1,
block_N=block_N,
num_stages=num_stages,
threads=threads,
dtype=dtype)
dtype=dtype,
)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype)
torch.testing.assert_close(
kernel(Q, K, V, sinks),
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
rtol=1e-2,
atol=1e-2)
kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2
)
print("All checks passed.✅")
latency = do_bench(
lambda: ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), warmup=500)
latency = do_bench(lambda: ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
......@@ -311,19 +302,13 @@ def main(batch: int = 1,
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query')
parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument(
'--window_size',
type=int,
default=None,
help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument('--tune', action='store_true', help='tune')
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query")
parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument("--tune", action="store_true", help="tune")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype,
args.tune)
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune)
......@@ -19,28 +19,30 @@ def get_configs():
@autotune(configs=get_configs(), warmup=500, rep=100)
@tilelang.jit(
out_idx=[3], pass_configs={
out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
},
)
def flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
window_size=None, # None for full attention
sm_scale=None,
block_M=128,
block_N=128,
num_stages=2,
threads=256,
dtype: str = "float16"):
batch,
heads,
seq_q,
seq_kv,
dim,
window_size=None, # None for full attention
sm_scale=None,
block_M=128,
block_N=128,
num_stages=2,
threads=256,
dtype: str = "float16",
):
if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N"
if sm_scale is None:
sm_scale = (1.0 / dim)**0.5
sm_scale = (1.0 / dim) ** 0.5
scale = sm_scale * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
......@@ -61,13 +63,12 @@ def flashattn(
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
if window_size is not None:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0,
-T.infinity(acc_s.dtype))
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype))
else:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
......@@ -82,18 +83,18 @@ def flashattn(
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
......@@ -105,8 +106,7 @@ def flashattn(
# NOTE(wt): check_inf is necessary for sliding window attention.
for i in T.Parallel(block_M):
if window_size is not None:
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0,
scores_max[i])
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
......@@ -121,19 +121,19 @@ def flashattn(
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
Sinks: T.Tensor([heads], dtype),
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
Sinks: T.Tensor([heads], dtype),
):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
......@@ -150,60 +150,59 @@ def flashattn(
logsum = T.alloc_fragment([block_M], accum_dtype)
sinks = T.alloc_fragment([block_M], dtype)
T.annotate_layout({
Q_shared: make_swizzled_layout(Q_shared),
K_shared: make_swizzled_layout(K_shared),
V_shared: make_swizzled_layout(V_shared),
O_shared: make_swizzled_layout(O_shared),
})
T.annotate_layout(
{
Q_shared: make_swizzled_layout(Q_shared),
K_shared: make_swizzled_layout(K_shared),
V_shared: make_swizzled_layout(V_shared),
O_shared: make_swizzled_layout(O_shared),
}
)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Parallel(block_M):
sinks[i] = Sinks[by]
end = T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
end = T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.max(0, (bx * block_M + past_len - window_size) //
block_N) if window_size is not None else 0
start = T.max(0, (bx * block_M + past_len - window_size) // block_N) if window_size is not None else 0
for k in T.Pipelined(
start,
end,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]):
start,
end,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]],
):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i in T.Parallel(block_M):
logsum[i] += T.exp2(sinks[i] * 1.44269504 -
scores_max[i] * scale) # The only change for attention sink
logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return main
# Following functions are adapted and optimized from
# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def ref_program(query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
sinks: torch.Tensor,
sliding_window: Optional[int] = None,
dtype: torch.dtype = torch.float16) -> torch.Tensor:
query = query.transpose(1, 2).contiguous().unsqueeze(
3) # align with the original function'sinterface
def ref_program(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
sinks: torch.Tensor,
sliding_window: Optional[int] = None,
dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
query = query.transpose(1, 2).contiguous().unsqueeze(3) # align with the original function'sinterface
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
......@@ -238,41 +237,35 @@ def ref_program(query: torch.Tensor,
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups,
head_dim).to(dtype)
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype)
return output.transpose(1, 2).contiguous()
def gen_inputs(
B,
H,
Sq,
Skv,
D,
dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
query = torch.randn([B, H, Sq, D], dtype=dtype, device='cuda')
key = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda')
value = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda')
sinks = torch.randn([H], dtype=dtype, device='cuda')
def gen_inputs(B, H, Sq, Skv, D, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
query = torch.randn([B, H, Sq, D], dtype=dtype, device="cuda")
key = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda")
value = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda")
sinks = torch.randn([H], dtype=dtype, device="cuda")
return query, key, value, sinks
def main(batch: int = 1,
heads: int = 32,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False):
def main(
batch: int = 1,
heads: int = 32,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False,
):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None:
print('Using sliding window attention.')
print("Using sliding window attention.")
assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min(
window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else:
print('Using full attention.')
print("Using full attention.")
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul
......@@ -299,15 +292,14 @@ def main(batch: int = 1,
block_N=block_N,
num_stages=num_stages,
threads=threads,
dtype=dtype)
dtype=dtype,
)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype)
torch.testing.assert_close(
kernel(Q, K, V, sinks),
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
rtol=1e-2,
atol=1e-2)
kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2
)
print("All checks passed.✅")
latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
......@@ -317,19 +309,13 @@ def main(batch: int = 1,
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query')
parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument(
'--window_size',
type=int,
default=None,
help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument('--tune', action='store_true', help='tune')
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query")
parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument("--tune", action="store_true", help="tune")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype,
args.tune)
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune)
......@@ -12,8 +12,7 @@ bitblas.set_log_level("INFO")
def generate_text_batch(model, tokenizer, prompts, max_length=100):
# Encode the input prompts as a batch
input_ids = tokenizer(
prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device)
input_ids = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device)
# Generate cos and sin values (commented out as not used in generation)
seq_length = input_ids.size(1)
......@@ -37,9 +36,7 @@ def generate_text_batch(model, tokenizer, prompts, max_length=100):
end_time = time.time()
# Decode the output ids to text
generated_texts = [
tokenizer.decode(output_id, skip_special_tokens=True) for output_id in output_ids
]
generated_texts = [tokenizer.decode(output_id, skip_special_tokens=True) for output_id in output_ids]
generation_time = end_time - start_time
num_tokens = sum(len(output_id) for output_id in output_ids)
......@@ -52,8 +49,8 @@ def generate_text_batch(model, tokenizer, prompts, max_length=100):
def profile(model, input_data):
import numpy as np
model = model.cuda()
model.eval()
......@@ -74,25 +71,29 @@ def profile(model, input_data):
return np.mean(times)
model_path = '1bitLLM/bitnet_b1_58-3B'
model_path = "1bitLLM/bitnet_b1_58-3B"
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--bs', default=16, type=int)
parser.add_argument('--in_seq_len', default=32, type=int)
parser.add_argument('--out_seq_len', default=128, type=int)
parser.add_argument('--bitblas', action='store_true')
parser.add_argument("--bs", default=16, type=int)
parser.add_argument("--in_seq_len", default=32, type=int)
parser.add_argument("--out_seq_len", default=128, type=int)
parser.add_argument("--bitblas", action="store_true")
args = parser.parse_args()
bs = args.bs
in_seq_len = args.in_seq_len
out_seq_len = args.out_seq_len
is_bitblas = args.bitblas
model = BitnetForCausalLM.from_pretrained(
model_path,
use_flash_attention_2=True,
torch_dtype=torch.float16,
).cuda().half()
model = (
BitnetForCausalLM.from_pretrained(
model_path,
use_flash_attention_2=True,
torch_dtype=torch.float16,
)
.cuda()
.half()
)
if is_bitblas:
with torch.no_grad():
model.quantize()
......@@ -109,5 +110,5 @@ def main():
print(generate_text_batch(model, tokenizer, prompts, max_length=max_length))
if __name__ == '__main__':
if __name__ == "__main__":
main()
......@@ -6,13 +6,14 @@ from modeling_bitnet import BitnetForCausalLM
torch.set_grad_enabled(False)
parser = argparse.ArgumentParser()
parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str)
parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str)
def profile(model, input_data):
import time
import numpy as np
model = model.cuda()
model.eval()
......@@ -35,8 +36,8 @@ def profile(model, input_data):
def main():
model = BitnetForCausalLM.from_pretrained(
'1bitLLM/bitnet_b1_58-3B',
device_map='auto',
"1bitLLM/bitnet_b1_58-3B",
device_map="auto",
low_cpu_mem_usage=True,
use_flash_attention_2=True,
torch_dtype=torch.float16,
......@@ -52,5 +53,5 @@ def main():
print(f"Batch size: {batch_size}, Seq len: {seq_len}, Latency: {latency}")
if __name__ == '__main__':
if __name__ == "__main__":
main()
......@@ -17,7 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" LLaMA model configuration"""
"""LLaMA model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
......@@ -180,16 +180,10 @@ class BitnetConfig(PretrainedConfig):
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
f"got {self.rope_scaling}")
raise ValueError(f"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, got {self.rope_scaling}")
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor,
float) or rope_scaling_factor <= 1.0:
raise ValueError(
f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
raise ValueError(f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}")
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
......@@ -47,8 +47,8 @@ def generate_text(model, tokenizer, prompt, max_length=100):
def profile(model, input_data):
import numpy as np
model = model.cuda()
model.eval()
......@@ -69,18 +69,22 @@ def profile(model, input_data):
return np.mean(times)
model_path = '1bitLLM/bitnet_b1_58-3B'
model_path = "1bitLLM/bitnet_b1_58-3B"
def main():
model = BitnetForCausalLM.from_pretrained(
model_path,
use_flash_attention_2=False,
torch_dtype=torch.float16,
).cuda().half()
model = (
BitnetForCausalLM.from_pretrained(
model_path,
use_flash_attention_2=False,
torch_dtype=torch.float16,
)
.cuda()
.half()
)
tokenizer = BitnetTokenizer.from_pretrained(model_path, use_fast=False)
input_id = tokenizer("Hello")['input_ids']
input_id = tokenizer("Hello")["input_ids"]
input_id = torch.tensor(input_id).unsqueeze(0).cuda()
print("original model generated text:")
......@@ -91,5 +95,5 @@ def main():
print(generate_text(model, tokenizer, "Hello", max_length=100))
if __name__ == '__main__':
if __name__ == "__main__":
main()
......@@ -6,13 +6,14 @@ from modeling_bitnet import BitnetForCausalLM
torch.set_grad_enabled(False)
parser = argparse.ArgumentParser()
parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str)
parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str)
def profile(model, input_data):
import time
import numpy as np
model = model.cuda()
model.eval()
......@@ -35,17 +36,17 @@ def profile(model, input_data):
def main():
model = BitnetForCausalLM.from_pretrained(
'1bitLLM/bitnet_b1_58-3B',
device_map='auto',
"1bitLLM/bitnet_b1_58-3B",
device_map="auto",
low_cpu_mem_usage=True,
use_flash_attention_2=True,
torch_dtype=torch.float16,
).half()
print(f"gpu memory: {torch.cuda.memory_allocated() / 1024 ** 3} GB")
print(f"gpu memory: {torch.cuda.memory_allocated() / 1024**3} GB")
with torch.no_grad():
model._post_process_weights()
print(f"gpu memory BitBLAS: {torch.cuda.memory_allocated() / 1024 ** 3} GB")
print(f"gpu memory BitBLAS: {torch.cuda.memory_allocated() / 1024**3} GB")
if __name__ == '__main__':
if __name__ == "__main__":
main()
......@@ -15,9 +15,9 @@ from tqdm import tqdm
torch.set_grad_enabled(False)
parser = argparse.ArgumentParser()
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str)
parser.add_argument('--seqlen', default=2048, type=int)
parser.add_argument("--seed", default=0, type=int)
parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str)
parser.add_argument("--seqlen", default=2048, type=int)
def calulate_loss(model, input, loss_fct):
......@@ -29,12 +29,16 @@ def calulate_loss(model, input, loss_fct):
def main(args):
datasets = ['c4', 'wikitext2']
model = BitnetForCausalLM.from_pretrained(
args.hf_path,
use_flash_attention_2=True,
torch_dtype=torch.float16,
).cuda().half()
datasets = ["c4", "wikitext2"]
model = (
BitnetForCausalLM.from_pretrained(
args.hf_path,
use_flash_attention_2=True,
torch_dtype=torch.float16,
)
.cuda()
.half()
)
with torch.no_grad():
model._post_process_weights()
tokenizer = BitnetTokenizer.from_pretrained(args.hf_path, use_fast=False)
......@@ -48,9 +52,9 @@ def main(args):
for ii in progress:
input = torch.Tensor(testdata[ii]).long().cuda().view(1, -1)
loss = calulate_loss(model, input, loss_fct)
count += (input.size(-1) - 1)
count += input.size(-1) - 1
acc_loss += loss.item()
progress.set_description(f"avg_loss = {acc_loss/ count / math.log(2)}")
progress.set_description(f"avg_loss = {acc_loss / count / math.log(2)}")
avg_loss = acc_loss / count / math.log(2)
ppl.append(2**avg_loss)
......@@ -60,7 +64,7 @@ def main(args):
print("Avg PPL:", sum(ppl) / len(ppl))
if __name__ == '__main__':
if __name__ == "__main__":
torch.set_grad_enabled(False)
args = parser.parse_args()
random.seed(args.seed)
......
......@@ -15,21 +15,17 @@ def set_seed(seed):
def get_test_dataset(dataset_name, tokenizer, seqlen=2048):
if dataset_name == "wikitext2":
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
testdata = "".join(testdata['text']).split('\n')
testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
testdata = "".join(testdata["text"]).split("\n")
elif dataset_name == "c4":
testdata = load_dataset(
'allenai/c4',
data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'},
split='validation')['text']
testdata = load_dataset("allenai/c4", data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, split="validation")[
"text"
]
else:
raise NotImplementedError
testdata = [item for item in testdata if item != ""]
tokenized_text = [
tokenizer(item, add_special_tokens=False)['input_ids'] + [tokenizer.eos_token_id]
for item in testdata
]
tokenized_text = [tokenizer(item, add_special_tokens=False)["input_ids"] + [tokenizer.eos_token_id] for item in testdata]
data, doc = [], [tokenizer.bos_token_id]
for sen in tokenized_text:
......@@ -45,7 +41,6 @@ def get_test_dataset(dataset_name, tokenizer, seqlen=2048):
class LMEvalAdaptor(BaseLM):
def __init__(self, model_name, model, tokenizer, batch_size=1, max_length=-1):
super().__init__()
......@@ -137,5 +132,4 @@ class LMEvalAdaptor(BaseLM):
return out
def _model_generate(self, context, max_length, eos_token_id):
return self.model.generate(
context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False)
return self.model.generate(context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False)
......@@ -102,17 +102,17 @@ def bitnet_158_int8xint2_decode(
@T.prim_func
def kernel(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer(C_shape, out_dtype),
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer(C_shape, out_dtype),
):
with T.Kernel(
T.ceildiv(N, n_partition),
M,
threads=(reduce_thread, n_partition),
T.ceildiv(N, n_partition),
M,
threads=(reduce_thread, n_partition),
) as (
bx,
by,
bx,
by,
):
A_local = T.alloc_local((micro_size_k,), in_dtype)
B_quant_local = T.alloc_local([micro_size_k_compressed], storage_dtype)
......@@ -133,8 +133,7 @@ def bitnet_158_int8xint2_decode(
for v in T.vectorized(micro_size_k_compressed):
B_quant_local[v] = B[
bx * n_partition + ni,
ko * (reduce_thread * micro_size_k_compressed) +
kr * micro_size_k_compressed + v,
ko * (reduce_thread * micro_size_k_compressed) + kr * micro_size_k_compressed + v,
]
T.call_extern(
......@@ -156,9 +155,9 @@ def bitnet_158_int8xint2_decode(
accum_res[0] += A_local[ki] * B_dequantize_local[ki]
with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
):
T.evaluate(
T.tvm_thread_allreduce(
......@@ -168,7 +167,8 @@ def bitnet_158_int8xint2_decode(
reduced_accum_res[0],
kr,
dtype="handle",
))
)
)
if kr == 0:
C[by, bx * n_partition + ni] = reduced_accum_res[0]
......@@ -234,13 +234,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"):
return new_qweight.view(np.int8)
def assert_bitnet_158_int8xint2_decode_correctness(M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
fast_decoding=True):
def assert_bitnet_158_int8xint2_decode_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding=True):
program = bitnet_158_int8xint2_decode(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding)
print(program)
kernel = tilelang.compile(program)
......
......@@ -8,11 +8,13 @@ import tilelang.language as T
from tilelang import tvm as tvm
from tvm import DataType
from tilelang.intrinsics.mma_layout import (
make_mma_swizzle_layout as make_swizzle_layout,)
make_mma_swizzle_layout as make_swizzle_layout,
)
import numpy as np
from tilelang.intrinsics.mma_macro_generator import (
INT4TensorCoreIntrinEmitter,)
INT4TensorCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func
torch.manual_seed(42)
......@@ -181,38 +183,36 @@ def bitnet_158_int8xint2_prefill(
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer((M, N), out_dtype),
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer((M, N), out_dtype),
):
"""
GPU kernel entry that performs a blocked, pipelined matrix multiplication A @ B.T writing into C.
GPU kernel entry that performs a blocked, pipelined matrix multiplication A @ B.T writing into C.
This kernel:
- Loads tiles of A and a compressed/interleaved representation of B from global memory into shared memory.
- Decodes B's packed low-precision format (storage_dtype, e.g., 2-bit packed) into element values of `in_dtype` in shared memory via an external decode routine.
- Uses Warp/MMA tiled fragments and an INT4/INT2-capable MMA emitter to compute accumulation across K in a pipelined fashion with configurable stages.
- Writes accumulated tile results from shared memory back to global C with the expected block/micro-tile indexing.
This kernel:
- Loads tiles of A and a compressed/interleaved representation of B from global memory into shared memory.
- Decodes B's packed low-precision format (storage_dtype, e.g., 2-bit packed) into element values of `in_dtype` in shared memory via an external decode routine.
- Uses Warp/MMA tiled fragments and an INT4/INT2-capable MMA emitter to compute accumulation across K in a pipelined fashion with configurable stages.
- Writes accumulated tile results from shared memory back to global C with the expected block/micro-tile indexing.
Parameters:
A: Input matrix buffer of shape A_shape and element type `in_dtype`. Represents the MxK activations.
B: Compressed/interleaved weight buffer of shape B_shape and storage type `storage_dtype`. Must contain B in the packed low-precision layout expected by the decode routine used by this kernel.
C: Output buffer of shape (M, N) and type `out_dtype`; receives the resulting matrix (accumulated values are produced in `accum_dtype` and stored into C).
Parameters:
A: Input matrix buffer of shape A_shape and element type `in_dtype`. Represents the MxK activations.
B: Compressed/interleaved weight buffer of shape B_shape and storage type `storage_dtype`. Must contain B in the packed low-precision layout expected by the decode routine used by this kernel.
C: Output buffer of shape (M, N) and type `out_dtype`; receives the resulting matrix (accumulated values are produced in `accum_dtype` and stored into C).
Side effects:
Writes results into C. Calls external device decode functions to expand B from its packed representation into shared memory before computation.
Side effects:
Writes results into C. Calls external device decode functions to expand B from its packed representation into shared memory before computation.
"""
with T.Kernel(
T.ceildiv(N, block_N),
T.ceildiv(M, block_M),
threads=threads,
prelude=decode_i2s_to_i8s,
T.ceildiv(N, block_N),
T.ceildiv(M, block_M),
threads=threads,
prelude=decode_i2s_to_i8s,
) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope)
B_dequantize_shared = T.alloc_shared(
B_dequantize_shared_shape, in_dtype, scope=shared_scope)
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_frag = T.alloc_local((warp_rows * fragement_size_a), in_dtype)
B_frag = T.alloc_local((warp_cols * fragement_size_b), in_dtype)
......@@ -223,10 +223,12 @@ def bitnet_158_int8xint2_prefill(
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_dequantize_shared: make_swizzle_layout(B_dequantize_shared),
})
T.annotate_layout(
{
A_shared: make_swizzle_layout(A_shared),
B_dequantize_shared: make_swizzle_layout(B_dequantize_shared),
}
)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
......@@ -234,7 +236,6 @@ def bitnet_158_int8xint2_prefill(
T.clear(C_frag)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
......@@ -243,12 +244,9 @@ def bitnet_158_int8xint2_prefill(
for j, k in T.Parallel(block_N, block_K // num_elems_per_byte):
B_shared[j, k] = B[bx * block_N + j, ko * (block_K // num_elems_per_byte) + k]
for i in T.serial(block_N * block_K // num_elems_per_byte //
(threads * local_size_compressed)):
for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * local_size_compressed)):
for v in T.vectorized(0, local_size_compressed):
index = (
i * threads * local_size_compressed +
thread_bindings * local_size_compressed + v)
index = i * threads * local_size_compressed + thread_bindings * local_size_compressed + v
vi, vj = T.index_to_coordinates(index, B_shared_shape)
B_local[v] = B_shared[vi, vj]
......@@ -260,12 +258,11 @@ def bitnet_158_int8xint2_prefill(
)
for v in T.vectorized(0, local_size):
index = (i * threads * local_size + thread_bindings * local_size + v)
index = i * threads * local_size + thread_bindings * local_size + v
vi, vj = T.index_to_coordinates(index, B_dequantize_shared_shape)
B_dequantize_shared[vi, vj] = B_dequantize_local[v]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_frag,
......@@ -360,13 +357,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"):
return new_qweight.view(np.int8)
def assert_bitnet_158_int8xint2_prefill_correctness(M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
fast_decoding=True):
def assert_bitnet_158_int8xint2_prefill_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding=True):
program = bitnet_158_int8xint2_prefill(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding)
print(program)
kernel = tilelang.compile(program)
......
......@@ -6,7 +6,8 @@ from tvm import tl as TL
import tvm.tl.language as T
from bitblas.tl.utils import get_swizzle_layout
from bitblas.tl.mma_macro_generator import (
TensorCoreIntrinEmitter,)
TensorCoreIntrinEmitter,
)
from bitblas.base import simplify_prim_func
torch.manual_seed(0)
......@@ -101,12 +102,11 @@ def tl_matmul(
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
......@@ -116,10 +116,12 @@ def tl_matmul(
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
T.annotate_layout(
{
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
}
)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
......@@ -127,7 +129,6 @@ def tl_matmul(
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
......@@ -137,7 +138,6 @@ def tl_matmul(
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
......
......@@ -49,7 +49,13 @@ def generate_text(model, tokenizer, prompt, max_length=100):
def main():
# load quantized model
qmodel = BitnetForCausalLM.from_quantized(saved_model_path,).cuda().half()
qmodel = (
BitnetForCausalLM.from_quantized(
saved_model_path,
)
.cuda()
.half()
)
tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False)
# print("original model generated text:")
# print(generate_text(model, tokenizer, "Hi, ", max_length=100))
......
......@@ -25,9 +25,9 @@ parser.add_argument("--saved_model_path", type=str, default=None)
args = parser.parse_args()
model_name_or_path = args.model_name_or_path
saved_model_path = os.path.join(
dirpath, "models",
f"{model_name_or_path}_bitblas") if args.saved_model_path is None else args.saved_model_path
saved_model_path = (
os.path.join(dirpath, "models", f"{model_name_or_path}_bitblas") if args.saved_model_path is None else args.saved_model_path
)
def generate_text(model, tokenizer, prompt, max_length=100):
......@@ -67,7 +67,10 @@ def main():
model_name_or_path,
use_flash_attention_2=False,
torch_dtype=torch.float16,
).cuda().half())
)
.cuda()
.half()
)
tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False)
# print("original model generated text:")
......@@ -112,10 +115,16 @@ def main():
file_path = cached_file(model_name_or_path, file)
os.system(f"cp {file_path} {saved_model_path}")
# load quantized model
qmodel = BitnetForCausalLM.from_quantized(saved_model_path,).cuda().half()
qmodel = (
BitnetForCausalLM.from_quantized(
saved_model_path,
)
.cuda()
.half()
)
print("quantized model generated text:")
print(generate_text(qmodel, tokenizer, "Hi, ", max_length=100))
if __name__ == '__main__':
if __name__ == "__main__":
main()
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment