Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
4c11d6dc
Commit
4c11d6dc
authored
Sep 20, 2023
by
Ruslan Svirschevski
Browse files
reverted fn signatures in functional()
parent
1d9f0f2a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
15 deletions
+15
-15
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+1
-1
bitsandbytes/functional.py
bitsandbytes/functional.py
+13
-13
tests/test_functional.py
tests/test_functional.py
+1
-1
No files found.
bitsandbytes/autograd/_functions.py
View file @
4c11d6dc
...
...
@@ -569,7 +569,7 @@ def matmul_4bit(A: tensor, B: tensor, quant_state: F.QuantState, out: tensor = N
warn
(
f
'Some matrices hidden dimension is not a multiple of
{
quant_state
.
blocksize
}
and efficient inference kernels are not supported for these (slow). Matrix input size found:
{
A
.
shape
}
'
)
return
MatMul4Bit
.
apply
(
A
,
B
,
out
,
bias
,
quant_state
)
else
:
out
=
F
.
gemv_4bit
(
A
,
B
.
t
(),
out
,
quant_
state
=
quant_state
)
out
=
F
.
gemv_4bit
(
A
,
B
.
t
(),
out
,
state
=
quant_state
)
if
bias
is
not
None
:
out
+=
bias
return
out
...
...
bitsandbytes/functional.py
View file @
4c11d6dc
...
...
@@ -1579,22 +1579,22 @@ def gemv_4bit(
out
:
Tensor
=
None
,
transposed_A
=
False
,
transposed_B
=
False
,
quant_
state
=
None
state
=
None
):
prev_device
=
pre_call
(
A
.
device
)
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
if
quant_
state
is
None
:
if
state
is
None
:
raise
ValueError
(
f
'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )'
)
if
A
.
numel
()
!=
A
.
shape
[
-
1
]:
raise
ValueError
(
f
'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]'
)
Bshape
=
quant_
state
.
shape
Bshape
=
state
.
shape
bout
=
Bshape
[
0
]
absmax
=
quant_
state
.
absmax
if
quant_
state
.
nested
:
absmax
=
dequantize_blockwise
(
quant_
state
.
absmax
,
quant_
state
.
state2
)
absmax
+=
quant_
state
.
offset
absmax
=
state
.
absmax
if
state
.
nested
:
absmax
=
dequantize_blockwise
(
state
.
absmax
,
state
.
state2
)
absmax
+=
state
.
offset
if
out
is
None
:
if
len
(
A
.
shape
)
==
3
:
...
...
@@ -1608,7 +1608,7 @@ def gemv_4bit(
lda
=
Bshape
[
0
]
ldc
=
Bshape
[
0
]
ldb
=
(
A
.
shape
[
-
1
]
+
1
)
//
2
is_on_gpu
([
B
,
A
,
out
,
absmax
,
quant_
state
.
code
])
is_on_gpu
([
B
,
A
,
out
,
absmax
,
state
.
code
])
m
=
ct
.
c_int32
(
m
)
n
=
ct
.
c_int32
(
n
)
k
=
ct
.
c_int32
(
k
)
...
...
@@ -1618,11 +1618,11 @@ def gemv_4bit(
if
B
.
dtype
==
torch
.
uint8
:
if
A
.
dtype
==
torch
.
float16
:
lib
.
cgemm_4bit_inference_naive_fp16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
absmax
),
get_ptr
(
quant_
state
.
code
),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
quant_
state
.
blocksize
))
lib
.
cgemm_4bit_inference_naive_fp16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
absmax
),
get_ptr
(
state
.
code
),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
.
blocksize
))
elif
A
.
dtype
==
torch
.
bfloat16
:
lib
.
cgemm_4bit_inference_naive_bf16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
absmax
),
get_ptr
(
quant_
state
.
code
),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
quant_
state
.
blocksize
))
lib
.
cgemm_4bit_inference_naive_bf16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
absmax
),
get_ptr
(
state
.
code
),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
.
blocksize
))
elif
A
.
dtype
==
torch
.
float32
:
lib
.
cgemm_4bit_inference_naive_fp32
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
absmax
),
get_ptr
(
quant_
state
.
code
),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
quant_
state
.
blocksize
))
lib
.
cgemm_4bit_inference_naive_fp32
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
absmax
),
get_ptr
(
state
.
code
),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
.
blocksize
))
else
:
raise
NotImplementedError
(
f
'Matmul not implemented for data type
{
A
.
dtype
}
'
)
...
...
@@ -1904,7 +1904,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
def
mm_dequant
(
A
,
state
,
quant_
state
,
row_stats
,
col_stats
,
out
=
None
,
...
...
@@ -1914,7 +1914,7 @@ def mm_dequant(
):
assert
A
.
dtype
==
torch
.
int32
if
bias
is
not
None
:
assert
bias
.
dtype
==
torch
.
float16
out_shape
=
state
[
0
]
out_shape
=
quant_
state
[
0
]
if
len
(
out_shape
)
==
3
:
out_shape
=
(
out_shape
[
0
]
*
out_shape
[
1
],
out_shape
[
2
])
...
...
tests/test_functional.py
View file @
4c11d6dc
...
...
@@ -2401,7 +2401,7 @@ def test_gemv_4bit(dtype, storage_type, double_quant, kind):
qB
,
state
=
F
.
quantize_4bit
(
B
,
quant_type
=
storage_type
,
compress_statistics
=
double_quant
)
C3
=
torch
.
matmul
(
A
,
B
.
t
())
C2
=
F
.
gemv_4bit
(
A
,
qB
.
t
(),
quant_
state
=
state
)
C2
=
F
.
gemv_4bit
(
A
,
qB
.
t
(),
state
=
state
)
A
.
requires_grad
=
True
C1
=
bnb
.
matmul_4bit
(
A
,
qB
.
t
(),
state
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment