Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
4c11d6dc
Commit
4c11d6dc
authored
Sep 20, 2023
by
Ruslan Svirschevski
Browse files
reverted fn signatures in functional()
parent
1d9f0f2a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
15 deletions
+15
-15
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+1
-1
bitsandbytes/functional.py
bitsandbytes/functional.py
+13
-13
tests/test_functional.py
tests/test_functional.py
+1
-1
No files found.
bitsandbytes/autograd/_functions.py
View file @
4c11d6dc
...
...
@@ -569,7 +569,7 @@ def matmul_4bit(A: tensor, B: tensor, quant_state: F.QuantState, out: tensor = N
warn
(
f
'Some matrices hidden dimension is not a multiple of
{
quant_state
.
blocksize
}
and efficient inference kernels are not supported for these (slow). Matrix input size found:
{
A
.
shape
}
'
)
return
MatMul4Bit
.
apply
(
A
,
B
,
out
,
bias
,
quant_state
)
else
:
out
=
F
.
gemv_4bit
(
A
,
B
.
t
(),
out
,
quant_
state
=
quant_state
)
out
=
F
.
gemv_4bit
(
A
,
B
.
t
(),
out
,
state
=
quant_state
)
if
bias
is
not
None
:
out
+=
bias
return
out
...
...
bitsandbytes/functional.py
View file @
4c11d6dc
...
...
@@ -1579,22 +1579,22 @@ def gemv_4bit(
out
:
Tensor
=
None
,
transposed_A
=
False
,
transposed_B
=
False
,
quant_
state
=
None
state
=
None
):
prev_device
=
pre_call
(
A
.
device
)
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
if
quant_
state
is
None
:
if
state
is
None
:
raise
ValueError
(
f
'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )'
)
if
A
.
numel
()
!=
A
.
shape
[
-
1
]:
raise
ValueError
(
f
'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]'
)
Bshape
=
quant_
state
.
shape
Bshape
=
state
.
shape
bout
=
Bshape
[
0
]
absmax
=
quant_
state
.
absmax
if
quant_
state
.
nested
:
absmax
=
dequantize_blockwise
(
quant_
state
.
absmax
,
quant_
state
.
state2
)
absmax
+=
quant_
state
.
offset
absmax
=
state
.
absmax
if
state
.
nested
:
absmax
=
dequantize_blockwise
(
state
.
absmax
,
state
.
state2
)
absmax
+=
state
.
offset
if
out
is
None
:
if
len
(
A
.
shape
)
==
3
:
...
...
@@ -1608,7 +1608,7 @@ def gemv_4bit(
lda
=
Bshape
[
0
]
ldc
=
Bshape
[
0
]
ldb
=
(
A
.
shape
[
-
1
]
+
1
)
//
2
is_on_gpu
([
B
,
A
,
out
,
absmax
,
quant_
state
.
code
])
is_on_gpu
([
B
,
A
,
out
,
absmax
,
state
.
code
])
m
=
ct
.
c_int32
(
m
)
n
=
ct
.
c_int32
(
n
)
k
=
ct
.
c_int32
(
k
)
...
...
@@ -1618,11 +1618,11 @@ def gemv_4bit(
if
B
.
dtype
==
torch
.
uint8
:
if
A
.
dtype
==
torch
.
float16
:
lib
.
cgemm_4bit_inference_naive_fp16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
absmax
),
get_ptr
(
quant_
state
.
code
),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
quant_
state
.
blocksize
))
lib
.
cgemm_4bit_inference_naive_fp16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
absmax
),
get_ptr
(
state
.
code
),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
.
blocksize
))
elif
A
.
dtype
==
torch
.
bfloat16
:
lib
.
cgemm_4bit_inference_naive_bf16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
absmax
),
get_ptr
(
quant_
state
.
code
),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
quant_
state
.
blocksize
))
lib
.
cgemm_4bit_inference_naive_bf16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
absmax
),
get_ptr
(
state
.
code
),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
.
blocksize
))
elif
A
.
dtype
==
torch
.
float32
:
lib
.
cgemm_4bit_inference_naive_fp32
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
absmax
),
get_ptr
(
quant_
state
.
code
),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
quant_
state
.
blocksize
))
lib
.
cgemm_4bit_inference_naive_fp32
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
absmax
),
get_ptr
(
state
.
code
),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
.
blocksize
))
else
:
raise
NotImplementedError
(
f
'Matmul not implemented for data type
{
A
.
dtype
}
'
)
...
...
@@ -1904,7 +1904,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
def
mm_dequant
(
A
,
state
,
quant_
state
,
row_stats
,
col_stats
,
out
=
None
,
...
...
@@ -1914,7 +1914,7 @@ def mm_dequant(
):
assert
A
.
dtype
==
torch
.
int32
if
bias
is
not
None
:
assert
bias
.
dtype
==
torch
.
float16
out_shape
=
state
[
0
]
out_shape
=
quant_
state
[
0
]
if
len
(
out_shape
)
==
3
:
out_shape
=
(
out_shape
[
0
]
*
out_shape
[
1
],
out_shape
[
2
])
...
...
tests/test_functional.py
View file @
4c11d6dc
...
...
@@ -2401,7 +2401,7 @@ def test_gemv_4bit(dtype, storage_type, double_quant, kind):
qB
,
state
=
F
.
quantize_4bit
(
B
,
quant_type
=
storage_type
,
compress_statistics
=
double_quant
)
C3
=
torch
.
matmul
(
A
,
B
.
t
())
C2
=
F
.
gemv_4bit
(
A
,
qB
.
t
(),
quant_
state
=
state
)
C2
=
F
.
gemv_4bit
(
A
,
qB
.
t
(),
state
=
state
)
A
.
requires_grad
=
True
C1
=
bnb
.
matmul_4bit
(
A
,
qB
.
t
(),
state
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment