Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
5a4263f4
Commit
5a4263f4
authored
Feb 24, 2024
by
Ruff
Committed by
Aarni Koskela
Mar 13, 2024
Browse files
Reformat with ruff-format
parent
02e30ca6
Changes
39
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
914 additions
and
902 deletions
+914
-902
bitsandbytes/triton/int8_matmul_mixed_dequantize.py
bitsandbytes/triton/int8_matmul_mixed_dequantize.py
+94
-54
bitsandbytes/triton/int8_matmul_rowwise_dequantize.py
bitsandbytes/triton/int8_matmul_rowwise_dequantize.py
+94
-53
bitsandbytes/triton/quantize_columnwise_and_transpose.py
bitsandbytes/triton/quantize_columnwise_and_transpose.py
+25
-23
bitsandbytes/triton/quantize_global.py
bitsandbytes/triton/quantize_global.py
+50
-31
bitsandbytes/triton/quantize_rowwise.py
bitsandbytes/triton/quantize_rowwise.py
+19
-18
bitsandbytes/utils.py
bitsandbytes/utils.py
+13
-15
check_bnb_install.py
check_bnb_install.py
+5
-5
examples/int8_inference_huggingface.py
examples/int8_inference_huggingface.py
+4
-9
install_cuda.py
install_cuda.py
+12
-4
scripts/stale.py
scripts/stale.py
+2
-1
tests/test_autograd.py
tests/test_autograd.py
+61
-79
tests/test_cuda_setup_evaluator.py
tests/test_cuda_setup_evaluator.py
+1
-4
tests/test_functional.py
tests/test_functional.py
+342
-435
tests/test_generation.py
tests/test_generation.py
+39
-36
tests/test_linear4bit.py
tests/test_linear4bit.py
+5
-7
tests/test_linear8bitlt.py
tests/test_linear8bitlt.py
+15
-3
tests/test_modules.py
tests/test_modules.py
+73
-75
tests/test_optim.py
tests/test_optim.py
+49
-42
tests/test_triton.py
tests/test_triton.py
+11
-8
No files found.
bitsandbytes/triton/int8_matmul_mixed_dequantize.py
View file @
5a4263f4
...
@@ -3,14 +3,14 @@ import torch
...
@@ -3,14 +3,14 @@ import torch
from
bitsandbytes.triton.triton_utils
import
is_triton_available
from
bitsandbytes.triton.triton_utils
import
is_triton_available
if
not
is_triton_available
():
if
not
is_triton_available
():
def
int8_matmul_mixed_dequantize
(
a
,
b
,
state_x
,
state_w
,
bias
):
return
None
else
:
def
int8_matmul_mixed_dequantize
(
a
,
b
,
state_x
,
state_w
,
bias
):
return
None
else
:
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
# This is a matmul kernel based on triton.ops.matmul
# This is a matmul kernel based on triton.ops.matmul
# It is modified to support rowwise quantized input and global quantized weight
# It is modified to support rowwise quantized input and global quantized weight
# It's purpose is fused matmul then dequantize
# It's purpose is fused matmul then dequantize
...
@@ -27,58 +27,83 @@ else:
...
@@ -27,58 +27,83 @@ else:
for
block_n
in
[
32
,
64
,
128
,
256
]:
for
block_n
in
[
32
,
64
,
128
,
256
]:
num_warps
=
2
if
block_n
<=
64
else
4
num_warps
=
2
if
block_n
<=
64
else
4
configs
.
append
(
configs
.
append
(
triton
.
Config
({
'BLOCK_M'
:
block_m
,
'BLOCK_N'
:
block_n
,
'BLOCK_K'
:
block_k
,
'SPLIT_K'
:
1
},
triton
.
Config
(
num_stages
=
num_stages
,
num_warps
=
num_warps
))
{
"BLOCK_M"
:
block_m
,
"BLOCK_N"
:
block_n
,
"BLOCK_K"
:
block_k
,
"SPLIT_K"
:
1
},
num_stages
=
num_stages
,
num_warps
=
num_warps
,
),
)
# split_k
# split_k
for
split_k
in
[
2
,
4
,
8
,
16
]:
for
split_k
in
[
2
,
4
,
8
,
16
]:
configs
.
append
(
triton
.
Config
({
'BLOCK_M'
:
block_m
,
'BLOCK_N'
:
block_n
,
'BLOCK_K'
:
block_k
,
'SPLIT_K'
:
split_k
},
configs
.
append
(
num_stages
=
num_stages
,
num_warps
=
num_warps
,
pre_hook
=
init_to_zero
(
'C'
)))
triton
.
Config
(
{
"BLOCK_M"
:
block_m
,
"BLOCK_N"
:
block_n
,
"BLOCK_K"
:
block_k
,
"SPLIT_K"
:
split_k
},
num_stages
=
num_stages
,
num_warps
=
num_warps
,
pre_hook
=
init_to_zero
(
"C"
),
),
)
return
configs
return
configs
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
configs
=
[
# basic configs for compute-bound matmuls
# basic configs for compute-bound matmuls
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
256
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
256
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'
BLOCK_M
'
:
256
,
'
BLOCK_N
'
:
128
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"
BLOCK_M
"
:
256
,
"
BLOCK_N
"
:
128
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'
BLOCK_M
'
:
256
,
'
BLOCK_N
'
:
64
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
256
,
"
BLOCK_N
"
:
64
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
64
,
'
BLOCK_N
'
:
256
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
64
,
"
BLOCK_N
"
:
256
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
128
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
128
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
64
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
64
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
64
,
'
BLOCK_N
'
:
128
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
64
,
"
BLOCK_N
"
:
128
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
32
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
32
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
64
,
'
BLOCK_N
'
:
32
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
5
,
num_warps
=
2
),
triton
.
Config
({
"
BLOCK_M
"
:
64
,
"
BLOCK_N
"
:
32
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
# good for int8
# good for int8
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
256
,
'
BLOCK_K
'
:
128
,
'
SPLIT_K
'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
256
,
"
BLOCK_K
"
:
128
,
"
SPLIT_K
"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'
BLOCK_M
'
:
256
,
'
BLOCK_N
'
:
128
,
'
BLOCK_K
'
:
128
,
'
SPLIT_K
'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"
BLOCK_M
"
:
256
,
"
BLOCK_N
"
:
128
,
"
BLOCK_K
"
:
128
,
"
SPLIT_K
"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'
BLOCK_M
'
:
256
,
'
BLOCK_N
'
:
64
,
'
BLOCK_K
'
:
128
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
256
,
"
BLOCK_N
"
:
64
,
"
BLOCK_K
"
:
128
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
64
,
'
BLOCK_N
'
:
256
,
'
BLOCK_K
'
:
128
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
64
,
"
BLOCK_N
"
:
256
,
"
BLOCK_K
"
:
128
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
128
,
'
BLOCK_K
'
:
128
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
128
,
"
BLOCK_K
"
:
128
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
64
,
'
BLOCK_K
'
:
64
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
64
,
"
BLOCK_K
"
:
64
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
64
,
'
BLOCK_N
'
:
128
,
'
BLOCK_K
'
:
64
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
64
,
"
BLOCK_N
"
:
128
,
"
BLOCK_K
"
:
64
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
32
,
'
BLOCK_K
'
:
64
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
32
,
"
BLOCK_K
"
:
64
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
64
,
'
BLOCK_N
'
:
32
,
'
BLOCK_K
'
:
64
,
'
SPLIT_K
'
:
1
},
num_stages
=
5
,
num_warps
=
2
),
triton
.
Config
({
"
BLOCK_M
"
:
64
,
"
BLOCK_N
"
:
32
,
"
BLOCK_K
"
:
64
,
"
SPLIT_K
"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
*
get_configs_io_bound
(),
*
get_configs_io_bound
(),
],
],
key
=
[
'M'
,
'N'
,
'K'
],
key
=
[
"M"
,
"N"
,
"K"
],
prune_configs_by
=
{
prune_configs_by
=
{
"early_config_prune"
:
early_config_prune
,
"perf_model"
:
estimate_matmul_time
,
"top_k"
:
10
},
'early_config_prune'
:
early_config_prune
,
)
'perf_model'
:
estimate_matmul_time
,
@
triton
.
heuristics
(
'top_k'
:
10
{
"EVEN_K"
:
lambda
args
:
args
[
"K"
]
%
(
args
[
"BLOCK_K"
]
*
args
[
"SPLIT_K"
])
==
0
,
},
},
)
)
@
triton
.
heuristics
({
'EVEN_K'
:
lambda
args
:
args
[
'K'
]
%
(
args
[
'BLOCK_K'
]
*
args
[
'SPLIT_K'
])
==
0
,
})
@
triton
.
jit
@
triton
.
jit
def
_int8_matmul_mixed_dequantize
(
A
,
B
,
C
,
bias
,
state_x_ptr
,
state_w_ptr
,
M
,
N
,
K
,
divfactor
:
tl
.
constexpr
,
has_bias
:
tl
.
constexpr
,
def
_int8_matmul_mixed_dequantize
(
stride_am
,
stride_ak
,
A
,
stride_bk
,
stride_bn
,
B
,
stride_cm
,
stride_cn
,
C
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
bias
,
GROUP_M
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
state_x_ptr
,
ACC_TYPE
:
tl
.
constexpr
state_w_ptr
,
):
M
,
N
,
K
,
divfactor
:
tl
.
constexpr
,
has_bias
:
tl
.
constexpr
,
stride_am
,
stride_ak
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
GROUP_M
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
ACC_TYPE
:
tl
.
constexpr
,
):
# matrix multiplication
# matrix multiplication
pid
=
tl
.
program_id
(
0
)
pid
=
tl
.
program_id
(
0
)
pid_z
=
tl
.
program_id
(
1
)
pid_z
=
tl
.
program_id
(
1
)
...
@@ -115,13 +140,13 @@ else:
...
@@ -115,13 +140,13 @@ else:
b
=
tl
.
load
(
B
)
b
=
tl
.
load
(
B
)
else
:
else
:
k_remaining
=
K
-
k
*
(
BLOCK_K
*
SPLIT_K
)
k_remaining
=
K
-
k
*
(
BLOCK_K
*
SPLIT_K
)
a
=
tl
.
load
(
A
,
mask
=
rk
[
None
,
:]
<
k_remaining
,
other
=
0.
)
a
=
tl
.
load
(
A
,
mask
=
rk
[
None
,
:]
<
k_remaining
,
other
=
0.
0
)
b
=
tl
.
load
(
B
,
mask
=
rk
[:,
None
]
<
k_remaining
,
other
=
0.
)
b
=
tl
.
load
(
B
,
mask
=
rk
[:,
None
]
<
k_remaining
,
other
=
0.
0
)
acc
+=
tl
.
dot
(
a
,
b
)
acc
+=
tl
.
dot
(
a
,
b
)
A
+=
BLOCK_K
*
SPLIT_K
*
stride_ak
A
+=
BLOCK_K
*
SPLIT_K
*
stride_ak
B
+=
BLOCK_K
*
SPLIT_K
*
stride_bk
B
+=
BLOCK_K
*
SPLIT_K
*
stride_bk
acc
=
(
w_factor
*
(
x_factor
*
(
acc
*
divfactor
))
)
acc
=
w_factor
*
(
x_factor
*
(
acc
*
divfactor
))
acc
=
acc
.
to
(
C
.
dtype
.
element_ty
)
acc
=
acc
.
to
(
C
.
dtype
.
element_ty
)
# conditionally add bias
# conditionally add bias
...
@@ -137,10 +162,9 @@ else:
...
@@ -137,10 +162,9 @@ else:
else
:
else
:
tl
.
atomic_add
(
C
,
acc
,
mask
=
mask
)
tl
.
atomic_add
(
C
,
acc
,
mask
=
mask
)
def
int8_matmul_mixed_dequantize
(
a
,
b
,
state_x
,
state_w
,
bias
):
def
int8_matmul_mixed_dequantize
(
a
,
b
,
state_x
,
state_w
,
bias
):
device
=
a
.
device
device
=
a
.
device
divfactor
=
1.
/
(
127.
*
127.
)
divfactor
=
1.
0
/
(
127.
0
*
127.
0
)
has_bias
=
0
if
bias
is
None
else
1
has_bias
=
0
if
bias
is
None
else
1
# handle non-contiguous inputs if necessary
# handle non-contiguous inputs if necessary
if
a
.
stride
(
0
)
>
1
and
a
.
stride
(
1
)
>
1
:
if
a
.
stride
(
0
)
>
1
and
a
.
stride
(
1
)
>
1
:
...
@@ -154,12 +178,28 @@ else:
...
@@ -154,12 +178,28 @@ else:
# allocates output
# allocates output
c
=
torch
.
empty
((
M
,
N
),
device
=
device
,
dtype
=
torch
.
float16
)
c
=
torch
.
empty
((
M
,
N
),
device
=
device
,
dtype
=
torch
.
float16
)
# accumulator types
# accumulator types
ACC_TYPE
=
tl
.
float32
#
if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
ACC_TYPE
=
tl
.
float32
#
if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch int8_matmul_mixed_dequantize kernel
# launch int8_matmul_mixed_dequantize kernel
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
'BLOCK_M'
])
*
triton
.
cdiv
(
N
,
META
[
'BLOCK_N'
]),
META
[
'SPLIT_K'
])
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_N"
]),
META
[
"SPLIT_K"
])
_int8_matmul_mixed_dequantize
[
grid
](
a
,
b
,
c
,
bias
,
state_x
,
state_w
,
M
,
N
,
K
,
divfactor
,
has_bias
,
_int8_matmul_mixed_dequantize
[
grid
](
a
.
stride
(
0
),
a
.
stride
(
1
),
a
,
b
.
stride
(
0
),
b
.
stride
(
1
),
b
,
c
.
stride
(
0
),
c
.
stride
(
1
),
c
,
GROUP_M
=
8
,
ACC_TYPE
=
ACC_TYPE
)
bias
,
state_x
,
state_w
,
M
,
N
,
K
,
divfactor
,
has_bias
,
a
.
stride
(
0
),
a
.
stride
(
1
),
b
.
stride
(
0
),
b
.
stride
(
1
),
c
.
stride
(
0
),
c
.
stride
(
1
),
GROUP_M
=
8
,
ACC_TYPE
=
ACC_TYPE
,
)
return
c
return
c
bitsandbytes/triton/int8_matmul_rowwise_dequantize.py
View file @
5a4263f4
...
@@ -3,7 +3,9 @@ import torch
...
@@ -3,7 +3,9 @@ import torch
from
bitsandbytes.triton.triton_utils
import
is_triton_available
from
bitsandbytes.triton.triton_utils
import
is_triton_available
if
not
is_triton_available
():
if
not
is_triton_available
():
def
int8_matmul_rowwise_dequantize
(
a
,
b
,
state_x
,
state_w
,
bias
):
return
None
def
int8_matmul_rowwise_dequantize
(
a
,
b
,
state_x
,
state_w
,
bias
):
return
None
else
:
else
:
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
...
@@ -17,7 +19,6 @@ else:
...
@@ -17,7 +19,6 @@ else:
def
init_to_zero
(
name
):
def
init_to_zero
(
name
):
return
lambda
nargs
:
nargs
[
name
].
zero_
()
return
lambda
nargs
:
nargs
[
name
].
zero_
()
def
get_configs_io_bound
():
def
get_configs_io_bound
():
configs
=
[]
configs
=
[]
for
num_stages
in
[
2
,
3
,
4
,
5
,
6
]:
for
num_stages
in
[
2
,
3
,
4
,
5
,
6
]:
...
@@ -26,58 +27,83 @@ else:
...
@@ -26,58 +27,83 @@ else:
for
block_n
in
[
32
,
64
,
128
,
256
]:
for
block_n
in
[
32
,
64
,
128
,
256
]:
num_warps
=
2
if
block_n
<=
64
else
4
num_warps
=
2
if
block_n
<=
64
else
4
configs
.
append
(
configs
.
append
(
triton
.
Config
({
'BLOCK_M'
:
block_m
,
'BLOCK_N'
:
block_n
,
'BLOCK_K'
:
block_k
,
'SPLIT_K'
:
1
},
triton
.
Config
(
num_stages
=
num_stages
,
num_warps
=
num_warps
))
{
"BLOCK_M"
:
block_m
,
"BLOCK_N"
:
block_n
,
"BLOCK_K"
:
block_k
,
"SPLIT_K"
:
1
},
num_stages
=
num_stages
,
num_warps
=
num_warps
,
),
)
# split_k
# split_k
for
split_k
in
[
2
,
4
,
8
,
16
]:
for
split_k
in
[
2
,
4
,
8
,
16
]:
configs
.
append
(
triton
.
Config
({
'BLOCK_M'
:
block_m
,
'BLOCK_N'
:
block_n
,
'BLOCK_K'
:
block_k
,
'SPLIT_K'
:
split_k
},
configs
.
append
(
num_stages
=
num_stages
,
num_warps
=
num_warps
,
pre_hook
=
init_to_zero
(
'C'
)))
triton
.
Config
(
{
"BLOCK_M"
:
block_m
,
"BLOCK_N"
:
block_n
,
"BLOCK_K"
:
block_k
,
"SPLIT_K"
:
split_k
},
num_stages
=
num_stages
,
num_warps
=
num_warps
,
pre_hook
=
init_to_zero
(
"C"
),
),
)
return
configs
return
configs
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
configs
=
[
# basic configs for compute-bound matmuls
# basic configs for compute-bound matmuls
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
256
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
256
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'
BLOCK_M
'
:
256
,
'
BLOCK_N
'
:
128
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"
BLOCK_M
"
:
256
,
"
BLOCK_N
"
:
128
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'
BLOCK_M
'
:
256
,
'
BLOCK_N
'
:
64
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
256
,
"
BLOCK_N
"
:
64
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
64
,
'
BLOCK_N
'
:
256
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
64
,
"
BLOCK_N
"
:
256
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
128
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
128
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
64
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
64
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
64
,
'
BLOCK_N
'
:
128
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
64
,
"
BLOCK_N
"
:
128
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
32
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
32
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
64
,
'
BLOCK_N
'
:
32
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
5
,
num_warps
=
2
),
triton
.
Config
({
"
BLOCK_M
"
:
64
,
"
BLOCK_N
"
:
32
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
# good for int8
# good for int8
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
256
,
'
BLOCK_K
'
:
128
,
'
SPLIT_K
'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
256
,
"
BLOCK_K
"
:
128
,
"
SPLIT_K
"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'
BLOCK_M
'
:
256
,
'
BLOCK_N
'
:
128
,
'
BLOCK_K
'
:
128
,
'
SPLIT_K
'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"
BLOCK_M
"
:
256
,
"
BLOCK_N
"
:
128
,
"
BLOCK_K
"
:
128
,
"
SPLIT_K
"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'
BLOCK_M
'
:
256
,
'
BLOCK_N
'
:
64
,
'
BLOCK_K
'
:
128
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
256
,
"
BLOCK_N
"
:
64
,
"
BLOCK_K
"
:
128
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
64
,
'
BLOCK_N
'
:
256
,
'
BLOCK_K
'
:
128
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
64
,
"
BLOCK_N
"
:
256
,
"
BLOCK_K
"
:
128
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
128
,
'
BLOCK_K
'
:
128
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
128
,
"
BLOCK_K
"
:
128
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
64
,
'
BLOCK_K
'
:
64
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
64
,
"
BLOCK_K
"
:
64
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
64
,
'
BLOCK_N
'
:
128
,
'
BLOCK_K
'
:
64
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
64
,
"
BLOCK_N
"
:
128
,
"
BLOCK_K
"
:
64
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
32
,
'
BLOCK_K
'
:
64
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
32
,
"
BLOCK_K
"
:
64
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
64
,
'
BLOCK_N
'
:
32
,
'
BLOCK_K
'
:
64
,
'
SPLIT_K
'
:
1
},
num_stages
=
5
,
num_warps
=
2
),
triton
.
Config
({
"
BLOCK_M
"
:
64
,
"
BLOCK_N
"
:
32
,
"
BLOCK_K
"
:
64
,
"
SPLIT_K
"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
*
get_configs_io_bound
(),
*
get_configs_io_bound
(),
],
],
key
=
[
'M'
,
'N'
,
'K'
],
key
=
[
"M"
,
"N"
,
"K"
],
prune_configs_by
=
{
prune_configs_by
=
{
"early_config_prune"
:
early_config_prune
,
"perf_model"
:
estimate_matmul_time
,
"top_k"
:
10
},
'early_config_prune'
:
early_config_prune
,
)
'perf_model'
:
estimate_matmul_time
,
@
triton
.
heuristics
(
'top_k'
:
10
{
"EVEN_K"
:
lambda
args
:
args
[
"K"
]
%
(
args
[
"BLOCK_K"
]
*
args
[
"SPLIT_K"
])
==
0
,
},
},
)
)
@
triton
.
heuristics
({
'EVEN_K'
:
lambda
args
:
args
[
'K'
]
%
(
args
[
'BLOCK_K'
]
*
args
[
'SPLIT_K'
])
==
0
,
})
@
triton
.
jit
@
triton
.
jit
def
_int8_matmul_rowwise_dequantize
(
A
,
B
,
C
,
bias
,
state_x_ptr
,
state_w_ptr
,
M
,
N
,
K
,
divfactor
,
has_bias
:
tl
.
constexpr
,
def
_int8_matmul_rowwise_dequantize
(
stride_am
,
stride_ak
,
A
,
stride_bk
,
stride_bn
,
B
,
stride_cm
,
stride_cn
,
C
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
bias
,
GROUP_M
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
state_x_ptr
,
ACC_TYPE
:
tl
.
constexpr
state_w_ptr
,
):
M
,
N
,
K
,
divfactor
,
has_bias
:
tl
.
constexpr
,
stride_am
,
stride_ak
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
GROUP_M
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
ACC_TYPE
:
tl
.
constexpr
,
):
# matrix multiplication
# matrix multiplication
pid
=
tl
.
program_id
(
0
)
pid
=
tl
.
program_id
(
0
)
pid_z
=
tl
.
program_id
(
1
)
pid_z
=
tl
.
program_id
(
1
)
...
@@ -114,13 +140,13 @@ else:
...
@@ -114,13 +140,13 @@ else:
b
=
tl
.
load
(
B
)
b
=
tl
.
load
(
B
)
else
:
else
:
k_remaining
=
K
-
k
*
(
BLOCK_K
*
SPLIT_K
)
k_remaining
=
K
-
k
*
(
BLOCK_K
*
SPLIT_K
)
a
=
tl
.
load
(
A
,
mask
=
rk
[
None
,
:]
<
k_remaining
,
other
=
0.
)
a
=
tl
.
load
(
A
,
mask
=
rk
[
None
,
:]
<
k_remaining
,
other
=
0.
0
)
b
=
tl
.
load
(
B
,
mask
=
rk
[:,
None
]
<
k_remaining
,
other
=
0.
)
b
=
tl
.
load
(
B
,
mask
=
rk
[:,
None
]
<
k_remaining
,
other
=
0.
0
)
acc
+=
tl
.
dot
(
a
,
b
)
acc
+=
tl
.
dot
(
a
,
b
)
A
+=
BLOCK_K
*
SPLIT_K
*
stride_ak
A
+=
BLOCK_K
*
SPLIT_K
*
stride_ak
B
+=
BLOCK_K
*
SPLIT_K
*
stride_bk
B
+=
BLOCK_K
*
SPLIT_K
*
stride_bk
acc
=
(
w_factor
*
(
x_factor
*
(
acc
*
divfactor
))
)
acc
=
w_factor
*
(
x_factor
*
(
acc
*
divfactor
))
acc
=
acc
.
to
(
C
.
dtype
.
element_ty
)
acc
=
acc
.
to
(
C
.
dtype
.
element_ty
)
if
has_bias
:
if
has_bias
:
...
@@ -135,9 +161,8 @@ else:
...
@@ -135,9 +161,8 @@ else:
else
:
else
:
tl
.
atomic_add
(
C
,
acc
,
mask
=
mask
)
tl
.
atomic_add
(
C
,
acc
,
mask
=
mask
)
def
int8_matmul_rowwise_dequantize
(
a
,
b
,
state_x
,
state_w
,
bias
):
def
int8_matmul_rowwise_dequantize
(
a
,
b
,
state_x
,
state_w
,
bias
):
divfactor
=
1.
/
(
127.
*
127.
)
divfactor
=
1.
0
/
(
127.
0
*
127.
0
)
has_bias
=
0
if
bias
is
None
else
1
has_bias
=
0
if
bias
is
None
else
1
...
@@ -154,12 +179,28 @@ else:
...
@@ -154,12 +179,28 @@ else:
# allocates output
# allocates output
c
=
torch
.
empty
((
M
,
N
),
device
=
device
,
dtype
=
torch
.
float16
)
c
=
torch
.
empty
((
M
,
N
),
device
=
device
,
dtype
=
torch
.
float16
)
# accumulator types
# accumulator types
ACC_TYPE
=
tl
.
float32
#
if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
ACC_TYPE
=
tl
.
float32
#
if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch int8_matmul_rowwise_dequantize kernel
# launch int8_matmul_rowwise_dequantize kernel
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
'BLOCK_M'
])
*
triton
.
cdiv
(
N
,
META
[
'BLOCK_N'
]),
META
[
'SPLIT_K'
])
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_N"
]),
META
[
"SPLIT_K"
])
_int8_matmul_rowwise_dequantize
[
grid
](
a
,
b
,
c
,
bias
,
state_x
,
state_w
,
M
,
N
,
K
,
divfactor
,
has_bias
,
_int8_matmul_rowwise_dequantize
[
grid
](
a
.
stride
(
0
),
a
.
stride
(
1
),
a
,
b
.
stride
(
0
),
b
.
stride
(
1
),
b
,
c
.
stride
(
0
),
c
.
stride
(
1
),
c
,
GROUP_M
=
8
,
ACC_TYPE
=
ACC_TYPE
)
bias
,
state_x
,
state_w
,
M
,
N
,
K
,
divfactor
,
has_bias
,
a
.
stride
(
0
),
a
.
stride
(
1
),
b
.
stride
(
0
),
b
.
stride
(
1
),
c
.
stride
(
0
),
c
.
stride
(
1
),
GROUP_M
=
8
,
ACC_TYPE
=
ACC_TYPE
,
)
return
c
return
c
bitsandbytes/triton/quantize_columnwise_and_transpose.py
View file @
5a4263f4
...
@@ -5,9 +5,10 @@ import torch
...
@@ -5,9 +5,10 @@ import torch
from
bitsandbytes.triton.triton_utils
import
is_triton_available
from
bitsandbytes.triton.triton_utils
import
is_triton_available
if
not
is_triton_available
():
if
not
is_triton_available
():
def
quantize_columnwise_and_transpose
(
x
:
torch
.
Tensor
):
return
None
else
:
def
quantize_columnwise_and_transpose
(
x
:
torch
.
Tensor
):
return
None
else
:
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
...
@@ -15,23 +16,23 @@ else:
...
@@ -15,23 +16,23 @@ else:
# TODO: autotune this better.
# TODO: autotune this better.
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
configs
=
[
triton
.
Config
({},
num_stages
=
1
),
triton
.
Config
({},
num_stages
=
1
),
triton
.
Config
({},
num_stages
=
2
),
triton
.
Config
({},
num_stages
=
2
),
triton
.
Config
({},
num_stages
=
4
),
triton
.
Config
({},
num_stages
=
4
),
triton
.
Config
({},
num_stages
=
8
),
triton
.
Config
({},
num_stages
=
8
),
triton
.
Config
({},
num_stages
=
16
),
triton
.
Config
({},
num_stages
=
16
),
triton
.
Config
({},
num_stages
=
1
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
1
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
2
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
2
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
4
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
4
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
8
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
8
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
16
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
16
,
num_warps
=
8
),
triton
.
Config
({},
num_warps
=
1
),
triton
.
Config
({},
num_warps
=
1
),
triton
.
Config
({},
num_warps
=
2
),
triton
.
Config
({},
num_warps
=
2
),
triton
.
Config
({},
num_warps
=
4
),
triton
.
Config
({},
num_warps
=
4
),
triton
.
Config
({},
num_warps
=
8
),
triton
.
Config
({},
num_warps
=
8
),
],
],
key
=
[
'
n_elements
'
]
key
=
[
"
n_elements
"
],
)
)
@
triton
.
jit
@
triton
.
jit
def
_quantize_columnwise_and_transpose
(
def
_quantize_columnwise_and_transpose
(
...
@@ -39,7 +40,8 @@ else:
...
@@ -39,7 +40,8 @@ else:
output_ptr
,
output_ptr
,
output_maxs
,
output_maxs
,
n_elements
,
n_elements
,
M
:
tl
.
constexpr
,
N
:
tl
.
constexpr
,
M
:
tl
.
constexpr
,
N
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
P2
:
tl
.
constexpr
,
P2
:
tl
.
constexpr
,
):
):
...
@@ -47,12 +49,12 @@ else:
...
@@ -47,12 +49,12 @@ else:
block_start
=
pid
block_start
=
pid
p2_arange
=
tl
.
arange
(
0
,
P2
)
p2_arange
=
tl
.
arange
(
0
,
P2
)
p2_arange_mask
=
p2_arange
<
M
p2_arange_mask
=
p2_arange
<
M
arange
=
p2_arange
*
N
arange
=
p2_arange
*
N
offsets
=
block_start
+
arange
offsets
=
block_start
+
arange
x
=
tl
.
load
(
x_ptr
+
offsets
,
mask
=
p2_arange_mask
)
x
=
tl
.
load
(
x_ptr
+
offsets
,
mask
=
p2_arange_mask
)
abs_x
=
tl
.
abs
(
x
)
abs_x
=
tl
.
abs
(
x
)
max_val
=
tl
.
max
(
tl
.
where
(
p2_arange_mask
,
abs_x
,
0
),
axis
=
0
)
max_val
=
tl
.
max
(
tl
.
where
(
p2_arange_mask
,
abs_x
,
0
),
axis
=
0
)
output
=
tl
.
libdevice
.
llrint
(
127.
*
(
x
/
max_val
))
output
=
tl
.
libdevice
.
llrint
(
127.
0
*
(
x
/
max_val
))
new_start
=
pid
*
M
new_start
=
pid
*
M
new_offsets
=
new_start
+
p2_arange
new_offsets
=
new_start
+
p2_arange
...
@@ -68,6 +70,6 @@ else:
...
@@ -68,6 +70,6 @@ else:
assert
x
.
is_cuda
and
output
.
is_cuda
assert
x
.
is_cuda
and
output
.
is_cuda
n_elements
=
output
.
numel
()
n_elements
=
output
.
numel
()
grid
=
lambda
meta
:
(
triton
.
cdiv
(
n_elements
,
meta
[
'
BLOCK_SIZE
'
]),)
grid
=
lambda
meta
:
(
triton
.
cdiv
(
n_elements
,
meta
[
"
BLOCK_SIZE
"
]),)
_quantize_columnwise_and_transpose
[
grid
](
x
,
output
,
output_maxs
,
n_elements
,
M
,
N
,
BLOCK_SIZE
=
M
,
P2
=
P2
)
_quantize_columnwise_and_transpose
[
grid
](
x
,
output
,
output_maxs
,
n_elements
,
M
,
N
,
BLOCK_SIZE
=
M
,
P2
=
P2
)
return
output
,
output_maxs
return
output
,
output_maxs
bitsandbytes/triton/quantize_global.py
View file @
5a4263f4
import
torch
import
torch
from
bitsandbytes.triton.triton_utils
import
is_triton_available
from
bitsandbytes.triton.triton_utils
import
is_triton_available
if
not
is_triton_available
():
if
not
is_triton_available
():
def
quantize_global_transpose
(
input
):
return
None
def
quantize_global
(
x
:
torch
.
Tensor
):
return
None
else
:
def
quantize_global_transpose
(
input
):
return
None
def
quantize_global
(
x
:
torch
.
Tensor
):
return
None
else
:
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
# global quantize
# global quantize
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
configs
=
[
triton
.
Config
({
'BLOCK_SIZE'
:
1024
,},
num_warps
=
4
),
triton
.
Config
({
"BLOCK_SIZE"
:
1024
},
num_warps
=
4
),
triton
.
Config
({
'BLOCK_SIZE'
:
2048
,},
num_stages
=
1
),
triton
.
Config
({
"BLOCK_SIZE"
:
2048
},
num_stages
=
1
),
],
],
key
=
[
"n_elements"
],
key
=
[
'n_elements'
]
)
)
@
triton
.
jit
@
triton
.
jit
def
_quantize_global
(
def
_quantize_global
(
...
@@ -34,35 +35,43 @@ else:
...
@@ -34,35 +35,43 @@ else:
mask
=
offsets
<
n_elements
mask
=
offsets
<
n_elements
x
=
tl
.
load
(
x_ptr
+
offsets
,
mask
=
mask
)
x
=
tl
.
load
(
x_ptr
+
offsets
,
mask
=
mask
)
absmax_inv
=
tl
.
load
(
absmax_inv_ptr
)
absmax_inv
=
tl
.
load
(
absmax_inv_ptr
)
output
=
tl
.
libdevice
.
llrint
(
127.
*
(
x
*
absmax_inv
))
output
=
tl
.
libdevice
.
llrint
(
127.
0
*
(
x
*
absmax_inv
))
tl
.
store
(
output_ptr
+
offsets
,
output
,
mask
=
mask
)
tl
.
store
(
output_ptr
+
offsets
,
output
,
mask
=
mask
)
def
quantize_global
(
x
:
torch
.
Tensor
):
def
quantize_global
(
x
:
torch
.
Tensor
):
absmax
=
x
.
abs
().
max
().
unsqueeze
(
0
)
absmax
=
x
.
abs
().
max
().
unsqueeze
(
0
)
absmax_inv
=
1.
/
absmax
absmax_inv
=
1.
0
/
absmax
output
=
torch
.
empty
(
*
x
.
shape
,
device
=
'
cuda
'
,
dtype
=
torch
.
int8
)
output
=
torch
.
empty
(
*
x
.
shape
,
device
=
"
cuda
"
,
dtype
=
torch
.
int8
)
assert
x
.
is_cuda
and
output
.
is_cuda
assert
x
.
is_cuda
and
output
.
is_cuda
n_elements
=
output
.
numel
()
n_elements
=
output
.
numel
()
grid
=
lambda
meta
:
(
triton
.
cdiv
(
n_elements
,
meta
[
'
BLOCK_SIZE
'
]),)
grid
=
lambda
meta
:
(
triton
.
cdiv
(
n_elements
,
meta
[
"
BLOCK_SIZE
"
]),)
_quantize_global
[
grid
](
x
,
absmax_inv
,
output
,
n_elements
)
_quantize_global
[
grid
](
x
,
absmax_inv
,
output
,
n_elements
)
return
output
,
absmax
return
output
,
absmax
# global quantize and transpose
# global quantize and transpose
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
configs
=
[
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
128
,
'GROUP_M'
:
8
},
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"GROUP_M"
:
8
},
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
128
,
'GROUP_M'
:
8
},
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"GROUP_M"
:
8
},
num_warps
=
4
),
# ...
# ...
],
],
key
=
[
"M"
,
"N"
],
key
=
[
'M'
,
'N'
]
)
)
@
triton
.
jit
@
triton
.
jit
def
_quantize_global_transpose
(
A
,
absmax_inv_ptr
,
B
,
stride_am
,
stride_an
,
stride_bn
,
stride_bm
,
M
,
N
,
def
_quantize_global_transpose
(
BLOCK_M
:
tl
.
constexpr
,
A
,
BLOCK_N
:
tl
.
constexpr
,
absmax_inv_ptr
,
GROUP_M
:
tl
.
constexpr
):
B
,
stride_am
,
stride_an
,
stride_bn
,
stride_bm
,
M
,
N
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
GROUP_M
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
pid
=
tl
.
program_id
(
0
)
grid_m
=
(
M
+
BLOCK_M
-
1
)
//
BLOCK_M
grid_m
=
(
M
+
BLOCK_M
-
1
)
//
BLOCK_M
grid_n
=
(
N
+
BLOCK_N
-
1
)
//
BLOCK_N
grid_n
=
(
N
+
BLOCK_N
-
1
)
//
BLOCK_N
...
@@ -86,20 +95,30 @@ else:
...
@@ -86,20 +95,30 @@ else:
B
=
B
+
(
rm
[:,
None
]
*
stride_bm
+
rn
[
None
,
:]
*
stride_bn
)
B
=
B
+
(
rm
[:,
None
]
*
stride_bm
+
rn
[
None
,
:]
*
stride_bn
)
mask
=
(
rm
<
M
)[:,
None
]
&
(
rn
<
N
)[
None
,
:]
mask
=
(
rm
<
M
)[:,
None
]
&
(
rn
<
N
)[
None
,
:]
output
=
tl
.
libdevice
.
llrint
(
127.
*
(
a
*
absmax_inv
))
output
=
tl
.
libdevice
.
llrint
(
127.
0
*
(
a
*
absmax_inv
))
tl
.
store
(
B
,
output
,
mask
=
mask
)
tl
.
store
(
B
,
output
,
mask
=
mask
)
def
quantize_global_transpose
(
input
):
def
quantize_global_transpose
(
input
):
absmax
=
input
.
abs
().
max
().
unsqueeze
(
0
)
absmax
=
input
.
abs
().
max
().
unsqueeze
(
0
)
absmax_inv
=
1.
/
absmax
absmax_inv
=
1.
0
/
absmax
M
,
N
=
input
.
shape
M
,
N
=
input
.
shape
out
=
torch
.
empty
(
N
,
M
,
device
=
'
cuda
'
,
dtype
=
torch
.
int8
)
out
=
torch
.
empty
(
N
,
M
,
device
=
"
cuda
"
,
dtype
=
torch
.
int8
)
assert
out
.
size
(
0
)
==
N
and
out
.
size
(
1
)
==
M
assert
out
.
size
(
0
)
==
N
and
out
.
size
(
1
)
==
M
assert
input
.
stride
(
0
)
==
1
or
input
.
stride
(
1
)
==
1
assert
input
.
stride
(
0
)
==
1
or
input
.
stride
(
1
)
==
1
assert
out
.
stride
(
0
)
==
1
or
out
.
stride
(
1
)
==
1
assert
out
.
stride
(
0
)
==
1
or
out
.
stride
(
1
)
==
1
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
'BLOCK_M'
])
*
triton
.
cdiv
(
N
,
META
[
'BLOCK_N'
]),)
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_N"
]),)
_quantize_global_transpose
[
grid
](
input
,
absmax_inv
,
out
,
input
.
stride
(
0
),
input
.
stride
(
1
),
out
.
stride
(
0
),
out
.
stride
(
1
),
M
,
N
)
_quantize_global_transpose
[
grid
](
input
,
absmax_inv
,
out
,
input
.
stride
(
0
),
input
.
stride
(
1
),
out
.
stride
(
0
),
out
.
stride
(
1
),
M
,
N
,
)
return
out
,
absmax
return
out
,
absmax
bitsandbytes/triton/quantize_rowwise.py
View file @
5a4263f4
...
@@ -5,9 +5,10 @@ import torch
...
@@ -5,9 +5,10 @@ import torch
from
bitsandbytes.triton.triton_utils
import
is_triton_available
from
bitsandbytes.triton.triton_utils
import
is_triton_available
if
not
is_triton_available
():
if
not
is_triton_available
():
def
quantize_rowwise
(
x
:
torch
.
Tensor
):
return
None
else
:
def
quantize_rowwise
(
x
:
torch
.
Tensor
):
return
None
else
:
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
...
@@ -15,21 +16,21 @@ else:
...
@@ -15,21 +16,21 @@ else:
# TODO: autotune this better.
# TODO: autotune this better.
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
configs
=
[
triton
.
Config
({},
num_stages
=
1
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
1
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
2
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
2
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
4
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
4
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
8
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
8
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
1
),
triton
.
Config
({},
num_stages
=
1
),
triton
.
Config
({},
num_stages
=
2
),
triton
.
Config
({},
num_stages
=
2
),
triton
.
Config
({},
num_stages
=
4
),
triton
.
Config
({},
num_stages
=
4
),
triton
.
Config
({},
num_stages
=
8
),
triton
.
Config
({},
num_stages
=
8
),
triton
.
Config
({},
num_warps
=
1
),
triton
.
Config
({},
num_warps
=
1
),
triton
.
Config
({},
num_warps
=
2
),
triton
.
Config
({},
num_warps
=
2
),
triton
.
Config
({},
num_warps
=
4
),
triton
.
Config
({},
num_warps
=
4
),
triton
.
Config
({},
num_warps
=
8
),
triton
.
Config
({},
num_warps
=
8
),
],
],
key
=
[
'
n_elements
'
]
key
=
[
"
n_elements
"
],
)
)
@
triton
.
jit
@
triton
.
jit
def
_quantize_rowwise
(
def
_quantize_rowwise
(
...
@@ -49,7 +50,7 @@ else:
...
@@ -49,7 +50,7 @@ else:
abs_x
=
tl
.
abs
(
x
)
abs_x
=
tl
.
abs
(
x
)
max_val
=
tl
.
max
(
tl
.
where
(
row_mask
,
abs_x
,
0
),
axis
=
0
)
max_val
=
tl
.
max
(
tl
.
where
(
row_mask
,
abs_x
,
0
),
axis
=
0
)
output
=
tl
.
libdevice
.
llrint
(
127.
*
(
x
/
max_val
))
output
=
tl
.
libdevice
.
llrint
(
127.
0
*
(
x
/
max_val
))
tl
.
store
(
output_ptr
+
offsets
,
output
,
mask
=
row_mask
)
tl
.
store
(
output_ptr
+
offsets
,
output
,
mask
=
row_mask
)
tl
.
store
(
output_maxs
+
pid
,
max_val
)
tl
.
store
(
output_maxs
+
pid
,
max_val
)
...
...
bitsandbytes/utils.py
View file @
5a4263f4
...
@@ -30,7 +30,7 @@ def outlier_hook(module, input):
...
@@ -30,7 +30,7 @@ def outlier_hook(module, input):
# (1) zscore test of std of hidden dimension
# (1) zscore test of std of hidden dimension
outlier_idx
=
find_outlier_dims
(
merged
,
reduction_dim
=
1
,
zscore
=
3
)
outlier_idx
=
find_outlier_dims
(
merged
,
reduction_dim
=
1
,
zscore
=
3
)
# (2) magnitude > 6 test
# (2) magnitude > 6 test
dims
=
(
torch
.
abs
(
input
[
0
])
>
6
).
sum
(
dim
=
list
(
range
(
len
(
input
[
0
].
shape
)
-
1
)))
dims
=
(
torch
.
abs
(
input
[
0
])
>
6
).
sum
(
dim
=
list
(
range
(
len
(
input
[
0
].
shape
)
-
1
)))
outlier_idx2
=
torch
.
where
(
dims
>
0
)[
0
]
outlier_idx2
=
torch
.
where
(
dims
>
0
)[
0
]
outlier_idx
=
torch
.
cat
([
outlier_idx
,
outlier_idx2
]).
unique
()
outlier_idx
=
torch
.
cat
([
outlier_idx
,
outlier_idx2
]).
unique
()
tracer
.
hvalue2outlier_idx
[
hvalue
]
=
outlier_idx
tracer
.
hvalue2outlier_idx
[
hvalue
]
=
outlier_idx
...
@@ -59,14 +59,14 @@ class OutlierTracer:
...
@@ -59,14 +59,14 @@ class OutlierTracer:
self
.
hooks
.
append
(
m
.
register_forward_pre_hook
(
outlier_hook
))
self
.
hooks
.
append
(
m
.
register_forward_pre_hook
(
outlier_hook
))
def
is_initialized
(
self
):
def
is_initialized
(
self
):
return
getattr
(
self
,
'
initialized
'
,
False
)
return
getattr
(
self
,
"
initialized
"
,
False
)
def
get_hvalue
(
self
,
weight
):
def
get_hvalue
(
self
,
weight
):
return
weight
.
data
.
storage
().
data_ptr
()
return
weight
.
data
.
storage
().
data_ptr
()
def
get_outliers
(
self
,
weight
):
def
get_outliers
(
self
,
weight
):
if
not
self
.
is_initialized
():
if
not
self
.
is_initialized
():
print
(
'
Outlier tracer is not initialized...
'
)
print
(
"
Outlier tracer is not initialized...
"
)
return
None
return
None
hvalue
=
self
.
get_hvalue
(
weight
)
hvalue
=
self
.
get_hvalue
(
weight
)
if
hvalue
in
self
.
hvalue2outlier_idx
:
if
hvalue
in
self
.
hvalue2outlier_idx
:
...
@@ -80,6 +80,7 @@ class OutlierTracer:
...
@@ -80,6 +80,7 @@ class OutlierTracer:
cls
.
_instance
=
cls
.
__new__
(
cls
)
cls
.
_instance
=
cls
.
__new__
(
cls
)
return
cls
.
_instance
return
cls
.
_instance
def
find_outlier_dims
(
weight
,
reduction_dim
=
0
,
zscore
=
4.0
,
topk
=
None
,
rdm
=
False
):
def
find_outlier_dims
(
weight
,
reduction_dim
=
0
,
zscore
=
4.0
,
topk
=
None
,
rdm
=
False
):
if
rdm
:
if
rdm
:
return
torch
.
randint
(
0
,
weight
.
shape
[
1
],
size
=
(
topk
,),
device
=
weight
.
device
).
long
()
return
torch
.
randint
(
0
,
weight
.
shape
[
1
],
size
=
(
topk
,),
device
=
weight
.
device
).
long
()
...
@@ -87,13 +88,13 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False)
...
@@ -87,13 +88,13 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False)
m
=
weight
.
mean
(
reduction_dim
)
m
=
weight
.
mean
(
reduction_dim
)
mm
=
m
.
mean
()
mm
=
m
.
mean
()
mstd
=
m
.
std
()
mstd
=
m
.
std
()
zm
=
(
m
-
mm
)
/
mstd
zm
=
(
m
-
mm
)
/
mstd
std
=
weight
.
std
(
reduction_dim
)
std
=
weight
.
std
(
reduction_dim
)
stdm
=
std
.
mean
()
stdm
=
std
.
mean
()
stdstd
=
std
.
std
()
stdstd
=
std
.
std
()
zstd
=
(
std
-
stdm
)
/
stdstd
zstd
=
(
std
-
stdm
)
/
stdstd
if
topk
is
not
None
:
if
topk
is
not
None
:
val
,
idx
=
torch
.
topk
(
std
.
abs
(),
k
=
topk
,
dim
=
0
)
val
,
idx
=
torch
.
topk
(
std
.
abs
(),
k
=
topk
,
dim
=
0
)
...
@@ -105,10 +106,7 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False)
...
@@ -105,10 +106,7 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False)
def
execute_and_return
(
command_string
:
str
)
->
Tuple
[
str
,
str
]:
def
execute_and_return
(
command_string
:
str
)
->
Tuple
[
str
,
str
]:
def
_decode
(
subprocess_err_out_tuple
):
def
_decode
(
subprocess_err_out_tuple
):
return
tuple
(
return
tuple
(
to_decode
.
decode
(
"UTF-8"
).
strip
()
for
to_decode
in
subprocess_err_out_tuple
)
to_decode
.
decode
(
"UTF-8"
).
strip
()
for
to_decode
in
subprocess_err_out_tuple
)
def
execute_and_return_decoded_std_streams
(
command_string
):
def
execute_and_return_decoded_std_streams
(
command_string
):
return
_decode
(
return
_decode
(
...
@@ -116,14 +114,13 @@ def execute_and_return(command_string: str) -> Tuple[str, str]:
...
@@ -116,14 +114,13 @@ def execute_and_return(command_string: str) -> Tuple[str, str]:
shlex
.
split
(
command_string
),
shlex
.
split
(
command_string
),
stdout
=
subprocess
.
PIPE
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
).
communicate
()
).
communicate
()
,
)
)
std_out
,
std_err
=
execute_and_return_decoded_std_streams
(
command_string
)
std_out
,
std_err
=
execute_and_return_decoded_std_streams
(
command_string
)
return
std_out
,
std_err
return
std_out
,
std_err
def
replace_linear
(
def
replace_linear
(
model
,
model
,
linear_replacement
,
linear_replacement
,
...
@@ -163,8 +160,9 @@ def replace_linear(
...
@@ -163,8 +160,9 @@ def replace_linear(
model
.
_modules
[
name
].
bias
=
old_module
.
bias
model
.
_modules
[
name
].
bias
=
old_module
.
bias
if
post_processing_function
is
not
None
:
if
post_processing_function
is
not
None
:
func
=
getattr
(
module
,
post_processing_function
,
None
)
func
=
getattr
(
module
,
post_processing_function
,
None
)
if
func
is
not
None
:
func
(
module
)
if
func
is
not
None
:
func
(
module
)
return
model
return
model
...
@@ -179,7 +177,7 @@ def pack_dict_to_tensor(source_dict):
...
@@ -179,7 +177,7 @@ def pack_dict_to_tensor(source_dict):
A torch tensor containing the packed data.
A torch tensor containing the packed data.
"""
"""
json_str
=
json
.
dumps
(
source_dict
)
json_str
=
json
.
dumps
(
source_dict
)
json_bytes
=
json_str
.
encode
(
'
utf-8
'
)
json_bytes
=
json_str
.
encode
(
"
utf-8
"
)
tensor_data
=
torch
.
tensor
(
list
(
json_bytes
),
dtype
=
torch
.
uint8
)
tensor_data
=
torch
.
tensor
(
list
(
json_bytes
),
dtype
=
torch
.
uint8
)
return
tensor_data
return
tensor_data
...
@@ -196,7 +194,7 @@ def unpack_tensor_to_dict(tensor_data):
...
@@ -196,7 +194,7 @@ def unpack_tensor_to_dict(tensor_data):
A Python dictionary containing the unpacked data.
A Python dictionary containing the unpacked data.
"""
"""
json_bytes
=
bytes
(
tensor_data
.
cpu
().
numpy
())
json_bytes
=
bytes
(
tensor_data
.
cpu
().
numpy
())
json_str
=
json_bytes
.
decode
(
'
utf-8
'
)
json_str
=
json_bytes
.
decode
(
"
utf-8
"
)
unpacked_dict
=
json
.
loads
(
json_str
)
unpacked_dict
=
json
.
loads
(
json_str
)
return
unpacked_dict
return
unpacked_dict
check_bnb_install.py
View file @
5a4263f4
...
@@ -2,14 +2,14 @@ import torch
...
@@ -2,14 +2,14 @@ import torch
import
bitsandbytes
as
bnb
import
bitsandbytes
as
bnb
p
=
torch
.
nn
.
Parameter
(
torch
.
rand
(
10
,
10
).
cuda
())
p
=
torch
.
nn
.
Parameter
(
torch
.
rand
(
10
,
10
).
cuda
())
a
=
torch
.
rand
(
10
,
10
).
cuda
()
a
=
torch
.
rand
(
10
,
10
).
cuda
()
p1
=
p
.
data
.
sum
().
item
()
p1
=
p
.
data
.
sum
().
item
()
adam
=
bnb
.
optim
.
Adam
([
p
])
adam
=
bnb
.
optim
.
Adam
([
p
])
out
=
a
*
p
out
=
a
*
p
loss
=
out
.
sum
()
loss
=
out
.
sum
()
loss
.
backward
()
loss
.
backward
()
adam
.
step
()
adam
.
step
()
...
@@ -17,5 +17,5 @@ adam.step()
...
@@ -17,5 +17,5 @@ adam.step()
p2
=
p
.
data
.
sum
().
item
()
p2
=
p
.
data
.
sum
().
item
()
assert
p1
!=
p2
assert
p1
!=
p2
print
(
'
SUCCESS!
'
)
print
(
"
SUCCESS!
"
)
print
(
'
Installation was successful!
'
)
print
(
"
Installation was successful!
"
)
examples/int8_inference_huggingface.py
View file @
5a4263f4
...
@@ -2,23 +2,18 @@ import torch
...
@@ -2,23 +2,18 @@ import torch
from
transformers
import
LlamaForCausalLM
,
LlamaTokenizer
from
transformers
import
LlamaForCausalLM
,
LlamaTokenizer
MAX_NEW_TOKENS
=
128
MAX_NEW_TOKENS
=
128
model_name
=
'
meta-llama/Llama-2-7b-hf
'
model_name
=
"
meta-llama/Llama-2-7b-hf
"
text
=
'
Hamburg is in which country?
\n
'
text
=
"
Hamburg is in which country?
\n
"
tokenizer
=
LlamaTokenizer
.
from_pretrained
(
model_name
)
tokenizer
=
LlamaTokenizer
.
from_pretrained
(
model_name
)
input_ids
=
tokenizer
(
text
,
return_tensors
=
"pt"
).
input_ids
input_ids
=
tokenizer
(
text
,
return_tensors
=
"pt"
).
input_ids
max_memory
=
f
'
{
int
(
torch
.
cuda
.
mem_get_info
()[
0
]
/
1024
**
3
)
-
2
}
GB
'
max_memory
=
f
"
{
int
(
torch
.
cuda
.
mem_get_info
()[
0
]
/
1024
**
3
)
-
2
}
GB
"
n_gpus
=
torch
.
cuda
.
device_count
()
n_gpus
=
torch
.
cuda
.
device_count
()
max_memory
=
{
i
:
max_memory
for
i
in
range
(
n_gpus
)}
max_memory
=
{
i
:
max_memory
for
i
in
range
(
n_gpus
)}
model
=
LlamaForCausalLM
.
from_pretrained
(
model
=
LlamaForCausalLM
.
from_pretrained
(
model_name
,
device_map
=
"auto"
,
load_in_8bit
=
True
,
max_memory
=
max_memory
)
model_name
,
device_map
=
'auto'
,
load_in_8bit
=
True
,
max_memory
=
max_memory
)
generated_ids
=
model
.
generate
(
input_ids
,
max_length
=
MAX_NEW_TOKENS
)
generated_ids
=
model
.
generate
(
input_ids
,
max_length
=
MAX_NEW_TOKENS
)
print
(
tokenizer
.
decode
(
generated_ids
[
0
],
skip_special_tokens
=
True
))
print
(
tokenizer
.
decode
(
generated_ids
[
0
],
skip_special_tokens
=
True
))
install_cuda.py
View file @
5a4263f4
...
@@ -19,6 +19,7 @@ cuda_versions = {
...
@@ -19,6 +19,7 @@ cuda_versions = {
"123"
:
"https://developer.download.nvidia.com/compute/cuda/12.3.2/local_installers/cuda_12.3.2_545.23.08_linux.run"
,
"123"
:
"https://developer.download.nvidia.com/compute/cuda/12.3.2/local_installers/cuda_12.3.2_545.23.08_linux.run"
,
}
}
def
install_cuda
(
version
,
base_path
,
download_path
):
def
install_cuda
(
version
,
base_path
,
download_path
):
formatted_version
=
f
"
{
version
[:
-
1
]
}
.
{
version
[
-
1
]
}
"
formatted_version
=
f
"
{
version
[:
-
1
]
}
.
{
version
[
-
1
]
}
"
folder
=
f
"cuda-
{
formatted_version
}
"
folder
=
f
"cuda-
{
formatted_version
}
"
...
@@ -29,7 +30,7 @@ def install_cuda(version, base_path, download_path):
...
@@ -29,7 +30,7 @@ def install_cuda(version, base_path, download_path):
subprocess
.
run
([
"rm"
,
"-rf"
,
install_path
],
check
=
True
)
subprocess
.
run
([
"rm"
,
"-rf"
,
install_path
],
check
=
True
)
url
=
cuda_versions
[
version
]
url
=
cuda_versions
[
version
]
filename
=
url
.
split
(
'/'
)[
-
1
]
filename
=
url
.
split
(
"/"
)[
-
1
]
filepath
=
os
.
path
.
join
(
download_path
,
filename
)
filepath
=
os
.
path
.
join
(
download_path
,
filename
)
if
not
os
.
path
.
exists
(
filepath
):
if
not
os
.
path
.
exists
(
filepath
):
...
@@ -44,9 +45,14 @@ def install_cuda(version, base_path, download_path):
...
@@ -44,9 +45,14 @@ def install_cuda(version, base_path, download_path):
# Install CUDA
# Install CUDA
print
(
f
"Installing CUDA version
{
version
}
..."
)
print
(
f
"Installing CUDA version
{
version
}
..."
)
install_command
=
[
install_command
=
[
"bash"
,
filepath
,
"bash"
,
"--no-drm"
,
"--no-man-page"
,
"--override"
,
filepath
,
"--toolkitpath="
+
install_path
,
"--toolkit"
,
"--silent"
"--no-drm"
,
"--no-man-page"
,
"--override"
,
"--toolkitpath="
+
install_path
,
"--toolkit"
,
"--silent"
,
]
]
print
(
f
"Running command:
{
' '
.
join
(
install_command
)
}
"
)
print
(
f
"Running command:
{
' '
.
join
(
install_command
)
}
"
)
...
@@ -62,6 +68,7 @@ def install_cuda(version, base_path, download_path):
...
@@ -62,6 +68,7 @@ def install_cuda(version, base_path, download_path):
print
(
f
"CUDA version
{
version
}
installed at
{
install_path
}
"
)
print
(
f
"CUDA version
{
version
}
installed at
{
install_path
}
"
)
def
main
():
def
main
():
user_base_path
=
os
.
path
.
expanduser
(
"~/cuda"
)
user_base_path
=
os
.
path
.
expanduser
(
"~/cuda"
)
system_base_path
=
"/usr/local/cuda"
system_base_path
=
"/usr/local/cuda"
...
@@ -93,5 +100,6 @@ def main():
...
@@ -93,5 +100,6 @@ def main():
print
(
f
"Invalid CUDA version:
{
version
}
. Available versions are:
{
', '
.
join
(
cuda_versions
.
keys
())
}
"
)
print
(
f
"Invalid CUDA version:
{
version
}
. Available versions are:
{
', '
.
join
(
cuda_versions
.
keys
())
}
"
)
sys
.
exit
(
1
)
sys
.
exit
(
1
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
main
()
scripts/stale.py
View file @
5a4263f4
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
Script to close stale issue. Taken in part from the AllenNLP repository.
Script to close stale issue. Taken in part from the AllenNLP repository.
https://github.com/allenai/allennlp.
https://github.com/allenai/allennlp.
"""
"""
from
datetime
import
datetime
as
dt
,
timezone
from
datetime
import
datetime
as
dt
,
timezone
import
os
import
os
...
@@ -50,7 +51,7 @@ def main():
...
@@ -50,7 +51,7 @@ def main():
issue
.
create_comment
(
issue
.
create_comment
(
"This issue has been automatically marked as stale because it has not had "
"This issue has been automatically marked as stale because it has not had "
"recent activity. If you think this still needs to be addressed "
"recent activity. If you think this still needs to be addressed "
"please comment on this thread.
\n\n
"
"please comment on this thread.
\n\n
"
,
)
)
...
...
tests/test_autograd.py
View file @
5a4263f4
...
@@ -20,7 +20,11 @@ TRANSPOSE_VALS = [(False, True), (False, False)]
...
@@ -20,7 +20,11 @@ TRANSPOSE_VALS = [(False, True), (False, False)]
@
pytest
.
mark
.
parametrize
(
"dim2"
,
get_test_dims
(
32
,
96
,
n
=
1
),
ids
=
id_formatter
(
"dim2"
))
@
pytest
.
mark
.
parametrize
(
"dim2"
,
get_test_dims
(
32
,
96
,
n
=
1
),
ids
=
id_formatter
(
"dim2"
))
@
pytest
.
mark
.
parametrize
(
"dim3"
,
get_test_dims
(
32
,
96
,
n
=
1
),
ids
=
id_formatter
(
"dim3"
))
@
pytest
.
mark
.
parametrize
(
"dim3"
,
get_test_dims
(
32
,
96
,
n
=
1
),
ids
=
id_formatter
(
"dim3"
))
@
pytest
.
mark
.
parametrize
(
"dim4"
,
get_test_dims
(
32
,
96
,
n
=
1
),
ids
=
id_formatter
(
"dim4"
))
@
pytest
.
mark
.
parametrize
(
"dim4"
,
get_test_dims
(
32
,
96
,
n
=
1
),
ids
=
id_formatter
(
"dim4"
))
@
pytest
.
mark
.
parametrize
(
"funcs"
,
[(
torch
.
bmm
,
bnb
.
bmm_cublas
),
(
torch
.
matmul
,
bnb
.
matmul_cublas
)],
ids
=
[
"func=bmm"
,
"func=matmul"
])
@
pytest
.
mark
.
parametrize
(
"funcs"
,
[(
torch
.
bmm
,
bnb
.
bmm_cublas
),
(
torch
.
matmul
,
bnb
.
matmul_cublas
)],
ids
=
[
"func=bmm"
,
"func=matmul"
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"req_grad"
,
BOOLEAN_TUPLES
,
ids
=
id_formatter
(
"req_grad"
))
@
pytest
.
mark
.
parametrize
(
"req_grad"
,
BOOLEAN_TUPLES
,
ids
=
id_formatter
(
"req_grad"
))
@
pytest
.
mark
.
parametrize
(
"transpose"
,
BOOLEAN_TUPLES
,
ids
=
id_formatter
(
"transpose"
))
@
pytest
.
mark
.
parametrize
(
"transpose"
,
BOOLEAN_TUPLES
,
ids
=
id_formatter
(
"transpose"
))
...
@@ -30,16 +34,13 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
...
@@ -30,16 +34,13 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
dim3
=
dim3
-
(
dim3
%
16
)
dim3
=
dim3
-
(
dim3
%
16
)
dim4
=
dim4
-
(
dim4
%
16
)
dim4
=
dim4
-
(
dim4
%
16
)
for
i
in
range
(
25
):
for
i
in
range
(
25
):
# normal multiply
# normal multiply
if
funcs
[
0
]
in
[
torch
.
mm
,
torch
.
matmul
]:
if
funcs
[
0
]
in
[
torch
.
mm
,
torch
.
matmul
]:
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
A
=
torch
.
randn
(
size
=
dimA
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
])
A
=
torch
.
randn
(
size
=
dimA
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
])
B
=
torch
.
randn
(
size
=
dimB
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
])
B
=
torch
.
randn
(
size
=
dimB
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
])
target
=
torch
.
randn
(
target
=
torch
.
randn
(
size
=
(
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
])
size
=
(
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
]
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
...
@@ -71,9 +72,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
...
@@ -71,9 +72,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
A
.
grad
=
None
A
.
grad
=
None
B
.
grad
=
None
B
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradA2
=
A
.
grad
gradB2
=
B
.
grad
gradB2
=
B
.
grad
...
@@ -81,18 +80,14 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
...
@@ -81,18 +80,14 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
B
.
grad
=
None
B
.
grad
=
None
if
req_grad
[
0
]:
if
req_grad
[
0
]:
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
n
=
gradB1
.
numel
()
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
# batched matrix multiply
# batched matrix multiply
if
funcs
[
0
]
in
[
torch
.
bmm
,
torch
.
matmul
]:
if
funcs
[
0
]
in
[
torch
.
bmm
,
torch
.
matmul
]:
...
@@ -119,9 +114,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
...
@@ -119,9 +114,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
n
=
out_bnb
.
numel
()
n
=
out_bnb
.
numel
()
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.01
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.01
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
out_bnb
,
out_torch
,
atol
=
0.027
,
rtol
=
0.2
)
out_bnb
,
out_torch
,
atol
=
0.027
,
rtol
=
0.2
)
if
any
(
req_grad
):
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
out_bnb
.
data
.
copy_
(
out_torch
)
...
@@ -133,9 +126,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
...
@@ -133,9 +126,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
A
.
grad
=
None
A
.
grad
=
None
B
.
grad
=
None
B
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradA2
=
A
.
grad
gradB2
=
B
.
grad
gradB2
=
B
.
grad
...
@@ -143,9 +134,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
...
@@ -143,9 +134,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
B
.
grad
=
None
B
.
grad
=
None
if
req_grad
[
0
]:
if
req_grad
[
0
]:
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
n
=
gradB1
.
numel
()
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
...
@@ -192,9 +181,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
...
@@ -192,9 +181,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
A
.
grad
=
None
A
.
grad
=
None
B
.
grad
=
None
B
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradA2
=
A
.
grad
gradB2
=
B
.
grad
gradB2
=
B
.
grad
...
@@ -202,9 +189,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
...
@@ -202,9 +189,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
B
.
grad
=
None
B
.
grad
=
None
if
req_grad
[
0
]:
if
req_grad
[
0
]:
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
n
=
gradB1
.
numel
()
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
...
@@ -218,25 +203,17 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
...
@@ -218,25 +203,17 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
@
pytest
.
mark
.
parametrize
(
"dim3"
,
get_test_dims
(
32
,
96
,
n
=
1
),
ids
=
id_formatter
(
"dim3"
))
@
pytest
.
mark
.
parametrize
(
"dim3"
,
get_test_dims
(
32
,
96
,
n
=
1
),
ids
=
id_formatter
(
"dim3"
))
@
pytest
.
mark
.
parametrize
(
"dim4"
,
get_test_dims
(
32
,
96
,
n
=
1
),
ids
=
id_formatter
(
"dim4"
))
@
pytest
.
mark
.
parametrize
(
"dim4"
,
get_test_dims
(
32
,
96
,
n
=
1
),
ids
=
id_formatter
(
"dim4"
))
@
pytest
.
mark
.
parametrize
(
"decomp"
,
[
0.0
,
6.0
],
ids
=
id_formatter
(
"decomp"
))
@
pytest
.
mark
.
parametrize
(
"decomp"
,
[
0.0
,
6.0
],
ids
=
id_formatter
(
"decomp"
))
@
pytest
.
mark
.
parametrize
(
"funcs"
,
[(
torch
.
matmul
,
bnb
.
matmul
),
(
torch
.
matmul
,
bnb
.
research
.
switchback_bnb
)],
ids
=
[
"func=matmul"
,
"func=switchback_bnb"
])
@
pytest
.
mark
.
parametrize
(
"funcs"
,
[(
torch
.
matmul
,
bnb
.
matmul
),
(
torch
.
matmul
,
bnb
.
research
.
switchback_bnb
)],
ids
=
[
"func=matmul"
,
"func=switchback_bnb"
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"req_grad"
,
BOOLEAN_TRIPLES
,
ids
=
id_formatter
(
"req_grad"
))
@
pytest
.
mark
.
parametrize
(
"req_grad"
,
BOOLEAN_TRIPLES
,
ids
=
id_formatter
(
"req_grad"
))
@
pytest
.
mark
.
parametrize
(
"transpose"
,
TRANSPOSE_VALS
,
ids
=
id_formatter
(
"transpose"
))
@
pytest
.
mark
.
parametrize
(
"transpose"
,
TRANSPOSE_VALS
,
ids
=
id_formatter
(
"transpose"
))
@
pytest
.
mark
.
parametrize
(
"has_fp16_weights"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"has_fp16_weights"
))
@
pytest
.
mark
.
parametrize
(
"has_fp16_weights"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"has_fp16_weights"
))
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"has_bias"
))
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"has_bias"
))
def
test_matmullt
(
def
test_matmullt
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
decomp
,
has_fp16_weights
,
has_bias
):
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
decomp
,
has_fp16_weights
,
has_bias
):
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
outlier_dim
=
torch
.
randint
(
0
,
dimA
[
1
],
size
=
(
dimA
[
1
]
//
8
,),
device
=
"cuda"
)
outlier_dim
=
torch
.
randint
(
0
,
dimA
[
1
],
size
=
(
dimA
[
1
]
//
8
,),
device
=
"cuda"
)
...
@@ -245,18 +222,13 @@ def test_matmullt(
...
@@ -245,18 +222,13 @@ def test_matmullt(
req_grad
[
2
]
=
False
req_grad
[
2
]
=
False
for
i
in
range
(
3
):
for
i
in
range
(
3
):
# normal multiply
# normal multiply
if
funcs
[
0
]
in
[
torch
.
mm
,
torch
.
matmul
]:
if
funcs
[
0
]
in
[
torch
.
mm
,
torch
.
matmul
]:
A
=
torch
.
randn
(
A
=
torch
.
randn
(
size
=
dimA
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
],
dtype
=
dtype
)
size
=
dimA
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
],
dtype
=
dtype
)
if
decomp
==
6.0
:
if
decomp
==
6.0
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
A
[:,
outlier_dim
]
=
6.0
A
[:,
outlier_dim
]
=
6.0
B
=
torch
.
randn
(
B
=
torch
.
randn
(
size
=
dimB
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
size
=
dimB
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
target
=
torch
.
randn
(
target
=
torch
.
randn
(
size
=
(
dim2
,
dim4
),
size
=
(
dim2
,
dim4
),
device
=
"cuda"
,
device
=
"cuda"
,
...
@@ -266,7 +238,7 @@ def test_matmullt(
...
@@ -266,7 +238,7 @@ def test_matmullt(
bias
=
None
bias
=
None
bias2
=
None
bias2
=
None
if
has_bias
:
if
has_bias
:
bias
=
torch
.
randn
(
dim4
,
device
=
'
cuda
'
,
dtype
=
dtype
,
requires_grad
=
req_grad
[
2
])
bias
=
torch
.
randn
(
dim4
,
device
=
"
cuda
"
,
dtype
=
dtype
,
requires_grad
=
req_grad
[
2
])
bias2
=
bias
.
clone
()
bias2
=
bias
.
clone
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
B2
=
B
.
clone
()
B2
=
B
.
clone
()
...
@@ -311,9 +283,7 @@ def test_matmullt(
...
@@ -311,9 +283,7 @@ def test_matmullt(
if
any
(
req_grad
):
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
out_bnb
.
data
.
copy_
(
out_torch
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
loss_bnb
=
torch
.
nn
.
functional
.
mse_loss
(
loss_bnb
=
torch
.
nn
.
functional
.
mse_loss
(
out_bnb
,
target
).
mean
()
out_bnb
,
target
).
mean
()
loss_bnb
.
backward
()
loss_bnb
.
backward
()
gradA1
=
A
.
grad
gradA1
=
A
.
grad
gradB1
=
B
.
grad
gradB1
=
B
.
grad
...
@@ -323,9 +293,7 @@ def test_matmullt(
...
@@ -323,9 +293,7 @@ def test_matmullt(
gradBias1
=
bias
.
grad
gradBias1
=
bias
.
grad
bias
.
grad
=
None
bias
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradA2
=
A
.
grad
gradB2
=
B
.
grad
gradB2
=
B
.
grad
...
@@ -336,9 +304,7 @@ def test_matmullt(
...
@@ -336,9 +304,7 @@ def test_matmullt(
bias
.
grad
=
None
bias
.
grad
=
None
if
req_grad
[
0
]:
if
req_grad
[
0
]:
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
n
=
gradB1
.
numel
()
if
dim2
>
0
:
if
dim2
>
0
:
...
@@ -352,9 +318,7 @@ def test_matmullt(
...
@@ -352,9 +318,7 @@ def test_matmullt(
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.1
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.02
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.02
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
if
req_grad
[
2
]:
if
req_grad
[
2
]:
torch
.
testing
.
assert_close
(
gradBias1
,
gradBias2
)
torch
.
testing
.
assert_close
(
gradBias1
,
gradBias2
)
...
@@ -370,8 +334,20 @@ def test_matmullt(
...
@@ -370,8 +334,20 @@ def test_matmullt(
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"has_bias"
))
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"has_bias"
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
float32
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
float32
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"compress_statistics"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"compress_statistics"
))
@
pytest
.
mark
.
parametrize
(
"compress_statistics"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"compress_statistics"
))
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
'fp4'
,
'nf4'
],
ids
=
id_formatter
(
"quant_type"
))
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
"fp4"
,
"nf4"
],
ids
=
id_formatter
(
"quant_type"
))
def
test_matmul_4bit
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
has_bias
,
compress_statistics
,
quant_type
):
def
test_matmul_4bit
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
has_bias
,
compress_statistics
,
quant_type
,
):
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
if
has_bias
==
False
:
if
has_bias
==
False
:
...
@@ -387,11 +363,15 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
...
@@ -387,11 +363,15 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
bias
=
None
bias
=
None
bias2
=
None
bias2
=
None
if
has_bias
:
if
has_bias
:
bias
=
torch
.
randn
(
dim4
,
device
=
'
cuda
'
,
dtype
=
dtype
,
requires_grad
=
req_grad
[
2
])
bias
=
torch
.
randn
(
dim4
,
device
=
"
cuda
"
,
dtype
=
dtype
,
requires_grad
=
req_grad
[
2
])
bias2
=
bias
.
clone
()
bias2
=
bias
.
clone
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
B2
,
quant_state
=
bnb
.
functional
.
quantize_4bit
(
B
,
compress_statistics
=
compress_statistics
,
quant_type
=
quant_type
)
B2
,
quant_state
=
bnb
.
functional
.
quantize_4bit
(
B
,
compress_statistics
=
compress_statistics
,
quant_type
=
quant_type
,
)
if
not
transpose
[
0
]
and
transpose
[
1
]:
if
not
transpose
[
0
]
and
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
.
t
())
out_torch
=
funcs
[
0
](
A
,
B
.
t
())
...
@@ -410,7 +390,7 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
...
@@ -410,7 +390,7 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
if
n
>
0
:
if
n
>
0
:
assert
err
<
0.115
assert
err
<
0.115
#assert err < 0.20
#
assert err < 0.20
if
any
(
req_grad
):
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
out_bnb
.
data
.
copy_
(
out_torch
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -424,7 +404,7 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
...
@@ -424,7 +404,7 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
gradBias1
=
bias
.
grad
gradBias1
=
bias
.
grad
bias
.
grad
=
None
bias
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradA2
=
A
.
grad
gradB2
=
B
.
grad
gradB2
=
B
.
grad
...
@@ -435,7 +415,7 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
...
@@ -435,7 +415,7 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
bias
.
grad
=
None
bias
.
grad
=
None
if
req_grad
[
0
]:
if
req_grad
[
0
]:
torch
.
testing
.
assert_close
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
torch
.
testing
.
assert_close
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
2
]:
if
req_grad
[
2
]:
torch
.
testing
.
assert_close
(
gradBias1
,
gradBias2
)
torch
.
testing
.
assert_close
(
gradBias1
,
gradBias2
)
...
@@ -448,8 +428,12 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
...
@@ -448,8 +428,12 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
@
pytest
.
mark
.
parametrize
(
"req_grad"
,
BOOLEAN_TRIPLES
,
ids
=
id_formatter
(
"req_grad"
))
@
pytest
.
mark
.
parametrize
(
"req_grad"
,
BOOLEAN_TRIPLES
,
ids
=
id_formatter
(
"req_grad"
))
@
pytest
.
mark
.
parametrize
(
"transpose"
,
TRANSPOSE_VALS
,
ids
=
id_formatter
(
"transpose"
))
@
pytest
.
mark
.
parametrize
(
"transpose"
,
TRANSPOSE_VALS
,
ids
=
id_formatter
(
"transpose"
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
float32
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
float32
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"funcs"
,
[(
torch
.
matmul
,
bnb
.
research
.
matmul_fp8_mixed
),
(
torch
.
matmul
,
bnb
.
research
.
matmul_fp8_global
)],
ids
=
[
"matmul_fp8_mixed"
,
'matmul_fp8_global'
])
@
pytest
.
mark
.
parametrize
(
def
test_matmul_fp8
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
):
"funcs"
,
[(
torch
.
matmul
,
bnb
.
research
.
matmul_fp8_mixed
),
(
torch
.
matmul
,
bnb
.
research
.
matmul_fp8_global
)],
ids
=
[
"matmul_fp8_mixed"
,
"matmul_fp8_global"
],
)
def
test_matmul_fp8
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
):
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
req_grad
=
list
(
req_grad
)
req_grad
=
list
(
req_grad
)
...
@@ -480,7 +464,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -480,7 +464,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
err
=
torch
.
abs
(
out_bnb
-
out_torch
).
float
().
mean
().
item
()
err
=
torch
.
abs
(
out_bnb
-
out_torch
).
float
().
mean
().
item
()
if
n
>
0
:
if
n
>
0
:
assert
err
<
0.115
assert
err
<
0.115
#assert err < 0.20
#
assert err < 0.20
if
any
(
req_grad
):
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
out_bnb
.
data
.
copy_
(
out_torch
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -491,7 +475,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -491,7 +475,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
A
.
grad
=
None
A
.
grad
=
None
B
.
grad
=
None
B
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradA2
=
A
.
grad
gradB2
=
B
.
grad
gradB2
=
B
.
grad
...
@@ -499,7 +483,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -499,7 +483,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B
.
grad
=
None
B
.
grad
=
None
if
req_grad
[
0
]:
if
req_grad
[
0
]:
torch
.
testing
.
assert_close
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
torch
.
testing
.
assert_close
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
n
=
gradB1
.
numel
()
...
@@ -514,8 +498,6 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -514,8 +498,6 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.1
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.02
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.02
grad_err
=
(
gradB1
-
gradB2
).
abs
().
mean
()
grad_err
=
(
gradB1
-
gradB2
).
abs
().
mean
()
assert
grad_err
.
item
()
<
0.003
assert
grad_err
.
item
()
<
0.003
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
tests/test_cuda_setup_evaluator.py
View file @
5a4263f4
...
@@ -35,7 +35,4 @@ def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
...
@@ -35,7 +35,4 @@ def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
def
test_get_cuda_bnb_library_path_nocublaslt
(
monkeypatch
,
cuda111_noblas_spec
):
def
test_get_cuda_bnb_library_path_nocublaslt
(
monkeypatch
,
cuda111_noblas_spec
):
monkeypatch
.
delenv
(
"BNB_CUDA_VERSION"
,
raising
=
False
)
monkeypatch
.
delenv
(
"BNB_CUDA_VERSION"
,
raising
=
False
)
assert
(
assert
get_cuda_bnb_library_path
(
cuda111_noblas_spec
).
stem
==
"libbitsandbytes_cuda111_nocublaslt"
get_cuda_bnb_library_path
(
cuda111_noblas_spec
).
stem
==
"libbitsandbytes_cuda111_nocublaslt"
)
tests/test_functional.py
View file @
5a4263f4
...
@@ -19,9 +19,7 @@ from tests.helpers import (
...
@@ -19,9 +19,7 @@ from tests.helpers import (
id_formatter
,
id_formatter
,
)
)
torch
.
set_printoptions
(
torch
.
set_printoptions
(
precision
=
5
,
sci_mode
=
False
,
linewidth
=
120
,
edgeitems
=
20
,
threshold
=
10000
)
precision
=
5
,
sci_mode
=
False
,
linewidth
=
120
,
edgeitems
=
20
,
threshold
=
10000
)
k
=
20
k
=
20
...
@@ -98,9 +96,7 @@ def teardown():
...
@@ -98,9 +96,7 @@ def teardown():
pass
pass
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
[
"float"
,
"half"
])
"dtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
[
"float"
,
"half"
]
)
def
test_estimate_quantiles
(
dtype
):
def
test_estimate_quantiles
(
dtype
):
A
=
torch
.
rand
(
1024
,
1024
,
device
=
"cuda"
)
A
=
torch
.
rand
(
1024
,
1024
,
device
=
"cuda"
)
A
=
A
.
to
(
dtype
)
A
=
A
.
to
(
dtype
)
...
@@ -136,7 +132,6 @@ def test_quantile_quantization():
...
@@ -136,7 +132,6 @@ def test_quantile_quantization():
assert
diff
<
0.001
assert
diff
<
0.001
def
test_dynamic_quantization
():
def
test_dynamic_quantization
():
diffs
=
[]
diffs
=
[]
reldiffs
=
[]
reldiffs
=
[]
...
@@ -149,8 +144,8 @@ def test_dynamic_quantization():
...
@@ -149,8 +144,8 @@ def test_dynamic_quantization():
diffs
.
append
(
diff
.
mean
().
item
())
diffs
.
append
(
diff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
assert
diff
.
mean
().
item
()
<
0.0135
assert
diff
.
mean
().
item
()
<
0.0135
print
(
sum
(
diffs
)
/
len
(
diffs
))
print
(
sum
(
diffs
)
/
len
(
diffs
))
print
(
sum
(
reldiffs
)
/
len
(
reldiffs
))
print
(
sum
(
reldiffs
)
/
len
(
reldiffs
))
for
i
in
range
(
100
):
for
i
in
range
(
100
):
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
"cuda"
)
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
"cuda"
)
...
@@ -161,13 +156,12 @@ def test_dynamic_quantization():
...
@@ -161,13 +156,12 @@ def test_dynamic_quantization():
assert
diff
<
0.004
assert
diff
<
0.004
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"nested"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"nested"
))
@
pytest
.
mark
.
parametrize
(
"nested"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"nested"
))
@
pytest
.
mark
.
parametrize
(
"blocksize"
,
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
])
@
pytest
.
mark
.
parametrize
(
"blocksize"
,
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
])
@
pytest
.
mark
.
parametrize
(
"signed"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"signed"
))
@
pytest
.
mark
.
parametrize
(
"signed"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"signed"
))
def
test_dynamic_blockwise_quantization
(
dtype
,
nested
,
blocksize
,
signed
):
def
test_dynamic_blockwise_quantization
(
dtype
,
nested
,
blocksize
,
signed
):
#print('')
#
print('')
diffs
=
[]
diffs
=
[]
reldiffs
=
[]
reldiffs
=
[]
for
i
in
range
(
100
):
for
i
in
range
(
100
):
...
@@ -178,10 +172,10 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
...
@@ -178,10 +172,10 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
reldiff
=
diff
/
torch
.
abs
(
A1
.
float
()
+
1e-8
)
reldiff
=
diff
/
torch
.
abs
(
A1
.
float
()
+
1e-8
)
diffs
.
append
(
diff
.
mean
().
item
())
diffs
.
append
(
diff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
abserr
=
sum
(
diffs
)
/
len
(
diffs
)
abserr
=
sum
(
diffs
)
/
len
(
diffs
)
relerr
=
sum
(
reldiffs
)
/
len
(
reldiffs
)
relerr
=
sum
(
reldiffs
)
/
len
(
reldiffs
)
#print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
#
print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
#print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
#
print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
assert
abserr
<
0.011
assert
abserr
<
0.011
assert
relerr
<
0.018
assert
relerr
<
0.018
assert
A2
.
dtype
==
dtype
assert
A2
.
dtype
==
dtype
...
@@ -196,9 +190,9 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
...
@@ -196,9 +190,9 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
reldiff
=
diff
/
torch
.
abs
(
A1
.
float
()
+
1e-8
)
reldiff
=
diff
/
torch
.
abs
(
A1
.
float
()
+
1e-8
)
diffs
.
append
(
diff
.
mean
().
item
())
diffs
.
append
(
diff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
#torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
#
torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
abserr
=
sum
(
diffs
)
/
len
(
diffs
)
abserr
=
sum
(
diffs
)
/
len
(
diffs
)
relerr
=
sum
(
reldiffs
)
/
len
(
reldiffs
)
relerr
=
sum
(
reldiffs
)
/
len
(
reldiffs
)
if
signed
:
if
signed
:
assert
abserr
<
0.0035
assert
abserr
<
0.0035
assert
relerr
<
0.015
assert
relerr
<
0.015
...
@@ -206,14 +200,11 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
...
@@ -206,14 +200,11 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
assert
abserr
<
0.00175
assert
abserr
<
0.00175
assert
relerr
<
0.012
assert
relerr
<
0.012
assert
A2
.
dtype
==
dtype
assert
A2
.
dtype
==
dtype
#print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
# print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
#print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
# print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"gtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
[
"float"
,
"half"
])
"gtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
[
"float"
,
"half"
]
)
def
test_percentile_clipping
(
gtype
):
def
test_percentile_clipping
(
gtype
):
gnorm_vec1
=
torch
.
zeros
(
100
,
device
=
"cuda"
)
gnorm_vec1
=
torch
.
zeros
(
100
,
device
=
"cuda"
)
gnorm_vec2
=
torch
.
zeros
(
100
,
device
=
"cuda"
)
gnorm_vec2
=
torch
.
zeros
(
100
,
device
=
"cuda"
)
...
@@ -223,9 +214,7 @@ def test_percentile_clipping(gtype):
...
@@ -223,9 +214,7 @@ def test_percentile_clipping(gtype):
for
i
in
range
(
k
):
for
i
in
range
(
k
):
step
+=
1
step
+=
1
g
=
torch
.
randn
(
n
,
n
,
dtype
=
gtype
,
device
=
"cuda"
)
g
=
torch
.
randn
(
n
,
n
,
dtype
=
gtype
,
device
=
"cuda"
)
gnorm1
,
clip2
,
gnorm_scale
=
F
.
percentile_clipping
(
gnorm1
,
clip2
,
gnorm_scale
=
F
.
percentile_clipping
(
g
,
gnorm_vec2
,
step
,
percentile
=
percentile
)
g
,
gnorm_vec2
,
step
,
percentile
=
percentile
)
assert
gnorm_scale
==
1.0
if
gnorm1
<
clip2
else
clip2
/
gnorm1
assert
gnorm_scale
==
1.0
if
gnorm1
<
clip2
else
clip2
/
gnorm1
gnorm2
=
torch
.
norm
(
g
.
float
())
gnorm2
=
torch
.
norm
(
g
.
float
())
...
@@ -309,7 +298,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
...
@@ -309,7 +298,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
dim2
=
dim2
-
(
dim2
%
32
)
dim2
=
dim2
-
(
dim2
%
32
)
errors
=
[]
errors
=
[]
relerrors
=
[]
relerrors
=
[]
#print("")
#
print("")
for
i
in
range
(
5
):
for
i
in
range
(
5
):
if
batched
:
if
batched
:
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
32
,
dim1
,
dim2
//
32
),
device
=
"cuda"
)
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
32
,
dim1
,
dim2
//
32
),
device
=
"cuda"
)
...
@@ -321,9 +310,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
...
@@ -321,9 +310,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim2
,
dim1
),
device
=
"cuda"
)
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim2
,
dim1
),
device
=
"cuda"
)
maxA
,
Ac
=
quant_methods
[
0
](
A
,
1
)
maxA
,
Ac
=
quant_methods
[
0
](
A
,
1
)
maxB
,
Bc
=
quant_methods
[
1
](
B
,
0
)
maxB
,
Bc
=
quant_methods
[
1
](
B
,
0
)
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
quant_methods
[
2
](
maxA
,
Ac
),
A
,
atol
=
0.025
,
rtol
=
0.05
)
quant_methods
[
2
](
maxA
,
Ac
),
A
,
atol
=
0.025
,
rtol
=
0.05
)
if
batched
:
if
batched
:
out2
=
torch
.
bmm
(
A
,
B
)
out2
=
torch
.
bmm
(
A
,
B
)
C
=
torch
.
bmm
(
Ac
.
float
(),
Bc
.
float
())
C
=
torch
.
bmm
(
Ac
.
float
(),
Bc
.
float
())
...
@@ -338,8 +325,8 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
...
@@ -338,8 +325,8 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
relerr
=
err
/
torch
.
abs
(
out2
)
relerr
=
err
/
torch
.
abs
(
out2
)
errors
.
append
(
err
.
mean
().
item
())
errors
.
append
(
err
.
mean
().
item
())
relerrors
.
append
(
relerr
.
mean
().
item
())
relerrors
.
append
(
relerr
.
mean
().
item
())
#print(mean(errors))
#
print(mean(errors))
#print(mean(relerrors))
#
print(mean(relerrors))
def
test_stable_embedding
():
def
test_stable_embedding
():
...
@@ -356,16 +343,8 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
...
@@ -356,16 +343,8 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
batch_dim
=
batch_dim
-
(
batch_dim
%
16
)
batch_dim
=
batch_dim
-
(
batch_dim
%
16
)
seq_dim
=
seq_dim
-
(
seq_dim
%
16
)
seq_dim
=
seq_dim
-
(
seq_dim
%
16
)
for
i
in
range
(
k
):
for
i
in
range
(
k
):
shapeA
=
(
shapeA
=
(
batch_dim
,
hidden_dim
)
if
not
transpose
[
0
]
else
(
hidden_dim
,
batch_dim
)
(
batch_dim
,
hidden_dim
)
shapeB
=
(
32
*
random
.
randint
(
1
,
4
),
hidden_dim
)
if
transpose
[
1
]
else
(
hidden_dim
,
32
*
random
.
randint
(
1
,
4
))
if
not
transpose
[
0
]
else
(
hidden_dim
,
batch_dim
)
)
shapeB
=
(
(
32
*
random
.
randint
(
1
,
4
),
hidden_dim
)
if
transpose
[
1
]
else
(
hidden_dim
,
32
*
random
.
randint
(
1
,
4
))
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
"cuda"
).
to
(
torch
.
int8
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
"cuda"
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeB
,
device
=
"cuda"
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeB
,
device
=
"cuda"
).
to
(
torch
.
int8
)
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
...
@@ -385,11 +364,7 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
...
@@ -385,11 +364,7 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
for
i
in
range
(
k
):
for
i
in
range
(
k
):
shapeA
=
(
batch_dim
,
seq_dim
,
hidden_dim
)
shapeA
=
(
batch_dim
,
seq_dim
,
hidden_dim
)
shapeB
=
(
shapeB
=
(
32
*
random
.
randint
(
1
,
4
),
hidden_dim
)
if
transpose
[
1
]
else
(
hidden_dim
,
32
*
random
.
randint
(
1
,
4
))
(
32
*
random
.
randint
(
1
,
4
),
hidden_dim
)
if
transpose
[
1
]
else
(
hidden_dim
,
32
*
random
.
randint
(
1
,
4
))
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
"cuda"
).
to
(
torch
.
int8
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
"cuda"
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeB
,
device
=
"cuda"
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeB
,
device
=
"cuda"
).
to
(
torch
.
int8
)
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
...
@@ -410,16 +385,10 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
...
@@ -410,16 +385,10 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
hidden_dim
=
hidden_dim
-
(
hidden_dim
%
32
)
hidden_dim
=
hidden_dim
-
(
hidden_dim
%
32
)
batch_dim
=
batch_dim
-
(
batch_dim
%
2
)
batch_dim
=
batch_dim
-
(
batch_dim
%
2
)
for
i
in
range
(
25
):
for
i
in
range
(
25
):
A
=
torch
.
randint
(
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
batch_dim
,
seq_dim
,
hidden_dim
),
device
=
"cuda"
).
to
(
torch
.
int8
)
-
128
,
127
,
size
=
(
batch_dim
,
seq_dim
,
hidden_dim
),
device
=
"cuda"
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
batch_dim
,
seq_dim
,
1024
),
device
=
"cuda"
).
to
(
torch
.
int8
)
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
batch_dim
,
seq_dim
,
1024
),
device
=
"cuda"
).
to
(
torch
.
int8
)
out2
=
torch
.
einsum
(
"bsi, bso->io"
,
A
.
float
(),
B
.
float
())
out2
=
torch
.
einsum
(
"bsi, bso->io"
,
A
.
float
(),
B
.
float
())
iout
=
torch
.
empty
(
iout
=
torch
.
empty
(
A
.
shape
[
2
],
B
.
shape
[
2
],
dtype
=
torch
.
int32
,
device
=
A
.
device
)
A
.
shape
[
2
],
B
.
shape
[
2
],
dtype
=
torch
.
int32
,
device
=
A
.
device
)
out
=
F
.
igemm
(
A
,
B
,
out
=
iout
)
out
=
F
.
igemm
(
A
,
B
,
out
=
iout
)
torch
.
testing
.
assert_close
(
out
.
float
(),
out2
)
torch
.
testing
.
assert_close
(
out
.
float
(),
out2
)
...
@@ -444,9 +413,7 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
...
@@ -444,9 +413,7 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
errs2
=
[]
errs2
=
[]
relerrs2
=
[]
relerrs2
=
[]
for
i
in
range
(
k
):
for
i
in
range
(
k
):
A
=
torch
.
normal
(
A
=
torch
.
normal
(
0.0
,
0.5
,
size
=
(
batch_dim
,
seq_dim
,
hidden_dim
),
device
=
"cuda"
)
0.0
,
0.5
,
size
=
(
batch_dim
,
seq_dim
,
hidden_dim
),
device
=
"cuda"
)
if
transpose
:
if
transpose
:
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
256
,
hidden_dim
),
device
=
"cuda"
)
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
256
,
hidden_dim
),
device
=
"cuda"
)
else
:
else
:
...
@@ -523,9 +490,7 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
...
@@ -523,9 +490,7 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
out2
=
torch
.
bmm
(
A
.
permute
([
0
,
2
,
1
]).
float
(),
B
.
float
())
out2
=
torch
.
bmm
(
A
.
permute
([
0
,
2
,
1
]).
float
(),
B
.
float
())
out
=
F
.
igemm
(
A
.
permute
([
0
,
2
,
1
]),
B
)
out
=
F
.
igemm
(
A
.
permute
([
0
,
2
,
1
]),
B
)
elif
transpose
[
0
]
and
transpose
[
1
]:
elif
transpose
[
0
]
and
transpose
[
1
]:
out2
=
torch
.
bmm
(
out2
=
torch
.
bmm
(
A
.
permute
([
0
,
2
,
1
]).
float
(),
B
.
permute
([
0
,
2
,
1
]).
float
())
A
.
permute
([
0
,
2
,
1
]).
float
(),
B
.
permute
([
0
,
2
,
1
]).
float
()
)
out
=
F
.
igemm
(
A
.
permute
([
0
,
2
,
1
]),
B
.
permute
([
0
,
2
,
1
]))
out
=
F
.
igemm
(
A
.
permute
([
0
,
2
,
1
]),
B
.
permute
([
0
,
2
,
1
]))
torch
.
testing
.
assert_close
(
out
.
float
(),
out2
.
float
())
torch
.
testing
.
assert_close
(
out
.
float
(),
out2
.
float
())
...
@@ -541,7 +506,7 @@ def test_vector_quant(dim1, dim2, dim3):
...
@@ -541,7 +506,7 @@ def test_vector_quant(dim1, dim2, dim3):
qA
,
SA
=
F
.
vectorwise_quant
(
A
,
dim
=
0
)
qA
,
SA
=
F
.
vectorwise_quant
(
A
,
dim
=
0
)
A1
=
F
.
vectorwise_dequant
(
qA
,
SA
)
A1
=
F
.
vectorwise_dequant
(
qA
,
SA
)
n
=
A1
.
numel
()
n
=
A1
.
numel
()
assert_all_approx_close
(
A1
,
A
,
atol
=
0.01
,
rtol
=
0.1
,
count
=
int
(
n
*
0.002
))
assert_all_approx_close
(
A1
,
A
,
atol
=
0.01
,
rtol
=
0.1
,
count
=
int
(
n
*
0.002
))
@
pytest
.
mark
.
parametrize
(
"dim1"
,
get_test_dims
(
2
,
256
,
n
=
2
),
ids
=
id_formatter
(
"dim1"
))
@
pytest
.
mark
.
parametrize
(
"dim1"
,
get_test_dims
(
2
,
256
,
n
=
2
),
ids
=
id_formatter
(
"dim1"
))
...
@@ -565,9 +530,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
...
@@ -565,9 +530,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
if
dims
==
2
:
if
dims
==
2
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
),
device
=
"cuda"
).
to
(
dtype
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
),
device
=
"cuda"
).
to
(
dtype
)
elif
dims
==
3
:
elif
dims
==
3
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
dtype
)
dtype
)
out
,
S
=
F
.
nvidia_transform
(
A
,
to_order
=
orderOut
)
out
,
S
=
F
.
nvidia_transform
(
A
,
to_order
=
orderOut
)
...
@@ -579,17 +542,11 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
...
@@ -579,17 +542,11 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
if
dims
==
2
:
if
dims
==
2
:
n
=
A
.
shape
[
0
]
*
(
A
.
shape
[
1
]
+
(
32
-
(
A
.
shape
[
1
]
%
32
)))
n
=
A
.
shape
[
0
]
*
(
A
.
shape
[
1
]
+
(
32
-
(
A
.
shape
[
1
]
%
32
)))
elif
dims
==
3
:
elif
dims
==
3
:
n
=
(
n
=
A
.
shape
[
0
]
*
A
.
shape
[
1
]
*
(
A
.
shape
[
2
]
+
(
32
-
(
A
.
shape
[
2
]
%
32
)))
A
.
shape
[
0
]
*
A
.
shape
[
1
]
*
(
A
.
shape
[
2
]
+
(
32
-
(
A
.
shape
[
2
]
%
32
)))
)
assert
out
.
numel
()
==
n
assert
out
.
numel
()
==
n
elif
orderOut
==
"col_turing"
:
elif
orderOut
==
"col_turing"
:
# 32 col 8 row tiles
# 32 col 8 row tiles
n
=
(
A
.
shape
[
0
]
+
(
8
-
A
.
shape
[
0
]
%
8
))
*
(
n
=
(
A
.
shape
[
0
]
+
(
8
-
A
.
shape
[
0
]
%
8
))
*
(
A
.
shape
[
1
]
+
(
32
-
(
A
.
shape
[
1
]
%
32
)))
A
.
shape
[
1
]
+
(
32
-
(
A
.
shape
[
1
]
%
32
))
)
assert
out
.
numel
()
==
n
assert
out
.
numel
()
==
n
total_coltile
=
(
A
.
shape
[
1
]
//
32
)
+
(
1
if
A
.
shape
[
1
]
%
32
!=
0
else
0
)
total_coltile
=
(
A
.
shape
[
1
]
//
32
)
+
(
1
if
A
.
shape
[
1
]
%
32
!=
0
else
0
)
for
row
in
range
(
A
.
shape
[
0
]):
for
row
in
range
(
A
.
shape
[
0
]):
...
@@ -598,9 +555,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
...
@@ -598,9 +555,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
j
=
col
j
=
col
coltile
=
(
col
//
32
)
+
(
1
if
col
%
32
!=
0
else
0
)
coltile
=
(
col
//
32
)
+
(
1
if
col
%
32
!=
0
else
0
)
rowtile
=
(
rowtile
=
((
row
//
8
)
+
(
1
if
row
%
8
!=
0
else
0
))
*
total_coltile
(
row
//
8
)
+
(
1
if
row
%
8
!=
0
else
0
)
)
*
total_coltile
offset
=
32
*
8
*
(
rowtile
+
coltile
)
offset
=
32
*
8
*
(
rowtile
+
coltile
)
col2
=
col
%
32
col2
=
col
%
32
row2
=
(
row
%
8
)
*
32
row2
=
(
row
%
8
)
*
32
...
@@ -611,9 +566,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
...
@@ -611,9 +566,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
# torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
# torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
if
orderOut
==
"col32"
:
if
orderOut
==
"col32"
:
out2
,
S
=
F
.
nvidia_transform
(
out2
,
S
=
F
.
nvidia_transform
(
out
,
from_order
=
orderOut
,
to_order
=
"row"
,
state
=
S
)
out
,
from_order
=
orderOut
,
to_order
=
"row"
,
state
=
S
)
torch
.
testing
.
assert_close
(
A
,
out2
)
torch
.
testing
.
assert_close
(
A
,
out2
)
...
@@ -626,16 +579,10 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
...
@@ -626,16 +579,10 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
def
test_igemmlt_int
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
,
ldb
):
def
test_igemmlt_int
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
,
ldb
):
for
i
in
range
(
k
):
for
i
in
range
(
k
):
if
dims
==
2
:
if
dims
==
2
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim3
),
device
=
"cuda"
).
to
(
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
torch
.
int8
)
elif
dims
==
3
:
elif
dims
==
3
:
A
=
torch
.
randint
(
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim4
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim4
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
t
().
float
())
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
t
().
float
())
A2
,
SA
=
F
.
transform
(
A
,
"col32"
)
A2
,
SA
=
F
.
transform
(
A
,
"col32"
)
...
@@ -645,9 +592,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
...
@@ -645,9 +592,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
torch
.
testing
.
assert_close
(
C1
,
C3
.
float
())
torch
.
testing
.
assert_close
(
C1
,
C3
.
float
())
# transpose
# transpose
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim3
,
dim4
),
device
=
"cuda"
).
to
(
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim3
,
dim4
),
device
=
"cuda"
).
to
(
torch
.
int8
)
torch
.
int8
)
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
float
())
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
float
())
B2t
,
SBt
=
F
.
transform
(
B
,
"col_turing"
,
transpose
=
True
)
B2t
,
SBt
=
F
.
transform
(
B
,
"col_turing"
,
transpose
=
True
)
...
@@ -667,9 +612,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
...
@@ -667,9 +612,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
if
dims
==
2
:
if
dims
==
2
:
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim1
,
dim3
),
device
=
"cuda"
).
half
()
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim1
,
dim3
),
device
=
"cuda"
).
half
()
elif
dims
==
3
:
elif
dims
==
3
:
A
=
torch
.
normal
(
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
half
()
0
,
0.5
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
half
()
B
=
torch
.
randn
((
dim4
,
dim3
),
device
=
"cuda"
).
half
()
B
=
torch
.
randn
((
dim4
,
dim3
),
device
=
"cuda"
).
half
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
C1
=
torch
.
matmul
(
A
,
B
.
t
())
C1
=
torch
.
matmul
(
A
,
B
.
t
())
...
@@ -700,6 +643,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
...
@@ -700,6 +643,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
# C3, S = F.transform(C2, 'row', state=SC)
# C3, S = F.transform(C2, 'row', state=SC)
# torch.testing.assert_close(C1, C3.float())
# torch.testing.assert_close(C1, C3.float())
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"batch"
,
"seq"
,
"model"
,
"hidden"
),
(
"batch"
,
"seq"
,
"model"
,
"hidden"
),
[
[
...
@@ -729,7 +673,6 @@ def test_bench_8bit_training(batch, seq, model, hidden):
...
@@ -729,7 +673,6 @@ def test_bench_8bit_training(batch, seq, model, hidden):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
k
):
for
i
in
range
(
k
):
out1
=
torch
.
matmul
(
A
,
w1
.
t
())
# fc1
out1
=
torch
.
matmul
(
A
,
w1
.
t
())
# fc1
# out2 = torch.matmul(out1, w2.t())# fc2
# out2 = torch.matmul(out1, w2.t())# fc2
...
@@ -866,13 +809,15 @@ def test_bench_8bit_training(batch, seq, model, hidden):
...
@@ -866,13 +809,15 @@ def test_bench_8bit_training(batch, seq, model, hidden):
def
test_dequant_mm
(
dim1
,
dim4
,
dims
,
formatB
,
has_bias
):
def
test_dequant_mm
(
dim1
,
dim4
,
dims
,
formatB
,
has_bias
):
inner
=
torch
.
randint
(
1
,
128
,
size
=
(
1
,)).
item
()
inner
=
torch
.
randint
(
1
,
128
,
size
=
(
1
,)).
item
()
bias
=
None
bias
=
None
if
has_bias
:
bias
=
torch
.
randn
(
dim4
,
device
=
'cuda'
,
dtype
=
torch
.
float16
)
if
has_bias
:
bias
=
torch
.
randn
(
dim4
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
formatB
=
F
.
get_special_format_str
()
formatB
=
F
.
get_special_format_str
()
for
i
in
range
(
1
):
for
i
in
range
(
1
):
A
=
torch
.
randn
(
dim1
,
inner
,
device
=
"cuda"
)
A
=
torch
.
randn
(
dim1
,
inner
,
device
=
"cuda"
)
B
=
torch
.
randn
(
dim4
,
inner
,
device
=
"cuda"
)
B
=
torch
.
randn
(
dim4
,
inner
,
device
=
"cuda"
)
C1
=
torch
.
matmul
(
A
.
half
(),
B
.
t
().
half
())
C1
=
torch
.
matmul
(
A
.
half
(),
B
.
t
().
half
())
if
has_bias
:
C1
+=
bias
if
has_bias
:
C1
+=
bias
A1
,
maxA
=
F
.
vectorwise_quant
(
A
,
dim
=
1
)
A1
,
maxA
=
F
.
vectorwise_quant
(
A
,
dim
=
1
)
B1
,
maxB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
)
B1
,
maxB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
)
...
@@ -883,7 +828,8 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
...
@@ -883,7 +828,8 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
C3
,
S
=
F
.
nvidia_transform
(
C2
,
"row"
,
state
=
SC
)
C3
,
S
=
F
.
nvidia_transform
(
C2
,
"row"
,
state
=
SC
)
C4
=
F
.
vectorwise_mm_dequant
(
C3
.
float
(),
maxA
,
maxB
.
t
())
C4
=
F
.
vectorwise_mm_dequant
(
C3
.
float
(),
maxA
,
maxB
.
t
())
if
has_bias
:
C4
+=
bias
if
has_bias
:
C4
+=
bias
# TODO: is something wrong here? If so, the problem goes deeper
# TODO: is something wrong here? If so, the problem goes deeper
# n = C1.numel()
# n = C1.numel()
...
@@ -917,9 +863,7 @@ def test_colrow_absmax(dim1, dim2, dims):
...
@@ -917,9 +863,7 @@ def test_colrow_absmax(dim1, dim2, dims):
else
:
else
:
assert
False
assert
False
row_stats2
,
col_stats2
,
nnz_block_ptr2
=
F
.
get_colrow_absmax
(
row_stats2
,
col_stats2
,
nnz_block_ptr2
=
F
.
get_colrow_absmax
(
A
,
threshold
=
threshold
)
A
,
threshold
=
threshold
)
A_blocked
=
einops
.
rearrange
(
A_blocked
=
einops
.
rearrange
(
torch
.
abs
(
A
),
torch
.
abs
(
A
),
...
@@ -939,9 +883,7 @@ def test_colrow_absmax(dim1, dim2, dims):
...
@@ -939,9 +883,7 @@ def test_colrow_absmax(dim1, dim2, dims):
torch
.
testing
.
assert_close
(
row_stats1_trunc
,
row_stats2
)
torch
.
testing
.
assert_close
(
row_stats1_trunc
,
row_stats2
)
torch
.
testing
.
assert_close
(
nnz_block_ptr1
.
int
(),
nnz_block_ptr2
)
torch
.
testing
.
assert_close
(
nnz_block_ptr1
.
int
(),
nnz_block_ptr2
)
row_stats2
,
col_stats2
,
nnz_block_ptr2
=
F
.
get_colrow_absmax
(
row_stats2
,
col_stats2
,
nnz_block_ptr2
=
F
.
get_colrow_absmax
(
A
,
threshold
=
0.0
)
A
,
threshold
=
0.0
)
torch
.
testing
.
assert_close
(
col_stats1
,
col_stats2
)
torch
.
testing
.
assert_close
(
col_stats1
,
col_stats2
)
torch
.
testing
.
assert_close
(
row_stats1
,
row_stats2
)
torch
.
testing
.
assert_close
(
row_stats1
,
row_stats2
)
...
@@ -963,24 +905,16 @@ def test_double_quant(dim1, dim2):
...
@@ -963,24 +905,16 @@ def test_double_quant(dim1, dim2):
torch
.
testing
.
assert_close
(
CAt
,
out_col1
,
atol
=
1
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
CAt
,
out_col1
,
atol
=
1
,
rtol
=
0
)
n
=
CAt
.
numel
()
n
=
CAt
.
numel
()
num_not_close_rows
=
(
num_not_close_rows
=
(
torch
.
isclose
(
CA
,
out_row1
,
atol
=
1
)
==
0
).
sum
().
item
()
(
torch
.
isclose
(
CA
,
out_row1
,
atol
=
1
)
==
0
).
sum
().
item
()
num_not_close_cols
=
(
torch
.
isclose
(
CAt
,
out_col1
,
atol
=
1
)
==
0
).
sum
().
item
()
)
num_not_close_cols
=
(
(
torch
.
isclose
(
CAt
,
out_col1
,
atol
=
1
)
==
0
).
sum
().
item
()
)
# allow for 1:500 error due to rounding differences
# allow for 1:500 error due to rounding differences
min_error
=
1
/
500
min_error
=
1
/
500
if
num_not_close_cols
>
(
min_error
*
n
):
if
num_not_close_cols
>
(
min_error
*
n
):
print
(
print
(
f
"Min error exceeded
{
num_not_close_cols
}
elements are different. Error:
{
num_not_close_cols
/
n
:.
4
f
}
"
)
f
"Min error exceeded
{
num_not_close_cols
}
elements are different. Error:
{
num_not_close_cols
/
n
:.
4
f
}
"
)
assert
False
assert
False
if
num_not_close_rows
>
(
min_error
*
n
):
if
num_not_close_rows
>
(
min_error
*
n
):
print
(
print
(
f
"Min error exceeded
{
num_not_close_rows
}
elements are different. Error:
{
num_not_close_rows
/
n
:.
4
f
}
"
)
f
"Min error exceeded
{
num_not_close_rows
}
elements are different. Error:
{
num_not_close_rows
/
n
:.
4
f
}
"
)
assert
False
assert
False
torch
.
testing
.
assert_close
(
Srow
.
flatten
().
float
(),
statsA
)
torch
.
testing
.
assert_close
(
Srow
.
flatten
().
float
(),
statsA
)
...
@@ -991,13 +925,12 @@ def test_double_quant(dim1, dim2):
...
@@ -991,13 +925,12 @@ def test_double_quant(dim1, dim2):
(
"dim1"
,
"dim4"
,
"inner"
),
(
"dim1"
,
"dim4"
,
"inner"
),
(
(
pytest
.
param
(
dim1
,
dim4
,
inner
,
id
=
f
"
{
dim1
=
}
,
{
dim4
=
}
,
{
inner
=
}
"
)
pytest
.
param
(
dim1
,
dim4
,
inner
,
id
=
f
"
{
dim1
=
}
,
{
dim4
=
}
,
{
inner
=
}
"
)
for
(
dim1
,
dim4
,
inner
)
for
(
dim1
,
dim4
,
inner
)
in
zip
(
in
zip
(
get_test_dims
(
1
,
4
*
1024
,
n
=
4
),
get_test_dims
(
1
,
4
*
1024
,
n
=
4
),
get_test_dims
(
1
,
4
*
1024
,
n
=
4
),
get_test_dims
(
1
,
4
*
1024
,
n
=
4
),
get_test_dims
(
1
,
4
*
1024
,
n
=
4
),
get_test_dims
(
1
,
4
*
1024
,
n
=
4
),
)
)
)
)
,
)
)
def
test_integrated_igemmlt
(
dim1
,
dim4
,
inner
):
def
test_integrated_igemmlt
(
dim1
,
dim4
,
inner
):
for
i
in
range
(
k
):
for
i
in
range
(
k
):
...
@@ -1037,13 +970,12 @@ def test_integrated_igemmlt(dim1, dim4, inner):
...
@@ -1037,13 +970,12 @@ def test_integrated_igemmlt(dim1, dim4, inner):
(
"dim1"
,
"dim4"
,
"inner"
),
(
"dim1"
,
"dim4"
,
"inner"
),
(
(
pytest
.
param
(
dim1
,
dim4
,
inner
,
id
=
f
"
{
dim1
=
}
,
{
dim4
=
}
,
{
inner
=
}
"
)
pytest
.
param
(
dim1
,
dim4
,
inner
,
id
=
f
"
{
dim1
=
}
,
{
dim4
=
}
,
{
inner
=
}
"
)
for
(
dim1
,
dim4
,
inner
)
for
(
dim1
,
dim4
,
inner
)
in
zip
(
in
zip
(
get_test_dims
(
1
,
4
*
1024
,
n
=
6
),
get_test_dims
(
1
,
4
*
1024
,
n
=
6
),
get_test_dims
(
1
,
4
*
1024
,
n
=
6
),
get_test_dims
(
1
,
4
*
1024
,
n
=
6
),
get_test_dims
(
1
,
4
*
1024
,
n
=
6
),
get_test_dims
(
1
,
4
*
1024
,
n
=
6
),
)
)
)
)
,
)
)
@
pytest
.
mark
.
skip
(
"Row scale has some bugs for ampere"
)
@
pytest
.
mark
.
skip
(
"Row scale has some bugs for ampere"
)
def
test_igemmlt_row_scale
(
dim1
,
dim4
,
inner
):
def
test_igemmlt_row_scale
(
dim1
,
dim4
,
inner
):
...
@@ -1067,9 +999,7 @@ def test_igemmlt_row_scale(dim1, dim4, inner):
...
@@ -1067,9 +999,7 @@ def test_igemmlt_row_scale(dim1, dim4, inner):
c
=
10.0
*
inner
*
scale
c
=
10.0
*
inner
*
scale
row_scale
=
torch
.
ones_like
(
maxA
)
/
c
row_scale
=
torch
.
ones_like
(
maxA
)
/
c
outC32
,
SC
=
F
.
igemmlt
(
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
,
dtype
=
torch
.
int8
,
row_scale
=
row_scale
)
A2
,
B2
,
SA
,
SB
,
dtype
=
torch
.
int8
,
row_scale
=
row_scale
)
C3
,
S
=
F
.
nvidia_transform
(
outC32
,
"row"
,
state
=
SC
)
C3
,
S
=
F
.
nvidia_transform
(
outC32
,
"row"
,
state
=
SC
)
maxval
=
torch
.
abs
(
C3
).
max
()
maxval
=
torch
.
abs
(
C3
).
max
()
if
maxval
==
127
:
if
maxval
==
127
:
...
@@ -1150,9 +1080,7 @@ def test_row_scale_bench(dim1, dim4, inner):
...
@@ -1150,9 +1080,7 @@ def test_row_scale_bench(dim1, dim4, inner):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
k
):
for
i
in
range
(
k
):
outC32
,
SC
=
F
.
igemmlt
(
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
,
dtype
=
torch
.
int8
,
row_scale
=
row_scale
)
A2
,
B2
,
SA
,
SB
,
dtype
=
torch
.
int8
,
row_scale
=
row_scale
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
"row-wise"
,
time
.
time
()
-
t0
)
print
(
"row-wise"
,
time
.
time
()
-
t0
)
...
@@ -1177,13 +1105,9 @@ def test_row_scale_bench(dim1, dim4, inner):
...
@@ -1177,13 +1105,9 @@ def test_row_scale_bench(dim1, dim4, inner):
def
test_transform
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
orderA
,
orderOut
,
transpose
):
def
test_transform
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
orderA
,
orderOut
,
transpose
):
for
i
in
range
(
k
):
for
i
in
range
(
k
):
if
dims
==
2
:
if
dims
==
2
:
A
=
torch
.
randint
(
10
,
99
,
size
=
(
dim1
,
dim2
),
device
=
"cuda"
).
to
(
A
=
torch
.
randint
(
10
,
99
,
size
=
(
dim1
,
dim2
),
device
=
"cuda"
).
to
(
dtype
)
dtype
)
elif
dims
==
3
:
elif
dims
==
3
:
A
=
torch
.
randint
(
A
=
torch
.
randint
(
10
,
99
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
dtype
)
10
,
99
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
dtype
)
A
.
view
(
-
1
)[
-
1
]
=
-
1
A
.
view
(
-
1
)[
-
1
]
=
-
1
if
transpose
:
if
transpose
:
...
@@ -1224,23 +1148,17 @@ def test_coo_double_quant(dim1, dim2):
...
@@ -1224,23 +1148,17 @@ def test_coo_double_quant(dim1, dim2):
idx
=
torch
.
abs
(
A
)
>=
threshold
idx
=
torch
.
abs
(
A
)
>=
threshold
CA2
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
)
CA2
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
)
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
,
threshold
=
threshold
)
A
,
threshold
=
threshold
)
if
coo_tensor
is
not
None
:
if
coo_tensor
is
not
None
:
A1
=
A
*
idx
A1
=
A
*
idx
A2
=
torch
.
zeros_like
(
A
)
A2
=
torch
.
zeros_like
(
A
)
A2
[
A2
[
coo_tensor
.
rowidx
.
long
(),
coo_tensor
.
colidx
.
long
()]
=
coo_tensor
.
values
coo_tensor
.
rowidx
.
long
(),
coo_tensor
.
colidx
.
long
()
]
=
coo_tensor
.
values
torch
.
testing
.
assert_close
(
A1
,
A2
)
torch
.
testing
.
assert_close
(
A1
,
A2
)
A1
=
A
*
(
idx
==
0
)
A1
=
A
*
(
idx
==
0
)
A2
=
(
CA
.
float
()
*
statsA
.
unsqueeze
(
1
)
/
127
).
half
()
A2
=
(
CA
.
float
()
*
statsA
.
unsqueeze
(
1
)
/
127
).
half
()
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
A
*
(
idx
==
0
),
A2
,
rtol
=
0.05
,
atol
=
1.5e-2
)
A
*
(
idx
==
0
),
A2
,
rtol
=
0.05
,
atol
=
1.5e-2
)
@
pytest
.
mark
.
parametrize
(
"dim1"
,
get_test_dims
(
1
,
1
*
1024
,
n
=
2
),
ids
=
id_formatter
(
"dim1"
))
@
pytest
.
mark
.
parametrize
(
"dim1"
,
get_test_dims
(
1
,
1
*
1024
,
n
=
2
),
ids
=
id_formatter
(
"dim1"
))
...
@@ -1261,9 +1179,7 @@ def test_spmm_coo(dim1, dim2, transposed_B):
...
@@ -1261,9 +1179,7 @@ def test_spmm_coo(dim1, dim2, transposed_B):
nnz
=
(
idx
==
1
).
sum
().
item
()
nnz
=
(
idx
==
1
).
sum
().
item
()
rows
,
cols
=
torch
.
where
(
idx
)
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
A2
=
A
*
idx
if
transposed_B
:
if
transposed_B
:
...
@@ -1303,9 +1219,7 @@ def test_spmm_bench():
...
@@ -1303,9 +1219,7 @@ def test_spmm_bench():
print
(
nnz
/
idx
.
numel
())
print
(
nnz
/
idx
.
numel
())
rows
,
cols
=
torch
.
where
(
idx
)
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
for
i
in
range
(
10
):
for
i
in
range
(
10
):
out2
=
F
.
spmm_coo
(
cooA
,
B
)
out2
=
F
.
spmm_coo
(
cooA
,
B
)
...
@@ -1339,9 +1253,7 @@ def test_integrated_sparse_decomp(dim1, dim2):
...
@@ -1339,9 +1253,7 @@ def test_integrated_sparse_decomp(dim1, dim2):
out1_32
,
Sout1_32
=
F
.
igemmlt
(
C32A
,
CTw1
,
SA
,
Sw1
)
out1_32
,
Sout1_32
=
F
.
igemmlt
(
C32A
,
CTw1
,
SA
,
Sw1
)
out2
=
F
.
mm_dequant
(
out1_32
,
Sout1_32
,
statsA
,
statsw1
)
out2
=
F
.
mm_dequant
(
out1_32
,
Sout1_32
,
statsA
,
statsw1
)
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
,
threshold
=
threshold
)
A
,
threshold
=
threshold
)
C32A
,
SA
=
F
.
transform
(
CA
,
"col32"
)
C32A
,
SA
=
F
.
transform
(
CA
,
"col32"
)
out1_32
,
Sout1_32
=
F
.
igemmlt
(
C32A
,
CTw1
,
SA
,
Sw1
)
out1_32
,
Sout1_32
=
F
.
igemmlt
(
C32A
,
CTw1
,
SA
,
Sw1
)
...
@@ -1396,9 +1308,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
...
@@ -1396,9 +1308,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
nnz
=
(
idx
==
1
).
sum
().
item
()
nnz
=
(
idx
==
1
).
sum
().
item
()
rows
,
cols
=
torch
.
where
(
idx
)
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
A2
=
A
*
idx
out1
=
torch
.
matmul
(
A2
.
half
(),
B
.
half
())
out1
=
torch
.
matmul
(
A2
.
half
(),
B
.
half
())
out
=
out_func
(
out1
.
shape
,
dtype
=
torch
.
float16
,
device
=
out1
.
device
)
out
=
out_func
(
out1
.
shape
,
dtype
=
torch
.
float16
,
device
=
out1
.
device
)
...
@@ -1413,9 +1323,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
...
@@ -1413,9 +1323,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
std
=
out1
.
std
()
std
=
out1
.
std
()
out1
/=
std
out1
/=
std
out2
/=
std
out2
/=
std
assert_all_approx_close
(
assert_all_approx_close
(
out1
,
out2
.
half
(),
rtol
=
0.01
,
atol
=
3.0e-2
,
count
=
count
)
out1
,
out2
.
half
(),
rtol
=
0.01
,
atol
=
3.0e-2
,
count
=
count
)
# assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
# assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
idx_col
=
torch
.
randint
(
0
,
A2
.
shape
[
-
1
],
size
=
(
15
,))
idx_col
=
torch
.
randint
(
0
,
A2
.
shape
[
-
1
],
size
=
(
15
,))
...
@@ -1443,9 +1351,7 @@ def test_coo2csr():
...
@@ -1443,9 +1351,7 @@ def test_coo2csr():
nnz
=
(
idx
==
1
).
sum
().
item
()
nnz
=
(
idx
==
1
).
sum
().
item
()
rows
,
cols
=
torch
.
where
(
idx
)
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
A2
=
A
*
idx
csrA
=
F
.
coo2csr
(
cooA
)
csrA
=
F
.
coo2csr
(
cooA
)
counts
=
csrA
.
rowptr
[
1
:]
-
csrA
.
rowptr
[:
-
1
]
counts
=
csrA
.
rowptr
[
1
:]
-
csrA
.
rowptr
[:
-
1
]
...
@@ -1463,9 +1369,7 @@ def test_coo2csc():
...
@@ -1463,9 +1369,7 @@ def test_coo2csc():
nnz
=
(
idx
==
1
).
sum
().
item
()
nnz
=
(
idx
==
1
).
sum
().
item
()
rows
,
cols
=
torch
.
where
(
idx
)
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
A2
=
A
*
idx
cscA
=
F
.
coo2csc
(
cooA
)
cscA
=
F
.
coo2csc
(
cooA
)
counts
=
cscA
.
colptr
[
1
:]
-
cscA
.
colptr
[:
-
1
]
counts
=
cscA
.
colptr
[
1
:]
-
cscA
.
colptr
[:
-
1
]
...
@@ -1499,9 +1403,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
...
@@ -1499,9 +1403,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
nnz
=
(
idx
==
1
).
sum
().
item
()
nnz
=
(
idx
==
1
).
sum
().
item
()
rows
,
cols
=
torch
.
where
(
idx
)
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
A2
=
A
*
idx
out2
=
F
.
spmm_coo_very_sparse
(
cooA
,
CBt
,
dequant_stats
=
statsBt
)
out2
=
F
.
spmm_coo_very_sparse
(
cooA
,
CBt
,
dequant_stats
=
statsBt
)
out1
=
torch
.
matmul
(
A2
,
B
.
half
())
out1
=
torch
.
matmul
(
A2
,
B
.
half
())
...
@@ -1582,7 +1484,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
...
@@ -1582,7 +1484,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"batch"
,
"seq"
,
"model"
,
"hidden"
),
(
"batch"
,
"seq"
,
"model"
,
"hidden"
),
[
pytest
.
param
(
1
,
1
,
6656
,
4
*
6656
,
id
=
"batch=1, seq=1, model=6656, hidden=26k"
)],
[
pytest
.
param
(
1
,
1
,
6656
,
4
*
6656
,
id
=
"batch=1, seq=1, model=6656, hidden=26k"
)],
)
)
@
pytest
.
mark
.
benchmark
@
pytest
.
mark
.
benchmark
def
test_bench_matmul
(
batch
,
seq
,
model
,
hidden
):
def
test_bench_matmul
(
batch
,
seq
,
model
,
hidden
):
...
@@ -1605,8 +1507,8 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1605,8 +1507,8 @@ def test_bench_matmul(batch, seq, model, hidden):
outliers
=
torch
.
randint
(
0
,
model
,
size
=
(
5
,)).
cuda
()
outliers
=
torch
.
randint
(
0
,
model
,
size
=
(
5
,)).
cuda
()
A
[:,
:,
outliers
]
=
8.0
A
[:,
:,
outliers
]
=
8.0
linearMixedBit
=
(
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
False
,
threshold
=
6.0
).
cuda
().
half
()
)
linearMixedBit
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
False
,
threshold
=
6.0
).
cuda
().
half
()
#linearMixedBit.eval()
#
linearMixedBit.eval()
linear8bit_train
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
).
cuda
().
half
()
linear8bit_train
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
).
cuda
().
half
()
linear8bit_train_thresh
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
threshold
=
6.0
).
cuda
().
half
()
linear8bit_train_thresh
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
threshold
=
6.0
).
cuda
().
half
()
...
@@ -1623,121 +1525,123 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1623,121 +1525,123 @@ def test_bench_matmul(batch, seq, model, hidden):
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
torch
.
matmul
(
A
,
B
.
t
())
torch
.
matmul
(
A
,
B
.
t
())
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
"pytorch fp16: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
print
(
f
"pytorch fp16: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
,
)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#t0 = time.time()
#
t0 = time.time()
#for i in range(iters):
#
for i in range(iters):
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state)
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
#
print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#t0 = time.time()
#
t0 = time.time()
#for i in range(iters):
#
for i in range(iters):
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c)
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
#
print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
bnb
.
matmul_4bit
(
A
,
B_nf4
.
t
(),
quant_state
=
state_nf4
)
bnb
.
matmul_4bit
(
A
,
B_nf4
.
t
(),
quant_state
=
state_nf4
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
"bnb nf4: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
print
(
f
"bnb nf4: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
bnb
.
matmul_4bit
(
A
,
B_nf4_c
.
t
(),
quant_state
=
state_nf4_c
)
bnb
.
matmul_4bit
(
A
,
B_nf4_c
.
t
(),
quant_state
=
state_nf4_c
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
"bnb nf4+DQ: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
print
(
f
"bnb nf4+DQ: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
# torch.cuda.synchronize()
#torch.cuda.synchronize()
# t0 = time.time()
#t0 = time.time()
# for i in range(iters):
#for i in range(iters):
# bnb.matmul(A, B)
# bnb.matmul(A, B)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#
print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#t0 = time.time()
#
t0 = time.time()
#for i in range(iters):
#
for i in range(iters):
# bnb.matmul(A, B, threshold=6.0)
# bnb.matmul(A, B, threshold=6.0)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#
print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
#
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
#C32A, SA = F.transform(CA, "col32")
#
C32A, SA = F.transform(CA, "col32")
#CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
#
CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
#CxB, SB = F.transform(CB, to_order=formatB)
#
CxB, SB = F.transform(CB, to_order=formatB)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#t0 = time.time()
#
t0 = time.time()
#for i in range(iters):
#
for i in range(iters):
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#
print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#BA, statsB = F.vectorwise_quant(B, dim=1)
#
BA, statsB = F.vectorwise_quant(B, dim=1)
#CxB, SB = F.nvidia_transform(CB, to_order=formatB)
#
CxB, SB = F.nvidia_transform(CB, to_order=formatB)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#t0 = time.time()
#
t0 = time.time()
#for i in range(iters):
#
for i in range(iters):
# A2 = A.view(-1, A.shape[-1]).contiguous()
# A2 = A.view(-1, A.shape[-1]).contiguous()
# CA, statsA = F.vectorwise_quant(A2, dim=1)
# CA, statsA = F.vectorwise_quant(A2, dim=1)
# C32A, SA = F.nvidia_transform(CA, "col32")
# C32A, SA = F.nvidia_transform(CA, "col32")
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
# Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
# F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
# F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#
print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
#
BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
#CxB, SB = F.nvidia_transform(CB, to_order=formatB)
#
CxB, SB = F.nvidia_transform(CB, to_order=formatB)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#t0 = time.time()
#
t0 = time.time()
#for i in range(iters):
#
for i in range(iters):
# A2 = A.view(-1, A.shape[-1]).contiguous()
# A2 = A.view(-1, A.shape[-1]).contiguous()
# CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
# CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
# C32A, SA = F.nvidia_transform(CA, "col32")
# C32A, SA = F.nvidia_transform(CA, "col32")
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
# Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
# out = Cout * statsB * statsA * (1.0 / (127 * 127))
# out = Cout * statsB * statsA * (1.0 / (127 * 127))
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#
print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#linear8bit(A)
#
linear8bit(A)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#t0 = time.time()
#
t0 = time.time()
#for i in range(iters):
#
for i in range(iters):
# linear8bit(A)
# linear8bit(A)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#
print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#linearMixedBit(A)
#
linearMixedBit(A)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#t0 = time.time()
#
t0 = time.time()
#for i in range(iters):
#
for i in range(iters):
# linearMixedBit(A)
# linearMixedBit(A)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#
print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#linear8bit_train(A)
#
linear8bit_train(A)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#t0 = time.time()
#
t0 = time.time()
#for i in range(iters):
#
for i in range(iters):
# linear8bit_train(A)
# linear8bit_train(A)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#
print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#linear8bit_train_thresh(A)
#
linear8bit_train_thresh(A)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#t0 = time.time()
#
t0 = time.time()
#for i in range(iters):
#
for i in range(iters):
# linear8bit_train(A)
# linear8bit_train(A)
#torch.cuda.synchronize()
# torch.cuda.synchronize()
#print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
# print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
def
test_zeropoint
():
def
test_zeropoint
():
def
quant_zp
(
x
):
def
quant_zp
(
x
):
...
@@ -1778,8 +1682,8 @@ def test_zeropoint():
...
@@ -1778,8 +1682,8 @@ def test_zeropoint():
C2
-=
A
.
sum
(
1
).
view
(
-
1
,
1
)
*
zp
C2
-=
A
.
sum
(
1
).
view
(
-
1
,
1
)
*
zp
ca
,
cqa
,
cza
=
quant_zp
(
A
)
ca
,
cqa
,
cza
=
quant_zp
(
A
)
#print(ca.min(), ca.max())
#
print(ca.min(), ca.max())
#print((ca - cza).min(), (ca - cza).max())
#
print((ca - cza).min(), (ca - cza).max())
zp
=
1
zp
=
1
scale
=
2.0
scale
=
2.0
...
@@ -1808,14 +1712,14 @@ def test_zeropoint():
...
@@ -1808,14 +1712,14 @@ def test_zeropoint():
C7
-=
zpa
*
zpb
*
A
.
shape
[
1
]
C7
-=
zpa
*
zpb
*
A
.
shape
[
1
]
C7
/=
qa
*
qb
C7
/=
qa
*
qb
#print("")
#
print("")
# print(C0.flatten()[:10])
# print(C0.flatten()[:10])
#print(C1.flatten()[:10])
#
print(C1.flatten()[:10])
#print(C2.flatten()[:10])
#
print(C2.flatten()[:10])
#print(C3.flatten()[:10])
#
print(C3.flatten()[:10])
#print(C5.flatten()[:10])
#
print(C5.flatten()[:10])
#print(C6.flatten()[:10])
#
print(C6.flatten()[:10])
#print(C7.flatten()[:10])
#
print(C7.flatten()[:10])
err1
=
torch
.
abs
(
C1
-
C2
).
mean
().
item
()
err1
=
torch
.
abs
(
C1
-
C2
).
mean
().
item
()
err2
=
torch
.
abs
(
C1
-
C3
).
mean
().
item
()
err2
=
torch
.
abs
(
C1
-
C3
).
mean
().
item
()
err3
=
torch
.
abs
(
C1
-
C4
).
mean
().
item
()
err3
=
torch
.
abs
(
C1
-
C4
).
mean
().
item
()
...
@@ -1852,16 +1756,15 @@ def test_extract_outliers():
...
@@ -1852,16 +1756,15 @@ def test_extract_outliers():
torch
.
testing
.
assert_close
(
outliers1
,
outliers2
)
torch
.
testing
.
assert_close
(
outliers1
,
outliers2
)
def
test_blockwise_cpu_large
():
def
test_blockwise_cpu_large
():
diffs
=
[]
diffs
=
[]
reldiffs
=
[]
reldiffs
=
[]
batch
=
128
batch
=
128
seq
=
128
seq
=
128
for
hidden
in
[
128
]:
#
, 14336]:
for
hidden
in
[
128
]:
#
, 14336]:
for
blocksize
in
[
4096
,
16384
]:
for
blocksize
in
[
4096
,
16384
]:
for
i
in
range
(
2
):
for
i
in
range
(
2
):
A1
=
torch
.
randn
(
batch
,
seq
,
hidden
,
device
=
'
cpu
'
)
A1
=
torch
.
randn
(
batch
,
seq
,
hidden
,
device
=
"
cpu
"
)
t0
=
time
.
time
()
t0
=
time
.
time
()
C
,
S
=
F
.
quantize_blockwise
(
A1
,
blocksize
=
blocksize
)
C
,
S
=
F
.
quantize_blockwise
(
A1
,
blocksize
=
blocksize
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
,
blocksize
=
blocksize
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
,
blocksize
=
blocksize
)
...
@@ -1875,10 +1778,9 @@ def test_blockwise_cpu_large():
...
@@ -1875,10 +1778,9 @@ def test_blockwise_cpu_large():
# print(sum(reldiffs)/len(reldiffs))
# print(sum(reldiffs)/len(reldiffs))
def
test_fp8_quant
():
def
test_fp8_quant
():
for
e_bits
in
range
(
1
,
7
):
for
e_bits
in
range
(
1
,
7
):
p_bits
=
7
-
e_bits
p_bits
=
7
-
e_bits
code
=
F
.
create_fp8_map
(
True
,
e_bits
,
p_bits
).
cuda
()
code
=
F
.
create_fp8_map
(
True
,
e_bits
,
p_bits
).
cuda
()
abserr
=
[]
abserr
=
[]
...
@@ -1888,12 +1790,12 @@ def test_fp8_quant():
...
@@ -1888,12 +1790,12 @@ def test_fp8_quant():
C
,
SC
=
F
.
quantize_blockwise
(
A1
,
code
=
code
)
C
,
SC
=
F
.
quantize_blockwise
(
A1
,
code
=
code
)
A2
=
F
.
dequantize_blockwise
(
C
,
SC
)
A2
=
F
.
dequantize_blockwise
(
C
,
SC
)
diff
=
torch
.
abs
(
A1
-
A2
)
diff
=
torch
.
abs
(
A1
-
A2
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
abserr
.
append
(
diff
.
mean
().
item
())
abserr
.
append
(
diff
.
mean
().
item
())
relerr
.
append
(
reldiff
.
mean
().
item
())
relerr
.
append
(
reldiff
.
mean
().
item
())
#assert diff < 0.0075
#
assert diff < 0.0075
#print(sum(abserr)/len(abserr))
#
print(sum(abserr)/len(abserr))
#print(sum(relerr)/len(relerr))
#
print(sum(relerr)/len(relerr))
abserr
=
[]
abserr
=
[]
relerr
=
[]
relerr
=
[]
...
@@ -1902,12 +1804,12 @@ def test_fp8_quant():
...
@@ -1902,12 +1804,12 @@ def test_fp8_quant():
C
,
SC
=
F
.
quantize_blockwise
(
A1
,
code
=
code
)
C
,
SC
=
F
.
quantize_blockwise
(
A1
,
code
=
code
)
A2
=
F
.
dequantize_blockwise
(
C
,
SC
)
A2
=
F
.
dequantize_blockwise
(
C
,
SC
)
diff
=
torch
.
abs
(
A1
-
A2
)
diff
=
torch
.
abs
(
A1
-
A2
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
abserr
.
append
(
diff
.
mean
().
item
())
abserr
.
append
(
diff
.
mean
().
item
())
relerr
.
append
(
reldiff
.
mean
().
item
())
relerr
.
append
(
reldiff
.
mean
().
item
())
#assert diff < 0.0075
#
assert diff < 0.0075
#print(sum(abserr)/len(abserr))
#
print(sum(abserr)/len(abserr))
#print(sum(relerr)/len(relerr))
#
print(sum(relerr)/len(relerr))
abserr
=
[]
abserr
=
[]
relerr
=
[]
relerr
=
[]
...
@@ -1916,50 +1818,48 @@ def test_fp8_quant():
...
@@ -1916,50 +1818,48 @@ def test_fp8_quant():
C
,
SC
=
F
.
quantize_blockwise
(
A1
)
C
,
SC
=
F
.
quantize_blockwise
(
A1
)
A2
=
F
.
dequantize_blockwise
(
C
,
SC
)
A2
=
F
.
dequantize_blockwise
(
C
,
SC
)
diff
=
torch
.
abs
(
A1
-
A2
)
diff
=
torch
.
abs
(
A1
-
A2
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
abserr
.
append
(
diff
.
mean
().
item
())
abserr
.
append
(
diff
.
mean
().
item
())
relerr
.
append
(
reldiff
.
mean
().
item
())
relerr
.
append
(
reldiff
.
mean
().
item
())
#assert diff < 0.0075
#
assert diff < 0.0075
#print(3, sum(abserr)/len(abserr))
#
print(3, sum(abserr)/len(abserr))
#print(3, sum(relerr)/len(relerr))
#
print(3, sum(relerr)/len(relerr))
def
test_few_bit_quant
():
def
test_few_bit_quant
():
# print('')
#print('')
for
bits
in
range
(
2
,
9
):
for
bits
in
range
(
2
,
9
):
#print('='*30, bits, '='*30)
#
print('='*30, bits, '='*30)
for
method
in
[
'
linear
'
,
'
fp8
'
,
'
dynamic
'
,
'
quantile
'
]:
for
method
in
[
"
linear
"
,
"
fp8
"
,
"
dynamic
"
,
"
quantile
"
]:
abserrs
=
[]
abserrs
=
[]
relerrs
=
[]
relerrs
=
[]
code
=
None
code
=
None
if
method
==
'
linear
'
:
if
method
==
"
linear
"
:
code
=
F
.
create_linear_map
(
True
,
total_bits
=
bits
).
cuda
()
code
=
F
.
create_linear_map
(
True
,
total_bits
=
bits
).
cuda
()
elif
method
==
'
fp8
'
:
elif
method
==
"
fp8
"
:
ebits
=
math
.
ceil
(
bits
/
2
)
ebits
=
math
.
ceil
(
bits
/
2
)
pbits
=
bits
-
ebits
-
1
pbits
=
bits
-
ebits
-
1
code
=
F
.
create_fp8_map
(
True
,
ebits
,
pbits
,
bits
).
cuda
()
code
=
F
.
create_fp8_map
(
True
,
ebits
,
pbits
,
bits
).
cuda
()
elif
method
==
'
dynamic
'
:
elif
method
==
"
dynamic
"
:
code
=
F
.
create_dynamic_map
(
True
,
bits
-
0
,
bits
).
cuda
()
code
=
F
.
create_dynamic_map
(
True
,
bits
-
0
,
bits
).
cuda
()
elif
method
==
'
quantile
'
:
elif
method
==
"
quantile
"
:
values
=
torch
.
randn
(
2048
,
2048
,
device
=
'
cuda
'
)
values
=
torch
.
randn
(
2048
,
2048
,
device
=
"
cuda
"
)
code
=
F
.
create_quantile_map
(
values
,
bits
).
cuda
()
code
=
F
.
create_quantile_map
(
values
,
bits
).
cuda
()
# for some data types we have no zero
# for some data types we have no zero
# for some data types we have one zero
# for some data types we have one zero
# for some data types we have two zeros
# for some data types we have two zeros
assert
torch
.
unique
(
code
).
numel
()
in
[
2
**
bits
,
2
**
bits
-
1
],
f
'
bits:
{
bits
}
, method:
{
method
}
'
assert
torch
.
unique
(
code
).
numel
()
in
[
2
**
bits
,
2
**
bits
-
1
],
f
"
bits:
{
bits
}
, method:
{
method
}
"
#print(method, (code==0).sum())
#
print(method, (code==0).sum())
assert
code
.
numel
()
==
256
assert
code
.
numel
()
==
256
for
i
in
range
(
10
):
for
i
in
range
(
10
):
values
=
torch
.
randn
(
1
,
32
,
device
=
"cuda"
)
values
=
torch
.
randn
(
1
,
32
,
device
=
'cuda'
)
values
/=
values
.
abs
().
max
()
values
/=
values
.
abs
().
max
()
#values[values.abs() < 1e-6] += 1e-5
#
values[values.abs() < 1e-6] += 1e-5
q1
=
[]
q1
=
[]
v1
=
[]
v1
=
[]
for
v
in
values
[
0
]:
for
v
in
values
[
0
]:
idx
=
torch
.
abs
(
v
-
code
).
argmin
()
idx
=
torch
.
abs
(
v
-
code
).
argmin
()
q1
.
append
(
idx
.
item
())
q1
.
append
(
idx
.
item
())
v1
.
append
(
code
[
idx
].
item
())
v1
.
append
(
code
[
idx
].
item
())
...
@@ -1970,62 +1870,61 @@ def test_few_bit_quant():
...
@@ -1970,62 +1870,61 @@ def test_few_bit_quant():
v2
=
F
.
dequantize_blockwise
(
q2
,
S2
)
v2
=
F
.
dequantize_blockwise
(
q2
,
S2
)
idx
=
torch
.
isclose
(
q1
.
int
(),
q2
.
int
())
idx
=
torch
.
isclose
(
q1
.
int
(),
q2
.
int
())
err2
=
torch
.
abs
(
v2
-
values
)
err2
=
torch
.
abs
(
v2
-
values
)
abserrs
.
append
(
err2
.
mean
().
item
())
abserrs
.
append
(
err2
.
mean
().
item
())
relerrs
.
append
((
err2
/
(
1e-10
+
values
).
abs
()).
mean
().
item
())
relerrs
.
append
((
err2
/
(
1e-10
+
values
).
abs
()).
mean
().
item
())
if
idx
.
sum
():
if
idx
.
sum
():
# some weird cases
# some weird cases
err1
=
torch
.
abs
(
v1
-
values
).
mean
()
err1
=
torch
.
abs
(
v1
-
values
).
mean
()
#assert err2.mean() <= err1
#
assert err2.mean() <= err1
else
:
else
:
torch
.
testing
.
assert_close
(
q1
,
q2
)
torch
.
testing
.
assert_close
(
q1
,
q2
)
#print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
#
print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
#assert False
#
assert False
def
test_kbit_quantile_estimation
():
def
test_kbit_quantile_estimation
():
for
i
in
range
(
100
):
for
i
in
range
(
100
):
data
=
torch
.
randn
(
1024
,
1024
,
device
=
'
cuda
'
)
data
=
torch
.
randn
(
1024
,
1024
,
device
=
"
cuda
"
)
for
bits
in
range
(
2
,
9
):
for
bits
in
range
(
2
,
9
):
p
=
np
.
linspace
(
1.3e-4
,
1
-
1.3e-4
,
2
**
bits
)
p
=
np
.
linspace
(
1.3e-4
,
1
-
1.3e-4
,
2
**
bits
)
val1
=
torch
.
Tensor
(
norm
.
ppf
(
p
)).
cuda
()
val1
=
torch
.
Tensor
(
norm
.
ppf
(
p
)).
cuda
()
val2
=
F
.
estimate_quantiles
(
data
,
offset
=
0
,
num_quantiles
=
2
**
bits
)
val2
=
F
.
estimate_quantiles
(
data
,
offset
=
0
,
num_quantiles
=
2
**
bits
)
err
=
torch
.
abs
(
val1
-
val2
).
mean
()
err
=
torch
.
abs
(
val1
-
val2
).
mean
()
assert
err
<
0.038
assert
err
<
0.038
for
i
in
range
(
100
):
for
i
in
range
(
100
):
data
=
torch
.
randn
(
1024
,
1024
,
device
=
'
cuda
'
)
data
=
torch
.
randn
(
1024
,
1024
,
device
=
"
cuda
"
)
for
bits
in
range
(
2
,
4
):
for
bits
in
range
(
2
,
4
):
total_values
=
2
**
bits
-
1
total_values
=
2
**
bits
-
1
p
=
np
.
linspace
(
0
,
1
,
2
*
total_values
+
1
)
p
=
np
.
linspace
(
0
,
1
,
2
*
total_values
+
1
)
idx
=
np
.
arange
(
1
,
2
*
total_values
+
1
,
2
)
idx
=
np
.
arange
(
1
,
2
*
total_values
+
1
,
2
)
p
=
p
[
idx
]
p
=
p
[
idx
]
offset
=
1
/
(
2
*
total_values
)
offset
=
1
/
(
2
*
total_values
)
p
=
np
.
linspace
(
offset
,
1
-
offset
,
total_values
)
p
=
np
.
linspace
(
offset
,
1
-
offset
,
total_values
)
val1
=
torch
.
Tensor
(
norm
.
ppf
(
p
)).
cuda
()
val1
=
torch
.
Tensor
(
norm
.
ppf
(
p
)).
cuda
()
val2
=
F
.
estimate_quantiles
(
data
,
num_quantiles
=
2
**
bits
-
1
)
val2
=
F
.
estimate_quantiles
(
data
,
num_quantiles
=
2
**
bits
-
1
)
err
=
torch
.
abs
(
val1
-
val2
).
mean
()
err
=
torch
.
abs
(
val1
-
val2
).
mean
()
assert
err
<
0.035
assert
err
<
0.035
@
pytest
.
mark
.
benchmark
@
pytest
.
mark
.
benchmark
def
test_bench_dequantization
():
def
test_bench_dequantization
():
a
=
torch
.
rand
(
1024
,
1024
,
device
=
'
cuda
'
).
half
()
a
=
torch
.
rand
(
1024
,
1024
,
device
=
"
cuda
"
).
half
()
code
=
F
.
create_fp8_map
(
True
,
3
,
0
,
4
).
cuda
()
code
=
F
.
create_fp8_map
(
True
,
3
,
0
,
4
).
cuda
()
qa
,
SA
=
F
.
quantize_blockwise
(
a
,
code
=
code
)
qa
,
SA
=
F
.
quantize_blockwise
(
a
,
code
=
code
)
print
(
qa
.
max
())
print
(
qa
.
max
())
max_theoretical_mu
=
1024
*
1024
*
2
/
1024
**
3
/
672
*
1000
*
1000
max_theoretical_mu
=
1024
*
1024
*
2
/
1024
**
3
/
672
*
1000
*
1000
#print(max_theoretical_mu)
#
print(max_theoretical_mu)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
for
i
in
range
(
100
):
qa
,
SA
=
F
.
quantize_blockwise
(
a
)
qa
,
SA
=
F
.
quantize_blockwise
(
a
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
#print((time.time()-t0)/1e6)
# print((time.time()-t0)/1e6)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
],
ids
=
describe_dtype
)
...
@@ -2037,26 +1936,28 @@ def test_fp4_quant(dtype):
...
@@ -2037,26 +1936,28 @@ def test_fp4_quant(dtype):
result
=
0
result
=
0
bias
=
3
bias
=
3
sign
,
e1
,
e2
,
p1
=
bits
sign
,
e1
,
e2
,
p1
=
bits
idx
=
sign
*
8
+
e1
*
4
+
e2
*
2
+
p1
*
1
idx
=
sign
*
8
+
e1
*
4
+
e2
*
2
+
p1
*
1
sign
=
-
1.0
if
sign
else
1.0
sign
=
-
1.0
if
sign
else
1.0
exp
=
e1
*
2
+
e2
*
1
exp
=
e1
*
2
+
e2
*
1
if
exp
==
0
:
if
exp
==
0
:
# sub-normal
# sub-normal
if
p1
==
0
:
result
=
0
if
p1
==
0
:
else
:
result
=
sign
*
0.0625
result
=
0
else
:
result
=
sign
*
0.0625
else
:
else
:
# normal
# normal
exp
=
2
**
(
-
exp
+
bias
+
1
)
exp
=
2
**
(
-
exp
+
bias
+
1
)
frac
=
1.5
if
p1
else
1.0
frac
=
1.5
if
p1
else
1.0
result
=
sign
*
exp
*
frac
result
=
sign
*
exp
*
frac
code
[
idx
]
=
result
code
[
idx
]
=
result
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'
cuda
'
,
dtype
=
dtype
)
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
"
cuda
"
,
dtype
=
dtype
)
qa
,
SA
=
F
.
quantize_fp4
(
A1
,
blocksize
=
64
)
qa
,
SA
=
F
.
quantize_fp4
(
A1
,
blocksize
=
64
)
A2
=
F
.
dequantize_fp4
(
qa
,
SA
)
A2
=
F
.
dequantize_fp4
(
qa
,
SA
)
err
=
(
A1
-
A2
).
abs
().
float
()
err
=
(
A1
-
A2
).
abs
().
float
()
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-8
)).
mean
()
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-8
)).
mean
()
idx
=
err
>
1.0
idx
=
err
>
1.0
err
=
err
.
mean
()
err
=
err
.
mean
()
...
@@ -2065,31 +1966,29 @@ def test_fp4_quant(dtype):
...
@@ -2065,31 +1966,29 @@ def test_fp4_quant(dtype):
assert
relerr
.
item
()
<
0.28
assert
relerr
.
item
()
<
0.28
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
'
fp4
'
,
'
nf4
'
])
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
"
fp4
"
,
"
nf4
"
])
def
test_4bit_compressed_stats
(
quant_type
):
def
test_4bit_compressed_stats
(
quant_type
):
for
blocksize
in
[
128
,
64
]:
for
blocksize
in
[
128
,
64
]:
errs1
=
[]
errs1
=
[]
errs2
=
[]
errs2
=
[]
for
i
in
range
(
10
):
for
i
in
range
(
10
):
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'
cuda
'
).
half
()
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
"
cuda
"
).
half
()
q2
,
SA2
=
F
.
quantize_4bit
(
A1
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
q2
,
SA2
=
F
.
quantize_4bit
(
A1
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
q3
,
SA3
=
F
.
quantize_4bit
(
A1
,
blocksize
=
blocksize
,
compress_statistics
=
True
,
quant_type
=
quant_type
)
q3
,
SA3
=
F
.
quantize_4bit
(
A1
,
blocksize
=
blocksize
,
compress_statistics
=
True
,
quant_type
=
quant_type
)
A2
=
F
.
dequantize_4bit
(
q2
,
SA2
,
quant_type
=
quant_type
)
A2
=
F
.
dequantize_4bit
(
q2
,
SA2
,
quant_type
=
quant_type
)
A3
=
F
.
dequantize_4bit
(
q3
,
SA3
,
quant_type
=
quant_type
)
A3
=
F
.
dequantize_4bit
(
q3
,
SA3
,
quant_type
=
quant_type
)
err
=
(
A1
-
A2
).
abs
().
float
()
err
=
(
A1
-
A2
).
abs
().
float
()
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-15
)).
mean
()
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-15
)).
mean
()
err
=
err
.
mean
()
err
=
err
.
mean
()
errs1
.
append
(
err
.
item
())
errs1
.
append
(
err
.
item
())
assert
err
.
item
()
<
0.11
assert
err
.
item
()
<
0.11
assert
relerr
.
item
()
<
0.28
assert
relerr
.
item
()
<
0.28
err
=
(
A1
-
A3
).
abs
().
float
()
err
=
(
A1
-
A3
).
abs
().
float
()
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-15
)).
mean
()
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-15
)).
mean
()
err
=
err
.
mean
()
err
=
err
.
mean
()
errs2
.
append
(
err
.
item
())
errs2
.
append
(
err
.
item
())
...
@@ -2097,70 +1996,71 @@ def test_4bit_compressed_stats(quant_type):
...
@@ -2097,70 +1996,71 @@ def test_4bit_compressed_stats(quant_type):
assert
err
.
item
()
<
0.11
assert
err
.
item
()
<
0.11
assert
relerr
.
item
()
<
0.28
assert
relerr
.
item
()
<
0.28
#print(sum(errs1)/len(errs1), blocksize, quant_type)
# print(sum(errs1)/len(errs1), blocksize, quant_type)
#print(sum(errs2)/len(errs2), blocksize, quant_type)
# print(sum(errs2)/len(errs2), blocksize, quant_type)
# @pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
#@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
"nf4"
])
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
'nf4'
])
@
pytest
.
mark
.
benchmark
@
pytest
.
mark
.
benchmark
def
test_bench_4bit_dequant
(
quant_type
):
def
test_bench_4bit_dequant
(
quant_type
):
blocksize
=
256
blocksize
=
256
a
=
torch
.
rand
(
1024
*
12
*
4
,
1024
*
12
,
device
=
'
cuda
'
).
half
()
a
=
torch
.
rand
(
1024
*
12
*
4
,
1024
*
12
,
device
=
"
cuda
"
).
half
()
qa
,
SA
=
F
.
quantize_4bit
(
a
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
qa
,
SA
=
F
.
quantize_4bit
(
a
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
input_size
=
a
.
numel
()
/
2
input_size
=
a
.
numel
()
/
2
output_size
=
a
.
numel
()
*
2
output_size
=
a
.
numel
()
*
2
num_bytes
=
input_size
+
output_size
num_bytes
=
input_size
+
output_size
GB
=
num_bytes
/
1e9
GB
=
num_bytes
/
1e9
max_theoretical_s
=
GB
/
768
max_theoretical_s
=
GB
/
768
#print(max_theoretical_s*1e6)
#
print(max_theoretical_s*1e6)
b
=
torch
.
randn
(
128
,
1024
*
12
,
device
=
'
cuda
'
).
half
()
b
=
torch
.
randn
(
128
,
1024
*
12
,
device
=
"
cuda
"
).
half
()
iters
=
100
iters
=
100
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
F
.
dequantize_4bit
(
qa
,
SA
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
F
.
dequantize_4bit
(
qa
,
SA
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
#b.copy_(a)
#
b.copy_(a)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
#print((time.time()-t0)/iters*1e6)
#
print((time.time()-t0)/iters*1e6)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#t0 = time.time()
#
t0 = time.time()
#for i in range(iters):
#
for i in range(iters):
# torch.matmul(b, a.t())
# torch.matmul(b, a.t())
#torch.cuda.synchronize()
# torch.cuda.synchronize()
#print((time.time()-t0)/iters*1e6)
# print((time.time()-t0)/iters*1e6)
def
test_normal_map_tree
():
def
test_normal_map_tree
():
code
=
F
.
create_normal_map
()
code
=
F
.
create_normal_map
()
values
=
code
[:
8
].
tolist
()
+
code
[
-
8
:].
tolist
()
values
=
code
[:
8
].
tolist
()
+
code
[
-
8
:].
tolist
()
num_pivots
=
1
num_pivots
=
1
#print(values)
#
print(values)
while
num_pivots
<
16
:
while
num_pivots
<
16
:
idx
=
list
(
range
(
16
//
num_pivots
//
2
,
16
,
16
//
num_pivots
))
idx
=
list
(
range
(
16
//
num_pivots
//
2
,
16
,
16
//
num_pivots
))
#print(idx)
#
print(idx)
num_pivots
*=
2
num_pivots
*=
2
pivots
=
[]
pivots
=
[]
for
i
in
idx
:
for
i
in
idx
:
pivots
.
append
((
values
[
i
-
1
]
+
values
[
i
])
/
2
)
pivots
.
append
((
values
[
i
-
1
]
+
values
[
i
])
/
2
)
#print(pivots)
#
print(pivots)
@
pytest
.
mark
.
parametrize
(
"double_quant"
,
TRUE_FALSE
,
ids
=
lambda
double_quant
:
f
"DQ_
{
double_quant
}
"
)
@
pytest
.
mark
.
parametrize
(
"double_quant"
,
TRUE_FALSE
,
ids
=
lambda
double_quant
:
f
"DQ_
{
double_quant
}
"
)
@
pytest
.
mark
.
parametrize
(
"storage_type"
,
[
'
nf4
'
,
'
fp4
'
])
@
pytest
.
mark
.
parametrize
(
"storage_type"
,
[
"
nf4
"
,
"
fp4
"
])
@
pytest
.
mark
.
parametrize
(
"kind"
,
[
'
fc1
'
,
'
fc2
'
,
'
attn
'
,
'
attn_packed
'
])
@
pytest
.
mark
.
parametrize
(
"kind"
,
[
"
fc1
"
,
"
fc2
"
,
"
attn
"
,
"
attn_packed
"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"quant_storage"
,
[
torch
.
uint8
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"quant_storage"
,
[
torch
.
uint8
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
ids
=
describe_dtype
,
)
def
test_gemv_4bit
(
dtype
,
storage_type
,
quant_storage
,
double_quant
,
kind
):
def
test_gemv_4bit
(
dtype
,
storage_type
,
quant_storage
,
double_quant
,
kind
):
for
dim
in
[
128
,
256
,
512
,
1024
]:
for
dim
in
[
128
,
256
,
512
,
1024
]:
#for dim in [4*1024]:
#
for dim in [4*1024]:
#for dim in [1*16]:
#
for dim in [1*16]:
errs1
=
[]
errs1
=
[]
errs2
=
[]
errs2
=
[]
errs3
=
[]
errs3
=
[]
...
@@ -2171,38 +2071,42 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
...
@@ -2171,38 +2071,42 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
max_errs2
=
[]
max_errs2
=
[]
max_errs3
=
[]
max_errs3
=
[]
for
i
in
range
(
100
):
for
i
in
range
(
100
):
if
kind
==
'fc1'
:
if
kind
==
"fc1"
:
A
=
torch
.
randn
(
1
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
A
=
torch
.
randn
(
1
,
dim
,
dtype
=
dtype
,
device
=
"cuda"
)
B
=
torch
.
randn
(
dim
*
4
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
B
=
torch
.
randn
(
dim
*
4
,
dim
,
dtype
=
dtype
,
device
=
"cuda"
)
/
math
.
sqrt
(
dim
)
elif
kind
==
'fc2'
:
elif
kind
==
"fc2"
:
A
=
torch
.
randn
(
1
,
4
*
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
A
=
torch
.
randn
(
1
,
4
*
dim
,
dtype
=
dtype
,
device
=
"cuda"
)
B
=
torch
.
randn
(
dim
,
4
*
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
B
=
torch
.
randn
(
dim
,
4
*
dim
,
dtype
=
dtype
,
device
=
"cuda"
)
/
math
.
sqrt
(
dim
)
elif
kind
==
'attn'
:
elif
kind
==
"attn"
:
A
=
torch
.
randn
(
1
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
A
=
torch
.
randn
(
1
,
dim
,
dtype
=
dtype
,
device
=
"cuda"
)
B
=
torch
.
randn
(
dim
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
B
=
torch
.
randn
(
dim
,
dim
,
dtype
=
dtype
,
device
=
"cuda"
)
/
math
.
sqrt
(
dim
)
elif
kind
==
'attn_packed'
:
elif
kind
==
"attn_packed"
:
A
=
torch
.
randn
(
1
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
A
=
torch
.
randn
(
1
,
dim
,
dtype
=
dtype
,
device
=
"cuda"
)
B
=
torch
.
randn
(
dim
*
3
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
B
=
torch
.
randn
(
dim
*
3
,
dim
,
dtype
=
dtype
,
device
=
"cuda"
)
/
math
.
sqrt
(
dim
)
qB
,
state
=
F
.
quantize_4bit
(
B
,
quant_type
=
storage_type
,
compress_statistics
=
double_quant
,
quant_storage
=
quant_storage
)
qB
,
state
=
F
.
quantize_4bit
(
B
,
quant_type
=
storage_type
,
compress_statistics
=
double_quant
,
quant_storage
=
quant_storage
,
)
C3
=
torch
.
matmul
(
A
,
B
.
t
())
C3
=
torch
.
matmul
(
A
,
B
.
t
())
C2
=
F
.
gemv_4bit
(
A
,
qB
.
t
(),
state
=
state
)
C2
=
F
.
gemv_4bit
(
A
,
qB
.
t
(),
state
=
state
)
A
.
requires_grad
=
True
A
.
requires_grad
=
True
C1
=
bnb
.
matmul_4bit
(
A
,
qB
.
t
(),
state
)
C1
=
bnb
.
matmul_4bit
(
A
,
qB
.
t
(),
state
)
err1
=
(
C1
-
C2
).
abs
().
float
()
err1
=
(
C1
-
C2
).
abs
().
float
()
err2
=
(
C3
-
C2
).
abs
().
float
()
err2
=
(
C3
-
C2
).
abs
().
float
()
err3
=
(
C3
-
C1
).
abs
().
float
()
err3
=
(
C3
-
C1
).
abs
().
float
()
mag1
=
torch
.
abs
(
C1
).
float
()
+
1e-5
mag1
=
torch
.
abs
(
C1
).
float
()
+
1e-5
mag2
=
torch
.
abs
(
C3
).
float
()
+
1e-5
mag2
=
torch
.
abs
(
C3
).
float
()
+
1e-5
mag3
=
torch
.
abs
(
C3
).
float
()
+
1e-5
mag3
=
torch
.
abs
(
C3
).
float
()
+
1e-5
relerr1
=
err1
/
mag1
relerr1
=
err1
/
mag1
relerr2
=
err2
/
mag2
relerr2
=
err2
/
mag2
relerr3
=
err3
/
mag3
relerr3
=
err3
/
mag3
max_err1
=
err1
.
max
()
max_err1
=
err1
.
max
()
max_err2
=
err2
.
max
()
max_err2
=
err2
.
max
()
...
@@ -2220,34 +2124,34 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
...
@@ -2220,34 +2124,34 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
max_errs2
.
append
(
max_err2
.
item
())
max_errs2
.
append
(
max_err2
.
item
())
max_errs3
.
append
(
max_err3
.
item
())
max_errs3
.
append
(
max_err3
.
item
())
c
=
int
(
C1
.
numel
()
*
0.0014
*
(
dim
/
256
))
+
1
c
=
int
(
C1
.
numel
()
*
0.0014
*
(
dim
/
256
))
+
1
c
=
assert_all_approx_close
(
C1
,
C2
,
1e-5
,
0.01
,
count
=
c
,
throw
=
False
)
c
=
assert_all_approx_close
(
C1
,
C2
,
1e-5
,
0.01
,
count
=
c
,
throw
=
False
)
err1
=
sum
(
errs1
)
/
len
(
errs1
)
/
math
.
sqrt
(
dim
)
err1
=
sum
(
errs1
)
/
len
(
errs1
)
/
math
.
sqrt
(
dim
)
err2
=
sum
(
errs2
)
/
len
(
errs2
)
/
math
.
sqrt
(
dim
)
err2
=
sum
(
errs2
)
/
len
(
errs2
)
/
math
.
sqrt
(
dim
)
err3
=
sum
(
errs3
)
/
len
(
errs3
)
/
math
.
sqrt
(
dim
)
err3
=
sum
(
errs3
)
/
len
(
errs3
)
/
math
.
sqrt
(
dim
)
relerr1
=
sum
(
relerrs1
)
/
len
(
relerrs1
)
/
math
.
sqrt
(
dim
)
relerr1
=
sum
(
relerrs1
)
/
len
(
relerrs1
)
/
math
.
sqrt
(
dim
)
relerr2
=
sum
(
relerrs2
)
/
len
(
relerrs2
)
/
math
.
sqrt
(
dim
)
relerr2
=
sum
(
relerrs2
)
/
len
(
relerrs2
)
/
math
.
sqrt
(
dim
)
relerr3
=
sum
(
relerrs3
)
/
len
(
relerrs3
)
/
math
.
sqrt
(
dim
)
relerr3
=
sum
(
relerrs3
)
/
len
(
relerrs3
)
/
math
.
sqrt
(
dim
)
maxerr1
=
sum
(
max_errs1
)
/
len
(
max_errs1
)
/
math
.
sqrt
(
dim
)
maxerr1
=
sum
(
max_errs1
)
/
len
(
max_errs1
)
/
math
.
sqrt
(
dim
)
maxerr2
=
sum
(
max_errs2
)
/
len
(
max_errs2
)
/
math
.
sqrt
(
dim
)
maxerr2
=
sum
(
max_errs2
)
/
len
(
max_errs2
)
/
math
.
sqrt
(
dim
)
maxerr3
=
sum
(
max_errs3
)
/
len
(
max_errs3
)
/
math
.
sqrt
(
dim
)
maxerr3
=
sum
(
max_errs3
)
/
len
(
max_errs3
)
/
math
.
sqrt
(
dim
)
absratio
=
err2
/
err3
absratio
=
err2
/
err3
relratio
=
relerr2
/
relerr3
relratio
=
relerr2
/
relerr3
maxratio
=
relerr2
/
relerr3
maxratio
=
relerr2
/
relerr3
# for debugging if the tests fails
# for debugging if the tests fails
#
#
#print('='*80)
#
print('='*80)
#print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
#
print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
#print(C1.flatten()[-20:])
#
print(C1.flatten()[-20:])
#print(C2.flatten()[-20:])
#
print(C2.flatten()[-20:])
#print(f'inference vs training abs: {err1}')
#
print(f'inference vs training abs: {err1}')
#print(f'inference vs training rel: {relerr1}')
#
print(f'inference vs training rel: {relerr1}')
#print(f'inference vs training max: {maxerr1}')
#
print(f'inference vs training max: {maxerr1}')
#print(f'inference vs training vs torch err ratio abs: {absratio}')
#
print(f'inference vs training vs torch err ratio abs: {absratio}')
#print(f'inference vs training vs torch err ratio rel: {relratio}')
#
print(f'inference vs training vs torch err ratio rel: {relratio}')
#print(f'inference vs training vs torch err ratio max: {maxratio}')
#
print(f'inference vs training vs torch err ratio max: {maxratio}')
if
dtype
==
torch
.
float16
:
if
dtype
==
torch
.
float16
:
if
dim
<=
512
:
if
dim
<=
512
:
assert
err1
<
7e-5
assert
err1
<
7e-5
...
@@ -2283,56 +2187,59 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
...
@@ -2283,56 +2187,59 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
assert
relratio
<
1.04
and
relratio
>
0.96
assert
relratio
<
1.04
and
relratio
>
0.96
assert
maxratio
<
1.02
and
maxratio
>
0.98
assert
maxratio
<
1.02
and
maxratio
>
0.98
@
pytest
.
mark
.
skip
(
"Row scale has some bugs for ampere"
)
@
pytest
.
mark
.
skip
(
"Row scale has some bugs for ampere"
)
def
test_managed
():
def
test_managed
():
n
=
32
*
10
n
=
32
*
10
A
=
F
.
get_paged
(
n
,
n
,
dtype
=
torch
.
float32
)
A
=
F
.
get_paged
(
n
,
n
,
dtype
=
torch
.
float32
)
B
=
F
.
get_paged
(
n
,
n
,
dtype
=
torch
.
uint8
)
B
=
F
.
get_paged
(
n
,
n
,
dtype
=
torch
.
uint8
)
B2
=
F
.
get_paged
(
n
,
n
,
dtype
=
torch
.
float32
)
B2
=
F
.
get_paged
(
n
,
n
,
dtype
=
torch
.
float32
)
assert
A
.
is_paged
assert
A
.
is_paged
assert
B
.
is_paged
assert
B
.
is_paged
assert
A
.
page_deviceid
==
0
assert
A
.
page_deviceid
==
0
assert
B
.
page_deviceid
==
0
assert
B
.
page_deviceid
==
0
F
.
fill
(
A
,
17.0
)
F
.
fill
(
A
,
17.0
)
F
.
fill
(
B
,
17
)
F
.
fill
(
B
,
17
)
F
.
fill
(
B2
,
2
)
F
.
fill
(
B2
,
2
)
assert
(
A
==
17
).
sum
().
item
()
==
n
*
n
assert
(
A
==
17
).
sum
().
item
()
==
n
*
n
assert
(
B
==
17
).
sum
().
item
()
==
n
*
n
assert
(
B
==
17
).
sum
().
item
()
==
n
*
n
C
=
A
*
B
.
float
()
C
=
A
*
B
.
float
()
assert
(
C
==
289
).
sum
().
item
()
==
n
*
n
assert
(
C
==
289
).
sum
().
item
()
==
n
*
n
F
.
_mul
(
A
,
B2
)
F
.
_mul
(
A
,
B2
)
F
.
_mul
(
A
,
B2
)
F
.
_mul
(
A
,
B2
)
F
.
_mul
(
A
,
B2
)
F
.
_mul
(
A
,
B2
)
assert
(
A
==
17
*
(
2
**
3
)).
sum
().
item
()
==
n
*
n
assert
(
A
==
17
*
(
2
**
3
)).
sum
().
item
()
==
n
*
n
# F.prefetch_tensor(A)
# F.prefetch_tensor(B)
# F.prefetch_tensor(A)
# F.prefetch_tensor(B)
# F.fill(B2, 17.0)
# F.fill(B2, 17.0)
# F._mul(A, B2)
# F._mul(A, B2)
# F.prefetch_tensor(A, to_cpu=True)
# F.prefetch_tensor(A, to_cpu=True)
# F.prefetch_tensor(B, to_cpu=True)
# F.prefetch_tensor(B, to_cpu=True)
# F.prefetch_tensor(B2, to_cpu=True)
# F.prefetch_tensor(B2, to_cpu=True)
# torch.cuda.synchronize()
# torch.cuda.synchronize()
# assert (A==17).sum().item() == n*n
# assert (A==17).sum().item() == n*n
# torch.testing.assert_close(A, torch.ones(A.shape)*289)
# torch.testing.assert_close(A, torch.ones(A.shape)*289)
@
pytest
.
mark
.
parametrize
(
"storage_type"
,
[
'
nf4
'
,
'
fp4
'
],
ids
=
[
'
nf4
'
,
'
fp4
'
])
@
pytest
.
mark
.
parametrize
(
"storage_type"
,
[
"
nf4
"
,
"
fp4
"
],
ids
=
[
"
nf4
"
,
"
fp4
"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"double_quant"
,
[
False
],
ids
=
[
'
DQ_True
'
])
@
pytest
.
mark
.
parametrize
(
"double_quant"
,
[
False
],
ids
=
[
"
DQ_True
"
])
def
test_gemv_eye_4bit
(
storage_type
,
dtype
,
double_quant
):
def
test_gemv_eye_4bit
(
storage_type
,
dtype
,
double_quant
):
dims
=
10
dims
=
10
torch
.
random
.
manual_seed
(
np
.
random
.
randint
(
0
,
412424242
))
torch
.
random
.
manual_seed
(
np
.
random
.
randint
(
0
,
412424242
))
dims
=
get_test_dims
(
0
,
8192
,
n
=
dims
)
dims
=
get_test_dims
(
0
,
8192
,
n
=
dims
)
dims
=
[
dim
+
(
64
-
(
dim
%
64
))
for
dim
in
dims
]
dims
=
[
dim
+
(
64
-
(
dim
%
64
))
for
dim
in
dims
]
#for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
#
for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
for
dim
in
dims
:
for
dim
in
dims
:
A
=
torch
.
normal
(
0
,
0.1
,
size
=
(
1
,
1
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
A
=
torch
.
normal
(
0
,
0.1
,
size
=
(
1
,
1
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
B
=
torch
.
eye
(
dim
,
dtype
=
dtype
,
device
=
'
cuda
'
)
B
=
torch
.
eye
(
dim
,
dtype
=
dtype
,
device
=
"
cuda
"
)
qB
,
state
=
F
.
quantize_4bit
(
B
,
quant_type
=
storage_type
,
compress_statistics
=
double_quant
)
qB
,
state
=
F
.
quantize_4bit
(
B
,
quant_type
=
storage_type
,
compress_statistics
=
double_quant
)
C3
=
torch
.
matmul
(
A
,
B
.
t
())
C3
=
torch
.
matmul
(
A
,
B
.
t
())
...
@@ -2343,5 +2250,5 @@ def test_gemv_eye_4bit(storage_type, dtype, double_quant):
...
@@ -2343,5 +2250,5 @@ def test_gemv_eye_4bit(storage_type, dtype, double_quant):
torch
.
testing
.
assert_close
(
A
,
C3
)
torch
.
testing
.
assert_close
(
A
,
C3
)
torch
.
testing
.
assert_close
(
A
,
C1
)
torch
.
testing
.
assert_close
(
A
,
C1
)
torch
.
testing
.
assert_close
(
A
,
C2
)
torch
.
testing
.
assert_close
(
A
,
C2
)
#torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001)
#
torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001)
#torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080)
#
torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080)
tests/test_generation.py
View file @
5a4263f4
...
@@ -10,56 +10,61 @@ transformers = pytest.importorskip("transformers")
...
@@ -10,56 +10,61 @@ transformers = pytest.importorskip("transformers")
def
get_4bit_config
():
def
get_4bit_config
():
return
transformers
.
BitsAndBytesConfig
(
return
transformers
.
BitsAndBytesConfig
(
load_in_4bit
=
True
,
load_in_4bit
=
True
,
load_in_8bit
=
False
,
load_in_8bit
=
False
,
llm_int8_threshold
=
6.0
,
llm_int8_threshold
=
6.0
,
llm_int8_has_fp16_weight
=
False
,
llm_int8_has_fp16_weight
=
False
,
bnb_4bit_compute_dtype
=
torch
.
float16
,
bnb_4bit_compute_dtype
=
torch
.
float16
,
bnb_4bit_use_double_quant
=
True
,
bnb_4bit_use_double_quant
=
True
,
bnb_4bit_quant_type
=
'
nf4
'
,
bnb_4bit_quant_type
=
"
nf4
"
,
)
)
def
get_model_and_tokenizer
(
config
):
def
get_model_and_tokenizer
(
config
):
model_name_or_path
,
quant_type
=
config
model_name_or_path
,
quant_type
=
config
bnb_config
=
get_4bit_config
()
bnb_config
=
get_4bit_config
()
if
quant_type
==
'
16bit
'
:
if
quant_type
==
"
16bit
"
:
bnb_config
.
load_in_4bit
=
False
bnb_config
.
load_in_4bit
=
False
else
:
else
:
bnb_config
.
bnb_4bit_quant_type
=
quant_type
bnb_config
.
bnb_4bit_quant_type
=
quant_type
model
=
transformers
.
AutoModelForCausalLM
.
from_pretrained
(
model_name_or_path
,
model
=
transformers
.
AutoModelForCausalLM
.
from_pretrained
(
model_name_or_path
,
quantization_config
=
bnb_config
,
quantization_config
=
bnb_config
,
max_memory
=
{
0
:
'
48GB
'
},
max_memory
=
{
0
:
"
48GB
"
},
device_map
=
'
auto
'
,
device_map
=
"
auto
"
,
torch_dtype
=
torch
.
bfloat16
torch_dtype
=
torch
.
bfloat16
,
).
eval
()
).
eval
()
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_name_or_path
)
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_name_or_path
)
return
model
,
tokenizer
return
model
,
tokenizer
def
get_prompt_for_generation_eval
(
text
,
add_roles
=
True
):
def
get_prompt_for_generation_eval
(
text
,
add_roles
=
True
):
description
=
(
description
=
(
"A chat between a curious human and an artificial intelligence assistant. "
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
"The assistant gives helpful, detailed, and polite answers to the user's questions."
)
)
if
add_roles
:
if
add_roles
:
prompt
=
f
'
{
description
}
### Human:
{
text
}
### Assistant:
'
prompt
=
f
"
{
description
}
### Human:
{
text
}
### Assistant:
"
else
:
else
:
prompt
=
f
'
{
description
}
{
text
}
'
prompt
=
f
"
{
description
}
{
text
}
"
return
prompt
return
prompt
def
generate
(
model
,
tokenizer
,
text
,
generation_config
,
prompt_func
=
get_prompt_for_generation_eval
):
def
generate
(
model
,
tokenizer
,
text
,
generation_config
,
prompt_func
=
get_prompt_for_generation_eval
):
text
=
prompt_func
(
text
)
text
=
prompt_func
(
text
)
inputs
=
tokenizer
(
text
,
return_tensors
=
"pt"
).
to
(
'
cuda:0
'
)
inputs
=
tokenizer
(
text
,
return_tensors
=
"pt"
).
to
(
"
cuda:0
"
)
outputs
=
model
.
generate
(
inputs
=
inputs
[
'
input_ids
'
],
generation_config
=
generation_config
)
outputs
=
model
.
generate
(
inputs
=
inputs
[
"
input_ids
"
],
generation_config
=
generation_config
)
return
tokenizer
.
decode
(
outputs
[
0
],
skip_special_tokens
=
True
)
return
tokenizer
.
decode
(
outputs
[
0
],
skip_special_tokens
=
True
)
models
=
[
'huggyllama/llama-7b'
,
'bigscience/bloom-1b7'
]
dtypes
=
[
'nf4'
,
'fp4'
]
@
pytest
.
fixture
(
scope
=
'session'
,
params
=
product
(
models
,
dtypes
))
models
=
[
"huggyllama/llama-7b"
,
"bigscience/bloom-1b7"
]
dtypes
=
[
"nf4"
,
"fp4"
]
@
pytest
.
fixture
(
scope
=
"session"
,
params
=
product
(
models
,
dtypes
))
def
model_and_tokenizer
(
request
):
def
model_and_tokenizer
(
request
):
model
,
tokenizer
=
get_model_and_tokenizer
(
request
.
param
)
model
,
tokenizer
=
get_model_and_tokenizer
(
request
.
param
)
yield
request
.
param
,
model
,
tokenizer
yield
request
.
param
,
model
,
tokenizer
...
@@ -81,20 +86,19 @@ def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype):
...
@@ -81,20 +86,19 @@ def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype):
)
)
generation_config
.
max_new_tokens
=
20
generation_config
.
max_new_tokens
=
20
# text = 'Please write down the first 50 digits of pi.'
#text = 'Please write down the first 50 digits of pi.'
# text = get_prompt_for_generation_eval(text)
#text = get_prompt_for_generation_eval(text)
# text += ' Sure, here the first 50 digits of pi: 3.14159'
#text += ' Sure, here the first 50 digits of pi: 3.14159'
n_cases
=
6
n_cases
=
6
text
=
'
3.14159
'
text
=
"
3.14159
"
if
hasattr
(
model
.
config
,
'
quantization_config
'
):
if
hasattr
(
model
.
config
,
"
quantization_config
"
):
model
.
config
.
quantization_config
.
bnb_4bit_compute_dtype
=
dtype
model
.
config
.
quantization_config
.
bnb_4bit_compute_dtype
=
dtype
model
.
config
.
quantization_config
.
bnb_4bit_use_double_quant
=
DQ
model
.
config
.
quantization_config
.
bnb_4bit_use_double_quant
=
DQ
if
not
inference_kernel
:
if
not
inference_kernel
:
text
=
[
text
]
*
n_cases
text
=
[
text
]
*
n_cases
inputs
=
tokenizer
(
text
,
return_tensors
=
"pt"
).
to
(
'
cuda:0
'
)
inputs
=
tokenizer
(
text
,
return_tensors
=
"pt"
).
to
(
"
cuda:0
"
)
x
=
inputs
[
'
input_ids
'
]
x
=
inputs
[
"
input_ids
"
]
outputs
=
[]
outputs
=
[]
if
inference_kernel
:
if
inference_kernel
:
for
i
in
range
(
n_cases
):
for
i
in
range
(
n_cases
):
...
@@ -105,15 +109,14 @@ def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype):
...
@@ -105,15 +109,14 @@ def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype):
outputs
=
model
.
generate
(
x
,
generation_config
=
generation_config
)
outputs
=
model
.
generate
(
x
,
generation_config
=
generation_config
)
outputs
=
[
tokenizer
.
decode
(
output
,
skip_special_tokens
=
True
)
for
output
in
outputs
]
outputs
=
[
tokenizer
.
decode
(
output
,
skip_special_tokens
=
True
)
for
output
in
outputs
]
assert
len
(
outputs
)
==
n_cases
assert
len
(
outputs
)
==
n_cases
failure_count
=
0
failure_count
=
0
for
i
in
range
(
n_cases
):
for
i
in
range
(
n_cases
):
if
not
outputs
[
i
][:
len
(
str
(
math
.
pi
))]
==
str
(
math
.
pi
):
if
not
outputs
[
i
][:
len
(
str
(
math
.
pi
))]
==
str
(
math
.
pi
):
failure_count
+=
1
failure_count
+=
1
failure_max
=
(
2
if
fixture_config
[
0
]
==
'
huggyllama/llama-7b
'
else
4
)
failure_max
=
2
if
fixture_config
[
0
]
==
"
huggyllama/llama-7b
"
else
4
if
failure_count
>
failure_max
:
if
failure_count
>
failure_max
:
print
(
math
.
pi
)
print
(
math
.
pi
)
for
out
in
outputs
:
for
out
in
outputs
:
print
(
out
)
print
(
out
)
raise
ValueError
(
f
'
Failure count:
{
failure_count
}
/
{
n_cases
}
'
)
raise
ValueError
(
f
"
Failure count:
{
failure_count
}
/
{
n_cases
}
"
)
tests/test_linear4bit.py
View file @
5a4263f4
...
@@ -28,9 +28,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
...
@@ -28,9 +28,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
device
=
"cuda"
device
=
"cuda"
layer_shape
=
(
300
,
400
)
layer_shape
=
(
300
,
400
)
linear
=
torch
.
nn
.
Linear
(
linear
=
torch
.
nn
.
Linear
(
*
layer_shape
,
dtype
=
original_dtype
,
device
=
"cpu"
)
# original layer
*
layer_shape
,
dtype
=
original_dtype
,
device
=
"cpu"
)
# original layer
# Quantizing original layer
# Quantizing original layer
linear_q
=
bnb
.
nn
.
Linear4bit
(
linear_q
=
bnb
.
nn
.
Linear4bit
(
...
@@ -42,9 +40,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
...
@@ -42,9 +40,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
quant_type
=
quant_type
,
quant_type
=
quant_type
,
device
=
"meta"
,
device
=
"meta"
,
)
)
new_weight
=
bnb
.
nn
.
Params4bit
(
new_weight
=
bnb
.
nn
.
Params4bit
(
data
=
linear
.
weight
,
quant_type
=
quant_type
,
requires_grad
=
False
)
data
=
linear
.
weight
,
quant_type
=
quant_type
,
requires_grad
=
False
)
linear_q
.
weight
=
new_weight
linear_q
.
weight
=
new_weight
if
bias
:
if
bias
:
linear_q
.
bias
=
torch
.
nn
.
Parameter
(
linear
.
bias
)
linear_q
.
bias
=
torch
.
nn
.
Parameter
(
linear
.
bias
)
...
@@ -172,7 +168,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
...
@@ -172,7 +168,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
target_compression
=
(
target_compression
=
(
0.143
if
original_dtype
==
torch
.
float32
else
0.29
0.143
if
original_dtype
==
torch
.
float32
else
0.29
)
# these numbers get lower as weight shape increases
)
# these numbers get lower as weight shape increases
ratio_error_msg
=
f
"quantized_size
{
size_4
:,
}
is larger on disk than
{
target_compression
:.
2
%
}
of original size
{
size_orig
:,
}
"
ratio_error_msg
=
(
f
"quantized_size
{
size_4
:,
}
is larger on disk than
{
target_compression
:.
2
%
}
of original size
{
size_orig
:,
}
"
)
assert
size_ratio
<
target_compression
,
ratio_error_msg
assert
size_ratio
<
target_compression
,
ratio_error_msg
...
...
tests/test_linear8bitlt.py
View file @
5a4263f4
...
@@ -19,6 +19,7 @@ from tests.helpers import (
...
@@ -19,6 +19,7 @@ from tests.helpers import (
# contributed by Alex Borzunov, see:
# contributed by Alex Borzunov, see:
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
()
or
torch
.
cuda
.
get_device_capability
()
<
(
7
,
5
),
not
torch
.
cuda
.
is_available
()
or
torch
.
cuda
.
get_device_capability
()
<
(
7
,
5
),
reason
=
"this test requires a turing-generation or newer GPU, see bitsandbytes docs"
,
reason
=
"this test requires a turing-generation or newer GPU, see bitsandbytes docs"
,
...
@@ -50,7 +51,9 @@ def test_linear_no_igemmlt():
...
@@ -50,7 +51,9 @@ def test_linear_no_igemmlt():
linear_custom
.
state
.
force_no_igemmlt
=
True
linear_custom
.
state
.
force_no_igemmlt
=
True
linear_custom
.
weight
=
bnb
.
nn
.
Int8Params
(
linear_custom
.
weight
=
bnb
.
nn
.
Int8Params
(
linear
.
weight
.
data
.
clone
(),
requires_grad
=
False
,
has_fp16_weights
=
False
linear
.
weight
.
data
.
clone
(),
requires_grad
=
False
,
has_fp16_weights
=
False
,
).
to
(
linear
.
weight
.
dtype
)
).
to
(
linear
.
weight
.
dtype
)
linear_custom
.
bias
=
linear
.
bias
linear_custom
.
bias
=
linear
.
bias
linear_custom
=
linear_custom
.
cuda
()
linear_custom
=
linear_custom
.
cuda
()
...
@@ -77,7 +80,14 @@ def test_linear_no_igemmlt():
...
@@ -77,7 +80,14 @@ def test_linear_no_igemmlt():
@
pytest
.
mark
.
parametrize
(
"force_no_igemmlt"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"force_no_igemmlt"
))
@
pytest
.
mark
.
parametrize
(
"force_no_igemmlt"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"force_no_igemmlt"
))
@
pytest
.
mark
.
parametrize
(
"save_before_forward"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"save_before_forward"
))
@
pytest
.
mark
.
parametrize
(
"save_before_forward"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"save_before_forward"
))
@
pytest
.
mark
.
parametrize
(
"load_before_cuda"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"load_before_cuda"
))
@
pytest
.
mark
.
parametrize
(
"load_before_cuda"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"load_before_cuda"
))
def
test_linear_serialization
(
has_fp16_weights
,
serialize_before_forward
,
deserialize_before_cuda
,
force_no_igemmlt
,
save_before_forward
,
load_before_cuda
):
def
test_linear_serialization
(
has_fp16_weights
,
serialize_before_forward
,
deserialize_before_cuda
,
force_no_igemmlt
,
save_before_forward
,
load_before_cuda
,
):
linear
=
torch
.
nn
.
Linear
(
32
,
96
)
linear
=
torch
.
nn
.
Linear
(
32
,
96
)
x
=
torch
.
randn
(
3
,
32
,
dtype
=
torch
.
half
)
x
=
torch
.
randn
(
3
,
32
,
dtype
=
torch
.
half
)
...
@@ -92,7 +102,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
...
@@ -92,7 +102,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
linear_custom
.
state
.
force_no_igemmlt
=
True
linear_custom
.
state
.
force_no_igemmlt
=
True
linear_custom
.
weight
=
bnb
.
nn
.
Int8Params
(
linear_custom
.
weight
=
bnb
.
nn
.
Int8Params
(
linear
.
weight
.
data
.
clone
(),
requires_grad
=
has_fp16_weights
,
has_fp16_weights
=
has_fp16_weights
linear
.
weight
.
data
.
clone
(),
requires_grad
=
has_fp16_weights
,
has_fp16_weights
=
has_fp16_weights
,
)
)
linear_custom
.
bias
=
linear
.
bias
linear_custom
.
bias
=
linear
.
bias
linear_custom
=
linear_custom
.
cuda
()
linear_custom
=
linear_custom
.
cuda
()
...
...
tests/test_modules.py
View file @
5a4263f4
...
@@ -19,12 +19,18 @@ class MLP8bit(torch.nn.Module):
...
@@ -19,12 +19,18 @@ class MLP8bit(torch.nn.Module):
def
__init__
(
self
,
dim1
,
dim2
,
has_fp16_weights
=
True
,
memory_efficient_backward
=
False
,
threshold
=
0.0
):
def
__init__
(
self
,
dim1
,
dim2
,
has_fp16_weights
=
True
,
memory_efficient_backward
=
False
,
threshold
=
0.0
):
super
().
__init__
()
super
().
__init__
()
self
.
fc1
=
bnb
.
nn
.
Linear8bitLt
(
self
.
fc1
=
bnb
.
nn
.
Linear8bitLt
(
dim1
,
dim2
,
has_fp16_weights
=
has_fp16_weights
,
memory_efficient_backward
=
memory_efficient_backward
,
dim1
,
threshold
=
threshold
dim2
,
has_fp16_weights
=
has_fp16_weights
,
memory_efficient_backward
=
memory_efficient_backward
,
threshold
=
threshold
,
)
)
self
.
fc2
=
bnb
.
nn
.
Linear8bitLt
(
self
.
fc2
=
bnb
.
nn
.
Linear8bitLt
(
dim2
,
dim1
,
has_fp16_weights
=
has_fp16_weights
,
memory_efficient_backward
=
memory_efficient_backward
,
dim2
,
threshold
=
threshold
dim1
,
has_fp16_weights
=
has_fp16_weights
,
memory_efficient_backward
=
memory_efficient_backward
,
threshold
=
threshold
,
)
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -52,9 +58,7 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
...
@@ -52,9 +58,7 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
class
LinearFunction
(
torch
.
autograd
.
Function
):
class
LinearFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
get_8bit_linear_trimmed
(
x
,
stochastic
=
False
,
trim_value
=
3.0
):
def
get_8bit_linear_trimmed
(
x
,
stochastic
=
False
,
trim_value
=
3.0
):
round_func
=
(
round_func
=
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
)
norm
=
math
.
sqrt
(
math
.
pi
)
/
math
.
sqrt
(
2.0
)
norm
=
math
.
sqrt
(
math
.
pi
)
/
math
.
sqrt
(
2.0
)
# std = torch.abs(x).mean()*norm
# std = torch.abs(x).mean()*norm
std
=
torch
.
std
(
x
)
std
=
torch
.
std
(
x
)
...
@@ -122,9 +126,7 @@ class LinearFunction(torch.autograd.Function):
...
@@ -122,9 +126,7 @@ class LinearFunction(torch.autograd.Function):
return
x
.
to
(
dtype
)
return
x
.
to
(
dtype
)
def
get_8bit_linear
(
x
,
stochastic
=
False
):
def
get_8bit_linear
(
x
,
stochastic
=
False
):
round_func
=
(
round_func
=
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
)
max1
=
torch
.
abs
(
x
).
max
()
max1
=
torch
.
abs
(
x
).
max
()
x
=
x
/
max1
*
127
x
=
x
/
max1
*
127
x
=
round_func
(
x
)
/
127
*
max1
x
=
round_func
(
x
)
/
127
*
max1
...
@@ -133,9 +135,7 @@ class LinearFunction(torch.autograd.Function):
...
@@ -133,9 +135,7 @@ class LinearFunction(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
get_8bit_vector_wise
(
x
,
dim
,
stochastic
=
False
):
def
get_8bit_vector_wise
(
x
,
dim
,
stochastic
=
False
):
round_func
=
(
round_func
=
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
)
max1
=
torch
.
amax
(
torch
.
abs
(
x
),
dim
=
dim
,
keepdim
=
True
)
max1
=
torch
.
amax
(
torch
.
abs
(
x
),
dim
=
dim
,
keepdim
=
True
)
max1
[
max1
==
0
]
=
1.0
max1
[
max1
==
0
]
=
1.0
x
=
(
x
*
127
)
/
max1
x
=
(
x
*
127
)
/
max1
...
@@ -219,9 +219,7 @@ class LinearFunction(torch.autograd.Function):
...
@@ -219,9 +219,7 @@ class LinearFunction(torch.autograd.Function):
weight8
,
S1
=
LinearFunction
.
quant
(
weight
,
args
.
quant_type
,
dim
=
1
)
weight8
,
S1
=
LinearFunction
.
quant
(
weight
,
args
.
quant_type
,
dim
=
1
)
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
2
)
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
2
)
outputq
=
bnb
.
functional
.
igemm
(
x8
,
weight8
.
t
())
outputq
=
bnb
.
functional
.
igemm
(
x8
,
weight8
.
t
())
output
=
LinearFunction
.
dequant
(
output
=
LinearFunction
.
dequant
(
outputq
,
S1
,
S2
,
x
.
dtype
,
args
.
quant_type
)
outputq
,
S1
,
S2
,
x
.
dtype
,
args
.
quant_type
)
# if torch.rand(1) < 0.01:
# if torch.rand(1) < 0.01:
# output32 = torch.matmul(x, weight.t())
# output32 = torch.matmul(x, weight.t())
# err = torch.abs(output-output32).float()
# err = torch.abs(output-output32).float()
...
@@ -250,37 +248,25 @@ class LinearFunction(torch.autograd.Function):
...
@@ -250,37 +248,25 @@ class LinearFunction(torch.autograd.Function):
# weight and x are already 8bit
# weight and x are already 8bit
# -> transform grad_output to 8-bit
# -> transform grad_output to 8-bit
if
args
.
use_8bit_training
==
"forward+wgrad"
:
if
args
.
use_8bit_training
==
"forward+wgrad"
:
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output
,
args
.
quant_type
,
dim
=
[
0
,
1
])
grad_output
,
args
.
quant_type
,
dim
=
[
0
,
1
]
)
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
[
0
,
1
])
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
[
0
,
1
])
grad_weight8
=
bnb
.
functional
.
igemm
(
grad_output8
,
x8
)
grad_weight8
=
bnb
.
functional
.
igemm
(
grad_output8
,
x8
)
grad_weight
=
LinearFunction
.
dequant
(
grad_weight
=
LinearFunction
.
dequant
(
grad_weight8
,
S1
,
S2
,
grad_output
.
dtype
,
args
.
quant_type
)
grad_weight8
,
S1
,
S2
,
grad_output
.
dtype
,
args
.
quant_type
)
# grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
# grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
grad_input
=
grad_output
.
matmul
(
weight
)
grad_input
=
grad_output
.
matmul
(
weight
)
elif
args
.
use_8bit_training
==
"full"
:
elif
args
.
use_8bit_training
==
"full"
:
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output
,
args
.
quant_type
,
dim
=
[
0
,
1
])
grad_output
,
args
.
quant_type
,
dim
=
[
0
,
1
]
)
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
[
0
,
1
])
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
[
0
,
1
])
grad_weight8
=
torch
.
zeros_like
(
weight
,
dtype
=
torch
.
int32
)
grad_weight8
=
torch
.
zeros_like
(
weight
,
dtype
=
torch
.
int32
)
bnb
.
functional
.
igemm
(
grad_output8
,
x8
,
out
=
grad_weight8
)
bnb
.
functional
.
igemm
(
grad_output8
,
x8
,
out
=
grad_weight8
)
grad_weight
=
LinearFunction
.
dequant
(
grad_weight
=
LinearFunction
.
dequant
(
grad_weight8
,
S1
,
S2
,
grad_output
.
dtype
,
args
.
quant_type
)
grad_weight8
,
S1
,
S2
,
grad_output
.
dtype
,
args
.
quant_type
)
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output
,
args
.
quant_type
,
dim
=
2
)
grad_output
,
args
.
quant_type
,
dim
=
2
)
weight8
,
S3
=
LinearFunction
.
quant
(
weight
,
args
.
quant_type
,
dim
=
0
)
weight8
,
S3
=
LinearFunction
.
quant
(
weight
,
args
.
quant_type
,
dim
=
0
)
grad_input8
=
bnb
.
functional
.
igemm
(
grad_output8
,
weight8
)
grad_input8
=
bnb
.
functional
.
igemm
(
grad_output8
,
weight8
)
grad_input
=
LinearFunction
.
dequant
(
grad_input
=
LinearFunction
.
dequant
(
grad_input8
,
S1
,
S3
,
grad_output
.
dtype
,
args
.
quant_type
)
grad_input8
,
S1
,
S3
,
grad_output
.
dtype
,
args
.
quant_type
)
else
:
else
:
grad_input
=
grad_output
.
matmul
(
weight
)
grad_input
=
grad_output
.
matmul
(
weight
)
...
@@ -356,12 +342,8 @@ def test_linear8bitlt_accumulated_gradient():
...
@@ -356,12 +342,8 @@ def test_linear8bitlt_accumulated_gradient():
opt1
.
zero_grad
(
True
)
opt1
.
zero_grad
(
True
)
opt2
.
step
()
opt2
.
step
()
opt2
.
zero_grad
(
True
)
opt2
.
zero_grad
(
True
)
assert_all_approx_close
(
assert_all_approx_close
(
l1
[
0
].
weight
,
l2
[
0
].
weight
,
rtol
=
1.05
,
atol
=
0.01
,
count
=
2
)
l1
[
0
].
weight
,
l2
[
0
].
weight
,
rtol
=
1.05
,
atol
=
0.01
,
count
=
2
assert_all_approx_close
(
l1
[
1
].
weight
,
l2
[
1
].
weight
,
rtol
=
1.05
,
atol
=
0.01
,
count
=
2
)
)
assert_all_approx_close
(
l1
[
1
].
weight
,
l2
[
1
].
weight
,
rtol
=
1.05
,
atol
=
0.01
,
count
=
2
)
# we do this copy because otherwise we have small divergences over time that add up
# we do this copy because otherwise we have small divergences over time that add up
l1
[
0
].
weight
.
data
.
copy_
(
l2
[
0
].
weight
.
data
)
l1
[
0
].
weight
.
data
.
copy_
(
l2
[
0
].
weight
.
data
)
l1
[
1
].
weight
.
data
.
copy_
(
l2
[
1
].
weight
.
data
)
l1
[
1
].
weight
.
data
.
copy_
(
l2
[
1
].
weight
.
data
)
...
@@ -375,7 +357,17 @@ def test_linear8bitlt_accumulated_gradient():
...
@@ -375,7 +357,17 @@ def test_linear8bitlt_accumulated_gradient():
@
pytest
.
mark
.
parametrize
(
"threshold"
,
[
0.0
,
2.0
])
@
pytest
.
mark
.
parametrize
(
"threshold"
,
[
0.0
,
2.0
])
@
pytest
.
mark
.
parametrize
(
"memory_efficient_backward"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"memory_efficient_backward"
,
[
False
])
def
test_linear8bitlt_no_fp16_weights
(
threshold
,
memory_efficient_backward
):
def
test_linear8bitlt_no_fp16_weights
(
threshold
,
memory_efficient_backward
):
l1
=
(
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
).
cuda
().
half
())
l1
=
(
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
,
)
.
cuda
()
.
half
()
)
assert
l1
.
weight
.
dtype
==
torch
.
int8
assert
l1
.
weight
.
dtype
==
torch
.
int8
l1
.
eval
()
l1
.
eval
()
...
@@ -397,11 +389,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
...
@@ -397,11 +389,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
if
threshold
>
0
:
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
assert
mlp
.
fc2
.
state
.
idx
is
not
None
mlp
=
(
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
cuda
().
half
()
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
)
.
cuda
()
.
half
()
)
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
...
@@ -414,11 +402,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
...
@@ -414,11 +402,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
if
threshold
>
0
:
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
assert
mlp
.
fc2
.
state
.
idx
is
not
None
mlp
=
(
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
half
().
cuda
()
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
)
.
half
()
.
cuda
()
)
for
i
in
range
(
100
):
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
...
@@ -431,7 +415,17 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
...
@@ -431,7 +415,17 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
mlp
=
(
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
).
half
().
to
(
"cuda"
))
mlp
=
(
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
,
)
.
half
()
.
to
(
"cuda"
)
)
for
i
in
range
(
100
):
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
...
@@ -447,8 +441,12 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
...
@@ -447,8 +441,12 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
assert
mlp
.
fc2
.
weight
.
device
.
type
==
"cuda"
assert
mlp
.
fc2
.
weight
.
device
.
type
==
"cuda"
mlp
=
MLP8bit
(
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
32
,
)
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
,
)
w1
,
w2
=
mlp
.
fc1
.
weight
.
clone
().
cuda
(),
mlp
.
fc2
.
weight
.
clone
().
cuda
()
# grab weights before quantization,
w1
,
w2
=
mlp
.
fc1
.
weight
.
clone
().
cuda
(),
mlp
.
fc2
.
weight
.
clone
().
cuda
()
# grab weights before quantization,
mlp
=
mlp
.
cuda
().
half
()
# and this line triggers quantization
mlp
=
mlp
.
cuda
().
half
()
# and this line triggers quantization
...
@@ -489,7 +487,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
...
@@ -489,7 +487,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
lambda
n_in
,
n_out
,
bias
=
True
:
bnb
.
nn
.
Linear8bitLt
(
n_in
,
n_out
,
bias
=
bias
,
has_fp16_weights
=
False
),
lambda
n_in
,
n_out
,
bias
=
True
:
bnb
.
nn
.
Linear8bitLt
(
n_in
,
n_out
,
bias
=
bias
,
has_fp16_weights
=
False
),
bnb
.
nn
.
LinearFP4
,
bnb
.
nn
.
LinearFP4
,
],
],
ids
=
[
'
Int8Lt
'
,
'
FP4
'
],
ids
=
[
"
Int8Lt
"
,
"
FP4
"
],
)
)
def
test_linear_kbit_fp32_bias
(
module
):
def
test_linear_kbit_fp32_bias
(
module
):
# casts model to fp16 -> int8 automatically
# casts model to fp16 -> int8 automatically
...
@@ -544,7 +542,7 @@ def test_kbit_backprop(module):
...
@@ -544,7 +542,7 @@ def test_kbit_backprop(module):
kbit
[
1
].
bias
.
detach
().
copy_
(
ref
[
1
].
bias
)
kbit
[
1
].
bias
.
detach
().
copy_
(
ref
[
1
].
bias
)
ref
=
ref
.
half
().
cuda
()
ref
=
ref
.
half
().
cuda
()
kbit
=
kbit
.
half
().
cuda
()
kbit
=
kbit
.
half
().
cuda
()
kbit
=
kbit
.
half
().
to
(
'
cuda
'
)
kbit
=
kbit
.
half
().
to
(
"
cuda
"
)
errs1
=
[]
errs1
=
[]
errs2
=
[]
errs2
=
[]
...
@@ -562,10 +560,10 @@ def test_kbit_backprop(module):
...
@@ -562,10 +560,10 @@ def test_kbit_backprop(module):
bgrad1
=
ref
[
0
].
bias
.
grad
bgrad1
=
ref
[
0
].
bias
.
grad
bgrad2
=
kbit
[
0
].
bias
.
grad
bgrad2
=
kbit
[
0
].
bias
.
grad
err1
=
(
out1
-
out2
).
abs
().
float
()
err1
=
(
out1
-
out2
).
abs
().
float
()
err2
=
(
grad1
-
grad2
).
abs
().
float
()
err2
=
(
grad1
-
grad2
).
abs
().
float
()
relerr1
=
(
err1
/
(
out1
.
abs
().
float
()
+
1e-9
)
)
relerr1
=
err1
/
(
out1
.
abs
().
float
()
+
1e-9
)
relerr2
=
(
err2
/
(
grad1
.
abs
().
float
()
+
1e-9
)
)
relerr2
=
err2
/
(
grad1
.
abs
().
float
()
+
1e-9
)
errs1
.
append
(
err1
.
mean
().
item
())
errs1
.
append
(
err1
.
mean
().
item
())
errs2
.
append
(
err2
.
mean
().
item
())
errs2
.
append
(
err2
.
mean
().
item
())
relerrs1
.
append
(
relerr1
.
mean
().
item
())
relerrs1
.
append
(
relerr1
.
mean
().
item
())
...
@@ -582,20 +580,20 @@ def test_kbit_backprop(module):
...
@@ -582,20 +580,20 @@ def test_kbit_backprop(module):
assert
kbit
[
0
].
weight
.
grad
is
None
or
kbit
[
0
].
weight
.
grad
.
sum
().
item
()
==
0
assert
kbit
[
0
].
weight
.
grad
is
None
or
kbit
[
0
].
weight
.
grad
.
sum
().
item
()
==
0
assert
kbit
[
0
].
weight
.
grad
is
None
or
kbit
[
0
].
bias
.
grad
.
sum
().
item
()
==
0
assert
kbit
[
0
].
weight
.
grad
is
None
or
kbit
[
0
].
bias
.
grad
.
sum
().
item
()
==
0
#print('out', sum(errs1)/len(errs1))
#
print('out', sum(errs1)/len(errs1))
#print('grad', sum(errs2)/len(errs2))
#
print('grad', sum(errs2)/len(errs2))
#print('rel out', sum(relerrs1)/len(relerrs1))
#
print('rel out', sum(relerrs1)/len(relerrs1))
#print('rel grad', sum(relerrs2)/len(relerrs2))
#
print('rel grad', sum(relerrs2)/len(relerrs2))
def
test_fp8linear
():
def
test_fp8linear
():
b
=
10
b
=
10
h
=
1024
h
=
1024
inp
=
torch
.
randn
(
b
,
h
).
cuda
()
inp
=
torch
.
randn
(
b
,
h
).
cuda
()
fp32
=
torch
.
nn
.
Linear
(
h
,
h
*
2
).
cuda
()
fp32
=
torch
.
nn
.
Linear
(
h
,
h
*
2
).
cuda
()
fp8
=
bnb
.
research
.
nn
.
LinearFP8Mixed
(
h
,
h
*
2
).
cuda
()
fp8
=
bnb
.
research
.
nn
.
LinearFP8Mixed
(
h
,
h
*
2
).
cuda
()
fp32b
=
torch
.
nn
.
Linear
(
h
*
2
,
h
).
cuda
()
fp32b
=
torch
.
nn
.
Linear
(
h
*
2
,
h
).
cuda
()
fp8b
=
bnb
.
research
.
nn
.
LinearFP8Mixed
(
h
*
2
,
h
).
cuda
()
fp8b
=
bnb
.
research
.
nn
.
LinearFP8Mixed
(
h
*
2
,
h
).
cuda
()
fp8
.
weight
.
data
.
copy_
(
fp32
.
weight
.
data
)
fp8
.
weight
.
data
.
copy_
(
fp32
.
weight
.
data
)
fp8
.
bias
.
data
.
copy_
(
fp32
.
bias
.
data
)
fp8
.
bias
.
data
.
copy_
(
fp32
.
bias
.
data
)
...
@@ -605,34 +603,34 @@ def test_fp8linear():
...
@@ -605,34 +603,34 @@ def test_fp8linear():
a
=
fp32b
(
torch
.
nn
.
functional
.
gelu
(
fp32
(
inp
)))
a
=
fp32b
(
torch
.
nn
.
functional
.
gelu
(
fp32
(
inp
)))
b
=
fp8b
(
torch
.
nn
.
functional
.
gelu
(
fp8
(
inp
)))
b
=
fp8b
(
torch
.
nn
.
functional
.
gelu
(
fp8
(
inp
)))
err
=
(
a
-
b
).
abs
().
mean
()
err
=
(
a
-
b
).
abs
().
mean
()
a
.
mean
().
backward
()
a
.
mean
().
backward
()
b
.
mean
().
backward
()
b
.
mean
().
backward
()
graderr
=
(
fp8
.
weight
.
grad
-
fp32
.
weight
.
grad
).
abs
().
mean
()
graderr
=
(
fp8
.
weight
.
grad
-
fp32
.
weight
.
grad
).
abs
().
mean
()
bgraderr
=
(
fp8
.
bias
.
grad
-
fp32
.
bias
.
grad
).
abs
().
mean
()
bgraderr
=
(
fp8
.
bias
.
grad
-
fp32
.
bias
.
grad
).
abs
().
mean
()
assert
err
<
0.05
assert
err
<
0.05
assert
graderr
<
0.00002
assert
graderr
<
0.00002
assert
bgraderr
<
0.00002
assert
bgraderr
<
0.00002
def
test_4bit_warnings
():
def
test_4bit_warnings
():
dim1
=
64
dim1
=
64
with
pytest
.
warns
(
UserWarning
,
match
=
r
'
inference or training
'
):
with
pytest
.
warns
(
UserWarning
,
match
=
r
"
inference or training
"
):
net
=
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear4bit
(
dim1
,
dim1
,
compute_dtype
=
torch
.
float32
)
for
i
in
range
(
10
)])
net
=
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear4bit
(
dim1
,
dim1
,
compute_dtype
=
torch
.
float32
)
for
i
in
range
(
10
)])
net
=
net
.
cuda
()
net
=
net
.
cuda
()
inp
=
torch
.
rand
(
10
,
dim1
).
cuda
().
half
()
inp
=
torch
.
rand
(
10
,
dim1
).
cuda
().
half
()
net
(
inp
)
net
(
inp
)
with
pytest
.
warns
(
UserWarning
,
match
=
r
'
inference.
'
):
with
pytest
.
warns
(
UserWarning
,
match
=
r
"
inference.
"
):
net
=
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear4bit
(
dim1
,
dim1
,
compute_dtype
=
torch
.
float32
)
for
i
in
range
(
10
)])
net
=
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear4bit
(
dim1
,
dim1
,
compute_dtype
=
torch
.
float32
)
for
i
in
range
(
10
)])
net
=
net
.
cuda
()
net
=
net
.
cuda
()
inp
=
torch
.
rand
(
1
,
dim1
).
cuda
().
half
()
inp
=
torch
.
rand
(
1
,
dim1
).
cuda
().
half
()
net
(
inp
)
net
(
inp
)
with
pytest
.
warns
(
UserWarning
)
as
record
:
with
pytest
.
warns
(
UserWarning
)
as
record
:
net
=
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear4bit
(
dim1
,
dim1
,
compute_dtype
=
torch
.
float32
)
for
i
in
range
(
10
)])
net
=
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear4bit
(
dim1
,
dim1
,
compute_dtype
=
torch
.
float32
)
for
i
in
range
(
10
)])
net
=
net
.
cuda
()
net
=
net
.
cuda
()
inp
=
torch
.
rand
(
10
,
dim1
).
cuda
().
half
()
inp
=
torch
.
rand
(
10
,
dim1
).
cuda
().
half
()
...
...
tests/test_optim.py
View file @
5a4263f4
...
@@ -16,6 +16,7 @@ from tests.helpers import describe_dtype, id_formatter
...
@@ -16,6 +16,7 @@ from tests.helpers import describe_dtype, id_formatter
k
=
20
k
=
20
def
assert_most_approx_close
(
a
,
b
,
rtol
=
1e-3
,
atol
=
1e-3
,
max_error_count
=
0
):
def
assert_most_approx_close
(
a
,
b
,
rtol
=
1e-3
,
atol
=
1e-3
,
max_error_count
=
0
):
idx
=
torch
.
isclose
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
)
idx
=
torch
.
isclose
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
)
error_count
=
(
idx
==
0
).
sum
().
item
()
error_count
=
(
idx
==
0
).
sum
().
item
()
...
@@ -33,6 +34,7 @@ def get_temp_dir():
...
@@ -33,6 +34,7 @@ def get_temp_dir():
def
rm_path
(
path
):
def
rm_path
(
path
):
shutil
.
rmtree
(
path
)
shutil
.
rmtree
(
path
)
str2optimizers
=
{}
str2optimizers
=
{}
str2optimizers
[
"adam_pytorch"
]
=
(
None
,
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
str2optimizers
[
"adam_pytorch"
]
=
(
None
,
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
str2optimizers
[
"lion_pytorch"
]
=
(
None
,
Lion
,
bnb
.
optim
.
Lion
)
str2optimizers
[
"lion_pytorch"
]
=
(
None
,
Lion
,
bnb
.
optim
.
Lion
)
...
@@ -66,8 +68,14 @@ str2optimizers["rmsprop8bit"] = (
...
@@ -66,8 +68,14 @@ str2optimizers["rmsprop8bit"] = (
)
)
str2optimizers
[
"adam8bit_blockwise"
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
"adam8bit_blockwise"
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
"paged_adamw8bit_blockwise"
]
=
(
torch
.
optim
.
AdamW
,
lambda
pxx
:
bnb
.
optim
.
PagedAdamW8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
"paged_adamw8bit_blockwise"
]
=
(
str2optimizers
[
"paged_adam8bit_blockwise"
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
PagedAdam8bit
(
pxx
,
block_wise
=
True
))
torch
.
optim
.
AdamW
,
lambda
pxx
:
bnb
.
optim
.
PagedAdamW8bit
(
pxx
,
block_wise
=
True
),
)
str2optimizers
[
"paged_adam8bit_blockwise"
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
PagedAdam8bit
(
pxx
,
block_wise
=
True
),
)
str2optimizers
[
"lion8bit_blockwise"
]
=
(
Lion
,
lambda
pxx
:
bnb
.
optim
.
Lion8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
"lion8bit_blockwise"
]
=
(
Lion
,
lambda
pxx
:
bnb
.
optim
.
Lion8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
"paged_lion8bit_blockwise"
]
=
(
Lion
,
lambda
pxx
:
bnb
.
optim
.
PagedLion8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
"paged_lion8bit_blockwise"
]
=
(
Lion
,
lambda
pxx
:
bnb
.
optim
.
PagedLion8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
"momentum8bit_blockwise"
]
=
(
str2optimizers
[
"momentum8bit_blockwise"
]
=
(
...
@@ -90,9 +98,18 @@ str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
...
@@ -90,9 +98,18 @@ str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames
[
"rmsprop"
]
=
[(
"square_avg"
,
"state1"
)]
str2statenames
[
"rmsprop"
]
=
[(
"square_avg"
,
"state1"
)]
str2statenames
[
"adam8bit"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"max2"
)]
str2statenames
[
"adam8bit"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"max2"
)]
str2statenames
[
"lamb8bit"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"max2"
)]
str2statenames
[
"lamb8bit"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"max2"
)]
str2statenames
[
"adam8bit_blockwise"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"absmax2"
)]
str2statenames
[
"adam8bit_blockwise"
]
=
[
str2statenames
[
"paged_adam8bit_blockwise"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"absmax2"
)]
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
),
str2statenames
[
"paged_adamw8bit_blockwise"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"absmax2"
)]
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"absmax2"
),
]
str2statenames
[
"paged_adam8bit_blockwise"
]
=
[
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"absmax2"
),
]
str2statenames
[
"paged_adamw8bit_blockwise"
]
=
[
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"absmax2"
),
]
str2statenames
[
"momentum8bit"
]
=
[(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"momentum8bit"
]
=
[(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"lion8bit"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"lion8bit"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"momentum8bit_blockwise"
]
=
[(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"absmax1"
)]
str2statenames
[
"momentum8bit_blockwise"
]
=
[(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"absmax1"
)]
...
@@ -101,7 +118,7 @@ str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "ab
...
@@ -101,7 +118,7 @@ str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "ab
str2statenames
[
"lion8bit_blockwise"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
)]
str2statenames
[
"lion8bit_blockwise"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
)]
str2statenames
[
"paged_lion8bit_blockwise"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
)]
str2statenames
[
"paged_lion8bit_blockwise"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
)]
optimizer_names_32bit
=
[
"adam"
,
"momentum"
,
"rmsprop"
,
'
paged_adamw
'
,
'
paged_adam
'
,
'
lion
'
,
'
paged_lion
'
]
optimizer_names_32bit
=
[
"adam"
,
"momentum"
,
"rmsprop"
,
"
paged_adamw
"
,
"
paged_adam
"
,
"
lion
"
,
"
paged_lion
"
]
@
pytest
.
mark
.
parametrize
(
"optim_name"
,
optimizer_names_32bit
,
ids
=
id_formatter
(
"opt"
))
@
pytest
.
mark
.
parametrize
(
"optim_name"
,
optimizer_names_32bit
,
ids
=
id_formatter
(
"opt"
))
...
@@ -109,7 +126,7 @@ optimizer_names_32bit = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_ad
...
@@ -109,7 +126,7 @@ optimizer_names_32bit = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_ad
@
pytest
.
mark
.
parametrize
(
"dim1"
,
[
1024
],
ids
=
id_formatter
(
"dim1"
))
@
pytest
.
mark
.
parametrize
(
"dim1"
,
[
1024
],
ids
=
id_formatter
(
"dim1"
))
@
pytest
.
mark
.
parametrize
(
"dim2"
,
[
32
,
1024
,
4097
,
1
],
ids
=
id_formatter
(
"dim2"
))
@
pytest
.
mark
.
parametrize
(
"dim2"
,
[
32
,
1024
,
4097
,
1
],
ids
=
id_formatter
(
"dim2"
))
def
test_optimizer32bit
(
dim1
,
dim2
,
gtype
,
optim_name
):
def
test_optimizer32bit
(
dim1
,
dim2
,
gtype
,
optim_name
):
if
gtype
==
torch
.
bfloat16
and
optim_name
in
[
'
momentum
'
,
'
rmsprop
'
]:
if
gtype
==
torch
.
bfloat16
and
optim_name
in
[
"
momentum
"
,
"
rmsprop
"
]:
pytest
.
skip
()
pytest
.
skip
()
if
dim1
==
1
and
dim2
==
1
:
if
dim1
==
1
and
dim2
==
1
:
return
return
...
@@ -161,9 +178,13 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
...
@@ -161,9 +178,13 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
for
name1
,
name2
in
str2statenames
[
optim_name
]:
for
name1
,
name2
in
str2statenames
[
optim_name
]:
# since Lion can have pretty noisy updates where things lie at the boundary
# since Lion can have pretty noisy updates where things lie at the boundary
# allow up to 10 errors for Lion
# allow up to 10 errors for Lion
assert_most_approx_close
(
torch_optimizer
.
state
[
p1
][
name1
],
bnb_optimizer
.
state
[
p2
][
name2
],
assert_most_approx_close
(
atol
=
atol
,
rtol
=
rtol
,
torch_optimizer
.
state
[
p1
][
name1
],
max_error_count
=
10
)
bnb_optimizer
.
state
[
p2
][
name2
],
atol
=
atol
,
rtol
=
rtol
,
max_error_count
=
10
,
)
if
gtype
!=
torch
.
float32
:
if
gtype
!=
torch
.
float32
:
# the adam buffers should also be close because they are 32-bit
# the adam buffers should also be close because they are 32-bit
...
@@ -193,13 +214,9 @@ def test_global_config(dim1, dim2, gtype):
...
@@ -193,13 +214,9 @@ def test_global_config(dim1, dim2, gtype):
eps
=
1e-8
eps
=
1e-8
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
initialize
()
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
initialize
()
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
override_config
(
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
override_config
(
p3
,
"optim_bits"
,
8
)
p3
,
"optim_bits"
,
8
)
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
register_parameters
(
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
register_parameters
([
p1
,
p2
,
p3
])
[
p1
,
p2
,
p3
]
)
p1
=
p1
.
cuda
()
p1
=
p1
.
cuda
()
p2
=
p2
.
cuda
()
p2
=
p2
.
cuda
()
p3
=
p3
.
cuda
()
p3
=
p3
.
cuda
()
...
@@ -242,7 +259,8 @@ optimizer_names_8bit = [
...
@@ -242,7 +259,8 @@ optimizer_names_8bit = [
@
pytest
.
mark
.
parametrize
(
"dim2"
,
[
32
,
1024
,
4097
],
ids
=
id_formatter
(
"dim2"
))
@
pytest
.
mark
.
parametrize
(
"dim2"
,
[
32
,
1024
,
4097
],
ids
=
id_formatter
(
"dim2"
))
@
pytest
.
mark
.
parametrize
(
"dim1"
,
[
1024
],
ids
=
id_formatter
(
"dim1"
))
@
pytest
.
mark
.
parametrize
(
"dim1"
,
[
1024
],
ids
=
id_formatter
(
"dim1"
))
def
test_optimizer8bit
(
dim1
,
dim2
,
gtype
,
optim_name
):
def
test_optimizer8bit
(
dim1
,
dim2
,
gtype
,
optim_name
):
if
gtype
==
torch
.
bfloat16
and
optim_name
not
in
[
'adam8bit_blockwise'
,
'lion8bit_blockwise'
]:
pytest
.
skip
()
if
gtype
==
torch
.
bfloat16
and
optim_name
not
in
[
"adam8bit_blockwise"
,
"lion8bit_blockwise"
]:
pytest
.
skip
()
if
dim1
==
1
and
dim2
==
1
:
if
dim1
==
1
and
dim2
==
1
:
return
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
...
@@ -294,17 +312,12 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
...
@@ -294,17 +312,12 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
absmax
=
bnb_optimizer
.
state
[
p2
][
max_val
],
absmax
=
bnb_optimizer
.
state
[
p2
][
max_val
],
A
=
bnb_optimizer
.
state
[
p2
][
name2
],
A
=
bnb_optimizer
.
state
[
p2
][
name2
],
)
)
num_not_close
=
(
num_not_close
=
torch
.
isclose
(
torch_optimizer
.
state
[
p1
][
name1
],
s1
,
atol
=
atol
,
rtol
=
rtol
)
==
0
torch
.
isclose
(
# assert num_not_close.sum().item() < 20
torch_optimizer
.
state
[
p1
][
name1
],
s1
,
atol
=
atol
,
rtol
=
rtol
)
==
0
)
#assert num_not_close.sum().item() < 20
dequant_states
.
append
(
s1
.
clone
())
dequant_states
.
append
(
s1
.
clone
())
err
=
torch
.
abs
(
p1
-
p2
)
err
=
torch
.
abs
(
p1
-
p2
)
relerr
=
err
/
(
torch
.
abs
(
p1
)
+
1e-9
)
relerr
=
err
/
(
torch
.
abs
(
p1
)
+
1e-9
)
if
g
.
dtype
==
torch
.
bfloat16
:
if
g
.
dtype
==
torch
.
bfloat16
:
assert
err
.
mean
()
<
0.00015
assert
err
.
mean
()
<
0.00015
assert
relerr
.
mean
()
<
0.0016
assert
relerr
.
mean
()
<
0.0016
...
@@ -316,9 +329,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
...
@@ -316,9 +329,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
relerrors
.
append
(
relerr
.
mean
().
item
())
relerrors
.
append
(
relerr
.
mean
().
item
())
if
i
%
10
==
0
and
i
>
0
:
if
i
%
10
==
0
and
i
>
0
:
for
(
name1
,
name2
,
qmap
,
max_val
),
s
in
zip
(
for
(
name1
,
name2
,
qmap
,
max_val
),
s
in
zip
(
str2statenames
[
optim_name
],
dequant_states
):
str2statenames
[
optim_name
],
dequant_states
):
s1cpy
=
s
.
clone
()
s1cpy
=
s
.
clone
()
raws1cpy
=
bnb_optimizer
.
state
[
p2
][
name2
].
clone
()
raws1cpy
=
bnb_optimizer
.
state
[
p2
][
name2
].
clone
()
qmap1
=
bnb_optimizer
.
state
[
p2
][
qmap
].
clone
()
qmap1
=
bnb_optimizer
.
state
[
p2
][
qmap
].
clone
()
...
@@ -348,7 +359,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
...
@@ -348,7 +359,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
)
)
torch
.
testing
.
assert_close
(
s1cpy
,
s1
)
torch
.
testing
.
assert_close
(
s1cpy
,
s1
)
num_not_close
=
(
torch
.
isclose
(
torch_optimizer
.
state
[
p1
][
name1
],
s1
,
atol
=
atol
,
rtol
=
rtol
)
==
0
)
num_not_close
=
torch
.
isclose
(
torch_optimizer
.
state
[
p1
][
name1
],
s1
,
atol
=
atol
,
rtol
=
rtol
)
==
0
assert
num_not_close
.
sum
().
item
()
<
20
assert
num_not_close
.
sum
().
item
()
<
20
# since Lion can have pretty noisy updates where things lie at the boundary
# since Lion can have pretty noisy updates where things lie at the boundary
# allow up to 5 errors for Lion
# allow up to 5 errors for Lion
...
@@ -395,15 +406,11 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
...
@@ -395,15 +406,11 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
for
i
in
range
(
50
):
for
i
in
range
(
50
):
step
+=
1
step
+=
1
g1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
+
(
g1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
+
(
0.01
*
i
)
0.01
*
i
)
g2
=
g1
.
clone
()
g2
=
g1
.
clone
()
p2
.
grad
=
g2
p2
.
grad
=
g2
current_gnorm
,
clip_val
,
gnorm_scale
=
F
.
percentile_clipping
(
current_gnorm
,
clip_val
,
gnorm_scale
=
F
.
percentile_clipping
(
g1
,
gnorm_vec
,
step
,
5
)
g1
,
gnorm_vec
,
step
,
5
)
g1
=
(
g1
.
float
()
*
gnorm_scale
).
to
(
gtype
)
g1
=
(
g1
.
float
()
*
gnorm_scale
).
to
(
gtype
)
p1
.
grad
=
g1
p1
.
grad
=
g1
...
@@ -497,8 +504,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
...
@@ -497,8 +504,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
@
pytest
.
mark
.
parametrize
(
"dim1"
,
[
2
*
1024
],
ids
=
id_formatter
(
"dim1"
))
@
pytest
.
mark
.
parametrize
(
"dim1"
,
[
2
*
1024
],
ids
=
id_formatter
(
"dim1"
))
@
pytest
.
mark
.
parametrize
(
"gtype"
,
[
torch
.
float16
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"gtype"
,
[
torch
.
float16
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"optim_name"
,
[
'
paged_adamw
'
],
ids
=
id_formatter
(
"optim_name"
))
@
pytest
.
mark
.
parametrize
(
"optim_name"
,
[
"
paged_adamw
"
],
ids
=
id_formatter
(
"optim_name"
))
@
pytest
.
mark
.
parametrize
(
"mode"
,
[
'
bnb
'
],
ids
=
id_formatter
(
"mode"
))
@
pytest
.
mark
.
parametrize
(
"mode"
,
[
"
bnb
"
],
ids
=
id_formatter
(
"mode"
))
@
pytest
.
mark
.
benchmark
@
pytest
.
mark
.
benchmark
def
test_stream_optimizer_bench
(
dim1
,
gtype
,
optim_name
,
mode
):
def
test_stream_optimizer_bench
(
dim1
,
gtype
,
optim_name
,
mode
):
layers1
=
torch
.
nn
.
Sequential
(
*
torch
.
nn
.
ModuleList
([
torch
.
nn
.
Linear
(
dim1
,
dim1
)
for
i
in
range
(
10
)]))
layers1
=
torch
.
nn
.
Sequential
(
*
torch
.
nn
.
ModuleList
([
torch
.
nn
.
Linear
(
dim1
,
dim1
)
for
i
in
range
(
10
)]))
...
@@ -506,24 +513,24 @@ def test_stream_optimizer_bench(dim1, gtype, optim_name, mode):
...
@@ -506,24 +513,24 @@ def test_stream_optimizer_bench(dim1, gtype, optim_name, mode):
layers1
=
layers1
.
cuda
()
layers1
=
layers1
.
cuda
()
large_tensor
=
None
large_tensor
=
None
if
mode
==
'
torch
'
:
if
mode
==
"
torch
"
:
optim
=
str2optimizers
[
optim_name
][
0
](
layers1
.
parameters
())
optim
=
str2optimizers
[
optim_name
][
0
](
layers1
.
parameters
())
else
:
else
:
optim
=
str2optimizers
[
optim_name
][
1
](
layers1
.
parameters
())
optim
=
str2optimizers
[
optim_name
][
1
](
layers1
.
parameters
())
# 12 GB
# 12 GB
large_tensor
=
torch
.
empty
((
int
(
4.5e9
),),
device
=
'
cuda
'
)
large_tensor
=
torch
.
empty
((
int
(
4.5e9
),),
device
=
"
cuda
"
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
time
.
sleep
(
5
)
time
.
sleep
(
5
)
num_batches
=
5
num_batches
=
5
batches
=
torch
.
randn
(
num_batches
,
128
,
dim1
,
device
=
'
cuda
'
).
to
(
gtype
)
batches
=
torch
.
randn
(
num_batches
,
128
,
dim1
,
device
=
"
cuda
"
).
to
(
gtype
)
lbls
=
torch
.
randint
(
0
,
10
,
size
=
(
num_batches
,
128
)).
cuda
()
lbls
=
torch
.
randint
(
0
,
10
,
size
=
(
num_batches
,
128
)).
cuda
()
for
i
in
range
(
num_batches
):
for
i
in
range
(
num_batches
):
print
(
i
)
print
(
i
)
b
=
batches
[
i
]
b
=
batches
[
i
]
if
i
==
2
:
if
i
==
2
:
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
...
...
tests/test_triton.py
View file @
5a4263f4
...
@@ -7,15 +7,18 @@ from bitsandbytes.triton.triton_utils import is_triton_available
...
@@ -7,15 +7,18 @@ from bitsandbytes.triton.triton_utils import is_triton_available
from
tests.helpers
import
TRUE_FALSE
from
tests.helpers
import
TRUE_FALSE
@
pytest
.
mark
.
skipif
(
not
is_triton_available
()
or
not
torch
.
cuda
.
is_available
()
or
not
torch
.
cuda
.
get_device_capability
()[
0
]
>=
8
,
@
pytest
.
mark
.
skipif
(
reason
=
"This test requires triton and a GPU with compute capability 8.0 or higher."
)
not
is_triton_available
()
or
not
torch
.
cuda
.
is_available
()
or
not
torch
.
cuda
.
get_device_capability
()[
0
]
>=
8
,
reason
=
"This test requires triton and a GPU with compute capability 8.0 or higher."
,
)
@
pytest
.
mark
.
parametrize
(
"vector_wise_quantization"
,
TRUE_FALSE
)
@
pytest
.
mark
.
parametrize
(
"vector_wise_quantization"
,
TRUE_FALSE
)
def
test_switchback
(
vector_wise_quantization
):
def
test_switchback
(
vector_wise_quantization
):
for
dim
in
[
83
]:
for
dim
in
[
83
]:
for
batch
in
[
13
]:
for
batch
in
[
13
]:
standard
=
torch
.
nn
.
Linear
(
dim
,
4
*
dim
).
cuda
().
half
()
standard
=
torch
.
nn
.
Linear
(
dim
,
4
*
dim
).
cuda
().
half
()
switchback
=
SwitchBackLinear
(
dim
,
4
*
dim
,
vector_wise_quantization
=
vector_wise_quantization
).
cuda
().
half
()
switchback
=
(
SwitchBackLinear
(
dim
,
4
*
dim
,
vector_wise_quantization
=
vector_wise_quantization
).
cuda
().
half
()
)
baseline
=
Linear8bitLt
(
dim
,
4
*
dim
).
cuda
().
half
()
baseline
=
Linear8bitLt
(
dim
,
4
*
dim
).
cuda
().
half
()
switchback
.
weight
.
data
.
copy_
(
standard
.
weight
)
switchback
.
weight
.
data
.
copy_
(
standard
.
weight
)
switchback
.
bias
.
data
.
copy_
(
standard
.
bias
)
switchback
.
bias
.
data
.
copy_
(
standard
.
bias
)
...
@@ -38,23 +41,23 @@ def test_switchback(vector_wise_quantization):
...
@@ -38,23 +41,23 @@ def test_switchback(vector_wise_quantization):
err_sb
=
(
out_standard
-
out_sb
).
abs
().
mean
()
err_sb
=
(
out_standard
-
out_sb
).
abs
().
mean
()
err_baseline
=
(
out_standard
-
out_baseline
).
abs
().
mean
()
err_baseline
=
(
out_standard
-
out_baseline
).
abs
().
mean
()
print
(
'
OUT
'
,
err_sb
,
err_baseline
)
print
(
"
OUT
"
,
err_sb
,
err_baseline
)
assert
err_sb
<
2
*
err_baseline
assert
err_sb
<
2
*
err_baseline
err_sb
=
(
standard
.
bias
.
grad
-
switchback
.
bias
.
grad
).
abs
().
mean
()
err_sb
=
(
standard
.
bias
.
grad
-
switchback
.
bias
.
grad
).
abs
().
mean
()
err_baseline
=
(
standard
.
bias
.
grad
-
baseline
.
bias
.
grad
).
abs
().
mean
()
err_baseline
=
(
standard
.
bias
.
grad
-
baseline
.
bias
.
grad
).
abs
().
mean
()
print
(
'
GW2
'
,
err_sb
,
err_baseline
)
print
(
"
GW2
"
,
err_sb
,
err_baseline
)
assert
err_sb
<
2
*
err_baseline
assert
err_sb
<
2
*
err_baseline
err_sb
=
(
standard
.
weight
.
grad
-
switchback
.
weight
.
grad
).
abs
().
mean
()
err_sb
=
(
standard
.
weight
.
grad
-
switchback
.
weight
.
grad
).
abs
().
mean
()
err_baseline
=
(
standard
.
weight
.
grad
-
baseline
.
weight
.
grad
).
abs
().
mean
()
err_baseline
=
(
standard
.
weight
.
grad
-
baseline
.
weight
.
grad
).
abs
().
mean
()
print
(
'
GW1
'
,
err_sb
,
err_baseline
)
print
(
"
GW1
"
,
err_sb
,
err_baseline
)
assert
err_sb
<
2
*
err_baseline
assert
err_sb
<
2
*
err_baseline
err_sb
=
(
x1
.
grad
-
x2
.
grad
).
abs
().
mean
()
err_sb
=
(
x1
.
grad
-
x2
.
grad
).
abs
().
mean
()
err_baseline
=
(
x1
.
grad
-
x3
.
grad
).
abs
().
mean
()
err_baseline
=
(
x1
.
grad
-
x3
.
grad
).
abs
().
mean
()
print
(
'
GX1
'
,
err_sb
,
err_baseline
)
print
(
"
GX1
"
,
err_sb
,
err_baseline
)
assert
err_sb
<
2
*
err_baseline
assert
err_sb
<
2
*
err_baseline
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment