Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
97073cdb
Unverified
Commit
97073cdb
authored
Apr 17, 2025
by
Matthew Douglas
Committed by
GitHub
Apr 17, 2025
Browse files
Support LLM.int8() inference with torch.compile (#1594)
parent
feaedbb0
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
84 additions
and
25 deletions
+84
-25
bitsandbytes/_ops.py
bitsandbytes/_ops.py
+26
-0
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+16
-25
bitsandbytes/backends/cuda/ops.py
bitsandbytes/backends/cuda/ops.py
+42
-0
No files found.
bitsandbytes/_ops.py
View file @
97073cdb
...
...
@@ -15,6 +15,32 @@ else:
register_fake
=
torch
.
library
.
impl_abstract
register_kernel
=
torch
.
library
.
impl
# Int8 mixed precision matmul + dequant + bias
torch
.
library
.
define
(
"bitsandbytes::int8_mixed_scaled_mm"
,
"(Tensor A, Tensor CA, Tensor CB, Tensor SCA, Tensor SCB, Tensor? outlier_cols=None, Tensor? bias=None) -> (Tensor, Tensor?)"
,
)
@
register_fake
(
"bitsandbytes::int8_mixed_scaled_mm"
)
def
_
(
A
:
torch
.
Tensor
,
CA
:
torch
.
Tensor
,
CB
:
torch
.
Tensor
,
SCA
:
torch
.
Tensor
,
SCB
:
torch
.
Tensor
,
outlier_cols
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
shapeC
=
(
*
CA
.
shape
[:
-
1
],
CB
.
shape
[
0
])
out
=
torch
.
empty
(
shapeC
,
device
=
A
.
device
,
dtype
=
A
.
dtype
)
outlier_cols
=
torch
.
library
.
get_ctx
().
new_dynamic_size
()
subA
=
A
.
new_empty
(
outlier_cols
,
dtype
=
torch
.
int64
)
return
out
,
subA
# Higher level op: int8 matmul + dequant + bias
torch
.
library
.
define
(
...
...
bitsandbytes/autograd/_functions.py
View file @
97073cdb
...
...
@@ -210,37 +210,28 @@ class MatMul8bitLt(torch.autograd.Function):
# 2. Quantize B
state
.
CB
,
state
.
SCB
,
_
=
F
.
int8_vectorwise_quant
(
B
.
to
(
torch
.
float16
))
# Handle sparse decomposition. In some instances, we may have not found any
# outlier columns at all. In that case, we'll skip this part completely.
if
state
.
threshold
>
0.0
and
outlier_cols
is
not
None
and
outlier_cols
.
numel
():
# Handle sparse decomposition
if
state
.
threshold
>
0.0
:
state
.
idx
=
outlier_cols
# Zero out the outliers in the transposed 8bit inputs.
if
CAt
is
not
None
:
CAt
[:,
state
.
idx
]
=
0
# Extract the input outliers in original precision
subA
=
A
[:,
state
.
idx
].
contiguous
()
# Mixed Int8 Matmul + Dequant + Bias
output
,
subA
=
torch
.
ops
.
bitsandbytes
.
int8_mixed_scaled_mm
(
A
,
CA
,
state
.
CB
,
SCA
,
state
.
SCB
,
outlier_cols
,
bias
,
)
# Extract the corresponding weights
if
state
.
has_fp16_weights
:
state
.
subB
=
B
[:,
state
.
idx
].
t
()
else
:
# To dequantize our weights associated with the input outliers,
# we want to divide by 127. It's however more performant to multiply
# by the reciprocal.
outliers
=
state
.
CB
[:,
state
.
idx
]
state
.
subB
=
F
.
int8_vectorwise_dequant
(
outliers
,
state
.
SCB
).
to
(
A
.
dtype
).
t
()
else
:
# Int8 Matmul + Dequant + Bias
output
=
torch
.
ops
.
bitsandbytes
.
int8_scaled_mm
.
default
(
CA
,
state
.
CB
,
SCA
,
state
.
SCB
,
bias
=
bias
,
dtype
=
A
.
dtype
)
subA
=
None
# 3. Int8 Matmul + Dequant + Bias
output
=
torch
.
ops
.
bitsandbytes
.
int8_scaled_mm
.
default
(
CA
,
state
.
CB
,
SCA
,
state
.
SCB
,
bias
=
bias
,
dtype
=
A
.
dtype
)
# 4. Mixed-precision decomposition matmul
if
subA
is
not
None
and
state
.
subB
is
not
None
:
output
=
output
.
addmm
(
subA
,
state
.
subB
)
# 5. Save state
ctx
.
state
=
state
...
...
bitsandbytes/backends/cuda/ops.py
View file @
97073cdb
...
...
@@ -22,6 +22,45 @@ def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
_int8_linear_matmul_impl
(
A
,
B
,
out
)
@
register_kernel
(
"bitsandbytes::int8_mixed_scaled_mm"
,
"cuda"
)
def
_
(
A
:
torch
.
Tensor
,
CA
:
torch
.
Tensor
,
CB
:
torch
.
Tensor
,
SCA
:
torch
.
Tensor
,
SCB
:
torch
.
Tensor
,
outlier_cols
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
subB
=
None
if
outlier_cols
is
not
None
and
outlier_cols
.
numel
():
# Extract the inputs with outliers in original precision
subA
=
A
[:,
outlier_cols
].
contiguous
()
# Dequantize the corresponding weight columns
subB
=
(
torch
.
ops
.
bitsandbytes
.
int8_vectorwise_dequant
.
default
(
CB
[:,
outlier_cols
].
contiguous
(),
SCB
)
.
to
(
A
.
dtype
)
.
t
()
)
# TODO: if state.has_fp16_weights: subB = B[:, outlier_cols].t()
else
:
# Needed for torch.compile when there are no outliers.
subA
=
torch
.
empty
(
0
,
device
=
A
.
device
,
dtype
=
A
.
dtype
)
# Int8 Matmul + Dequant + Bias
output
=
torch
.
ops
.
bitsandbytes
.
int8_scaled_mm
.
default
(
CA
,
CB
,
SCA
,
SCB
,
bias
=
bias
,
dtype
=
A
.
dtype
)
if
subB
is
not
None
:
# Add the outlier columns back to the output
output
=
output
.
addmm
(
subA
,
subB
)
return
output
,
subA
def
_int8_linear_matmul_impl
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
out
:
torch
.
Tensor
):
A
,
B
=
B
,
A
...
...
@@ -143,6 +182,9 @@ def _(A: torch.Tensor, threshold=0.0):
if
outliers
.
any
():
outlier_cols
=
torch
.
argwhere
(
outliers
.
any
(
dim
=
0
)).
view
(
-
1
)
else
:
# Needed for torch.compile support.
outlier_cols
=
torch
.
empty
(
0
,
device
=
A
.
device
,
dtype
=
torch
.
int64
)
with
_cuda_device_of
(
A
):
lib
.
cint8_vector_quant
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment