Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
...@@ -18,27 +18,30 @@ def get_configs(): ...@@ -18,27 +18,30 @@ def get_configs():
@autotune(configs=get_configs(), warmup=500, rep=100) @autotune(configs=get_configs(), warmup=500, rep=100)
@tilelang.jit( @tilelang.jit(
out_idx=[3], pass_configs={ out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def flashattn( def flashattn(
batch, batch,
heads, heads,
seq_q, seq_q,
seq_kv, seq_kv,
dim, dim,
window_size=None, # None for full attention window_size=None, # None for full attention
sm_scale=None, sm_scale=None,
block_M=64, block_M=64,
block_N=64, block_N=64,
num_stages=1, num_stages=1,
threads=128, threads=128,
dtype: str = "float16"): dtype: str = "float16",
):
if window_size is not None: if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N" assert window_size % block_N == 0, "window_size must be divisible by block_N"
if sm_scale is None: if sm_scale is None:
sm_scale = (1.0 / dim)**0.5 sm_scale = (1.0 / dim) ** 0.5
scale = sm_scale * 1.44269504 # log2(e) scale = sm_scale * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim] q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim] kv_shape = [batch, heads, seq_kv, dim]
...@@ -58,13 +61,12 @@ def flashattn( ...@@ -58,13 +61,12 @@ def flashattn(
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j k_idx = k * block_N + j
if window_size is not None: if window_size is not None:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype))
-T.infinity(acc_s.dtype))
else: else:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
...@@ -79,18 +81,18 @@ def flashattn( ...@@ -79,18 +81,18 @@ def flashattn(
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
def Softmax( def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype), scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype), scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype), logsum: T.FragmentBuffer([block_M], accum_dtype),
): ):
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -102,8 +104,7 @@ def flashattn( ...@@ -102,8 +104,7 @@ def flashattn(
# NOTE(wt): check_inf is necessary for sliding window attention. # NOTE(wt): check_inf is necessary for sliding window attention.
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
if window_size is not None: if window_size is not None:
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
scores_max[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
...@@ -118,19 +119,19 @@ def flashattn( ...@@ -118,19 +119,19 @@ def flashattn(
@T.macro @T.macro
def Rescale( def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
): ):
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
@T.prim_func @T.prim_func
def main( def main(
Q: T.Tensor(q_shape, dtype), Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype), K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype), V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype), Output: T.Tensor(q_shape, dtype),
Sinks: T.Tensor([heads], dtype), Sinks: T.Tensor([heads], dtype),
): ):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
...@@ -147,53 +148,51 @@ def flashattn( ...@@ -147,53 +148,51 @@ def flashattn(
logsum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype)
sinks = T.alloc_fragment([block_M], dtype) sinks = T.alloc_fragment([block_M], dtype)
T.annotate_layout({ T.annotate_layout(
Q_shared: make_swizzled_layout(Q_shared), {
K_shared: make_swizzled_layout(K_shared), Q_shared: make_swizzled_layout(Q_shared),
V_shared: make_swizzled_layout(V_shared), K_shared: make_swizzled_layout(K_shared),
O_shared: make_swizzled_layout(O_shared), V_shared: make_swizzled_layout(V_shared),
}) O_shared: make_swizzled_layout(O_shared),
}
)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
sinks[i] = Sinks[by] sinks[i] = Sinks[by]
end = T.min( end = T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.max(0, (bx * block_M + past_len - window_size) // start = T.max(0, (bx * block_M + past_len - window_size) // block_N) if window_size is not None else 0
block_N) if window_size is not None else 0
for k in T.Pipelined(start, end, num_stages=num_stages): for k in T.Pipelined(start, end, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
logsum)
Rescale(acc_o, scores_scale) Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
logsum[i] += T.exp2(sinks[i] * 1.44269504 - logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink
scores_max[i] * scale) # The only change for attention sink
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared) T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return main return main
# Modified from https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py # Modified from https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def ref_program(query: torch.Tensor, def ref_program(
key: torch.Tensor, query: torch.Tensor,
value: torch.Tensor, key: torch.Tensor,
sinks: torch.Tensor, value: torch.Tensor,
sliding_window: Optional[int] = None, sinks: torch.Tensor,
dtype: torch.dtype = torch.float16) -> torch.Tensor: sliding_window: Optional[int] = None,
dtype: torch.dtype = torch.float16,
query = query.transpose(1, 2).contiguous().unsqueeze( ) -> torch.Tensor:
3) # align with the original function's interface query = query.transpose(1, 2).contiguous().unsqueeze(3) # align with the original function's interface
key = key.transpose(1, 2).contiguous() key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous()
...@@ -228,41 +227,35 @@ def ref_program(query: torch.Tensor, ...@@ -228,41 +227,35 @@ def ref_program(query: torch.Tensor,
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype)
head_dim).to(dtype)
return output.transpose(1, 2).contiguous() return output.transpose(1, 2).contiguous()
def gen_inputs( def gen_inputs(B, H, Sq, Skv, D, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
B, query = torch.randn([B, H, Sq, D], dtype=dtype, device="cuda")
H, key = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda")
Sq, value = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda")
Skv, sinks = torch.randn([H], dtype=dtype, device="cuda")
D,
dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
query = torch.randn([B, H, Sq, D], dtype=dtype, device='cuda')
key = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda')
value = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda')
sinks = torch.randn([H], dtype=dtype, device='cuda')
return query, key, value, sinks return query, key, value, sinks
def main(batch: int = 1, def main(
heads: int = 1, batch: int = 1,
seq_q: int = 256, heads: int = 1,
seq_kv: int = 256, seq_q: int = 256,
dim: int = 128, seq_kv: int = 256,
window_size: Optional[int] = None, dim: int = 128,
dtype: str = "float16", window_size: Optional[int] = None,
tune: bool = False): dtype: str = "float16",
tune: bool = False,
):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None: if window_size is not None:
print('Using sliding window attention.') print("Using sliding window attention.")
assert window_size <= seq_q assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min( flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else: else:
print('Using full attention.') print("Using full attention.")
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul total_flops = 2 * flops_per_matmul
...@@ -289,19 +282,17 @@ def main(batch: int = 1, ...@@ -289,19 +282,17 @@ def main(batch: int = 1,
block_N=block_N, block_N=block_N,
num_stages=num_stages, num_stages=num_stages,
threads=threads, threads=threads,
dtype=dtype) dtype=dtype,
)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype)
torch.testing.assert_close( torch.testing.assert_close(
kernel(Q, K, V, sinks), kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), )
rtol=1e-2,
atol=1e-2)
print("All checks passed.✅") print("All checks passed.✅")
latency = do_bench( latency = do_bench(lambda: ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), warmup=500)
lambda: ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), warmup=500)
print("Ref: {:.2f} ms".format(latency)) print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500) latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
...@@ -311,19 +302,13 @@ def main(batch: int = 1, ...@@ -311,19 +302,13 @@ def main(batch: int = 1,
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size') parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument('--heads', type=int, default=32, help='heads') parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query') parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query")
parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value') parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value")
parser.add_argument('--dim', type=int, default=128, help='dim') parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument( parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
'--window_size', parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16")
type=int, parser.add_argument("--tune", action="store_true", help="tune")
default=None,
help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument('--tune', action='store_true', help='tune')
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune)
args.tune)
...@@ -19,28 +19,30 @@ def get_configs(): ...@@ -19,28 +19,30 @@ def get_configs():
@autotune(configs=get_configs(), warmup=500, rep=100) @autotune(configs=get_configs(), warmup=500, rep=100)
@tilelang.jit( @tilelang.jit(
out_idx=[3], pass_configs={ out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def flashattn( def flashattn(
batch, batch,
heads, heads,
seq_q, seq_q,
seq_kv, seq_kv,
dim, dim,
window_size=None, # None for full attention window_size=None, # None for full attention
sm_scale=None, sm_scale=None,
block_M=128, block_M=128,
block_N=128, block_N=128,
num_stages=2, num_stages=2,
threads=256, threads=256,
dtype: str = "float16"): dtype: str = "float16",
):
if window_size is not None: if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N" assert window_size % block_N == 0, "window_size must be divisible by block_N"
if sm_scale is None: if sm_scale is None:
sm_scale = (1.0 / dim)**0.5 sm_scale = (1.0 / dim) ** 0.5
scale = sm_scale * 1.44269504 # log2(e) scale = sm_scale * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim] q_shape = [batch, heads, seq_q, dim]
...@@ -61,13 +63,12 @@ def flashattn( ...@@ -61,13 +63,12 @@ def flashattn(
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j k_idx = k * block_N + j
if window_size is not None: if window_size is not None:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype))
-T.infinity(acc_s.dtype))
else: else:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
...@@ -82,18 +83,18 @@ def flashattn( ...@@ -82,18 +83,18 @@ def flashattn(
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
def Softmax( def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype), scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype), scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype), logsum: T.FragmentBuffer([block_M], accum_dtype),
): ):
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -105,8 +106,7 @@ def flashattn( ...@@ -105,8 +106,7 @@ def flashattn(
# NOTE(wt): check_inf is necessary for sliding window attention. # NOTE(wt): check_inf is necessary for sliding window attention.
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
if window_size is not None: if window_size is not None:
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
scores_max[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
...@@ -121,19 +121,19 @@ def flashattn( ...@@ -121,19 +121,19 @@ def flashattn(
@T.macro @T.macro
def Rescale( def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
): ):
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
@T.prim_func @T.prim_func
def main( def main(
Q: T.Tensor(q_shape, dtype), Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype), K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype), V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype), Output: T.Tensor(q_shape, dtype),
Sinks: T.Tensor([heads], dtype), Sinks: T.Tensor([heads], dtype),
): ):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
...@@ -150,60 +150,59 @@ def flashattn( ...@@ -150,60 +150,59 @@ def flashattn(
logsum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype)
sinks = T.alloc_fragment([block_M], dtype) sinks = T.alloc_fragment([block_M], dtype)
T.annotate_layout({ T.annotate_layout(
Q_shared: make_swizzled_layout(Q_shared), {
K_shared: make_swizzled_layout(K_shared), Q_shared: make_swizzled_layout(Q_shared),
V_shared: make_swizzled_layout(V_shared), K_shared: make_swizzled_layout(K_shared),
O_shared: make_swizzled_layout(O_shared), V_shared: make_swizzled_layout(V_shared),
}) O_shared: make_swizzled_layout(O_shared),
}
)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
sinks[i] = Sinks[by] sinks[i] = Sinks[by]
end = T.min( end = T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.max(0, (bx * block_M + past_len - window_size) // start = T.max(0, (bx * block_M + past_len - window_size) // block_N) if window_size is not None else 0
block_N) if window_size is not None else 0
for k in T.Pipelined( for k in T.Pipelined(
start, start,
end, end,
num_stages=num_stages, num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2], order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1], stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]): group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]],
):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
logsum)
Rescale(acc_o, scores_scale) Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
logsum[i] += T.exp2(sinks[i] * 1.44269504 - logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink
scores_max[i] * scale) # The only change for attention sink
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared) T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return main return main
# Following functions are adapted and optimized from # Following functions are adapted and optimized from
# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py # https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def ref_program(query: torch.Tensor, def ref_program(
key: torch.Tensor, query: torch.Tensor,
value: torch.Tensor, key: torch.Tensor,
sinks: torch.Tensor, value: torch.Tensor,
sliding_window: Optional[int] = None, sinks: torch.Tensor,
dtype: torch.dtype = torch.float16) -> torch.Tensor: sliding_window: Optional[int] = None,
dtype: torch.dtype = torch.float16,
query = query.transpose(1, 2).contiguous().unsqueeze( ) -> torch.Tensor:
3) # align with the original function'sinterface query = query.transpose(1, 2).contiguous().unsqueeze(3) # align with the original function'sinterface
key = key.transpose(1, 2).contiguous() key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous()
...@@ -238,41 +237,35 @@ def ref_program(query: torch.Tensor, ...@@ -238,41 +237,35 @@ def ref_program(query: torch.Tensor,
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype)
head_dim).to(dtype)
return output.transpose(1, 2).contiguous() return output.transpose(1, 2).contiguous()
def gen_inputs( def gen_inputs(B, H, Sq, Skv, D, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
B, query = torch.randn([B, H, Sq, D], dtype=dtype, device="cuda")
H, key = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda")
Sq, value = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda")
Skv, sinks = torch.randn([H], dtype=dtype, device="cuda")
D,
dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
query = torch.randn([B, H, Sq, D], dtype=dtype, device='cuda')
key = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda')
value = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda')
sinks = torch.randn([H], dtype=dtype, device='cuda')
return query, key, value, sinks return query, key, value, sinks
def main(batch: int = 1, def main(
heads: int = 32, batch: int = 1,
seq_q: int = 256, heads: int = 32,
seq_kv: int = 256, seq_q: int = 256,
dim: int = 128, seq_kv: int = 256,
window_size: Optional[int] = None, dim: int = 128,
dtype: str = "float16", window_size: Optional[int] = None,
tune: bool = False): dtype: str = "float16",
tune: bool = False,
):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None: if window_size is not None:
print('Using sliding window attention.') print("Using sliding window attention.")
assert window_size <= seq_q assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min( flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else: else:
print('Using full attention.') print("Using full attention.")
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul total_flops = 2 * flops_per_matmul
...@@ -299,15 +292,14 @@ def main(batch: int = 1, ...@@ -299,15 +292,14 @@ def main(batch: int = 1,
block_N=block_N, block_N=block_N,
num_stages=num_stages, num_stages=num_stages,
threads=threads, threads=threads,
dtype=dtype) dtype=dtype,
)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype)
torch.testing.assert_close( torch.testing.assert_close(
kernel(Q, K, V, sinks), kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), )
rtol=1e-2,
atol=1e-2)
print("All checks passed.✅") print("All checks passed.✅")
latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500) latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
...@@ -317,19 +309,13 @@ def main(batch: int = 1, ...@@ -317,19 +309,13 @@ def main(batch: int = 1,
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size') parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument('--heads', type=int, default=32, help='heads') parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query') parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query")
parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value') parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value")
parser.add_argument('--dim', type=int, default=128, help='dim') parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument( parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
'--window_size', parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16")
type=int, parser.add_argument("--tune", action="store_true", help="tune")
default=None,
help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument('--tune', action='store_true', help='tune')
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune)
args.tune)
...@@ -12,8 +12,7 @@ bitblas.set_log_level("INFO") ...@@ -12,8 +12,7 @@ bitblas.set_log_level("INFO")
def generate_text_batch(model, tokenizer, prompts, max_length=100): def generate_text_batch(model, tokenizer, prompts, max_length=100):
# Encode the input prompts as a batch # Encode the input prompts as a batch
input_ids = tokenizer( input_ids = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device)
prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device)
# Generate cos and sin values (commented out as not used in generation) # Generate cos and sin values (commented out as not used in generation)
seq_length = input_ids.size(1) seq_length = input_ids.size(1)
...@@ -37,9 +36,7 @@ def generate_text_batch(model, tokenizer, prompts, max_length=100): ...@@ -37,9 +36,7 @@ def generate_text_batch(model, tokenizer, prompts, max_length=100):
end_time = time.time() end_time = time.time()
# Decode the output ids to text # Decode the output ids to text
generated_texts = [ generated_texts = [tokenizer.decode(output_id, skip_special_tokens=True) for output_id in output_ids]
tokenizer.decode(output_id, skip_special_tokens=True) for output_id in output_ids
]
generation_time = end_time - start_time generation_time = end_time - start_time
num_tokens = sum(len(output_id) for output_id in output_ids) num_tokens = sum(len(output_id) for output_id in output_ids)
...@@ -52,8 +49,8 @@ def generate_text_batch(model, tokenizer, prompts, max_length=100): ...@@ -52,8 +49,8 @@ def generate_text_batch(model, tokenizer, prompts, max_length=100):
def profile(model, input_data): def profile(model, input_data):
import numpy as np import numpy as np
model = model.cuda() model = model.cuda()
model.eval() model.eval()
...@@ -74,25 +71,29 @@ def profile(model, input_data): ...@@ -74,25 +71,29 @@ def profile(model, input_data):
return np.mean(times) return np.mean(times)
model_path = '1bitLLM/bitnet_b1_58-3B' model_path = "1bitLLM/bitnet_b1_58-3B"
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--bs', default=16, type=int) parser.add_argument("--bs", default=16, type=int)
parser.add_argument('--in_seq_len', default=32, type=int) parser.add_argument("--in_seq_len", default=32, type=int)
parser.add_argument('--out_seq_len', default=128, type=int) parser.add_argument("--out_seq_len", default=128, type=int)
parser.add_argument('--bitblas', action='store_true') parser.add_argument("--bitblas", action="store_true")
args = parser.parse_args() args = parser.parse_args()
bs = args.bs bs = args.bs
in_seq_len = args.in_seq_len in_seq_len = args.in_seq_len
out_seq_len = args.out_seq_len out_seq_len = args.out_seq_len
is_bitblas = args.bitblas is_bitblas = args.bitblas
model = BitnetForCausalLM.from_pretrained( model = (
model_path, BitnetForCausalLM.from_pretrained(
use_flash_attention_2=True, model_path,
torch_dtype=torch.float16, use_flash_attention_2=True,
).cuda().half() torch_dtype=torch.float16,
)
.cuda()
.half()
)
if is_bitblas: if is_bitblas:
with torch.no_grad(): with torch.no_grad():
model.quantize() model.quantize()
...@@ -109,5 +110,5 @@ def main(): ...@@ -109,5 +110,5 @@ def main():
print(generate_text_batch(model, tokenizer, prompts, max_length=max_length)) print(generate_text_batch(model, tokenizer, prompts, max_length=max_length))
if __name__ == '__main__': if __name__ == "__main__":
main() main()
...@@ -6,13 +6,14 @@ from modeling_bitnet import BitnetForCausalLM ...@@ -6,13 +6,14 @@ from modeling_bitnet import BitnetForCausalLM
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str) parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str)
def profile(model, input_data): def profile(model, input_data):
import time import time
import numpy as np import numpy as np
model = model.cuda() model = model.cuda()
model.eval() model.eval()
...@@ -35,8 +36,8 @@ def profile(model, input_data): ...@@ -35,8 +36,8 @@ def profile(model, input_data):
def main(): def main():
model = BitnetForCausalLM.from_pretrained( model = BitnetForCausalLM.from_pretrained(
'1bitLLM/bitnet_b1_58-3B', "1bitLLM/bitnet_b1_58-3B",
device_map='auto', device_map="auto",
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
use_flash_attention_2=True, use_flash_attention_2=True,
torch_dtype=torch.float16, torch_dtype=torch.float16,
...@@ -52,5 +53,5 @@ def main(): ...@@ -52,5 +53,5 @@ def main():
print(f"Batch size: {batch_size}, Seq len: {seq_len}, Latency: {latency}") print(f"Batch size: {batch_size}, Seq len: {seq_len}, Latency: {latency}")
if __name__ == '__main__': if __name__ == "__main__":
main() main()
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" LLaMA model configuration""" """LLaMA model configuration"""
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging from transformers.utils import logging
...@@ -180,16 +180,10 @@ class BitnetConfig(PretrainedConfig): ...@@ -180,16 +180,10 @@ class BitnetConfig(PretrainedConfig):
return return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError( raise ValueError(f"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, got {self.rope_scaling}")
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
f"got {self.rope_scaling}")
rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None) rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError( raise ValueError(f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}")
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
) raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
if rope_scaling_factor is None or not isinstance(rope_scaling_factor,
float) or rope_scaling_factor <= 1.0:
raise ValueError(
f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
...@@ -47,8 +47,8 @@ def generate_text(model, tokenizer, prompt, max_length=100): ...@@ -47,8 +47,8 @@ def generate_text(model, tokenizer, prompt, max_length=100):
def profile(model, input_data): def profile(model, input_data):
import numpy as np import numpy as np
model = model.cuda() model = model.cuda()
model.eval() model.eval()
...@@ -69,18 +69,22 @@ def profile(model, input_data): ...@@ -69,18 +69,22 @@ def profile(model, input_data):
return np.mean(times) return np.mean(times)
model_path = '1bitLLM/bitnet_b1_58-3B' model_path = "1bitLLM/bitnet_b1_58-3B"
def main(): def main():
model = BitnetForCausalLM.from_pretrained( model = (
model_path, BitnetForCausalLM.from_pretrained(
use_flash_attention_2=False, model_path,
torch_dtype=torch.float16, use_flash_attention_2=False,
).cuda().half() torch_dtype=torch.float16,
)
.cuda()
.half()
)
tokenizer = BitnetTokenizer.from_pretrained(model_path, use_fast=False) tokenizer = BitnetTokenizer.from_pretrained(model_path, use_fast=False)
input_id = tokenizer("Hello")['input_ids'] input_id = tokenizer("Hello")["input_ids"]
input_id = torch.tensor(input_id).unsqueeze(0).cuda() input_id = torch.tensor(input_id).unsqueeze(0).cuda()
print("original model generated text:") print("original model generated text:")
...@@ -91,5 +95,5 @@ def main(): ...@@ -91,5 +95,5 @@ def main():
print(generate_text(model, tokenizer, "Hello", max_length=100)) print(generate_text(model, tokenizer, "Hello", max_length=100))
if __name__ == '__main__': if __name__ == "__main__":
main() main()
...@@ -6,13 +6,14 @@ from modeling_bitnet import BitnetForCausalLM ...@@ -6,13 +6,14 @@ from modeling_bitnet import BitnetForCausalLM
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str) parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str)
def profile(model, input_data): def profile(model, input_data):
import time import time
import numpy as np import numpy as np
model = model.cuda() model = model.cuda()
model.eval() model.eval()
...@@ -35,17 +36,17 @@ def profile(model, input_data): ...@@ -35,17 +36,17 @@ def profile(model, input_data):
def main(): def main():
model = BitnetForCausalLM.from_pretrained( model = BitnetForCausalLM.from_pretrained(
'1bitLLM/bitnet_b1_58-3B', "1bitLLM/bitnet_b1_58-3B",
device_map='auto', device_map="auto",
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
use_flash_attention_2=True, use_flash_attention_2=True,
torch_dtype=torch.float16, torch_dtype=torch.float16,
).half() ).half()
print(f"gpu memory: {torch.cuda.memory_allocated() / 1024 ** 3} GB") print(f"gpu memory: {torch.cuda.memory_allocated() / 1024**3} GB")
with torch.no_grad(): with torch.no_grad():
model._post_process_weights() model._post_process_weights()
print(f"gpu memory BitBLAS: {torch.cuda.memory_allocated() / 1024 ** 3} GB") print(f"gpu memory BitBLAS: {torch.cuda.memory_allocated() / 1024**3} GB")
if __name__ == '__main__': if __name__ == "__main__":
main() main()
...@@ -15,9 +15,9 @@ from tqdm import tqdm ...@@ -15,9 +15,9 @@ from tqdm import tqdm
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--seed', default=0, type=int) parser.add_argument("--seed", default=0, type=int)
parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str) parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str)
parser.add_argument('--seqlen', default=2048, type=int) parser.add_argument("--seqlen", default=2048, type=int)
def calulate_loss(model, input, loss_fct): def calulate_loss(model, input, loss_fct):
...@@ -29,12 +29,16 @@ def calulate_loss(model, input, loss_fct): ...@@ -29,12 +29,16 @@ def calulate_loss(model, input, loss_fct):
def main(args): def main(args):
datasets = ['c4', 'wikitext2'] datasets = ["c4", "wikitext2"]
model = BitnetForCausalLM.from_pretrained( model = (
args.hf_path, BitnetForCausalLM.from_pretrained(
use_flash_attention_2=True, args.hf_path,
torch_dtype=torch.float16, use_flash_attention_2=True,
).cuda().half() torch_dtype=torch.float16,
)
.cuda()
.half()
)
with torch.no_grad(): with torch.no_grad():
model._post_process_weights() model._post_process_weights()
tokenizer = BitnetTokenizer.from_pretrained(args.hf_path, use_fast=False) tokenizer = BitnetTokenizer.from_pretrained(args.hf_path, use_fast=False)
...@@ -48,9 +52,9 @@ def main(args): ...@@ -48,9 +52,9 @@ def main(args):
for ii in progress: for ii in progress:
input = torch.Tensor(testdata[ii]).long().cuda().view(1, -1) input = torch.Tensor(testdata[ii]).long().cuda().view(1, -1)
loss = calulate_loss(model, input, loss_fct) loss = calulate_loss(model, input, loss_fct)
count += (input.size(-1) - 1) count += input.size(-1) - 1
acc_loss += loss.item() acc_loss += loss.item()
progress.set_description(f"avg_loss = {acc_loss/ count / math.log(2)}") progress.set_description(f"avg_loss = {acc_loss / count / math.log(2)}")
avg_loss = acc_loss / count / math.log(2) avg_loss = acc_loss / count / math.log(2)
ppl.append(2**avg_loss) ppl.append(2**avg_loss)
...@@ -60,7 +64,7 @@ def main(args): ...@@ -60,7 +64,7 @@ def main(args):
print("Avg PPL:", sum(ppl) / len(ppl)) print("Avg PPL:", sum(ppl) / len(ppl))
if __name__ == '__main__': if __name__ == "__main__":
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
args = parser.parse_args() args = parser.parse_args()
random.seed(args.seed) random.seed(args.seed)
......
...@@ -15,21 +15,17 @@ def set_seed(seed): ...@@ -15,21 +15,17 @@ def set_seed(seed):
def get_test_dataset(dataset_name, tokenizer, seqlen=2048): def get_test_dataset(dataset_name, tokenizer, seqlen=2048):
if dataset_name == "wikitext2": if dataset_name == "wikitext2":
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
testdata = "".join(testdata['text']).split('\n') testdata = "".join(testdata["text"]).split("\n")
elif dataset_name == "c4": elif dataset_name == "c4":
testdata = load_dataset( testdata = load_dataset("allenai/c4", data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, split="validation")[
'allenai/c4', "text"
data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, ]
split='validation')['text']
else: else:
raise NotImplementedError raise NotImplementedError
testdata = [item for item in testdata if item != ""] testdata = [item for item in testdata if item != ""]
tokenized_text = [ tokenized_text = [tokenizer(item, add_special_tokens=False)["input_ids"] + [tokenizer.eos_token_id] for item in testdata]
tokenizer(item, add_special_tokens=False)['input_ids'] + [tokenizer.eos_token_id]
for item in testdata
]
data, doc = [], [tokenizer.bos_token_id] data, doc = [], [tokenizer.bos_token_id]
for sen in tokenized_text: for sen in tokenized_text:
...@@ -45,7 +41,6 @@ def get_test_dataset(dataset_name, tokenizer, seqlen=2048): ...@@ -45,7 +41,6 @@ def get_test_dataset(dataset_name, tokenizer, seqlen=2048):
class LMEvalAdaptor(BaseLM): class LMEvalAdaptor(BaseLM):
def __init__(self, model_name, model, tokenizer, batch_size=1, max_length=-1): def __init__(self, model_name, model, tokenizer, batch_size=1, max_length=-1):
super().__init__() super().__init__()
...@@ -137,5 +132,4 @@ class LMEvalAdaptor(BaseLM): ...@@ -137,5 +132,4 @@ class LMEvalAdaptor(BaseLM):
return out return out
def _model_generate(self, context, max_length, eos_token_id): def _model_generate(self, context, max_length, eos_token_id):
return self.model.generate( return self.model.generate(context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False)
context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False)
...@@ -102,17 +102,17 @@ def bitnet_158_int8xint2_decode( ...@@ -102,17 +102,17 @@ def bitnet_158_int8xint2_decode(
@T.prim_func @T.prim_func
def kernel( def kernel(
A: T.Buffer(A_shape, in_dtype), A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype), B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer(C_shape, out_dtype), C: T.Buffer(C_shape, out_dtype),
): ):
with T.Kernel( with T.Kernel(
T.ceildiv(N, n_partition), T.ceildiv(N, n_partition),
M, M,
threads=(reduce_thread, n_partition), threads=(reduce_thread, n_partition),
) as ( ) as (
bx, bx,
by, by,
): ):
A_local = T.alloc_local((micro_size_k,), in_dtype) A_local = T.alloc_local((micro_size_k,), in_dtype)
B_quant_local = T.alloc_local([micro_size_k_compressed], storage_dtype) B_quant_local = T.alloc_local([micro_size_k_compressed], storage_dtype)
...@@ -133,8 +133,7 @@ def bitnet_158_int8xint2_decode( ...@@ -133,8 +133,7 @@ def bitnet_158_int8xint2_decode(
for v in T.vectorized(micro_size_k_compressed): for v in T.vectorized(micro_size_k_compressed):
B_quant_local[v] = B[ B_quant_local[v] = B[
bx * n_partition + ni, bx * n_partition + ni,
ko * (reduce_thread * micro_size_k_compressed) + ko * (reduce_thread * micro_size_k_compressed) + kr * micro_size_k_compressed + v,
kr * micro_size_k_compressed + v,
] ]
T.call_extern( T.call_extern(
...@@ -156,9 +155,9 @@ def bitnet_158_int8xint2_decode( ...@@ -156,9 +155,9 @@ def bitnet_158_int8xint2_decode(
accum_res[0] += A_local[ki] * B_dequantize_local[ki] accum_res[0] += A_local[ki] * B_dequantize_local[ki]
with T.attr( with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope", "reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"), T.reinterpret(T.uint64(0), dtype="handle"),
): ):
T.evaluate( T.evaluate(
T.tvm_thread_allreduce( T.tvm_thread_allreduce(
...@@ -168,7 +167,8 @@ def bitnet_158_int8xint2_decode( ...@@ -168,7 +167,8 @@ def bitnet_158_int8xint2_decode(
reduced_accum_res[0], reduced_accum_res[0],
kr, kr,
dtype="handle", dtype="handle",
)) )
)
if kr == 0: if kr == 0:
C[by, bx * n_partition + ni] = reduced_accum_res[0] C[by, bx * n_partition + ni] = reduced_accum_res[0]
...@@ -234,13 +234,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): ...@@ -234,13 +234,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"):
return new_qweight.view(np.int8) return new_qweight.view(np.int8)
def assert_bitnet_158_int8xint2_decode_correctness(M, def assert_bitnet_158_int8xint2_decode_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding=True):
N,
K,
in_dtype,
out_dtype,
accum_dtype,
fast_decoding=True):
program = bitnet_158_int8xint2_decode(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding) program = bitnet_158_int8xint2_decode(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding)
print(program) print(program)
kernel = tilelang.compile(program) kernel = tilelang.compile(program)
......
...@@ -8,11 +8,13 @@ import tilelang.language as T ...@@ -8,11 +8,13 @@ import tilelang.language as T
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm import DataType from tvm import DataType
from tilelang.intrinsics.mma_layout import ( from tilelang.intrinsics.mma_layout import (
make_mma_swizzle_layout as make_swizzle_layout,) make_mma_swizzle_layout as make_swizzle_layout,
)
import numpy as np import numpy as np
from tilelang.intrinsics.mma_macro_generator import ( from tilelang.intrinsics.mma_macro_generator import (
INT4TensorCoreIntrinEmitter,) INT4TensorCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func from tilelang.transform import simplify_prim_func
torch.manual_seed(42) torch.manual_seed(42)
...@@ -181,38 +183,36 @@ def bitnet_158_int8xint2_prefill( ...@@ -181,38 +183,36 @@ def bitnet_158_int8xint2_prefill(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype), B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Buffer((M, N), out_dtype),
): ):
""" """
GPU kernel entry that performs a blocked, pipelined matrix multiplication A @ B.T writing into C. GPU kernel entry that performs a blocked, pipelined matrix multiplication A @ B.T writing into C.
This kernel: This kernel:
- Loads tiles of A and a compressed/interleaved representation of B from global memory into shared memory. - Loads tiles of A and a compressed/interleaved representation of B from global memory into shared memory.
- Decodes B's packed low-precision format (storage_dtype, e.g., 2-bit packed) into element values of `in_dtype` in shared memory via an external decode routine. - Decodes B's packed low-precision format (storage_dtype, e.g., 2-bit packed) into element values of `in_dtype` in shared memory via an external decode routine.
- Uses Warp/MMA tiled fragments and an INT4/INT2-capable MMA emitter to compute accumulation across K in a pipelined fashion with configurable stages. - Uses Warp/MMA tiled fragments and an INT4/INT2-capable MMA emitter to compute accumulation across K in a pipelined fashion with configurable stages.
- Writes accumulated tile results from shared memory back to global C with the expected block/micro-tile indexing. - Writes accumulated tile results from shared memory back to global C with the expected block/micro-tile indexing.
Parameters: Parameters:
A: Input matrix buffer of shape A_shape and element type `in_dtype`. Represents the MxK activations. A: Input matrix buffer of shape A_shape and element type `in_dtype`. Represents the MxK activations.
B: Compressed/interleaved weight buffer of shape B_shape and storage type `storage_dtype`. Must contain B in the packed low-precision layout expected by the decode routine used by this kernel. B: Compressed/interleaved weight buffer of shape B_shape and storage type `storage_dtype`. Must contain B in the packed low-precision layout expected by the decode routine used by this kernel.
C: Output buffer of shape (M, N) and type `out_dtype`; receives the resulting matrix (accumulated values are produced in `accum_dtype` and stored into C). C: Output buffer of shape (M, N) and type `out_dtype`; receives the resulting matrix (accumulated values are produced in `accum_dtype` and stored into C).
Side effects: Side effects:
Writes results into C. Calls external device decode functions to expand B from its packed representation into shared memory before computation. Writes results into C. Calls external device decode functions to expand B from its packed representation into shared memory before computation.
""" """
with T.Kernel( with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(N, block_N),
T.ceildiv(M, block_M), T.ceildiv(M, block_M),
threads=threads, threads=threads,
prelude=decode_i2s_to_i8s, prelude=decode_i2s_to_i8s,
) as (bx, by): ) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope)
B_dequantize_shared = T.alloc_shared( B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype, scope=shared_scope)
B_dequantize_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_frag = T.alloc_local((warp_rows * fragement_size_a), in_dtype) A_frag = T.alloc_local((warp_rows * fragement_size_a), in_dtype)
B_frag = T.alloc_local((warp_cols * fragement_size_b), in_dtype) B_frag = T.alloc_local((warp_cols * fragement_size_b), in_dtype)
...@@ -223,10 +223,12 @@ def bitnet_158_int8xint2_prefill( ...@@ -223,10 +223,12 @@ def bitnet_158_int8xint2_prefill(
thread_bindings = T.thread_binding(0, threads, "threadIdx.x") thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({ T.annotate_layout(
A_shared: make_swizzle_layout(A_shared), {
B_dequantize_shared: make_swizzle_layout(B_dequantize_shared), A_shared: make_swizzle_layout(A_shared),
}) B_dequantize_shared: make_swizzle_layout(B_dequantize_shared),
}
)
# Improve L2 Cache # Improve L2 Cache
T.use_swizzle(panel_size=10) T.use_swizzle(panel_size=10)
...@@ -234,7 +236,6 @@ def bitnet_158_int8xint2_prefill( ...@@ -234,7 +236,6 @@ def bitnet_158_int8xint2_prefill(
T.clear(C_frag) T.clear(C_frag)
for ko in T.Pipelined((K // block_K), num_stages=stage): for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory # Load A into shared memory
for i, k in T.Parallel(block_M, block_K): for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k] A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
...@@ -243,12 +244,9 @@ def bitnet_158_int8xint2_prefill( ...@@ -243,12 +244,9 @@ def bitnet_158_int8xint2_prefill(
for j, k in T.Parallel(block_N, block_K // num_elems_per_byte): for j, k in T.Parallel(block_N, block_K // num_elems_per_byte):
B_shared[j, k] = B[bx * block_N + j, ko * (block_K // num_elems_per_byte) + k] B_shared[j, k] = B[bx * block_N + j, ko * (block_K // num_elems_per_byte) + k]
for i in T.serial(block_N * block_K // num_elems_per_byte // for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * local_size_compressed)):
(threads * local_size_compressed)):
for v in T.vectorized(0, local_size_compressed): for v in T.vectorized(0, local_size_compressed):
index = ( index = i * threads * local_size_compressed + thread_bindings * local_size_compressed + v
i * threads * local_size_compressed +
thread_bindings * local_size_compressed + v)
vi, vj = T.index_to_coordinates(index, B_shared_shape) vi, vj = T.index_to_coordinates(index, B_shared_shape)
B_local[v] = B_shared[vi, vj] B_local[v] = B_shared[vi, vj]
...@@ -260,12 +258,11 @@ def bitnet_158_int8xint2_prefill( ...@@ -260,12 +258,11 @@ def bitnet_158_int8xint2_prefill(
) )
for v in T.vectorized(0, local_size): for v in T.vectorized(0, local_size):
index = (i * threads * local_size + thread_bindings * local_size + v) index = i * threads * local_size + thread_bindings * local_size + v
vi, vj = T.index_to_coordinates(index, B_dequantize_shared_shape) vi, vj = T.index_to_coordinates(index, B_dequantize_shared_shape)
B_dequantize_shared[vi, vj] = B_dequantize_local[v] B_dequantize_shared[vi, vj] = B_dequantize_local[v]
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment # Load A into fragment
mma_emitter.ldmatrix_a( mma_emitter.ldmatrix_a(
A_frag, A_frag,
...@@ -360,13 +357,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): ...@@ -360,13 +357,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"):
return new_qweight.view(np.int8) return new_qweight.view(np.int8)
def assert_bitnet_158_int8xint2_prefill_correctness(M, def assert_bitnet_158_int8xint2_prefill_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding=True):
N,
K,
in_dtype,
out_dtype,
accum_dtype,
fast_decoding=True):
program = bitnet_158_int8xint2_prefill(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding) program = bitnet_158_int8xint2_prefill(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding)
print(program) print(program)
kernel = tilelang.compile(program) kernel = tilelang.compile(program)
......
...@@ -6,7 +6,8 @@ from tvm import tl as TL ...@@ -6,7 +6,8 @@ from tvm import tl as TL
import tvm.tl.language as T import tvm.tl.language as T
from bitblas.tl.utils import get_swizzle_layout from bitblas.tl.utils import get_swizzle_layout
from bitblas.tl.mma_macro_generator import ( from bitblas.tl.mma_macro_generator import (
TensorCoreIntrinEmitter,) TensorCoreIntrinEmitter,
)
from bitblas.base import simplify_prim_func from bitblas.base import simplify_prim_func
torch.manual_seed(0) torch.manual_seed(0)
...@@ -101,12 +102,11 @@ def tl_matmul( ...@@ -101,12 +102,11 @@ def tl_matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Buffer((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
...@@ -116,10 +116,12 @@ def tl_matmul( ...@@ -116,10 +116,12 @@ def tl_matmul(
thread_bindings = T.thread_binding(0, threads, "threadIdx.x") thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({ T.annotate_layout(
A_shared: make_swizzle_layout(A_shared), {
B_shared: make_swizzle_layout(B_shared), A_shared: make_swizzle_layout(A_shared),
}) B_shared: make_swizzle_layout(B_shared),
}
)
# Improve L2 Cache # Improve L2 Cache
T.use_swizzle(panel_size=10) T.use_swizzle(panel_size=10)
...@@ -127,7 +129,6 @@ def tl_matmul( ...@@ -127,7 +129,6 @@ def tl_matmul(
T.clear(C_local) T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage): for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory # Load A into shared memory
for i, k in T.Parallel(block_M, block_K): for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k] A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
...@@ -137,7 +138,6 @@ def tl_matmul( ...@@ -137,7 +138,6 @@ def tl_matmul(
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment # Load A into fragment
mma_emitter.ldmatrix_a( mma_emitter.ldmatrix_a(
A_local, A_local,
......
...@@ -49,7 +49,13 @@ def generate_text(model, tokenizer, prompt, max_length=100): ...@@ -49,7 +49,13 @@ def generate_text(model, tokenizer, prompt, max_length=100):
def main(): def main():
# load quantized model # load quantized model
qmodel = BitnetForCausalLM.from_quantized(saved_model_path,).cuda().half() qmodel = (
BitnetForCausalLM.from_quantized(
saved_model_path,
)
.cuda()
.half()
)
tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False) tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False)
# print("original model generated text:") # print("original model generated text:")
# print(generate_text(model, tokenizer, "Hi, ", max_length=100)) # print(generate_text(model, tokenizer, "Hi, ", max_length=100))
......
...@@ -25,9 +25,9 @@ parser.add_argument("--saved_model_path", type=str, default=None) ...@@ -25,9 +25,9 @@ parser.add_argument("--saved_model_path", type=str, default=None)
args = parser.parse_args() args = parser.parse_args()
model_name_or_path = args.model_name_or_path model_name_or_path = args.model_name_or_path
saved_model_path = os.path.join( saved_model_path = (
dirpath, "models", os.path.join(dirpath, "models", f"{model_name_or_path}_bitblas") if args.saved_model_path is None else args.saved_model_path
f"{model_name_or_path}_bitblas") if args.saved_model_path is None else args.saved_model_path )
def generate_text(model, tokenizer, prompt, max_length=100): def generate_text(model, tokenizer, prompt, max_length=100):
...@@ -67,7 +67,10 @@ def main(): ...@@ -67,7 +67,10 @@ def main():
model_name_or_path, model_name_or_path,
use_flash_attention_2=False, use_flash_attention_2=False,
torch_dtype=torch.float16, torch_dtype=torch.float16,
).cuda().half()) )
.cuda()
.half()
)
tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False) tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False)
# print("original model generated text:") # print("original model generated text:")
...@@ -112,10 +115,16 @@ def main(): ...@@ -112,10 +115,16 @@ def main():
file_path = cached_file(model_name_or_path, file) file_path = cached_file(model_name_or_path, file)
os.system(f"cp {file_path} {saved_model_path}") os.system(f"cp {file_path} {saved_model_path}")
# load quantized model # load quantized model
qmodel = BitnetForCausalLM.from_quantized(saved_model_path,).cuda().half() qmodel = (
BitnetForCausalLM.from_quantized(
saved_model_path,
)
.cuda()
.half()
)
print("quantized model generated text:") print("quantized model generated text:")
print(generate_text(qmodel, tokenizer, "Hi, ", max_length=100)) print(generate_text(qmodel, tokenizer, "Hi, ", max_length=100))
if __name__ == '__main__': if __name__ == "__main__":
main() main()
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment