Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
4ea489d3
"src/diffusers/utils/dynamic_modules_utils.py" did not exist on "78744b6a8f3c9dd4800e1b279cc37930dfd77048"
Commit
4ea489d3
authored
Apr 03, 2023
by
Tim Dettmers
Browse files
Refactor FP4 into 4Bit and integrate NF4 data type.
parent
64cc0592
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
145 additions
and
90 deletions
+145
-90
bitsandbytes/__init__.py
bitsandbytes/__init__.py
+1
-1
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+3
-3
bitsandbytes/functional.py
bitsandbytes/functional.py
+11
-10
bitsandbytes/nn/__init__.py
bitsandbytes/nn/__init__.py
+1
-1
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+17
-9
csrc/kernels.cu
csrc/kernels.cu
+47
-40
tests/test_autograd.py
tests/test_autograd.py
+8
-7
tests/test_functional.py
tests/test_functional.py
+26
-16
tests/test_modules.py
tests/test_modules.py
+31
-3
No files found.
bitsandbytes/__init__.py
View file @
4ea489d3
...
@@ -10,7 +10,7 @@ from .autograd._functions import (
...
@@ -10,7 +10,7 @@ from .autograd._functions import (
matmul
,
matmul
,
matmul_cublas
,
matmul_cublas
,
mm_cublas
,
mm_cublas
,
matmul_
fp4
matmul_
4bit
)
)
from
.cextension
import
COMPILED_WITH_CUDA
from
.cextension
import
COMPILED_WITH_CUDA
from
.nn
import
modules
from
.nn
import
modules
...
...
bitsandbytes/autograd/_functions.py
View file @
4ea489d3
...
@@ -475,7 +475,7 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -475,7 +475,7 @@ class MatMul8bitLt(torch.autograd.Function):
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
class
MatMul
FP4
(
torch
.
autograd
.
Function
):
class
MatMul
4Bit
(
torch
.
autograd
.
Function
):
# forward is the same, but we added the fallback for pre-turing GPUs
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
...
@@ -547,6 +547,6 @@ def matmul(
...
@@ -547,6 +547,6 @@ def matmul(
return
MatMul8bitLt
.
apply
(
A
,
B
,
out
,
bias
,
state
)
return
MatMul8bitLt
.
apply
(
A
,
B
,
out
,
bias
,
state
)
def
matmul_
fp4
(
A
:
tensor
,
B
:
tensor
,
quant_state
:
List
,
out
:
tensor
=
None
,
bias
=
None
):
def
matmul_
4bit
(
A
:
tensor
,
B
:
tensor
,
quant_state
:
List
,
out
:
tensor
=
None
,
bias
=
None
):
assert
quant_state
is
not
None
assert
quant_state
is
not
None
return
MatMul
FP4
.
apply
(
A
,
B
,
out
,
bias
,
quant_state
)
return
MatMul
4Bit
.
apply
(
A
,
B
,
out
,
bias
,
quant_state
)
bitsandbytes/functional.py
View file @
4ea489d3
...
@@ -689,14 +689,14 @@ def dequantize_blockwise(
...
@@ -689,14 +689,14 @@ def dequantize_blockwise(
return
out
return
out
def
quantize_fp4
(
A
:
Tensor
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
=
64
,
compress_statistics
=
False
):
def
quantize_fp4
(
A
:
Tensor
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
=
64
,
compress_statistics
=
False
):
return
quantize_4bit
_packed
(
A
,
absmax
,
out
,
blocksize
,
compress_statistics
,
'fp4'
)
return
quantize_4bit
(
A
,
absmax
,
out
,
blocksize
,
compress_statistics
,
'fp4'
)
def
quantize_nf4
(
A
:
Tensor
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
=
64
,
compress_statistics
=
False
):
def
quantize_nf4
(
A
:
Tensor
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
=
64
,
compress_statistics
=
False
):
return
quantize_4bit
_packed
(
A
,
absmax
,
out
,
blocksize
,
compress_statistics
,
'nf4'
)
return
quantize_4bit
(
A
,
absmax
,
out
,
blocksize
,
compress_statistics
,
'nf4'
)
def
quantize_4bit
_packed
(
A
:
Tensor
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
=
64
,
compress_statistics
=
False
,
quant_type
=
'fp4'
)
->
Tensor
:
def
quantize_4bit
(
A
:
Tensor
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
=
64
,
compress_statistics
=
False
,
quant_type
=
'fp4'
)
->
Tensor
:
"""
"""
Quantize tensor A in blocks of
FP4
values.
Quantize tensor A in blocks of
4-bit
values.
Quantizes tensor A by dividing it into blocks which are independently quantized to FP4.
Quantizes tensor A by dividing it into blocks which are independently quantized to FP4.
...
@@ -763,19 +763,19 @@ def quantize_4bit_packed(A: Tensor, absmax: Tensor = None, out: Tensor = None, b
...
@@ -763,19 +763,19 @@ def quantize_4bit_packed(A: Tensor, absmax: Tensor = None, out: Tensor = None, b
#qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256)
#qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256)
qabsmax
,
state2
=
quantize_blockwise
(
absmax
,
blocksize
=
256
)
qabsmax
,
state2
=
quantize_blockwise
(
absmax
,
blocksize
=
256
)
del
absmax
del
absmax
state
=
(
qabsmax
,
input_shape
,
A
.
dtype
,
blocksize
,
(
offset
,
state2
))
state
=
(
qabsmax
,
input_shape
,
A
.
dtype
,
blocksize
,
(
offset
,
state2
)
,
quant_type
)
else
:
else
:
state
=
(
absmax
,
input_shape
,
A
.
dtype
,
blocksize
,
None
)
state
=
(
absmax
,
input_shape
,
A
.
dtype
,
blocksize
,
None
,
quant_type
)
return
out
,
state
return
out
,
state
def
dequantize_fp4
(
A
:
Tensor
,
quant_state
:
Tuple
[
Tensor
,
Tensor
]
=
None
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
:
int
=
64
)
->
Tensor
:
def
dequantize_fp4
(
A
:
Tensor
,
quant_state
:
Tuple
[
Tensor
,
Tensor
]
=
None
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
:
int
=
64
)
->
Tensor
:
return
dequantize_4bit
_packed
(
A
,
quant_state
,
absmax
,
out
,
blocksize
,
'fp4'
)
return
dequantize_4bit
(
A
,
quant_state
,
absmax
,
out
,
blocksize
,
'fp4'
)
def
dequantize_nf4
(
A
:
Tensor
,
quant_state
:
Tuple
[
Tensor
,
Tensor
]
=
None
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
:
int
=
64
)
->
Tensor
:
def
dequantize_nf4
(
A
:
Tensor
,
quant_state
:
Tuple
[
Tensor
,
Tensor
]
=
None
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
:
int
=
64
)
->
Tensor
:
return
dequantize_4bit
_packed
(
A
,
quant_state
,
absmax
,
out
,
blocksize
,
'nf4'
)
return
dequantize_4bit
(
A
,
quant_state
,
absmax
,
out
,
blocksize
,
'nf4'
)
def
dequantize_4bit
_packed
(
A
:
Tensor
,
quant_state
:
Tuple
[
Tensor
,
Tensor
]
=
None
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
:
int
=
64
,
quant_type
=
'fp4'
)
->
Tensor
:
def
dequantize_4bit
(
A
:
Tensor
,
quant_state
:
Tuple
[
Tensor
,
Tensor
]
=
None
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
:
int
=
64
,
quant_type
=
'fp4'
)
->
Tensor
:
"""
"""
Dequantizes FP4 blockwise quantized values.
Dequantizes FP4 blockwise quantized values.
...
@@ -812,7 +812,8 @@ def dequantize_4bit_packed(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None,
...
@@ -812,7 +812,8 @@ def dequantize_4bit_packed(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None,
shape
=
out
.
shape
shape
=
out
.
shape
dtype
=
out
.
dtype
dtype
=
out
.
dtype
else
:
else
:
absmax
,
shape
,
dtype
,
blocksize
,
compressed_stats
=
quant_state
absmax
,
shape
,
dtype
,
blocksize
,
compressed_stats
,
quant_type
=
quant_state
if
compressed_stats
is
not
None
:
if
compressed_stats
is
not
None
:
offset
,
state2
=
compressed_stats
offset
,
state2
=
compressed_stats
...
...
bitsandbytes/nn/__init__.py
View file @
4ea489d3
...
@@ -2,4 +2,4 @@
...
@@ -2,4 +2,4 @@
#
#
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
.modules
import
Int8Params
,
Linear8bitLt
,
StableEmbedding
,
Linear
FP4
,
FP4Params
from
.modules
import
Int8Params
,
Linear8bitLt
,
StableEmbedding
,
Linear
4bit
,
LinearNF4
,
Linear
FP4
,
Params
4bit
bitsandbytes/nn/modules.py
View file @
4ea489d3
...
@@ -133,18 +133,19 @@ class Embedding(torch.nn.Embedding):
...
@@ -133,18 +133,19 @@ class Embedding(torch.nn.Embedding):
return
emb
return
emb
class
FP4
Params
(
torch
.
nn
.
Parameter
):
class
Params
4bit
(
torch
.
nn
.
Parameter
):
def
__new__
(
cls
,
data
=
None
,
requires_grad
=
True
,
quant_state
=
None
,
blocksize
=
64
,
compress_statistics
=
True
):
def
__new__
(
cls
,
data
=
None
,
requires_grad
=
True
,
quant_state
=
None
,
blocksize
=
64
,
compress_statistics
=
True
,
quant_type
=
'fp4'
):
cls
.
quant_state
=
None
cls
.
quant_state
=
None
cls
.
blocksize
=
blocksize
cls
.
blocksize
=
blocksize
cls
.
compress_statistics
=
compress_statistics
cls
.
compress_statistics
=
compress_statistics
cls
.
quant_type
=
quant_type
if
data
is
None
:
if
data
is
None
:
data
=
torch
.
empty
(
0
)
data
=
torch
.
empty
(
0
)
return
torch
.
Tensor
.
_make_subclass
(
cls
,
data
,
requires_grad
)
return
torch
.
Tensor
.
_make_subclass
(
cls
,
data
,
requires_grad
)
def
cuda
(
self
,
device
):
def
cuda
(
self
,
device
):
w
=
self
.
data
.
contiguous
().
half
().
cuda
(
device
)
w
=
self
.
data
.
contiguous
().
half
().
cuda
(
device
)
w_fp4
,
quant_state
=
bnb
.
functional
.
quantize_
fp4
(
w
,
blocksize
=
self
.
blocksize
,
compress_statistics
=
self
.
compress_statistics
)
w_fp4
,
quant_state
=
bnb
.
functional
.
quantize_
4bit
(
w
,
blocksize
=
self
.
blocksize
,
compress_statistics
=
self
.
compress_statistics
,
quant_type
=
self
.
quant_type
)
self
.
data
=
w_fp4
self
.
data
=
w_fp4
self
.
quant_state
=
quant_state
self
.
quant_state
=
quant_state
...
@@ -168,17 +169,16 @@ class FP4Params(torch.nn.Parameter):
...
@@ -168,17 +169,16 @@ class FP4Params(torch.nn.Parameter):
if
(
device
is
not
None
and
device
.
type
==
"cuda"
and
self
.
data
.
device
.
type
==
"cpu"
):
if
(
device
is
not
None
and
device
.
type
==
"cuda"
and
self
.
data
.
device
.
type
==
"cpu"
):
return
self
.
cuda
(
device
)
return
self
.
cuda
(
device
)
else
:
else
:
new_param
=
FP4
Params
(
super
().
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
),
new_param
=
Params
4bit
(
super
().
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
),
requires_grad
=
self
.
requires_grad
,
quant_state
=
self
.
quant_state
)
requires_grad
=
self
.
requires_grad
,
quant_state
=
self
.
quant_state
)
return
new_param
return
new_param
class
Linear4bit
(
nn
.
Linear
):
class
LinearFP4
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
,
quant_type
=
'fp4'
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
):
super
().
__init__
(
input_features
,
output_features
,
bias
)
super
().
__init__
(
input_features
,
output_features
,
bias
)
self
.
state
=
bnb
.
MatmulLtState
()
self
.
state
=
bnb
.
MatmulLtState
()
self
.
weight
=
FP4
Params
(
self
.
weight
.
data
,
requires_grad
=
False
,
compress_statistics
=
compress_statistics
)
self
.
weight
=
Params
4bit
(
self
.
weight
.
data
,
requires_grad
=
False
,
compress_statistics
=
compress_statistics
,
quant_type
=
quant_type
)
self
.
compute_dtype
=
compute_dtype
self
.
compute_dtype
=
compute_dtype
def
init_8bit_state
(
self
):
def
init_8bit_state
(
self
):
...
@@ -198,12 +198,20 @@ class LinearFP4(nn.Linear):
...
@@ -198,12 +198,20 @@ class LinearFP4(nn.Linear):
x
=
x
.
to
(
self
.
compute_dtype
)
x
=
x
.
to
(
self
.
compute_dtype
)
bias
=
None
if
self
.
bias
is
None
else
self
.
bias
.
half
()
bias
=
None
if
self
.
bias
is
None
else
self
.
bias
.
half
()
out
=
bnb
.
matmul_
fp4
(
x
,
self
.
weight
.
t
(),
bias
=
bias
,
quant_state
=
self
.
weight
.
quant_state
)
out
=
bnb
.
matmul_
4bit
(
x
,
self
.
weight
.
t
(),
bias
=
bias
,
quant_state
=
self
.
weight
.
quant_state
)
out
=
out
.
to
(
inp_dtype
)
out
=
out
.
to
(
inp_dtype
)
return
out
return
out
class
LinearFP4
(
Linear4bit
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
):
super
().
__init__
(
input_features
,
output_features
,
bias
,
compute_dtype
,
compress_statistics
,
'fp4'
)
class
LinearNF4
(
Linear4bit
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
):
super
().
__init__
(
input_features
,
output_features
,
bias
,
compute_dtype
,
compress_statistics
,
'nf4'
)
class
Int8Params
(
torch
.
nn
.
Parameter
):
class
Int8Params
(
torch
.
nn
.
Parameter
):
def
__new__
(
def
__new__
(
...
...
csrc/kernels.cu
View file @
4ea489d3
...
@@ -194,7 +194,7 @@ __device__ float dDequantizeNF4(unsigned char val, float absmax)
...
@@ -194,7 +194,7 @@ __device__ float dDequantizeNF4(unsigned char val, float absmax)
}
}
__device__
unsigned
char
dQuantizeN
ormal
(
float
x
)
__device__
unsigned
char
dQuantizeN
F4
(
float
x
)
{
{
// the values for this tree was generated by test_normal_map_tree
// the values for this tree was generated by test_normal_map_tree
...
@@ -221,7 +221,7 @@ __device__ unsigned char dQuantizeNormal(float x)
...
@@ -221,7 +221,7 @@ __device__ unsigned char dQuantizeNormal(float x)
if
(
x
>
0.1202552504837513
f
)
// 100
if
(
x
>
0.1202552504837513
f
)
// 100
return
0b1001
;
return
0b1001
;
else
else
return
0b1
1
00
;
return
0b1
0
00
;
else
else
if
(
x
>
-
0.33967943489551544
f
)
// 0
if
(
x
>
-
0.33967943489551544
f
)
// 0
if
(
x
>
-
0.13791173323988914
f
)
// 01
if
(
x
>
-
0.13791173323988914
f
)
// 01
...
@@ -726,8 +726,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
...
@@ -726,8 +726,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
#pragma unroll NUM_PER_TH
#pragma unroll NUM_PER_TH
for
(
int
j
=
0
;
j
<
NUM_PER_TH
/
2
;
j
++
)
for
(
int
j
=
0
;
j
<
NUM_PER_TH
/
2
;
j
++
)
{
{
packed_4bit
|=
dQuantizeN
ormal
(((
float
)
vals
[
2
*
j
])
*
local_abs_max
)
<<
4
;
packed_4bit
|=
dQuantizeN
F4
(((
float
)
vals
[
2
*
j
])
*
local_abs_max
)
<<
4
;
packed_4bit
|=
dQuantizeN
ormal
(((
float
)
vals
[
2
*
j
+
1
])
*
local_abs_max
);
packed_4bit
|=
dQuantizeN
F4
(((
float
)
vals
[
2
*
j
+
1
])
*
local_abs_max
);
qvals
[
j
]
=
packed_4bit
;
qvals
[
j
]
=
packed_4bit
;
}
}
break
;
break
;
...
@@ -738,7 +738,7 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
...
@@ -738,7 +738,7 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
}
}
}
}
template
<
typename
T
,
int
TILE_SIZE
,
int
THREADS
,
int
NUM_PER_TH
,
int
FP4
>
template
<
typename
T
,
int
TILE_SIZE
,
int
THREADS
,
int
NUM_PER_TH
,
int
DATA_TYPE
>
__global__
void
kDequantizeBlockwise
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
T
*
out
,
const
int
blocksize
,
const
int
n
)
__global__
void
kDequantizeBlockwise
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
T
*
out
,
const
int
blocksize
,
const
int
n
)
{
{
...
@@ -747,19 +747,19 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
...
@@ -747,19 +747,19 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
int
valid_items_store
=
0
;
int
valid_items_store
=
0
;
const
int
base_idx
=
(
blockIdx
.
x
*
TILE_SIZE
);
const
int
base_idx
=
(
blockIdx
.
x
*
TILE_SIZE
);
T
vals
[
NUM_PER_TH
*
(
FP4
?
2
:
1
)];
T
vals
[
NUM_PER_TH
*
(
(
DATA_TYPE
>
0
)
?
2
:
1
)];
unsigned
char
qvals
[
NUM_PER_TH
];
unsigned
char
qvals
[
NUM_PER_TH
];
float
local_abs_max
=
-
FLT_MAX
;
float
local_abs_max
=
-
FLT_MAX
;
typedef
cub
::
BlockLoad
<
unsigned
char
,
THREADS
,
NUM_PER_TH
,
cub
::
BLOCK_LOAD_WARP_TRANSPOSE
>
LoadChar
;
typedef
cub
::
BlockLoad
<
unsigned
char
,
THREADS
,
NUM_PER_TH
,
cub
::
BLOCK_LOAD_WARP_TRANSPOSE
>
LoadChar
;
typedef
cub
::
BlockStore
<
T
,
THREADS
,
NUM_PER_TH
*
(
FP4
?
2
:
1
),
cub
::
BLOCK_STORE_WARP_TRANSPOSE
>
StoreT
;
typedef
cub
::
BlockStore
<
T
,
THREADS
,
NUM_PER_TH
*
(
(
DATA_TYPE
>
0
)
?
2
:
1
),
cub
::
BLOCK_STORE_WARP_TRANSPOSE
>
StoreT
;
__shared__
typename
LoadChar
::
TempStorage
loadchar
;
__shared__
typename
LoadChar
::
TempStorage
loadchar
;
__shared__
typename
StoreT
::
TempStorage
storet
;
__shared__
typename
StoreT
::
TempStorage
storet
;
for
(
unsigned
int
i
=
base_idx
;
i
<
n_load
;
i
+=
gridDim
.
x
*
TILE_SIZE
)
for
(
unsigned
int
i
=
base_idx
;
i
<
n_load
;
i
+=
gridDim
.
x
*
TILE_SIZE
)
{
{
if
(
FP4
)
if
(
DATA_TYPE
>
0
)
{
{
valid_items_load
=
(
n
+
1
)
/
2
-
i
>
TILE_SIZE
?
TILE_SIZE
:
(
n
+
1
)
/
2
-
i
;
valid_items_load
=
(
n
+
1
)
/
2
-
i
>
TILE_SIZE
?
TILE_SIZE
:
(
n
+
1
)
/
2
-
i
;
valid_items_store
=
n
-
i
*
2
>
TILE_SIZE
*
2
?
TILE_SIZE
*
2
:
n
-
i
*
2
;
valid_items_store
=
n
-
i
*
2
>
TILE_SIZE
*
2
?
TILE_SIZE
*
2
:
n
-
i
*
2
;
...
@@ -775,27 +775,34 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
...
@@ -775,27 +775,34 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
LoadChar
(
loadchar
).
Load
(
&
(
A
[
i
]),
qvals
,
valid_items_load
,
128
);
LoadChar
(
loadchar
).
Load
(
&
(
A
[
i
]),
qvals
,
valid_items_load
,
128
);
if
(
FP4
)
switch
(
DATA_TYPE
)
{
{
case
General8bit
:
// load code through read-only cache via __ldg
#pragma unroll NUM_PER_TH
for
(
int
j
=
0
;
j
<
NUM_PER_TH
;
j
++
)
vals
[
j
]
=
__ldg
(
&
code
[
qvals
[
j
]])
*
local_abs_max
;
break
;
case
FP4
:
#pragma unroll NUM_PER_TH
#pragma unroll NUM_PER_TH
for
(
int
j
=
0
;
j
<
NUM_PER_TH
;
j
++
)
for
(
int
j
=
0
;
j
<
NUM_PER_TH
;
j
++
)
{
{
//vals[j*2] = dDequantizeFP4(qvals[j] >> 4, local_abs_max*0.083333f);
//vals[j*2 + 1] = dDequantizeFP4(qvals[j] & 0x0F, local_abs_max*0.083333);
vals
[
j
*
2
]
=
dDequantizeFP4Tree
(
qvals
[
j
]
>>
4
,
local_abs_max
);
vals
[
j
*
2
]
=
dDequantizeFP4Tree
(
qvals
[
j
]
>>
4
,
local_abs_max
);
vals
[
j
*
2
+
1
]
=
dDequantizeFP4Tree
(
qvals
[
j
]
&
0x0F
,
local_abs_max
);
vals
[
j
*
2
+
1
]
=
dDequantizeFP4Tree
(
qvals
[
j
]
&
0x0F
,
local_abs_max
);
}
}
}
break
;
else
case
NF4
:
{
// load code through read-only cache via __ldg
#pragma unroll NUM_PER_TH
#pragma unroll NUM_PER_TH
for
(
int
j
=
0
;
j
<
NUM_PER_TH
;
j
++
)
for
(
int
j
=
0
;
j
<
NUM_PER_TH
;
j
++
)
vals
[
j
]
=
__ldg
(
&
code
[
qvals
[
j
]])
*
local_abs_max
;
{
vals
[
j
*
2
]
=
dDequantizeNF4
(
qvals
[
j
]
>>
4
,
local_abs_max
);
vals
[
j
*
2
+
1
]
=
dDequantizeNF4
(
qvals
[
j
]
&
0x0F
,
local_abs_max
);
}
break
;
}
}
__syncthreads
();
__syncthreads
();
StoreT
(
storet
).
Store
(
&
(
out
[
FP4
?
i
*
2
:
i
]),
vals
,
valid_items_store
);
StoreT
(
storet
).
Store
(
&
(
out
[
(
DATA_TYPE
>
0
)
?
i
*
2
:
i
]),
vals
,
valid_items_store
);
}
}
}
}
...
...
tests/test_autograd.py
View file @
4ea489d3
...
@@ -440,7 +440,7 @@ dim4 = torch.randint(32, 96, size=(n,)).tolist()
...
@@ -440,7 +440,7 @@ dim4 = torch.randint(32, 96, size=(n,)).tolist()
dim2
.
append
(
0
)
dim2
.
append
(
0
)
funcs
=
[(
torch
.
matmul
,
bnb
.
matmul_
fp4
)]
funcs
=
[(
torch
.
matmul
,
bnb
.
matmul_
4bit
)]
str_funcs
=
[
"matmul"
]
str_funcs
=
[
"matmul"
]
req_grad
=
list
(
product
([
True
,
False
],
repeat
=
3
))
req_grad
=
list
(
product
([
True
,
False
],
repeat
=
3
))
req_grad_str
=
[]
req_grad_str
=
[]
...
@@ -457,12 +457,13 @@ dtype = [torch.float16, torch.float32]
...
@@ -457,12 +457,13 @@ dtype = [torch.float16, torch.float32]
compress_statistics
=
[
False
,
True
]
compress_statistics
=
[
False
,
True
]
has_fp16_weights
=
[
True
,
False
]
has_fp16_weights
=
[
True
,
False
]
has_bias
=
[
True
,
False
]
has_bias
=
[
True
,
False
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
has_bias
,
compress_statistics
))
quant_type
=
[
'fp4'
,
'nf4'
]
str_values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
,
has_bias
,
compress_statistics
))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
has_bias
,
compress_statistics
,
quant_type
))
names
=
[
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}_compress_statistics"
.
format
(
*
vals
)
for
vals
in
str_values
]
str_values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
,
has_bias
,
compress_statistics
,
quant_type
))
names
=
[
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}_compress_statistics_{}_quant_type_{}"
.
format
(
*
vals
)
for
vals
in
str_values
]
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics
, quant_type
"
,
values
,
ids
=
names
)
def
test_matmul_
fp4
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
has_bias
,
compress_statistics
):
def
test_matmul_
4bit
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
has_bias
,
compress_statistics
,
quant_type
):
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
if
has_bias
==
False
:
if
has_bias
==
False
:
...
@@ -482,7 +483,7 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
...
@@ -482,7 +483,7 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
bias2
=
bias
.
clone
()
bias2
=
bias
.
clone
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
B2
,
quant_state
=
bnb
.
functional
.
quantize_
fp4
(
B
,
compress_statistics
=
compress_statistics
)
B2
,
quant_state
=
bnb
.
functional
.
quantize_
4bit
(
B
,
compress_statistics
=
compress_statistics
,
quant_type
=
quant_type
)
if
not
transpose
[
0
]
and
transpose
[
1
]:
if
not
transpose
[
0
]
and
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
.
t
())
out_torch
=
funcs
[
0
](
A
,
B
.
t
())
...
...
tests/test_functional.py
View file @
4ea489d3
...
@@ -1784,8 +1784,8 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
...
@@ -1784,8 +1784,8 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
print
(
"partial matmul"
,
time
.
time
()
-
t0
)
print
(
"partial matmul"
,
time
.
time
()
-
t0
)
batch_size
=
4
batch_size
=
2
seqdim
=
2
56
seqdim
=
2
048
values
=
[]
values
=
[]
values
.
append
((
batch_size
,
seqdim
,
768
,
4
*
768
))
values
.
append
((
batch_size
,
seqdim
,
768
,
4
*
768
))
values
.
append
((
batch_size
,
seqdim
,
1024
,
4
*
1024
))
values
.
append
((
batch_size
,
seqdim
,
1024
,
4
*
1024
))
...
@@ -1798,7 +1798,7 @@ values.append((batch_size, seqdim, 12288, 4*12288))
...
@@ -1798,7 +1798,7 @@ values.append((batch_size, seqdim, 12288, 4*12288))
names
=
[
"batch_{}_seq_{}_model_{}_hidden_{}"
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"batch_{}_seq_{}_model_{}_hidden_{}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
def
test_bench_matmul
(
batch
,
seq
,
model
,
hidden
):
def
test_bench_matmul
(
batch
,
seq
,
model
,
hidden
):
iters
=
128
iters
=
32
formatB
=
F
.
get_special_format_str
()
formatB
=
F
.
get_special_format_str
()
A
=
torch
.
randn
(
batch
,
seq
,
model
,
device
=
"cuda"
).
half
()
A
=
torch
.
randn
(
batch
,
seq
,
model
,
device
=
"cuda"
).
half
()
...
@@ -1808,6 +1808,8 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1808,6 +1808,8 @@ def test_bench_matmul(batch, seq, model, hidden):
B_fp4
,
state
=
F
.
quantize_fp4
(
B
)
B_fp4
,
state
=
F
.
quantize_fp4
(
B
)
B_fp4_c
,
state_c
=
F
.
quantize_fp4
(
B
,
compress_statistics
=
True
)
B_fp4_c
,
state_c
=
F
.
quantize_fp4
(
B
,
compress_statistics
=
True
)
B_nf4
,
state_nf4
=
F
.
quantize_nf4
(
B
)
linear8bit
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
).
cuda
().
half
()
linear8bit
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
).
cuda
().
half
()
linear8bit
.
eval
()
linear8bit
.
eval
()
...
@@ -1836,17 +1838,24 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1836,17 +1838,24 @@ def test_bench_matmul(batch, seq, model, hidden):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
bnb
.
matmul_
fp4
(
A
,
B_fp4
.
t
(),
quant_state
=
state
)
bnb
.
matmul_
4bit
(
A
,
B_fp4
.
t
(),
quant_state
=
state
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
"bnb fp4: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
print
(
f
"bnb fp4: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
bnb
.
matmul_
fp4
(
A
,
B_fp4
.
t
(),
quant_state
=
state_c
)
bnb
.
matmul_
4bit
(
A
,
B_fp4
.
t
(),
quant_state
=
state_c
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
"bnb fp4 + compressed stats: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
print
(
f
"bnb fp4 + compressed stats: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
bnb
.
matmul_4bit
(
A
,
B_nf4
.
t
(),
quant_state
=
state_nf4
)
torch
.
cuda
.
synchronize
()
print
(
f
"bnb nf4: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
#torch.cuda.synchronize()
#torch.cuda.synchronize()
#t0 = time.time()
#t0 = time.time()
#for i in range(iters):
#for i in range(iters):
...
@@ -2262,17 +2271,18 @@ def test_4bit_compressed_stats(quant_type):
...
@@ -2262,17 +2271,18 @@ def test_4bit_compressed_stats(quant_type):
errs2
=
[]
errs2
=
[]
for
i
in
range
(
10
):
for
i
in
range
(
10
):
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'cuda'
).
half
()
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'cuda'
).
half
()
q2
,
SA2
=
F
.
quantize_4bit
_packed
(
A1
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
q2
,
SA2
=
F
.
quantize_4bit
(
A1
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
q3
,
SA3
=
F
.
quantize_4bit
_packed
(
A1
,
blocksize
=
blocksize
,
compress_statistics
=
True
,
quant_type
=
quant_type
)
q3
,
SA3
=
F
.
quantize_4bit
(
A1
,
blocksize
=
blocksize
,
compress_statistics
=
True
,
quant_type
=
quant_type
)
A2
=
F
.
dequantize_4bit
_packed
(
q2
,
SA2
,
quant_type
=
quant_type
)
A2
=
F
.
dequantize_4bit
(
q2
,
SA2
,
quant_type
=
quant_type
)
A3
=
F
.
dequantize_4bit
_packed
(
q3
,
SA3
,
quant_type
=
quant_type
)
A3
=
F
.
dequantize_4bit
(
q3
,
SA3
,
quant_type
=
quant_type
)
err
=
(
A1
-
A2
).
abs
().
float
()
err
=
(
A1
-
A2
).
abs
().
float
()
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-15
)).
mean
()
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-15
)).
mean
()
err
=
err
.
mean
()
err
=
err
.
mean
()
errs1
.
append
(
relerr
.
item
())
errs1
.
append
(
err
.
item
())
assert
err
.
item
()
<
0.11
assert
err
.
item
()
<
0.11
assert
relerr
.
item
()
<
0.28
assert
relerr
.
item
()
<
0.28
...
@@ -2281,23 +2291,23 @@ def test_4bit_compressed_stats(quant_type):
...
@@ -2281,23 +2291,23 @@ def test_4bit_compressed_stats(quant_type):
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-15
)).
mean
()
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-15
)).
mean
()
err
=
err
.
mean
()
err
=
err
.
mean
()
errs2
.
append
(
rel
err
.
item
())
errs2
.
append
(
err
.
item
())
assert
err
.
item
()
<
0.11
assert
err
.
item
()
<
0.11
assert
relerr
.
item
()
<
0.28
assert
relerr
.
item
()
<
0.28
#print(sum(errs1)/len(errs1), blocksize)
#print(sum(errs1)/len(errs1), blocksize
, quant_type
)
#print(sum(errs2)/len(errs2), blocksize)
#print(sum(errs2)/len(errs2), blocksize
, quant_type
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
'fp4'
,
'nf4'
])
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
'fp4'
,
'nf4'
])
def
test_bench_
fp4
_dequant
(
quant_type
):
def
test_bench_
4bit
_dequant
(
quant_type
):
blocksize
=
256
blocksize
=
256
a
=
torch
.
rand
(
1024
*
12
*
4
,
1024
*
12
,
device
=
'cuda'
).
half
()
a
=
torch
.
rand
(
1024
*
12
*
4
,
1024
*
12
,
device
=
'cuda'
).
half
()
qa
,
SA
=
F
.
quantize_4bit
_packed
(
a
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
qa
,
SA
=
F
.
quantize_4bit
(
a
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
input_size
=
a
.
numel
()
/
2
input_size
=
a
.
numel
()
/
2
output_size
=
a
.
numel
()
*
2
output_size
=
a
.
numel
()
*
2
...
@@ -2311,7 +2321,7 @@ def test_bench_fp4_dequant(quant_type):
...
@@ -2311,7 +2321,7 @@ def test_bench_fp4_dequant(quant_type):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
F
.
dequantize_4bit
_packed
(
qa
,
SA
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
F
.
dequantize_4bit
(
qa
,
SA
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
#b.copy_(a)
#b.copy_(a)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
#print((time.time()-t0)/iters*1e6)
#print((time.time()-t0)/iters*1e6)
...
...
tests/test_modules.py
View file @
4ea489d3
...
@@ -506,8 +506,16 @@ def test_linear_kbit_fp32_bias(module):
...
@@ -506,8 +506,16 @@ def test_linear_kbit_fp32_bias(module):
o1
=
l1
(
b1
)
o1
=
l1
(
b1
)
assert
l1
.
bias
is
None
assert
l1
.
bias
is
None
modules
=
[]
modules
.
append
(
bnb
.
nn
.
Linear8bitLt
)
modules
.
append
(
bnb
.
nn
.
Linear4bit
)
modules
.
append
(
bnb
.
nn
.
LinearFP4
)
modules
.
append
(
bnb
.
nn
.
LinearNF4
)
modules
.
append
(
lambda
d1
,
d2
:
bnb
.
nn
.
LinearFP4
(
d1
,
d2
,
compress_statistics
=
True
))
modules
.
append
(
lambda
d1
,
d2
:
bnb
.
nn
.
LinearNF4
(
d1
,
d2
,
compress_statistics
=
True
))
names
=
[
'Int8Lt'
,
'4bit'
,
'FP4'
,
'NF4'
,
'FP4+C'
,
'NF4+C'
]
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
parametrize
(
"module"
,
[
bnb
.
nn
.
Linear8bitLt
,
bnb
.
nn
.
LinearFP4
,
lambda
d1
,
d2
:
bnb
.
nn
.
LinearFP4
(
d1
,
d2
,
compress_statistics
=
True
)],
ids
=
[
'Int8Lt'
,
'FP4'
,
'FP4+C'
]
)
@
pytest
.
mark
.
parametrize
(
"module"
,
modules
,
ids
=
names
)
def
test_kbit_backprop
(
module
):
def
test_kbit_backprop
(
module
):
b
=
17
b
=
17
dim1
=
37
dim1
=
37
...
@@ -515,6 +523,8 @@ def test_kbit_backprop(module):
...
@@ -515,6 +523,8 @@ def test_kbit_backprop(module):
ref
=
nn
.
Sequential
(
*
[
torch
.
nn
.
Linear
(
dim1
,
dim2
),
torch
.
nn
.
Linear
(
dim2
,
10
)])
ref
=
nn
.
Sequential
(
*
[
torch
.
nn
.
Linear
(
dim1
,
dim2
),
torch
.
nn
.
Linear
(
dim2
,
10
)])
ref
[
1
].
weight
.
requires_grad
=
False
ref
[
1
].
weight
.
requires_grad
=
False
torch
.
nn
.
init
.
kaiming_normal_
(
ref
[
0
].
weight
)
torch
.
nn
.
init
.
kaiming_normal_
(
ref
[
1
].
weight
)
kbit
=
nn
.
Sequential
(
*
[
torch
.
nn
.
Linear
(
dim1
,
dim2
),
module
(
dim2
,
10
)])
kbit
=
nn
.
Sequential
(
*
[
torch
.
nn
.
Linear
(
dim1
,
dim2
),
module
(
dim2
,
10
)])
kbit
[
0
].
weight
.
detach
().
copy_
(
ref
[
0
].
weight
)
kbit
[
0
].
weight
.
detach
().
copy_
(
ref
[
0
].
weight
)
kbit
[
1
].
weight
.
detach
().
copy_
(
ref
[
1
].
weight
)
kbit
[
1
].
weight
.
detach
().
copy_
(
ref
[
1
].
weight
)
...
@@ -523,6 +533,10 @@ def test_kbit_backprop(module):
...
@@ -523,6 +533,10 @@ def test_kbit_backprop(module):
ref
=
ref
.
half
().
cuda
()
ref
=
ref
.
half
().
cuda
()
kbit
=
kbit
.
half
().
cuda
()
kbit
=
kbit
.
half
().
cuda
()
errs1
=
[]
errs2
=
[]
relerrs1
=
[]
relerrs2
=
[]
for
i
in
range
(
100
):
for
i
in
range
(
100
):
batch
=
torch
.
randn
(
b
,
dim1
).
half
().
cuda
()
batch
=
torch
.
randn
(
b
,
dim1
).
half
().
cuda
()
out1
=
ref
(
batch
)
out1
=
ref
(
batch
)
...
@@ -535,12 +549,26 @@ def test_kbit_backprop(module):
...
@@ -535,12 +549,26 @@ def test_kbit_backprop(module):
bgrad1
=
ref
[
0
].
bias
.
grad
bgrad1
=
ref
[
0
].
bias
.
grad
bgrad2
=
kbit
[
0
].
bias
.
grad
bgrad2
=
kbit
[
0
].
bias
.
grad
torch
.
testing
.
assert_allclose
(
grad1
,
grad2
,
atol
=
0.008
,
rtol
=
0.05
)
err1
=
(
out1
-
out2
).
abs
().
float
()
torch
.
testing
.
assert_allclose
(
bgrad1
,
bgrad2
,
atol
=
0.008
,
rtol
=
0.05
)
err2
=
(
grad1
-
grad2
).
abs
().
float
()
relerr1
=
(
err1
/
(
out1
.
abs
().
float
()
+
1e-9
))
relerr2
=
(
err2
/
(
grad1
.
abs
().
float
()
+
1e-9
))
errs1
.
append
(
err1
.
mean
().
item
())
errs2
.
append
(
err2
.
mean
().
item
())
relerrs1
.
append
(
relerr1
.
mean
().
item
())
relerrs2
.
append
(
relerr2
.
mean
().
item
())
#torch.testing.assert_allclose(grad1, grad2, atol=0.008, rtol=0.05)
#torch.testing.assert_allclose(bgrad1, bgrad2, atol=0.008, rtol=0.05)
ref
.
zero_grad
()
ref
.
zero_grad
()
kbit
.
zero_grad
()
kbit
.
zero_grad
()
assert
kbit
[
0
].
weight
.
grad
.
sum
().
item
()
==
0
assert
kbit
[
0
].
weight
.
grad
.
sum
().
item
()
==
0
assert
kbit
[
0
].
bias
.
grad
.
sum
().
item
()
==
0
assert
kbit
[
0
].
bias
.
grad
.
sum
().
item
()
==
0
print
(
'out'
,
sum
(
errs1
)
/
len
(
errs1
))
print
(
'grad'
,
sum
(
errs2
)
/
len
(
errs2
))
print
(
'rel out'
,
sum
(
relerrs1
)
/
len
(
relerrs1
))
print
(
'rel grad'
,
sum
(
relerrs2
)
/
len
(
relerrs2
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment