Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
06029dd6
Unverified
Commit
06029dd6
authored
Mar 13, 2024
by
Titus
Committed by
GitHub
Mar 13, 2024
Browse files
Merge pull request #1081 from akx/ruff-format
Reformat Python code with Ruff
parents
fd723b78
5a4263f4
Changes
41
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
928 additions
and
916 deletions
+928
-916
bitsandbytes/triton/dequantize_rowwise.py
bitsandbytes/triton/dequantize_rowwise.py
+19
-19
bitsandbytes/triton/int8_matmul_mixed_dequantize.py
bitsandbytes/triton/int8_matmul_mixed_dequantize.py
+94
-54
bitsandbytes/triton/int8_matmul_rowwise_dequantize.py
bitsandbytes/triton/int8_matmul_rowwise_dequantize.py
+94
-53
bitsandbytes/triton/quantize_columnwise_and_transpose.py
bitsandbytes/triton/quantize_columnwise_and_transpose.py
+25
-23
bitsandbytes/triton/quantize_global.py
bitsandbytes/triton/quantize_global.py
+50
-31
bitsandbytes/triton/quantize_rowwise.py
bitsandbytes/triton/quantize_rowwise.py
+19
-18
bitsandbytes/utils.py
bitsandbytes/utils.py
+13
-15
check_bnb_install.py
check_bnb_install.py
+5
-5
examples/int8_inference_huggingface.py
examples/int8_inference_huggingface.py
+4
-9
install_cuda.py
install_cuda.py
+12
-4
pyproject.toml
pyproject.toml
+6
-3
scripts/stale.py
scripts/stale.py
+2
-1
tests/test_autograd.py
tests/test_autograd.py
+61
-79
tests/test_cuda_setup_evaluator.py
tests/test_cuda_setup_evaluator.py
+1
-4
tests/test_functional.py
tests/test_functional.py
+342
-435
tests/test_generation.py
tests/test_generation.py
+39
-36
tests/test_linear4bit.py
tests/test_linear4bit.py
+5
-7
tests/test_linear8bitlt.py
tests/test_linear8bitlt.py
+15
-3
tests/test_modules.py
tests/test_modules.py
+73
-75
tests/test_optim.py
tests/test_optim.py
+49
-42
No files found.
bitsandbytes/triton/dequantize_rowwise.py
View file @
06029dd6
...
@@ -5,9 +5,10 @@ import torch
...
@@ -5,9 +5,10 @@ import torch
from
bitsandbytes.triton.triton_utils
import
is_triton_available
from
bitsandbytes.triton.triton_utils
import
is_triton_available
if
not
is_triton_available
():
if
not
is_triton_available
():
def
dequantize_rowwise
(
x
:
torch
.
Tensor
,
state_x
:
torch
.
Tensor
):
return
None
else
:
def
dequantize_rowwise
(
x
:
torch
.
Tensor
,
state_x
:
torch
.
Tensor
):
return
None
else
:
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
...
@@ -15,21 +16,21 @@ else:
...
@@ -15,21 +16,21 @@ else:
# TODO: autotune this better.
# TODO: autotune this better.
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
configs
=
[
triton
.
Config
({},
num_stages
=
1
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
1
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
2
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
2
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
4
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
4
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
8
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
8
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
1
),
triton
.
Config
({},
num_stages
=
1
),
triton
.
Config
({},
num_stages
=
2
),
triton
.
Config
({},
num_stages
=
2
),
triton
.
Config
({},
num_stages
=
4
),
triton
.
Config
({},
num_stages
=
4
),
triton
.
Config
({},
num_stages
=
8
),
triton
.
Config
({},
num_stages
=
8
),
triton
.
Config
({},
num_warps
=
1
),
triton
.
Config
({},
num_warps
=
1
),
triton
.
Config
({},
num_warps
=
2
),
triton
.
Config
({},
num_warps
=
2
),
triton
.
Config
({},
num_warps
=
4
),
triton
.
Config
({},
num_warps
=
4
),
triton
.
Config
({},
num_warps
=
8
),
triton
.
Config
({},
num_warps
=
8
),
],
],
key
=
[
'
n_elements
'
]
key
=
[
"
n_elements
"
],
)
)
@
triton
.
jit
@
triton
.
jit
def
_dequantize_rowwise
(
def
_dequantize_rowwise
(
...
@@ -51,7 +52,6 @@ else:
...
@@ -51,7 +52,6 @@ else:
output
=
max_val
*
x
*
inv_127
output
=
max_val
*
x
*
inv_127
tl
.
store
(
output_ptr
+
offsets
,
output
,
mask
=
row_mask
)
tl
.
store
(
output_ptr
+
offsets
,
output
,
mask
=
row_mask
)
def
dequantize_rowwise
(
x
:
torch
.
Tensor
,
state_x
:
torch
.
Tensor
):
def
dequantize_rowwise
(
x
:
torch
.
Tensor
,
state_x
:
torch
.
Tensor
):
output
=
torch
.
empty
(
*
x
.
shape
,
device
=
x
.
device
,
dtype
=
torch
.
float16
)
output
=
torch
.
empty
(
*
x
.
shape
,
device
=
x
.
device
,
dtype
=
torch
.
float16
)
...
@@ -60,5 +60,5 @@ else:
...
@@ -60,5 +60,5 @@ else:
assert
x
.
is_cuda
and
output
.
is_cuda
assert
x
.
is_cuda
and
output
.
is_cuda
n_elements
=
output
.
numel
()
n_elements
=
output
.
numel
()
grid
=
lambda
meta
:
(
x
.
shape
[
0
],)
grid
=
lambda
meta
:
(
x
.
shape
[
0
],)
_dequantize_rowwise
[
grid
](
x
,
state_x
,
output
,
1.
/
127
,
n_elements
,
BLOCK_SIZE
=
x
.
shape
[
1
],
P2
=
P2
)
_dequantize_rowwise
[
grid
](
x
,
state_x
,
output
,
1.
0
/
127
,
n_elements
,
BLOCK_SIZE
=
x
.
shape
[
1
],
P2
=
P2
)
return
output
return
output
bitsandbytes/triton/int8_matmul_mixed_dequantize.py
View file @
06029dd6
...
@@ -3,14 +3,14 @@ import torch
...
@@ -3,14 +3,14 @@ import torch
from
bitsandbytes.triton.triton_utils
import
is_triton_available
from
bitsandbytes.triton.triton_utils
import
is_triton_available
if
not
is_triton_available
():
if
not
is_triton_available
():
def
int8_matmul_mixed_dequantize
(
a
,
b
,
state_x
,
state_w
,
bias
):
return
None
else
:
def
int8_matmul_mixed_dequantize
(
a
,
b
,
state_x
,
state_w
,
bias
):
return
None
else
:
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
# This is a matmul kernel based on triton.ops.matmul
# This is a matmul kernel based on triton.ops.matmul
# It is modified to support rowwise quantized input and global quantized weight
# It is modified to support rowwise quantized input and global quantized weight
# It's purpose is fused matmul then dequantize
# It's purpose is fused matmul then dequantize
...
@@ -27,58 +27,83 @@ else:
...
@@ -27,58 +27,83 @@ else:
for
block_n
in
[
32
,
64
,
128
,
256
]:
for
block_n
in
[
32
,
64
,
128
,
256
]:
num_warps
=
2
if
block_n
<=
64
else
4
num_warps
=
2
if
block_n
<=
64
else
4
configs
.
append
(
configs
.
append
(
triton
.
Config
({
'BLOCK_M'
:
block_m
,
'BLOCK_N'
:
block_n
,
'BLOCK_K'
:
block_k
,
'SPLIT_K'
:
1
},
triton
.
Config
(
num_stages
=
num_stages
,
num_warps
=
num_warps
))
{
"BLOCK_M"
:
block_m
,
"BLOCK_N"
:
block_n
,
"BLOCK_K"
:
block_k
,
"SPLIT_K"
:
1
},
num_stages
=
num_stages
,
num_warps
=
num_warps
,
),
)
# split_k
# split_k
for
split_k
in
[
2
,
4
,
8
,
16
]:
for
split_k
in
[
2
,
4
,
8
,
16
]:
configs
.
append
(
triton
.
Config
({
'BLOCK_M'
:
block_m
,
'BLOCK_N'
:
block_n
,
'BLOCK_K'
:
block_k
,
'SPLIT_K'
:
split_k
},
configs
.
append
(
num_stages
=
num_stages
,
num_warps
=
num_warps
,
pre_hook
=
init_to_zero
(
'C'
)))
triton
.
Config
(
{
"BLOCK_M"
:
block_m
,
"BLOCK_N"
:
block_n
,
"BLOCK_K"
:
block_k
,
"SPLIT_K"
:
split_k
},
num_stages
=
num_stages
,
num_warps
=
num_warps
,
pre_hook
=
init_to_zero
(
"C"
),
),
)
return
configs
return
configs
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
configs
=
[
# basic configs for compute-bound matmuls
# basic configs for compute-bound matmuls
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
256
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
256
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'
BLOCK_M
'
:
256
,
'
BLOCK_N
'
:
128
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"
BLOCK_M
"
:
256
,
"
BLOCK_N
"
:
128
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'
BLOCK_M
'
:
256
,
'
BLOCK_N
'
:
64
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
256
,
"
BLOCK_N
"
:
64
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
64
,
'
BLOCK_N
'
:
256
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
64
,
"
BLOCK_N
"
:
256
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
128
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
128
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
64
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
64
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
64
,
'
BLOCK_N
'
:
128
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
64
,
"
BLOCK_N
"
:
128
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
32
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
32
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
64
,
'
BLOCK_N
'
:
32
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
5
,
num_warps
=
2
),
triton
.
Config
({
"
BLOCK_M
"
:
64
,
"
BLOCK_N
"
:
32
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
# good for int8
# good for int8
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
256
,
'
BLOCK_K
'
:
128
,
'
SPLIT_K
'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
256
,
"
BLOCK_K
"
:
128
,
"
SPLIT_K
"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'
BLOCK_M
'
:
256
,
'
BLOCK_N
'
:
128
,
'
BLOCK_K
'
:
128
,
'
SPLIT_K
'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"
BLOCK_M
"
:
256
,
"
BLOCK_N
"
:
128
,
"
BLOCK_K
"
:
128
,
"
SPLIT_K
"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'
BLOCK_M
'
:
256
,
'
BLOCK_N
'
:
64
,
'
BLOCK_K
'
:
128
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
256
,
"
BLOCK_N
"
:
64
,
"
BLOCK_K
"
:
128
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
64
,
'
BLOCK_N
'
:
256
,
'
BLOCK_K
'
:
128
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
64
,
"
BLOCK_N
"
:
256
,
"
BLOCK_K
"
:
128
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
128
,
'
BLOCK_K
'
:
128
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
128
,
"
BLOCK_K
"
:
128
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
64
,
'
BLOCK_K
'
:
64
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
64
,
"
BLOCK_K
"
:
64
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
64
,
'
BLOCK_N
'
:
128
,
'
BLOCK_K
'
:
64
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
64
,
"
BLOCK_N
"
:
128
,
"
BLOCK_K
"
:
64
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
32
,
'
BLOCK_K
'
:
64
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
32
,
"
BLOCK_K
"
:
64
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
64
,
'
BLOCK_N
'
:
32
,
'
BLOCK_K
'
:
64
,
'
SPLIT_K
'
:
1
},
num_stages
=
5
,
num_warps
=
2
),
triton
.
Config
({
"
BLOCK_M
"
:
64
,
"
BLOCK_N
"
:
32
,
"
BLOCK_K
"
:
64
,
"
SPLIT_K
"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
*
get_configs_io_bound
(),
*
get_configs_io_bound
(),
],
],
key
=
[
'M'
,
'N'
,
'K'
],
key
=
[
"M"
,
"N"
,
"K"
],
prune_configs_by
=
{
prune_configs_by
=
{
"early_config_prune"
:
early_config_prune
,
"perf_model"
:
estimate_matmul_time
,
"top_k"
:
10
},
'early_config_prune'
:
early_config_prune
,
)
'perf_model'
:
estimate_matmul_time
,
@
triton
.
heuristics
(
'top_k'
:
10
{
"EVEN_K"
:
lambda
args
:
args
[
"K"
]
%
(
args
[
"BLOCK_K"
]
*
args
[
"SPLIT_K"
])
==
0
,
},
},
)
)
@
triton
.
heuristics
({
'EVEN_K'
:
lambda
args
:
args
[
'K'
]
%
(
args
[
'BLOCK_K'
]
*
args
[
'SPLIT_K'
])
==
0
,
})
@
triton
.
jit
@
triton
.
jit
def
_int8_matmul_mixed_dequantize
(
A
,
B
,
C
,
bias
,
state_x_ptr
,
state_w_ptr
,
M
,
N
,
K
,
divfactor
:
tl
.
constexpr
,
has_bias
:
tl
.
constexpr
,
def
_int8_matmul_mixed_dequantize
(
stride_am
,
stride_ak
,
A
,
stride_bk
,
stride_bn
,
B
,
stride_cm
,
stride_cn
,
C
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
bias
,
GROUP_M
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
state_x_ptr
,
ACC_TYPE
:
tl
.
constexpr
state_w_ptr
,
):
M
,
N
,
K
,
divfactor
:
tl
.
constexpr
,
has_bias
:
tl
.
constexpr
,
stride_am
,
stride_ak
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
GROUP_M
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
ACC_TYPE
:
tl
.
constexpr
,
):
# matrix multiplication
# matrix multiplication
pid
=
tl
.
program_id
(
0
)
pid
=
tl
.
program_id
(
0
)
pid_z
=
tl
.
program_id
(
1
)
pid_z
=
tl
.
program_id
(
1
)
...
@@ -115,13 +140,13 @@ else:
...
@@ -115,13 +140,13 @@ else:
b
=
tl
.
load
(
B
)
b
=
tl
.
load
(
B
)
else
:
else
:
k_remaining
=
K
-
k
*
(
BLOCK_K
*
SPLIT_K
)
k_remaining
=
K
-
k
*
(
BLOCK_K
*
SPLIT_K
)
a
=
tl
.
load
(
A
,
mask
=
rk
[
None
,
:]
<
k_remaining
,
other
=
0.
)
a
=
tl
.
load
(
A
,
mask
=
rk
[
None
,
:]
<
k_remaining
,
other
=
0.
0
)
b
=
tl
.
load
(
B
,
mask
=
rk
[:,
None
]
<
k_remaining
,
other
=
0.
)
b
=
tl
.
load
(
B
,
mask
=
rk
[:,
None
]
<
k_remaining
,
other
=
0.
0
)
acc
+=
tl
.
dot
(
a
,
b
)
acc
+=
tl
.
dot
(
a
,
b
)
A
+=
BLOCK_K
*
SPLIT_K
*
stride_ak
A
+=
BLOCK_K
*
SPLIT_K
*
stride_ak
B
+=
BLOCK_K
*
SPLIT_K
*
stride_bk
B
+=
BLOCK_K
*
SPLIT_K
*
stride_bk
acc
=
(
w_factor
*
(
x_factor
*
(
acc
*
divfactor
))
)
acc
=
w_factor
*
(
x_factor
*
(
acc
*
divfactor
))
acc
=
acc
.
to
(
C
.
dtype
.
element_ty
)
acc
=
acc
.
to
(
C
.
dtype
.
element_ty
)
# conditionally add bias
# conditionally add bias
...
@@ -137,10 +162,9 @@ else:
...
@@ -137,10 +162,9 @@ else:
else
:
else
:
tl
.
atomic_add
(
C
,
acc
,
mask
=
mask
)
tl
.
atomic_add
(
C
,
acc
,
mask
=
mask
)
def
int8_matmul_mixed_dequantize
(
a
,
b
,
state_x
,
state_w
,
bias
):
def
int8_matmul_mixed_dequantize
(
a
,
b
,
state_x
,
state_w
,
bias
):
device
=
a
.
device
device
=
a
.
device
divfactor
=
1.
/
(
127.
*
127.
)
divfactor
=
1.
0
/
(
127.
0
*
127.
0
)
has_bias
=
0
if
bias
is
None
else
1
has_bias
=
0
if
bias
is
None
else
1
# handle non-contiguous inputs if necessary
# handle non-contiguous inputs if necessary
if
a
.
stride
(
0
)
>
1
and
a
.
stride
(
1
)
>
1
:
if
a
.
stride
(
0
)
>
1
and
a
.
stride
(
1
)
>
1
:
...
@@ -154,12 +178,28 @@ else:
...
@@ -154,12 +178,28 @@ else:
# allocates output
# allocates output
c
=
torch
.
empty
((
M
,
N
),
device
=
device
,
dtype
=
torch
.
float16
)
c
=
torch
.
empty
((
M
,
N
),
device
=
device
,
dtype
=
torch
.
float16
)
# accumulator types
# accumulator types
ACC_TYPE
=
tl
.
float32
#
if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
ACC_TYPE
=
tl
.
float32
#
if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch int8_matmul_mixed_dequantize kernel
# launch int8_matmul_mixed_dequantize kernel
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
'BLOCK_M'
])
*
triton
.
cdiv
(
N
,
META
[
'BLOCK_N'
]),
META
[
'SPLIT_K'
])
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_N"
]),
META
[
"SPLIT_K"
])
_int8_matmul_mixed_dequantize
[
grid
](
a
,
b
,
c
,
bias
,
state_x
,
state_w
,
M
,
N
,
K
,
divfactor
,
has_bias
,
_int8_matmul_mixed_dequantize
[
grid
](
a
.
stride
(
0
),
a
.
stride
(
1
),
a
,
b
.
stride
(
0
),
b
.
stride
(
1
),
b
,
c
.
stride
(
0
),
c
.
stride
(
1
),
c
,
GROUP_M
=
8
,
ACC_TYPE
=
ACC_TYPE
)
bias
,
state_x
,
state_w
,
M
,
N
,
K
,
divfactor
,
has_bias
,
a
.
stride
(
0
),
a
.
stride
(
1
),
b
.
stride
(
0
),
b
.
stride
(
1
),
c
.
stride
(
0
),
c
.
stride
(
1
),
GROUP_M
=
8
,
ACC_TYPE
=
ACC_TYPE
,
)
return
c
return
c
bitsandbytes/triton/int8_matmul_rowwise_dequantize.py
View file @
06029dd6
...
@@ -3,7 +3,9 @@ import torch
...
@@ -3,7 +3,9 @@ import torch
from
bitsandbytes.triton.triton_utils
import
is_triton_available
from
bitsandbytes.triton.triton_utils
import
is_triton_available
if
not
is_triton_available
():
if
not
is_triton_available
():
def
int8_matmul_rowwise_dequantize
(
a
,
b
,
state_x
,
state_w
,
bias
):
return
None
def
int8_matmul_rowwise_dequantize
(
a
,
b
,
state_x
,
state_w
,
bias
):
return
None
else
:
else
:
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
...
@@ -17,7 +19,6 @@ else:
...
@@ -17,7 +19,6 @@ else:
def
init_to_zero
(
name
):
def
init_to_zero
(
name
):
return
lambda
nargs
:
nargs
[
name
].
zero_
()
return
lambda
nargs
:
nargs
[
name
].
zero_
()
def
get_configs_io_bound
():
def
get_configs_io_bound
():
configs
=
[]
configs
=
[]
for
num_stages
in
[
2
,
3
,
4
,
5
,
6
]:
for
num_stages
in
[
2
,
3
,
4
,
5
,
6
]:
...
@@ -26,58 +27,83 @@ else:
...
@@ -26,58 +27,83 @@ else:
for
block_n
in
[
32
,
64
,
128
,
256
]:
for
block_n
in
[
32
,
64
,
128
,
256
]:
num_warps
=
2
if
block_n
<=
64
else
4
num_warps
=
2
if
block_n
<=
64
else
4
configs
.
append
(
configs
.
append
(
triton
.
Config
({
'BLOCK_M'
:
block_m
,
'BLOCK_N'
:
block_n
,
'BLOCK_K'
:
block_k
,
'SPLIT_K'
:
1
},
triton
.
Config
(
num_stages
=
num_stages
,
num_warps
=
num_warps
))
{
"BLOCK_M"
:
block_m
,
"BLOCK_N"
:
block_n
,
"BLOCK_K"
:
block_k
,
"SPLIT_K"
:
1
},
num_stages
=
num_stages
,
num_warps
=
num_warps
,
),
)
# split_k
# split_k
for
split_k
in
[
2
,
4
,
8
,
16
]:
for
split_k
in
[
2
,
4
,
8
,
16
]:
configs
.
append
(
triton
.
Config
({
'BLOCK_M'
:
block_m
,
'BLOCK_N'
:
block_n
,
'BLOCK_K'
:
block_k
,
'SPLIT_K'
:
split_k
},
configs
.
append
(
num_stages
=
num_stages
,
num_warps
=
num_warps
,
pre_hook
=
init_to_zero
(
'C'
)))
triton
.
Config
(
{
"BLOCK_M"
:
block_m
,
"BLOCK_N"
:
block_n
,
"BLOCK_K"
:
block_k
,
"SPLIT_K"
:
split_k
},
num_stages
=
num_stages
,
num_warps
=
num_warps
,
pre_hook
=
init_to_zero
(
"C"
),
),
)
return
configs
return
configs
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
configs
=
[
# basic configs for compute-bound matmuls
# basic configs for compute-bound matmuls
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
256
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
256
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'
BLOCK_M
'
:
256
,
'
BLOCK_N
'
:
128
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"
BLOCK_M
"
:
256
,
"
BLOCK_N
"
:
128
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'
BLOCK_M
'
:
256
,
'
BLOCK_N
'
:
64
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
256
,
"
BLOCK_N
"
:
64
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
64
,
'
BLOCK_N
'
:
256
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
64
,
"
BLOCK_N
"
:
256
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
128
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
128
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
64
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
64
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
64
,
'
BLOCK_N
'
:
128
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
64
,
"
BLOCK_N
"
:
128
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
32
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
32
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
64
,
'
BLOCK_N
'
:
32
,
'
BLOCK_K
'
:
32
,
'
SPLIT_K
'
:
1
},
num_stages
=
5
,
num_warps
=
2
),
triton
.
Config
({
"
BLOCK_M
"
:
64
,
"
BLOCK_N
"
:
32
,
"
BLOCK_K
"
:
32
,
"
SPLIT_K
"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
# good for int8
# good for int8
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
256
,
'
BLOCK_K
'
:
128
,
'
SPLIT_K
'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
256
,
"
BLOCK_K
"
:
128
,
"
SPLIT_K
"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'
BLOCK_M
'
:
256
,
'
BLOCK_N
'
:
128
,
'
BLOCK_K
'
:
128
,
'
SPLIT_K
'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"
BLOCK_M
"
:
256
,
"
BLOCK_N
"
:
128
,
"
BLOCK_K
"
:
128
,
"
SPLIT_K
"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'
BLOCK_M
'
:
256
,
'
BLOCK_N
'
:
64
,
'
BLOCK_K
'
:
128
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
256
,
"
BLOCK_N
"
:
64
,
"
BLOCK_K
"
:
128
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
64
,
'
BLOCK_N
'
:
256
,
'
BLOCK_K
'
:
128
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
64
,
"
BLOCK_N
"
:
256
,
"
BLOCK_K
"
:
128
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
128
,
'
BLOCK_K
'
:
128
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
128
,
"
BLOCK_K
"
:
128
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
64
,
'
BLOCK_K
'
:
64
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
64
,
"
BLOCK_K
"
:
64
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
64
,
'
BLOCK_N
'
:
128
,
'
BLOCK_K
'
:
64
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
64
,
"
BLOCK_N
"
:
128
,
"
BLOCK_K
"
:
64
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
128
,
'
BLOCK_N
'
:
32
,
'
BLOCK_K
'
:
64
,
'
SPLIT_K
'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"
BLOCK_M
"
:
128
,
"
BLOCK_N
"
:
32
,
"
BLOCK_K
"
:
64
,
"
SPLIT_K
"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'
BLOCK_M
'
:
64
,
'
BLOCK_N
'
:
32
,
'
BLOCK_K
'
:
64
,
'
SPLIT_K
'
:
1
},
num_stages
=
5
,
num_warps
=
2
),
triton
.
Config
({
"
BLOCK_M
"
:
64
,
"
BLOCK_N
"
:
32
,
"
BLOCK_K
"
:
64
,
"
SPLIT_K
"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
*
get_configs_io_bound
(),
*
get_configs_io_bound
(),
],
],
key
=
[
'M'
,
'N'
,
'K'
],
key
=
[
"M"
,
"N"
,
"K"
],
prune_configs_by
=
{
prune_configs_by
=
{
"early_config_prune"
:
early_config_prune
,
"perf_model"
:
estimate_matmul_time
,
"top_k"
:
10
},
'early_config_prune'
:
early_config_prune
,
)
'perf_model'
:
estimate_matmul_time
,
@
triton
.
heuristics
(
'top_k'
:
10
{
"EVEN_K"
:
lambda
args
:
args
[
"K"
]
%
(
args
[
"BLOCK_K"
]
*
args
[
"SPLIT_K"
])
==
0
,
},
},
)
)
@
triton
.
heuristics
({
'EVEN_K'
:
lambda
args
:
args
[
'K'
]
%
(
args
[
'BLOCK_K'
]
*
args
[
'SPLIT_K'
])
==
0
,
})
@
triton
.
jit
@
triton
.
jit
def
_int8_matmul_rowwise_dequantize
(
A
,
B
,
C
,
bias
,
state_x_ptr
,
state_w_ptr
,
M
,
N
,
K
,
divfactor
,
has_bias
:
tl
.
constexpr
,
def
_int8_matmul_rowwise_dequantize
(
stride_am
,
stride_ak
,
A
,
stride_bk
,
stride_bn
,
B
,
stride_cm
,
stride_cn
,
C
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
bias
,
GROUP_M
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
state_x_ptr
,
ACC_TYPE
:
tl
.
constexpr
state_w_ptr
,
):
M
,
N
,
K
,
divfactor
,
has_bias
:
tl
.
constexpr
,
stride_am
,
stride_ak
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
GROUP_M
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
ACC_TYPE
:
tl
.
constexpr
,
):
# matrix multiplication
# matrix multiplication
pid
=
tl
.
program_id
(
0
)
pid
=
tl
.
program_id
(
0
)
pid_z
=
tl
.
program_id
(
1
)
pid_z
=
tl
.
program_id
(
1
)
...
@@ -114,13 +140,13 @@ else:
...
@@ -114,13 +140,13 @@ else:
b
=
tl
.
load
(
B
)
b
=
tl
.
load
(
B
)
else
:
else
:
k_remaining
=
K
-
k
*
(
BLOCK_K
*
SPLIT_K
)
k_remaining
=
K
-
k
*
(
BLOCK_K
*
SPLIT_K
)
a
=
tl
.
load
(
A
,
mask
=
rk
[
None
,
:]
<
k_remaining
,
other
=
0.
)
a
=
tl
.
load
(
A
,
mask
=
rk
[
None
,
:]
<
k_remaining
,
other
=
0.
0
)
b
=
tl
.
load
(
B
,
mask
=
rk
[:,
None
]
<
k_remaining
,
other
=
0.
)
b
=
tl
.
load
(
B
,
mask
=
rk
[:,
None
]
<
k_remaining
,
other
=
0.
0
)
acc
+=
tl
.
dot
(
a
,
b
)
acc
+=
tl
.
dot
(
a
,
b
)
A
+=
BLOCK_K
*
SPLIT_K
*
stride_ak
A
+=
BLOCK_K
*
SPLIT_K
*
stride_ak
B
+=
BLOCK_K
*
SPLIT_K
*
stride_bk
B
+=
BLOCK_K
*
SPLIT_K
*
stride_bk
acc
=
(
w_factor
*
(
x_factor
*
(
acc
*
divfactor
))
)
acc
=
w_factor
*
(
x_factor
*
(
acc
*
divfactor
))
acc
=
acc
.
to
(
C
.
dtype
.
element_ty
)
acc
=
acc
.
to
(
C
.
dtype
.
element_ty
)
if
has_bias
:
if
has_bias
:
...
@@ -135,9 +161,8 @@ else:
...
@@ -135,9 +161,8 @@ else:
else
:
else
:
tl
.
atomic_add
(
C
,
acc
,
mask
=
mask
)
tl
.
atomic_add
(
C
,
acc
,
mask
=
mask
)
def
int8_matmul_rowwise_dequantize
(
a
,
b
,
state_x
,
state_w
,
bias
):
def
int8_matmul_rowwise_dequantize
(
a
,
b
,
state_x
,
state_w
,
bias
):
divfactor
=
1.
/
(
127.
*
127.
)
divfactor
=
1.
0
/
(
127.
0
*
127.
0
)
has_bias
=
0
if
bias
is
None
else
1
has_bias
=
0
if
bias
is
None
else
1
...
@@ -154,12 +179,28 @@ else:
...
@@ -154,12 +179,28 @@ else:
# allocates output
# allocates output
c
=
torch
.
empty
((
M
,
N
),
device
=
device
,
dtype
=
torch
.
float16
)
c
=
torch
.
empty
((
M
,
N
),
device
=
device
,
dtype
=
torch
.
float16
)
# accumulator types
# accumulator types
ACC_TYPE
=
tl
.
float32
#
if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
ACC_TYPE
=
tl
.
float32
#
if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch int8_matmul_rowwise_dequantize kernel
# launch int8_matmul_rowwise_dequantize kernel
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
'BLOCK_M'
])
*
triton
.
cdiv
(
N
,
META
[
'BLOCK_N'
]),
META
[
'SPLIT_K'
])
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_N"
]),
META
[
"SPLIT_K"
])
_int8_matmul_rowwise_dequantize
[
grid
](
a
,
b
,
c
,
bias
,
state_x
,
state_w
,
M
,
N
,
K
,
divfactor
,
has_bias
,
_int8_matmul_rowwise_dequantize
[
grid
](
a
.
stride
(
0
),
a
.
stride
(
1
),
a
,
b
.
stride
(
0
),
b
.
stride
(
1
),
b
,
c
.
stride
(
0
),
c
.
stride
(
1
),
c
,
GROUP_M
=
8
,
ACC_TYPE
=
ACC_TYPE
)
bias
,
state_x
,
state_w
,
M
,
N
,
K
,
divfactor
,
has_bias
,
a
.
stride
(
0
),
a
.
stride
(
1
),
b
.
stride
(
0
),
b
.
stride
(
1
),
c
.
stride
(
0
),
c
.
stride
(
1
),
GROUP_M
=
8
,
ACC_TYPE
=
ACC_TYPE
,
)
return
c
return
c
bitsandbytes/triton/quantize_columnwise_and_transpose.py
View file @
06029dd6
...
@@ -5,9 +5,10 @@ import torch
...
@@ -5,9 +5,10 @@ import torch
from
bitsandbytes.triton.triton_utils
import
is_triton_available
from
bitsandbytes.triton.triton_utils
import
is_triton_available
if
not
is_triton_available
():
if
not
is_triton_available
():
def
quantize_columnwise_and_transpose
(
x
:
torch
.
Tensor
):
return
None
else
:
def
quantize_columnwise_and_transpose
(
x
:
torch
.
Tensor
):
return
None
else
:
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
...
@@ -15,23 +16,23 @@ else:
...
@@ -15,23 +16,23 @@ else:
# TODO: autotune this better.
# TODO: autotune this better.
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
configs
=
[
triton
.
Config
({},
num_stages
=
1
),
triton
.
Config
({},
num_stages
=
1
),
triton
.
Config
({},
num_stages
=
2
),
triton
.
Config
({},
num_stages
=
2
),
triton
.
Config
({},
num_stages
=
4
),
triton
.
Config
({},
num_stages
=
4
),
triton
.
Config
({},
num_stages
=
8
),
triton
.
Config
({},
num_stages
=
8
),
triton
.
Config
({},
num_stages
=
16
),
triton
.
Config
({},
num_stages
=
16
),
triton
.
Config
({},
num_stages
=
1
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
1
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
2
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
2
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
4
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
4
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
8
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
8
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
16
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
16
,
num_warps
=
8
),
triton
.
Config
({},
num_warps
=
1
),
triton
.
Config
({},
num_warps
=
1
),
triton
.
Config
({},
num_warps
=
2
),
triton
.
Config
({},
num_warps
=
2
),
triton
.
Config
({},
num_warps
=
4
),
triton
.
Config
({},
num_warps
=
4
),
triton
.
Config
({},
num_warps
=
8
),
triton
.
Config
({},
num_warps
=
8
),
],
],
key
=
[
'
n_elements
'
]
key
=
[
"
n_elements
"
],
)
)
@
triton
.
jit
@
triton
.
jit
def
_quantize_columnwise_and_transpose
(
def
_quantize_columnwise_and_transpose
(
...
@@ -39,7 +40,8 @@ else:
...
@@ -39,7 +40,8 @@ else:
output_ptr
,
output_ptr
,
output_maxs
,
output_maxs
,
n_elements
,
n_elements
,
M
:
tl
.
constexpr
,
N
:
tl
.
constexpr
,
M
:
tl
.
constexpr
,
N
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
P2
:
tl
.
constexpr
,
P2
:
tl
.
constexpr
,
):
):
...
@@ -47,12 +49,12 @@ else:
...
@@ -47,12 +49,12 @@ else:
block_start
=
pid
block_start
=
pid
p2_arange
=
tl
.
arange
(
0
,
P2
)
p2_arange
=
tl
.
arange
(
0
,
P2
)
p2_arange_mask
=
p2_arange
<
M
p2_arange_mask
=
p2_arange
<
M
arange
=
p2_arange
*
N
arange
=
p2_arange
*
N
offsets
=
block_start
+
arange
offsets
=
block_start
+
arange
x
=
tl
.
load
(
x_ptr
+
offsets
,
mask
=
p2_arange_mask
)
x
=
tl
.
load
(
x_ptr
+
offsets
,
mask
=
p2_arange_mask
)
abs_x
=
tl
.
abs
(
x
)
abs_x
=
tl
.
abs
(
x
)
max_val
=
tl
.
max
(
tl
.
where
(
p2_arange_mask
,
abs_x
,
0
),
axis
=
0
)
max_val
=
tl
.
max
(
tl
.
where
(
p2_arange_mask
,
abs_x
,
0
),
axis
=
0
)
output
=
tl
.
libdevice
.
llrint
(
127.
*
(
x
/
max_val
))
output
=
tl
.
libdevice
.
llrint
(
127.
0
*
(
x
/
max_val
))
new_start
=
pid
*
M
new_start
=
pid
*
M
new_offsets
=
new_start
+
p2_arange
new_offsets
=
new_start
+
p2_arange
...
@@ -68,6 +70,6 @@ else:
...
@@ -68,6 +70,6 @@ else:
assert
x
.
is_cuda
and
output
.
is_cuda
assert
x
.
is_cuda
and
output
.
is_cuda
n_elements
=
output
.
numel
()
n_elements
=
output
.
numel
()
grid
=
lambda
meta
:
(
triton
.
cdiv
(
n_elements
,
meta
[
'
BLOCK_SIZE
'
]),)
grid
=
lambda
meta
:
(
triton
.
cdiv
(
n_elements
,
meta
[
"
BLOCK_SIZE
"
]),)
_quantize_columnwise_and_transpose
[
grid
](
x
,
output
,
output_maxs
,
n_elements
,
M
,
N
,
BLOCK_SIZE
=
M
,
P2
=
P2
)
_quantize_columnwise_and_transpose
[
grid
](
x
,
output
,
output_maxs
,
n_elements
,
M
,
N
,
BLOCK_SIZE
=
M
,
P2
=
P2
)
return
output
,
output_maxs
return
output
,
output_maxs
bitsandbytes/triton/quantize_global.py
View file @
06029dd6
import
torch
import
torch
from
bitsandbytes.triton.triton_utils
import
is_triton_available
from
bitsandbytes.triton.triton_utils
import
is_triton_available
if
not
is_triton_available
():
if
not
is_triton_available
():
def
quantize_global_transpose
(
input
):
return
None
def
quantize_global
(
x
:
torch
.
Tensor
):
return
None
else
:
def
quantize_global_transpose
(
input
):
return
None
def
quantize_global
(
x
:
torch
.
Tensor
):
return
None
else
:
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
# global quantize
# global quantize
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
configs
=
[
triton
.
Config
({
'BLOCK_SIZE'
:
1024
,},
num_warps
=
4
),
triton
.
Config
({
"BLOCK_SIZE"
:
1024
},
num_warps
=
4
),
triton
.
Config
({
'BLOCK_SIZE'
:
2048
,},
num_stages
=
1
),
triton
.
Config
({
"BLOCK_SIZE"
:
2048
},
num_stages
=
1
),
],
],
key
=
[
"n_elements"
],
key
=
[
'n_elements'
]
)
)
@
triton
.
jit
@
triton
.
jit
def
_quantize_global
(
def
_quantize_global
(
...
@@ -34,35 +35,43 @@ else:
...
@@ -34,35 +35,43 @@ else:
mask
=
offsets
<
n_elements
mask
=
offsets
<
n_elements
x
=
tl
.
load
(
x_ptr
+
offsets
,
mask
=
mask
)
x
=
tl
.
load
(
x_ptr
+
offsets
,
mask
=
mask
)
absmax_inv
=
tl
.
load
(
absmax_inv_ptr
)
absmax_inv
=
tl
.
load
(
absmax_inv_ptr
)
output
=
tl
.
libdevice
.
llrint
(
127.
*
(
x
*
absmax_inv
))
output
=
tl
.
libdevice
.
llrint
(
127.
0
*
(
x
*
absmax_inv
))
tl
.
store
(
output_ptr
+
offsets
,
output
,
mask
=
mask
)
tl
.
store
(
output_ptr
+
offsets
,
output
,
mask
=
mask
)
def
quantize_global
(
x
:
torch
.
Tensor
):
def
quantize_global
(
x
:
torch
.
Tensor
):
absmax
=
x
.
abs
().
max
().
unsqueeze
(
0
)
absmax
=
x
.
abs
().
max
().
unsqueeze
(
0
)
absmax_inv
=
1.
/
absmax
absmax_inv
=
1.
0
/
absmax
output
=
torch
.
empty
(
*
x
.
shape
,
device
=
'
cuda
'
,
dtype
=
torch
.
int8
)
output
=
torch
.
empty
(
*
x
.
shape
,
device
=
"
cuda
"
,
dtype
=
torch
.
int8
)
assert
x
.
is_cuda
and
output
.
is_cuda
assert
x
.
is_cuda
and
output
.
is_cuda
n_elements
=
output
.
numel
()
n_elements
=
output
.
numel
()
grid
=
lambda
meta
:
(
triton
.
cdiv
(
n_elements
,
meta
[
'
BLOCK_SIZE
'
]),)
grid
=
lambda
meta
:
(
triton
.
cdiv
(
n_elements
,
meta
[
"
BLOCK_SIZE
"
]),)
_quantize_global
[
grid
](
x
,
absmax_inv
,
output
,
n_elements
)
_quantize_global
[
grid
](
x
,
absmax_inv
,
output
,
n_elements
)
return
output
,
absmax
return
output
,
absmax
# global quantize and transpose
# global quantize and transpose
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
configs
=
[
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
128
,
'GROUP_M'
:
8
},
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"GROUP_M"
:
8
},
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
128
,
'GROUP_M'
:
8
},
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"GROUP_M"
:
8
},
num_warps
=
4
),
# ...
# ...
],
],
key
=
[
"M"
,
"N"
],
key
=
[
'M'
,
'N'
]
)
)
@
triton
.
jit
@
triton
.
jit
def
_quantize_global_transpose
(
A
,
absmax_inv_ptr
,
B
,
stride_am
,
stride_an
,
stride_bn
,
stride_bm
,
M
,
N
,
def
_quantize_global_transpose
(
BLOCK_M
:
tl
.
constexpr
,
A
,
BLOCK_N
:
tl
.
constexpr
,
absmax_inv_ptr
,
GROUP_M
:
tl
.
constexpr
):
B
,
stride_am
,
stride_an
,
stride_bn
,
stride_bm
,
M
,
N
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
GROUP_M
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
pid
=
tl
.
program_id
(
0
)
grid_m
=
(
M
+
BLOCK_M
-
1
)
//
BLOCK_M
grid_m
=
(
M
+
BLOCK_M
-
1
)
//
BLOCK_M
grid_n
=
(
N
+
BLOCK_N
-
1
)
//
BLOCK_N
grid_n
=
(
N
+
BLOCK_N
-
1
)
//
BLOCK_N
...
@@ -86,20 +95,30 @@ else:
...
@@ -86,20 +95,30 @@ else:
B
=
B
+
(
rm
[:,
None
]
*
stride_bm
+
rn
[
None
,
:]
*
stride_bn
)
B
=
B
+
(
rm
[:,
None
]
*
stride_bm
+
rn
[
None
,
:]
*
stride_bn
)
mask
=
(
rm
<
M
)[:,
None
]
&
(
rn
<
N
)[
None
,
:]
mask
=
(
rm
<
M
)[:,
None
]
&
(
rn
<
N
)[
None
,
:]
output
=
tl
.
libdevice
.
llrint
(
127.
*
(
a
*
absmax_inv
))
output
=
tl
.
libdevice
.
llrint
(
127.
0
*
(
a
*
absmax_inv
))
tl
.
store
(
B
,
output
,
mask
=
mask
)
tl
.
store
(
B
,
output
,
mask
=
mask
)
def
quantize_global_transpose
(
input
):
def
quantize_global_transpose
(
input
):
absmax
=
input
.
abs
().
max
().
unsqueeze
(
0
)
absmax
=
input
.
abs
().
max
().
unsqueeze
(
0
)
absmax_inv
=
1.
/
absmax
absmax_inv
=
1.
0
/
absmax
M
,
N
=
input
.
shape
M
,
N
=
input
.
shape
out
=
torch
.
empty
(
N
,
M
,
device
=
'
cuda
'
,
dtype
=
torch
.
int8
)
out
=
torch
.
empty
(
N
,
M
,
device
=
"
cuda
"
,
dtype
=
torch
.
int8
)
assert
out
.
size
(
0
)
==
N
and
out
.
size
(
1
)
==
M
assert
out
.
size
(
0
)
==
N
and
out
.
size
(
1
)
==
M
assert
input
.
stride
(
0
)
==
1
or
input
.
stride
(
1
)
==
1
assert
input
.
stride
(
0
)
==
1
or
input
.
stride
(
1
)
==
1
assert
out
.
stride
(
0
)
==
1
or
out
.
stride
(
1
)
==
1
assert
out
.
stride
(
0
)
==
1
or
out
.
stride
(
1
)
==
1
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
'BLOCK_M'
])
*
triton
.
cdiv
(
N
,
META
[
'BLOCK_N'
]),)
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_N"
]),)
_quantize_global_transpose
[
grid
](
input
,
absmax_inv
,
out
,
input
.
stride
(
0
),
input
.
stride
(
1
),
out
.
stride
(
0
),
out
.
stride
(
1
),
M
,
N
)
_quantize_global_transpose
[
grid
](
input
,
absmax_inv
,
out
,
input
.
stride
(
0
),
input
.
stride
(
1
),
out
.
stride
(
0
),
out
.
stride
(
1
),
M
,
N
,
)
return
out
,
absmax
return
out
,
absmax
bitsandbytes/triton/quantize_rowwise.py
View file @
06029dd6
...
@@ -5,9 +5,10 @@ import torch
...
@@ -5,9 +5,10 @@ import torch
from
bitsandbytes.triton.triton_utils
import
is_triton_available
from
bitsandbytes.triton.triton_utils
import
is_triton_available
if
not
is_triton_available
():
if
not
is_triton_available
():
def
quantize_rowwise
(
x
:
torch
.
Tensor
):
return
None
else
:
def
quantize_rowwise
(
x
:
torch
.
Tensor
):
return
None
else
:
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
...
@@ -15,21 +16,21 @@ else:
...
@@ -15,21 +16,21 @@ else:
# TODO: autotune this better.
# TODO: autotune this better.
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
configs
=
[
triton
.
Config
({},
num_stages
=
1
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
1
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
2
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
2
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
4
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
4
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
8
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
8
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
1
),
triton
.
Config
({},
num_stages
=
1
),
triton
.
Config
({},
num_stages
=
2
),
triton
.
Config
({},
num_stages
=
2
),
triton
.
Config
({},
num_stages
=
4
),
triton
.
Config
({},
num_stages
=
4
),
triton
.
Config
({},
num_stages
=
8
),
triton
.
Config
({},
num_stages
=
8
),
triton
.
Config
({},
num_warps
=
1
),
triton
.
Config
({},
num_warps
=
1
),
triton
.
Config
({},
num_warps
=
2
),
triton
.
Config
({},
num_warps
=
2
),
triton
.
Config
({},
num_warps
=
4
),
triton
.
Config
({},
num_warps
=
4
),
triton
.
Config
({},
num_warps
=
8
),
triton
.
Config
({},
num_warps
=
8
),
],
],
key
=
[
'
n_elements
'
]
key
=
[
"
n_elements
"
],
)
)
@
triton
.
jit
@
triton
.
jit
def
_quantize_rowwise
(
def
_quantize_rowwise
(
...
@@ -49,7 +50,7 @@ else:
...
@@ -49,7 +50,7 @@ else:
abs_x
=
tl
.
abs
(
x
)
abs_x
=
tl
.
abs
(
x
)
max_val
=
tl
.
max
(
tl
.
where
(
row_mask
,
abs_x
,
0
),
axis
=
0
)
max_val
=
tl
.
max
(
tl
.
where
(
row_mask
,
abs_x
,
0
),
axis
=
0
)
output
=
tl
.
libdevice
.
llrint
(
127.
*
(
x
/
max_val
))
output
=
tl
.
libdevice
.
llrint
(
127.
0
*
(
x
/
max_val
))
tl
.
store
(
output_ptr
+
offsets
,
output
,
mask
=
row_mask
)
tl
.
store
(
output_ptr
+
offsets
,
output
,
mask
=
row_mask
)
tl
.
store
(
output_maxs
+
pid
,
max_val
)
tl
.
store
(
output_maxs
+
pid
,
max_val
)
...
...
bitsandbytes/utils.py
View file @
06029dd6
...
@@ -30,7 +30,7 @@ def outlier_hook(module, input):
...
@@ -30,7 +30,7 @@ def outlier_hook(module, input):
# (1) zscore test of std of hidden dimension
# (1) zscore test of std of hidden dimension
outlier_idx
=
find_outlier_dims
(
merged
,
reduction_dim
=
1
,
zscore
=
3
)
outlier_idx
=
find_outlier_dims
(
merged
,
reduction_dim
=
1
,
zscore
=
3
)
# (2) magnitude > 6 test
# (2) magnitude > 6 test
dims
=
(
torch
.
abs
(
input
[
0
])
>
6
).
sum
(
dim
=
list
(
range
(
len
(
input
[
0
].
shape
)
-
1
)))
dims
=
(
torch
.
abs
(
input
[
0
])
>
6
).
sum
(
dim
=
list
(
range
(
len
(
input
[
0
].
shape
)
-
1
)))
outlier_idx2
=
torch
.
where
(
dims
>
0
)[
0
]
outlier_idx2
=
torch
.
where
(
dims
>
0
)[
0
]
outlier_idx
=
torch
.
cat
([
outlier_idx
,
outlier_idx2
]).
unique
()
outlier_idx
=
torch
.
cat
([
outlier_idx
,
outlier_idx2
]).
unique
()
tracer
.
hvalue2outlier_idx
[
hvalue
]
=
outlier_idx
tracer
.
hvalue2outlier_idx
[
hvalue
]
=
outlier_idx
...
@@ -59,14 +59,14 @@ class OutlierTracer:
...
@@ -59,14 +59,14 @@ class OutlierTracer:
self
.
hooks
.
append
(
m
.
register_forward_pre_hook
(
outlier_hook
))
self
.
hooks
.
append
(
m
.
register_forward_pre_hook
(
outlier_hook
))
def
is_initialized
(
self
):
def
is_initialized
(
self
):
return
getattr
(
self
,
'
initialized
'
,
False
)
return
getattr
(
self
,
"
initialized
"
,
False
)
def
get_hvalue
(
self
,
weight
):
def
get_hvalue
(
self
,
weight
):
return
weight
.
data
.
storage
().
data_ptr
()
return
weight
.
data
.
storage
().
data_ptr
()
def
get_outliers
(
self
,
weight
):
def
get_outliers
(
self
,
weight
):
if
not
self
.
is_initialized
():
if
not
self
.
is_initialized
():
print
(
'
Outlier tracer is not initialized...
'
)
print
(
"
Outlier tracer is not initialized...
"
)
return
None
return
None
hvalue
=
self
.
get_hvalue
(
weight
)
hvalue
=
self
.
get_hvalue
(
weight
)
if
hvalue
in
self
.
hvalue2outlier_idx
:
if
hvalue
in
self
.
hvalue2outlier_idx
:
...
@@ -80,6 +80,7 @@ class OutlierTracer:
...
@@ -80,6 +80,7 @@ class OutlierTracer:
cls
.
_instance
=
cls
.
__new__
(
cls
)
cls
.
_instance
=
cls
.
__new__
(
cls
)
return
cls
.
_instance
return
cls
.
_instance
def
find_outlier_dims
(
weight
,
reduction_dim
=
0
,
zscore
=
4.0
,
topk
=
None
,
rdm
=
False
):
def
find_outlier_dims
(
weight
,
reduction_dim
=
0
,
zscore
=
4.0
,
topk
=
None
,
rdm
=
False
):
if
rdm
:
if
rdm
:
return
torch
.
randint
(
0
,
weight
.
shape
[
1
],
size
=
(
topk
,),
device
=
weight
.
device
).
long
()
return
torch
.
randint
(
0
,
weight
.
shape
[
1
],
size
=
(
topk
,),
device
=
weight
.
device
).
long
()
...
@@ -87,13 +88,13 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False)
...
@@ -87,13 +88,13 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False)
m
=
weight
.
mean
(
reduction_dim
)
m
=
weight
.
mean
(
reduction_dim
)
mm
=
m
.
mean
()
mm
=
m
.
mean
()
mstd
=
m
.
std
()
mstd
=
m
.
std
()
zm
=
(
m
-
mm
)
/
mstd
zm
=
(
m
-
mm
)
/
mstd
std
=
weight
.
std
(
reduction_dim
)
std
=
weight
.
std
(
reduction_dim
)
stdm
=
std
.
mean
()
stdm
=
std
.
mean
()
stdstd
=
std
.
std
()
stdstd
=
std
.
std
()
zstd
=
(
std
-
stdm
)
/
stdstd
zstd
=
(
std
-
stdm
)
/
stdstd
if
topk
is
not
None
:
if
topk
is
not
None
:
val
,
idx
=
torch
.
topk
(
std
.
abs
(),
k
=
topk
,
dim
=
0
)
val
,
idx
=
torch
.
topk
(
std
.
abs
(),
k
=
topk
,
dim
=
0
)
...
@@ -105,10 +106,7 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False)
...
@@ -105,10 +106,7 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False)
def
execute_and_return
(
command_string
:
str
)
->
Tuple
[
str
,
str
]:
def
execute_and_return
(
command_string
:
str
)
->
Tuple
[
str
,
str
]:
def
_decode
(
subprocess_err_out_tuple
):
def
_decode
(
subprocess_err_out_tuple
):
return
tuple
(
return
tuple
(
to_decode
.
decode
(
"UTF-8"
).
strip
()
for
to_decode
in
subprocess_err_out_tuple
)
to_decode
.
decode
(
"UTF-8"
).
strip
()
for
to_decode
in
subprocess_err_out_tuple
)
def
execute_and_return_decoded_std_streams
(
command_string
):
def
execute_and_return_decoded_std_streams
(
command_string
):
return
_decode
(
return
_decode
(
...
@@ -116,14 +114,13 @@ def execute_and_return(command_string: str) -> Tuple[str, str]:
...
@@ -116,14 +114,13 @@ def execute_and_return(command_string: str) -> Tuple[str, str]:
shlex
.
split
(
command_string
),
shlex
.
split
(
command_string
),
stdout
=
subprocess
.
PIPE
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
).
communicate
()
).
communicate
()
,
)
)
std_out
,
std_err
=
execute_and_return_decoded_std_streams
(
command_string
)
std_out
,
std_err
=
execute_and_return_decoded_std_streams
(
command_string
)
return
std_out
,
std_err
return
std_out
,
std_err
def
replace_linear
(
def
replace_linear
(
model
,
model
,
linear_replacement
,
linear_replacement
,
...
@@ -163,8 +160,9 @@ def replace_linear(
...
@@ -163,8 +160,9 @@ def replace_linear(
model
.
_modules
[
name
].
bias
=
old_module
.
bias
model
.
_modules
[
name
].
bias
=
old_module
.
bias
if
post_processing_function
is
not
None
:
if
post_processing_function
is
not
None
:
func
=
getattr
(
module
,
post_processing_function
,
None
)
func
=
getattr
(
module
,
post_processing_function
,
None
)
if
func
is
not
None
:
func
(
module
)
if
func
is
not
None
:
func
(
module
)
return
model
return
model
...
@@ -179,7 +177,7 @@ def pack_dict_to_tensor(source_dict):
...
@@ -179,7 +177,7 @@ def pack_dict_to_tensor(source_dict):
A torch tensor containing the packed data.
A torch tensor containing the packed data.
"""
"""
json_str
=
json
.
dumps
(
source_dict
)
json_str
=
json
.
dumps
(
source_dict
)
json_bytes
=
json_str
.
encode
(
'
utf-8
'
)
json_bytes
=
json_str
.
encode
(
"
utf-8
"
)
tensor_data
=
torch
.
tensor
(
list
(
json_bytes
),
dtype
=
torch
.
uint8
)
tensor_data
=
torch
.
tensor
(
list
(
json_bytes
),
dtype
=
torch
.
uint8
)
return
tensor_data
return
tensor_data
...
@@ -196,7 +194,7 @@ def unpack_tensor_to_dict(tensor_data):
...
@@ -196,7 +194,7 @@ def unpack_tensor_to_dict(tensor_data):
A Python dictionary containing the unpacked data.
A Python dictionary containing the unpacked data.
"""
"""
json_bytes
=
bytes
(
tensor_data
.
cpu
().
numpy
())
json_bytes
=
bytes
(
tensor_data
.
cpu
().
numpy
())
json_str
=
json_bytes
.
decode
(
'
utf-8
'
)
json_str
=
json_bytes
.
decode
(
"
utf-8
"
)
unpacked_dict
=
json
.
loads
(
json_str
)
unpacked_dict
=
json
.
loads
(
json_str
)
return
unpacked_dict
return
unpacked_dict
check_bnb_install.py
View file @
06029dd6
...
@@ -2,14 +2,14 @@ import torch
...
@@ -2,14 +2,14 @@ import torch
import
bitsandbytes
as
bnb
import
bitsandbytes
as
bnb
p
=
torch
.
nn
.
Parameter
(
torch
.
rand
(
10
,
10
).
cuda
())
p
=
torch
.
nn
.
Parameter
(
torch
.
rand
(
10
,
10
).
cuda
())
a
=
torch
.
rand
(
10
,
10
).
cuda
()
a
=
torch
.
rand
(
10
,
10
).
cuda
()
p1
=
p
.
data
.
sum
().
item
()
p1
=
p
.
data
.
sum
().
item
()
adam
=
bnb
.
optim
.
Adam
([
p
])
adam
=
bnb
.
optim
.
Adam
([
p
])
out
=
a
*
p
out
=
a
*
p
loss
=
out
.
sum
()
loss
=
out
.
sum
()
loss
.
backward
()
loss
.
backward
()
adam
.
step
()
adam
.
step
()
...
@@ -17,5 +17,5 @@ adam.step()
...
@@ -17,5 +17,5 @@ adam.step()
p2
=
p
.
data
.
sum
().
item
()
p2
=
p
.
data
.
sum
().
item
()
assert
p1
!=
p2
assert
p1
!=
p2
print
(
'
SUCCESS!
'
)
print
(
"
SUCCESS!
"
)
print
(
'
Installation was successful!
'
)
print
(
"
Installation was successful!
"
)
examples/int8_inference_huggingface.py
View file @
06029dd6
...
@@ -2,23 +2,18 @@ import torch
...
@@ -2,23 +2,18 @@ import torch
from
transformers
import
LlamaForCausalLM
,
LlamaTokenizer
from
transformers
import
LlamaForCausalLM
,
LlamaTokenizer
MAX_NEW_TOKENS
=
128
MAX_NEW_TOKENS
=
128
model_name
=
'
meta-llama/Llama-2-7b-hf
'
model_name
=
"
meta-llama/Llama-2-7b-hf
"
text
=
'
Hamburg is in which country?
\n
'
text
=
"
Hamburg is in which country?
\n
"
tokenizer
=
LlamaTokenizer
.
from_pretrained
(
model_name
)
tokenizer
=
LlamaTokenizer
.
from_pretrained
(
model_name
)
input_ids
=
tokenizer
(
text
,
return_tensors
=
"pt"
).
input_ids
input_ids
=
tokenizer
(
text
,
return_tensors
=
"pt"
).
input_ids
max_memory
=
f
'
{
int
(
torch
.
cuda
.
mem_get_info
()[
0
]
/
1024
**
3
)
-
2
}
GB
'
max_memory
=
f
"
{
int
(
torch
.
cuda
.
mem_get_info
()[
0
]
/
1024
**
3
)
-
2
}
GB
"
n_gpus
=
torch
.
cuda
.
device_count
()
n_gpus
=
torch
.
cuda
.
device_count
()
max_memory
=
{
i
:
max_memory
for
i
in
range
(
n_gpus
)}
max_memory
=
{
i
:
max_memory
for
i
in
range
(
n_gpus
)}
model
=
LlamaForCausalLM
.
from_pretrained
(
model
=
LlamaForCausalLM
.
from_pretrained
(
model_name
,
device_map
=
"auto"
,
load_in_8bit
=
True
,
max_memory
=
max_memory
)
model_name
,
device_map
=
'auto'
,
load_in_8bit
=
True
,
max_memory
=
max_memory
)
generated_ids
=
model
.
generate
(
input_ids
,
max_length
=
MAX_NEW_TOKENS
)
generated_ids
=
model
.
generate
(
input_ids
,
max_length
=
MAX_NEW_TOKENS
)
print
(
tokenizer
.
decode
(
generated_ids
[
0
],
skip_special_tokens
=
True
))
print
(
tokenizer
.
decode
(
generated_ids
[
0
],
skip_special_tokens
=
True
))
install_cuda.py
View file @
06029dd6
...
@@ -19,6 +19,7 @@ cuda_versions = {
...
@@ -19,6 +19,7 @@ cuda_versions = {
"123"
:
"https://developer.download.nvidia.com/compute/cuda/12.3.2/local_installers/cuda_12.3.2_545.23.08_linux.run"
,
"123"
:
"https://developer.download.nvidia.com/compute/cuda/12.3.2/local_installers/cuda_12.3.2_545.23.08_linux.run"
,
}
}
def
install_cuda
(
version
,
base_path
,
download_path
):
def
install_cuda
(
version
,
base_path
,
download_path
):
formatted_version
=
f
"
{
version
[:
-
1
]
}
.
{
version
[
-
1
]
}
"
formatted_version
=
f
"
{
version
[:
-
1
]
}
.
{
version
[
-
1
]
}
"
folder
=
f
"cuda-
{
formatted_version
}
"
folder
=
f
"cuda-
{
formatted_version
}
"
...
@@ -29,7 +30,7 @@ def install_cuda(version, base_path, download_path):
...
@@ -29,7 +30,7 @@ def install_cuda(version, base_path, download_path):
subprocess
.
run
([
"rm"
,
"-rf"
,
install_path
],
check
=
True
)
subprocess
.
run
([
"rm"
,
"-rf"
,
install_path
],
check
=
True
)
url
=
cuda_versions
[
version
]
url
=
cuda_versions
[
version
]
filename
=
url
.
split
(
'/'
)[
-
1
]
filename
=
url
.
split
(
"/"
)[
-
1
]
filepath
=
os
.
path
.
join
(
download_path
,
filename
)
filepath
=
os
.
path
.
join
(
download_path
,
filename
)
if
not
os
.
path
.
exists
(
filepath
):
if
not
os
.
path
.
exists
(
filepath
):
...
@@ -44,9 +45,14 @@ def install_cuda(version, base_path, download_path):
...
@@ -44,9 +45,14 @@ def install_cuda(version, base_path, download_path):
# Install CUDA
# Install CUDA
print
(
f
"Installing CUDA version
{
version
}
..."
)
print
(
f
"Installing CUDA version
{
version
}
..."
)
install_command
=
[
install_command
=
[
"bash"
,
filepath
,
"bash"
,
"--no-drm"
,
"--no-man-page"
,
"--override"
,
filepath
,
"--toolkitpath="
+
install_path
,
"--toolkit"
,
"--silent"
"--no-drm"
,
"--no-man-page"
,
"--override"
,
"--toolkitpath="
+
install_path
,
"--toolkit"
,
"--silent"
,
]
]
print
(
f
"Running command:
{
' '
.
join
(
install_command
)
}
"
)
print
(
f
"Running command:
{
' '
.
join
(
install_command
)
}
"
)
...
@@ -62,6 +68,7 @@ def install_cuda(version, base_path, download_path):
...
@@ -62,6 +68,7 @@ def install_cuda(version, base_path, download_path):
print
(
f
"CUDA version
{
version
}
installed at
{
install_path
}
"
)
print
(
f
"CUDA version
{
version
}
installed at
{
install_path
}
"
)
def
main
():
def
main
():
user_base_path
=
os
.
path
.
expanduser
(
"~/cuda"
)
user_base_path
=
os
.
path
.
expanduser
(
"~/cuda"
)
system_base_path
=
"/usr/local/cuda"
system_base_path
=
"/usr/local/cuda"
...
@@ -93,5 +100,6 @@ def main():
...
@@ -93,5 +100,6 @@ def main():
print
(
f
"Invalid CUDA version:
{
version
}
. Available versions are:
{
', '
.
join
(
cuda_versions
.
keys
())
}
"
)
print
(
f
"Invalid CUDA version:
{
version
}
. Available versions are:
{
', '
.
join
(
cuda_versions
.
keys
())
}
"
)
sys
.
exit
(
1
)
sys
.
exit
(
1
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
main
()
pyproject.toml
View file @
06029dd6
...
@@ -8,6 +8,10 @@ src = [
...
@@ -8,6 +8,10 @@ src = [
"tests"
,
"tests"
,
"benchmarking"
"benchmarking"
]
]
target-version
=
"py38"
line-length
=
119
[tool.ruff.lint]
select
=
[
select
=
[
"B"
,
# bugbear: security warnings
"B"
,
# bugbear: security warnings
"E"
,
# pycodestyle
"E"
,
# pycodestyle
...
@@ -17,7 +21,6 @@ select = [
...
@@ -17,7 +21,6 @@ select = [
"UP"
,
# alert you when better syntax is available in your python version
"UP"
,
# alert you when better syntax is available in your python version
"RUF"
,
# the ruff developer's own rules
"RUF"
,
# the ruff developer's own rules
]
]
target-version
=
"py38"
ignore
=
[
ignore
=
[
"B007"
,
# Loop control variable not used within the loop body (TODO: enable)
"B007"
,
# Loop control variable not used within the loop body (TODO: enable)
"B028"
,
# Warning without stacklevel (TODO: enable)
"B028"
,
# Warning without stacklevel (TODO: enable)
...
@@ -30,7 +33,7 @@ ignore = [
...
@@ -30,7 +33,7 @@ ignore = [
]
]
ignore-init-module-imports
=
true
# allow to expose in __init__.py via imports
ignore-init-module-imports
=
true
# allow to expose in __init__.py via imports
[tool.ruff.extend-per-file-ignores]
[tool.ruff.
lint.
extend-per-file-ignores]
"**/__init__.py"
=
["F401"]
# allow unused imports in __init__.py
"**/__init__.py"
=
["F401"]
# allow unused imports in __init__.py
"{benchmarking,tests}/**/*.py"
=
[
"{benchmarking,tests}/**/*.py"
=
[
"B007"
,
"B007"
,
...
@@ -42,7 +45,7 @@ ignore-init-module-imports = true # allow to expose in __init__.py via imports
...
@@ -42,7 +45,7 @@ ignore-init-module-imports = true # allow to expose in __init__.py via imports
"UP030"
,
"UP030"
,
]
]
[tool.ruff.isort]
[tool.ruff.
lint.
isort]
combine-as-imports
=
true
combine-as-imports
=
true
detect-same-package
=
true
detect-same-package
=
true
force-sort-within-sections
=
true
force-sort-within-sections
=
true
...
...
scripts/stale.py
View file @
06029dd6
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
Script to close stale issue. Taken in part from the AllenNLP repository.
Script to close stale issue. Taken in part from the AllenNLP repository.
https://github.com/allenai/allennlp.
https://github.com/allenai/allennlp.
"""
"""
from
datetime
import
datetime
as
dt
,
timezone
from
datetime
import
datetime
as
dt
,
timezone
import
os
import
os
...
@@ -50,7 +51,7 @@ def main():
...
@@ -50,7 +51,7 @@ def main():
issue
.
create_comment
(
issue
.
create_comment
(
"This issue has been automatically marked as stale because it has not had "
"This issue has been automatically marked as stale because it has not had "
"recent activity. If you think this still needs to be addressed "
"recent activity. If you think this still needs to be addressed "
"please comment on this thread.
\n\n
"
"please comment on this thread.
\n\n
"
,
)
)
...
...
tests/test_autograd.py
View file @
06029dd6
...
@@ -20,7 +20,11 @@ TRANSPOSE_VALS = [(False, True), (False, False)]
...
@@ -20,7 +20,11 @@ TRANSPOSE_VALS = [(False, True), (False, False)]
@
pytest
.
mark
.
parametrize
(
"dim2"
,
get_test_dims
(
32
,
96
,
n
=
1
),
ids
=
id_formatter
(
"dim2"
))
@
pytest
.
mark
.
parametrize
(
"dim2"
,
get_test_dims
(
32
,
96
,
n
=
1
),
ids
=
id_formatter
(
"dim2"
))
@
pytest
.
mark
.
parametrize
(
"dim3"
,
get_test_dims
(
32
,
96
,
n
=
1
),
ids
=
id_formatter
(
"dim3"
))
@
pytest
.
mark
.
parametrize
(
"dim3"
,
get_test_dims
(
32
,
96
,
n
=
1
),
ids
=
id_formatter
(
"dim3"
))
@
pytest
.
mark
.
parametrize
(
"dim4"
,
get_test_dims
(
32
,
96
,
n
=
1
),
ids
=
id_formatter
(
"dim4"
))
@
pytest
.
mark
.
parametrize
(
"dim4"
,
get_test_dims
(
32
,
96
,
n
=
1
),
ids
=
id_formatter
(
"dim4"
))
@
pytest
.
mark
.
parametrize
(
"funcs"
,
[(
torch
.
bmm
,
bnb
.
bmm_cublas
),
(
torch
.
matmul
,
bnb
.
matmul_cublas
)],
ids
=
[
"func=bmm"
,
"func=matmul"
])
@
pytest
.
mark
.
parametrize
(
"funcs"
,
[(
torch
.
bmm
,
bnb
.
bmm_cublas
),
(
torch
.
matmul
,
bnb
.
matmul_cublas
)],
ids
=
[
"func=bmm"
,
"func=matmul"
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"req_grad"
,
BOOLEAN_TUPLES
,
ids
=
id_formatter
(
"req_grad"
))
@
pytest
.
mark
.
parametrize
(
"req_grad"
,
BOOLEAN_TUPLES
,
ids
=
id_formatter
(
"req_grad"
))
@
pytest
.
mark
.
parametrize
(
"transpose"
,
BOOLEAN_TUPLES
,
ids
=
id_formatter
(
"transpose"
))
@
pytest
.
mark
.
parametrize
(
"transpose"
,
BOOLEAN_TUPLES
,
ids
=
id_formatter
(
"transpose"
))
...
@@ -30,16 +34,13 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
...
@@ -30,16 +34,13 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
dim3
=
dim3
-
(
dim3
%
16
)
dim3
=
dim3
-
(
dim3
%
16
)
dim4
=
dim4
-
(
dim4
%
16
)
dim4
=
dim4
-
(
dim4
%
16
)
for
i
in
range
(
25
):
for
i
in
range
(
25
):
# normal multiply
# normal multiply
if
funcs
[
0
]
in
[
torch
.
mm
,
torch
.
matmul
]:
if
funcs
[
0
]
in
[
torch
.
mm
,
torch
.
matmul
]:
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
A
=
torch
.
randn
(
size
=
dimA
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
])
A
=
torch
.
randn
(
size
=
dimA
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
])
B
=
torch
.
randn
(
size
=
dimB
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
])
B
=
torch
.
randn
(
size
=
dimB
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
])
target
=
torch
.
randn
(
target
=
torch
.
randn
(
size
=
(
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
])
size
=
(
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
]
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
...
@@ -71,9 +72,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
...
@@ -71,9 +72,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
A
.
grad
=
None
A
.
grad
=
None
B
.
grad
=
None
B
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradA2
=
A
.
grad
gradB2
=
B
.
grad
gradB2
=
B
.
grad
...
@@ -81,18 +80,14 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
...
@@ -81,18 +80,14 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
B
.
grad
=
None
B
.
grad
=
None
if
req_grad
[
0
]:
if
req_grad
[
0
]:
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
n
=
gradB1
.
numel
()
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
# batched matrix multiply
# batched matrix multiply
if
funcs
[
0
]
in
[
torch
.
bmm
,
torch
.
matmul
]:
if
funcs
[
0
]
in
[
torch
.
bmm
,
torch
.
matmul
]:
...
@@ -119,9 +114,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
...
@@ -119,9 +114,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
n
=
out_bnb
.
numel
()
n
=
out_bnb
.
numel
()
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.01
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.01
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
out_bnb
,
out_torch
,
atol
=
0.027
,
rtol
=
0.2
)
out_bnb
,
out_torch
,
atol
=
0.027
,
rtol
=
0.2
)
if
any
(
req_grad
):
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
out_bnb
.
data
.
copy_
(
out_torch
)
...
@@ -133,9 +126,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
...
@@ -133,9 +126,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
A
.
grad
=
None
A
.
grad
=
None
B
.
grad
=
None
B
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradA2
=
A
.
grad
gradB2
=
B
.
grad
gradB2
=
B
.
grad
...
@@ -143,9 +134,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
...
@@ -143,9 +134,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
B
.
grad
=
None
B
.
grad
=
None
if
req_grad
[
0
]:
if
req_grad
[
0
]:
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
n
=
gradB1
.
numel
()
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
...
@@ -192,9 +181,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
...
@@ -192,9 +181,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
A
.
grad
=
None
A
.
grad
=
None
B
.
grad
=
None
B
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradA2
=
A
.
grad
gradB2
=
B
.
grad
gradB2
=
B
.
grad
...
@@ -202,9 +189,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
...
@@ -202,9 +189,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
B
.
grad
=
None
B
.
grad
=
None
if
req_grad
[
0
]:
if
req_grad
[
0
]:
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
n
=
gradB1
.
numel
()
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
...
@@ -218,25 +203,17 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
...
@@ -218,25 +203,17 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
@
pytest
.
mark
.
parametrize
(
"dim3"
,
get_test_dims
(
32
,
96
,
n
=
1
),
ids
=
id_formatter
(
"dim3"
))
@
pytest
.
mark
.
parametrize
(
"dim3"
,
get_test_dims
(
32
,
96
,
n
=
1
),
ids
=
id_formatter
(
"dim3"
))
@
pytest
.
mark
.
parametrize
(
"dim4"
,
get_test_dims
(
32
,
96
,
n
=
1
),
ids
=
id_formatter
(
"dim4"
))
@
pytest
.
mark
.
parametrize
(
"dim4"
,
get_test_dims
(
32
,
96
,
n
=
1
),
ids
=
id_formatter
(
"dim4"
))
@
pytest
.
mark
.
parametrize
(
"decomp"
,
[
0.0
,
6.0
],
ids
=
id_formatter
(
"decomp"
))
@
pytest
.
mark
.
parametrize
(
"decomp"
,
[
0.0
,
6.0
],
ids
=
id_formatter
(
"decomp"
))
@
pytest
.
mark
.
parametrize
(
"funcs"
,
[(
torch
.
matmul
,
bnb
.
matmul
),
(
torch
.
matmul
,
bnb
.
research
.
switchback_bnb
)],
ids
=
[
"func=matmul"
,
"func=switchback_bnb"
])
@
pytest
.
mark
.
parametrize
(
"funcs"
,
[(
torch
.
matmul
,
bnb
.
matmul
),
(
torch
.
matmul
,
bnb
.
research
.
switchback_bnb
)],
ids
=
[
"func=matmul"
,
"func=switchback_bnb"
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"req_grad"
,
BOOLEAN_TRIPLES
,
ids
=
id_formatter
(
"req_grad"
))
@
pytest
.
mark
.
parametrize
(
"req_grad"
,
BOOLEAN_TRIPLES
,
ids
=
id_formatter
(
"req_grad"
))
@
pytest
.
mark
.
parametrize
(
"transpose"
,
TRANSPOSE_VALS
,
ids
=
id_formatter
(
"transpose"
))
@
pytest
.
mark
.
parametrize
(
"transpose"
,
TRANSPOSE_VALS
,
ids
=
id_formatter
(
"transpose"
))
@
pytest
.
mark
.
parametrize
(
"has_fp16_weights"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"has_fp16_weights"
))
@
pytest
.
mark
.
parametrize
(
"has_fp16_weights"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"has_fp16_weights"
))
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"has_bias"
))
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"has_bias"
))
def
test_matmullt
(
def
test_matmullt
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
decomp
,
has_fp16_weights
,
has_bias
):
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
decomp
,
has_fp16_weights
,
has_bias
):
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
outlier_dim
=
torch
.
randint
(
0
,
dimA
[
1
],
size
=
(
dimA
[
1
]
//
8
,),
device
=
"cuda"
)
outlier_dim
=
torch
.
randint
(
0
,
dimA
[
1
],
size
=
(
dimA
[
1
]
//
8
,),
device
=
"cuda"
)
...
@@ -245,18 +222,13 @@ def test_matmullt(
...
@@ -245,18 +222,13 @@ def test_matmullt(
req_grad
[
2
]
=
False
req_grad
[
2
]
=
False
for
i
in
range
(
3
):
for
i
in
range
(
3
):
# normal multiply
# normal multiply
if
funcs
[
0
]
in
[
torch
.
mm
,
torch
.
matmul
]:
if
funcs
[
0
]
in
[
torch
.
mm
,
torch
.
matmul
]:
A
=
torch
.
randn
(
A
=
torch
.
randn
(
size
=
dimA
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
],
dtype
=
dtype
)
size
=
dimA
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
],
dtype
=
dtype
)
if
decomp
==
6.0
:
if
decomp
==
6.0
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
A
[:,
outlier_dim
]
=
6.0
A
[:,
outlier_dim
]
=
6.0
B
=
torch
.
randn
(
B
=
torch
.
randn
(
size
=
dimB
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
size
=
dimB
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
target
=
torch
.
randn
(
target
=
torch
.
randn
(
size
=
(
dim2
,
dim4
),
size
=
(
dim2
,
dim4
),
device
=
"cuda"
,
device
=
"cuda"
,
...
@@ -266,7 +238,7 @@ def test_matmullt(
...
@@ -266,7 +238,7 @@ def test_matmullt(
bias
=
None
bias
=
None
bias2
=
None
bias2
=
None
if
has_bias
:
if
has_bias
:
bias
=
torch
.
randn
(
dim4
,
device
=
'
cuda
'
,
dtype
=
dtype
,
requires_grad
=
req_grad
[
2
])
bias
=
torch
.
randn
(
dim4
,
device
=
"
cuda
"
,
dtype
=
dtype
,
requires_grad
=
req_grad
[
2
])
bias2
=
bias
.
clone
()
bias2
=
bias
.
clone
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
B2
=
B
.
clone
()
B2
=
B
.
clone
()
...
@@ -311,9 +283,7 @@ def test_matmullt(
...
@@ -311,9 +283,7 @@ def test_matmullt(
if
any
(
req_grad
):
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
out_bnb
.
data
.
copy_
(
out_torch
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
loss_bnb
=
torch
.
nn
.
functional
.
mse_loss
(
loss_bnb
=
torch
.
nn
.
functional
.
mse_loss
(
out_bnb
,
target
).
mean
()
out_bnb
,
target
).
mean
()
loss_bnb
.
backward
()
loss_bnb
.
backward
()
gradA1
=
A
.
grad
gradA1
=
A
.
grad
gradB1
=
B
.
grad
gradB1
=
B
.
grad
...
@@ -323,9 +293,7 @@ def test_matmullt(
...
@@ -323,9 +293,7 @@ def test_matmullt(
gradBias1
=
bias
.
grad
gradBias1
=
bias
.
grad
bias
.
grad
=
None
bias
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradA2
=
A
.
grad
gradB2
=
B
.
grad
gradB2
=
B
.
grad
...
@@ -336,9 +304,7 @@ def test_matmullt(
...
@@ -336,9 +304,7 @@ def test_matmullt(
bias
.
grad
=
None
bias
.
grad
=
None
if
req_grad
[
0
]:
if
req_grad
[
0
]:
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
n
=
gradB1
.
numel
()
if
dim2
>
0
:
if
dim2
>
0
:
...
@@ -352,9 +318,7 @@ def test_matmullt(
...
@@ -352,9 +318,7 @@ def test_matmullt(
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.1
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.02
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.02
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
if
req_grad
[
2
]:
if
req_grad
[
2
]:
torch
.
testing
.
assert_close
(
gradBias1
,
gradBias2
)
torch
.
testing
.
assert_close
(
gradBias1
,
gradBias2
)
...
@@ -370,8 +334,20 @@ def test_matmullt(
...
@@ -370,8 +334,20 @@ def test_matmullt(
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"has_bias"
))
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"has_bias"
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
float32
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
float32
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"compress_statistics"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"compress_statistics"
))
@
pytest
.
mark
.
parametrize
(
"compress_statistics"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"compress_statistics"
))
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
'fp4'
,
'nf4'
],
ids
=
id_formatter
(
"quant_type"
))
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
"fp4"
,
"nf4"
],
ids
=
id_formatter
(
"quant_type"
))
def
test_matmul_4bit
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
has_bias
,
compress_statistics
,
quant_type
):
def
test_matmul_4bit
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
has_bias
,
compress_statistics
,
quant_type
,
):
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
if
has_bias
==
False
:
if
has_bias
==
False
:
...
@@ -387,11 +363,15 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
...
@@ -387,11 +363,15 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
bias
=
None
bias
=
None
bias2
=
None
bias2
=
None
if
has_bias
:
if
has_bias
:
bias
=
torch
.
randn
(
dim4
,
device
=
'
cuda
'
,
dtype
=
dtype
,
requires_grad
=
req_grad
[
2
])
bias
=
torch
.
randn
(
dim4
,
device
=
"
cuda
"
,
dtype
=
dtype
,
requires_grad
=
req_grad
[
2
])
bias2
=
bias
.
clone
()
bias2
=
bias
.
clone
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
B2
,
quant_state
=
bnb
.
functional
.
quantize_4bit
(
B
,
compress_statistics
=
compress_statistics
,
quant_type
=
quant_type
)
B2
,
quant_state
=
bnb
.
functional
.
quantize_4bit
(
B
,
compress_statistics
=
compress_statistics
,
quant_type
=
quant_type
,
)
if
not
transpose
[
0
]
and
transpose
[
1
]:
if
not
transpose
[
0
]
and
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
.
t
())
out_torch
=
funcs
[
0
](
A
,
B
.
t
())
...
@@ -410,7 +390,7 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
...
@@ -410,7 +390,7 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
if
n
>
0
:
if
n
>
0
:
assert
err
<
0.115
assert
err
<
0.115
#assert err < 0.20
#
assert err < 0.20
if
any
(
req_grad
):
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
out_bnb
.
data
.
copy_
(
out_torch
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -424,7 +404,7 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
...
@@ -424,7 +404,7 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
gradBias1
=
bias
.
grad
gradBias1
=
bias
.
grad
bias
.
grad
=
None
bias
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradA2
=
A
.
grad
gradB2
=
B
.
grad
gradB2
=
B
.
grad
...
@@ -435,7 +415,7 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
...
@@ -435,7 +415,7 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
bias
.
grad
=
None
bias
.
grad
=
None
if
req_grad
[
0
]:
if
req_grad
[
0
]:
torch
.
testing
.
assert_close
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
torch
.
testing
.
assert_close
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
2
]:
if
req_grad
[
2
]:
torch
.
testing
.
assert_close
(
gradBias1
,
gradBias2
)
torch
.
testing
.
assert_close
(
gradBias1
,
gradBias2
)
...
@@ -448,8 +428,12 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
...
@@ -448,8 +428,12 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
@
pytest
.
mark
.
parametrize
(
"req_grad"
,
BOOLEAN_TRIPLES
,
ids
=
id_formatter
(
"req_grad"
))
@
pytest
.
mark
.
parametrize
(
"req_grad"
,
BOOLEAN_TRIPLES
,
ids
=
id_formatter
(
"req_grad"
))
@
pytest
.
mark
.
parametrize
(
"transpose"
,
TRANSPOSE_VALS
,
ids
=
id_formatter
(
"transpose"
))
@
pytest
.
mark
.
parametrize
(
"transpose"
,
TRANSPOSE_VALS
,
ids
=
id_formatter
(
"transpose"
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
float32
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
float32
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"funcs"
,
[(
torch
.
matmul
,
bnb
.
research
.
matmul_fp8_mixed
),
(
torch
.
matmul
,
bnb
.
research
.
matmul_fp8_global
)],
ids
=
[
"matmul_fp8_mixed"
,
'matmul_fp8_global'
])
@
pytest
.
mark
.
parametrize
(
def
test_matmul_fp8
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
):
"funcs"
,
[(
torch
.
matmul
,
bnb
.
research
.
matmul_fp8_mixed
),
(
torch
.
matmul
,
bnb
.
research
.
matmul_fp8_global
)],
ids
=
[
"matmul_fp8_mixed"
,
"matmul_fp8_global"
],
)
def
test_matmul_fp8
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
):
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
req_grad
=
list
(
req_grad
)
req_grad
=
list
(
req_grad
)
...
@@ -480,7 +464,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -480,7 +464,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
err
=
torch
.
abs
(
out_bnb
-
out_torch
).
float
().
mean
().
item
()
err
=
torch
.
abs
(
out_bnb
-
out_torch
).
float
().
mean
().
item
()
if
n
>
0
:
if
n
>
0
:
assert
err
<
0.115
assert
err
<
0.115
#assert err < 0.20
#
assert err < 0.20
if
any
(
req_grad
):
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
out_bnb
.
data
.
copy_
(
out_torch
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -491,7 +475,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -491,7 +475,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
A
.
grad
=
None
A
.
grad
=
None
B
.
grad
=
None
B
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradA2
=
A
.
grad
gradB2
=
B
.
grad
gradB2
=
B
.
grad
...
@@ -499,7 +483,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -499,7 +483,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B
.
grad
=
None
B
.
grad
=
None
if
req_grad
[
0
]:
if
req_grad
[
0
]:
torch
.
testing
.
assert_close
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
torch
.
testing
.
assert_close
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
n
=
gradB1
.
numel
()
...
@@ -514,8 +498,6 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -514,8 +498,6 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.1
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.02
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.02
grad_err
=
(
gradB1
-
gradB2
).
abs
().
mean
()
grad_err
=
(
gradB1
-
gradB2
).
abs
().
mean
()
assert
grad_err
.
item
()
<
0.003
assert
grad_err
.
item
()
<
0.003
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
tests/test_cuda_setup_evaluator.py
View file @
06029dd6
...
@@ -35,7 +35,4 @@ def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
...
@@ -35,7 +35,4 @@ def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
def
test_get_cuda_bnb_library_path_nocublaslt
(
monkeypatch
,
cuda111_noblas_spec
):
def
test_get_cuda_bnb_library_path_nocublaslt
(
monkeypatch
,
cuda111_noblas_spec
):
monkeypatch
.
delenv
(
"BNB_CUDA_VERSION"
,
raising
=
False
)
monkeypatch
.
delenv
(
"BNB_CUDA_VERSION"
,
raising
=
False
)
assert
(
assert
get_cuda_bnb_library_path
(
cuda111_noblas_spec
).
stem
==
"libbitsandbytes_cuda111_nocublaslt"
get_cuda_bnb_library_path
(
cuda111_noblas_spec
).
stem
==
"libbitsandbytes_cuda111_nocublaslt"
)
tests/test_functional.py
View file @
06029dd6
...
@@ -19,9 +19,7 @@ from tests.helpers import (
...
@@ -19,9 +19,7 @@ from tests.helpers import (
id_formatter
,
id_formatter
,
)
)
torch
.
set_printoptions
(
torch
.
set_printoptions
(
precision
=
5
,
sci_mode
=
False
,
linewidth
=
120
,
edgeitems
=
20
,
threshold
=
10000
)
precision
=
5
,
sci_mode
=
False
,
linewidth
=
120
,
edgeitems
=
20
,
threshold
=
10000
)
k
=
20
k
=
20
...
@@ -98,9 +96,7 @@ def teardown():
...
@@ -98,9 +96,7 @@ def teardown():
pass
pass
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
[
"float"
,
"half"
])
"dtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
[
"float"
,
"half"
]
)
def
test_estimate_quantiles
(
dtype
):
def
test_estimate_quantiles
(
dtype
):
A
=
torch
.
rand
(
1024
,
1024
,
device
=
"cuda"
)
A
=
torch
.
rand
(
1024
,
1024
,
device
=
"cuda"
)
A
=
A
.
to
(
dtype
)
A
=
A
.
to
(
dtype
)
...
@@ -136,7 +132,6 @@ def test_quantile_quantization():
...
@@ -136,7 +132,6 @@ def test_quantile_quantization():
assert
diff
<
0.001
assert
diff
<
0.001
def
test_dynamic_quantization
():
def
test_dynamic_quantization
():
diffs
=
[]
diffs
=
[]
reldiffs
=
[]
reldiffs
=
[]
...
@@ -149,8 +144,8 @@ def test_dynamic_quantization():
...
@@ -149,8 +144,8 @@ def test_dynamic_quantization():
diffs
.
append
(
diff
.
mean
().
item
())
diffs
.
append
(
diff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
assert
diff
.
mean
().
item
()
<
0.0135
assert
diff
.
mean
().
item
()
<
0.0135
print
(
sum
(
diffs
)
/
len
(
diffs
))
print
(
sum
(
diffs
)
/
len
(
diffs
))
print
(
sum
(
reldiffs
)
/
len
(
reldiffs
))
print
(
sum
(
reldiffs
)
/
len
(
reldiffs
))
for
i
in
range
(
100
):
for
i
in
range
(
100
):
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
"cuda"
)
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
"cuda"
)
...
@@ -161,13 +156,12 @@ def test_dynamic_quantization():
...
@@ -161,13 +156,12 @@ def test_dynamic_quantization():
assert
diff
<
0.004
assert
diff
<
0.004
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"nested"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"nested"
))
@
pytest
.
mark
.
parametrize
(
"nested"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"nested"
))
@
pytest
.
mark
.
parametrize
(
"blocksize"
,
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
])
@
pytest
.
mark
.
parametrize
(
"blocksize"
,
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
])
@
pytest
.
mark
.
parametrize
(
"signed"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"signed"
))
@
pytest
.
mark
.
parametrize
(
"signed"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"signed"
))
def
test_dynamic_blockwise_quantization
(
dtype
,
nested
,
blocksize
,
signed
):
def
test_dynamic_blockwise_quantization
(
dtype
,
nested
,
blocksize
,
signed
):
#print('')
#
print('')
diffs
=
[]
diffs
=
[]
reldiffs
=
[]
reldiffs
=
[]
for
i
in
range
(
100
):
for
i
in
range
(
100
):
...
@@ -178,10 +172,10 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
...
@@ -178,10 +172,10 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
reldiff
=
diff
/
torch
.
abs
(
A1
.
float
()
+
1e-8
)
reldiff
=
diff
/
torch
.
abs
(
A1
.
float
()
+
1e-8
)
diffs
.
append
(
diff
.
mean
().
item
())
diffs
.
append
(
diff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
abserr
=
sum
(
diffs
)
/
len
(
diffs
)
abserr
=
sum
(
diffs
)
/
len
(
diffs
)
relerr
=
sum
(
reldiffs
)
/
len
(
reldiffs
)
relerr
=
sum
(
reldiffs
)
/
len
(
reldiffs
)
#print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
#
print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
#print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
#
print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
assert
abserr
<
0.011
assert
abserr
<
0.011
assert
relerr
<
0.018
assert
relerr
<
0.018
assert
A2
.
dtype
==
dtype
assert
A2
.
dtype
==
dtype
...
@@ -196,9 +190,9 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
...
@@ -196,9 +190,9 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
reldiff
=
diff
/
torch
.
abs
(
A1
.
float
()
+
1e-8
)
reldiff
=
diff
/
torch
.
abs
(
A1
.
float
()
+
1e-8
)
diffs
.
append
(
diff
.
mean
().
item
())
diffs
.
append
(
diff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
#torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
#
torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
abserr
=
sum
(
diffs
)
/
len
(
diffs
)
abserr
=
sum
(
diffs
)
/
len
(
diffs
)
relerr
=
sum
(
reldiffs
)
/
len
(
reldiffs
)
relerr
=
sum
(
reldiffs
)
/
len
(
reldiffs
)
if
signed
:
if
signed
:
assert
abserr
<
0.0035
assert
abserr
<
0.0035
assert
relerr
<
0.015
assert
relerr
<
0.015
...
@@ -206,14 +200,11 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
...
@@ -206,14 +200,11 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
assert
abserr
<
0.00175
assert
abserr
<
0.00175
assert
relerr
<
0.012
assert
relerr
<
0.012
assert
A2
.
dtype
==
dtype
assert
A2
.
dtype
==
dtype
#print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
# print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
#print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
# print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"gtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
[
"float"
,
"half"
])
"gtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
[
"float"
,
"half"
]
)
def
test_percentile_clipping
(
gtype
):
def
test_percentile_clipping
(
gtype
):
gnorm_vec1
=
torch
.
zeros
(
100
,
device
=
"cuda"
)
gnorm_vec1
=
torch
.
zeros
(
100
,
device
=
"cuda"
)
gnorm_vec2
=
torch
.
zeros
(
100
,
device
=
"cuda"
)
gnorm_vec2
=
torch
.
zeros
(
100
,
device
=
"cuda"
)
...
@@ -223,9 +214,7 @@ def test_percentile_clipping(gtype):
...
@@ -223,9 +214,7 @@ def test_percentile_clipping(gtype):
for
i
in
range
(
k
):
for
i
in
range
(
k
):
step
+=
1
step
+=
1
g
=
torch
.
randn
(
n
,
n
,
dtype
=
gtype
,
device
=
"cuda"
)
g
=
torch
.
randn
(
n
,
n
,
dtype
=
gtype
,
device
=
"cuda"
)
gnorm1
,
clip2
,
gnorm_scale
=
F
.
percentile_clipping
(
gnorm1
,
clip2
,
gnorm_scale
=
F
.
percentile_clipping
(
g
,
gnorm_vec2
,
step
,
percentile
=
percentile
)
g
,
gnorm_vec2
,
step
,
percentile
=
percentile
)
assert
gnorm_scale
==
1.0
if
gnorm1
<
clip2
else
clip2
/
gnorm1
assert
gnorm_scale
==
1.0
if
gnorm1
<
clip2
else
clip2
/
gnorm1
gnorm2
=
torch
.
norm
(
g
.
float
())
gnorm2
=
torch
.
norm
(
g
.
float
())
...
@@ -309,7 +298,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
...
@@ -309,7 +298,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
dim2
=
dim2
-
(
dim2
%
32
)
dim2
=
dim2
-
(
dim2
%
32
)
errors
=
[]
errors
=
[]
relerrors
=
[]
relerrors
=
[]
#print("")
#
print("")
for
i
in
range
(
5
):
for
i
in
range
(
5
):
if
batched
:
if
batched
:
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
32
,
dim1
,
dim2
//
32
),
device
=
"cuda"
)
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
32
,
dim1
,
dim2
//
32
),
device
=
"cuda"
)
...
@@ -321,9 +310,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
...
@@ -321,9 +310,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim2
,
dim1
),
device
=
"cuda"
)
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim2
,
dim1
),
device
=
"cuda"
)
maxA
,
Ac
=
quant_methods
[
0
](
A
,
1
)
maxA
,
Ac
=
quant_methods
[
0
](
A
,
1
)
maxB
,
Bc
=
quant_methods
[
1
](
B
,
0
)
maxB
,
Bc
=
quant_methods
[
1
](
B
,
0
)
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
quant_methods
[
2
](
maxA
,
Ac
),
A
,
atol
=
0.025
,
rtol
=
0.05
)
quant_methods
[
2
](
maxA
,
Ac
),
A
,
atol
=
0.025
,
rtol
=
0.05
)
if
batched
:
if
batched
:
out2
=
torch
.
bmm
(
A
,
B
)
out2
=
torch
.
bmm
(
A
,
B
)
C
=
torch
.
bmm
(
Ac
.
float
(),
Bc
.
float
())
C
=
torch
.
bmm
(
Ac
.
float
(),
Bc
.
float
())
...
@@ -338,8 +325,8 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
...
@@ -338,8 +325,8 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
relerr
=
err
/
torch
.
abs
(
out2
)
relerr
=
err
/
torch
.
abs
(
out2
)
errors
.
append
(
err
.
mean
().
item
())
errors
.
append
(
err
.
mean
().
item
())
relerrors
.
append
(
relerr
.
mean
().
item
())
relerrors
.
append
(
relerr
.
mean
().
item
())
#print(mean(errors))
#
print(mean(errors))
#print(mean(relerrors))
#
print(mean(relerrors))
def
test_stable_embedding
():
def
test_stable_embedding
():
...
@@ -356,16 +343,8 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
...
@@ -356,16 +343,8 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
batch_dim
=
batch_dim
-
(
batch_dim
%
16
)
batch_dim
=
batch_dim
-
(
batch_dim
%
16
)
seq_dim
=
seq_dim
-
(
seq_dim
%
16
)
seq_dim
=
seq_dim
-
(
seq_dim
%
16
)
for
i
in
range
(
k
):
for
i
in
range
(
k
):
shapeA
=
(
shapeA
=
(
batch_dim
,
hidden_dim
)
if
not
transpose
[
0
]
else
(
hidden_dim
,
batch_dim
)
(
batch_dim
,
hidden_dim
)
shapeB
=
(
32
*
random
.
randint
(
1
,
4
),
hidden_dim
)
if
transpose
[
1
]
else
(
hidden_dim
,
32
*
random
.
randint
(
1
,
4
))
if
not
transpose
[
0
]
else
(
hidden_dim
,
batch_dim
)
)
shapeB
=
(
(
32
*
random
.
randint
(
1
,
4
),
hidden_dim
)
if
transpose
[
1
]
else
(
hidden_dim
,
32
*
random
.
randint
(
1
,
4
))
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
"cuda"
).
to
(
torch
.
int8
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
"cuda"
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeB
,
device
=
"cuda"
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeB
,
device
=
"cuda"
).
to
(
torch
.
int8
)
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
...
@@ -385,11 +364,7 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
...
@@ -385,11 +364,7 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
for
i
in
range
(
k
):
for
i
in
range
(
k
):
shapeA
=
(
batch_dim
,
seq_dim
,
hidden_dim
)
shapeA
=
(
batch_dim
,
seq_dim
,
hidden_dim
)
shapeB
=
(
shapeB
=
(
32
*
random
.
randint
(
1
,
4
),
hidden_dim
)
if
transpose
[
1
]
else
(
hidden_dim
,
32
*
random
.
randint
(
1
,
4
))
(
32
*
random
.
randint
(
1
,
4
),
hidden_dim
)
if
transpose
[
1
]
else
(
hidden_dim
,
32
*
random
.
randint
(
1
,
4
))
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
"cuda"
).
to
(
torch
.
int8
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
"cuda"
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeB
,
device
=
"cuda"
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeB
,
device
=
"cuda"
).
to
(
torch
.
int8
)
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
...
@@ -410,16 +385,10 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
...
@@ -410,16 +385,10 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
hidden_dim
=
hidden_dim
-
(
hidden_dim
%
32
)
hidden_dim
=
hidden_dim
-
(
hidden_dim
%
32
)
batch_dim
=
batch_dim
-
(
batch_dim
%
2
)
batch_dim
=
batch_dim
-
(
batch_dim
%
2
)
for
i
in
range
(
25
):
for
i
in
range
(
25
):
A
=
torch
.
randint
(
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
batch_dim
,
seq_dim
,
hidden_dim
),
device
=
"cuda"
).
to
(
torch
.
int8
)
-
128
,
127
,
size
=
(
batch_dim
,
seq_dim
,
hidden_dim
),
device
=
"cuda"
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
batch_dim
,
seq_dim
,
1024
),
device
=
"cuda"
).
to
(
torch
.
int8
)
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
batch_dim
,
seq_dim
,
1024
),
device
=
"cuda"
).
to
(
torch
.
int8
)
out2
=
torch
.
einsum
(
"bsi, bso->io"
,
A
.
float
(),
B
.
float
())
out2
=
torch
.
einsum
(
"bsi, bso->io"
,
A
.
float
(),
B
.
float
())
iout
=
torch
.
empty
(
iout
=
torch
.
empty
(
A
.
shape
[
2
],
B
.
shape
[
2
],
dtype
=
torch
.
int32
,
device
=
A
.
device
)
A
.
shape
[
2
],
B
.
shape
[
2
],
dtype
=
torch
.
int32
,
device
=
A
.
device
)
out
=
F
.
igemm
(
A
,
B
,
out
=
iout
)
out
=
F
.
igemm
(
A
,
B
,
out
=
iout
)
torch
.
testing
.
assert_close
(
out
.
float
(),
out2
)
torch
.
testing
.
assert_close
(
out
.
float
(),
out2
)
...
@@ -444,9 +413,7 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
...
@@ -444,9 +413,7 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
errs2
=
[]
errs2
=
[]
relerrs2
=
[]
relerrs2
=
[]
for
i
in
range
(
k
):
for
i
in
range
(
k
):
A
=
torch
.
normal
(
A
=
torch
.
normal
(
0.0
,
0.5
,
size
=
(
batch_dim
,
seq_dim
,
hidden_dim
),
device
=
"cuda"
)
0.0
,
0.5
,
size
=
(
batch_dim
,
seq_dim
,
hidden_dim
),
device
=
"cuda"
)
if
transpose
:
if
transpose
:
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
256
,
hidden_dim
),
device
=
"cuda"
)
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
256
,
hidden_dim
),
device
=
"cuda"
)
else
:
else
:
...
@@ -523,9 +490,7 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
...
@@ -523,9 +490,7 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
out2
=
torch
.
bmm
(
A
.
permute
([
0
,
2
,
1
]).
float
(),
B
.
float
())
out2
=
torch
.
bmm
(
A
.
permute
([
0
,
2
,
1
]).
float
(),
B
.
float
())
out
=
F
.
igemm
(
A
.
permute
([
0
,
2
,
1
]),
B
)
out
=
F
.
igemm
(
A
.
permute
([
0
,
2
,
1
]),
B
)
elif
transpose
[
0
]
and
transpose
[
1
]:
elif
transpose
[
0
]
and
transpose
[
1
]:
out2
=
torch
.
bmm
(
out2
=
torch
.
bmm
(
A
.
permute
([
0
,
2
,
1
]).
float
(),
B
.
permute
([
0
,
2
,
1
]).
float
())
A
.
permute
([
0
,
2
,
1
]).
float
(),
B
.
permute
([
0
,
2
,
1
]).
float
()
)
out
=
F
.
igemm
(
A
.
permute
([
0
,
2
,
1
]),
B
.
permute
([
0
,
2
,
1
]))
out
=
F
.
igemm
(
A
.
permute
([
0
,
2
,
1
]),
B
.
permute
([
0
,
2
,
1
]))
torch
.
testing
.
assert_close
(
out
.
float
(),
out2
.
float
())
torch
.
testing
.
assert_close
(
out
.
float
(),
out2
.
float
())
...
@@ -541,7 +506,7 @@ def test_vector_quant(dim1, dim2, dim3):
...
@@ -541,7 +506,7 @@ def test_vector_quant(dim1, dim2, dim3):
qA
,
SA
=
F
.
vectorwise_quant
(
A
,
dim
=
0
)
qA
,
SA
=
F
.
vectorwise_quant
(
A
,
dim
=
0
)
A1
=
F
.
vectorwise_dequant
(
qA
,
SA
)
A1
=
F
.
vectorwise_dequant
(
qA
,
SA
)
n
=
A1
.
numel
()
n
=
A1
.
numel
()
assert_all_approx_close
(
A1
,
A
,
atol
=
0.01
,
rtol
=
0.1
,
count
=
int
(
n
*
0.002
))
assert_all_approx_close
(
A1
,
A
,
atol
=
0.01
,
rtol
=
0.1
,
count
=
int
(
n
*
0.002
))
@
pytest
.
mark
.
parametrize
(
"dim1"
,
get_test_dims
(
2
,
256
,
n
=
2
),
ids
=
id_formatter
(
"dim1"
))
@
pytest
.
mark
.
parametrize
(
"dim1"
,
get_test_dims
(
2
,
256
,
n
=
2
),
ids
=
id_formatter
(
"dim1"
))
...
@@ -565,9 +530,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
...
@@ -565,9 +530,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
if
dims
==
2
:
if
dims
==
2
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
),
device
=
"cuda"
).
to
(
dtype
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
),
device
=
"cuda"
).
to
(
dtype
)
elif
dims
==
3
:
elif
dims
==
3
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
dtype
)
dtype
)
out
,
S
=
F
.
nvidia_transform
(
A
,
to_order
=
orderOut
)
out
,
S
=
F
.
nvidia_transform
(
A
,
to_order
=
orderOut
)
...
@@ -579,17 +542,11 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
...
@@ -579,17 +542,11 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
if
dims
==
2
:
if
dims
==
2
:
n
=
A
.
shape
[
0
]
*
(
A
.
shape
[
1
]
+
(
32
-
(
A
.
shape
[
1
]
%
32
)))
n
=
A
.
shape
[
0
]
*
(
A
.
shape
[
1
]
+
(
32
-
(
A
.
shape
[
1
]
%
32
)))
elif
dims
==
3
:
elif
dims
==
3
:
n
=
(
n
=
A
.
shape
[
0
]
*
A
.
shape
[
1
]
*
(
A
.
shape
[
2
]
+
(
32
-
(
A
.
shape
[
2
]
%
32
)))
A
.
shape
[
0
]
*
A
.
shape
[
1
]
*
(
A
.
shape
[
2
]
+
(
32
-
(
A
.
shape
[
2
]
%
32
)))
)
assert
out
.
numel
()
==
n
assert
out
.
numel
()
==
n
elif
orderOut
==
"col_turing"
:
elif
orderOut
==
"col_turing"
:
# 32 col 8 row tiles
# 32 col 8 row tiles
n
=
(
A
.
shape
[
0
]
+
(
8
-
A
.
shape
[
0
]
%
8
))
*
(
n
=
(
A
.
shape
[
0
]
+
(
8
-
A
.
shape
[
0
]
%
8
))
*
(
A
.
shape
[
1
]
+
(
32
-
(
A
.
shape
[
1
]
%
32
)))
A
.
shape
[
1
]
+
(
32
-
(
A
.
shape
[
1
]
%
32
))
)
assert
out
.
numel
()
==
n
assert
out
.
numel
()
==
n
total_coltile
=
(
A
.
shape
[
1
]
//
32
)
+
(
1
if
A
.
shape
[
1
]
%
32
!=
0
else
0
)
total_coltile
=
(
A
.
shape
[
1
]
//
32
)
+
(
1
if
A
.
shape
[
1
]
%
32
!=
0
else
0
)
for
row
in
range
(
A
.
shape
[
0
]):
for
row
in
range
(
A
.
shape
[
0
]):
...
@@ -598,9 +555,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
...
@@ -598,9 +555,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
j
=
col
j
=
col
coltile
=
(
col
//
32
)
+
(
1
if
col
%
32
!=
0
else
0
)
coltile
=
(
col
//
32
)
+
(
1
if
col
%
32
!=
0
else
0
)
rowtile
=
(
rowtile
=
((
row
//
8
)
+
(
1
if
row
%
8
!=
0
else
0
))
*
total_coltile
(
row
//
8
)
+
(
1
if
row
%
8
!=
0
else
0
)
)
*
total_coltile
offset
=
32
*
8
*
(
rowtile
+
coltile
)
offset
=
32
*
8
*
(
rowtile
+
coltile
)
col2
=
col
%
32
col2
=
col
%
32
row2
=
(
row
%
8
)
*
32
row2
=
(
row
%
8
)
*
32
...
@@ -611,9 +566,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
...
@@ -611,9 +566,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
# torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
# torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
if
orderOut
==
"col32"
:
if
orderOut
==
"col32"
:
out2
,
S
=
F
.
nvidia_transform
(
out2
,
S
=
F
.
nvidia_transform
(
out
,
from_order
=
orderOut
,
to_order
=
"row"
,
state
=
S
)
out
,
from_order
=
orderOut
,
to_order
=
"row"
,
state
=
S
)
torch
.
testing
.
assert_close
(
A
,
out2
)
torch
.
testing
.
assert_close
(
A
,
out2
)
...
@@ -626,16 +579,10 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
...
@@ -626,16 +579,10 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
def
test_igemmlt_int
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
,
ldb
):
def
test_igemmlt_int
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
,
ldb
):
for
i
in
range
(
k
):
for
i
in
range
(
k
):
if
dims
==
2
:
if
dims
==
2
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim3
),
device
=
"cuda"
).
to
(
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
torch
.
int8
)
elif
dims
==
3
:
elif
dims
==
3
:
A
=
torch
.
randint
(
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim4
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim4
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
t
().
float
())
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
t
().
float
())
A2
,
SA
=
F
.
transform
(
A
,
"col32"
)
A2
,
SA
=
F
.
transform
(
A
,
"col32"
)
...
@@ -645,9 +592,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
...
@@ -645,9 +592,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
torch
.
testing
.
assert_close
(
C1
,
C3
.
float
())
torch
.
testing
.
assert_close
(
C1
,
C3
.
float
())
# transpose
# transpose
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim3
,
dim4
),
device
=
"cuda"
).
to
(
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim3
,
dim4
),
device
=
"cuda"
).
to
(
torch
.
int8
)
torch
.
int8
)
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
float
())
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
float
())
B2t
,
SBt
=
F
.
transform
(
B
,
"col_turing"
,
transpose
=
True
)
B2t
,
SBt
=
F
.
transform
(
B
,
"col_turing"
,
transpose
=
True
)
...
@@ -667,9 +612,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
...
@@ -667,9 +612,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
if
dims
==
2
:
if
dims
==
2
:
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim1
,
dim3
),
device
=
"cuda"
).
half
()
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim1
,
dim3
),
device
=
"cuda"
).
half
()
elif
dims
==
3
:
elif
dims
==
3
:
A
=
torch
.
normal
(
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
half
()
0
,
0.5
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
half
()
B
=
torch
.
randn
((
dim4
,
dim3
),
device
=
"cuda"
).
half
()
B
=
torch
.
randn
((
dim4
,
dim3
),
device
=
"cuda"
).
half
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
C1
=
torch
.
matmul
(
A
,
B
.
t
())
C1
=
torch
.
matmul
(
A
,
B
.
t
())
...
@@ -700,6 +643,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
...
@@ -700,6 +643,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
# C3, S = F.transform(C2, 'row', state=SC)
# C3, S = F.transform(C2, 'row', state=SC)
# torch.testing.assert_close(C1, C3.float())
# torch.testing.assert_close(C1, C3.float())
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"batch"
,
"seq"
,
"model"
,
"hidden"
),
(
"batch"
,
"seq"
,
"model"
,
"hidden"
),
[
[
...
@@ -729,7 +673,6 @@ def test_bench_8bit_training(batch, seq, model, hidden):
...
@@ -729,7 +673,6 @@ def test_bench_8bit_training(batch, seq, model, hidden):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
k
):
for
i
in
range
(
k
):
out1
=
torch
.
matmul
(
A
,
w1
.
t
())
# fc1
out1
=
torch
.
matmul
(
A
,
w1
.
t
())
# fc1
# out2 = torch.matmul(out1, w2.t())# fc2
# out2 = torch.matmul(out1, w2.t())# fc2
...
@@ -866,13 +809,15 @@ def test_bench_8bit_training(batch, seq, model, hidden):
...
@@ -866,13 +809,15 @@ def test_bench_8bit_training(batch, seq, model, hidden):
def
test_dequant_mm
(
dim1
,
dim4
,
dims
,
formatB
,
has_bias
):
def
test_dequant_mm
(
dim1
,
dim4
,
dims
,
formatB
,
has_bias
):
inner
=
torch
.
randint
(
1
,
128
,
size
=
(
1
,)).
item
()
inner
=
torch
.
randint
(
1
,
128
,
size
=
(
1
,)).
item
()
bias
=
None
bias
=
None
if
has_bias
:
bias
=
torch
.
randn
(
dim4
,
device
=
'cuda'
,
dtype
=
torch
.
float16
)
if
has_bias
:
bias
=
torch
.
randn
(
dim4
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
formatB
=
F
.
get_special_format_str
()
formatB
=
F
.
get_special_format_str
()
for
i
in
range
(
1
):
for
i
in
range
(
1
):
A
=
torch
.
randn
(
dim1
,
inner
,
device
=
"cuda"
)
A
=
torch
.
randn
(
dim1
,
inner
,
device
=
"cuda"
)
B
=
torch
.
randn
(
dim4
,
inner
,
device
=
"cuda"
)
B
=
torch
.
randn
(
dim4
,
inner
,
device
=
"cuda"
)
C1
=
torch
.
matmul
(
A
.
half
(),
B
.
t
().
half
())
C1
=
torch
.
matmul
(
A
.
half
(),
B
.
t
().
half
())
if
has_bias
:
C1
+=
bias
if
has_bias
:
C1
+=
bias
A1
,
maxA
=
F
.
vectorwise_quant
(
A
,
dim
=
1
)
A1
,
maxA
=
F
.
vectorwise_quant
(
A
,
dim
=
1
)
B1
,
maxB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
)
B1
,
maxB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
)
...
@@ -883,7 +828,8 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
...
@@ -883,7 +828,8 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
C3
,
S
=
F
.
nvidia_transform
(
C2
,
"row"
,
state
=
SC
)
C3
,
S
=
F
.
nvidia_transform
(
C2
,
"row"
,
state
=
SC
)
C4
=
F
.
vectorwise_mm_dequant
(
C3
.
float
(),
maxA
,
maxB
.
t
())
C4
=
F
.
vectorwise_mm_dequant
(
C3
.
float
(),
maxA
,
maxB
.
t
())
if
has_bias
:
C4
+=
bias
if
has_bias
:
C4
+=
bias
# TODO: is something wrong here? If so, the problem goes deeper
# TODO: is something wrong here? If so, the problem goes deeper
# n = C1.numel()
# n = C1.numel()
...
@@ -917,9 +863,7 @@ def test_colrow_absmax(dim1, dim2, dims):
...
@@ -917,9 +863,7 @@ def test_colrow_absmax(dim1, dim2, dims):
else
:
else
:
assert
False
assert
False
row_stats2
,
col_stats2
,
nnz_block_ptr2
=
F
.
get_colrow_absmax
(
row_stats2
,
col_stats2
,
nnz_block_ptr2
=
F
.
get_colrow_absmax
(
A
,
threshold
=
threshold
)
A
,
threshold
=
threshold
)
A_blocked
=
einops
.
rearrange
(
A_blocked
=
einops
.
rearrange
(
torch
.
abs
(
A
),
torch
.
abs
(
A
),
...
@@ -939,9 +883,7 @@ def test_colrow_absmax(dim1, dim2, dims):
...
@@ -939,9 +883,7 @@ def test_colrow_absmax(dim1, dim2, dims):
torch
.
testing
.
assert_close
(
row_stats1_trunc
,
row_stats2
)
torch
.
testing
.
assert_close
(
row_stats1_trunc
,
row_stats2
)
torch
.
testing
.
assert_close
(
nnz_block_ptr1
.
int
(),
nnz_block_ptr2
)
torch
.
testing
.
assert_close
(
nnz_block_ptr1
.
int
(),
nnz_block_ptr2
)
row_stats2
,
col_stats2
,
nnz_block_ptr2
=
F
.
get_colrow_absmax
(
row_stats2
,
col_stats2
,
nnz_block_ptr2
=
F
.
get_colrow_absmax
(
A
,
threshold
=
0.0
)
A
,
threshold
=
0.0
)
torch
.
testing
.
assert_close
(
col_stats1
,
col_stats2
)
torch
.
testing
.
assert_close
(
col_stats1
,
col_stats2
)
torch
.
testing
.
assert_close
(
row_stats1
,
row_stats2
)
torch
.
testing
.
assert_close
(
row_stats1
,
row_stats2
)
...
@@ -963,24 +905,16 @@ def test_double_quant(dim1, dim2):
...
@@ -963,24 +905,16 @@ def test_double_quant(dim1, dim2):
torch
.
testing
.
assert_close
(
CAt
,
out_col1
,
atol
=
1
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
CAt
,
out_col1
,
atol
=
1
,
rtol
=
0
)
n
=
CAt
.
numel
()
n
=
CAt
.
numel
()
num_not_close_rows
=
(
num_not_close_rows
=
(
torch
.
isclose
(
CA
,
out_row1
,
atol
=
1
)
==
0
).
sum
().
item
()
(
torch
.
isclose
(
CA
,
out_row1
,
atol
=
1
)
==
0
).
sum
().
item
()
num_not_close_cols
=
(
torch
.
isclose
(
CAt
,
out_col1
,
atol
=
1
)
==
0
).
sum
().
item
()
)
num_not_close_cols
=
(
(
torch
.
isclose
(
CAt
,
out_col1
,
atol
=
1
)
==
0
).
sum
().
item
()
)
# allow for 1:500 error due to rounding differences
# allow for 1:500 error due to rounding differences
min_error
=
1
/
500
min_error
=
1
/
500
if
num_not_close_cols
>
(
min_error
*
n
):
if
num_not_close_cols
>
(
min_error
*
n
):
print
(
print
(
f
"Min error exceeded
{
num_not_close_cols
}
elements are different. Error:
{
num_not_close_cols
/
n
:.
4
f
}
"
)
f
"Min error exceeded
{
num_not_close_cols
}
elements are different. Error:
{
num_not_close_cols
/
n
:.
4
f
}
"
)
assert
False
assert
False
if
num_not_close_rows
>
(
min_error
*
n
):
if
num_not_close_rows
>
(
min_error
*
n
):
print
(
print
(
f
"Min error exceeded
{
num_not_close_rows
}
elements are different. Error:
{
num_not_close_rows
/
n
:.
4
f
}
"
)
f
"Min error exceeded
{
num_not_close_rows
}
elements are different. Error:
{
num_not_close_rows
/
n
:.
4
f
}
"
)
assert
False
assert
False
torch
.
testing
.
assert_close
(
Srow
.
flatten
().
float
(),
statsA
)
torch
.
testing
.
assert_close
(
Srow
.
flatten
().
float
(),
statsA
)
...
@@ -991,13 +925,12 @@ def test_double_quant(dim1, dim2):
...
@@ -991,13 +925,12 @@ def test_double_quant(dim1, dim2):
(
"dim1"
,
"dim4"
,
"inner"
),
(
"dim1"
,
"dim4"
,
"inner"
),
(
(
pytest
.
param
(
dim1
,
dim4
,
inner
,
id
=
f
"
{
dim1
=
}
,
{
dim4
=
}
,
{
inner
=
}
"
)
pytest
.
param
(
dim1
,
dim4
,
inner
,
id
=
f
"
{
dim1
=
}
,
{
dim4
=
}
,
{
inner
=
}
"
)
for
(
dim1
,
dim4
,
inner
)
for
(
dim1
,
dim4
,
inner
)
in
zip
(
in
zip
(
get_test_dims
(
1
,
4
*
1024
,
n
=
4
),
get_test_dims
(
1
,
4
*
1024
,
n
=
4
),
get_test_dims
(
1
,
4
*
1024
,
n
=
4
),
get_test_dims
(
1
,
4
*
1024
,
n
=
4
),
get_test_dims
(
1
,
4
*
1024
,
n
=
4
),
get_test_dims
(
1
,
4
*
1024
,
n
=
4
),
)
)
)
)
,
)
)
def
test_integrated_igemmlt
(
dim1
,
dim4
,
inner
):
def
test_integrated_igemmlt
(
dim1
,
dim4
,
inner
):
for
i
in
range
(
k
):
for
i
in
range
(
k
):
...
@@ -1037,13 +970,12 @@ def test_integrated_igemmlt(dim1, dim4, inner):
...
@@ -1037,13 +970,12 @@ def test_integrated_igemmlt(dim1, dim4, inner):
(
"dim1"
,
"dim4"
,
"inner"
),
(
"dim1"
,
"dim4"
,
"inner"
),
(
(
pytest
.
param
(
dim1
,
dim4
,
inner
,
id
=
f
"
{
dim1
=
}
,
{
dim4
=
}
,
{
inner
=
}
"
)
pytest
.
param
(
dim1
,
dim4
,
inner
,
id
=
f
"
{
dim1
=
}
,
{
dim4
=
}
,
{
inner
=
}
"
)
for
(
dim1
,
dim4
,
inner
)
for
(
dim1
,
dim4
,
inner
)
in
zip
(
in
zip
(
get_test_dims
(
1
,
4
*
1024
,
n
=
6
),
get_test_dims
(
1
,
4
*
1024
,
n
=
6
),
get_test_dims
(
1
,
4
*
1024
,
n
=
6
),
get_test_dims
(
1
,
4
*
1024
,
n
=
6
),
get_test_dims
(
1
,
4
*
1024
,
n
=
6
),
get_test_dims
(
1
,
4
*
1024
,
n
=
6
),
)
)
)
)
,
)
)
@
pytest
.
mark
.
skip
(
"Row scale has some bugs for ampere"
)
@
pytest
.
mark
.
skip
(
"Row scale has some bugs for ampere"
)
def
test_igemmlt_row_scale
(
dim1
,
dim4
,
inner
):
def
test_igemmlt_row_scale
(
dim1
,
dim4
,
inner
):
...
@@ -1067,9 +999,7 @@ def test_igemmlt_row_scale(dim1, dim4, inner):
...
@@ -1067,9 +999,7 @@ def test_igemmlt_row_scale(dim1, dim4, inner):
c
=
10.0
*
inner
*
scale
c
=
10.0
*
inner
*
scale
row_scale
=
torch
.
ones_like
(
maxA
)
/
c
row_scale
=
torch
.
ones_like
(
maxA
)
/
c
outC32
,
SC
=
F
.
igemmlt
(
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
,
dtype
=
torch
.
int8
,
row_scale
=
row_scale
)
A2
,
B2
,
SA
,
SB
,
dtype
=
torch
.
int8
,
row_scale
=
row_scale
)
C3
,
S
=
F
.
nvidia_transform
(
outC32
,
"row"
,
state
=
SC
)
C3
,
S
=
F
.
nvidia_transform
(
outC32
,
"row"
,
state
=
SC
)
maxval
=
torch
.
abs
(
C3
).
max
()
maxval
=
torch
.
abs
(
C3
).
max
()
if
maxval
==
127
:
if
maxval
==
127
:
...
@@ -1150,9 +1080,7 @@ def test_row_scale_bench(dim1, dim4, inner):
...
@@ -1150,9 +1080,7 @@ def test_row_scale_bench(dim1, dim4, inner):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
k
):
for
i
in
range
(
k
):
outC32
,
SC
=
F
.
igemmlt
(
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
,
dtype
=
torch
.
int8
,
row_scale
=
row_scale
)
A2
,
B2
,
SA
,
SB
,
dtype
=
torch
.
int8
,
row_scale
=
row_scale
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
"row-wise"
,
time
.
time
()
-
t0
)
print
(
"row-wise"
,
time
.
time
()
-
t0
)
...
@@ -1177,13 +1105,9 @@ def test_row_scale_bench(dim1, dim4, inner):
...
@@ -1177,13 +1105,9 @@ def test_row_scale_bench(dim1, dim4, inner):
def
test_transform
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
orderA
,
orderOut
,
transpose
):
def
test_transform
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
orderA
,
orderOut
,
transpose
):
for
i
in
range
(
k
):
for
i
in
range
(
k
):
if
dims
==
2
:
if
dims
==
2
:
A
=
torch
.
randint
(
10
,
99
,
size
=
(
dim1
,
dim2
),
device
=
"cuda"
).
to
(
A
=
torch
.
randint
(
10
,
99
,
size
=
(
dim1
,
dim2
),
device
=
"cuda"
).
to
(
dtype
)
dtype
)
elif
dims
==
3
:
elif
dims
==
3
:
A
=
torch
.
randint
(
A
=
torch
.
randint
(
10
,
99
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
dtype
)
10
,
99
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
dtype
)
A
.
view
(
-
1
)[
-
1
]
=
-
1
A
.
view
(
-
1
)[
-
1
]
=
-
1
if
transpose
:
if
transpose
:
...
@@ -1224,23 +1148,17 @@ def test_coo_double_quant(dim1, dim2):
...
@@ -1224,23 +1148,17 @@ def test_coo_double_quant(dim1, dim2):
idx
=
torch
.
abs
(
A
)
>=
threshold
idx
=
torch
.
abs
(
A
)
>=
threshold
CA2
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
)
CA2
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
)
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
,
threshold
=
threshold
)
A
,
threshold
=
threshold
)
if
coo_tensor
is
not
None
:
if
coo_tensor
is
not
None
:
A1
=
A
*
idx
A1
=
A
*
idx
A2
=
torch
.
zeros_like
(
A
)
A2
=
torch
.
zeros_like
(
A
)
A2
[
A2
[
coo_tensor
.
rowidx
.
long
(),
coo_tensor
.
colidx
.
long
()]
=
coo_tensor
.
values
coo_tensor
.
rowidx
.
long
(),
coo_tensor
.
colidx
.
long
()
]
=
coo_tensor
.
values
torch
.
testing
.
assert_close
(
A1
,
A2
)
torch
.
testing
.
assert_close
(
A1
,
A2
)
A1
=
A
*
(
idx
==
0
)
A1
=
A
*
(
idx
==
0
)
A2
=
(
CA
.
float
()
*
statsA
.
unsqueeze
(
1
)
/
127
).
half
()
A2
=
(
CA
.
float
()
*
statsA
.
unsqueeze
(
1
)
/
127
).
half
()
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
A
*
(
idx
==
0
),
A2
,
rtol
=
0.05
,
atol
=
1.5e-2
)
A
*
(
idx
==
0
),
A2
,
rtol
=
0.05
,
atol
=
1.5e-2
)
@
pytest
.
mark
.
parametrize
(
"dim1"
,
get_test_dims
(
1
,
1
*
1024
,
n
=
2
),
ids
=
id_formatter
(
"dim1"
))
@
pytest
.
mark
.
parametrize
(
"dim1"
,
get_test_dims
(
1
,
1
*
1024
,
n
=
2
),
ids
=
id_formatter
(
"dim1"
))
...
@@ -1261,9 +1179,7 @@ def test_spmm_coo(dim1, dim2, transposed_B):
...
@@ -1261,9 +1179,7 @@ def test_spmm_coo(dim1, dim2, transposed_B):
nnz
=
(
idx
==
1
).
sum
().
item
()
nnz
=
(
idx
==
1
).
sum
().
item
()
rows
,
cols
=
torch
.
where
(
idx
)
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
A2
=
A
*
idx
if
transposed_B
:
if
transposed_B
:
...
@@ -1303,9 +1219,7 @@ def test_spmm_bench():
...
@@ -1303,9 +1219,7 @@ def test_spmm_bench():
print
(
nnz
/
idx
.
numel
())
print
(
nnz
/
idx
.
numel
())
rows
,
cols
=
torch
.
where
(
idx
)
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
for
i
in
range
(
10
):
for
i
in
range
(
10
):
out2
=
F
.
spmm_coo
(
cooA
,
B
)
out2
=
F
.
spmm_coo
(
cooA
,
B
)
...
@@ -1339,9 +1253,7 @@ def test_integrated_sparse_decomp(dim1, dim2):
...
@@ -1339,9 +1253,7 @@ def test_integrated_sparse_decomp(dim1, dim2):
out1_32
,
Sout1_32
=
F
.
igemmlt
(
C32A
,
CTw1
,
SA
,
Sw1
)
out1_32
,
Sout1_32
=
F
.
igemmlt
(
C32A
,
CTw1
,
SA
,
Sw1
)
out2
=
F
.
mm_dequant
(
out1_32
,
Sout1_32
,
statsA
,
statsw1
)
out2
=
F
.
mm_dequant
(
out1_32
,
Sout1_32
,
statsA
,
statsw1
)
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
,
threshold
=
threshold
)
A
,
threshold
=
threshold
)
C32A
,
SA
=
F
.
transform
(
CA
,
"col32"
)
C32A
,
SA
=
F
.
transform
(
CA
,
"col32"
)
out1_32
,
Sout1_32
=
F
.
igemmlt
(
C32A
,
CTw1
,
SA
,
Sw1
)
out1_32
,
Sout1_32
=
F
.
igemmlt
(
C32A
,
CTw1
,
SA
,
Sw1
)
...
@@ -1396,9 +1308,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
...
@@ -1396,9 +1308,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
nnz
=
(
idx
==
1
).
sum
().
item
()
nnz
=
(
idx
==
1
).
sum
().
item
()
rows
,
cols
=
torch
.
where
(
idx
)
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
A2
=
A
*
idx
out1
=
torch
.
matmul
(
A2
.
half
(),
B
.
half
())
out1
=
torch
.
matmul
(
A2
.
half
(),
B
.
half
())
out
=
out_func
(
out1
.
shape
,
dtype
=
torch
.
float16
,
device
=
out1
.
device
)
out
=
out_func
(
out1
.
shape
,
dtype
=
torch
.
float16
,
device
=
out1
.
device
)
...
@@ -1413,9 +1323,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
...
@@ -1413,9 +1323,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
std
=
out1
.
std
()
std
=
out1
.
std
()
out1
/=
std
out1
/=
std
out2
/=
std
out2
/=
std
assert_all_approx_close
(
assert_all_approx_close
(
out1
,
out2
.
half
(),
rtol
=
0.01
,
atol
=
3.0e-2
,
count
=
count
)
out1
,
out2
.
half
(),
rtol
=
0.01
,
atol
=
3.0e-2
,
count
=
count
)
# assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
# assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
idx_col
=
torch
.
randint
(
0
,
A2
.
shape
[
-
1
],
size
=
(
15
,))
idx_col
=
torch
.
randint
(
0
,
A2
.
shape
[
-
1
],
size
=
(
15
,))
...
@@ -1443,9 +1351,7 @@ def test_coo2csr():
...
@@ -1443,9 +1351,7 @@ def test_coo2csr():
nnz
=
(
idx
==
1
).
sum
().
item
()
nnz
=
(
idx
==
1
).
sum
().
item
()
rows
,
cols
=
torch
.
where
(
idx
)
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
A2
=
A
*
idx
csrA
=
F
.
coo2csr
(
cooA
)
csrA
=
F
.
coo2csr
(
cooA
)
counts
=
csrA
.
rowptr
[
1
:]
-
csrA
.
rowptr
[:
-
1
]
counts
=
csrA
.
rowptr
[
1
:]
-
csrA
.
rowptr
[:
-
1
]
...
@@ -1463,9 +1369,7 @@ def test_coo2csc():
...
@@ -1463,9 +1369,7 @@ def test_coo2csc():
nnz
=
(
idx
==
1
).
sum
().
item
()
nnz
=
(
idx
==
1
).
sum
().
item
()
rows
,
cols
=
torch
.
where
(
idx
)
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
A2
=
A
*
idx
cscA
=
F
.
coo2csc
(
cooA
)
cscA
=
F
.
coo2csc
(
cooA
)
counts
=
cscA
.
colptr
[
1
:]
-
cscA
.
colptr
[:
-
1
]
counts
=
cscA
.
colptr
[
1
:]
-
cscA
.
colptr
[:
-
1
]
...
@@ -1499,9 +1403,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
...
@@ -1499,9 +1403,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
nnz
=
(
idx
==
1
).
sum
().
item
()
nnz
=
(
idx
==
1
).
sum
().
item
()
rows
,
cols
=
torch
.
where
(
idx
)
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
A2
=
A
*
idx
out2
=
F
.
spmm_coo_very_sparse
(
cooA
,
CBt
,
dequant_stats
=
statsBt
)
out2
=
F
.
spmm_coo_very_sparse
(
cooA
,
CBt
,
dequant_stats
=
statsBt
)
out1
=
torch
.
matmul
(
A2
,
B
.
half
())
out1
=
torch
.
matmul
(
A2
,
B
.
half
())
...
@@ -1582,7 +1484,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
...
@@ -1582,7 +1484,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"batch"
,
"seq"
,
"model"
,
"hidden"
),
(
"batch"
,
"seq"
,
"model"
,
"hidden"
),
[
pytest
.
param
(
1
,
1
,
6656
,
4
*
6656
,
id
=
"batch=1, seq=1, model=6656, hidden=26k"
)],
[
pytest
.
param
(
1
,
1
,
6656
,
4
*
6656
,
id
=
"batch=1, seq=1, model=6656, hidden=26k"
)],
)
)
@
pytest
.
mark
.
benchmark
@
pytest
.
mark
.
benchmark
def
test_bench_matmul
(
batch
,
seq
,
model
,
hidden
):
def
test_bench_matmul
(
batch
,
seq
,
model
,
hidden
):
...
@@ -1605,8 +1507,8 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1605,8 +1507,8 @@ def test_bench_matmul(batch, seq, model, hidden):
outliers
=
torch
.
randint
(
0
,
model
,
size
=
(
5
,)).
cuda
()
outliers
=
torch
.
randint
(
0
,
model
,
size
=
(
5
,)).
cuda
()
A
[:,
:,
outliers
]
=
8.0
A
[:,
:,
outliers
]
=
8.0
linearMixedBit
=
(
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
False
,
threshold
=
6.0
).
cuda
().
half
()
)
linearMixedBit
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
False
,
threshold
=
6.0
).
cuda
().
half
()
#linearMixedBit.eval()
#
linearMixedBit.eval()
linear8bit_train
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
).
cuda
().
half
()
linear8bit_train
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
).
cuda
().
half
()
linear8bit_train_thresh
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
threshold
=
6.0
).
cuda
().
half
()
linear8bit_train_thresh
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
threshold
=
6.0
).
cuda
().
half
()
...
@@ -1623,121 +1525,123 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1623,121 +1525,123 @@ def test_bench_matmul(batch, seq, model, hidden):
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
torch
.
matmul
(
A
,
B
.
t
())
torch
.
matmul
(
A
,
B
.
t
())
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
"pytorch fp16: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
print
(
f
"pytorch fp16: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
,
)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#t0 = time.time()
#
t0 = time.time()
#for i in range(iters):
#
for i in range(iters):
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state)
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
#
print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#t0 = time.time()
#
t0 = time.time()
#for i in range(iters):
#
for i in range(iters):
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c)
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
#
print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
bnb
.
matmul_4bit
(
A
,
B_nf4
.
t
(),
quant_state
=
state_nf4
)
bnb
.
matmul_4bit
(
A
,
B_nf4
.
t
(),
quant_state
=
state_nf4
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
"bnb nf4: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
print
(
f
"bnb nf4: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
bnb
.
matmul_4bit
(
A
,
B_nf4_c
.
t
(),
quant_state
=
state_nf4_c
)
bnb
.
matmul_4bit
(
A
,
B_nf4_c
.
t
(),
quant_state
=
state_nf4_c
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
"bnb nf4+DQ: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
print
(
f
"bnb nf4+DQ: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
# torch.cuda.synchronize()
#torch.cuda.synchronize()
# t0 = time.time()
#t0 = time.time()
# for i in range(iters):
#for i in range(iters):
# bnb.matmul(A, B)
# bnb.matmul(A, B)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#
print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#t0 = time.time()
#
t0 = time.time()
#for i in range(iters):
#
for i in range(iters):
# bnb.matmul(A, B, threshold=6.0)
# bnb.matmul(A, B, threshold=6.0)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#
print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
#
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
#C32A, SA = F.transform(CA, "col32")
#
C32A, SA = F.transform(CA, "col32")
#CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
#
CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
#CxB, SB = F.transform(CB, to_order=formatB)
#
CxB, SB = F.transform(CB, to_order=formatB)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#t0 = time.time()
#
t0 = time.time()
#for i in range(iters):
#
for i in range(iters):
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#
print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#BA, statsB = F.vectorwise_quant(B, dim=1)
#
BA, statsB = F.vectorwise_quant(B, dim=1)
#CxB, SB = F.nvidia_transform(CB, to_order=formatB)
#
CxB, SB = F.nvidia_transform(CB, to_order=formatB)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#t0 = time.time()
#
t0 = time.time()
#for i in range(iters):
#
for i in range(iters):
# A2 = A.view(-1, A.shape[-1]).contiguous()
# A2 = A.view(-1, A.shape[-1]).contiguous()
# CA, statsA = F.vectorwise_quant(A2, dim=1)
# CA, statsA = F.vectorwise_quant(A2, dim=1)
# C32A, SA = F.nvidia_transform(CA, "col32")
# C32A, SA = F.nvidia_transform(CA, "col32")
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
# Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
# F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
# F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#
print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
#
BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
#CxB, SB = F.nvidia_transform(CB, to_order=formatB)
#
CxB, SB = F.nvidia_transform(CB, to_order=formatB)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#t0 = time.time()
#
t0 = time.time()
#for i in range(iters):
#
for i in range(iters):
# A2 = A.view(-1, A.shape[-1]).contiguous()
# A2 = A.view(-1, A.shape[-1]).contiguous()
# CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
# CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
# C32A, SA = F.nvidia_transform(CA, "col32")
# C32A, SA = F.nvidia_transform(CA, "col32")
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
# Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
# out = Cout * statsB * statsA * (1.0 / (127 * 127))
# out = Cout * statsB * statsA * (1.0 / (127 * 127))
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#
print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#linear8bit(A)
#
linear8bit(A)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#t0 = time.time()
#
t0 = time.time()
#for i in range(iters):
#
for i in range(iters):
# linear8bit(A)
# linear8bit(A)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#
print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#linearMixedBit(A)
#
linearMixedBit(A)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#t0 = time.time()
#
t0 = time.time()
#for i in range(iters):
#
for i in range(iters):
# linearMixedBit(A)
# linearMixedBit(A)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#
print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#linear8bit_train(A)
#
linear8bit_train(A)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#t0 = time.time()
#
t0 = time.time()
#for i in range(iters):
#
for i in range(iters):
# linear8bit_train(A)
# linear8bit_train(A)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#
print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#linear8bit_train_thresh(A)
#
linear8bit_train_thresh(A)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#t0 = time.time()
#
t0 = time.time()
#for i in range(iters):
#
for i in range(iters):
# linear8bit_train(A)
# linear8bit_train(A)
#torch.cuda.synchronize()
# torch.cuda.synchronize()
#print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
# print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
def
test_zeropoint
():
def
test_zeropoint
():
def
quant_zp
(
x
):
def
quant_zp
(
x
):
...
@@ -1778,8 +1682,8 @@ def test_zeropoint():
...
@@ -1778,8 +1682,8 @@ def test_zeropoint():
C2
-=
A
.
sum
(
1
).
view
(
-
1
,
1
)
*
zp
C2
-=
A
.
sum
(
1
).
view
(
-
1
,
1
)
*
zp
ca
,
cqa
,
cza
=
quant_zp
(
A
)
ca
,
cqa
,
cza
=
quant_zp
(
A
)
#print(ca.min(), ca.max())
#
print(ca.min(), ca.max())
#print((ca - cza).min(), (ca - cza).max())
#
print((ca - cza).min(), (ca - cza).max())
zp
=
1
zp
=
1
scale
=
2.0
scale
=
2.0
...
@@ -1808,14 +1712,14 @@ def test_zeropoint():
...
@@ -1808,14 +1712,14 @@ def test_zeropoint():
C7
-=
zpa
*
zpb
*
A
.
shape
[
1
]
C7
-=
zpa
*
zpb
*
A
.
shape
[
1
]
C7
/=
qa
*
qb
C7
/=
qa
*
qb
#print("")
#
print("")
# print(C0.flatten()[:10])
# print(C0.flatten()[:10])
#print(C1.flatten()[:10])
#
print(C1.flatten()[:10])
#print(C2.flatten()[:10])
#
print(C2.flatten()[:10])
#print(C3.flatten()[:10])
#
print(C3.flatten()[:10])
#print(C5.flatten()[:10])
#
print(C5.flatten()[:10])
#print(C6.flatten()[:10])
#
print(C6.flatten()[:10])
#print(C7.flatten()[:10])
#
print(C7.flatten()[:10])
err1
=
torch
.
abs
(
C1
-
C2
).
mean
().
item
()
err1
=
torch
.
abs
(
C1
-
C2
).
mean
().
item
()
err2
=
torch
.
abs
(
C1
-
C3
).
mean
().
item
()
err2
=
torch
.
abs
(
C1
-
C3
).
mean
().
item
()
err3
=
torch
.
abs
(
C1
-
C4
).
mean
().
item
()
err3
=
torch
.
abs
(
C1
-
C4
).
mean
().
item
()
...
@@ -1852,16 +1756,15 @@ def test_extract_outliers():
...
@@ -1852,16 +1756,15 @@ def test_extract_outliers():
torch
.
testing
.
assert_close
(
outliers1
,
outliers2
)
torch
.
testing
.
assert_close
(
outliers1
,
outliers2
)
def
test_blockwise_cpu_large
():
def
test_blockwise_cpu_large
():
diffs
=
[]
diffs
=
[]
reldiffs
=
[]
reldiffs
=
[]
batch
=
128
batch
=
128
seq
=
128
seq
=
128
for
hidden
in
[
128
]:
#
, 14336]:
for
hidden
in
[
128
]:
#
, 14336]:
for
blocksize
in
[
4096
,
16384
]:
for
blocksize
in
[
4096
,
16384
]:
for
i
in
range
(
2
):
for
i
in
range
(
2
):
A1
=
torch
.
randn
(
batch
,
seq
,
hidden
,
device
=
'
cpu
'
)
A1
=
torch
.
randn
(
batch
,
seq
,
hidden
,
device
=
"
cpu
"
)
t0
=
time
.
time
()
t0
=
time
.
time
()
C
,
S
=
F
.
quantize_blockwise
(
A1
,
blocksize
=
blocksize
)
C
,
S
=
F
.
quantize_blockwise
(
A1
,
blocksize
=
blocksize
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
,
blocksize
=
blocksize
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
,
blocksize
=
blocksize
)
...
@@ -1875,10 +1778,9 @@ def test_blockwise_cpu_large():
...
@@ -1875,10 +1778,9 @@ def test_blockwise_cpu_large():
# print(sum(reldiffs)/len(reldiffs))
# print(sum(reldiffs)/len(reldiffs))
def
test_fp8_quant
():
def
test_fp8_quant
():
for
e_bits
in
range
(
1
,
7
):
for
e_bits
in
range
(
1
,
7
):
p_bits
=
7
-
e_bits
p_bits
=
7
-
e_bits
code
=
F
.
create_fp8_map
(
True
,
e_bits
,
p_bits
).
cuda
()
code
=
F
.
create_fp8_map
(
True
,
e_bits
,
p_bits
).
cuda
()
abserr
=
[]
abserr
=
[]
...
@@ -1888,12 +1790,12 @@ def test_fp8_quant():
...
@@ -1888,12 +1790,12 @@ def test_fp8_quant():
C
,
SC
=
F
.
quantize_blockwise
(
A1
,
code
=
code
)
C
,
SC
=
F
.
quantize_blockwise
(
A1
,
code
=
code
)
A2
=
F
.
dequantize_blockwise
(
C
,
SC
)
A2
=
F
.
dequantize_blockwise
(
C
,
SC
)
diff
=
torch
.
abs
(
A1
-
A2
)
diff
=
torch
.
abs
(
A1
-
A2
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
abserr
.
append
(
diff
.
mean
().
item
())
abserr
.
append
(
diff
.
mean
().
item
())
relerr
.
append
(
reldiff
.
mean
().
item
())
relerr
.
append
(
reldiff
.
mean
().
item
())
#assert diff < 0.0075
#
assert diff < 0.0075
#print(sum(abserr)/len(abserr))
#
print(sum(abserr)/len(abserr))
#print(sum(relerr)/len(relerr))
#
print(sum(relerr)/len(relerr))
abserr
=
[]
abserr
=
[]
relerr
=
[]
relerr
=
[]
...
@@ -1902,12 +1804,12 @@ def test_fp8_quant():
...
@@ -1902,12 +1804,12 @@ def test_fp8_quant():
C
,
SC
=
F
.
quantize_blockwise
(
A1
,
code
=
code
)
C
,
SC
=
F
.
quantize_blockwise
(
A1
,
code
=
code
)
A2
=
F
.
dequantize_blockwise
(
C
,
SC
)
A2
=
F
.
dequantize_blockwise
(
C
,
SC
)
diff
=
torch
.
abs
(
A1
-
A2
)
diff
=
torch
.
abs
(
A1
-
A2
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
abserr
.
append
(
diff
.
mean
().
item
())
abserr
.
append
(
diff
.
mean
().
item
())
relerr
.
append
(
reldiff
.
mean
().
item
())
relerr
.
append
(
reldiff
.
mean
().
item
())
#assert diff < 0.0075
#
assert diff < 0.0075
#print(sum(abserr)/len(abserr))
#
print(sum(abserr)/len(abserr))
#print(sum(relerr)/len(relerr))
#
print(sum(relerr)/len(relerr))
abserr
=
[]
abserr
=
[]
relerr
=
[]
relerr
=
[]
...
@@ -1916,50 +1818,48 @@ def test_fp8_quant():
...
@@ -1916,50 +1818,48 @@ def test_fp8_quant():
C
,
SC
=
F
.
quantize_blockwise
(
A1
)
C
,
SC
=
F
.
quantize_blockwise
(
A1
)
A2
=
F
.
dequantize_blockwise
(
C
,
SC
)
A2
=
F
.
dequantize_blockwise
(
C
,
SC
)
diff
=
torch
.
abs
(
A1
-
A2
)
diff
=
torch
.
abs
(
A1
-
A2
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
abserr
.
append
(
diff
.
mean
().
item
())
abserr
.
append
(
diff
.
mean
().
item
())
relerr
.
append
(
reldiff
.
mean
().
item
())
relerr
.
append
(
reldiff
.
mean
().
item
())
#assert diff < 0.0075
#
assert diff < 0.0075
#print(3, sum(abserr)/len(abserr))
#
print(3, sum(abserr)/len(abserr))
#print(3, sum(relerr)/len(relerr))
#
print(3, sum(relerr)/len(relerr))
def
test_few_bit_quant
():
def
test_few_bit_quant
():
# print('')
#print('')
for
bits
in
range
(
2
,
9
):
for
bits
in
range
(
2
,
9
):
#print('='*30, bits, '='*30)
#
print('='*30, bits, '='*30)
for
method
in
[
'
linear
'
,
'
fp8
'
,
'
dynamic
'
,
'
quantile
'
]:
for
method
in
[
"
linear
"
,
"
fp8
"
,
"
dynamic
"
,
"
quantile
"
]:
abserrs
=
[]
abserrs
=
[]
relerrs
=
[]
relerrs
=
[]
code
=
None
code
=
None
if
method
==
'
linear
'
:
if
method
==
"
linear
"
:
code
=
F
.
create_linear_map
(
True
,
total_bits
=
bits
).
cuda
()
code
=
F
.
create_linear_map
(
True
,
total_bits
=
bits
).
cuda
()
elif
method
==
'
fp8
'
:
elif
method
==
"
fp8
"
:
ebits
=
math
.
ceil
(
bits
/
2
)
ebits
=
math
.
ceil
(
bits
/
2
)
pbits
=
bits
-
ebits
-
1
pbits
=
bits
-
ebits
-
1
code
=
F
.
create_fp8_map
(
True
,
ebits
,
pbits
,
bits
).
cuda
()
code
=
F
.
create_fp8_map
(
True
,
ebits
,
pbits
,
bits
).
cuda
()
elif
method
==
'
dynamic
'
:
elif
method
==
"
dynamic
"
:
code
=
F
.
create_dynamic_map
(
True
,
bits
-
0
,
bits
).
cuda
()
code
=
F
.
create_dynamic_map
(
True
,
bits
-
0
,
bits
).
cuda
()
elif
method
==
'
quantile
'
:
elif
method
==
"
quantile
"
:
values
=
torch
.
randn
(
2048
,
2048
,
device
=
'
cuda
'
)
values
=
torch
.
randn
(
2048
,
2048
,
device
=
"
cuda
"
)
code
=
F
.
create_quantile_map
(
values
,
bits
).
cuda
()
code
=
F
.
create_quantile_map
(
values
,
bits
).
cuda
()
# for some data types we have no zero
# for some data types we have no zero
# for some data types we have one zero
# for some data types we have one zero
# for some data types we have two zeros
# for some data types we have two zeros
assert
torch
.
unique
(
code
).
numel
()
in
[
2
**
bits
,
2
**
bits
-
1
],
f
'
bits:
{
bits
}
, method:
{
method
}
'
assert
torch
.
unique
(
code
).
numel
()
in
[
2
**
bits
,
2
**
bits
-
1
],
f
"
bits:
{
bits
}
, method:
{
method
}
"
#print(method, (code==0).sum())
#
print(method, (code==0).sum())
assert
code
.
numel
()
==
256
assert
code
.
numel
()
==
256
for
i
in
range
(
10
):
for
i
in
range
(
10
):
values
=
torch
.
randn
(
1
,
32
,
device
=
"cuda"
)
values
=
torch
.
randn
(
1
,
32
,
device
=
'cuda'
)
values
/=
values
.
abs
().
max
()
values
/=
values
.
abs
().
max
()
#values[values.abs() < 1e-6] += 1e-5
#
values[values.abs() < 1e-6] += 1e-5
q1
=
[]
q1
=
[]
v1
=
[]
v1
=
[]
for
v
in
values
[
0
]:
for
v
in
values
[
0
]:
idx
=
torch
.
abs
(
v
-
code
).
argmin
()
idx
=
torch
.
abs
(
v
-
code
).
argmin
()
q1
.
append
(
idx
.
item
())
q1
.
append
(
idx
.
item
())
v1
.
append
(
code
[
idx
].
item
())
v1
.
append
(
code
[
idx
].
item
())
...
@@ -1970,62 +1870,61 @@ def test_few_bit_quant():
...
@@ -1970,62 +1870,61 @@ def test_few_bit_quant():
v2
=
F
.
dequantize_blockwise
(
q2
,
S2
)
v2
=
F
.
dequantize_blockwise
(
q2
,
S2
)
idx
=
torch
.
isclose
(
q1
.
int
(),
q2
.
int
())
idx
=
torch
.
isclose
(
q1
.
int
(),
q2
.
int
())
err2
=
torch
.
abs
(
v2
-
values
)
err2
=
torch
.
abs
(
v2
-
values
)
abserrs
.
append
(
err2
.
mean
().
item
())
abserrs
.
append
(
err2
.
mean
().
item
())
relerrs
.
append
((
err2
/
(
1e-10
+
values
).
abs
()).
mean
().
item
())
relerrs
.
append
((
err2
/
(
1e-10
+
values
).
abs
()).
mean
().
item
())
if
idx
.
sum
():
if
idx
.
sum
():
# some weird cases
# some weird cases
err1
=
torch
.
abs
(
v1
-
values
).
mean
()
err1
=
torch
.
abs
(
v1
-
values
).
mean
()
#assert err2.mean() <= err1
#
assert err2.mean() <= err1
else
:
else
:
torch
.
testing
.
assert_close
(
q1
,
q2
)
torch
.
testing
.
assert_close
(
q1
,
q2
)
#print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
#
print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
#assert False
#
assert False
def
test_kbit_quantile_estimation
():
def
test_kbit_quantile_estimation
():
for
i
in
range
(
100
):
for
i
in
range
(
100
):
data
=
torch
.
randn
(
1024
,
1024
,
device
=
'
cuda
'
)
data
=
torch
.
randn
(
1024
,
1024
,
device
=
"
cuda
"
)
for
bits
in
range
(
2
,
9
):
for
bits
in
range
(
2
,
9
):
p
=
np
.
linspace
(
1.3e-4
,
1
-
1.3e-4
,
2
**
bits
)
p
=
np
.
linspace
(
1.3e-4
,
1
-
1.3e-4
,
2
**
bits
)
val1
=
torch
.
Tensor
(
norm
.
ppf
(
p
)).
cuda
()
val1
=
torch
.
Tensor
(
norm
.
ppf
(
p
)).
cuda
()
val2
=
F
.
estimate_quantiles
(
data
,
offset
=
0
,
num_quantiles
=
2
**
bits
)
val2
=
F
.
estimate_quantiles
(
data
,
offset
=
0
,
num_quantiles
=
2
**
bits
)
err
=
torch
.
abs
(
val1
-
val2
).
mean
()
err
=
torch
.
abs
(
val1
-
val2
).
mean
()
assert
err
<
0.038
assert
err
<
0.038
for
i
in
range
(
100
):
for
i
in
range
(
100
):
data
=
torch
.
randn
(
1024
,
1024
,
device
=
'
cuda
'
)
data
=
torch
.
randn
(
1024
,
1024
,
device
=
"
cuda
"
)
for
bits
in
range
(
2
,
4
):
for
bits
in
range
(
2
,
4
):
total_values
=
2
**
bits
-
1
total_values
=
2
**
bits
-
1
p
=
np
.
linspace
(
0
,
1
,
2
*
total_values
+
1
)
p
=
np
.
linspace
(
0
,
1
,
2
*
total_values
+
1
)
idx
=
np
.
arange
(
1
,
2
*
total_values
+
1
,
2
)
idx
=
np
.
arange
(
1
,
2
*
total_values
+
1
,
2
)
p
=
p
[
idx
]
p
=
p
[
idx
]
offset
=
1
/
(
2
*
total_values
)
offset
=
1
/
(
2
*
total_values
)
p
=
np
.
linspace
(
offset
,
1
-
offset
,
total_values
)
p
=
np
.
linspace
(
offset
,
1
-
offset
,
total_values
)
val1
=
torch
.
Tensor
(
norm
.
ppf
(
p
)).
cuda
()
val1
=
torch
.
Tensor
(
norm
.
ppf
(
p
)).
cuda
()
val2
=
F
.
estimate_quantiles
(
data
,
num_quantiles
=
2
**
bits
-
1
)
val2
=
F
.
estimate_quantiles
(
data
,
num_quantiles
=
2
**
bits
-
1
)
err
=
torch
.
abs
(
val1
-
val2
).
mean
()
err
=
torch
.
abs
(
val1
-
val2
).
mean
()
assert
err
<
0.035
assert
err
<
0.035
@
pytest
.
mark
.
benchmark
@
pytest
.
mark
.
benchmark
def
test_bench_dequantization
():
def
test_bench_dequantization
():
a
=
torch
.
rand
(
1024
,
1024
,
device
=
'
cuda
'
).
half
()
a
=
torch
.
rand
(
1024
,
1024
,
device
=
"
cuda
"
).
half
()
code
=
F
.
create_fp8_map
(
True
,
3
,
0
,
4
).
cuda
()
code
=
F
.
create_fp8_map
(
True
,
3
,
0
,
4
).
cuda
()
qa
,
SA
=
F
.
quantize_blockwise
(
a
,
code
=
code
)
qa
,
SA
=
F
.
quantize_blockwise
(
a
,
code
=
code
)
print
(
qa
.
max
())
print
(
qa
.
max
())
max_theoretical_mu
=
1024
*
1024
*
2
/
1024
**
3
/
672
*
1000
*
1000
max_theoretical_mu
=
1024
*
1024
*
2
/
1024
**
3
/
672
*
1000
*
1000
#print(max_theoretical_mu)
#
print(max_theoretical_mu)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
for
i
in
range
(
100
):
qa
,
SA
=
F
.
quantize_blockwise
(
a
)
qa
,
SA
=
F
.
quantize_blockwise
(
a
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
#print((time.time()-t0)/1e6)
# print((time.time()-t0)/1e6)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
],
ids
=
describe_dtype
)
...
@@ -2037,26 +1936,28 @@ def test_fp4_quant(dtype):
...
@@ -2037,26 +1936,28 @@ def test_fp4_quant(dtype):
result
=
0
result
=
0
bias
=
3
bias
=
3
sign
,
e1
,
e2
,
p1
=
bits
sign
,
e1
,
e2
,
p1
=
bits
idx
=
sign
*
8
+
e1
*
4
+
e2
*
2
+
p1
*
1
idx
=
sign
*
8
+
e1
*
4
+
e2
*
2
+
p1
*
1
sign
=
-
1.0
if
sign
else
1.0
sign
=
-
1.0
if
sign
else
1.0
exp
=
e1
*
2
+
e2
*
1
exp
=
e1
*
2
+
e2
*
1
if
exp
==
0
:
if
exp
==
0
:
# sub-normal
# sub-normal
if
p1
==
0
:
result
=
0
if
p1
==
0
:
else
:
result
=
sign
*
0.0625
result
=
0
else
:
result
=
sign
*
0.0625
else
:
else
:
# normal
# normal
exp
=
2
**
(
-
exp
+
bias
+
1
)
exp
=
2
**
(
-
exp
+
bias
+
1
)
frac
=
1.5
if
p1
else
1.0
frac
=
1.5
if
p1
else
1.0
result
=
sign
*
exp
*
frac
result
=
sign
*
exp
*
frac
code
[
idx
]
=
result
code
[
idx
]
=
result
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'
cuda
'
,
dtype
=
dtype
)
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
"
cuda
"
,
dtype
=
dtype
)
qa
,
SA
=
F
.
quantize_fp4
(
A1
,
blocksize
=
64
)
qa
,
SA
=
F
.
quantize_fp4
(
A1
,
blocksize
=
64
)
A2
=
F
.
dequantize_fp4
(
qa
,
SA
)
A2
=
F
.
dequantize_fp4
(
qa
,
SA
)
err
=
(
A1
-
A2
).
abs
().
float
()
err
=
(
A1
-
A2
).
abs
().
float
()
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-8
)).
mean
()
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-8
)).
mean
()
idx
=
err
>
1.0
idx
=
err
>
1.0
err
=
err
.
mean
()
err
=
err
.
mean
()
...
@@ -2065,31 +1966,29 @@ def test_fp4_quant(dtype):
...
@@ -2065,31 +1966,29 @@ def test_fp4_quant(dtype):
assert
relerr
.
item
()
<
0.28
assert
relerr
.
item
()
<
0.28
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
'
fp4
'
,
'
nf4
'
])
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
"
fp4
"
,
"
nf4
"
])
def
test_4bit_compressed_stats
(
quant_type
):
def
test_4bit_compressed_stats
(
quant_type
):
for
blocksize
in
[
128
,
64
]:
for
blocksize
in
[
128
,
64
]:
errs1
=
[]
errs1
=
[]
errs2
=
[]
errs2
=
[]
for
i
in
range
(
10
):
for
i
in
range
(
10
):
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'
cuda
'
).
half
()
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
"
cuda
"
).
half
()
q2
,
SA2
=
F
.
quantize_4bit
(
A1
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
q2
,
SA2
=
F
.
quantize_4bit
(
A1
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
q3
,
SA3
=
F
.
quantize_4bit
(
A1
,
blocksize
=
blocksize
,
compress_statistics
=
True
,
quant_type
=
quant_type
)
q3
,
SA3
=
F
.
quantize_4bit
(
A1
,
blocksize
=
blocksize
,
compress_statistics
=
True
,
quant_type
=
quant_type
)
A2
=
F
.
dequantize_4bit
(
q2
,
SA2
,
quant_type
=
quant_type
)
A2
=
F
.
dequantize_4bit
(
q2
,
SA2
,
quant_type
=
quant_type
)
A3
=
F
.
dequantize_4bit
(
q3
,
SA3
,
quant_type
=
quant_type
)
A3
=
F
.
dequantize_4bit
(
q3
,
SA3
,
quant_type
=
quant_type
)
err
=
(
A1
-
A2
).
abs
().
float
()
err
=
(
A1
-
A2
).
abs
().
float
()
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-15
)).
mean
()
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-15
)).
mean
()
err
=
err
.
mean
()
err
=
err
.
mean
()
errs1
.
append
(
err
.
item
())
errs1
.
append
(
err
.
item
())
assert
err
.
item
()
<
0.11
assert
err
.
item
()
<
0.11
assert
relerr
.
item
()
<
0.28
assert
relerr
.
item
()
<
0.28
err
=
(
A1
-
A3
).
abs
().
float
()
err
=
(
A1
-
A3
).
abs
().
float
()
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-15
)).
mean
()
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-15
)).
mean
()
err
=
err
.
mean
()
err
=
err
.
mean
()
errs2
.
append
(
err
.
item
())
errs2
.
append
(
err
.
item
())
...
@@ -2097,70 +1996,71 @@ def test_4bit_compressed_stats(quant_type):
...
@@ -2097,70 +1996,71 @@ def test_4bit_compressed_stats(quant_type):
assert
err
.
item
()
<
0.11
assert
err
.
item
()
<
0.11
assert
relerr
.
item
()
<
0.28
assert
relerr
.
item
()
<
0.28
#print(sum(errs1)/len(errs1), blocksize, quant_type)
# print(sum(errs1)/len(errs1), blocksize, quant_type)
#print(sum(errs2)/len(errs2), blocksize, quant_type)
# print(sum(errs2)/len(errs2), blocksize, quant_type)
# @pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
#@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
"nf4"
])
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
'nf4'
])
@
pytest
.
mark
.
benchmark
@
pytest
.
mark
.
benchmark
def
test_bench_4bit_dequant
(
quant_type
):
def
test_bench_4bit_dequant
(
quant_type
):
blocksize
=
256
blocksize
=
256
a
=
torch
.
rand
(
1024
*
12
*
4
,
1024
*
12
,
device
=
'
cuda
'
).
half
()
a
=
torch
.
rand
(
1024
*
12
*
4
,
1024
*
12
,
device
=
"
cuda
"
).
half
()
qa
,
SA
=
F
.
quantize_4bit
(
a
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
qa
,
SA
=
F
.
quantize_4bit
(
a
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
input_size
=
a
.
numel
()
/
2
input_size
=
a
.
numel
()
/
2
output_size
=
a
.
numel
()
*
2
output_size
=
a
.
numel
()
*
2
num_bytes
=
input_size
+
output_size
num_bytes
=
input_size
+
output_size
GB
=
num_bytes
/
1e9
GB
=
num_bytes
/
1e9
max_theoretical_s
=
GB
/
768
max_theoretical_s
=
GB
/
768
#print(max_theoretical_s*1e6)
#
print(max_theoretical_s*1e6)
b
=
torch
.
randn
(
128
,
1024
*
12
,
device
=
'
cuda
'
).
half
()
b
=
torch
.
randn
(
128
,
1024
*
12
,
device
=
"
cuda
"
).
half
()
iters
=
100
iters
=
100
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
F
.
dequantize_4bit
(
qa
,
SA
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
F
.
dequantize_4bit
(
qa
,
SA
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
#b.copy_(a)
#
b.copy_(a)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
#print((time.time()-t0)/iters*1e6)
#
print((time.time()-t0)/iters*1e6)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#t0 = time.time()
#
t0 = time.time()
#for i in range(iters):
#
for i in range(iters):
# torch.matmul(b, a.t())
# torch.matmul(b, a.t())
#torch.cuda.synchronize()
# torch.cuda.synchronize()
#print((time.time()-t0)/iters*1e6)
# print((time.time()-t0)/iters*1e6)
def
test_normal_map_tree
():
def
test_normal_map_tree
():
code
=
F
.
create_normal_map
()
code
=
F
.
create_normal_map
()
values
=
code
[:
8
].
tolist
()
+
code
[
-
8
:].
tolist
()
values
=
code
[:
8
].
tolist
()
+
code
[
-
8
:].
tolist
()
num_pivots
=
1
num_pivots
=
1
#print(values)
#
print(values)
while
num_pivots
<
16
:
while
num_pivots
<
16
:
idx
=
list
(
range
(
16
//
num_pivots
//
2
,
16
,
16
//
num_pivots
))
idx
=
list
(
range
(
16
//
num_pivots
//
2
,
16
,
16
//
num_pivots
))
#print(idx)
#
print(idx)
num_pivots
*=
2
num_pivots
*=
2
pivots
=
[]
pivots
=
[]
for
i
in
idx
:
for
i
in
idx
:
pivots
.
append
((
values
[
i
-
1
]
+
values
[
i
])
/
2
)
pivots
.
append
((
values
[
i
-
1
]
+
values
[
i
])
/
2
)
#print(pivots)
#
print(pivots)
@
pytest
.
mark
.
parametrize
(
"double_quant"
,
TRUE_FALSE
,
ids
=
lambda
double_quant
:
f
"DQ_
{
double_quant
}
"
)
@
pytest
.
mark
.
parametrize
(
"double_quant"
,
TRUE_FALSE
,
ids
=
lambda
double_quant
:
f
"DQ_
{
double_quant
}
"
)
@
pytest
.
mark
.
parametrize
(
"storage_type"
,
[
'
nf4
'
,
'
fp4
'
])
@
pytest
.
mark
.
parametrize
(
"storage_type"
,
[
"
nf4
"
,
"
fp4
"
])
@
pytest
.
mark
.
parametrize
(
"kind"
,
[
'
fc1
'
,
'
fc2
'
,
'
attn
'
,
'
attn_packed
'
])
@
pytest
.
mark
.
parametrize
(
"kind"
,
[
"
fc1
"
,
"
fc2
"
,
"
attn
"
,
"
attn_packed
"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"quant_storage"
,
[
torch
.
uint8
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"quant_storage"
,
[
torch
.
uint8
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
ids
=
describe_dtype
,
)
def
test_gemv_4bit
(
dtype
,
storage_type
,
quant_storage
,
double_quant
,
kind
):
def
test_gemv_4bit
(
dtype
,
storage_type
,
quant_storage
,
double_quant
,
kind
):
for
dim
in
[
128
,
256
,
512
,
1024
]:
for
dim
in
[
128
,
256
,
512
,
1024
]:
#for dim in [4*1024]:
#
for dim in [4*1024]:
#for dim in [1*16]:
#
for dim in [1*16]:
errs1
=
[]
errs1
=
[]
errs2
=
[]
errs2
=
[]
errs3
=
[]
errs3
=
[]
...
@@ -2171,38 +2071,42 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
...
@@ -2171,38 +2071,42 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
max_errs2
=
[]
max_errs2
=
[]
max_errs3
=
[]
max_errs3
=
[]
for
i
in
range
(
100
):
for
i
in
range
(
100
):
if
kind
==
'fc1'
:
if
kind
==
"fc1"
:
A
=
torch
.
randn
(
1
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
A
=
torch
.
randn
(
1
,
dim
,
dtype
=
dtype
,
device
=
"cuda"
)
B
=
torch
.
randn
(
dim
*
4
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
B
=
torch
.
randn
(
dim
*
4
,
dim
,
dtype
=
dtype
,
device
=
"cuda"
)
/
math
.
sqrt
(
dim
)
elif
kind
==
'fc2'
:
elif
kind
==
"fc2"
:
A
=
torch
.
randn
(
1
,
4
*
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
A
=
torch
.
randn
(
1
,
4
*
dim
,
dtype
=
dtype
,
device
=
"cuda"
)
B
=
torch
.
randn
(
dim
,
4
*
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
B
=
torch
.
randn
(
dim
,
4
*
dim
,
dtype
=
dtype
,
device
=
"cuda"
)
/
math
.
sqrt
(
dim
)
elif
kind
==
'attn'
:
elif
kind
==
"attn"
:
A
=
torch
.
randn
(
1
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
A
=
torch
.
randn
(
1
,
dim
,
dtype
=
dtype
,
device
=
"cuda"
)
B
=
torch
.
randn
(
dim
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
B
=
torch
.
randn
(
dim
,
dim
,
dtype
=
dtype
,
device
=
"cuda"
)
/
math
.
sqrt
(
dim
)
elif
kind
==
'attn_packed'
:
elif
kind
==
"attn_packed"
:
A
=
torch
.
randn
(
1
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
A
=
torch
.
randn
(
1
,
dim
,
dtype
=
dtype
,
device
=
"cuda"
)
B
=
torch
.
randn
(
dim
*
3
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
B
=
torch
.
randn
(
dim
*
3
,
dim
,
dtype
=
dtype
,
device
=
"cuda"
)
/
math
.
sqrt
(
dim
)
qB
,
state
=
F
.
quantize_4bit
(
B
,
quant_type
=
storage_type
,
compress_statistics
=
double_quant
,
quant_storage
=
quant_storage
)
qB
,
state
=
F
.
quantize_4bit
(
B
,
quant_type
=
storage_type
,
compress_statistics
=
double_quant
,
quant_storage
=
quant_storage
,
)
C3
=
torch
.
matmul
(
A
,
B
.
t
())
C3
=
torch
.
matmul
(
A
,
B
.
t
())
C2
=
F
.
gemv_4bit
(
A
,
qB
.
t
(),
state
=
state
)
C2
=
F
.
gemv_4bit
(
A
,
qB
.
t
(),
state
=
state
)
A
.
requires_grad
=
True
A
.
requires_grad
=
True
C1
=
bnb
.
matmul_4bit
(
A
,
qB
.
t
(),
state
)
C1
=
bnb
.
matmul_4bit
(
A
,
qB
.
t
(),
state
)
err1
=
(
C1
-
C2
).
abs
().
float
()
err1
=
(
C1
-
C2
).
abs
().
float
()
err2
=
(
C3
-
C2
).
abs
().
float
()
err2
=
(
C3
-
C2
).
abs
().
float
()
err3
=
(
C3
-
C1
).
abs
().
float
()
err3
=
(
C3
-
C1
).
abs
().
float
()
mag1
=
torch
.
abs
(
C1
).
float
()
+
1e-5
mag1
=
torch
.
abs
(
C1
).
float
()
+
1e-5
mag2
=
torch
.
abs
(
C3
).
float
()
+
1e-5
mag2
=
torch
.
abs
(
C3
).
float
()
+
1e-5
mag3
=
torch
.
abs
(
C3
).
float
()
+
1e-5
mag3
=
torch
.
abs
(
C3
).
float
()
+
1e-5
relerr1
=
err1
/
mag1
relerr1
=
err1
/
mag1
relerr2
=
err2
/
mag2
relerr2
=
err2
/
mag2
relerr3
=
err3
/
mag3
relerr3
=
err3
/
mag3
max_err1
=
err1
.
max
()
max_err1
=
err1
.
max
()
max_err2
=
err2
.
max
()
max_err2
=
err2
.
max
()
...
@@ -2220,34 +2124,34 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
...
@@ -2220,34 +2124,34 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
max_errs2
.
append
(
max_err2
.
item
())
max_errs2
.
append
(
max_err2
.
item
())
max_errs3
.
append
(
max_err3
.
item
())
max_errs3
.
append
(
max_err3
.
item
())
c
=
int
(
C1
.
numel
()
*
0.0014
*
(
dim
/
256
))
+
1
c
=
int
(
C1
.
numel
()
*
0.0014
*
(
dim
/
256
))
+
1
c
=
assert_all_approx_close
(
C1
,
C2
,
1e-5
,
0.01
,
count
=
c
,
throw
=
False
)
c
=
assert_all_approx_close
(
C1
,
C2
,
1e-5
,
0.01
,
count
=
c
,
throw
=
False
)
err1
=
sum
(
errs1
)
/
len
(
errs1
)
/
math
.
sqrt
(
dim
)
err1
=
sum
(
errs1
)
/
len
(
errs1
)
/
math
.
sqrt
(
dim
)
err2
=
sum
(
errs2
)
/
len
(
errs2
)
/
math
.
sqrt
(
dim
)
err2
=
sum
(
errs2
)
/
len
(
errs2
)
/
math
.
sqrt
(
dim
)
err3
=
sum
(
errs3
)
/
len
(
errs3
)
/
math
.
sqrt
(
dim
)
err3
=
sum
(
errs3
)
/
len
(
errs3
)
/
math
.
sqrt
(
dim
)
relerr1
=
sum
(
relerrs1
)
/
len
(
relerrs1
)
/
math
.
sqrt
(
dim
)
relerr1
=
sum
(
relerrs1
)
/
len
(
relerrs1
)
/
math
.
sqrt
(
dim
)
relerr2
=
sum
(
relerrs2
)
/
len
(
relerrs2
)
/
math
.
sqrt
(
dim
)
relerr2
=
sum
(
relerrs2
)
/
len
(
relerrs2
)
/
math
.
sqrt
(
dim
)
relerr3
=
sum
(
relerrs3
)
/
len
(
relerrs3
)
/
math
.
sqrt
(
dim
)
relerr3
=
sum
(
relerrs3
)
/
len
(
relerrs3
)
/
math
.
sqrt
(
dim
)
maxerr1
=
sum
(
max_errs1
)
/
len
(
max_errs1
)
/
math
.
sqrt
(
dim
)
maxerr1
=
sum
(
max_errs1
)
/
len
(
max_errs1
)
/
math
.
sqrt
(
dim
)
maxerr2
=
sum
(
max_errs2
)
/
len
(
max_errs2
)
/
math
.
sqrt
(
dim
)
maxerr2
=
sum
(
max_errs2
)
/
len
(
max_errs2
)
/
math
.
sqrt
(
dim
)
maxerr3
=
sum
(
max_errs3
)
/
len
(
max_errs3
)
/
math
.
sqrt
(
dim
)
maxerr3
=
sum
(
max_errs3
)
/
len
(
max_errs3
)
/
math
.
sqrt
(
dim
)
absratio
=
err2
/
err3
absratio
=
err2
/
err3
relratio
=
relerr2
/
relerr3
relratio
=
relerr2
/
relerr3
maxratio
=
relerr2
/
relerr3
maxratio
=
relerr2
/
relerr3
# for debugging if the tests fails
# for debugging if the tests fails
#
#
#print('='*80)
#
print('='*80)
#print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
#
print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
#print(C1.flatten()[-20:])
#
print(C1.flatten()[-20:])
#print(C2.flatten()[-20:])
#
print(C2.flatten()[-20:])
#print(f'inference vs training abs: {err1}')
#
print(f'inference vs training abs: {err1}')
#print(f'inference vs training rel: {relerr1}')
#
print(f'inference vs training rel: {relerr1}')
#print(f'inference vs training max: {maxerr1}')
#
print(f'inference vs training max: {maxerr1}')
#print(f'inference vs training vs torch err ratio abs: {absratio}')
#
print(f'inference vs training vs torch err ratio abs: {absratio}')
#print(f'inference vs training vs torch err ratio rel: {relratio}')
#
print(f'inference vs training vs torch err ratio rel: {relratio}')
#print(f'inference vs training vs torch err ratio max: {maxratio}')
#
print(f'inference vs training vs torch err ratio max: {maxratio}')
if
dtype
==
torch
.
float16
:
if
dtype
==
torch
.
float16
:
if
dim
<=
512
:
if
dim
<=
512
:
assert
err1
<
7e-5
assert
err1
<
7e-5
...
@@ -2283,56 +2187,59 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
...
@@ -2283,56 +2187,59 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
assert
relratio
<
1.04
and
relratio
>
0.96
assert
relratio
<
1.04
and
relratio
>
0.96
assert
maxratio
<
1.02
and
maxratio
>
0.98
assert
maxratio
<
1.02
and
maxratio
>
0.98
@
pytest
.
mark
.
skip
(
"Row scale has some bugs for ampere"
)
@
pytest
.
mark
.
skip
(
"Row scale has some bugs for ampere"
)
def
test_managed
():
def
test_managed
():
n
=
32
*
10
n
=
32
*
10
A
=
F
.
get_paged
(
n
,
n
,
dtype
=
torch
.
float32
)
A
=
F
.
get_paged
(
n
,
n
,
dtype
=
torch
.
float32
)
B
=
F
.
get_paged
(
n
,
n
,
dtype
=
torch
.
uint8
)
B
=
F
.
get_paged
(
n
,
n
,
dtype
=
torch
.
uint8
)
B2
=
F
.
get_paged
(
n
,
n
,
dtype
=
torch
.
float32
)
B2
=
F
.
get_paged
(
n
,
n
,
dtype
=
torch
.
float32
)
assert
A
.
is_paged
assert
A
.
is_paged
assert
B
.
is_paged
assert
B
.
is_paged
assert
A
.
page_deviceid
==
0
assert
A
.
page_deviceid
==
0
assert
B
.
page_deviceid
==
0
assert
B
.
page_deviceid
==
0
F
.
fill
(
A
,
17.0
)
F
.
fill
(
A
,
17.0
)
F
.
fill
(
B
,
17
)
F
.
fill
(
B
,
17
)
F
.
fill
(
B2
,
2
)
F
.
fill
(
B2
,
2
)
assert
(
A
==
17
).
sum
().
item
()
==
n
*
n
assert
(
A
==
17
).
sum
().
item
()
==
n
*
n
assert
(
B
==
17
).
sum
().
item
()
==
n
*
n
assert
(
B
==
17
).
sum
().
item
()
==
n
*
n
C
=
A
*
B
.
float
()
C
=
A
*
B
.
float
()
assert
(
C
==
289
).
sum
().
item
()
==
n
*
n
assert
(
C
==
289
).
sum
().
item
()
==
n
*
n
F
.
_mul
(
A
,
B2
)
F
.
_mul
(
A
,
B2
)
F
.
_mul
(
A
,
B2
)
F
.
_mul
(
A
,
B2
)
F
.
_mul
(
A
,
B2
)
F
.
_mul
(
A
,
B2
)
assert
(
A
==
17
*
(
2
**
3
)).
sum
().
item
()
==
n
*
n
assert
(
A
==
17
*
(
2
**
3
)).
sum
().
item
()
==
n
*
n
# F.prefetch_tensor(A)
# F.prefetch_tensor(B)
# F.prefetch_tensor(A)
# F.prefetch_tensor(B)
# F.fill(B2, 17.0)
# F.fill(B2, 17.0)
# F._mul(A, B2)
# F._mul(A, B2)
# F.prefetch_tensor(A, to_cpu=True)
# F.prefetch_tensor(A, to_cpu=True)
# F.prefetch_tensor(B, to_cpu=True)
# F.prefetch_tensor(B, to_cpu=True)
# F.prefetch_tensor(B2, to_cpu=True)
# F.prefetch_tensor(B2, to_cpu=True)
# torch.cuda.synchronize()
# torch.cuda.synchronize()
# assert (A==17).sum().item() == n*n
# assert (A==17).sum().item() == n*n
# torch.testing.assert_close(A, torch.ones(A.shape)*289)
# torch.testing.assert_close(A, torch.ones(A.shape)*289)
@
pytest
.
mark
.
parametrize
(
"storage_type"
,
[
'
nf4
'
,
'
fp4
'
],
ids
=
[
'
nf4
'
,
'
fp4
'
])
@
pytest
.
mark
.
parametrize
(
"storage_type"
,
[
"
nf4
"
,
"
fp4
"
],
ids
=
[
"
nf4
"
,
"
fp4
"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"double_quant"
,
[
False
],
ids
=
[
'
DQ_True
'
])
@
pytest
.
mark
.
parametrize
(
"double_quant"
,
[
False
],
ids
=
[
"
DQ_True
"
])
def
test_gemv_eye_4bit
(
storage_type
,
dtype
,
double_quant
):
def
test_gemv_eye_4bit
(
storage_type
,
dtype
,
double_quant
):
dims
=
10
dims
=
10
torch
.
random
.
manual_seed
(
np
.
random
.
randint
(
0
,
412424242
))
torch
.
random
.
manual_seed
(
np
.
random
.
randint
(
0
,
412424242
))
dims
=
get_test_dims
(
0
,
8192
,
n
=
dims
)
dims
=
get_test_dims
(
0
,
8192
,
n
=
dims
)
dims
=
[
dim
+
(
64
-
(
dim
%
64
))
for
dim
in
dims
]
dims
=
[
dim
+
(
64
-
(
dim
%
64
))
for
dim
in
dims
]
#for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
#
for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
for
dim
in
dims
:
for
dim
in
dims
:
A
=
torch
.
normal
(
0
,
0.1
,
size
=
(
1
,
1
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
A
=
torch
.
normal
(
0
,
0.1
,
size
=
(
1
,
1
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
B
=
torch
.
eye
(
dim
,
dtype
=
dtype
,
device
=
'
cuda
'
)
B
=
torch
.
eye
(
dim
,
dtype
=
dtype
,
device
=
"
cuda
"
)
qB
,
state
=
F
.
quantize_4bit
(
B
,
quant_type
=
storage_type
,
compress_statistics
=
double_quant
)
qB
,
state
=
F
.
quantize_4bit
(
B
,
quant_type
=
storage_type
,
compress_statistics
=
double_quant
)
C3
=
torch
.
matmul
(
A
,
B
.
t
())
C3
=
torch
.
matmul
(
A
,
B
.
t
())
...
@@ -2343,5 +2250,5 @@ def test_gemv_eye_4bit(storage_type, dtype, double_quant):
...
@@ -2343,5 +2250,5 @@ def test_gemv_eye_4bit(storage_type, dtype, double_quant):
torch
.
testing
.
assert_close
(
A
,
C3
)
torch
.
testing
.
assert_close
(
A
,
C3
)
torch
.
testing
.
assert_close
(
A
,
C1
)
torch
.
testing
.
assert_close
(
A
,
C1
)
torch
.
testing
.
assert_close
(
A
,
C2
)
torch
.
testing
.
assert_close
(
A
,
C2
)
#torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001)
#
torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001)
#torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080)
#
torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080)
tests/test_generation.py
View file @
06029dd6
...
@@ -10,56 +10,61 @@ transformers = pytest.importorskip("transformers")
...
@@ -10,56 +10,61 @@ transformers = pytest.importorskip("transformers")
def
get_4bit_config
():
def
get_4bit_config
():
return
transformers
.
BitsAndBytesConfig
(
return
transformers
.
BitsAndBytesConfig
(
load_in_4bit
=
True
,
load_in_4bit
=
True
,
load_in_8bit
=
False
,
load_in_8bit
=
False
,
llm_int8_threshold
=
6.0
,
llm_int8_threshold
=
6.0
,
llm_int8_has_fp16_weight
=
False
,
llm_int8_has_fp16_weight
=
False
,
bnb_4bit_compute_dtype
=
torch
.
float16
,
bnb_4bit_compute_dtype
=
torch
.
float16
,
bnb_4bit_use_double_quant
=
True
,
bnb_4bit_use_double_quant
=
True
,
bnb_4bit_quant_type
=
'
nf4
'
,
bnb_4bit_quant_type
=
"
nf4
"
,
)
)
def
get_model_and_tokenizer
(
config
):
def
get_model_and_tokenizer
(
config
):
model_name_or_path
,
quant_type
=
config
model_name_or_path
,
quant_type
=
config
bnb_config
=
get_4bit_config
()
bnb_config
=
get_4bit_config
()
if
quant_type
==
'
16bit
'
:
if
quant_type
==
"
16bit
"
:
bnb_config
.
load_in_4bit
=
False
bnb_config
.
load_in_4bit
=
False
else
:
else
:
bnb_config
.
bnb_4bit_quant_type
=
quant_type
bnb_config
.
bnb_4bit_quant_type
=
quant_type
model
=
transformers
.
AutoModelForCausalLM
.
from_pretrained
(
model_name_or_path
,
model
=
transformers
.
AutoModelForCausalLM
.
from_pretrained
(
model_name_or_path
,
quantization_config
=
bnb_config
,
quantization_config
=
bnb_config
,
max_memory
=
{
0
:
'
48GB
'
},
max_memory
=
{
0
:
"
48GB
"
},
device_map
=
'
auto
'
,
device_map
=
"
auto
"
,
torch_dtype
=
torch
.
bfloat16
torch_dtype
=
torch
.
bfloat16
,
).
eval
()
).
eval
()
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_name_or_path
)
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_name_or_path
)
return
model
,
tokenizer
return
model
,
tokenizer
def
get_prompt_for_generation_eval
(
text
,
add_roles
=
True
):
def
get_prompt_for_generation_eval
(
text
,
add_roles
=
True
):
description
=
(
description
=
(
"A chat between a curious human and an artificial intelligence assistant. "
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
"The assistant gives helpful, detailed, and polite answers to the user's questions."
)
)
if
add_roles
:
if
add_roles
:
prompt
=
f
'
{
description
}
### Human:
{
text
}
### Assistant:
'
prompt
=
f
"
{
description
}
### Human:
{
text
}
### Assistant:
"
else
:
else
:
prompt
=
f
'
{
description
}
{
text
}
'
prompt
=
f
"
{
description
}
{
text
}
"
return
prompt
return
prompt
def
generate
(
model
,
tokenizer
,
text
,
generation_config
,
prompt_func
=
get_prompt_for_generation_eval
):
def
generate
(
model
,
tokenizer
,
text
,
generation_config
,
prompt_func
=
get_prompt_for_generation_eval
):
text
=
prompt_func
(
text
)
text
=
prompt_func
(
text
)
inputs
=
tokenizer
(
text
,
return_tensors
=
"pt"
).
to
(
'
cuda:0
'
)
inputs
=
tokenizer
(
text
,
return_tensors
=
"pt"
).
to
(
"
cuda:0
"
)
outputs
=
model
.
generate
(
inputs
=
inputs
[
'
input_ids
'
],
generation_config
=
generation_config
)
outputs
=
model
.
generate
(
inputs
=
inputs
[
"
input_ids
"
],
generation_config
=
generation_config
)
return
tokenizer
.
decode
(
outputs
[
0
],
skip_special_tokens
=
True
)
return
tokenizer
.
decode
(
outputs
[
0
],
skip_special_tokens
=
True
)
models
=
[
'huggyllama/llama-7b'
,
'bigscience/bloom-1b7'
]
dtypes
=
[
'nf4'
,
'fp4'
]
@
pytest
.
fixture
(
scope
=
'session'
,
params
=
product
(
models
,
dtypes
))
models
=
[
"huggyllama/llama-7b"
,
"bigscience/bloom-1b7"
]
dtypes
=
[
"nf4"
,
"fp4"
]
@
pytest
.
fixture
(
scope
=
"session"
,
params
=
product
(
models
,
dtypes
))
def
model_and_tokenizer
(
request
):
def
model_and_tokenizer
(
request
):
model
,
tokenizer
=
get_model_and_tokenizer
(
request
.
param
)
model
,
tokenizer
=
get_model_and_tokenizer
(
request
.
param
)
yield
request
.
param
,
model
,
tokenizer
yield
request
.
param
,
model
,
tokenizer
...
@@ -81,20 +86,19 @@ def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype):
...
@@ -81,20 +86,19 @@ def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype):
)
)
generation_config
.
max_new_tokens
=
20
generation_config
.
max_new_tokens
=
20
# text = 'Please write down the first 50 digits of pi.'
#text = 'Please write down the first 50 digits of pi.'
# text = get_prompt_for_generation_eval(text)
#text = get_prompt_for_generation_eval(text)
# text += ' Sure, here the first 50 digits of pi: 3.14159'
#text += ' Sure, here the first 50 digits of pi: 3.14159'
n_cases
=
6
n_cases
=
6
text
=
'
3.14159
'
text
=
"
3.14159
"
if
hasattr
(
model
.
config
,
'
quantization_config
'
):
if
hasattr
(
model
.
config
,
"
quantization_config
"
):
model
.
config
.
quantization_config
.
bnb_4bit_compute_dtype
=
dtype
model
.
config
.
quantization_config
.
bnb_4bit_compute_dtype
=
dtype
model
.
config
.
quantization_config
.
bnb_4bit_use_double_quant
=
DQ
model
.
config
.
quantization_config
.
bnb_4bit_use_double_quant
=
DQ
if
not
inference_kernel
:
if
not
inference_kernel
:
text
=
[
text
]
*
n_cases
text
=
[
text
]
*
n_cases
inputs
=
tokenizer
(
text
,
return_tensors
=
"pt"
).
to
(
'
cuda:0
'
)
inputs
=
tokenizer
(
text
,
return_tensors
=
"pt"
).
to
(
"
cuda:0
"
)
x
=
inputs
[
'
input_ids
'
]
x
=
inputs
[
"
input_ids
"
]
outputs
=
[]
outputs
=
[]
if
inference_kernel
:
if
inference_kernel
:
for
i
in
range
(
n_cases
):
for
i
in
range
(
n_cases
):
...
@@ -105,15 +109,14 @@ def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype):
...
@@ -105,15 +109,14 @@ def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype):
outputs
=
model
.
generate
(
x
,
generation_config
=
generation_config
)
outputs
=
model
.
generate
(
x
,
generation_config
=
generation_config
)
outputs
=
[
tokenizer
.
decode
(
output
,
skip_special_tokens
=
True
)
for
output
in
outputs
]
outputs
=
[
tokenizer
.
decode
(
output
,
skip_special_tokens
=
True
)
for
output
in
outputs
]
assert
len
(
outputs
)
==
n_cases
assert
len
(
outputs
)
==
n_cases
failure_count
=
0
failure_count
=
0
for
i
in
range
(
n_cases
):
for
i
in
range
(
n_cases
):
if
not
outputs
[
i
][:
len
(
str
(
math
.
pi
))]
==
str
(
math
.
pi
):
if
not
outputs
[
i
][:
len
(
str
(
math
.
pi
))]
==
str
(
math
.
pi
):
failure_count
+=
1
failure_count
+=
1
failure_max
=
(
2
if
fixture_config
[
0
]
==
'
huggyllama/llama-7b
'
else
4
)
failure_max
=
2
if
fixture_config
[
0
]
==
"
huggyllama/llama-7b
"
else
4
if
failure_count
>
failure_max
:
if
failure_count
>
failure_max
:
print
(
math
.
pi
)
print
(
math
.
pi
)
for
out
in
outputs
:
for
out
in
outputs
:
print
(
out
)
print
(
out
)
raise
ValueError
(
f
'
Failure count:
{
failure_count
}
/
{
n_cases
}
'
)
raise
ValueError
(
f
"
Failure count:
{
failure_count
}
/
{
n_cases
}
"
)
tests/test_linear4bit.py
View file @
06029dd6
...
@@ -28,9 +28,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
...
@@ -28,9 +28,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
device
=
"cuda"
device
=
"cuda"
layer_shape
=
(
300
,
400
)
layer_shape
=
(
300
,
400
)
linear
=
torch
.
nn
.
Linear
(
linear
=
torch
.
nn
.
Linear
(
*
layer_shape
,
dtype
=
original_dtype
,
device
=
"cpu"
)
# original layer
*
layer_shape
,
dtype
=
original_dtype
,
device
=
"cpu"
)
# original layer
# Quantizing original layer
# Quantizing original layer
linear_q
=
bnb
.
nn
.
Linear4bit
(
linear_q
=
bnb
.
nn
.
Linear4bit
(
...
@@ -42,9 +40,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
...
@@ -42,9 +40,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
quant_type
=
quant_type
,
quant_type
=
quant_type
,
device
=
"meta"
,
device
=
"meta"
,
)
)
new_weight
=
bnb
.
nn
.
Params4bit
(
new_weight
=
bnb
.
nn
.
Params4bit
(
data
=
linear
.
weight
,
quant_type
=
quant_type
,
requires_grad
=
False
)
data
=
linear
.
weight
,
quant_type
=
quant_type
,
requires_grad
=
False
)
linear_q
.
weight
=
new_weight
linear_q
.
weight
=
new_weight
if
bias
:
if
bias
:
linear_q
.
bias
=
torch
.
nn
.
Parameter
(
linear
.
bias
)
linear_q
.
bias
=
torch
.
nn
.
Parameter
(
linear
.
bias
)
...
@@ -172,7 +168,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
...
@@ -172,7 +168,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
target_compression
=
(
target_compression
=
(
0.143
if
original_dtype
==
torch
.
float32
else
0.29
0.143
if
original_dtype
==
torch
.
float32
else
0.29
)
# these numbers get lower as weight shape increases
)
# these numbers get lower as weight shape increases
ratio_error_msg
=
f
"quantized_size
{
size_4
:,
}
is larger on disk than
{
target_compression
:.
2
%
}
of original size
{
size_orig
:,
}
"
ratio_error_msg
=
(
f
"quantized_size
{
size_4
:,
}
is larger on disk than
{
target_compression
:.
2
%
}
of original size
{
size_orig
:,
}
"
)
assert
size_ratio
<
target_compression
,
ratio_error_msg
assert
size_ratio
<
target_compression
,
ratio_error_msg
...
...
tests/test_linear8bitlt.py
View file @
06029dd6
...
@@ -19,6 +19,7 @@ from tests.helpers import (
...
@@ -19,6 +19,7 @@ from tests.helpers import (
# contributed by Alex Borzunov, see:
# contributed by Alex Borzunov, see:
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
()
or
torch
.
cuda
.
get_device_capability
()
<
(
7
,
5
),
not
torch
.
cuda
.
is_available
()
or
torch
.
cuda
.
get_device_capability
()
<
(
7
,
5
),
reason
=
"this test requires a turing-generation or newer GPU, see bitsandbytes docs"
,
reason
=
"this test requires a turing-generation or newer GPU, see bitsandbytes docs"
,
...
@@ -50,7 +51,9 @@ def test_linear_no_igemmlt():
...
@@ -50,7 +51,9 @@ def test_linear_no_igemmlt():
linear_custom
.
state
.
force_no_igemmlt
=
True
linear_custom
.
state
.
force_no_igemmlt
=
True
linear_custom
.
weight
=
bnb
.
nn
.
Int8Params
(
linear_custom
.
weight
=
bnb
.
nn
.
Int8Params
(
linear
.
weight
.
data
.
clone
(),
requires_grad
=
False
,
has_fp16_weights
=
False
linear
.
weight
.
data
.
clone
(),
requires_grad
=
False
,
has_fp16_weights
=
False
,
).
to
(
linear
.
weight
.
dtype
)
).
to
(
linear
.
weight
.
dtype
)
linear_custom
.
bias
=
linear
.
bias
linear_custom
.
bias
=
linear
.
bias
linear_custom
=
linear_custom
.
cuda
()
linear_custom
=
linear_custom
.
cuda
()
...
@@ -77,7 +80,14 @@ def test_linear_no_igemmlt():
...
@@ -77,7 +80,14 @@ def test_linear_no_igemmlt():
@
pytest
.
mark
.
parametrize
(
"force_no_igemmlt"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"force_no_igemmlt"
))
@
pytest
.
mark
.
parametrize
(
"force_no_igemmlt"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"force_no_igemmlt"
))
@
pytest
.
mark
.
parametrize
(
"save_before_forward"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"save_before_forward"
))
@
pytest
.
mark
.
parametrize
(
"save_before_forward"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"save_before_forward"
))
@
pytest
.
mark
.
parametrize
(
"load_before_cuda"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"load_before_cuda"
))
@
pytest
.
mark
.
parametrize
(
"load_before_cuda"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"load_before_cuda"
))
def
test_linear_serialization
(
has_fp16_weights
,
serialize_before_forward
,
deserialize_before_cuda
,
force_no_igemmlt
,
save_before_forward
,
load_before_cuda
):
def
test_linear_serialization
(
has_fp16_weights
,
serialize_before_forward
,
deserialize_before_cuda
,
force_no_igemmlt
,
save_before_forward
,
load_before_cuda
,
):
linear
=
torch
.
nn
.
Linear
(
32
,
96
)
linear
=
torch
.
nn
.
Linear
(
32
,
96
)
x
=
torch
.
randn
(
3
,
32
,
dtype
=
torch
.
half
)
x
=
torch
.
randn
(
3
,
32
,
dtype
=
torch
.
half
)
...
@@ -92,7 +102,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
...
@@ -92,7 +102,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
linear_custom
.
state
.
force_no_igemmlt
=
True
linear_custom
.
state
.
force_no_igemmlt
=
True
linear_custom
.
weight
=
bnb
.
nn
.
Int8Params
(
linear_custom
.
weight
=
bnb
.
nn
.
Int8Params
(
linear
.
weight
.
data
.
clone
(),
requires_grad
=
has_fp16_weights
,
has_fp16_weights
=
has_fp16_weights
linear
.
weight
.
data
.
clone
(),
requires_grad
=
has_fp16_weights
,
has_fp16_weights
=
has_fp16_weights
,
)
)
linear_custom
.
bias
=
linear
.
bias
linear_custom
.
bias
=
linear
.
bias
linear_custom
=
linear_custom
.
cuda
()
linear_custom
=
linear_custom
.
cuda
()
...
...
tests/test_modules.py
View file @
06029dd6
...
@@ -19,12 +19,18 @@ class MLP8bit(torch.nn.Module):
...
@@ -19,12 +19,18 @@ class MLP8bit(torch.nn.Module):
def
__init__
(
self
,
dim1
,
dim2
,
has_fp16_weights
=
True
,
memory_efficient_backward
=
False
,
threshold
=
0.0
):
def
__init__
(
self
,
dim1
,
dim2
,
has_fp16_weights
=
True
,
memory_efficient_backward
=
False
,
threshold
=
0.0
):
super
().
__init__
()
super
().
__init__
()
self
.
fc1
=
bnb
.
nn
.
Linear8bitLt
(
self
.
fc1
=
bnb
.
nn
.
Linear8bitLt
(
dim1
,
dim2
,
has_fp16_weights
=
has_fp16_weights
,
memory_efficient_backward
=
memory_efficient_backward
,
dim1
,
threshold
=
threshold
dim2
,
has_fp16_weights
=
has_fp16_weights
,
memory_efficient_backward
=
memory_efficient_backward
,
threshold
=
threshold
,
)
)
self
.
fc2
=
bnb
.
nn
.
Linear8bitLt
(
self
.
fc2
=
bnb
.
nn
.
Linear8bitLt
(
dim2
,
dim1
,
has_fp16_weights
=
has_fp16_weights
,
memory_efficient_backward
=
memory_efficient_backward
,
dim2
,
threshold
=
threshold
dim1
,
has_fp16_weights
=
has_fp16_weights
,
memory_efficient_backward
=
memory_efficient_backward
,
threshold
=
threshold
,
)
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -52,9 +58,7 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
...
@@ -52,9 +58,7 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
class
LinearFunction
(
torch
.
autograd
.
Function
):
class
LinearFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
get_8bit_linear_trimmed
(
x
,
stochastic
=
False
,
trim_value
=
3.0
):
def
get_8bit_linear_trimmed
(
x
,
stochastic
=
False
,
trim_value
=
3.0
):
round_func
=
(
round_func
=
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
)
norm
=
math
.
sqrt
(
math
.
pi
)
/
math
.
sqrt
(
2.0
)
norm
=
math
.
sqrt
(
math
.
pi
)
/
math
.
sqrt
(
2.0
)
# std = torch.abs(x).mean()*norm
# std = torch.abs(x).mean()*norm
std
=
torch
.
std
(
x
)
std
=
torch
.
std
(
x
)
...
@@ -122,9 +126,7 @@ class LinearFunction(torch.autograd.Function):
...
@@ -122,9 +126,7 @@ class LinearFunction(torch.autograd.Function):
return
x
.
to
(
dtype
)
return
x
.
to
(
dtype
)
def
get_8bit_linear
(
x
,
stochastic
=
False
):
def
get_8bit_linear
(
x
,
stochastic
=
False
):
round_func
=
(
round_func
=
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
)
max1
=
torch
.
abs
(
x
).
max
()
max1
=
torch
.
abs
(
x
).
max
()
x
=
x
/
max1
*
127
x
=
x
/
max1
*
127
x
=
round_func
(
x
)
/
127
*
max1
x
=
round_func
(
x
)
/
127
*
max1
...
@@ -133,9 +135,7 @@ class LinearFunction(torch.autograd.Function):
...
@@ -133,9 +135,7 @@ class LinearFunction(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
get_8bit_vector_wise
(
x
,
dim
,
stochastic
=
False
):
def
get_8bit_vector_wise
(
x
,
dim
,
stochastic
=
False
):
round_func
=
(
round_func
=
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
)
max1
=
torch
.
amax
(
torch
.
abs
(
x
),
dim
=
dim
,
keepdim
=
True
)
max1
=
torch
.
amax
(
torch
.
abs
(
x
),
dim
=
dim
,
keepdim
=
True
)
max1
[
max1
==
0
]
=
1.0
max1
[
max1
==
0
]
=
1.0
x
=
(
x
*
127
)
/
max1
x
=
(
x
*
127
)
/
max1
...
@@ -219,9 +219,7 @@ class LinearFunction(torch.autograd.Function):
...
@@ -219,9 +219,7 @@ class LinearFunction(torch.autograd.Function):
weight8
,
S1
=
LinearFunction
.
quant
(
weight
,
args
.
quant_type
,
dim
=
1
)
weight8
,
S1
=
LinearFunction
.
quant
(
weight
,
args
.
quant_type
,
dim
=
1
)
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
2
)
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
2
)
outputq
=
bnb
.
functional
.
igemm
(
x8
,
weight8
.
t
())
outputq
=
bnb
.
functional
.
igemm
(
x8
,
weight8
.
t
())
output
=
LinearFunction
.
dequant
(
output
=
LinearFunction
.
dequant
(
outputq
,
S1
,
S2
,
x
.
dtype
,
args
.
quant_type
)
outputq
,
S1
,
S2
,
x
.
dtype
,
args
.
quant_type
)
# if torch.rand(1) < 0.01:
# if torch.rand(1) < 0.01:
# output32 = torch.matmul(x, weight.t())
# output32 = torch.matmul(x, weight.t())
# err = torch.abs(output-output32).float()
# err = torch.abs(output-output32).float()
...
@@ -250,37 +248,25 @@ class LinearFunction(torch.autograd.Function):
...
@@ -250,37 +248,25 @@ class LinearFunction(torch.autograd.Function):
# weight and x are already 8bit
# weight and x are already 8bit
# -> transform grad_output to 8-bit
# -> transform grad_output to 8-bit
if
args
.
use_8bit_training
==
"forward+wgrad"
:
if
args
.
use_8bit_training
==
"forward+wgrad"
:
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output
,
args
.
quant_type
,
dim
=
[
0
,
1
])
grad_output
,
args
.
quant_type
,
dim
=
[
0
,
1
]
)
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
[
0
,
1
])
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
[
0
,
1
])
grad_weight8
=
bnb
.
functional
.
igemm
(
grad_output8
,
x8
)
grad_weight8
=
bnb
.
functional
.
igemm
(
grad_output8
,
x8
)
grad_weight
=
LinearFunction
.
dequant
(
grad_weight
=
LinearFunction
.
dequant
(
grad_weight8
,
S1
,
S2
,
grad_output
.
dtype
,
args
.
quant_type
)
grad_weight8
,
S1
,
S2
,
grad_output
.
dtype
,
args
.
quant_type
)
# grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
# grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
grad_input
=
grad_output
.
matmul
(
weight
)
grad_input
=
grad_output
.
matmul
(
weight
)
elif
args
.
use_8bit_training
==
"full"
:
elif
args
.
use_8bit_training
==
"full"
:
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output
,
args
.
quant_type
,
dim
=
[
0
,
1
])
grad_output
,
args
.
quant_type
,
dim
=
[
0
,
1
]
)
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
[
0
,
1
])
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
[
0
,
1
])
grad_weight8
=
torch
.
zeros_like
(
weight
,
dtype
=
torch
.
int32
)
grad_weight8
=
torch
.
zeros_like
(
weight
,
dtype
=
torch
.
int32
)
bnb
.
functional
.
igemm
(
grad_output8
,
x8
,
out
=
grad_weight8
)
bnb
.
functional
.
igemm
(
grad_output8
,
x8
,
out
=
grad_weight8
)
grad_weight
=
LinearFunction
.
dequant
(
grad_weight
=
LinearFunction
.
dequant
(
grad_weight8
,
S1
,
S2
,
grad_output
.
dtype
,
args
.
quant_type
)
grad_weight8
,
S1
,
S2
,
grad_output
.
dtype
,
args
.
quant_type
)
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output
,
args
.
quant_type
,
dim
=
2
)
grad_output
,
args
.
quant_type
,
dim
=
2
)
weight8
,
S3
=
LinearFunction
.
quant
(
weight
,
args
.
quant_type
,
dim
=
0
)
weight8
,
S3
=
LinearFunction
.
quant
(
weight
,
args
.
quant_type
,
dim
=
0
)
grad_input8
=
bnb
.
functional
.
igemm
(
grad_output8
,
weight8
)
grad_input8
=
bnb
.
functional
.
igemm
(
grad_output8
,
weight8
)
grad_input
=
LinearFunction
.
dequant
(
grad_input
=
LinearFunction
.
dequant
(
grad_input8
,
S1
,
S3
,
grad_output
.
dtype
,
args
.
quant_type
)
grad_input8
,
S1
,
S3
,
grad_output
.
dtype
,
args
.
quant_type
)
else
:
else
:
grad_input
=
grad_output
.
matmul
(
weight
)
grad_input
=
grad_output
.
matmul
(
weight
)
...
@@ -356,12 +342,8 @@ def test_linear8bitlt_accumulated_gradient():
...
@@ -356,12 +342,8 @@ def test_linear8bitlt_accumulated_gradient():
opt1
.
zero_grad
(
True
)
opt1
.
zero_grad
(
True
)
opt2
.
step
()
opt2
.
step
()
opt2
.
zero_grad
(
True
)
opt2
.
zero_grad
(
True
)
assert_all_approx_close
(
assert_all_approx_close
(
l1
[
0
].
weight
,
l2
[
0
].
weight
,
rtol
=
1.05
,
atol
=
0.01
,
count
=
2
)
l1
[
0
].
weight
,
l2
[
0
].
weight
,
rtol
=
1.05
,
atol
=
0.01
,
count
=
2
assert_all_approx_close
(
l1
[
1
].
weight
,
l2
[
1
].
weight
,
rtol
=
1.05
,
atol
=
0.01
,
count
=
2
)
)
assert_all_approx_close
(
l1
[
1
].
weight
,
l2
[
1
].
weight
,
rtol
=
1.05
,
atol
=
0.01
,
count
=
2
)
# we do this copy because otherwise we have small divergences over time that add up
# we do this copy because otherwise we have small divergences over time that add up
l1
[
0
].
weight
.
data
.
copy_
(
l2
[
0
].
weight
.
data
)
l1
[
0
].
weight
.
data
.
copy_
(
l2
[
0
].
weight
.
data
)
l1
[
1
].
weight
.
data
.
copy_
(
l2
[
1
].
weight
.
data
)
l1
[
1
].
weight
.
data
.
copy_
(
l2
[
1
].
weight
.
data
)
...
@@ -375,7 +357,17 @@ def test_linear8bitlt_accumulated_gradient():
...
@@ -375,7 +357,17 @@ def test_linear8bitlt_accumulated_gradient():
@
pytest
.
mark
.
parametrize
(
"threshold"
,
[
0.0
,
2.0
])
@
pytest
.
mark
.
parametrize
(
"threshold"
,
[
0.0
,
2.0
])
@
pytest
.
mark
.
parametrize
(
"memory_efficient_backward"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"memory_efficient_backward"
,
[
False
])
def
test_linear8bitlt_no_fp16_weights
(
threshold
,
memory_efficient_backward
):
def
test_linear8bitlt_no_fp16_weights
(
threshold
,
memory_efficient_backward
):
l1
=
(
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
).
cuda
().
half
())
l1
=
(
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
,
)
.
cuda
()
.
half
()
)
assert
l1
.
weight
.
dtype
==
torch
.
int8
assert
l1
.
weight
.
dtype
==
torch
.
int8
l1
.
eval
()
l1
.
eval
()
...
@@ -397,11 +389,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
...
@@ -397,11 +389,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
if
threshold
>
0
:
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
assert
mlp
.
fc2
.
state
.
idx
is
not
None
mlp
=
(
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
cuda
().
half
()
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
)
.
cuda
()
.
half
()
)
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
...
@@ -414,11 +402,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
...
@@ -414,11 +402,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
if
threshold
>
0
:
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
assert
mlp
.
fc2
.
state
.
idx
is
not
None
mlp
=
(
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
half
().
cuda
()
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
)
.
half
()
.
cuda
()
)
for
i
in
range
(
100
):
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
...
@@ -431,7 +415,17 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
...
@@ -431,7 +415,17 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
mlp
=
(
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
).
half
().
to
(
"cuda"
))
mlp
=
(
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
,
)
.
half
()
.
to
(
"cuda"
)
)
for
i
in
range
(
100
):
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
...
@@ -447,8 +441,12 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
...
@@ -447,8 +441,12 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
assert
mlp
.
fc2
.
weight
.
device
.
type
==
"cuda"
assert
mlp
.
fc2
.
weight
.
device
.
type
==
"cuda"
mlp
=
MLP8bit
(
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
32
,
)
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
,
)
w1
,
w2
=
mlp
.
fc1
.
weight
.
clone
().
cuda
(),
mlp
.
fc2
.
weight
.
clone
().
cuda
()
# grab weights before quantization,
w1
,
w2
=
mlp
.
fc1
.
weight
.
clone
().
cuda
(),
mlp
.
fc2
.
weight
.
clone
().
cuda
()
# grab weights before quantization,
mlp
=
mlp
.
cuda
().
half
()
# and this line triggers quantization
mlp
=
mlp
.
cuda
().
half
()
# and this line triggers quantization
...
@@ -489,7 +487,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
...
@@ -489,7 +487,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
lambda
n_in
,
n_out
,
bias
=
True
:
bnb
.
nn
.
Linear8bitLt
(
n_in
,
n_out
,
bias
=
bias
,
has_fp16_weights
=
False
),
lambda
n_in
,
n_out
,
bias
=
True
:
bnb
.
nn
.
Linear8bitLt
(
n_in
,
n_out
,
bias
=
bias
,
has_fp16_weights
=
False
),
bnb
.
nn
.
LinearFP4
,
bnb
.
nn
.
LinearFP4
,
],
],
ids
=
[
'
Int8Lt
'
,
'
FP4
'
],
ids
=
[
"
Int8Lt
"
,
"
FP4
"
],
)
)
def
test_linear_kbit_fp32_bias
(
module
):
def
test_linear_kbit_fp32_bias
(
module
):
# casts model to fp16 -> int8 automatically
# casts model to fp16 -> int8 automatically
...
@@ -544,7 +542,7 @@ def test_kbit_backprop(module):
...
@@ -544,7 +542,7 @@ def test_kbit_backprop(module):
kbit
[
1
].
bias
.
detach
().
copy_
(
ref
[
1
].
bias
)
kbit
[
1
].
bias
.
detach
().
copy_
(
ref
[
1
].
bias
)
ref
=
ref
.
half
().
cuda
()
ref
=
ref
.
half
().
cuda
()
kbit
=
kbit
.
half
().
cuda
()
kbit
=
kbit
.
half
().
cuda
()
kbit
=
kbit
.
half
().
to
(
'
cuda
'
)
kbit
=
kbit
.
half
().
to
(
"
cuda
"
)
errs1
=
[]
errs1
=
[]
errs2
=
[]
errs2
=
[]
...
@@ -562,10 +560,10 @@ def test_kbit_backprop(module):
...
@@ -562,10 +560,10 @@ def test_kbit_backprop(module):
bgrad1
=
ref
[
0
].
bias
.
grad
bgrad1
=
ref
[
0
].
bias
.
grad
bgrad2
=
kbit
[
0
].
bias
.
grad
bgrad2
=
kbit
[
0
].
bias
.
grad
err1
=
(
out1
-
out2
).
abs
().
float
()
err1
=
(
out1
-
out2
).
abs
().
float
()
err2
=
(
grad1
-
grad2
).
abs
().
float
()
err2
=
(
grad1
-
grad2
).
abs
().
float
()
relerr1
=
(
err1
/
(
out1
.
abs
().
float
()
+
1e-9
)
)
relerr1
=
err1
/
(
out1
.
abs
().
float
()
+
1e-9
)
relerr2
=
(
err2
/
(
grad1
.
abs
().
float
()
+
1e-9
)
)
relerr2
=
err2
/
(
grad1
.
abs
().
float
()
+
1e-9
)
errs1
.
append
(
err1
.
mean
().
item
())
errs1
.
append
(
err1
.
mean
().
item
())
errs2
.
append
(
err2
.
mean
().
item
())
errs2
.
append
(
err2
.
mean
().
item
())
relerrs1
.
append
(
relerr1
.
mean
().
item
())
relerrs1
.
append
(
relerr1
.
mean
().
item
())
...
@@ -582,20 +580,20 @@ def test_kbit_backprop(module):
...
@@ -582,20 +580,20 @@ def test_kbit_backprop(module):
assert
kbit
[
0
].
weight
.
grad
is
None
or
kbit
[
0
].
weight
.
grad
.
sum
().
item
()
==
0
assert
kbit
[
0
].
weight
.
grad
is
None
or
kbit
[
0
].
weight
.
grad
.
sum
().
item
()
==
0
assert
kbit
[
0
].
weight
.
grad
is
None
or
kbit
[
0
].
bias
.
grad
.
sum
().
item
()
==
0
assert
kbit
[
0
].
weight
.
grad
is
None
or
kbit
[
0
].
bias
.
grad
.
sum
().
item
()
==
0
#print('out', sum(errs1)/len(errs1))
#
print('out', sum(errs1)/len(errs1))
#print('grad', sum(errs2)/len(errs2))
#
print('grad', sum(errs2)/len(errs2))
#print('rel out', sum(relerrs1)/len(relerrs1))
#
print('rel out', sum(relerrs1)/len(relerrs1))
#print('rel grad', sum(relerrs2)/len(relerrs2))
#
print('rel grad', sum(relerrs2)/len(relerrs2))
def
test_fp8linear
():
def
test_fp8linear
():
b
=
10
b
=
10
h
=
1024
h
=
1024
inp
=
torch
.
randn
(
b
,
h
).
cuda
()
inp
=
torch
.
randn
(
b
,
h
).
cuda
()
fp32
=
torch
.
nn
.
Linear
(
h
,
h
*
2
).
cuda
()
fp32
=
torch
.
nn
.
Linear
(
h
,
h
*
2
).
cuda
()
fp8
=
bnb
.
research
.
nn
.
LinearFP8Mixed
(
h
,
h
*
2
).
cuda
()
fp8
=
bnb
.
research
.
nn
.
LinearFP8Mixed
(
h
,
h
*
2
).
cuda
()
fp32b
=
torch
.
nn
.
Linear
(
h
*
2
,
h
).
cuda
()
fp32b
=
torch
.
nn
.
Linear
(
h
*
2
,
h
).
cuda
()
fp8b
=
bnb
.
research
.
nn
.
LinearFP8Mixed
(
h
*
2
,
h
).
cuda
()
fp8b
=
bnb
.
research
.
nn
.
LinearFP8Mixed
(
h
*
2
,
h
).
cuda
()
fp8
.
weight
.
data
.
copy_
(
fp32
.
weight
.
data
)
fp8
.
weight
.
data
.
copy_
(
fp32
.
weight
.
data
)
fp8
.
bias
.
data
.
copy_
(
fp32
.
bias
.
data
)
fp8
.
bias
.
data
.
copy_
(
fp32
.
bias
.
data
)
...
@@ -605,34 +603,34 @@ def test_fp8linear():
...
@@ -605,34 +603,34 @@ def test_fp8linear():
a
=
fp32b
(
torch
.
nn
.
functional
.
gelu
(
fp32
(
inp
)))
a
=
fp32b
(
torch
.
nn
.
functional
.
gelu
(
fp32
(
inp
)))
b
=
fp8b
(
torch
.
nn
.
functional
.
gelu
(
fp8
(
inp
)))
b
=
fp8b
(
torch
.
nn
.
functional
.
gelu
(
fp8
(
inp
)))
err
=
(
a
-
b
).
abs
().
mean
()
err
=
(
a
-
b
).
abs
().
mean
()
a
.
mean
().
backward
()
a
.
mean
().
backward
()
b
.
mean
().
backward
()
b
.
mean
().
backward
()
graderr
=
(
fp8
.
weight
.
grad
-
fp32
.
weight
.
grad
).
abs
().
mean
()
graderr
=
(
fp8
.
weight
.
grad
-
fp32
.
weight
.
grad
).
abs
().
mean
()
bgraderr
=
(
fp8
.
bias
.
grad
-
fp32
.
bias
.
grad
).
abs
().
mean
()
bgraderr
=
(
fp8
.
bias
.
grad
-
fp32
.
bias
.
grad
).
abs
().
mean
()
assert
err
<
0.05
assert
err
<
0.05
assert
graderr
<
0.00002
assert
graderr
<
0.00002
assert
bgraderr
<
0.00002
assert
bgraderr
<
0.00002
def
test_4bit_warnings
():
def
test_4bit_warnings
():
dim1
=
64
dim1
=
64
with
pytest
.
warns
(
UserWarning
,
match
=
r
'
inference or training
'
):
with
pytest
.
warns
(
UserWarning
,
match
=
r
"
inference or training
"
):
net
=
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear4bit
(
dim1
,
dim1
,
compute_dtype
=
torch
.
float32
)
for
i
in
range
(
10
)])
net
=
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear4bit
(
dim1
,
dim1
,
compute_dtype
=
torch
.
float32
)
for
i
in
range
(
10
)])
net
=
net
.
cuda
()
net
=
net
.
cuda
()
inp
=
torch
.
rand
(
10
,
dim1
).
cuda
().
half
()
inp
=
torch
.
rand
(
10
,
dim1
).
cuda
().
half
()
net
(
inp
)
net
(
inp
)
with
pytest
.
warns
(
UserWarning
,
match
=
r
'
inference.
'
):
with
pytest
.
warns
(
UserWarning
,
match
=
r
"
inference.
"
):
net
=
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear4bit
(
dim1
,
dim1
,
compute_dtype
=
torch
.
float32
)
for
i
in
range
(
10
)])
net
=
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear4bit
(
dim1
,
dim1
,
compute_dtype
=
torch
.
float32
)
for
i
in
range
(
10
)])
net
=
net
.
cuda
()
net
=
net
.
cuda
()
inp
=
torch
.
rand
(
1
,
dim1
).
cuda
().
half
()
inp
=
torch
.
rand
(
1
,
dim1
).
cuda
().
half
()
net
(
inp
)
net
(
inp
)
with
pytest
.
warns
(
UserWarning
)
as
record
:
with
pytest
.
warns
(
UserWarning
)
as
record
:
net
=
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear4bit
(
dim1
,
dim1
,
compute_dtype
=
torch
.
float32
)
for
i
in
range
(
10
)])
net
=
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear4bit
(
dim1
,
dim1
,
compute_dtype
=
torch
.
float32
)
for
i
in
range
(
10
)])
net
=
net
.
cuda
()
net
=
net
.
cuda
()
inp
=
torch
.
rand
(
10
,
dim1
).
cuda
().
half
()
inp
=
torch
.
rand
(
10
,
dim1
).
cuda
().
half
()
...
...
tests/test_optim.py
View file @
06029dd6
...
@@ -16,6 +16,7 @@ from tests.helpers import describe_dtype, id_formatter
...
@@ -16,6 +16,7 @@ from tests.helpers import describe_dtype, id_formatter
k
=
20
k
=
20
def
assert_most_approx_close
(
a
,
b
,
rtol
=
1e-3
,
atol
=
1e-3
,
max_error_count
=
0
):
def
assert_most_approx_close
(
a
,
b
,
rtol
=
1e-3
,
atol
=
1e-3
,
max_error_count
=
0
):
idx
=
torch
.
isclose
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
)
idx
=
torch
.
isclose
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
)
error_count
=
(
idx
==
0
).
sum
().
item
()
error_count
=
(
idx
==
0
).
sum
().
item
()
...
@@ -33,6 +34,7 @@ def get_temp_dir():
...
@@ -33,6 +34,7 @@ def get_temp_dir():
def
rm_path
(
path
):
def
rm_path
(
path
):
shutil
.
rmtree
(
path
)
shutil
.
rmtree
(
path
)
str2optimizers
=
{}
str2optimizers
=
{}
str2optimizers
[
"adam_pytorch"
]
=
(
None
,
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
str2optimizers
[
"adam_pytorch"
]
=
(
None
,
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
str2optimizers
[
"lion_pytorch"
]
=
(
None
,
Lion
,
bnb
.
optim
.
Lion
)
str2optimizers
[
"lion_pytorch"
]
=
(
None
,
Lion
,
bnb
.
optim
.
Lion
)
...
@@ -66,8 +68,14 @@ str2optimizers["rmsprop8bit"] = (
...
@@ -66,8 +68,14 @@ str2optimizers["rmsprop8bit"] = (
)
)
str2optimizers
[
"adam8bit_blockwise"
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
"adam8bit_blockwise"
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
"paged_adamw8bit_blockwise"
]
=
(
torch
.
optim
.
AdamW
,
lambda
pxx
:
bnb
.
optim
.
PagedAdamW8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
"paged_adamw8bit_blockwise"
]
=
(
str2optimizers
[
"paged_adam8bit_blockwise"
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
PagedAdam8bit
(
pxx
,
block_wise
=
True
))
torch
.
optim
.
AdamW
,
lambda
pxx
:
bnb
.
optim
.
PagedAdamW8bit
(
pxx
,
block_wise
=
True
),
)
str2optimizers
[
"paged_adam8bit_blockwise"
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
PagedAdam8bit
(
pxx
,
block_wise
=
True
),
)
str2optimizers
[
"lion8bit_blockwise"
]
=
(
Lion
,
lambda
pxx
:
bnb
.
optim
.
Lion8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
"lion8bit_blockwise"
]
=
(
Lion
,
lambda
pxx
:
bnb
.
optim
.
Lion8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
"paged_lion8bit_blockwise"
]
=
(
Lion
,
lambda
pxx
:
bnb
.
optim
.
PagedLion8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
"paged_lion8bit_blockwise"
]
=
(
Lion
,
lambda
pxx
:
bnb
.
optim
.
PagedLion8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
"momentum8bit_blockwise"
]
=
(
str2optimizers
[
"momentum8bit_blockwise"
]
=
(
...
@@ -90,9 +98,18 @@ str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
...
@@ -90,9 +98,18 @@ str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames
[
"rmsprop"
]
=
[(
"square_avg"
,
"state1"
)]
str2statenames
[
"rmsprop"
]
=
[(
"square_avg"
,
"state1"
)]
str2statenames
[
"adam8bit"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"max2"
)]
str2statenames
[
"adam8bit"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"max2"
)]
str2statenames
[
"lamb8bit"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"max2"
)]
str2statenames
[
"lamb8bit"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"max2"
)]
str2statenames
[
"adam8bit_blockwise"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"absmax2"
)]
str2statenames
[
"adam8bit_blockwise"
]
=
[
str2statenames
[
"paged_adam8bit_blockwise"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"absmax2"
)]
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
),
str2statenames
[
"paged_adamw8bit_blockwise"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"absmax2"
)]
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"absmax2"
),
]
str2statenames
[
"paged_adam8bit_blockwise"
]
=
[
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"absmax2"
),
]
str2statenames
[
"paged_adamw8bit_blockwise"
]
=
[
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"absmax2"
),
]
str2statenames
[
"momentum8bit"
]
=
[(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"momentum8bit"
]
=
[(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"lion8bit"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"lion8bit"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"momentum8bit_blockwise"
]
=
[(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"absmax1"
)]
str2statenames
[
"momentum8bit_blockwise"
]
=
[(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"absmax1"
)]
...
@@ -101,7 +118,7 @@ str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "ab
...
@@ -101,7 +118,7 @@ str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "ab
str2statenames
[
"lion8bit_blockwise"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
)]
str2statenames
[
"lion8bit_blockwise"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
)]
str2statenames
[
"paged_lion8bit_blockwise"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
)]
str2statenames
[
"paged_lion8bit_blockwise"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
)]
optimizer_names_32bit
=
[
"adam"
,
"momentum"
,
"rmsprop"
,
'
paged_adamw
'
,
'
paged_adam
'
,
'
lion
'
,
'
paged_lion
'
]
optimizer_names_32bit
=
[
"adam"
,
"momentum"
,
"rmsprop"
,
"
paged_adamw
"
,
"
paged_adam
"
,
"
lion
"
,
"
paged_lion
"
]
@
pytest
.
mark
.
parametrize
(
"optim_name"
,
optimizer_names_32bit
,
ids
=
id_formatter
(
"opt"
))
@
pytest
.
mark
.
parametrize
(
"optim_name"
,
optimizer_names_32bit
,
ids
=
id_formatter
(
"opt"
))
...
@@ -109,7 +126,7 @@ optimizer_names_32bit = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_ad
...
@@ -109,7 +126,7 @@ optimizer_names_32bit = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_ad
@
pytest
.
mark
.
parametrize
(
"dim1"
,
[
1024
],
ids
=
id_formatter
(
"dim1"
))
@
pytest
.
mark
.
parametrize
(
"dim1"
,
[
1024
],
ids
=
id_formatter
(
"dim1"
))
@
pytest
.
mark
.
parametrize
(
"dim2"
,
[
32
,
1024
,
4097
,
1
],
ids
=
id_formatter
(
"dim2"
))
@
pytest
.
mark
.
parametrize
(
"dim2"
,
[
32
,
1024
,
4097
,
1
],
ids
=
id_formatter
(
"dim2"
))
def
test_optimizer32bit
(
dim1
,
dim2
,
gtype
,
optim_name
):
def
test_optimizer32bit
(
dim1
,
dim2
,
gtype
,
optim_name
):
if
gtype
==
torch
.
bfloat16
and
optim_name
in
[
'
momentum
'
,
'
rmsprop
'
]:
if
gtype
==
torch
.
bfloat16
and
optim_name
in
[
"
momentum
"
,
"
rmsprop
"
]:
pytest
.
skip
()
pytest
.
skip
()
if
dim1
==
1
and
dim2
==
1
:
if
dim1
==
1
and
dim2
==
1
:
return
return
...
@@ -161,9 +178,13 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
...
@@ -161,9 +178,13 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
for
name1
,
name2
in
str2statenames
[
optim_name
]:
for
name1
,
name2
in
str2statenames
[
optim_name
]:
# since Lion can have pretty noisy updates where things lie at the boundary
# since Lion can have pretty noisy updates where things lie at the boundary
# allow up to 10 errors for Lion
# allow up to 10 errors for Lion
assert_most_approx_close
(
torch_optimizer
.
state
[
p1
][
name1
],
bnb_optimizer
.
state
[
p2
][
name2
],
assert_most_approx_close
(
atol
=
atol
,
rtol
=
rtol
,
torch_optimizer
.
state
[
p1
][
name1
],
max_error_count
=
10
)
bnb_optimizer
.
state
[
p2
][
name2
],
atol
=
atol
,
rtol
=
rtol
,
max_error_count
=
10
,
)
if
gtype
!=
torch
.
float32
:
if
gtype
!=
torch
.
float32
:
# the adam buffers should also be close because they are 32-bit
# the adam buffers should also be close because they are 32-bit
...
@@ -193,13 +214,9 @@ def test_global_config(dim1, dim2, gtype):
...
@@ -193,13 +214,9 @@ def test_global_config(dim1, dim2, gtype):
eps
=
1e-8
eps
=
1e-8
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
initialize
()
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
initialize
()
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
override_config
(
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
override_config
(
p3
,
"optim_bits"
,
8
)
p3
,
"optim_bits"
,
8
)
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
register_parameters
(
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
register_parameters
([
p1
,
p2
,
p3
])
[
p1
,
p2
,
p3
]
)
p1
=
p1
.
cuda
()
p1
=
p1
.
cuda
()
p2
=
p2
.
cuda
()
p2
=
p2
.
cuda
()
p3
=
p3
.
cuda
()
p3
=
p3
.
cuda
()
...
@@ -242,7 +259,8 @@ optimizer_names_8bit = [
...
@@ -242,7 +259,8 @@ optimizer_names_8bit = [
@
pytest
.
mark
.
parametrize
(
"dim2"
,
[
32
,
1024
,
4097
],
ids
=
id_formatter
(
"dim2"
))
@
pytest
.
mark
.
parametrize
(
"dim2"
,
[
32
,
1024
,
4097
],
ids
=
id_formatter
(
"dim2"
))
@
pytest
.
mark
.
parametrize
(
"dim1"
,
[
1024
],
ids
=
id_formatter
(
"dim1"
))
@
pytest
.
mark
.
parametrize
(
"dim1"
,
[
1024
],
ids
=
id_formatter
(
"dim1"
))
def
test_optimizer8bit
(
dim1
,
dim2
,
gtype
,
optim_name
):
def
test_optimizer8bit
(
dim1
,
dim2
,
gtype
,
optim_name
):
if
gtype
==
torch
.
bfloat16
and
optim_name
not
in
[
'adam8bit_blockwise'
,
'lion8bit_blockwise'
]:
pytest
.
skip
()
if
gtype
==
torch
.
bfloat16
and
optim_name
not
in
[
"adam8bit_blockwise"
,
"lion8bit_blockwise"
]:
pytest
.
skip
()
if
dim1
==
1
and
dim2
==
1
:
if
dim1
==
1
and
dim2
==
1
:
return
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
...
@@ -294,17 +312,12 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
...
@@ -294,17 +312,12 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
absmax
=
bnb_optimizer
.
state
[
p2
][
max_val
],
absmax
=
bnb_optimizer
.
state
[
p2
][
max_val
],
A
=
bnb_optimizer
.
state
[
p2
][
name2
],
A
=
bnb_optimizer
.
state
[
p2
][
name2
],
)
)
num_not_close
=
(
num_not_close
=
torch
.
isclose
(
torch_optimizer
.
state
[
p1
][
name1
],
s1
,
atol
=
atol
,
rtol
=
rtol
)
==
0
torch
.
isclose
(
# assert num_not_close.sum().item() < 20
torch_optimizer
.
state
[
p1
][
name1
],
s1
,
atol
=
atol
,
rtol
=
rtol
)
==
0
)
#assert num_not_close.sum().item() < 20
dequant_states
.
append
(
s1
.
clone
())
dequant_states
.
append
(
s1
.
clone
())
err
=
torch
.
abs
(
p1
-
p2
)
err
=
torch
.
abs
(
p1
-
p2
)
relerr
=
err
/
(
torch
.
abs
(
p1
)
+
1e-9
)
relerr
=
err
/
(
torch
.
abs
(
p1
)
+
1e-9
)
if
g
.
dtype
==
torch
.
bfloat16
:
if
g
.
dtype
==
torch
.
bfloat16
:
assert
err
.
mean
()
<
0.00015
assert
err
.
mean
()
<
0.00015
assert
relerr
.
mean
()
<
0.0016
assert
relerr
.
mean
()
<
0.0016
...
@@ -316,9 +329,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
...
@@ -316,9 +329,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
relerrors
.
append
(
relerr
.
mean
().
item
())
relerrors
.
append
(
relerr
.
mean
().
item
())
if
i
%
10
==
0
and
i
>
0
:
if
i
%
10
==
0
and
i
>
0
:
for
(
name1
,
name2
,
qmap
,
max_val
),
s
in
zip
(
for
(
name1
,
name2
,
qmap
,
max_val
),
s
in
zip
(
str2statenames
[
optim_name
],
dequant_states
):
str2statenames
[
optim_name
],
dequant_states
):
s1cpy
=
s
.
clone
()
s1cpy
=
s
.
clone
()
raws1cpy
=
bnb_optimizer
.
state
[
p2
][
name2
].
clone
()
raws1cpy
=
bnb_optimizer
.
state
[
p2
][
name2
].
clone
()
qmap1
=
bnb_optimizer
.
state
[
p2
][
qmap
].
clone
()
qmap1
=
bnb_optimizer
.
state
[
p2
][
qmap
].
clone
()
...
@@ -348,7 +359,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
...
@@ -348,7 +359,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
)
)
torch
.
testing
.
assert_close
(
s1cpy
,
s1
)
torch
.
testing
.
assert_close
(
s1cpy
,
s1
)
num_not_close
=
(
torch
.
isclose
(
torch_optimizer
.
state
[
p1
][
name1
],
s1
,
atol
=
atol
,
rtol
=
rtol
)
==
0
)
num_not_close
=
torch
.
isclose
(
torch_optimizer
.
state
[
p1
][
name1
],
s1
,
atol
=
atol
,
rtol
=
rtol
)
==
0
assert
num_not_close
.
sum
().
item
()
<
20
assert
num_not_close
.
sum
().
item
()
<
20
# since Lion can have pretty noisy updates where things lie at the boundary
# since Lion can have pretty noisy updates where things lie at the boundary
# allow up to 5 errors for Lion
# allow up to 5 errors for Lion
...
@@ -395,15 +406,11 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
...
@@ -395,15 +406,11 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
for
i
in
range
(
50
):
for
i
in
range
(
50
):
step
+=
1
step
+=
1
g1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
+
(
g1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
+
(
0.01
*
i
)
0.01
*
i
)
g2
=
g1
.
clone
()
g2
=
g1
.
clone
()
p2
.
grad
=
g2
p2
.
grad
=
g2
current_gnorm
,
clip_val
,
gnorm_scale
=
F
.
percentile_clipping
(
current_gnorm
,
clip_val
,
gnorm_scale
=
F
.
percentile_clipping
(
g1
,
gnorm_vec
,
step
,
5
)
g1
,
gnorm_vec
,
step
,
5
)
g1
=
(
g1
.
float
()
*
gnorm_scale
).
to
(
gtype
)
g1
=
(
g1
.
float
()
*
gnorm_scale
).
to
(
gtype
)
p1
.
grad
=
g1
p1
.
grad
=
g1
...
@@ -497,8 +504,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
...
@@ -497,8 +504,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
@
pytest
.
mark
.
parametrize
(
"dim1"
,
[
2
*
1024
],
ids
=
id_formatter
(
"dim1"
))
@
pytest
.
mark
.
parametrize
(
"dim1"
,
[
2
*
1024
],
ids
=
id_formatter
(
"dim1"
))
@
pytest
.
mark
.
parametrize
(
"gtype"
,
[
torch
.
float16
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"gtype"
,
[
torch
.
float16
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"optim_name"
,
[
'
paged_adamw
'
],
ids
=
id_formatter
(
"optim_name"
))
@
pytest
.
mark
.
parametrize
(
"optim_name"
,
[
"
paged_adamw
"
],
ids
=
id_formatter
(
"optim_name"
))
@
pytest
.
mark
.
parametrize
(
"mode"
,
[
'
bnb
'
],
ids
=
id_formatter
(
"mode"
))
@
pytest
.
mark
.
parametrize
(
"mode"
,
[
"
bnb
"
],
ids
=
id_formatter
(
"mode"
))
@
pytest
.
mark
.
benchmark
@
pytest
.
mark
.
benchmark
def
test_stream_optimizer_bench
(
dim1
,
gtype
,
optim_name
,
mode
):
def
test_stream_optimizer_bench
(
dim1
,
gtype
,
optim_name
,
mode
):
layers1
=
torch
.
nn
.
Sequential
(
*
torch
.
nn
.
ModuleList
([
torch
.
nn
.
Linear
(
dim1
,
dim1
)
for
i
in
range
(
10
)]))
layers1
=
torch
.
nn
.
Sequential
(
*
torch
.
nn
.
ModuleList
([
torch
.
nn
.
Linear
(
dim1
,
dim1
)
for
i
in
range
(
10
)]))
...
@@ -506,24 +513,24 @@ def test_stream_optimizer_bench(dim1, gtype, optim_name, mode):
...
@@ -506,24 +513,24 @@ def test_stream_optimizer_bench(dim1, gtype, optim_name, mode):
layers1
=
layers1
.
cuda
()
layers1
=
layers1
.
cuda
()
large_tensor
=
None
large_tensor
=
None
if
mode
==
'
torch
'
:
if
mode
==
"
torch
"
:
optim
=
str2optimizers
[
optim_name
][
0
](
layers1
.
parameters
())
optim
=
str2optimizers
[
optim_name
][
0
](
layers1
.
parameters
())
else
:
else
:
optim
=
str2optimizers
[
optim_name
][
1
](
layers1
.
parameters
())
optim
=
str2optimizers
[
optim_name
][
1
](
layers1
.
parameters
())
# 12 GB
# 12 GB
large_tensor
=
torch
.
empty
((
int
(
4.5e9
),),
device
=
'
cuda
'
)
large_tensor
=
torch
.
empty
((
int
(
4.5e9
),),
device
=
"
cuda
"
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
time
.
sleep
(
5
)
time
.
sleep
(
5
)
num_batches
=
5
num_batches
=
5
batches
=
torch
.
randn
(
num_batches
,
128
,
dim1
,
device
=
'
cuda
'
).
to
(
gtype
)
batches
=
torch
.
randn
(
num_batches
,
128
,
dim1
,
device
=
"
cuda
"
).
to
(
gtype
)
lbls
=
torch
.
randint
(
0
,
10
,
size
=
(
num_batches
,
128
)).
cuda
()
lbls
=
torch
.
randint
(
0
,
10
,
size
=
(
num_batches
,
128
)).
cuda
()
for
i
in
range
(
num_batches
):
for
i
in
range
(
num_batches
):
print
(
i
)
print
(
i
)
b
=
batches
[
i
]
b
=
batches
[
i
]
if
i
==
2
:
if
i
==
2
:
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment