Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
......@@ -18,9 +18,11 @@ def get_configs():
@autotune(configs=get_configs(), warmup=500, rep=100)
@tilelang.jit(
out_idx=[3], pass_configs={
out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
},
)
def flashattn(
batch,
heads,
......@@ -33,12 +35,13 @@ def flashattn(
block_N=64,
num_stages=1,
threads=128,
dtype: str = "float16"):
dtype: str = "float16",
):
if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N"
if sm_scale is None:
sm_scale = (1.0 / dim)**0.5
sm_scale = (1.0 / dim) ** 0.5
scale = sm_scale * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
......@@ -58,13 +61,12 @@ def flashattn(
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
if window_size is not None:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0,
-T.infinity(acc_s.dtype))
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype))
else:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
......@@ -79,7 +81,7 @@ def flashattn(
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
......@@ -102,8 +104,7 @@ def flashattn(
# NOTE(wt): check_inf is necessary for sliding window attention.
for i in T.Parallel(block_M):
if window_size is not None:
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0,
scores_max[i])
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
......@@ -147,53 +148,51 @@ def flashattn(
logsum = T.alloc_fragment([block_M], accum_dtype)
sinks = T.alloc_fragment([block_M], dtype)
T.annotate_layout({
T.annotate_layout(
{
Q_shared: make_swizzled_layout(Q_shared),
K_shared: make_swizzled_layout(K_shared),
V_shared: make_swizzled_layout(V_shared),
O_shared: make_swizzled_layout(O_shared),
})
}
)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Parallel(block_M):
sinks[i] = Sinks[by]
end = T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
end = T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.max(0, (bx * block_M + past_len - window_size) //
block_N) if window_size is not None else 0
start = T.max(0, (bx * block_M + past_len - window_size) // block_N) if window_size is not None else 0
for k in T.Pipelined(start, end, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i in T.Parallel(block_M):
logsum[i] += T.exp2(sinks[i] * 1.44269504 -
scores_max[i] * scale) # The only change for attention sink
logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return main
# Modified from https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def ref_program(query: torch.Tensor,
def ref_program(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
sinks: torch.Tensor,
sliding_window: Optional[int] = None,
dtype: torch.dtype = torch.float16) -> torch.Tensor:
query = query.transpose(1, 2).contiguous().unsqueeze(
3) # align with the original function's interface
dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
query = query.transpose(1, 2).contiguous().unsqueeze(3) # align with the original function's interface
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
......@@ -228,41 +227,35 @@ def ref_program(query: torch.Tensor,
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups,
head_dim).to(dtype)
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype)
return output.transpose(1, 2).contiguous()
def gen_inputs(
B,
H,
Sq,
Skv,
D,
dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
query = torch.randn([B, H, Sq, D], dtype=dtype, device='cuda')
key = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda')
value = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda')
sinks = torch.randn([H], dtype=dtype, device='cuda')
def gen_inputs(B, H, Sq, Skv, D, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
query = torch.randn([B, H, Sq, D], dtype=dtype, device="cuda")
key = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda")
value = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda")
sinks = torch.randn([H], dtype=dtype, device="cuda")
return query, key, value, sinks
def main(batch: int = 1,
def main(
batch: int = 1,
heads: int = 1,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False):
tune: bool = False,
):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None:
print('Using sliding window attention.')
print("Using sliding window attention.")
assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min(
window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else:
print('Using full attention.')
print("Using full attention.")
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul
......@@ -289,19 +282,17 @@ def main(batch: int = 1,
block_N=block_N,
num_stages=num_stages,
threads=threads,
dtype=dtype)
dtype=dtype,
)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype)
torch.testing.assert_close(
kernel(Q, K, V, sinks),
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
rtol=1e-2,
atol=1e-2)
kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2
)
print("All checks passed.✅")
latency = do_bench(
lambda: ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), warmup=500)
latency = do_bench(lambda: ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
......@@ -311,19 +302,13 @@ def main(batch: int = 1,
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query')
parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument(
'--window_size',
type=int,
default=None,
help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument('--tune', action='store_true', help='tune')
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query")
parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument("--tune", action="store_true", help="tune")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype,
args.tune)
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune)
......@@ -19,9 +19,11 @@ def get_configs():
@autotune(configs=get_configs(), warmup=500, rep=100)
@tilelang.jit(
out_idx=[3], pass_configs={
out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
},
)
def flashattn(
batch,
heads,
......@@ -34,13 +36,13 @@ def flashattn(
block_N=128,
num_stages=2,
threads=256,
dtype: str = "float16"):
dtype: str = "float16",
):
if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N"
if sm_scale is None:
sm_scale = (1.0 / dim)**0.5
sm_scale = (1.0 / dim) ** 0.5
scale = sm_scale * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
......@@ -61,13 +63,12 @@ def flashattn(
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
if window_size is not None:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0,
-T.infinity(acc_s.dtype))
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype))
else:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
......@@ -82,7 +83,7 @@ def flashattn(
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
......@@ -105,8 +106,7 @@ def flashattn(
# NOTE(wt): check_inf is necessary for sliding window attention.
for i in T.Parallel(block_M):
if window_size is not None:
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0,
scores_max[i])
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
......@@ -150,25 +150,25 @@ def flashattn(
logsum = T.alloc_fragment([block_M], accum_dtype)
sinks = T.alloc_fragment([block_M], dtype)
T.annotate_layout({
T.annotate_layout(
{
Q_shared: make_swizzled_layout(Q_shared),
K_shared: make_swizzled_layout(K_shared),
V_shared: make_swizzled_layout(V_shared),
O_shared: make_swizzled_layout(O_shared),
})
}
)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Parallel(block_M):
sinks[i] = Sinks[by]
end = T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
end = T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.max(0, (bx * block_M + past_len - window_size) //
block_N) if window_size is not None else 0
start = T.max(0, (bx * block_M + past_len - window_size) // block_N) if window_size is not None else 0
for k in T.Pipelined(
start,
......@@ -176,34 +176,33 @@ def flashattn(
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]):
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]],
):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i in T.Parallel(block_M):
logsum[i] += T.exp2(sinks[i] * 1.44269504 -
scores_max[i] * scale) # The only change for attention sink
logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return main
# Following functions are adapted and optimized from
# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def ref_program(query: torch.Tensor,
def ref_program(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
sinks: torch.Tensor,
sliding_window: Optional[int] = None,
dtype: torch.dtype = torch.float16) -> torch.Tensor:
query = query.transpose(1, 2).contiguous().unsqueeze(
3) # align with the original function'sinterface
dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
query = query.transpose(1, 2).contiguous().unsqueeze(3) # align with the original function'sinterface
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
......@@ -238,41 +237,35 @@ def ref_program(query: torch.Tensor,
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups,
head_dim).to(dtype)
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype)
return output.transpose(1, 2).contiguous()
def gen_inputs(
B,
H,
Sq,
Skv,
D,
dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
query = torch.randn([B, H, Sq, D], dtype=dtype, device='cuda')
key = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda')
value = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda')
sinks = torch.randn([H], dtype=dtype, device='cuda')
def gen_inputs(B, H, Sq, Skv, D, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
query = torch.randn([B, H, Sq, D], dtype=dtype, device="cuda")
key = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda")
value = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda")
sinks = torch.randn([H], dtype=dtype, device="cuda")
return query, key, value, sinks
def main(batch: int = 1,
def main(
batch: int = 1,
heads: int = 32,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False):
tune: bool = False,
):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None:
print('Using sliding window attention.')
print("Using sliding window attention.")
assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min(
window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else:
print('Using full attention.')
print("Using full attention.")
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul
......@@ -299,15 +292,14 @@ def main(batch: int = 1,
block_N=block_N,
num_stages=num_stages,
threads=threads,
dtype=dtype)
dtype=dtype,
)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype)
torch.testing.assert_close(
kernel(Q, K, V, sinks),
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
rtol=1e-2,
atol=1e-2)
kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2
)
print("All checks passed.✅")
latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
......@@ -317,19 +309,13 @@ def main(batch: int = 1,
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query')
parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument(
'--window_size',
type=int,
default=None,
help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument('--tune', action='store_true', help='tune')
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query")
parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument("--tune", action="store_true", help="tune")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype,
args.tune)
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune)
......@@ -12,8 +12,7 @@ bitblas.set_log_level("INFO")
def generate_text_batch(model, tokenizer, prompts, max_length=100):
# Encode the input prompts as a batch
input_ids = tokenizer(
prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device)
input_ids = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device)
# Generate cos and sin values (commented out as not used in generation)
seq_length = input_ids.size(1)
......@@ -37,9 +36,7 @@ def generate_text_batch(model, tokenizer, prompts, max_length=100):
end_time = time.time()
# Decode the output ids to text
generated_texts = [
tokenizer.decode(output_id, skip_special_tokens=True) for output_id in output_ids
]
generated_texts = [tokenizer.decode(output_id, skip_special_tokens=True) for output_id in output_ids]
generation_time = end_time - start_time
num_tokens = sum(len(output_id) for output_id in output_ids)
......@@ -52,8 +49,8 @@ def generate_text_batch(model, tokenizer, prompts, max_length=100):
def profile(model, input_data):
import numpy as np
model = model.cuda()
model.eval()
......@@ -74,25 +71,29 @@ def profile(model, input_data):
return np.mean(times)
model_path = '1bitLLM/bitnet_b1_58-3B'
model_path = "1bitLLM/bitnet_b1_58-3B"
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--bs', default=16, type=int)
parser.add_argument('--in_seq_len', default=32, type=int)
parser.add_argument('--out_seq_len', default=128, type=int)
parser.add_argument('--bitblas', action='store_true')
parser.add_argument("--bs", default=16, type=int)
parser.add_argument("--in_seq_len", default=32, type=int)
parser.add_argument("--out_seq_len", default=128, type=int)
parser.add_argument("--bitblas", action="store_true")
args = parser.parse_args()
bs = args.bs
in_seq_len = args.in_seq_len
out_seq_len = args.out_seq_len
is_bitblas = args.bitblas
model = BitnetForCausalLM.from_pretrained(
model = (
BitnetForCausalLM.from_pretrained(
model_path,
use_flash_attention_2=True,
torch_dtype=torch.float16,
).cuda().half()
)
.cuda()
.half()
)
if is_bitblas:
with torch.no_grad():
model.quantize()
......@@ -109,5 +110,5 @@ def main():
print(generate_text_batch(model, tokenizer, prompts, max_length=max_length))
if __name__ == '__main__':
if __name__ == "__main__":
main()
......@@ -6,13 +6,14 @@ from modeling_bitnet import BitnetForCausalLM
torch.set_grad_enabled(False)
parser = argparse.ArgumentParser()
parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str)
parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str)
def profile(model, input_data):
import time
import numpy as np
model = model.cuda()
model.eval()
......@@ -35,8 +36,8 @@ def profile(model, input_data):
def main():
model = BitnetForCausalLM.from_pretrained(
'1bitLLM/bitnet_b1_58-3B',
device_map='auto',
"1bitLLM/bitnet_b1_58-3B",
device_map="auto",
low_cpu_mem_usage=True,
use_flash_attention_2=True,
torch_dtype=torch.float16,
......@@ -52,5 +53,5 @@ def main():
print(f"Batch size: {batch_size}, Seq len: {seq_len}, Latency: {latency}")
if __name__ == '__main__':
if __name__ == "__main__":
main()
......@@ -17,7 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" LLaMA model configuration"""
"""LLaMA model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
......@@ -180,16 +180,10 @@ class BitnetConfig(PretrainedConfig):
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
f"got {self.rope_scaling}")
raise ValueError(f"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, got {self.rope_scaling}")
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor,
float) or rope_scaling_factor <= 1.0:
raise ValueError(
f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
raise ValueError(f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}")
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
......@@ -47,8 +47,8 @@ def generate_text(model, tokenizer, prompt, max_length=100):
def profile(model, input_data):
import numpy as np
model = model.cuda()
model.eval()
......@@ -69,18 +69,22 @@ def profile(model, input_data):
return np.mean(times)
model_path = '1bitLLM/bitnet_b1_58-3B'
model_path = "1bitLLM/bitnet_b1_58-3B"
def main():
model = BitnetForCausalLM.from_pretrained(
model = (
BitnetForCausalLM.from_pretrained(
model_path,
use_flash_attention_2=False,
torch_dtype=torch.float16,
).cuda().half()
)
.cuda()
.half()
)
tokenizer = BitnetTokenizer.from_pretrained(model_path, use_fast=False)
input_id = tokenizer("Hello")['input_ids']
input_id = tokenizer("Hello")["input_ids"]
input_id = torch.tensor(input_id).unsqueeze(0).cuda()
print("original model generated text:")
......@@ -91,5 +95,5 @@ def main():
print(generate_text(model, tokenizer, "Hello", max_length=100))
if __name__ == '__main__':
if __name__ == "__main__":
main()
......@@ -6,13 +6,14 @@ from modeling_bitnet import BitnetForCausalLM
torch.set_grad_enabled(False)
parser = argparse.ArgumentParser()
parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str)
parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str)
def profile(model, input_data):
import time
import numpy as np
model = model.cuda()
model.eval()
......@@ -35,17 +36,17 @@ def profile(model, input_data):
def main():
model = BitnetForCausalLM.from_pretrained(
'1bitLLM/bitnet_b1_58-3B',
device_map='auto',
"1bitLLM/bitnet_b1_58-3B",
device_map="auto",
low_cpu_mem_usage=True,
use_flash_attention_2=True,
torch_dtype=torch.float16,
).half()
print(f"gpu memory: {torch.cuda.memory_allocated() / 1024 ** 3} GB")
print(f"gpu memory: {torch.cuda.memory_allocated() / 1024**3} GB")
with torch.no_grad():
model._post_process_weights()
print(f"gpu memory BitBLAS: {torch.cuda.memory_allocated() / 1024 ** 3} GB")
print(f"gpu memory BitBLAS: {torch.cuda.memory_allocated() / 1024**3} GB")
if __name__ == '__main__':
if __name__ == "__main__":
main()
......@@ -15,9 +15,9 @@ from tqdm import tqdm
torch.set_grad_enabled(False)
parser = argparse.ArgumentParser()
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str)
parser.add_argument('--seqlen', default=2048, type=int)
parser.add_argument("--seed", default=0, type=int)
parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str)
parser.add_argument("--seqlen", default=2048, type=int)
def calulate_loss(model, input, loss_fct):
......@@ -29,12 +29,16 @@ def calulate_loss(model, input, loss_fct):
def main(args):
datasets = ['c4', 'wikitext2']
model = BitnetForCausalLM.from_pretrained(
datasets = ["c4", "wikitext2"]
model = (
BitnetForCausalLM.from_pretrained(
args.hf_path,
use_flash_attention_2=True,
torch_dtype=torch.float16,
).cuda().half()
)
.cuda()
.half()
)
with torch.no_grad():
model._post_process_weights()
tokenizer = BitnetTokenizer.from_pretrained(args.hf_path, use_fast=False)
......@@ -48,9 +52,9 @@ def main(args):
for ii in progress:
input = torch.Tensor(testdata[ii]).long().cuda().view(1, -1)
loss = calulate_loss(model, input, loss_fct)
count += (input.size(-1) - 1)
count += input.size(-1) - 1
acc_loss += loss.item()
progress.set_description(f"avg_loss = {acc_loss/ count / math.log(2)}")
progress.set_description(f"avg_loss = {acc_loss / count / math.log(2)}")
avg_loss = acc_loss / count / math.log(2)
ppl.append(2**avg_loss)
......@@ -60,7 +64,7 @@ def main(args):
print("Avg PPL:", sum(ppl) / len(ppl))
if __name__ == '__main__':
if __name__ == "__main__":
torch.set_grad_enabled(False)
args = parser.parse_args()
random.seed(args.seed)
......
......@@ -15,21 +15,17 @@ def set_seed(seed):
def get_test_dataset(dataset_name, tokenizer, seqlen=2048):
if dataset_name == "wikitext2":
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
testdata = "".join(testdata['text']).split('\n')
testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
testdata = "".join(testdata["text"]).split("\n")
elif dataset_name == "c4":
testdata = load_dataset(
'allenai/c4',
data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'},
split='validation')['text']
testdata = load_dataset("allenai/c4", data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, split="validation")[
"text"
]
else:
raise NotImplementedError
testdata = [item for item in testdata if item != ""]
tokenized_text = [
tokenizer(item, add_special_tokens=False)['input_ids'] + [tokenizer.eos_token_id]
for item in testdata
]
tokenized_text = [tokenizer(item, add_special_tokens=False)["input_ids"] + [tokenizer.eos_token_id] for item in testdata]
data, doc = [], [tokenizer.bos_token_id]
for sen in tokenized_text:
......@@ -45,7 +41,6 @@ def get_test_dataset(dataset_name, tokenizer, seqlen=2048):
class LMEvalAdaptor(BaseLM):
def __init__(self, model_name, model, tokenizer, batch_size=1, max_length=-1):
super().__init__()
......@@ -137,5 +132,4 @@ class LMEvalAdaptor(BaseLM):
return out
def _model_generate(self, context, max_length, eos_token_id):
return self.model.generate(
context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False)
return self.model.generate(context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False)
......@@ -133,8 +133,7 @@ def bitnet_158_int8xint2_decode(
for v in T.vectorized(micro_size_k_compressed):
B_quant_local[v] = B[
bx * n_partition + ni,
ko * (reduce_thread * micro_size_k_compressed) +
kr * micro_size_k_compressed + v,
ko * (reduce_thread * micro_size_k_compressed) + kr * micro_size_k_compressed + v,
]
T.call_extern(
......@@ -168,7 +167,8 @@ def bitnet_158_int8xint2_decode(
reduced_accum_res[0],
kr,
dtype="handle",
))
)
)
if kr == 0:
C[by, bx * n_partition + ni] = reduced_accum_res[0]
......@@ -234,13 +234,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"):
return new_qweight.view(np.int8)
def assert_bitnet_158_int8xint2_decode_correctness(M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
fast_decoding=True):
def assert_bitnet_158_int8xint2_decode_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding=True):
program = bitnet_158_int8xint2_decode(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding)
print(program)
kernel = tilelang.compile(program)
......
......@@ -8,11 +8,13 @@ import tilelang.language as T
from tilelang import tvm as tvm
from tvm import DataType
from tilelang.intrinsics.mma_layout import (
make_mma_swizzle_layout as make_swizzle_layout,)
make_mma_swizzle_layout as make_swizzle_layout,
)
import numpy as np
from tilelang.intrinsics.mma_macro_generator import (
INT4TensorCoreIntrinEmitter,)
INT4TensorCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func
torch.manual_seed(42)
......@@ -208,11 +210,9 @@ def bitnet_158_int8xint2_prefill(
threads=threads,
prelude=decode_i2s_to_i8s,
) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope)
B_dequantize_shared = T.alloc_shared(
B_dequantize_shared_shape, in_dtype, scope=shared_scope)
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_frag = T.alloc_local((warp_rows * fragement_size_a), in_dtype)
B_frag = T.alloc_local((warp_cols * fragement_size_b), in_dtype)
......@@ -223,10 +223,12 @@ def bitnet_158_int8xint2_prefill(
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({
T.annotate_layout(
{
A_shared: make_swizzle_layout(A_shared),
B_dequantize_shared: make_swizzle_layout(B_dequantize_shared),
})
}
)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
......@@ -234,7 +236,6 @@ def bitnet_158_int8xint2_prefill(
T.clear(C_frag)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
......@@ -243,12 +244,9 @@ def bitnet_158_int8xint2_prefill(
for j, k in T.Parallel(block_N, block_K // num_elems_per_byte):
B_shared[j, k] = B[bx * block_N + j, ko * (block_K // num_elems_per_byte) + k]
for i in T.serial(block_N * block_K // num_elems_per_byte //
(threads * local_size_compressed)):
for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * local_size_compressed)):
for v in T.vectorized(0, local_size_compressed):
index = (
i * threads * local_size_compressed +
thread_bindings * local_size_compressed + v)
index = i * threads * local_size_compressed + thread_bindings * local_size_compressed + v
vi, vj = T.index_to_coordinates(index, B_shared_shape)
B_local[v] = B_shared[vi, vj]
......@@ -260,12 +258,11 @@ def bitnet_158_int8xint2_prefill(
)
for v in T.vectorized(0, local_size):
index = (i * threads * local_size + thread_bindings * local_size + v)
index = i * threads * local_size + thread_bindings * local_size + v
vi, vj = T.index_to_coordinates(index, B_dequantize_shared_shape)
B_dequantize_shared[vi, vj] = B_dequantize_local[v]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_frag,
......@@ -360,13 +357,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"):
return new_qweight.view(np.int8)
def assert_bitnet_158_int8xint2_prefill_correctness(M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
fast_decoding=True):
def assert_bitnet_158_int8xint2_prefill_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding=True):
program = bitnet_158_int8xint2_prefill(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding)
print(program)
kernel = tilelang.compile(program)
......
......@@ -6,7 +6,8 @@ from tvm import tl as TL
import tvm.tl.language as T
from bitblas.tl.utils import get_swizzle_layout
from bitblas.tl.mma_macro_generator import (
TensorCoreIntrinEmitter,)
TensorCoreIntrinEmitter,
)
from bitblas.base import simplify_prim_func
torch.manual_seed(0)
......@@ -106,7 +107,6 @@ def tl_matmul(
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
......@@ -116,10 +116,12 @@ def tl_matmul(
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({
T.annotate_layout(
{
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
}
)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
......@@ -127,7 +129,6 @@ def tl_matmul(
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
......@@ -137,7 +138,6 @@ def tl_matmul(
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
......
......@@ -49,7 +49,13 @@ def generate_text(model, tokenizer, prompt, max_length=100):
def main():
# load quantized model
qmodel = BitnetForCausalLM.from_quantized(saved_model_path,).cuda().half()
qmodel = (
BitnetForCausalLM.from_quantized(
saved_model_path,
)
.cuda()
.half()
)
tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False)
# print("original model generated text:")
# print(generate_text(model, tokenizer, "Hi, ", max_length=100))
......
......@@ -25,9 +25,9 @@ parser.add_argument("--saved_model_path", type=str, default=None)
args = parser.parse_args()
model_name_or_path = args.model_name_or_path
saved_model_path = os.path.join(
dirpath, "models",
f"{model_name_or_path}_bitblas") if args.saved_model_path is None else args.saved_model_path
saved_model_path = (
os.path.join(dirpath, "models", f"{model_name_or_path}_bitblas") if args.saved_model_path is None else args.saved_model_path
)
def generate_text(model, tokenizer, prompt, max_length=100):
......@@ -67,7 +67,10 @@ def main():
model_name_or_path,
use_flash_attention_2=False,
torch_dtype=torch.float16,
).cuda().half())
)
.cuda()
.half()
)
tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False)
# print("original model generated text:")
......@@ -112,10 +115,16 @@ def main():
file_path = cached_file(model_name_or_path, file)
os.system(f"cp {file_path} {saved_model_path}")
# load quantized model
qmodel = BitnetForCausalLM.from_quantized(saved_model_path,).cuda().half()
qmodel = (
BitnetForCausalLM.from_quantized(
saved_model_path,
)
.cuda()
.half()
)
print("quantized model generated text:")
print(generate_text(qmodel, tokenizer, "Hi, ", max_length=100))
if __name__ == '__main__':
if __name__ == "__main__":
main()
This diff is collapsed.
......@@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for LLaMA."""
import os
from shutil import copyfile
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
......@@ -37,12 +38,10 @@ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"hf-internal-testing/llama-tokenizer":
"https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model",
"hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model",
},
"tokenizer_file": {
"hf-internal-testing/llama-tokenizer":
"https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json",
"hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
......@@ -159,14 +158,10 @@ class BitnetTokenizer(PreTrainedTokenizer):
**kwargs,
):
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
bos_token = AddedToken(
bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(
eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
unk_token = AddedToken(
unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
pad_token = AddedToken(
pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
if legacy is None:
logger.warning_once(
......@@ -174,7 +169,8 @@ class BitnetTokenizer(PreTrainedTokenizer):
" expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you."
" If you want to use the new behavior, set `legacy=False`. This should only be set if you understand what it"
" means, and thoroughly read the reason why this was added as explained in"
" https://github.com/huggingface/transformers/pull/24565")
" https://github.com/huggingface/transformers/pull/24565"
)
legacy = True
self.legacy = legacy
......@@ -214,8 +210,7 @@ class BitnetTokenizer(PreTrainedTokenizer):
with open(self.vocab_file, "rb") as f:
sp_model = f.read()
model_pb2 = import_protobuf(
f"The new behavior of {self.__class__.__name__} (with `self.legacy = False`)")
model_pb2 = import_protobuf(f"The new behavior of {self.__class__.__name__} (with `self.legacy = False`)")
model = model_pb2.ModelProto.FromString(sp_model)
normalizer_spec = model_pb2.NormalizerSpec()
normalizer_spec.add_dummy_prefix = False
......@@ -261,8 +256,7 @@ class BitnetTokenizer(PreTrainedTokenizer):
tokens = super().tokenize(text, **kwargs)
if len(tokens
) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
tokens = tokens[1:]
return tokens
......@@ -284,7 +278,7 @@ class BitnetTokenizer(PreTrainedTokenizer):
# 1. Encode string + prefix ex: "<unk> Hey"
tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
# 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
return tokens[self.unk_token_length:] if len(tokens) >= self.unk_token_length else tokens
return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
......@@ -332,12 +326,9 @@ class BitnetTokenizer(PreTrainedTokenizer):
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(save_directory,
(filename_prefix + "-" if filename_prefix else "") +
VOCAB_FILES_NAMES["vocab_file"])
out_vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"])
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(
self.vocab_file):
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
......@@ -357,10 +348,9 @@ class BitnetTokenizer(PreTrainedTokenizer):
return output
def get_special_tokens_mask(self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False) -> List[int]:
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
......@@ -377,20 +367,16 @@ class BitnetTokenizer(PreTrainedTokenizer):
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True)
return super().get_special_tokens_mask(token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True)
bos_token_id = [1] if self.add_bos_token else []
eos_token_id = [1] if self.add_eos_token else []
if token_ids_1 is None:
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
return (bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + bos_token_id +
([0] * len(token_ids_1)) + eos_token_id)
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + bos_token_id + ([0] * len(token_ids_1)) + eos_token_id
def create_token_type_ids_from_sequences(self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None) -> List[int]:
def create_token_type_ids_from_sequences(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]:
"""
Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
sequence pair mask has the following format:
......@@ -473,9 +459,9 @@ class BitnetTokenizer(PreTrainedTokenizer):
"{% elif message['role'] == 'assistant' %}"
"{{ ' ' + content.strip() + ' ' + eos_token }}"
"{% endif %}"
"{% endfor %}")
template = template.replace("USE_DEFAULT_PROMPT",
"true" if self.use_default_system_prompt else "false")
"{% endfor %}"
)
template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
......
......@@ -24,15 +24,14 @@ def weight_quant(weight, num_bits=1):
def activation_quant(x, num_bits=8):
dtype = x.dtype
x = x.float()
Qn = -(2**(num_bits - 1))
Qp = 2**(num_bits - 1) - 1
Qn = -(2 ** (num_bits - 1))
Qp = 2 ** (num_bits - 1) - 1
s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
result = (x * s).round().clamp(Qn, Qp) / s
return result.type(dtype)
class BitLinearBitBLAS(nn.Module):
def __init__(
self,
in_features: int,
......@@ -68,7 +67,7 @@ class BitLinearBitBLAS(nn.Module):
self.bitblas_matmul = self._get_or_create_bitblas_operator(matmul_config, ENABLE_TUNING)
self.format = "bitnet"
self.Qp = 2**(self.input_bits - 1) - 1
self.Qp = 2 ** (self.input_bits - 1) - 1
def _get_or_create_bitblas_operator(self, config, enable_tuning):
if global_operator_cache.size() == 0:
......@@ -99,8 +98,7 @@ class BitLinearBitBLAS(nn.Module):
@classmethod
def from_bit_linear(cls, bitlinear, weight_group=1):
bitblas_linear = cls(
bitlinear.in_features, bitlinear.out_features, weight_bits=1, input_bits=8)
bitblas_linear = cls(bitlinear.in_features, bitlinear.out_features, weight_bits=1, input_bits=8)
sw, qweight = bitblas_linear.create_bitblas_weights(bitlinear.weight, weight_group)
bitblas_linear.register_buffer("qweight", qweight)
bitblas_linear.register_buffer("sw", sw)
......@@ -158,8 +156,8 @@ class BitLinearBitBLAS(nn.Module):
@torch.compile
def activation_quant(self, x, num_bits=8):
x = x.float()
Qn = -(2**(num_bits - 1))
Qp = 2**(num_bits - 1) - 1
Qn = -(2 ** (num_bits - 1))
Qp = 2 ** (num_bits - 1) - 1
s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
result = (x * s).round().clamp(Qn, Qp)
return result.type(torch.int8), s
......@@ -173,9 +171,8 @@ class BitLinearBitBLAS(nn.Module):
# for the correctness evaluation.
def native_forward(self, input):
quant_input = (input + (activation_quant(input, self.input_bits) - input).detach())
quant_weight = (
self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach())
quant_input = input + (activation_quant(input, self.input_bits) - input).detach()
quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach()
out = nn.functional.linear(quant_input, quant_weight)
if self.bias is not None:
......@@ -214,7 +211,6 @@ class BitLinearBitBLAS(nn.Module):
# Naive BitLinear from HuggingFace
class BitLinear(nn.Linear):
def __init__(self, *kargs, weight_bits=1, input_bits=8, **kwargs):
super(BitLinear, self).__init__(*kargs, **kwargs)
"""
......@@ -224,10 +220,8 @@ class BitLinear(nn.Linear):
self.input_bits = input_bits
def forward(self, input):
quant_input = input + (activation_quant(input, self.input_bits) - input).detach()
quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) -
self.weight).detach()
quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach()
out = nn.functional.linear(quant_input, quant_weight)
if self.bias is not None:
......
......@@ -20,7 +20,7 @@ from transformers import (
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.config import TokenizerPoolConfig
from vllm.distributed import (destroy_distributed_environment, destroy_model_parallel)
from vllm.distributed import destroy_distributed_environment, destroy_model_parallel
from vllm.inputs import TextPrompt
from vllm.logger import init_logger
from vllm.sequence import SampleLogprobs
......@@ -56,12 +56,13 @@ else:
class _ImageAssets(_ImageAssetsBase):
def __init__(self) -> None:
super().__init__([
super().__init__(
[
ImageAsset("stop_sign"),
ImageAsset("cherry_blossom"),
])
]
)
def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
"""
......@@ -136,7 +137,6 @@ _T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding)
class HfRunner:
def wrap_device(self, input: _T) -> _T:
if not is_cpu():
return input.to("cuda")
......@@ -166,7 +166,8 @@ class HfRunner:
SentenceTransformer(
model_name,
device="cpu",
).to(dtype=torch_dtype))
).to(dtype=torch_dtype)
)
else:
if is_vision_model:
auto_cls = AutoModelForVision2Seq
......@@ -184,7 +185,8 @@ class HfRunner:
torch_dtype=torch_dtype,
trust_remote_code=True,
**model_kwargs,
))
)
)
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
......@@ -204,8 +206,7 @@ class HfRunner:
)
except Exception:
logger.warning(
"Unable to auto-load processor from HuggingFace for "
"model %s. Using tokenizer instead.",
"Unable to auto-load processor from HuggingFace for model %s. Using tokenizer instead.",
model_name,
)
self.processor = self.tokenizer
......@@ -362,7 +363,7 @@ class HfRunner:
last_hidden_states,
self.model.get_output_embeddings().weight.t(),
)
if (getattr(self.model.get_output_embeddings(), "bias", None) is not None):
if getattr(self.model.get_output_embeddings(), "bias", None) is not None:
logits += self.model.get_output_embeddings().bias.unsqueeze(0)
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
seq_logprobs.append(logprobs)
......@@ -389,8 +390,7 @@ class HfRunner:
all_output_strs.append(self.tokenizer.decode(output_ids))
outputs = zip(all_output_ids, all_output_strs, all_logprobs)
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]
return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs]
def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
return self.model.encode(prompts)
......@@ -409,7 +409,6 @@ def hf_runner():
class VllmRunner:
def __init__(
self,
model_name: str,
......@@ -514,12 +513,10 @@ class VllmRunner:
num_logprobs: int,
images: Optional[List[Image.Image]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(
temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs)
greedy_logprobs_params = SamplingParams(temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs)
outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params, images=images)
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]
return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs]
def generate_beam_search(
self,
......
......@@ -39,8 +39,7 @@ with VllmRunner(
# set enforce_eager = True to disable cuda graph
enforce_eager=False,
) as bitnet_model:
bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"],
max_tokens=1024)
bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"], max_tokens=1024)
print("bitnet inference:")
print(bitbnet_outputs[0][0])
print(bitbnet_outputs[0][1])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment