Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
467
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
514 additions
and
658 deletions
+514
-658
examples/attention_sink/example_mha_sink_fwd_bhsd.py
examples/attention_sink/example_mha_sink_fwd_bhsd.py
+92
-107
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py
...tention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py
+98
-112
examples/bitnet-1.58b/benchmark_generate.py
examples/bitnet-1.58b/benchmark_generate.py
+18
-17
examples/bitnet-1.58b/benchmark_inference_latency.py
examples/bitnet-1.58b/benchmark_inference_latency.py
+5
-4
examples/bitnet-1.58b/configuration_bitnet.py
examples/bitnet-1.58b/configuration_bitnet.py
+5
-11
examples/bitnet-1.58b/eval_correctness.py
examples/bitnet-1.58b/eval_correctness.py
+13
-9
examples/bitnet-1.58b/eval_gpu_memory.py
examples/bitnet-1.58b/eval_gpu_memory.py
+7
-6
examples/bitnet-1.58b/eval_ppl.py
examples/bitnet-1.58b/eval_ppl.py
+16
-12
examples/bitnet-1.58b/eval_utils.py
examples/bitnet-1.58b/eval_utils.py
+7
-13
examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py
.../kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py
+15
-21
examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py
...kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py
+34
-43
examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py
examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py
+11
-11
examples/bitnet-1.58b/load_from_quantized.py
examples/bitnet-1.58b/load_from_quantized.py
+7
-1
examples/bitnet-1.58b/maint/create_bitblas_ckpt.py
examples/bitnet-1.58b/maint/create_bitblas_ckpt.py
+15
-6
examples/bitnet-1.58b/modeling_bitnet.py
examples/bitnet-1.58b/modeling_bitnet.py
+109
-199
examples/bitnet-1.58b/tokenization_bitnet.py
examples/bitnet-1.58b/tokenization_bitnet.py
+23
-37
examples/bitnet-1.58b/utils_quant.py
examples/bitnet-1.58b/utils_quant.py
+9
-15
examples/bitnet-1.58b/vllm_workspace/conftest.py
examples/bitnet-1.58b/vllm_workspace/conftest.py
+16
-19
examples/bitnet-1.58b/vllm_workspace/inference_with_compress_format.py
...et-1.58b/vllm_workspace/inference_with_compress_format.py
+7
-8
examples/bitnet-1.58b/vllm_workspace/inference_with_native_format.py
...tnet-1.58b/vllm_workspace/inference_with_native_format.py
+7
-7
No files found.
examples/attention_sink/example_mha_sink_fwd_bhsd.py
View file @
29051439
...
@@ -18,9 +18,11 @@ def get_configs():
...
@@ -18,9 +18,11 @@ def get_configs():
@
autotune
(
configs
=
get_configs
(),
warmup
=
500
,
rep
=
100
)
@
autotune
(
configs
=
get_configs
(),
warmup
=
500
,
rep
=
100
)
@
tilelang
.
jit
(
@
tilelang
.
jit
(
out_idx
=
[
3
],
pass_configs
=
{
out_idx
=
[
3
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn
(
def
flashattn
(
batch
,
batch
,
heads
,
heads
,
...
@@ -33,12 +35,13 @@ def flashattn(
...
@@ -33,12 +35,13 @@ def flashattn(
block_N
=
64
,
block_N
=
64
,
num_stages
=
1
,
num_stages
=
1
,
threads
=
128
,
threads
=
128
,
dtype
:
str
=
"float16"
):
dtype
:
str
=
"float16"
,
):
if
window_size
is
not
None
:
if
window_size
is
not
None
:
assert
window_size
%
block_N
==
0
,
"window_size must be divisible by block_N"
assert
window_size
%
block_N
==
0
,
"window_size must be divisible by block_N"
if
sm_scale
is
None
:
if
sm_scale
is
None
:
sm_scale
=
(
1.0
/
dim
)
**
0.5
sm_scale
=
(
1.0
/
dim
)
**
0.5
scale
=
sm_scale
*
1.44269504
# log2(e)
scale
=
sm_scale
*
1.44269504
# log2(e)
q_shape
=
[
batch
,
heads
,
seq_q
,
dim
]
q_shape
=
[
batch
,
heads
,
seq_q
,
dim
]
kv_shape
=
[
batch
,
heads
,
seq_kv
,
dim
]
kv_shape
=
[
batch
,
heads
,
seq_kv
,
dim
]
...
@@ -58,13 +61,12 @@ def flashattn(
...
@@ -58,13 +61,12 @@ def flashattn(
by
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
bz
:
T
.
int32
,
):
):
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
q_idx
=
bx
*
block_M
+
i
+
past_len
q_idx
=
bx
*
block_M
+
i
+
past_len
k_idx
=
k
*
block_N
+
j
k_idx
=
k
*
block_N
+
j
if
window_size
is
not
None
:
if
window_size
is
not
None
:
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
and
q_idx
<
k_idx
+
window_size
,
0
,
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
and
q_idx
<
k_idx
+
window_size
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
else
:
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
@@ -79,7 +81,7 @@ def flashattn(
...
@@ -79,7 +81,7 @@ def flashattn(
by
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
bz
:
T
.
int32
,
):
):
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
@
T
.
macro
...
@@ -102,8 +104,7 @@ def flashattn(
...
@@ -102,8 +104,7 @@ def flashattn(
# NOTE(wt): check_inf is necessary for sliding window attention.
# NOTE(wt): check_inf is necessary for sliding window attention.
for
i
in
T
.
Parallel
(
block_M
):
for
i
in
T
.
Parallel
(
block_M
):
if
window_size
is
not
None
:
if
window_size
is
not
None
:
scores_max
[
i
]
=
T
.
if_then_else
(
scores_max
[
i
]
==
-
T
.
infinity
(
accum_dtype
),
0
,
scores_max
[
i
]
=
T
.
if_then_else
(
scores_max
[
i
]
==
-
T
.
infinity
(
accum_dtype
),
0
,
scores_max
[
i
])
scores_max
[
i
])
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
...
@@ -147,53 +148,51 @@ def flashattn(
...
@@ -147,53 +148,51 @@ def flashattn(
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
sinks
=
T
.
alloc_fragment
([
block_M
],
dtype
)
sinks
=
T
.
alloc_fragment
([
block_M
],
dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
Q_shared
:
make_swizzled_layout
(
Q_shared
),
Q_shared
:
make_swizzled_layout
(
Q_shared
),
K_shared
:
make_swizzled_layout
(
K_shared
),
K_shared
:
make_swizzled_layout
(
K_shared
),
V_shared
:
make_swizzled_layout
(
V_shared
),
V_shared
:
make_swizzled_layout
(
V_shared
),
O_shared
:
make_swizzled_layout
(
O_shared
),
O_shared
:
make_swizzled_layout
(
O_shared
),
})
}
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
for
i
in
T
.
Parallel
(
block_M
):
for
i
in
T
.
Parallel
(
block_M
):
sinks
[
i
]
=
Sinks
[
by
]
sinks
[
i
]
=
Sinks
[
by
]
end
=
T
.
min
(
end
=
T
.
min
(
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
+
past_len
,
block_N
))
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
+
past_len
,
block_N
))
start
=
T
.
max
(
0
,
(
bx
*
block_M
+
past_len
-
window_size
)
//
start
=
T
.
max
(
0
,
(
bx
*
block_M
+
past_len
-
window_size
)
//
block_N
)
if
window_size
is
not
None
else
0
block_N
)
if
window_size
is
not
None
else
0
for
k
in
T
.
Pipelined
(
start
,
end
,
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
start
,
end
,
num_stages
=
num_stages
):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
logsum
)
Rescale
(
acc_o
,
scores_scale
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
in
T
.
Parallel
(
block_M
):
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
+=
T
.
exp2
(
sinks
[
i
]
*
1.44269504
-
logsum
[
i
]
+=
T
.
exp2
(
sinks
[
i
]
*
1.44269504
-
scores_max
[
i
]
*
scale
)
# The only change for attention sink
scores_max
[
i
]
*
scale
)
# The only change for attention sink
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
return
main
return
main
# Modified from https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
# Modified from https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def
ref_program
(
query
:
torch
.
Tensor
,
def
ref_program
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
sinks
:
torch
.
Tensor
,
sinks
:
torch
.
Tensor
,
sliding_window
:
Optional
[
int
]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
dtype
:
torch
.
dtype
=
torch
.
float16
)
->
torch
.
Tensor
:
dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
torch
.
Tensor
:
query
=
query
.
transpose
(
1
,
2
).
contiguous
().
unsqueeze
(
query
=
query
.
transpose
(
1
,
2
).
contiguous
().
unsqueeze
(
3
)
# align with the original function's interface
3
)
# align with the original function's interface
key
=
key
.
transpose
(
1
,
2
).
contiguous
()
key
=
key
.
transpose
(
1
,
2
).
contiguous
()
value
=
value
.
transpose
(
1
,
2
).
contiguous
()
value
=
value
.
transpose
(
1
,
2
).
contiguous
()
...
@@ -228,41 +227,35 @@ def ref_program(query: torch.Tensor,
...
@@ -228,41 +227,35 @@ def ref_program(query: torch.Tensor,
output
=
torch
.
einsum
(
"bhmqk,bkhmd->bqhmd"
,
scores
,
value
.
float
())
output
=
torch
.
einsum
(
"bhmqk,bkhmd->bqhmd"
,
scores
,
value
.
float
())
output
=
output
.
reshape
(
batch_size
,
num_queries
,
num_key_value_heads
*
num_key_value_groups
,
output
=
output
.
reshape
(
batch_size
,
num_queries
,
num_key_value_heads
*
num_key_value_groups
,
head_dim
).
to
(
dtype
)
head_dim
).
to
(
dtype
)
return
output
.
transpose
(
1
,
2
).
contiguous
()
return
output
.
transpose
(
1
,
2
).
contiguous
()
def
gen_inputs
(
def
gen_inputs
(
B
,
H
,
Sq
,
Skv
,
D
,
dtype
=
torch
.
float16
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
B
,
query
=
torch
.
randn
([
B
,
H
,
Sq
,
D
],
dtype
=
dtype
,
device
=
"cuda"
)
H
,
key
=
torch
.
randn
([
B
,
H
,
Skv
,
D
],
dtype
=
dtype
,
device
=
"cuda"
)
Sq
,
value
=
torch
.
randn
([
B
,
H
,
Skv
,
D
],
dtype
=
dtype
,
device
=
"cuda"
)
Skv
,
sinks
=
torch
.
randn
([
H
],
dtype
=
dtype
,
device
=
"cuda"
)
D
,
dtype
=
torch
.
float16
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
query
=
torch
.
randn
([
B
,
H
,
Sq
,
D
],
dtype
=
dtype
,
device
=
'cuda'
)
key
=
torch
.
randn
([
B
,
H
,
Skv
,
D
],
dtype
=
dtype
,
device
=
'cuda'
)
value
=
torch
.
randn
([
B
,
H
,
Skv
,
D
],
dtype
=
dtype
,
device
=
'cuda'
)
sinks
=
torch
.
randn
([
H
],
dtype
=
dtype
,
device
=
'cuda'
)
return
query
,
key
,
value
,
sinks
return
query
,
key
,
value
,
sinks
def
main
(
batch
:
int
=
1
,
def
main
(
batch
:
int
=
1
,
heads
:
int
=
1
,
heads
:
int
=
1
,
seq_q
:
int
=
256
,
seq_q
:
int
=
256
,
seq_kv
:
int
=
256
,
seq_kv
:
int
=
256
,
dim
:
int
=
128
,
dim
:
int
=
128
,
window_size
:
Optional
[
int
]
=
None
,
window_size
:
Optional
[
int
]
=
None
,
dtype
:
str
=
"float16"
,
dtype
:
str
=
"float16"
,
tune
:
bool
=
False
):
tune
:
bool
=
False
,
):
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
if
window_size
is
not
None
:
if
window_size
is
not
None
:
print
(
'
Using sliding window attention.
'
)
print
(
"
Using sliding window attention.
"
)
assert
window_size
<=
seq_q
assert
window_size
<=
seq_q
flops_per_matmul
=
2.0
*
batch
*
heads
*
min
(
flops_per_matmul
=
2.0
*
batch
*
heads
*
min
(
window_size
,
seq_kv
//
2
)
*
seq_q
*
dim
# just a rough estimation
window_size
,
seq_kv
//
2
)
*
seq_q
*
dim
# just a rough estimation
else
:
else
:
print
(
'
Using full attention.
'
)
print
(
"
Using full attention.
"
)
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_q
*
seq_kv
*
dim
*
0.5
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_q
*
seq_kv
*
dim
*
0.5
total_flops
=
2
*
flops_per_matmul
total_flops
=
2
*
flops_per_matmul
...
@@ -289,19 +282,17 @@ def main(batch: int = 1,
...
@@ -289,19 +282,17 @@ def main(batch: int = 1,
block_N
=
block_N
,
block_N
=
block_N
,
num_stages
=
num_stages
,
num_stages
=
num_stages
,
threads
=
threads
,
threads
=
threads
,
dtype
=
dtype
)
dtype
=
dtype
,
)
Q
,
K
,
V
,
sinks
=
gen_inputs
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
dtype
=
torch_dtype
)
Q
,
K
,
V
,
sinks
=
gen_inputs
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
dtype
=
torch_dtype
)
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
kernel
(
Q
,
K
,
V
,
sinks
),
kernel
(
Q
,
K
,
V
,
sinks
),
ref_program
(
Q
,
K
,
V
,
sinks
,
window_size
,
dtype
=
torch_dtype
),
rtol
=
1e-2
,
atol
=
1e-2
ref_program
(
Q
,
K
,
V
,
sinks
,
window_size
,
dtype
=
torch_dtype
),
)
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
"All checks passed.✅"
)
print
(
"All checks passed.✅"
)
latency
=
do_bench
(
latency
=
do_bench
(
lambda
:
ref_program
(
Q
,
K
,
V
,
sinks
,
window_size
,
dtype
=
torch_dtype
),
warmup
=
500
)
lambda
:
ref_program
(
Q
,
K
,
V
,
sinks
,
window_size
,
dtype
=
torch_dtype
),
warmup
=
500
)
print
(
"Ref: {:.2f} ms"
.
format
(
latency
))
print
(
"Ref: {:.2f} ms"
.
format
(
latency
))
print
(
"Ref: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
print
(
"Ref: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
latency
=
do_bench
(
lambda
:
kernel
(
Q
,
K
,
V
,
sinks
),
warmup
=
500
)
latency
=
do_bench
(
lambda
:
kernel
(
Q
,
K
,
V
,
sinks
),
warmup
=
500
)
...
@@ -311,19 +302,13 @@ def main(batch: int = 1,
...
@@ -311,19 +302,13 @@ def main(batch: int = 1,
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'batch size'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
8
,
help
=
"batch size"
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
32
,
help
=
'heads'
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
32
,
help
=
"heads"
)
parser
.
add_argument
(
'--seq_q'
,
type
=
int
,
default
=
4096
,
help
=
'sequence length of query'
)
parser
.
add_argument
(
"--seq_q"
,
type
=
int
,
default
=
4096
,
help
=
"sequence length of query"
)
parser
.
add_argument
(
'--seq_kv'
,
type
=
int
,
default
=
4096
,
help
=
'sequence length of key/value'
)
parser
.
add_argument
(
"--seq_kv"
,
type
=
int
,
default
=
4096
,
help
=
"sequence length of key/value"
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--window_size"
,
type
=
int
,
default
=
None
,
help
=
"window size (default: None, which means full attention)"
)
'--window_size'
,
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"float16"
,
help
=
"dtype, can be float16 or bfloat16"
)
type
=
int
,
parser
.
add_argument
(
"--tune"
,
action
=
"store_true"
,
help
=
"tune"
)
default
=
None
,
help
=
'window size (default: None, which means full attention)'
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
"float16"
,
help
=
"dtype, can be float16 or bfloat16"
)
parser
.
add_argument
(
'--tune'
,
action
=
'store_true'
,
help
=
'tune'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_q
,
args
.
seq_kv
,
args
.
dim
,
args
.
window_size
,
args
.
dtype
,
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_q
,
args
.
seq_kv
,
args
.
dim
,
args
.
window_size
,
args
.
dtype
,
args
.
tune
)
args
.
tune
)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py
View file @
29051439
...
@@ -19,9 +19,11 @@ def get_configs():
...
@@ -19,9 +19,11 @@ def get_configs():
@
autotune
(
configs
=
get_configs
(),
warmup
=
500
,
rep
=
100
)
@
autotune
(
configs
=
get_configs
(),
warmup
=
500
,
rep
=
100
)
@
tilelang
.
jit
(
@
tilelang
.
jit
(
out_idx
=
[
3
],
pass_configs
=
{
out_idx
=
[
3
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
flashattn
(
def
flashattn
(
batch
,
batch
,
heads
,
heads
,
...
@@ -34,13 +36,13 @@ def flashattn(
...
@@ -34,13 +36,13 @@ def flashattn(
block_N
=
128
,
block_N
=
128
,
num_stages
=
2
,
num_stages
=
2
,
threads
=
256
,
threads
=
256
,
dtype
:
str
=
"float16"
):
dtype
:
str
=
"float16"
,
):
if
window_size
is
not
None
:
if
window_size
is
not
None
:
assert
window_size
%
block_N
==
0
,
"window_size must be divisible by block_N"
assert
window_size
%
block_N
==
0
,
"window_size must be divisible by block_N"
if
sm_scale
is
None
:
if
sm_scale
is
None
:
sm_scale
=
(
1.0
/
dim
)
**
0.5
sm_scale
=
(
1.0
/
dim
)
**
0.5
scale
=
sm_scale
*
1.44269504
# log2(e)
scale
=
sm_scale
*
1.44269504
# log2(e)
q_shape
=
[
batch
,
heads
,
seq_q
,
dim
]
q_shape
=
[
batch
,
heads
,
seq_q
,
dim
]
...
@@ -61,13 +63,12 @@ def flashattn(
...
@@ -61,13 +63,12 @@ def flashattn(
by
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
bz
:
T
.
int32
,
):
):
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
q_idx
=
bx
*
block_M
+
i
+
past_len
q_idx
=
bx
*
block_M
+
i
+
past_len
k_idx
=
k
*
block_N
+
j
k_idx
=
k
*
block_N
+
j
if
window_size
is
not
None
:
if
window_size
is
not
None
:
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
and
q_idx
<
k_idx
+
window_size
,
0
,
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
and
q_idx
<
k_idx
+
window_size
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
else
:
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
@@ -82,7 +83,7 @@ def flashattn(
...
@@ -82,7 +83,7 @@ def flashattn(
by
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
bz
:
T
.
int32
,
):
):
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
@
T
.
macro
...
@@ -105,8 +106,7 @@ def flashattn(
...
@@ -105,8 +106,7 @@ def flashattn(
# NOTE(wt): check_inf is necessary for sliding window attention.
# NOTE(wt): check_inf is necessary for sliding window attention.
for
i
in
T
.
Parallel
(
block_M
):
for
i
in
T
.
Parallel
(
block_M
):
if
window_size
is
not
None
:
if
window_size
is
not
None
:
scores_max
[
i
]
=
T
.
if_then_else
(
scores_max
[
i
]
==
-
T
.
infinity
(
accum_dtype
),
0
,
scores_max
[
i
]
=
T
.
if_then_else
(
scores_max
[
i
]
==
-
T
.
infinity
(
accum_dtype
),
0
,
scores_max
[
i
])
scores_max
[
i
])
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
...
@@ -150,25 +150,25 @@ def flashattn(
...
@@ -150,25 +150,25 @@ def flashattn(
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
sinks
=
T
.
alloc_fragment
([
block_M
],
dtype
)
sinks
=
T
.
alloc_fragment
([
block_M
],
dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
Q_shared
:
make_swizzled_layout
(
Q_shared
),
Q_shared
:
make_swizzled_layout
(
Q_shared
),
K_shared
:
make_swizzled_layout
(
K_shared
),
K_shared
:
make_swizzled_layout
(
K_shared
),
V_shared
:
make_swizzled_layout
(
V_shared
),
V_shared
:
make_swizzled_layout
(
V_shared
),
O_shared
:
make_swizzled_layout
(
O_shared
),
O_shared
:
make_swizzled_layout
(
O_shared
),
})
}
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
for
i
in
T
.
Parallel
(
block_M
):
for
i
in
T
.
Parallel
(
block_M
):
sinks
[
i
]
=
Sinks
[
by
]
sinks
[
i
]
=
Sinks
[
by
]
end
=
T
.
min
(
end
=
T
.
min
(
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
+
past_len
,
block_N
))
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
+
past_len
,
block_N
))
start
=
T
.
max
(
0
,
(
bx
*
block_M
+
past_len
-
window_size
)
//
start
=
T
.
max
(
0
,
(
bx
*
block_M
+
past_len
-
window_size
)
//
block_N
)
if
window_size
is
not
None
else
0
block_N
)
if
window_size
is
not
None
else
0
for
k
in
T
.
Pipelined
(
for
k
in
T
.
Pipelined
(
start
,
start
,
...
@@ -176,34 +176,33 @@ def flashattn(
...
@@ -176,34 +176,33 @@ def flashattn(
num_stages
=
num_stages
,
num_stages
=
num_stages
,
order
=
[
-
1
,
0
,
3
,
1
,
-
1
,
2
],
order
=
[
-
1
,
0
,
3
,
1
,
-
1
,
2
],
stage
=
[
-
1
,
0
,
0
,
1
,
-
1
,
1
],
stage
=
[
-
1
,
0
,
0
,
1
,
-
1
,
1
],
group
=
[[
0
],
[
1
,
2
],
[
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
],
[
12
],
[
13
],
[
14
]]):
group
=
[[
0
],
[
1
,
2
],
[
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
],
[
12
],
[
13
],
[
14
]],
):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
logsum
)
Rescale
(
acc_o
,
scores_scale
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
in
T
.
Parallel
(
block_M
):
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
+=
T
.
exp2
(
sinks
[
i
]
*
1.44269504
-
logsum
[
i
]
+=
T
.
exp2
(
sinks
[
i
]
*
1.44269504
-
scores_max
[
i
]
*
scale
)
# The only change for attention sink
scores_max
[
i
]
*
scale
)
# The only change for attention sink
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
return
main
return
main
# Following functions are adapted and optimized from
# Following functions are adapted and optimized from
# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def
ref_program
(
query
:
torch
.
Tensor
,
def
ref_program
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
sinks
:
torch
.
Tensor
,
sinks
:
torch
.
Tensor
,
sliding_window
:
Optional
[
int
]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
dtype
:
torch
.
dtype
=
torch
.
float16
)
->
torch
.
Tensor
:
dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
torch
.
Tensor
:
query
=
query
.
transpose
(
1
,
2
).
contiguous
().
unsqueeze
(
query
=
query
.
transpose
(
1
,
2
).
contiguous
().
unsqueeze
(
3
)
# align with the original function'sinterface
3
)
# align with the original function'sinterface
key
=
key
.
transpose
(
1
,
2
).
contiguous
()
key
=
key
.
transpose
(
1
,
2
).
contiguous
()
value
=
value
.
transpose
(
1
,
2
).
contiguous
()
value
=
value
.
transpose
(
1
,
2
).
contiguous
()
...
@@ -238,41 +237,35 @@ def ref_program(query: torch.Tensor,
...
@@ -238,41 +237,35 @@ def ref_program(query: torch.Tensor,
output
=
torch
.
einsum
(
"bhmqk,bkhmd->bqhmd"
,
scores
,
value
.
float
())
output
=
torch
.
einsum
(
"bhmqk,bkhmd->bqhmd"
,
scores
,
value
.
float
())
output
=
output
.
reshape
(
batch_size
,
num_queries
,
num_key_value_heads
*
num_key_value_groups
,
output
=
output
.
reshape
(
batch_size
,
num_queries
,
num_key_value_heads
*
num_key_value_groups
,
head_dim
).
to
(
dtype
)
head_dim
).
to
(
dtype
)
return
output
.
transpose
(
1
,
2
).
contiguous
()
return
output
.
transpose
(
1
,
2
).
contiguous
()
def
gen_inputs
(
def
gen_inputs
(
B
,
H
,
Sq
,
Skv
,
D
,
dtype
=
torch
.
float16
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
B
,
query
=
torch
.
randn
([
B
,
H
,
Sq
,
D
],
dtype
=
dtype
,
device
=
"cuda"
)
H
,
key
=
torch
.
randn
([
B
,
H
,
Skv
,
D
],
dtype
=
dtype
,
device
=
"cuda"
)
Sq
,
value
=
torch
.
randn
([
B
,
H
,
Skv
,
D
],
dtype
=
dtype
,
device
=
"cuda"
)
Skv
,
sinks
=
torch
.
randn
([
H
],
dtype
=
dtype
,
device
=
"cuda"
)
D
,
dtype
=
torch
.
float16
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
query
=
torch
.
randn
([
B
,
H
,
Sq
,
D
],
dtype
=
dtype
,
device
=
'cuda'
)
key
=
torch
.
randn
([
B
,
H
,
Skv
,
D
],
dtype
=
dtype
,
device
=
'cuda'
)
value
=
torch
.
randn
([
B
,
H
,
Skv
,
D
],
dtype
=
dtype
,
device
=
'cuda'
)
sinks
=
torch
.
randn
([
H
],
dtype
=
dtype
,
device
=
'cuda'
)
return
query
,
key
,
value
,
sinks
return
query
,
key
,
value
,
sinks
def
main
(
batch
:
int
=
1
,
def
main
(
batch
:
int
=
1
,
heads
:
int
=
32
,
heads
:
int
=
32
,
seq_q
:
int
=
256
,
seq_q
:
int
=
256
,
seq_kv
:
int
=
256
,
seq_kv
:
int
=
256
,
dim
:
int
=
128
,
dim
:
int
=
128
,
window_size
:
Optional
[
int
]
=
None
,
window_size
:
Optional
[
int
]
=
None
,
dtype
:
str
=
"float16"
,
dtype
:
str
=
"float16"
,
tune
:
bool
=
False
):
tune
:
bool
=
False
,
):
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
torch_dtype
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
}[
dtype
]
if
window_size
is
not
None
:
if
window_size
is
not
None
:
print
(
'
Using sliding window attention.
'
)
print
(
"
Using sliding window attention.
"
)
assert
window_size
<=
seq_q
assert
window_size
<=
seq_q
flops_per_matmul
=
2.0
*
batch
*
heads
*
min
(
flops_per_matmul
=
2.0
*
batch
*
heads
*
min
(
window_size
,
seq_kv
//
2
)
*
seq_q
*
dim
# just a rough estimation
window_size
,
seq_kv
//
2
)
*
seq_q
*
dim
# just a rough estimation
else
:
else
:
print
(
'
Using full attention.
'
)
print
(
"
Using full attention.
"
)
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_q
*
seq_kv
*
dim
*
0.5
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_q
*
seq_kv
*
dim
*
0.5
total_flops
=
2
*
flops_per_matmul
total_flops
=
2
*
flops_per_matmul
...
@@ -299,15 +292,14 @@ def main(batch: int = 1,
...
@@ -299,15 +292,14 @@ def main(batch: int = 1,
block_N
=
block_N
,
block_N
=
block_N
,
num_stages
=
num_stages
,
num_stages
=
num_stages
,
threads
=
threads
,
threads
=
threads
,
dtype
=
dtype
)
dtype
=
dtype
,
)
Q
,
K
,
V
,
sinks
=
gen_inputs
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
dtype
=
torch_dtype
)
Q
,
K
,
V
,
sinks
=
gen_inputs
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
dtype
=
torch_dtype
)
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
kernel
(
Q
,
K
,
V
,
sinks
),
kernel
(
Q
,
K
,
V
,
sinks
),
ref_program
(
Q
,
K
,
V
,
sinks
,
window_size
,
dtype
=
torch_dtype
),
rtol
=
1e-2
,
atol
=
1e-2
ref_program
(
Q
,
K
,
V
,
sinks
,
window_size
,
dtype
=
torch_dtype
),
)
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
"All checks passed.✅"
)
print
(
"All checks passed.✅"
)
latency
=
do_bench
(
lambda
:
kernel
(
Q
,
K
,
V
,
sinks
),
warmup
=
500
)
latency
=
do_bench
(
lambda
:
kernel
(
Q
,
K
,
V
,
sinks
),
warmup
=
500
)
...
@@ -317,19 +309,13 @@ def main(batch: int = 1,
...
@@ -317,19 +309,13 @@ def main(batch: int = 1,
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'batch size'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
8
,
help
=
"batch size"
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
32
,
help
=
'heads'
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
32
,
help
=
"heads"
)
parser
.
add_argument
(
'--seq_q'
,
type
=
int
,
default
=
4096
,
help
=
'sequence length of query'
)
parser
.
add_argument
(
"--seq_q"
,
type
=
int
,
default
=
4096
,
help
=
"sequence length of query"
)
parser
.
add_argument
(
'--seq_kv'
,
type
=
int
,
default
=
4096
,
help
=
'sequence length of key/value'
)
parser
.
add_argument
(
"--seq_kv"
,
type
=
int
,
default
=
4096
,
help
=
"sequence length of key/value"
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--window_size"
,
type
=
int
,
default
=
None
,
help
=
"window size (default: None, which means full attention)"
)
'--window_size'
,
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"float16"
,
help
=
"dtype, can be float16 or bfloat16"
)
type
=
int
,
parser
.
add_argument
(
"--tune"
,
action
=
"store_true"
,
help
=
"tune"
)
default
=
None
,
help
=
'window size (default: None, which means full attention)'
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
"float16"
,
help
=
"dtype, can be float16 or bfloat16"
)
parser
.
add_argument
(
'--tune'
,
action
=
'store_true'
,
help
=
'tune'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_q
,
args
.
seq_kv
,
args
.
dim
,
args
.
window_size
,
args
.
dtype
,
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_q
,
args
.
seq_kv
,
args
.
dim
,
args
.
window_size
,
args
.
dtype
,
args
.
tune
)
args
.
tune
)
examples/bitnet-1.58b/benchmark_generate.py
View file @
29051439
...
@@ -12,8 +12,7 @@ bitblas.set_log_level("INFO")
...
@@ -12,8 +12,7 @@ bitblas.set_log_level("INFO")
def
generate_text_batch
(
model
,
tokenizer
,
prompts
,
max_length
=
100
):
def
generate_text_batch
(
model
,
tokenizer
,
prompts
,
max_length
=
100
):
# Encode the input prompts as a batch
# Encode the input prompts as a batch
input_ids
=
tokenizer
(
input_ids
=
tokenizer
(
prompts
,
return_tensors
=
"pt"
,
padding
=
True
,
truncation
=
True
).
input_ids
.
to
(
model
.
device
)
prompts
,
return_tensors
=
"pt"
,
padding
=
True
,
truncation
=
True
).
input_ids
.
to
(
model
.
device
)
# Generate cos and sin values (commented out as not used in generation)
# Generate cos and sin values (commented out as not used in generation)
seq_length
=
input_ids
.
size
(
1
)
seq_length
=
input_ids
.
size
(
1
)
...
@@ -37,9 +36,7 @@ def generate_text_batch(model, tokenizer, prompts, max_length=100):
...
@@ -37,9 +36,7 @@ def generate_text_batch(model, tokenizer, prompts, max_length=100):
end_time
=
time
.
time
()
end_time
=
time
.
time
()
# Decode the output ids to text
# Decode the output ids to text
generated_texts
=
[
generated_texts
=
[
tokenizer
.
decode
(
output_id
,
skip_special_tokens
=
True
)
for
output_id
in
output_ids
]
tokenizer
.
decode
(
output_id
,
skip_special_tokens
=
True
)
for
output_id
in
output_ids
]
generation_time
=
end_time
-
start_time
generation_time
=
end_time
-
start_time
num_tokens
=
sum
(
len
(
output_id
)
for
output_id
in
output_ids
)
num_tokens
=
sum
(
len
(
output_id
)
for
output_id
in
output_ids
)
...
@@ -52,8 +49,8 @@ def generate_text_batch(model, tokenizer, prompts, max_length=100):
...
@@ -52,8 +49,8 @@ def generate_text_batch(model, tokenizer, prompts, max_length=100):
def
profile
(
model
,
input_data
):
def
profile
(
model
,
input_data
):
import
numpy
as
np
import
numpy
as
np
model
=
model
.
cuda
()
model
=
model
.
cuda
()
model
.
eval
()
model
.
eval
()
...
@@ -74,25 +71,29 @@ def profile(model, input_data):
...
@@ -74,25 +71,29 @@ def profile(model, input_data):
return
np
.
mean
(
times
)
return
np
.
mean
(
times
)
model_path
=
'
1bitLLM/bitnet_b1_58-3B
'
model_path
=
"
1bitLLM/bitnet_b1_58-3B
"
def
main
():
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--bs
'
,
default
=
16
,
type
=
int
)
parser
.
add_argument
(
"
--bs
"
,
default
=
16
,
type
=
int
)
parser
.
add_argument
(
'
--in_seq_len
'
,
default
=
32
,
type
=
int
)
parser
.
add_argument
(
"
--in_seq_len
"
,
default
=
32
,
type
=
int
)
parser
.
add_argument
(
'
--out_seq_len
'
,
default
=
128
,
type
=
int
)
parser
.
add_argument
(
"
--out_seq_len
"
,
default
=
128
,
type
=
int
)
parser
.
add_argument
(
'
--bitblas
'
,
action
=
'
store_true
'
)
parser
.
add_argument
(
"
--bitblas
"
,
action
=
"
store_true
"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
bs
=
args
.
bs
bs
=
args
.
bs
in_seq_len
=
args
.
in_seq_len
in_seq_len
=
args
.
in_seq_len
out_seq_len
=
args
.
out_seq_len
out_seq_len
=
args
.
out_seq_len
is_bitblas
=
args
.
bitblas
is_bitblas
=
args
.
bitblas
model
=
BitnetForCausalLM
.
from_pretrained
(
model
=
(
BitnetForCausalLM
.
from_pretrained
(
model_path
,
model_path
,
use_flash_attention_2
=
True
,
use_flash_attention_2
=
True
,
torch_dtype
=
torch
.
float16
,
torch_dtype
=
torch
.
float16
,
).
cuda
().
half
()
)
.
cuda
()
.
half
()
)
if
is_bitblas
:
if
is_bitblas
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
.
quantize
()
model
.
quantize
()
...
@@ -109,5 +110,5 @@ def main():
...
@@ -109,5 +110,5 @@ def main():
print
(
generate_text_batch
(
model
,
tokenizer
,
prompts
,
max_length
=
max_length
))
print
(
generate_text_batch
(
model
,
tokenizer
,
prompts
,
max_length
=
max_length
))
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
main
()
main
()
examples/bitnet-1.58b/benchmark_inference_latency.py
View file @
29051439
...
@@ -6,13 +6,14 @@ from modeling_bitnet import BitnetForCausalLM
...
@@ -6,13 +6,14 @@ from modeling_bitnet import BitnetForCausalLM
torch
.
set_grad_enabled
(
False
)
torch
.
set_grad_enabled
(
False
)
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--hf_path
'
,
default
=
'
1bitLLM/bitnet_b1_58-3B
'
,
type
=
str
)
parser
.
add_argument
(
"
--hf_path
"
,
default
=
"
1bitLLM/bitnet_b1_58-3B
"
,
type
=
str
)
def
profile
(
model
,
input_data
):
def
profile
(
model
,
input_data
):
import
time
import
time
import
numpy
as
np
import
numpy
as
np
model
=
model
.
cuda
()
model
=
model
.
cuda
()
model
.
eval
()
model
.
eval
()
...
@@ -35,8 +36,8 @@ def profile(model, input_data):
...
@@ -35,8 +36,8 @@ def profile(model, input_data):
def
main
():
def
main
():
model
=
BitnetForCausalLM
.
from_pretrained
(
model
=
BitnetForCausalLM
.
from_pretrained
(
'
1bitLLM/bitnet_b1_58-3B
'
,
"
1bitLLM/bitnet_b1_58-3B
"
,
device_map
=
'
auto
'
,
device_map
=
"
auto
"
,
low_cpu_mem_usage
=
True
,
low_cpu_mem_usage
=
True
,
use_flash_attention_2
=
True
,
use_flash_attention_2
=
True
,
torch_dtype
=
torch
.
float16
,
torch_dtype
=
torch
.
float16
,
...
@@ -52,5 +53,5 @@ def main():
...
@@ -52,5 +53,5 @@ def main():
print
(
f
"Batch size:
{
batch_size
}
, Seq len:
{
seq_len
}
, Latency:
{
latency
}
"
)
print
(
f
"Batch size:
{
batch_size
}
, Seq len:
{
seq_len
}
, Latency:
{
latency
}
"
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
main
()
main
()
examples/bitnet-1.58b/configuration_bitnet.py
View file @
29051439
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""
LLaMA model configuration"""
"""LLaMA model configuration"""
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.utils
import
logging
from
transformers.utils
import
logging
...
@@ -180,16 +180,10 @@ class BitnetConfig(PretrainedConfig):
...
@@ -180,16 +180,10 @@ class BitnetConfig(PretrainedConfig):
return
return
if
not
isinstance
(
self
.
rope_scaling
,
dict
)
or
len
(
self
.
rope_scaling
)
!=
2
:
if
not
isinstance
(
self
.
rope_scaling
,
dict
)
or
len
(
self
.
rope_scaling
)
!=
2
:
raise
ValueError
(
raise
ValueError
(
f
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, got
{
self
.
rope_scaling
}
"
)
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
f
"got
{
self
.
rope_scaling
}
"
)
rope_scaling_type
=
self
.
rope_scaling
.
get
(
"type"
,
None
)
rope_scaling_type
=
self
.
rope_scaling
.
get
(
"type"
,
None
)
rope_scaling_factor
=
self
.
rope_scaling
.
get
(
"factor"
,
None
)
rope_scaling_factor
=
self
.
rope_scaling
.
get
(
"factor"
,
None
)
if
rope_scaling_type
is
None
or
rope_scaling_type
not
in
[
"linear"
,
"dynamic"
]:
if
rope_scaling_type
is
None
or
rope_scaling_type
not
in
[
"linear"
,
"dynamic"
]:
raise
ValueError
(
raise
ValueError
(
f
"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got
{
rope_scaling_type
}
"
)
f
"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got
{
rope_scaling_type
}
"
if
rope_scaling_factor
is
None
or
not
isinstance
(
rope_scaling_factor
,
float
)
or
rope_scaling_factor
<=
1.0
:
)
raise
ValueError
(
f
"`rope_scaling`'s factor field must be a float > 1, got
{
rope_scaling_factor
}
"
)
if
rope_scaling_factor
is
None
or
not
isinstance
(
rope_scaling_factor
,
float
)
or
rope_scaling_factor
<=
1.0
:
raise
ValueError
(
f
"`rope_scaling`'s factor field must be a float > 1, got
{
rope_scaling_factor
}
"
)
examples/bitnet-1.58b/eval_correctness.py
View file @
29051439
...
@@ -47,8 +47,8 @@ def generate_text(model, tokenizer, prompt, max_length=100):
...
@@ -47,8 +47,8 @@ def generate_text(model, tokenizer, prompt, max_length=100):
def
profile
(
model
,
input_data
):
def
profile
(
model
,
input_data
):
import
numpy
as
np
import
numpy
as
np
model
=
model
.
cuda
()
model
=
model
.
cuda
()
model
.
eval
()
model
.
eval
()
...
@@ -69,18 +69,22 @@ def profile(model, input_data):
...
@@ -69,18 +69,22 @@ def profile(model, input_data):
return
np
.
mean
(
times
)
return
np
.
mean
(
times
)
model_path
=
'
1bitLLM/bitnet_b1_58-3B
'
model_path
=
"
1bitLLM/bitnet_b1_58-3B
"
def
main
():
def
main
():
model
=
BitnetForCausalLM
.
from_pretrained
(
model
=
(
BitnetForCausalLM
.
from_pretrained
(
model_path
,
model_path
,
use_flash_attention_2
=
False
,
use_flash_attention_2
=
False
,
torch_dtype
=
torch
.
float16
,
torch_dtype
=
torch
.
float16
,
).
cuda
().
half
()
)
.
cuda
()
.
half
()
)
tokenizer
=
BitnetTokenizer
.
from_pretrained
(
model_path
,
use_fast
=
False
)
tokenizer
=
BitnetTokenizer
.
from_pretrained
(
model_path
,
use_fast
=
False
)
input_id
=
tokenizer
(
"Hello"
)[
'
input_ids
'
]
input_id
=
tokenizer
(
"Hello"
)[
"
input_ids
"
]
input_id
=
torch
.
tensor
(
input_id
).
unsqueeze
(
0
).
cuda
()
input_id
=
torch
.
tensor
(
input_id
).
unsqueeze
(
0
).
cuda
()
print
(
"original model generated text:"
)
print
(
"original model generated text:"
)
...
@@ -91,5 +95,5 @@ def main():
...
@@ -91,5 +95,5 @@ def main():
print
(
generate_text
(
model
,
tokenizer
,
"Hello"
,
max_length
=
100
))
print
(
generate_text
(
model
,
tokenizer
,
"Hello"
,
max_length
=
100
))
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
main
()
main
()
examples/bitnet-1.58b/eval_gpu_memory.py
View file @
29051439
...
@@ -6,13 +6,14 @@ from modeling_bitnet import BitnetForCausalLM
...
@@ -6,13 +6,14 @@ from modeling_bitnet import BitnetForCausalLM
torch
.
set_grad_enabled
(
False
)
torch
.
set_grad_enabled
(
False
)
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--hf_path
'
,
default
=
'
1bitLLM/bitnet_b1_58-3B
'
,
type
=
str
)
parser
.
add_argument
(
"
--hf_path
"
,
default
=
"
1bitLLM/bitnet_b1_58-3B
"
,
type
=
str
)
def
profile
(
model
,
input_data
):
def
profile
(
model
,
input_data
):
import
time
import
time
import
numpy
as
np
import
numpy
as
np
model
=
model
.
cuda
()
model
=
model
.
cuda
()
model
.
eval
()
model
.
eval
()
...
@@ -35,17 +36,17 @@ def profile(model, input_data):
...
@@ -35,17 +36,17 @@ def profile(model, input_data):
def
main
():
def
main
():
model
=
BitnetForCausalLM
.
from_pretrained
(
model
=
BitnetForCausalLM
.
from_pretrained
(
'
1bitLLM/bitnet_b1_58-3B
'
,
"
1bitLLM/bitnet_b1_58-3B
"
,
device_map
=
'
auto
'
,
device_map
=
"
auto
"
,
low_cpu_mem_usage
=
True
,
low_cpu_mem_usage
=
True
,
use_flash_attention_2
=
True
,
use_flash_attention_2
=
True
,
torch_dtype
=
torch
.
float16
,
torch_dtype
=
torch
.
float16
,
).
half
()
).
half
()
print
(
f
"gpu memory:
{
torch
.
cuda
.
memory_allocated
()
/
1024
**
3
}
GB"
)
print
(
f
"gpu memory:
{
torch
.
cuda
.
memory_allocated
()
/
1024
**
3
}
GB"
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
.
_post_process_weights
()
model
.
_post_process_weights
()
print
(
f
"gpu memory BitBLAS:
{
torch
.
cuda
.
memory_allocated
()
/
1024
**
3
}
GB"
)
print
(
f
"gpu memory BitBLAS:
{
torch
.
cuda
.
memory_allocated
()
/
1024
**
3
}
GB"
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
main
()
main
()
examples/bitnet-1.58b/eval_ppl.py
View file @
29051439
...
@@ -15,9 +15,9 @@ from tqdm import tqdm
...
@@ -15,9 +15,9 @@ from tqdm import tqdm
torch
.
set_grad_enabled
(
False
)
torch
.
set_grad_enabled
(
False
)
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--seed
'
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"
--seed
"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
'
--hf_path
'
,
default
=
'
1bitLLM/bitnet_b1_58-3B
'
,
type
=
str
)
parser
.
add_argument
(
"
--hf_path
"
,
default
=
"
1bitLLM/bitnet_b1_58-3B
"
,
type
=
str
)
parser
.
add_argument
(
'
--seqlen
'
,
default
=
2048
,
type
=
int
)
parser
.
add_argument
(
"
--seqlen
"
,
default
=
2048
,
type
=
int
)
def
calulate_loss
(
model
,
input
,
loss_fct
):
def
calulate_loss
(
model
,
input
,
loss_fct
):
...
@@ -29,12 +29,16 @@ def calulate_loss(model, input, loss_fct):
...
@@ -29,12 +29,16 @@ def calulate_loss(model, input, loss_fct):
def
main
(
args
):
def
main
(
args
):
datasets
=
[
'c4'
,
'wikitext2'
]
datasets
=
[
"c4"
,
"wikitext2"
]
model
=
BitnetForCausalLM
.
from_pretrained
(
model
=
(
BitnetForCausalLM
.
from_pretrained
(
args
.
hf_path
,
args
.
hf_path
,
use_flash_attention_2
=
True
,
use_flash_attention_2
=
True
,
torch_dtype
=
torch
.
float16
,
torch_dtype
=
torch
.
float16
,
).
cuda
().
half
()
)
.
cuda
()
.
half
()
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
.
_post_process_weights
()
model
.
_post_process_weights
()
tokenizer
=
BitnetTokenizer
.
from_pretrained
(
args
.
hf_path
,
use_fast
=
False
)
tokenizer
=
BitnetTokenizer
.
from_pretrained
(
args
.
hf_path
,
use_fast
=
False
)
...
@@ -48,9 +52,9 @@ def main(args):
...
@@ -48,9 +52,9 @@ def main(args):
for
ii
in
progress
:
for
ii
in
progress
:
input
=
torch
.
Tensor
(
testdata
[
ii
]).
long
().
cuda
().
view
(
1
,
-
1
)
input
=
torch
.
Tensor
(
testdata
[
ii
]).
long
().
cuda
().
view
(
1
,
-
1
)
loss
=
calulate_loss
(
model
,
input
,
loss_fct
)
loss
=
calulate_loss
(
model
,
input
,
loss_fct
)
count
+=
(
input
.
size
(
-
1
)
-
1
)
count
+=
input
.
size
(
-
1
)
-
1
acc_loss
+=
loss
.
item
()
acc_loss
+=
loss
.
item
()
progress
.
set_description
(
f
"avg_loss =
{
acc_loss
/
count
/
math
.
log
(
2
)
}
"
)
progress
.
set_description
(
f
"avg_loss =
{
acc_loss
/
count
/
math
.
log
(
2
)
}
"
)
avg_loss
=
acc_loss
/
count
/
math
.
log
(
2
)
avg_loss
=
acc_loss
/
count
/
math
.
log
(
2
)
ppl
.
append
(
2
**
avg_loss
)
ppl
.
append
(
2
**
avg_loss
)
...
@@ -60,7 +64,7 @@ def main(args):
...
@@ -60,7 +64,7 @@ def main(args):
print
(
"Avg PPL:"
,
sum
(
ppl
)
/
len
(
ppl
))
print
(
"Avg PPL:"
,
sum
(
ppl
)
/
len
(
ppl
))
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
torch
.
set_grad_enabled
(
False
)
torch
.
set_grad_enabled
(
False
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
random
.
seed
(
args
.
seed
)
random
.
seed
(
args
.
seed
)
...
...
examples/bitnet-1.58b/eval_utils.py
View file @
29051439
...
@@ -15,21 +15,17 @@ def set_seed(seed):
...
@@ -15,21 +15,17 @@ def set_seed(seed):
def
get_test_dataset
(
dataset_name
,
tokenizer
,
seqlen
=
2048
):
def
get_test_dataset
(
dataset_name
,
tokenizer
,
seqlen
=
2048
):
if
dataset_name
==
"wikitext2"
:
if
dataset_name
==
"wikitext2"
:
testdata
=
load_dataset
(
'
wikitext
'
,
'
wikitext-2-raw-v1
'
,
split
=
'
test
'
)
testdata
=
load_dataset
(
"
wikitext
"
,
"
wikitext-2-raw-v1
"
,
split
=
"
test
"
)
testdata
=
""
.
join
(
testdata
[
'
text
'
]).
split
(
'
\n
'
)
testdata
=
""
.
join
(
testdata
[
"
text
"
]).
split
(
"
\n
"
)
elif
dataset_name
==
"c4"
:
elif
dataset_name
==
"c4"
:
testdata
=
load_dataset
(
testdata
=
load_dataset
(
"allenai/c4"
,
data_files
=
{
"validation"
:
"en/c4-validation.00000-of-00008.json.gz"
},
split
=
"validation"
)[
'allenai/c4'
,
"text"
data_files
=
{
'validation'
:
'en/c4-validation.00000-of-00008.json.gz'
},
]
split
=
'validation'
)[
'text'
]
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
testdata
=
[
item
for
item
in
testdata
if
item
!=
""
]
testdata
=
[
item
for
item
in
testdata
if
item
!=
""
]
tokenized_text
=
[
tokenized_text
=
[
tokenizer
(
item
,
add_special_tokens
=
False
)[
"input_ids"
]
+
[
tokenizer
.
eos_token_id
]
for
item
in
testdata
]
tokenizer
(
item
,
add_special_tokens
=
False
)[
'input_ids'
]
+
[
tokenizer
.
eos_token_id
]
for
item
in
testdata
]
data
,
doc
=
[],
[
tokenizer
.
bos_token_id
]
data
,
doc
=
[],
[
tokenizer
.
bos_token_id
]
for
sen
in
tokenized_text
:
for
sen
in
tokenized_text
:
...
@@ -45,7 +41,6 @@ def get_test_dataset(dataset_name, tokenizer, seqlen=2048):
...
@@ -45,7 +41,6 @@ def get_test_dataset(dataset_name, tokenizer, seqlen=2048):
class
LMEvalAdaptor
(
BaseLM
):
class
LMEvalAdaptor
(
BaseLM
):
def
__init__
(
self
,
model_name
,
model
,
tokenizer
,
batch_size
=
1
,
max_length
=-
1
):
def
__init__
(
self
,
model_name
,
model
,
tokenizer
,
batch_size
=
1
,
max_length
=-
1
):
super
().
__init__
()
super
().
__init__
()
...
@@ -137,5 +132,4 @@ class LMEvalAdaptor(BaseLM):
...
@@ -137,5 +132,4 @@ class LMEvalAdaptor(BaseLM):
return
out
return
out
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
return
self
.
model
.
generate
(
return
self
.
model
.
generate
(
context
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
do_sample
=
False
)
context
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
do_sample
=
False
)
examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py
View file @
29051439
...
@@ -133,8 +133,7 @@ def bitnet_158_int8xint2_decode(
...
@@ -133,8 +133,7 @@ def bitnet_158_int8xint2_decode(
for
v
in
T
.
vectorized
(
micro_size_k_compressed
):
for
v
in
T
.
vectorized
(
micro_size_k_compressed
):
B_quant_local
[
v
]
=
B
[
B_quant_local
[
v
]
=
B
[
bx
*
n_partition
+
ni
,
bx
*
n_partition
+
ni
,
ko
*
(
reduce_thread
*
micro_size_k_compressed
)
+
ko
*
(
reduce_thread
*
micro_size_k_compressed
)
+
kr
*
micro_size_k_compressed
+
v
,
kr
*
micro_size_k_compressed
+
v
,
]
]
T
.
call_extern
(
T
.
call_extern
(
...
@@ -168,7 +167,8 @@ def bitnet_158_int8xint2_decode(
...
@@ -168,7 +167,8 @@ def bitnet_158_int8xint2_decode(
reduced_accum_res
[
0
],
reduced_accum_res
[
0
],
kr
,
kr
,
dtype
=
"handle"
,
dtype
=
"handle"
,
))
)
)
if
kr
==
0
:
if
kr
==
0
:
C
[
by
,
bx
*
n_partition
+
ni
]
=
reduced_accum_res
[
0
]
C
[
by
,
bx
*
n_partition
+
ni
]
=
reduced_accum_res
[
0
]
...
@@ -234,13 +234,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"):
...
@@ -234,13 +234,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"):
return
new_qweight
.
view
(
np
.
int8
)
return
new_qweight
.
view
(
np
.
int8
)
def
assert_bitnet_158_int8xint2_decode_correctness
(
M
,
def
assert_bitnet_158_int8xint2_decode_correctness
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
,
fast_decoding
=
True
):
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
,
fast_decoding
=
True
):
program
=
bitnet_158_int8xint2_decode
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
,
fast_decoding
)
program
=
bitnet_158_int8xint2_decode
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
,
fast_decoding
)
print
(
program
)
print
(
program
)
kernel
=
tilelang
.
compile
(
program
)
kernel
=
tilelang
.
compile
(
program
)
...
...
examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py
View file @
29051439
...
@@ -8,11 +8,13 @@ import tilelang.language as T
...
@@ -8,11 +8,13 @@ import tilelang.language as T
from
tilelang
import
tvm
as
tvm
from
tilelang
import
tvm
as
tvm
from
tvm
import
DataType
from
tvm
import
DataType
from
tilelang.intrinsics.mma_layout
import
(
from
tilelang.intrinsics.mma_layout
import
(
make_mma_swizzle_layout
as
make_swizzle_layout
,)
make_mma_swizzle_layout
as
make_swizzle_layout
,
)
import
numpy
as
np
import
numpy
as
np
from
tilelang.intrinsics.mma_macro_generator
import
(
from
tilelang.intrinsics.mma_macro_generator
import
(
INT4TensorCoreIntrinEmitter
,)
INT4TensorCoreIntrinEmitter
,
)
from
tilelang.transform
import
simplify_prim_func
from
tilelang.transform
import
simplify_prim_func
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
...
@@ -208,11 +210,9 @@ def bitnet_158_int8xint2_prefill(
...
@@ -208,11 +210,9 @@ def bitnet_158_int8xint2_prefill(
threads
=
threads
,
threads
=
threads
,
prelude
=
decode_i2s_to_i8s
,
prelude
=
decode_i2s_to_i8s
,
)
as
(
bx
,
by
):
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
storage_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
storage_dtype
,
scope
=
shared_scope
)
B_dequantize_shared
=
T
.
alloc_shared
(
B_dequantize_shared
=
T
.
alloc_shared
(
B_dequantize_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_dequantize_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
A_frag
=
T
.
alloc_local
((
warp_rows
*
fragement_size_a
),
in_dtype
)
A_frag
=
T
.
alloc_local
((
warp_rows
*
fragement_size_a
),
in_dtype
)
B_frag
=
T
.
alloc_local
((
warp_cols
*
fragement_size_b
),
in_dtype
)
B_frag
=
T
.
alloc_local
((
warp_cols
*
fragement_size_b
),
in_dtype
)
...
@@ -223,10 +223,12 @@ def bitnet_158_int8xint2_prefill(
...
@@ -223,10 +223,12 @@ def bitnet_158_int8xint2_prefill(
thread_bindings
=
T
.
thread_binding
(
0
,
threads
,
"threadIdx.x"
)
thread_bindings
=
T
.
thread_binding
(
0
,
threads
,
"threadIdx.x"
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
A_shared
:
make_swizzle_layout
(
A_shared
),
A_shared
:
make_swizzle_layout
(
A_shared
),
B_dequantize_shared
:
make_swizzle_layout
(
B_dequantize_shared
),
B_dequantize_shared
:
make_swizzle_layout
(
B_dequantize_shared
),
})
}
)
# Improve L2 Cache
# Improve L2 Cache
T
.
use_swizzle
(
panel_size
=
10
)
T
.
use_swizzle
(
panel_size
=
10
)
...
@@ -234,7 +236,6 @@ def bitnet_158_int8xint2_prefill(
...
@@ -234,7 +236,6 @@ def bitnet_158_int8xint2_prefill(
T
.
clear
(
C_frag
)
T
.
clear
(
C_frag
)
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
# Load A into shared memory
# Load A into shared memory
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
...
@@ -243,12 +244,9 @@ def bitnet_158_int8xint2_prefill(
...
@@ -243,12 +244,9 @@ def bitnet_158_int8xint2_prefill(
for
j
,
k
in
T
.
Parallel
(
block_N
,
block_K
//
num_elems_per_byte
):
for
j
,
k
in
T
.
Parallel
(
block_N
,
block_K
//
num_elems_per_byte
):
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
(
block_K
//
num_elems_per_byte
)
+
k
]
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
(
block_K
//
num_elems_per_byte
)
+
k
]
for
i
in
T
.
serial
(
block_N
*
block_K
//
num_elems_per_byte
//
for
i
in
T
.
serial
(
block_N
*
block_K
//
num_elems_per_byte
//
(
threads
*
local_size_compressed
)):
(
threads
*
local_size_compressed
)):
for
v
in
T
.
vectorized
(
0
,
local_size_compressed
):
for
v
in
T
.
vectorized
(
0
,
local_size_compressed
):
index
=
(
index
=
i
*
threads
*
local_size_compressed
+
thread_bindings
*
local_size_compressed
+
v
i
*
threads
*
local_size_compressed
+
thread_bindings
*
local_size_compressed
+
v
)
vi
,
vj
=
T
.
index_to_coordinates
(
index
,
B_shared_shape
)
vi
,
vj
=
T
.
index_to_coordinates
(
index
,
B_shared_shape
)
B_local
[
v
]
=
B_shared
[
vi
,
vj
]
B_local
[
v
]
=
B_shared
[
vi
,
vj
]
...
@@ -260,12 +258,11 @@ def bitnet_158_int8xint2_prefill(
...
@@ -260,12 +258,11 @@ def bitnet_158_int8xint2_prefill(
)
)
for
v
in
T
.
vectorized
(
0
,
local_size
):
for
v
in
T
.
vectorized
(
0
,
local_size
):
index
=
(
i
*
threads
*
local_size
+
thread_bindings
*
local_size
+
v
)
index
=
i
*
threads
*
local_size
+
thread_bindings
*
local_size
+
v
vi
,
vj
=
T
.
index_to_coordinates
(
index
,
B_dequantize_shared_shape
)
vi
,
vj
=
T
.
index_to_coordinates
(
index
,
B_dequantize_shared_shape
)
B_dequantize_shared
[
vi
,
vj
]
=
B_dequantize_local
[
v
]
B_dequantize_shared
[
vi
,
vj
]
=
B_dequantize_local
[
v
]
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load A into fragment
# Load A into fragment
mma_emitter
.
ldmatrix_a
(
mma_emitter
.
ldmatrix_a
(
A_frag
,
A_frag
,
...
@@ -360,13 +357,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"):
...
@@ -360,13 +357,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"):
return
new_qweight
.
view
(
np
.
int8
)
return
new_qweight
.
view
(
np
.
int8
)
def
assert_bitnet_158_int8xint2_prefill_correctness
(
M
,
def
assert_bitnet_158_int8xint2_prefill_correctness
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
,
fast_decoding
=
True
):
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
,
fast_decoding
=
True
):
program
=
bitnet_158_int8xint2_prefill
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
,
fast_decoding
)
program
=
bitnet_158_int8xint2_prefill
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
,
fast_decoding
)
print
(
program
)
print
(
program
)
kernel
=
tilelang
.
compile
(
program
)
kernel
=
tilelang
.
compile
(
program
)
...
...
examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py
View file @
29051439
...
@@ -6,7 +6,8 @@ from tvm import tl as TL
...
@@ -6,7 +6,8 @@ from tvm import tl as TL
import
tvm.tl.language
as
T
import
tvm.tl.language
as
T
from
bitblas.tl.utils
import
get_swizzle_layout
from
bitblas.tl.utils
import
get_swizzle_layout
from
bitblas.tl.mma_macro_generator
import
(
from
bitblas.tl.mma_macro_generator
import
(
TensorCoreIntrinEmitter
,)
TensorCoreIntrinEmitter
,
)
from
bitblas.base
import
simplify_prim_func
from
bitblas.base
import
simplify_prim_func
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
...
@@ -106,7 +107,6 @@ def tl_matmul(
...
@@ -106,7 +107,6 @@ def tl_matmul(
C
:
T
.
Buffer
((
M
,
N
),
out_dtype
),
C
:
T
.
Buffer
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
...
@@ -116,10 +116,12 @@ def tl_matmul(
...
@@ -116,10 +116,12 @@ def tl_matmul(
thread_bindings
=
T
.
thread_binding
(
0
,
threads
,
"threadIdx.x"
)
thread_bindings
=
T
.
thread_binding
(
0
,
threads
,
"threadIdx.x"
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
A_shared
:
make_swizzle_layout
(
A_shared
),
A_shared
:
make_swizzle_layout
(
A_shared
),
B_shared
:
make_swizzle_layout
(
B_shared
),
B_shared
:
make_swizzle_layout
(
B_shared
),
})
}
)
# Improve L2 Cache
# Improve L2 Cache
T
.
use_swizzle
(
panel_size
=
10
)
T
.
use_swizzle
(
panel_size
=
10
)
...
@@ -127,7 +129,6 @@ def tl_matmul(
...
@@ -127,7 +129,6 @@ def tl_matmul(
T
.
clear
(
C_local
)
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
# Load A into shared memory
# Load A into shared memory
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
...
@@ -137,7 +138,6 @@ def tl_matmul(
...
@@ -137,7 +138,6 @@ def tl_matmul(
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load A into fragment
# Load A into fragment
mma_emitter
.
ldmatrix_a
(
mma_emitter
.
ldmatrix_a
(
A_local
,
A_local
,
...
...
examples/bitnet-1.58b/load_from_quantized.py
View file @
29051439
...
@@ -49,7 +49,13 @@ def generate_text(model, tokenizer, prompt, max_length=100):
...
@@ -49,7 +49,13 @@ def generate_text(model, tokenizer, prompt, max_length=100):
def
main
():
def
main
():
# load quantized model
# load quantized model
qmodel
=
BitnetForCausalLM
.
from_quantized
(
saved_model_path
,).
cuda
().
half
()
qmodel
=
(
BitnetForCausalLM
.
from_quantized
(
saved_model_path
,
)
.
cuda
()
.
half
()
)
tokenizer
=
BitnetTokenizer
.
from_pretrained
(
model_name_or_path
,
use_fast
=
False
)
tokenizer
=
BitnetTokenizer
.
from_pretrained
(
model_name_or_path
,
use_fast
=
False
)
# print("original model generated text:")
# print("original model generated text:")
# print(generate_text(model, tokenizer, "Hi, ", max_length=100))
# print(generate_text(model, tokenizer, "Hi, ", max_length=100))
...
...
examples/bitnet-1.58b/maint/create_bitblas_ckpt.py
View file @
29051439
...
@@ -25,9 +25,9 @@ parser.add_argument("--saved_model_path", type=str, default=None)
...
@@ -25,9 +25,9 @@ parser.add_argument("--saved_model_path", type=str, default=None)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
model_name_or_path
=
args
.
model_name_or_path
model_name_or_path
=
args
.
model_name_or_path
saved_model_path
=
os
.
path
.
join
(
saved_model_path
=
(
dirpath
,
"models"
,
os
.
path
.
join
(
dirpath
,
"models"
,
f
"
{
model_name_or_path
}
_bitblas"
)
if
args
.
saved_model_path
is
None
else
args
.
saved_model_path
f
"
{
model_name_or_path
}
_bitblas"
)
if
args
.
saved_model_path
is
None
else
args
.
saved_model_path
)
def
generate_text
(
model
,
tokenizer
,
prompt
,
max_length
=
100
):
def
generate_text
(
model
,
tokenizer
,
prompt
,
max_length
=
100
):
...
@@ -67,7 +67,10 @@ def main():
...
@@ -67,7 +67,10 @@ def main():
model_name_or_path
,
model_name_or_path
,
use_flash_attention_2
=
False
,
use_flash_attention_2
=
False
,
torch_dtype
=
torch
.
float16
,
torch_dtype
=
torch
.
float16
,
).
cuda
().
half
())
)
.
cuda
()
.
half
()
)
tokenizer
=
BitnetTokenizer
.
from_pretrained
(
model_name_or_path
,
use_fast
=
False
)
tokenizer
=
BitnetTokenizer
.
from_pretrained
(
model_name_or_path
,
use_fast
=
False
)
# print("original model generated text:")
# print("original model generated text:")
...
@@ -112,10 +115,16 @@ def main():
...
@@ -112,10 +115,16 @@ def main():
file_path
=
cached_file
(
model_name_or_path
,
file
)
file_path
=
cached_file
(
model_name_or_path
,
file
)
os
.
system
(
f
"cp
{
file_path
}
{
saved_model_path
}
"
)
os
.
system
(
f
"cp
{
file_path
}
{
saved_model_path
}
"
)
# load quantized model
# load quantized model
qmodel
=
BitnetForCausalLM
.
from_quantized
(
saved_model_path
,).
cuda
().
half
()
qmodel
=
(
BitnetForCausalLM
.
from_quantized
(
saved_model_path
,
)
.
cuda
()
.
half
()
)
print
(
"quantized model generated text:"
)
print
(
"quantized model generated text:"
)
print
(
generate_text
(
qmodel
,
tokenizer
,
"Hi, "
,
max_length
=
100
))
print
(
generate_text
(
qmodel
,
tokenizer
,
"Hi, "
,
max_length
=
100
))
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
main
()
main
()
examples/bitnet-1.58b/modeling_bitnet.py
View file @
29051439
This diff is collapsed.
Click to expand it.
examples/bitnet-1.58b/tokenization_bitnet.py
View file @
29051439
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Tokenization classes for LLaMA."""
"""Tokenization classes for LLaMA."""
import
os
import
os
from
shutil
import
copyfile
from
shutil
import
copyfile
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
...
@@ -37,12 +38,10 @@ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
...
@@ -37,12 +38,10 @@ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
PRETRAINED_VOCAB_FILES_MAP
=
{
PRETRAINED_VOCAB_FILES_MAP
=
{
"vocab_file"
:
{
"vocab_file"
:
{
"hf-internal-testing/llama-tokenizer"
:
"hf-internal-testing/llama-tokenizer"
:
"https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model"
,
"https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model"
,
},
},
"tokenizer_file"
:
{
"tokenizer_file"
:
{
"hf-internal-testing/llama-tokenizer"
:
"hf-internal-testing/llama-tokenizer"
:
"https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json"
,
"https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json"
,
},
},
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
...
@@ -159,14 +158,10 @@ class BitnetTokenizer(PreTrainedTokenizer):
...
@@ -159,14 +158,10 @@ class BitnetTokenizer(PreTrainedTokenizer):
**
kwargs
,
**
kwargs
,
):
):
self
.
sp_model_kwargs
=
{}
if
sp_model_kwargs
is
None
else
sp_model_kwargs
self
.
sp_model_kwargs
=
{}
if
sp_model_kwargs
is
None
else
sp_model_kwargs
bos_token
=
AddedToken
(
bos_token
=
AddedToken
(
bos_token
,
normalized
=
False
,
special
=
True
)
if
isinstance
(
bos_token
,
str
)
else
bos_token
bos_token
,
normalized
=
False
,
special
=
True
)
if
isinstance
(
bos_token
,
str
)
else
bos_token
eos_token
=
AddedToken
(
eos_token
,
normalized
=
False
,
special
=
True
)
if
isinstance
(
eos_token
,
str
)
else
eos_token
eos_token
=
AddedToken
(
unk_token
=
AddedToken
(
unk_token
,
normalized
=
False
,
special
=
True
)
if
isinstance
(
unk_token
,
str
)
else
unk_token
eos_token
,
normalized
=
False
,
special
=
True
)
if
isinstance
(
eos_token
,
str
)
else
eos_token
pad_token
=
AddedToken
(
pad_token
,
normalized
=
False
,
special
=
True
)
if
isinstance
(
pad_token
,
str
)
else
pad_token
unk_token
=
AddedToken
(
unk_token
,
normalized
=
False
,
special
=
True
)
if
isinstance
(
unk_token
,
str
)
else
unk_token
pad_token
=
AddedToken
(
pad_token
,
normalized
=
False
,
special
=
True
)
if
isinstance
(
pad_token
,
str
)
else
pad_token
if
legacy
is
None
:
if
legacy
is
None
:
logger
.
warning_once
(
logger
.
warning_once
(
...
@@ -174,7 +169,8 @@ class BitnetTokenizer(PreTrainedTokenizer):
...
@@ -174,7 +169,8 @@ class BitnetTokenizer(PreTrainedTokenizer):
" expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you."
" expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you."
" If you want to use the new behavior, set `legacy=False`. This should only be set if you understand what it"
" If you want to use the new behavior, set `legacy=False`. This should only be set if you understand what it"
" means, and thoroughly read the reason why this was added as explained in"
" means, and thoroughly read the reason why this was added as explained in"
" https://github.com/huggingface/transformers/pull/24565"
)
" https://github.com/huggingface/transformers/pull/24565"
)
legacy
=
True
legacy
=
True
self
.
legacy
=
legacy
self
.
legacy
=
legacy
...
@@ -214,8 +210,7 @@ class BitnetTokenizer(PreTrainedTokenizer):
...
@@ -214,8 +210,7 @@ class BitnetTokenizer(PreTrainedTokenizer):
with
open
(
self
.
vocab_file
,
"rb"
)
as
f
:
with
open
(
self
.
vocab_file
,
"rb"
)
as
f
:
sp_model
=
f
.
read
()
sp_model
=
f
.
read
()
model_pb2
=
import_protobuf
(
model_pb2
=
import_protobuf
(
f
"The new behavior of
{
self
.
__class__
.
__name__
}
(with `self.legacy = False`)"
)
f
"The new behavior of
{
self
.
__class__
.
__name__
}
(with `self.legacy = False`)"
)
model
=
model_pb2
.
ModelProto
.
FromString
(
sp_model
)
model
=
model_pb2
.
ModelProto
.
FromString
(
sp_model
)
normalizer_spec
=
model_pb2
.
NormalizerSpec
()
normalizer_spec
=
model_pb2
.
NormalizerSpec
()
normalizer_spec
.
add_dummy_prefix
=
False
normalizer_spec
.
add_dummy_prefix
=
False
...
@@ -261,8 +256,7 @@ class BitnetTokenizer(PreTrainedTokenizer):
...
@@ -261,8 +256,7 @@ class BitnetTokenizer(PreTrainedTokenizer):
tokens
=
super
().
tokenize
(
text
,
**
kwargs
)
tokens
=
super
().
tokenize
(
text
,
**
kwargs
)
if
len
(
tokens
if
len
(
tokens
)
>
1
and
tokens
[
0
]
==
SPIECE_UNDERLINE
and
tokens
[
1
]
in
self
.
all_special_tokens
:
)
>
1
and
tokens
[
0
]
==
SPIECE_UNDERLINE
and
tokens
[
1
]
in
self
.
all_special_tokens
:
tokens
=
tokens
[
1
:]
tokens
=
tokens
[
1
:]
return
tokens
return
tokens
...
@@ -284,7 +278,7 @@ class BitnetTokenizer(PreTrainedTokenizer):
...
@@ -284,7 +278,7 @@ class BitnetTokenizer(PreTrainedTokenizer):
# 1. Encode string + prefix ex: "<unk> Hey"
# 1. Encode string + prefix ex: "<unk> Hey"
tokens
=
self
.
sp_model
.
encode
(
self
.
unk_token
+
text
,
out_type
=
str
)
tokens
=
self
.
sp_model
.
encode
(
self
.
unk_token
+
text
,
out_type
=
str
)
# 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
# 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
return
tokens
[
self
.
unk_token_length
:]
if
len
(
tokens
)
>=
self
.
unk_token_length
else
tokens
return
tokens
[
self
.
unk_token_length
:]
if
len
(
tokens
)
>=
self
.
unk_token_length
else
tokens
def
_convert_token_to_id
(
self
,
token
):
def
_convert_token_to_id
(
self
,
token
):
"""Converts a token (str) in an id using the vocab."""
"""Converts a token (str) in an id using the vocab."""
...
@@ -332,12 +326,9 @@ class BitnetTokenizer(PreTrainedTokenizer):
...
@@ -332,12 +326,9 @@ class BitnetTokenizer(PreTrainedTokenizer):
if
not
os
.
path
.
isdir
(
save_directory
):
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
error
(
f
"Vocabulary path (
{
save_directory
}
) should be a directory"
)
logger
.
error
(
f
"Vocabulary path (
{
save_directory
}
) should be a directory"
)
return
return
out_vocab_file
=
os
.
path
.
join
(
save_directory
,
out_vocab_file
=
os
.
path
.
join
(
save_directory
,
(
filename_prefix
+
"-"
if
filename_prefix
else
""
)
+
VOCAB_FILES_NAMES
[
"vocab_file"
])
(
filename_prefix
+
"-"
if
filename_prefix
else
""
)
+
VOCAB_FILES_NAMES
[
"vocab_file"
])
if
os
.
path
.
abspath
(
self
.
vocab_file
)
!=
os
.
path
.
abspath
(
out_vocab_file
)
and
os
.
path
.
isfile
(
if
os
.
path
.
abspath
(
self
.
vocab_file
)
!=
os
.
path
.
abspath
(
out_vocab_file
)
and
os
.
path
.
isfile
(
self
.
vocab_file
):
self
.
vocab_file
):
copyfile
(
self
.
vocab_file
,
out_vocab_file
)
copyfile
(
self
.
vocab_file
,
out_vocab_file
)
elif
not
os
.
path
.
isfile
(
self
.
vocab_file
):
elif
not
os
.
path
.
isfile
(
self
.
vocab_file
):
with
open
(
out_vocab_file
,
"wb"
)
as
fi
:
with
open
(
out_vocab_file
,
"wb"
)
as
fi
:
...
@@ -357,10 +348,9 @@ class BitnetTokenizer(PreTrainedTokenizer):
...
@@ -357,10 +348,9 @@ class BitnetTokenizer(PreTrainedTokenizer):
return
output
return
output
def
get_special_tokens_mask
(
self
,
def
get_special_tokens_mask
(
token_ids_0
:
List
[
int
],
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
,
already_has_special_tokens
:
bool
=
False
token_ids_1
:
Optional
[
List
[
int
]]
=
None
,
)
->
List
[
int
]:
already_has_special_tokens
:
bool
=
False
)
->
List
[
int
]:
"""
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
special tokens using the tokenizer `prepare_for_model` method.
...
@@ -377,20 +367,16 @@ class BitnetTokenizer(PreTrainedTokenizer):
...
@@ -377,20 +367,16 @@ class BitnetTokenizer(PreTrainedTokenizer):
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
"""
if
already_has_special_tokens
:
if
already_has_special_tokens
:
return
super
().
get_special_tokens_mask
(
return
super
().
get_special_tokens_mask
(
token_ids_0
=
token_ids_0
,
token_ids_1
=
token_ids_1
,
already_has_special_tokens
=
True
)
token_ids_0
=
token_ids_0
,
token_ids_1
=
token_ids_1
,
already_has_special_tokens
=
True
)
bos_token_id
=
[
1
]
if
self
.
add_bos_token
else
[]
bos_token_id
=
[
1
]
if
self
.
add_bos_token
else
[]
eos_token_id
=
[
1
]
if
self
.
add_eos_token
else
[]
eos_token_id
=
[
1
]
if
self
.
add_eos_token
else
[]
if
token_ids_1
is
None
:
if
token_ids_1
is
None
:
return
bos_token_id
+
([
0
]
*
len
(
token_ids_0
))
+
eos_token_id
return
bos_token_id
+
([
0
]
*
len
(
token_ids_0
))
+
eos_token_id
return
(
bos_token_id
+
([
0
]
*
len
(
token_ids_0
))
+
eos_token_id
+
bos_token_id
+
return
bos_token_id
+
([
0
]
*
len
(
token_ids_0
))
+
eos_token_id
+
bos_token_id
+
([
0
]
*
len
(
token_ids_1
))
+
eos_token_id
([
0
]
*
len
(
token_ids_1
))
+
eos_token_id
)
def
create_token_type_ids_from_sequences
(
self
,
def
create_token_type_ids_from_sequences
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
"""
"""
Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
sequence pair mask has the following format:
sequence pair mask has the following format:
...
@@ -473,9 +459,9 @@ class BitnetTokenizer(PreTrainedTokenizer):
...
@@ -473,9 +459,9 @@ class BitnetTokenizer(PreTrainedTokenizer):
"{% elif message['role'] == 'assistant' %}"
"{% elif message['role'] == 'assistant' %}"
"{{ ' ' + content.strip() + ' ' + eos_token }}"
"{{ ' ' + content.strip() + ' ' + eos_token }}"
"{% endif %}"
"{% endif %}"
"{% endfor %}"
)
"{% endfor %}"
template
=
template
.
replace
(
"USE_DEFAULT_PROMPT"
,
)
"true"
if
self
.
use_default_system_prompt
else
"false"
)
template
=
template
.
replace
(
"USE_DEFAULT_PROMPT"
,
"true"
if
self
.
use_default_system_prompt
else
"false"
)
default_message
=
DEFAULT_SYSTEM_PROMPT
.
replace
(
"
\n
"
,
"
\\
n"
).
replace
(
"'"
,
"
\\
'"
)
default_message
=
DEFAULT_SYSTEM_PROMPT
.
replace
(
"
\n
"
,
"
\\
n"
).
replace
(
"'"
,
"
\\
'"
)
template
=
template
.
replace
(
"DEFAULT_SYSTEM_MESSAGE"
,
default_message
)
template
=
template
.
replace
(
"DEFAULT_SYSTEM_MESSAGE"
,
default_message
)
...
...
examples/bitnet-1.58b/utils_quant.py
View file @
29051439
...
@@ -24,15 +24,14 @@ def weight_quant(weight, num_bits=1):
...
@@ -24,15 +24,14 @@ def weight_quant(weight, num_bits=1):
def
activation_quant
(
x
,
num_bits
=
8
):
def
activation_quant
(
x
,
num_bits
=
8
):
dtype
=
x
.
dtype
dtype
=
x
.
dtype
x
=
x
.
float
()
x
=
x
.
float
()
Qn
=
-
(
2
**
(
num_bits
-
1
))
Qn
=
-
(
2
**
(
num_bits
-
1
))
Qp
=
2
**
(
num_bits
-
1
)
-
1
Qp
=
2
**
(
num_bits
-
1
)
-
1
s
=
Qp
/
x
.
abs
().
max
(
dim
=-
1
,
keepdim
=
True
).
values
.
clamp
(
min
=
1e-5
)
s
=
Qp
/
x
.
abs
().
max
(
dim
=-
1
,
keepdim
=
True
).
values
.
clamp
(
min
=
1e-5
)
result
=
(
x
*
s
).
round
().
clamp
(
Qn
,
Qp
)
/
s
result
=
(
x
*
s
).
round
().
clamp
(
Qn
,
Qp
)
/
s
return
result
.
type
(
dtype
)
return
result
.
type
(
dtype
)
class
BitLinearBitBLAS
(
nn
.
Module
):
class
BitLinearBitBLAS
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
in_features
:
int
,
in_features
:
int
,
...
@@ -68,7 +67,7 @@ class BitLinearBitBLAS(nn.Module):
...
@@ -68,7 +67,7 @@ class BitLinearBitBLAS(nn.Module):
self
.
bitblas_matmul
=
self
.
_get_or_create_bitblas_operator
(
matmul_config
,
ENABLE_TUNING
)
self
.
bitblas_matmul
=
self
.
_get_or_create_bitblas_operator
(
matmul_config
,
ENABLE_TUNING
)
self
.
format
=
"bitnet"
self
.
format
=
"bitnet"
self
.
Qp
=
2
**
(
self
.
input_bits
-
1
)
-
1
self
.
Qp
=
2
**
(
self
.
input_bits
-
1
)
-
1
def
_get_or_create_bitblas_operator
(
self
,
config
,
enable_tuning
):
def
_get_or_create_bitblas_operator
(
self
,
config
,
enable_tuning
):
if
global_operator_cache
.
size
()
==
0
:
if
global_operator_cache
.
size
()
==
0
:
...
@@ -99,8 +98,7 @@ class BitLinearBitBLAS(nn.Module):
...
@@ -99,8 +98,7 @@ class BitLinearBitBLAS(nn.Module):
@
classmethod
@
classmethod
def
from_bit_linear
(
cls
,
bitlinear
,
weight_group
=
1
):
def
from_bit_linear
(
cls
,
bitlinear
,
weight_group
=
1
):
bitblas_linear
=
cls
(
bitblas_linear
=
cls
(
bitlinear
.
in_features
,
bitlinear
.
out_features
,
weight_bits
=
1
,
input_bits
=
8
)
bitlinear
.
in_features
,
bitlinear
.
out_features
,
weight_bits
=
1
,
input_bits
=
8
)
sw
,
qweight
=
bitblas_linear
.
create_bitblas_weights
(
bitlinear
.
weight
,
weight_group
)
sw
,
qweight
=
bitblas_linear
.
create_bitblas_weights
(
bitlinear
.
weight
,
weight_group
)
bitblas_linear
.
register_buffer
(
"qweight"
,
qweight
)
bitblas_linear
.
register_buffer
(
"qweight"
,
qweight
)
bitblas_linear
.
register_buffer
(
"sw"
,
sw
)
bitblas_linear
.
register_buffer
(
"sw"
,
sw
)
...
@@ -158,8 +156,8 @@ class BitLinearBitBLAS(nn.Module):
...
@@ -158,8 +156,8 @@ class BitLinearBitBLAS(nn.Module):
@
torch
.
compile
@
torch
.
compile
def
activation_quant
(
self
,
x
,
num_bits
=
8
):
def
activation_quant
(
self
,
x
,
num_bits
=
8
):
x
=
x
.
float
()
x
=
x
.
float
()
Qn
=
-
(
2
**
(
num_bits
-
1
))
Qn
=
-
(
2
**
(
num_bits
-
1
))
Qp
=
2
**
(
num_bits
-
1
)
-
1
Qp
=
2
**
(
num_bits
-
1
)
-
1
s
=
Qp
/
x
.
abs
().
max
(
dim
=-
1
,
keepdim
=
True
).
values
.
clamp
(
min
=
1e-5
)
s
=
Qp
/
x
.
abs
().
max
(
dim
=-
1
,
keepdim
=
True
).
values
.
clamp
(
min
=
1e-5
)
result
=
(
x
*
s
).
round
().
clamp
(
Qn
,
Qp
)
result
=
(
x
*
s
).
round
().
clamp
(
Qn
,
Qp
)
return
result
.
type
(
torch
.
int8
),
s
return
result
.
type
(
torch
.
int8
),
s
...
@@ -173,9 +171,8 @@ class BitLinearBitBLAS(nn.Module):
...
@@ -173,9 +171,8 @@ class BitLinearBitBLAS(nn.Module):
# for the correctness evaluation.
# for the correctness evaluation.
def
native_forward
(
self
,
input
):
def
native_forward
(
self
,
input
):
quant_input
=
(
input
+
(
activation_quant
(
input
,
self
.
input_bits
)
-
input
).
detach
())
quant_input
=
input
+
(
activation_quant
(
input
,
self
.
input_bits
)
-
input
).
detach
()
quant_weight
=
(
quant_weight
=
self
.
weight
+
(
weight_quant
(
self
.
weight
,
self
.
weight_bits
)
-
self
.
weight
).
detach
()
self
.
weight
+
(
weight_quant
(
self
.
weight
,
self
.
weight_bits
)
-
self
.
weight
).
detach
())
out
=
nn
.
functional
.
linear
(
quant_input
,
quant_weight
)
out
=
nn
.
functional
.
linear
(
quant_input
,
quant_weight
)
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
...
@@ -214,7 +211,6 @@ class BitLinearBitBLAS(nn.Module):
...
@@ -214,7 +211,6 @@ class BitLinearBitBLAS(nn.Module):
# Naive BitLinear from HuggingFace
# Naive BitLinear from HuggingFace
class
BitLinear
(
nn
.
Linear
):
class
BitLinear
(
nn
.
Linear
):
def
__init__
(
self
,
*
kargs
,
weight_bits
=
1
,
input_bits
=
8
,
**
kwargs
):
def
__init__
(
self
,
*
kargs
,
weight_bits
=
1
,
input_bits
=
8
,
**
kwargs
):
super
(
BitLinear
,
self
).
__init__
(
*
kargs
,
**
kwargs
)
super
(
BitLinear
,
self
).
__init__
(
*
kargs
,
**
kwargs
)
"""
"""
...
@@ -224,10 +220,8 @@ class BitLinear(nn.Linear):
...
@@ -224,10 +220,8 @@ class BitLinear(nn.Linear):
self
.
input_bits
=
input_bits
self
.
input_bits
=
input_bits
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
quant_input
=
input
+
(
activation_quant
(
input
,
self
.
input_bits
)
-
input
).
detach
()
quant_input
=
input
+
(
activation_quant
(
input
,
self
.
input_bits
)
-
input
).
detach
()
quant_weight
=
self
.
weight
+
(
weight_quant
(
self
.
weight
,
self
.
weight_bits
)
-
quant_weight
=
self
.
weight
+
(
weight_quant
(
self
.
weight
,
self
.
weight_bits
)
-
self
.
weight
).
detach
()
self
.
weight
).
detach
()
out
=
nn
.
functional
.
linear
(
quant_input
,
quant_weight
)
out
=
nn
.
functional
.
linear
(
quant_input
,
quant_weight
)
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
...
...
examples/bitnet-1.58b/vllm_workspace/conftest.py
View file @
29051439
...
@@ -20,7 +20,7 @@ from transformers import (
...
@@ -20,7 +20,7 @@ from transformers import (
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.assets.image
import
ImageAsset
from
vllm.assets.image
import
ImageAsset
from
vllm.config
import
TokenizerPoolConfig
from
vllm.config
import
TokenizerPoolConfig
from
vllm.distributed
import
(
destroy_distributed_environment
,
destroy_model_parallel
)
from
vllm.distributed
import
destroy_distributed_environment
,
destroy_model_parallel
from
vllm.inputs
import
TextPrompt
from
vllm.inputs
import
TextPrompt
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sequence
import
SampleLogprobs
from
vllm.sequence
import
SampleLogprobs
...
@@ -56,12 +56,13 @@ else:
...
@@ -56,12 +56,13 @@ else:
class
_ImageAssets
(
_ImageAssetsBase
):
class
_ImageAssets
(
_ImageAssetsBase
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
super
().
__init__
([
super
().
__init__
(
[
ImageAsset
(
"stop_sign"
),
ImageAsset
(
"stop_sign"
),
ImageAsset
(
"cherry_blossom"
),
ImageAsset
(
"cherry_blossom"
),
])
]
)
def
prompts
(
self
,
prompts
:
_ImageAssetPrompts
)
->
List
[
str
]:
def
prompts
(
self
,
prompts
:
_ImageAssetPrompts
)
->
List
[
str
]:
"""
"""
...
@@ -136,7 +137,6 @@ _T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding)
...
@@ -136,7 +137,6 @@ _T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding)
class
HfRunner
:
class
HfRunner
:
def
wrap_device
(
self
,
input
:
_T
)
->
_T
:
def
wrap_device
(
self
,
input
:
_T
)
->
_T
:
if
not
is_cpu
():
if
not
is_cpu
():
return
input
.
to
(
"cuda"
)
return
input
.
to
(
"cuda"
)
...
@@ -166,7 +166,8 @@ class HfRunner:
...
@@ -166,7 +166,8 @@ class HfRunner:
SentenceTransformer
(
SentenceTransformer
(
model_name
,
model_name
,
device
=
"cpu"
,
device
=
"cpu"
,
).
to
(
dtype
=
torch_dtype
))
).
to
(
dtype
=
torch_dtype
)
)
else
:
else
:
if
is_vision_model
:
if
is_vision_model
:
auto_cls
=
AutoModelForVision2Seq
auto_cls
=
AutoModelForVision2Seq
...
@@ -184,7 +185,8 @@ class HfRunner:
...
@@ -184,7 +185,8 @@ class HfRunner:
torch_dtype
=
torch_dtype
,
torch_dtype
=
torch_dtype
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
**
model_kwargs
,
**
model_kwargs
,
))
)
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
model_name
,
...
@@ -204,8 +206,7 @@ class HfRunner:
...
@@ -204,8 +206,7 @@ class HfRunner:
)
)
except
Exception
:
except
Exception
:
logger
.
warning
(
logger
.
warning
(
"Unable to auto-load processor from HuggingFace for "
"Unable to auto-load processor from HuggingFace for model %s. Using tokenizer instead."
,
"model %s. Using tokenizer instead."
,
model_name
,
model_name
,
)
)
self
.
processor
=
self
.
tokenizer
self
.
processor
=
self
.
tokenizer
...
@@ -362,7 +363,7 @@ class HfRunner:
...
@@ -362,7 +363,7 @@ class HfRunner:
last_hidden_states
,
last_hidden_states
,
self
.
model
.
get_output_embeddings
().
weight
.
t
(),
self
.
model
.
get_output_embeddings
().
weight
.
t
(),
)
)
if
(
getattr
(
self
.
model
.
get_output_embeddings
(),
"bias"
,
None
)
is
not
None
)
:
if
getattr
(
self
.
model
.
get_output_embeddings
(),
"bias"
,
None
)
is
not
None
:
logits
+=
self
.
model
.
get_output_embeddings
().
bias
.
unsqueeze
(
0
)
logits
+=
self
.
model
.
get_output_embeddings
().
bias
.
unsqueeze
(
0
)
logprobs
=
F
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
logprobs
=
F
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
seq_logprobs
.
append
(
logprobs
)
seq_logprobs
.
append
(
logprobs
)
...
@@ -389,8 +390,7 @@ class HfRunner:
...
@@ -389,8 +390,7 @@ class HfRunner:
all_output_strs
.
append
(
self
.
tokenizer
.
decode
(
output_ids
))
all_output_strs
.
append
(
self
.
tokenizer
.
decode
(
output_ids
))
outputs
=
zip
(
all_output_ids
,
all_output_strs
,
all_logprobs
)
outputs
=
zip
(
all_output_ids
,
all_output_strs
,
all_logprobs
)
return
[(
output_ids
,
output_str
,
output_logprobs
)
return
[(
output_ids
,
output_str
,
output_logprobs
)
for
output_ids
,
output_str
,
output_logprobs
in
outputs
]
for
output_ids
,
output_str
,
output_logprobs
in
outputs
]
def
encode
(
self
,
prompts
:
List
[
str
])
->
List
[
List
[
torch
.
Tensor
]]:
def
encode
(
self
,
prompts
:
List
[
str
])
->
List
[
List
[
torch
.
Tensor
]]:
return
self
.
model
.
encode
(
prompts
)
return
self
.
model
.
encode
(
prompts
)
...
@@ -409,7 +409,6 @@ def hf_runner():
...
@@ -409,7 +409,6 @@ def hf_runner():
class
VllmRunner
:
class
VllmRunner
:
def
__init__
(
def
__init__
(
self
,
self
,
model_name
:
str
,
model_name
:
str
,
...
@@ -514,12 +513,10 @@ class VllmRunner:
...
@@ -514,12 +513,10 @@ class VllmRunner:
num_logprobs
:
int
,
num_logprobs
:
int
,
images
:
Optional
[
List
[
Image
.
Image
]]
=
None
,
images
:
Optional
[
List
[
Image
.
Image
]]
=
None
,
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]:
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]:
greedy_logprobs_params
=
SamplingParams
(
greedy_logprobs_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
max_tokens
,
logprobs
=
num_logprobs
)
temperature
=
0.0
,
max_tokens
=
max_tokens
,
logprobs
=
num_logprobs
)
outputs
=
self
.
generate_w_logprobs
(
prompts
,
greedy_logprobs_params
,
images
=
images
)
outputs
=
self
.
generate_w_logprobs
(
prompts
,
greedy_logprobs_params
,
images
=
images
)
return
[(
output_ids
,
output_str
,
output_logprobs
)
return
[(
output_ids
,
output_str
,
output_logprobs
)
for
output_ids
,
output_str
,
output_logprobs
in
outputs
]
for
output_ids
,
output_str
,
output_logprobs
in
outputs
]
def
generate_beam_search
(
def
generate_beam_search
(
self
,
self
,
...
...
examples/bitnet-1.58b/vllm_workspace/inference_with_compress_format.py
View file @
29051439
...
@@ -39,8 +39,7 @@ with VllmRunner(
...
@@ -39,8 +39,7 @@ with VllmRunner(
# set enforce_eager = True to disable cuda graph
# set enforce_eager = True to disable cuda graph
enforce_eager
=
False
,
enforce_eager
=
False
,
)
as
bitnet_model
:
)
as
bitnet_model
:
bitbnet_outputs
=
bitnet_model
.
generate_greedy
([
"Hi, tell me about microsoft?"
],
bitbnet_outputs
=
bitnet_model
.
generate_greedy
([
"Hi, tell me about microsoft?"
],
max_tokens
=
1024
)
max_tokens
=
1024
)
print
(
"bitnet inference:"
)
print
(
"bitnet inference:"
)
print
(
bitbnet_outputs
[
0
][
0
])
print
(
bitbnet_outputs
[
0
][
0
])
print
(
bitbnet_outputs
[
0
][
1
])
print
(
bitbnet_outputs
[
0
][
1
])
examples/bitnet-1.58b/vllm_workspace/inference_with_native_format.py
View file @
29051439
Prev
1
2
3
4
5
6
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment