Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
4ea489d3
Commit
4ea489d3
authored
Apr 03, 2023
by
Tim Dettmers
Browse files
Refactor FP4 into 4Bit and integrate NF4 data type.
parent
64cc0592
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
145 additions
and
90 deletions
+145
-90
bitsandbytes/__init__.py
bitsandbytes/__init__.py
+1
-1
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+3
-3
bitsandbytes/functional.py
bitsandbytes/functional.py
+11
-10
bitsandbytes/nn/__init__.py
bitsandbytes/nn/__init__.py
+1
-1
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+17
-9
csrc/kernels.cu
csrc/kernels.cu
+47
-40
tests/test_autograd.py
tests/test_autograd.py
+8
-7
tests/test_functional.py
tests/test_functional.py
+26
-16
tests/test_modules.py
tests/test_modules.py
+31
-3
No files found.
bitsandbytes/__init__.py
View file @
4ea489d3
...
@@ -10,7 +10,7 @@ from .autograd._functions import (
...
@@ -10,7 +10,7 @@ from .autograd._functions import (
matmul
,
matmul
,
matmul_cublas
,
matmul_cublas
,
mm_cublas
,
mm_cublas
,
matmul_
fp4
matmul_
4bit
)
)
from
.cextension
import
COMPILED_WITH_CUDA
from
.cextension
import
COMPILED_WITH_CUDA
from
.nn
import
modules
from
.nn
import
modules
...
...
bitsandbytes/autograd/_functions.py
View file @
4ea489d3
...
@@ -475,7 +475,7 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -475,7 +475,7 @@ class MatMul8bitLt(torch.autograd.Function):
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
class
MatMul
FP4
(
torch
.
autograd
.
Function
):
class
MatMul
4Bit
(
torch
.
autograd
.
Function
):
# forward is the same, but we added the fallback for pre-turing GPUs
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
...
@@ -547,6 +547,6 @@ def matmul(
...
@@ -547,6 +547,6 @@ def matmul(
return
MatMul8bitLt
.
apply
(
A
,
B
,
out
,
bias
,
state
)
return
MatMul8bitLt
.
apply
(
A
,
B
,
out
,
bias
,
state
)
def
matmul_
fp4
(
A
:
tensor
,
B
:
tensor
,
quant_state
:
List
,
out
:
tensor
=
None
,
bias
=
None
):
def
matmul_
4bit
(
A
:
tensor
,
B
:
tensor
,
quant_state
:
List
,
out
:
tensor
=
None
,
bias
=
None
):
assert
quant_state
is
not
None
assert
quant_state
is
not
None
return
MatMul
FP4
.
apply
(
A
,
B
,
out
,
bias
,
quant_state
)
return
MatMul
4Bit
.
apply
(
A
,
B
,
out
,
bias
,
quant_state
)
bitsandbytes/functional.py
View file @
4ea489d3
...
@@ -689,14 +689,14 @@ def dequantize_blockwise(
...
@@ -689,14 +689,14 @@ def dequantize_blockwise(
return
out
return
out
def
quantize_fp4
(
A
:
Tensor
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
=
64
,
compress_statistics
=
False
):
def
quantize_fp4
(
A
:
Tensor
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
=
64
,
compress_statistics
=
False
):
return
quantize_4bit
_packed
(
A
,
absmax
,
out
,
blocksize
,
compress_statistics
,
'fp4'
)
return
quantize_4bit
(
A
,
absmax
,
out
,
blocksize
,
compress_statistics
,
'fp4'
)
def
quantize_nf4
(
A
:
Tensor
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
=
64
,
compress_statistics
=
False
):
def
quantize_nf4
(
A
:
Tensor
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
=
64
,
compress_statistics
=
False
):
return
quantize_4bit
_packed
(
A
,
absmax
,
out
,
blocksize
,
compress_statistics
,
'nf4'
)
return
quantize_4bit
(
A
,
absmax
,
out
,
blocksize
,
compress_statistics
,
'nf4'
)
def
quantize_4bit
_packed
(
A
:
Tensor
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
=
64
,
compress_statistics
=
False
,
quant_type
=
'fp4'
)
->
Tensor
:
def
quantize_4bit
(
A
:
Tensor
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
=
64
,
compress_statistics
=
False
,
quant_type
=
'fp4'
)
->
Tensor
:
"""
"""
Quantize tensor A in blocks of
FP4
values.
Quantize tensor A in blocks of
4-bit
values.
Quantizes tensor A by dividing it into blocks which are independently quantized to FP4.
Quantizes tensor A by dividing it into blocks which are independently quantized to FP4.
...
@@ -763,19 +763,19 @@ def quantize_4bit_packed(A: Tensor, absmax: Tensor = None, out: Tensor = None, b
...
@@ -763,19 +763,19 @@ def quantize_4bit_packed(A: Tensor, absmax: Tensor = None, out: Tensor = None, b
#qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256)
#qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256)
qabsmax
,
state2
=
quantize_blockwise
(
absmax
,
blocksize
=
256
)
qabsmax
,
state2
=
quantize_blockwise
(
absmax
,
blocksize
=
256
)
del
absmax
del
absmax
state
=
(
qabsmax
,
input_shape
,
A
.
dtype
,
blocksize
,
(
offset
,
state2
))
state
=
(
qabsmax
,
input_shape
,
A
.
dtype
,
blocksize
,
(
offset
,
state2
)
,
quant_type
)
else
:
else
:
state
=
(
absmax
,
input_shape
,
A
.
dtype
,
blocksize
,
None
)
state
=
(
absmax
,
input_shape
,
A
.
dtype
,
blocksize
,
None
,
quant_type
)
return
out
,
state
return
out
,
state
def
dequantize_fp4
(
A
:
Tensor
,
quant_state
:
Tuple
[
Tensor
,
Tensor
]
=
None
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
:
int
=
64
)
->
Tensor
:
def
dequantize_fp4
(
A
:
Tensor
,
quant_state
:
Tuple
[
Tensor
,
Tensor
]
=
None
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
:
int
=
64
)
->
Tensor
:
return
dequantize_4bit
_packed
(
A
,
quant_state
,
absmax
,
out
,
blocksize
,
'fp4'
)
return
dequantize_4bit
(
A
,
quant_state
,
absmax
,
out
,
blocksize
,
'fp4'
)
def
dequantize_nf4
(
A
:
Tensor
,
quant_state
:
Tuple
[
Tensor
,
Tensor
]
=
None
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
:
int
=
64
)
->
Tensor
:
def
dequantize_nf4
(
A
:
Tensor
,
quant_state
:
Tuple
[
Tensor
,
Tensor
]
=
None
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
:
int
=
64
)
->
Tensor
:
return
dequantize_4bit
_packed
(
A
,
quant_state
,
absmax
,
out
,
blocksize
,
'nf4'
)
return
dequantize_4bit
(
A
,
quant_state
,
absmax
,
out
,
blocksize
,
'nf4'
)
def
dequantize_4bit
_packed
(
A
:
Tensor
,
quant_state
:
Tuple
[
Tensor
,
Tensor
]
=
None
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
:
int
=
64
,
quant_type
=
'fp4'
)
->
Tensor
:
def
dequantize_4bit
(
A
:
Tensor
,
quant_state
:
Tuple
[
Tensor
,
Tensor
]
=
None
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
:
int
=
64
,
quant_type
=
'fp4'
)
->
Tensor
:
"""
"""
Dequantizes FP4 blockwise quantized values.
Dequantizes FP4 blockwise quantized values.
...
@@ -812,7 +812,8 @@ def dequantize_4bit_packed(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None,
...
@@ -812,7 +812,8 @@ def dequantize_4bit_packed(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None,
shape
=
out
.
shape
shape
=
out
.
shape
dtype
=
out
.
dtype
dtype
=
out
.
dtype
else
:
else
:
absmax
,
shape
,
dtype
,
blocksize
,
compressed_stats
=
quant_state
absmax
,
shape
,
dtype
,
blocksize
,
compressed_stats
,
quant_type
=
quant_state
if
compressed_stats
is
not
None
:
if
compressed_stats
is
not
None
:
offset
,
state2
=
compressed_stats
offset
,
state2
=
compressed_stats
...
...
bitsandbytes/nn/__init__.py
View file @
4ea489d3
...
@@ -2,4 +2,4 @@
...
@@ -2,4 +2,4 @@
#
#
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
.modules
import
Int8Params
,
Linear8bitLt
,
StableEmbedding
,
Linear
FP4
,
FP4Params
from
.modules
import
Int8Params
,
Linear8bitLt
,
StableEmbedding
,
Linear
4bit
,
LinearNF4
,
Linear
FP4
,
Params
4bit
bitsandbytes/nn/modules.py
View file @
4ea489d3
...
@@ -133,18 +133,19 @@ class Embedding(torch.nn.Embedding):
...
@@ -133,18 +133,19 @@ class Embedding(torch.nn.Embedding):
return
emb
return
emb
class
FP4
Params
(
torch
.
nn
.
Parameter
):
class
Params
4bit
(
torch
.
nn
.
Parameter
):
def
__new__
(
cls
,
data
=
None
,
requires_grad
=
True
,
quant_state
=
None
,
blocksize
=
64
,
compress_statistics
=
True
):
def
__new__
(
cls
,
data
=
None
,
requires_grad
=
True
,
quant_state
=
None
,
blocksize
=
64
,
compress_statistics
=
True
,
quant_type
=
'fp4'
):
cls
.
quant_state
=
None
cls
.
quant_state
=
None
cls
.
blocksize
=
blocksize
cls
.
blocksize
=
blocksize
cls
.
compress_statistics
=
compress_statistics
cls
.
compress_statistics
=
compress_statistics
cls
.
quant_type
=
quant_type
if
data
is
None
:
if
data
is
None
:
data
=
torch
.
empty
(
0
)
data
=
torch
.
empty
(
0
)
return
torch
.
Tensor
.
_make_subclass
(
cls
,
data
,
requires_grad
)
return
torch
.
Tensor
.
_make_subclass
(
cls
,
data
,
requires_grad
)
def
cuda
(
self
,
device
):
def
cuda
(
self
,
device
):
w
=
self
.
data
.
contiguous
().
half
().
cuda
(
device
)
w
=
self
.
data
.
contiguous
().
half
().
cuda
(
device
)
w_fp4
,
quant_state
=
bnb
.
functional
.
quantize_
fp4
(
w
,
blocksize
=
self
.
blocksize
,
compress_statistics
=
self
.
compress_statistics
)
w_fp4
,
quant_state
=
bnb
.
functional
.
quantize_
4bit
(
w
,
blocksize
=
self
.
blocksize
,
compress_statistics
=
self
.
compress_statistics
,
quant_type
=
self
.
quant_type
)
self
.
data
=
w_fp4
self
.
data
=
w_fp4
self
.
quant_state
=
quant_state
self
.
quant_state
=
quant_state
...
@@ -168,17 +169,16 @@ class FP4Params(torch.nn.Parameter):
...
@@ -168,17 +169,16 @@ class FP4Params(torch.nn.Parameter):
if
(
device
is
not
None
and
device
.
type
==
"cuda"
and
self
.
data
.
device
.
type
==
"cpu"
):
if
(
device
is
not
None
and
device
.
type
==
"cuda"
and
self
.
data
.
device
.
type
==
"cpu"
):
return
self
.
cuda
(
device
)
return
self
.
cuda
(
device
)
else
:
else
:
new_param
=
FP4
Params
(
super
().
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
),
new_param
=
Params
4bit
(
super
().
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
),
requires_grad
=
self
.
requires_grad
,
quant_state
=
self
.
quant_state
)
requires_grad
=
self
.
requires_grad
,
quant_state
=
self
.
quant_state
)
return
new_param
return
new_param
class
Linear4bit
(
nn
.
Linear
):
class
LinearFP4
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
,
quant_type
=
'fp4'
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
):
super
().
__init__
(
input_features
,
output_features
,
bias
)
super
().
__init__
(
input_features
,
output_features
,
bias
)
self
.
state
=
bnb
.
MatmulLtState
()
self
.
state
=
bnb
.
MatmulLtState
()
self
.
weight
=
FP4
Params
(
self
.
weight
.
data
,
requires_grad
=
False
,
compress_statistics
=
compress_statistics
)
self
.
weight
=
Params
4bit
(
self
.
weight
.
data
,
requires_grad
=
False
,
compress_statistics
=
compress_statistics
,
quant_type
=
quant_type
)
self
.
compute_dtype
=
compute_dtype
self
.
compute_dtype
=
compute_dtype
def
init_8bit_state
(
self
):
def
init_8bit_state
(
self
):
...
@@ -198,12 +198,20 @@ class LinearFP4(nn.Linear):
...
@@ -198,12 +198,20 @@ class LinearFP4(nn.Linear):
x
=
x
.
to
(
self
.
compute_dtype
)
x
=
x
.
to
(
self
.
compute_dtype
)
bias
=
None
if
self
.
bias
is
None
else
self
.
bias
.
half
()
bias
=
None
if
self
.
bias
is
None
else
self
.
bias
.
half
()
out
=
bnb
.
matmul_
fp4
(
x
,
self
.
weight
.
t
(),
bias
=
bias
,
quant_state
=
self
.
weight
.
quant_state
)
out
=
bnb
.
matmul_
4bit
(
x
,
self
.
weight
.
t
(),
bias
=
bias
,
quant_state
=
self
.
weight
.
quant_state
)
out
=
out
.
to
(
inp_dtype
)
out
=
out
.
to
(
inp_dtype
)
return
out
return
out
class
LinearFP4
(
Linear4bit
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
):
super
().
__init__
(
input_features
,
output_features
,
bias
,
compute_dtype
,
compress_statistics
,
'fp4'
)
class
LinearNF4
(
Linear4bit
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
):
super
().
__init__
(
input_features
,
output_features
,
bias
,
compute_dtype
,
compress_statistics
,
'nf4'
)
class
Int8Params
(
torch
.
nn
.
Parameter
):
class
Int8Params
(
torch
.
nn
.
Parameter
):
def
__new__
(
def
__new__
(
...
...
csrc/kernels.cu
View file @
4ea489d3
...
@@ -194,7 +194,7 @@ __device__ float dDequantizeNF4(unsigned char val, float absmax)
...
@@ -194,7 +194,7 @@ __device__ float dDequantizeNF4(unsigned char val, float absmax)
}
}
__device__
unsigned
char
dQuantizeN
ormal
(
float
x
)
__device__
unsigned
char
dQuantizeN
F4
(
float
x
)
{
{
// the values for this tree was generated by test_normal_map_tree
// the values for this tree was generated by test_normal_map_tree
...
@@ -221,7 +221,7 @@ __device__ unsigned char dQuantizeNormal(float x)
...
@@ -221,7 +221,7 @@ __device__ unsigned char dQuantizeNormal(float x)
if
(
x
>
0.1202552504837513
f
)
// 100
if
(
x
>
0.1202552504837513
f
)
// 100
return
0b1001
;
return
0b1001
;
else
else
return
0b1
1
00
;
return
0b1
0
00
;
else
else
if
(
x
>
-
0.33967943489551544
f
)
// 0
if
(
x
>
-
0.33967943489551544
f
)
// 0
if
(
x
>
-
0.13791173323988914
f
)
// 01
if
(
x
>
-
0.13791173323988914
f
)
// 01
...
@@ -726,8 +726,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
...
@@ -726,8 +726,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
#pragma unroll NUM_PER_TH
#pragma unroll NUM_PER_TH
for
(
int
j
=
0
;
j
<
NUM_PER_TH
/
2
;
j
++
)
for
(
int
j
=
0
;
j
<
NUM_PER_TH
/
2
;
j
++
)
{
{
packed_4bit
|=
dQuantizeN
ormal
(((
float
)
vals
[
2
*
j
])
*
local_abs_max
)
<<
4
;
packed_4bit
|=
dQuantizeN
F4
(((
float
)
vals
[
2
*
j
])
*
local_abs_max
)
<<
4
;
packed_4bit
|=
dQuantizeN
ormal
(((
float
)
vals
[
2
*
j
+
1
])
*
local_abs_max
);
packed_4bit
|=
dQuantizeN
F4
(((
float
)
vals
[
2
*
j
+
1
])
*
local_abs_max
);
qvals
[
j
]
=
packed_4bit
;
qvals
[
j
]
=
packed_4bit
;
}
}
break
;
break
;
...
@@ -738,7 +738,7 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
...
@@ -738,7 +738,7 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
}
}
}
}
template
<
typename
T
,
int
TILE_SIZE
,
int
THREADS
,
int
NUM_PER_TH
,
int
FP4
>
template
<
typename
T
,
int
TILE_SIZE
,
int
THREADS
,
int
NUM_PER_TH
,
int
DATA_TYPE
>
__global__
void
kDequantizeBlockwise
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
T
*
out
,
const
int
blocksize
,
const
int
n
)
__global__
void
kDequantizeBlockwise
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
T
*
out
,
const
int
blocksize
,
const
int
n
)
{
{
...
@@ -747,55 +747,62 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
...
@@ -747,55 +747,62 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
int
valid_items_store
=
0
;
int
valid_items_store
=
0
;
const
int
base_idx
=
(
blockIdx
.
x
*
TILE_SIZE
);
const
int
base_idx
=
(
blockIdx
.
x
*
TILE_SIZE
);
T
vals
[
NUM_PER_TH
*
(
FP4
?
2
:
1
)];
T
vals
[
NUM_PER_TH
*
(
(
DATA_TYPE
>
0
)
?
2
:
1
)];
unsigned
char
qvals
[
NUM_PER_TH
];
unsigned
char
qvals
[
NUM_PER_TH
];
float
local_abs_max
=
-
FLT_MAX
;
float
local_abs_max
=
-
FLT_MAX
;
typedef
cub
::
BlockLoad
<
unsigned
char
,
THREADS
,
NUM_PER_TH
,
cub
::
BLOCK_LOAD_WARP_TRANSPOSE
>
LoadChar
;
typedef
cub
::
BlockLoad
<
unsigned
char
,
THREADS
,
NUM_PER_TH
,
cub
::
BLOCK_LOAD_WARP_TRANSPOSE
>
LoadChar
;
typedef
cub
::
BlockStore
<
T
,
THREADS
,
NUM_PER_TH
*
(
FP4
?
2
:
1
),
cub
::
BLOCK_STORE_WARP_TRANSPOSE
>
StoreT
;
typedef
cub
::
BlockStore
<
T
,
THREADS
,
NUM_PER_TH
*
(
(
DATA_TYPE
>
0
)
?
2
:
1
),
cub
::
BLOCK_STORE_WARP_TRANSPOSE
>
StoreT
;
__shared__
typename
LoadChar
::
TempStorage
loadchar
;
__shared__
typename
LoadChar
::
TempStorage
loadchar
;
__shared__
typename
StoreT
::
TempStorage
storet
;
__shared__
typename
StoreT
::
TempStorage
storet
;
for
(
unsigned
int
i
=
base_idx
;
i
<
n_load
;
i
+=
gridDim
.
x
*
TILE_SIZE
)
for
(
unsigned
int
i
=
base_idx
;
i
<
n_load
;
i
+=
gridDim
.
x
*
TILE_SIZE
)
{
{
if
(
FP4
)
if
(
DATA_TYPE
>
0
)
{
{
valid_items_load
=
(
n
+
1
)
/
2
-
i
>
TILE_SIZE
?
TILE_SIZE
:
(
n
+
1
)
/
2
-
i
;
valid_items_load
=
(
n
+
1
)
/
2
-
i
>
TILE_SIZE
?
TILE_SIZE
:
(
n
+
1
)
/
2
-
i
;
valid_items_store
=
n
-
i
*
2
>
TILE_SIZE
*
2
?
TILE_SIZE
*
2
:
n
-
i
*
2
;
valid_items_store
=
n
-
i
*
2
>
TILE_SIZE
*
2
?
TILE_SIZE
*
2
:
n
-
i
*
2
;
}
}
else
else
{
{
valid_items_load
=
n
-
i
>
TILE_SIZE
?
TILE_SIZE
:
n
-
i
;
valid_items_load
=
n
-
i
>
TILE_SIZE
?
TILE_SIZE
:
n
-
i
;
valid_items_store
=
n
-
i
>
TILE_SIZE
?
TILE_SIZE
:
n
-
i
;
valid_items_store
=
n
-
i
>
TILE_SIZE
?
TILE_SIZE
:
n
-
i
;
}
}
local_abs_max
=
__ldg
(
&
absmax
[(
i
+
threadIdx
.
x
*
NUM_PER_TH
)
/
(
blocksize
)]);
local_abs_max
=
__ldg
(
&
absmax
[(
i
+
threadIdx
.
x
*
NUM_PER_TH
)
/
(
blocksize
)]);
__syncthreads
();
__syncthreads
();
LoadChar
(
loadchar
).
Load
(
&
(
A
[
i
]),
qvals
,
valid_items_load
,
128
);
LoadChar
(
loadchar
).
Load
(
&
(
A
[
i
]),
qvals
,
valid_items_load
,
128
);
if
(
FP4
)
switch
(
DATA_TYPE
)
{
{
#pragma unroll NUM_PER_TH
case
General8bit
:
for
(
int
j
=
0
;
j
<
NUM_PER_TH
;
j
++
)
// load code through read-only cache via __ldg
{
#pragma unroll NUM_PER_TH
//vals[j*2] = dDequantizeFP4(qvals[j] >> 4, local_abs_max*0.083333f);
for
(
int
j
=
0
;
j
<
NUM_PER_TH
;
j
++
)
//vals[j*2 + 1] = dDequantizeFP4(qvals[j] & 0x0F, local_abs_max*0.083333);
vals
[
j
]
=
__ldg
(
&
code
[
qvals
[
j
]])
*
local_abs_max
;
vals
[
j
*
2
]
=
dDequantizeFP4Tree
(
qvals
[
j
]
>>
4
,
local_abs_max
);
break
;
vals
[
j
*
2
+
1
]
=
dDequantizeFP4Tree
(
qvals
[
j
]
&
0x0F
,
local_abs_max
);
case
FP4
:
}
#pragma unroll NUM_PER_TH
}
for
(
int
j
=
0
;
j
<
NUM_PER_TH
;
j
++
)
else
{
{
vals
[
j
*
2
]
=
dDequantizeFP4Tree
(
qvals
[
j
]
>>
4
,
local_abs_max
);
// load code through read-only cache via __ldg
vals
[
j
*
2
+
1
]
=
dDequantizeFP4Tree
(
qvals
[
j
]
&
0x0F
,
local_abs_max
);
#pragma unroll NUM_PER_TH
}
for
(
int
j
=
0
;
j
<
NUM_PER_TH
;
j
++
)
break
;
vals
[
j
]
=
__ldg
(
&
code
[
qvals
[
j
]])
*
local_abs_max
;
case
NF4
:
}
#pragma unroll NUM_PER_TH
for
(
int
j
=
0
;
j
<
NUM_PER_TH
;
j
++
)
{
vals
[
j
*
2
]
=
dDequantizeNF4
(
qvals
[
j
]
>>
4
,
local_abs_max
);
vals
[
j
*
2
+
1
]
=
dDequantizeNF4
(
qvals
[
j
]
&
0x0F
,
local_abs_max
);
}
break
;
}
__syncthreads
();
__syncthreads
();
StoreT
(
storet
).
Store
(
&
(
out
[
FP4
?
i
*
2
:
i
]),
vals
,
valid_items_store
);
StoreT
(
storet
).
Store
(
&
(
out
[
(
DATA_TYPE
>
0
)
?
i
*
2
:
i
]),
vals
,
valid_items_store
);
}
}
}
}
...
...
tests/test_autograd.py
View file @
4ea489d3
...
@@ -440,7 +440,7 @@ dim4 = torch.randint(32, 96, size=(n,)).tolist()
...
@@ -440,7 +440,7 @@ dim4 = torch.randint(32, 96, size=(n,)).tolist()
dim2
.
append
(
0
)
dim2
.
append
(
0
)
funcs
=
[(
torch
.
matmul
,
bnb
.
matmul_
fp4
)]
funcs
=
[(
torch
.
matmul
,
bnb
.
matmul_
4bit
)]
str_funcs
=
[
"matmul"
]
str_funcs
=
[
"matmul"
]
req_grad
=
list
(
product
([
True
,
False
],
repeat
=
3
))
req_grad
=
list
(
product
([
True
,
False
],
repeat
=
3
))
req_grad_str
=
[]
req_grad_str
=
[]
...
@@ -457,12 +457,13 @@ dtype = [torch.float16, torch.float32]
...
@@ -457,12 +457,13 @@ dtype = [torch.float16, torch.float32]
compress_statistics
=
[
False
,
True
]
compress_statistics
=
[
False
,
True
]
has_fp16_weights
=
[
True
,
False
]
has_fp16_weights
=
[
True
,
False
]
has_bias
=
[
True
,
False
]
has_bias
=
[
True
,
False
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
has_bias
,
compress_statistics
))
quant_type
=
[
'fp4'
,
'nf4'
]
str_values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
,
has_bias
,
compress_statistics
))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
has_bias
,
compress_statistics
,
quant_type
))
names
=
[
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}_compress_statistics"
.
format
(
*
vals
)
for
vals
in
str_values
]
str_values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
,
has_bias
,
compress_statistics
,
quant_type
))
names
=
[
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}_compress_statistics_{}_quant_type_{}"
.
format
(
*
vals
)
for
vals
in
str_values
]
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics
, quant_type
"
,
values
,
ids
=
names
)
def
test_matmul_
fp4
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
has_bias
,
compress_statistics
):
def
test_matmul_
4bit
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
has_bias
,
compress_statistics
,
quant_type
):
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
if
has_bias
==
False
:
if
has_bias
==
False
:
...
@@ -482,7 +483,7 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
...
@@ -482,7 +483,7 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
bias2
=
bias
.
clone
()
bias2
=
bias
.
clone
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
B2
,
quant_state
=
bnb
.
functional
.
quantize_
fp4
(
B
,
compress_statistics
=
compress_statistics
)
B2
,
quant_state
=
bnb
.
functional
.
quantize_
4bit
(
B
,
compress_statistics
=
compress_statistics
,
quant_type
=
quant_type
)
if
not
transpose
[
0
]
and
transpose
[
1
]:
if
not
transpose
[
0
]
and
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
.
t
())
out_torch
=
funcs
[
0
](
A
,
B
.
t
())
...
...
tests/test_functional.py
View file @
4ea489d3
...
@@ -1784,8 +1784,8 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
...
@@ -1784,8 +1784,8 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
print
(
"partial matmul"
,
time
.
time
()
-
t0
)
print
(
"partial matmul"
,
time
.
time
()
-
t0
)
batch_size
=
4
batch_size
=
2
seqdim
=
2
56
seqdim
=
2
048
values
=
[]
values
=
[]
values
.
append
((
batch_size
,
seqdim
,
768
,
4
*
768
))
values
.
append
((
batch_size
,
seqdim
,
768
,
4
*
768
))
values
.
append
((
batch_size
,
seqdim
,
1024
,
4
*
1024
))
values
.
append
((
batch_size
,
seqdim
,
1024
,
4
*
1024
))
...
@@ -1798,7 +1798,7 @@ values.append((batch_size, seqdim, 12288, 4*12288))
...
@@ -1798,7 +1798,7 @@ values.append((batch_size, seqdim, 12288, 4*12288))
names
=
[
"batch_{}_seq_{}_model_{}_hidden_{}"
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"batch_{}_seq_{}_model_{}_hidden_{}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
def
test_bench_matmul
(
batch
,
seq
,
model
,
hidden
):
def
test_bench_matmul
(
batch
,
seq
,
model
,
hidden
):
iters
=
128
iters
=
32
formatB
=
F
.
get_special_format_str
()
formatB
=
F
.
get_special_format_str
()
A
=
torch
.
randn
(
batch
,
seq
,
model
,
device
=
"cuda"
).
half
()
A
=
torch
.
randn
(
batch
,
seq
,
model
,
device
=
"cuda"
).
half
()
...
@@ -1808,6 +1808,8 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1808,6 +1808,8 @@ def test_bench_matmul(batch, seq, model, hidden):
B_fp4
,
state
=
F
.
quantize_fp4
(
B
)
B_fp4
,
state
=
F
.
quantize_fp4
(
B
)
B_fp4_c
,
state_c
=
F
.
quantize_fp4
(
B
,
compress_statistics
=
True
)
B_fp4_c
,
state_c
=
F
.
quantize_fp4
(
B
,
compress_statistics
=
True
)
B_nf4
,
state_nf4
=
F
.
quantize_nf4
(
B
)
linear8bit
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
).
cuda
().
half
()
linear8bit
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
).
cuda
().
half
()
linear8bit
.
eval
()
linear8bit
.
eval
()
...
@@ -1836,17 +1838,24 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1836,17 +1838,24 @@ def test_bench_matmul(batch, seq, model, hidden):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
bnb
.
matmul_
fp4
(
A
,
B_fp4
.
t
(),
quant_state
=
state
)
bnb
.
matmul_
4bit
(
A
,
B_fp4
.
t
(),
quant_state
=
state
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
"bnb fp4: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
print
(
f
"bnb fp4: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
bnb
.
matmul_
fp4
(
A
,
B_fp4
.
t
(),
quant_state
=
state_c
)
bnb
.
matmul_
4bit
(
A
,
B_fp4
.
t
(),
quant_state
=
state_c
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
"bnb fp4 + compressed stats: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
print
(
f
"bnb fp4 + compressed stats: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
bnb
.
matmul_4bit
(
A
,
B_nf4
.
t
(),
quant_state
=
state_nf4
)
torch
.
cuda
.
synchronize
()
print
(
f
"bnb nf4: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
#torch.cuda.synchronize()
#torch.cuda.synchronize()
#t0 = time.time()
#t0 = time.time()
#for i in range(iters):
#for i in range(iters):
...
@@ -2262,17 +2271,18 @@ def test_4bit_compressed_stats(quant_type):
...
@@ -2262,17 +2271,18 @@ def test_4bit_compressed_stats(quant_type):
errs2
=
[]
errs2
=
[]
for
i
in
range
(
10
):
for
i
in
range
(
10
):
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'cuda'
).
half
()
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'cuda'
).
half
()
q2
,
SA2
=
F
.
quantize_4bit
_packed
(
A1
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
q2
,
SA2
=
F
.
quantize_4bit
(
A1
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
q3
,
SA3
=
F
.
quantize_4bit
_packed
(
A1
,
blocksize
=
blocksize
,
compress_statistics
=
True
,
quant_type
=
quant_type
)
q3
,
SA3
=
F
.
quantize_4bit
(
A1
,
blocksize
=
blocksize
,
compress_statistics
=
True
,
quant_type
=
quant_type
)
A2
=
F
.
dequantize_4bit
_packed
(
q2
,
SA2
,
quant_type
=
quant_type
)
A2
=
F
.
dequantize_4bit
(
q2
,
SA2
,
quant_type
=
quant_type
)
A3
=
F
.
dequantize_4bit
_packed
(
q3
,
SA3
,
quant_type
=
quant_type
)
A3
=
F
.
dequantize_4bit
(
q3
,
SA3
,
quant_type
=
quant_type
)
err
=
(
A1
-
A2
).
abs
().
float
()
err
=
(
A1
-
A2
).
abs
().
float
()
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-15
)).
mean
()
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-15
)).
mean
()
err
=
err
.
mean
()
err
=
err
.
mean
()
errs1
.
append
(
relerr
.
item
())
errs1
.
append
(
err
.
item
())
assert
err
.
item
()
<
0.11
assert
err
.
item
()
<
0.11
assert
relerr
.
item
()
<
0.28
assert
relerr
.
item
()
<
0.28
...
@@ -2281,23 +2291,23 @@ def test_4bit_compressed_stats(quant_type):
...
@@ -2281,23 +2291,23 @@ def test_4bit_compressed_stats(quant_type):
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-15
)).
mean
()
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-15
)).
mean
()
err
=
err
.
mean
()
err
=
err
.
mean
()
errs2
.
append
(
rel
err
.
item
())
errs2
.
append
(
err
.
item
())
assert
err
.
item
()
<
0.11
assert
err
.
item
()
<
0.11
assert
relerr
.
item
()
<
0.28
assert
relerr
.
item
()
<
0.28
#print(sum(errs1)/len(errs1), blocksize)
#print(sum(errs1)/len(errs1), blocksize
, quant_type
)
#print(sum(errs2)/len(errs2), blocksize)
#print(sum(errs2)/len(errs2), blocksize
, quant_type
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
'fp4'
,
'nf4'
])
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
'fp4'
,
'nf4'
])
def
test_bench_
fp4
_dequant
(
quant_type
):
def
test_bench_
4bit
_dequant
(
quant_type
):
blocksize
=
256
blocksize
=
256
a
=
torch
.
rand
(
1024
*
12
*
4
,
1024
*
12
,
device
=
'cuda'
).
half
()
a
=
torch
.
rand
(
1024
*
12
*
4
,
1024
*
12
,
device
=
'cuda'
).
half
()
qa
,
SA
=
F
.
quantize_4bit
_packed
(
a
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
qa
,
SA
=
F
.
quantize_4bit
(
a
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
input_size
=
a
.
numel
()
/
2
input_size
=
a
.
numel
()
/
2
output_size
=
a
.
numel
()
*
2
output_size
=
a
.
numel
()
*
2
...
@@ -2311,7 +2321,7 @@ def test_bench_fp4_dequant(quant_type):
...
@@ -2311,7 +2321,7 @@ def test_bench_fp4_dequant(quant_type):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
F
.
dequantize_4bit
_packed
(
qa
,
SA
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
F
.
dequantize_4bit
(
qa
,
SA
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
#b.copy_(a)
#b.copy_(a)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
#print((time.time()-t0)/iters*1e6)
#print((time.time()-t0)/iters*1e6)
...
...
tests/test_modules.py
View file @
4ea489d3
...
@@ -506,8 +506,16 @@ def test_linear_kbit_fp32_bias(module):
...
@@ -506,8 +506,16 @@ def test_linear_kbit_fp32_bias(module):
o1
=
l1
(
b1
)
o1
=
l1
(
b1
)
assert
l1
.
bias
is
None
assert
l1
.
bias
is
None
modules
=
[]
modules
.
append
(
bnb
.
nn
.
Linear8bitLt
)
modules
.
append
(
bnb
.
nn
.
Linear4bit
)
modules
.
append
(
bnb
.
nn
.
LinearFP4
)
modules
.
append
(
bnb
.
nn
.
LinearNF4
)
modules
.
append
(
lambda
d1
,
d2
:
bnb
.
nn
.
LinearFP4
(
d1
,
d2
,
compress_statistics
=
True
))
modules
.
append
(
lambda
d1
,
d2
:
bnb
.
nn
.
LinearNF4
(
d1
,
d2
,
compress_statistics
=
True
))
names
=
[
'Int8Lt'
,
'4bit'
,
'FP4'
,
'NF4'
,
'FP4+C'
,
'NF4+C'
]
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
parametrize
(
"module"
,
[
bnb
.
nn
.
Linear8bitLt
,
bnb
.
nn
.
LinearFP4
,
lambda
d1
,
d2
:
bnb
.
nn
.
LinearFP4
(
d1
,
d2
,
compress_statistics
=
True
)],
ids
=
[
'Int8Lt'
,
'FP4'
,
'FP4+C'
]
)
@
pytest
.
mark
.
parametrize
(
"module"
,
modules
,
ids
=
names
)
def
test_kbit_backprop
(
module
):
def
test_kbit_backprop
(
module
):
b
=
17
b
=
17
dim1
=
37
dim1
=
37
...
@@ -515,6 +523,8 @@ def test_kbit_backprop(module):
...
@@ -515,6 +523,8 @@ def test_kbit_backprop(module):
ref
=
nn
.
Sequential
(
*
[
torch
.
nn
.
Linear
(
dim1
,
dim2
),
torch
.
nn
.
Linear
(
dim2
,
10
)])
ref
=
nn
.
Sequential
(
*
[
torch
.
nn
.
Linear
(
dim1
,
dim2
),
torch
.
nn
.
Linear
(
dim2
,
10
)])
ref
[
1
].
weight
.
requires_grad
=
False
ref
[
1
].
weight
.
requires_grad
=
False
torch
.
nn
.
init
.
kaiming_normal_
(
ref
[
0
].
weight
)
torch
.
nn
.
init
.
kaiming_normal_
(
ref
[
1
].
weight
)
kbit
=
nn
.
Sequential
(
*
[
torch
.
nn
.
Linear
(
dim1
,
dim2
),
module
(
dim2
,
10
)])
kbit
=
nn
.
Sequential
(
*
[
torch
.
nn
.
Linear
(
dim1
,
dim2
),
module
(
dim2
,
10
)])
kbit
[
0
].
weight
.
detach
().
copy_
(
ref
[
0
].
weight
)
kbit
[
0
].
weight
.
detach
().
copy_
(
ref
[
0
].
weight
)
kbit
[
1
].
weight
.
detach
().
copy_
(
ref
[
1
].
weight
)
kbit
[
1
].
weight
.
detach
().
copy_
(
ref
[
1
].
weight
)
...
@@ -523,6 +533,10 @@ def test_kbit_backprop(module):
...
@@ -523,6 +533,10 @@ def test_kbit_backprop(module):
ref
=
ref
.
half
().
cuda
()
ref
=
ref
.
half
().
cuda
()
kbit
=
kbit
.
half
().
cuda
()
kbit
=
kbit
.
half
().
cuda
()
errs1
=
[]
errs2
=
[]
relerrs1
=
[]
relerrs2
=
[]
for
i
in
range
(
100
):
for
i
in
range
(
100
):
batch
=
torch
.
randn
(
b
,
dim1
).
half
().
cuda
()
batch
=
torch
.
randn
(
b
,
dim1
).
half
().
cuda
()
out1
=
ref
(
batch
)
out1
=
ref
(
batch
)
...
@@ -535,12 +549,26 @@ def test_kbit_backprop(module):
...
@@ -535,12 +549,26 @@ def test_kbit_backprop(module):
bgrad1
=
ref
[
0
].
bias
.
grad
bgrad1
=
ref
[
0
].
bias
.
grad
bgrad2
=
kbit
[
0
].
bias
.
grad
bgrad2
=
kbit
[
0
].
bias
.
grad
torch
.
testing
.
assert_allclose
(
grad1
,
grad2
,
atol
=
0.008
,
rtol
=
0.05
)
err1
=
(
out1
-
out2
).
abs
().
float
()
torch
.
testing
.
assert_allclose
(
bgrad1
,
bgrad2
,
atol
=
0.008
,
rtol
=
0.05
)
err2
=
(
grad1
-
grad2
).
abs
().
float
()
relerr1
=
(
err1
/
(
out1
.
abs
().
float
()
+
1e-9
))
relerr2
=
(
err2
/
(
grad1
.
abs
().
float
()
+
1e-9
))
errs1
.
append
(
err1
.
mean
().
item
())
errs2
.
append
(
err2
.
mean
().
item
())
relerrs1
.
append
(
relerr1
.
mean
().
item
())
relerrs2
.
append
(
relerr2
.
mean
().
item
())
#torch.testing.assert_allclose(grad1, grad2, atol=0.008, rtol=0.05)
#torch.testing.assert_allclose(bgrad1, bgrad2, atol=0.008, rtol=0.05)
ref
.
zero_grad
()
ref
.
zero_grad
()
kbit
.
zero_grad
()
kbit
.
zero_grad
()
assert
kbit
[
0
].
weight
.
grad
.
sum
().
item
()
==
0
assert
kbit
[
0
].
weight
.
grad
.
sum
().
item
()
==
0
assert
kbit
[
0
].
bias
.
grad
.
sum
().
item
()
==
0
assert
kbit
[
0
].
bias
.
grad
.
sum
().
item
()
==
0
print
(
'out'
,
sum
(
errs1
)
/
len
(
errs1
))
print
(
'grad'
,
sum
(
errs2
)
/
len
(
errs2
))
print
(
'rel out'
,
sum
(
relerrs1
)
/
len
(
relerrs1
))
print
(
'rel grad'
,
sum
(
relerrs2
)
/
len
(
relerrs2
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment