Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
8f88cef1
Unverified
Commit
8f88cef1
authored
Apr 10, 2025
by
Matthew Douglas
Committed by
GitHub
Apr 10, 2025
Browse files
Fix #1588 - torch compatability for <=2.4 (#1590)
parent
d2fe0e3c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
22 additions
and
17 deletions
+22
-17
bitsandbytes/_ops.py
bitsandbytes/_ops.py
+7
-7
bitsandbytes/backends/cpu/ops.py
bitsandbytes/backends/cpu/ops.py
+2
-2
bitsandbytes/backends/cuda/ops.py
bitsandbytes/backends/cuda/ops.py
+2
-2
bitsandbytes/backends/default/ops.py
bitsandbytes/backends/default/ops.py
+11
-6
No files found.
bitsandbytes/_ops.py
View file @
8f88cef1
...
@@ -19,7 +19,7 @@ else:
...
@@ -19,7 +19,7 @@ else:
# Higher level op: int8 matmul + dequant + bias
# Higher level op: int8 matmul + dequant + bias
torch
.
library
.
define
(
torch
.
library
.
define
(
"bitsandbytes::int8_scaled_mm"
,
"bitsandbytes::int8_scaled_mm"
,
"(Tensor A, Tensor B, Tensor row_stats, Tensor col_stats, Tensor? bias=None, ScalarType dtype=
float16
) -> Tensor"
,
"(Tensor A, Tensor B, Tensor row_stats, Tensor col_stats, Tensor? bias=None, ScalarType
?
dtype=
None
) -> Tensor"
,
)
)
...
@@ -30,10 +30,10 @@ def _(
...
@@ -30,10 +30,10 @@ def _(
row_stats
:
torch
.
Tensor
,
row_stats
:
torch
.
Tensor
,
col_stats
:
torch
.
Tensor
,
col_stats
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
dtype
=
torch
.
float16
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
shapeC
=
(
*
A
.
shape
[:
-
1
],
B
.
shape
[
0
])
shapeC
=
(
*
A
.
shape
[:
-
1
],
B
.
shape
[
0
])
return
torch
.
empty
(
shapeC
,
device
=
A
.
device
,
dtype
=
dtype
)
return
torch
.
empty
(
shapeC
,
device
=
A
.
device
,
dtype
=
dtype
or
torch
.
float16
)
torch
.
library
.
define
(
torch
.
library
.
define
(
...
@@ -98,7 +98,7 @@ def _(A: torch.Tensor, stats: torch.Tensor) -> torch.Tensor:
...
@@ -98,7 +98,7 @@ def _(A: torch.Tensor, stats: torch.Tensor) -> torch.Tensor:
# Default PyTorch-native implementation
# Default PyTorch-native implementation
@
register_kernel
(
"bitsandbytes::int8_vectorwise_dequant"
,
None
)
@
register_kernel
(
"bitsandbytes::int8_vectorwise_dequant"
,
"default"
)
def
_
(
A
:
torch
.
Tensor
,
stats
:
torch
.
Tensor
):
def
_
(
A
:
torch
.
Tensor
,
stats
:
torch
.
Tensor
):
# To dequantize we divide by 127, or multiply by the reciprocal.
# To dequantize we divide by 127, or multiply by the reciprocal.
return
A
*
stats
.
view
(
-
1
,
1
)
*
7.874015718698502e-3
return
A
*
stats
.
view
(
-
1
,
1
)
*
7.874015718698502e-3
...
@@ -106,7 +106,7 @@ def _(A: torch.Tensor, stats: torch.Tensor):
...
@@ -106,7 +106,7 @@ def _(A: torch.Tensor, stats: torch.Tensor):
torch
.
library
.
define
(
torch
.
library
.
define
(
"bitsandbytes::int8_mm_dequant"
,
"bitsandbytes::int8_mm_dequant"
,
"(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType dtype=
float16
, Tensor? bias=None) -> Tensor"
,
"(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType
?
dtype=
None
, Tensor? bias=None) -> Tensor"
,
)
)
...
@@ -115,11 +115,11 @@ def _(
...
@@ -115,11 +115,11 @@ def _(
A
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
row_stats
:
torch
.
Tensor
,
row_stats
:
torch
.
Tensor
,
col_stats
:
torch
.
Tensor
,
col_stats
:
torch
.
Tensor
,
dtype
=
torch
.
float16
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
torch
.
_check
(
A
.
dtype
==
torch
.
int32
,
lambda
:
"A must be int32"
)
torch
.
_check
(
A
.
dtype
==
torch
.
int32
,
lambda
:
"A must be int32"
)
return
torch
.
empty_like
(
A
,
dtype
=
dtype
)
return
torch
.
empty_like
(
A
,
dtype
=
dtype
or
torch
.
float16
)
torch
.
library
.
define
(
torch
.
library
.
define
(
...
...
bitsandbytes/backends/cpu/ops.py
View file @
8f88cef1
...
@@ -28,7 +28,7 @@ def _(
...
@@ -28,7 +28,7 @@ def _(
A
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
row_stats
:
torch
.
Tensor
,
row_stats
:
torch
.
Tensor
,
col_stats
:
torch
.
Tensor
,
col_stats
:
torch
.
Tensor
,
dtype
=
torch
.
float16
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
torch
.
_check
(
A
.
dtype
==
torch
.
int32
,
lambda
:
f
"A must be int32, got
{
A
.
dtype
}
"
)
torch
.
_check
(
A
.
dtype
==
torch
.
int32
,
lambda
:
f
"A must be int32, got
{
A
.
dtype
}
"
)
...
@@ -43,7 +43,7 @@ def _(
...
@@ -43,7 +43,7 @@ def _(
if
bias
is
not
None
:
if
bias
is
not
None
:
out
+=
bias
out
+=
bias
return
out
.
to
(
dtype
)
return
out
.
to
(
dtype
or
torch
.
float16
)
@
register_kernel
(
"bitsandbytes::quantize_blockwise"
,
"cpu"
)
@
register_kernel
(
"bitsandbytes::quantize_blockwise"
,
"cpu"
)
...
...
bitsandbytes/backends/cuda/ops.py
View file @
8f88cef1
...
@@ -90,7 +90,7 @@ def _(
...
@@ -90,7 +90,7 @@ def _(
A
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
row_stats
:
torch
.
Tensor
,
row_stats
:
torch
.
Tensor
,
col_stats
:
torch
.
Tensor
,
col_stats
:
torch
.
Tensor
,
dtype
=
torch
.
float16
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
torch
.
_check
(
A
.
dtype
==
torch
.
int32
,
lambda
:
f
"A must be int32, got
{
A
.
dtype
}
"
)
torch
.
_check
(
A
.
dtype
==
torch
.
int32
,
lambda
:
f
"A must be int32, got
{
A
.
dtype
}
"
)
...
@@ -121,7 +121,7 @@ def _(
...
@@ -121,7 +121,7 @@ def _(
if
bias
is
not
None
and
bias
.
dtype
!=
torch
.
float16
:
if
bias
is
not
None
and
bias
.
dtype
!=
torch
.
float16
:
out
.
add_
(
bias
)
out
.
add_
(
bias
)
return
out
.
to
(
dtype
)
return
out
.
to
(
dtype
or
torch
.
float16
)
@
register_kernel
(
"bitsandbytes::int8_vectorwise_quant"
,
"cuda"
)
@
register_kernel
(
"bitsandbytes::int8_vectorwise_quant"
,
"cuda"
)
...
...
bitsandbytes/backends/default/ops.py
View file @
8f88cef1
...
@@ -5,26 +5,31 @@ import torch
...
@@ -5,26 +5,31 @@ import torch
from
..._ops
import
register_kernel
from
..._ops
import
register_kernel
@
register_kernel
(
"bitsandbytes::int8_scaled_mm"
,
None
)
@
register_kernel
(
"bitsandbytes::int8_scaled_mm"
,
"default"
)
def
_
(
def
_
(
A
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
row_stats
:
torch
.
Tensor
,
row_stats
:
torch
.
Tensor
,
col_stats
:
torch
.
Tensor
,
col_stats
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
dtype
=
torch
.
float16
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
out_i32
=
torch
.
ops
.
bitsandbytes
.
int8_linear_matmul
.
default
(
A
,
B
)
out_i32
=
torch
.
ops
.
bitsandbytes
.
int8_linear_matmul
.
default
(
A
,
B
)
out
=
torch
.
ops
.
bitsandbytes
.
int8_mm_dequant
.
default
(
out_i32
,
row_stats
,
col_stats
,
dtype
=
dtype
,
bias
=
bias
)
return
torch
.
ops
.
bitsandbytes
.
int8_mm_dequant
.
default
(
return
out
out_i32
,
row_stats
,
col_stats
,
dtype
=
dtype
or
torch
.
float16
,
bias
=
bias
,
)
@
register_kernel
(
"bitsandbytes::int8_linear_matmul"
,
None
)
@
register_kernel
(
"bitsandbytes::int8_linear_matmul"
,
"default"
)
def
_
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
):
def
_
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
):
return
_int8_linear_matmul_impl
(
A
,
B
)
return
_int8_linear_matmul_impl
(
A
,
B
)
@
register_kernel
(
"bitsandbytes::int8_linear_matmul.out"
,
None
)
@
register_kernel
(
"bitsandbytes::int8_linear_matmul.out"
,
"default"
)
def
_
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
out
:
torch
.
Tensor
):
def
_
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
out
:
torch
.
Tensor
):
torch
.
_check
(
out
.
dtype
==
torch
.
int32
)
torch
.
_check
(
out
.
dtype
==
torch
.
int32
)
_int8_linear_matmul_impl
(
A
,
B
,
out
)
_int8_linear_matmul_impl
(
A
,
B
,
out
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment