Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
b8d1c261
Unverified
Commit
b8d1c261
authored
Sep 29, 2025
by
Matthew Douglas
Committed by
GitHub
Sep 29, 2025
Browse files
Linear8bitLt: support device movement after forward() (#1769)
parent
42e8abc3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
68 additions
and
7 deletions
+68
-7
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+30
-7
tests/test_linear8bitlt.py
tests/test_linear8bitlt.py
+38
-0
No files found.
bitsandbytes/nn/modules.py
View file @
b8d1c261
...
@@ -679,19 +679,27 @@ class Int8Params(torch.nn.Parameter):
...
@@ -679,19 +679,27 @@ class Int8Params(torch.nn.Parameter):
def
to
(
self
,
*
args
,
**
kwargs
):
def
to
(
self
,
*
args
,
**
kwargs
):
device
,
dtype
,
non_blocking
,
convert_to_format
=
torch
.
_C
.
_nn
.
_parse_to
(
*
args
,
**
kwargs
)
device
,
dtype
,
non_blocking
,
convert_to_format
=
torch
.
_C
.
_nn
.
_parse_to
(
*
args
,
**
kwargs
)
if
device
is
not
None
and
device
.
type
!=
"meta"
and
self
.
data
.
device
.
type
==
"cpu"
:
is_quantized
=
self
.
data
.
dtype
==
torch
.
int8
if
device
.
type
!=
"cpu"
or
self
.
data
.
dtype
!=
torch
.
int8
:
return
self
.
_quantize
(
device
)
elif
self
.
data
.
dtype
==
torch
.
int8
and
device
.
type
==
"cpu"
:
self
.
CB
=
self
.
data
if
not
is_quantized
and
device
is
not
None
and
device
.
type
!=
"meta"
and
self
.
data
.
device
.
type
==
"cpu"
:
# We're moving from a CPU device to a non-meta device.
# In this circumstance, we want to quantize if we haven't already.
return
self
.
_quantize
(
device
)
# Create a new parameter on the target device.
new_param
=
Int8Params
(
new_param
=
Int8Params
(
super
().
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
),
super
().
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
),
requires_grad
=
self
.
requires_grad
,
requires_grad
=
self
.
requires_grad
,
has_fp16_weights
=
self
.
has_fp16_weights
,
has_fp16_weights
=
self
.
has_fp16_weights
,
)
)
new_param
.
CB
=
self
.
CB
new_param
.
SCB
=
self
.
SCB
# If we had already quantized, move the statistics appropriately.
if
is_quantized
and
device
is
not
None
:
if
self
.
CB
is
not
None
:
new_param
.
CB
=
new_param
.
data
if
self
.
SCB
is
not
None
:
new_param
.
SCB
=
self
.
SCB
.
to
(
device
)
return
new_param
return
new_param
...
@@ -1037,6 +1045,21 @@ class Linear8bitLt(nn.Linear):
...
@@ -1037,6 +1045,21 @@ class Linear8bitLt(nn.Linear):
self
.
weight
.
CB
=
None
self
.
weight
.
CB
=
None
self
.
weight
.
SCB
=
None
self
.
weight
.
SCB
=
None
def
to
(
self
,
*
args
,
**
kwargs
):
# Call the parent to() method to handle standard parameter/buffer movement
result
=
super
().
to
(
*
args
,
**
kwargs
)
device
,
dtype
,
non_blocking
,
convert_to_format
=
torch
.
_C
.
_nn
.
_parse_to
(
*
args
,
**
kwargs
)
# Handle state tensors if needed.
if
device
is
not
None
:
if
result
.
state
.
CB
is
not
None
:
result
.
state
.
CB
=
result
.
state
.
CB
.
to
(
device
)
if
result
.
state
.
SCB
is
not
None
:
result
.
state
.
SCB
=
result
.
state
.
SCB
.
to
(
device
)
return
result
def
forward
(
self
,
x
:
torch
.
Tensor
):
def
forward
(
self
,
x
:
torch
.
Tensor
):
self
.
state
.
is_training
=
self
.
training
self
.
state
.
is_training
=
self
.
training
if
self
.
weight
.
CB
is
not
None
:
if
self
.
weight
.
CB
is
not
None
:
...
...
tests/test_linear8bitlt.py
View file @
b8d1c261
...
@@ -293,3 +293,41 @@ def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode):
...
@@ -293,3 +293,41 @@ def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode):
grad_compiled
=
x
.
grad
.
clone
()
grad_compiled
=
x
.
grad
.
clone
()
torch
.
testing
.
assert_close
(
grad_compiled
,
grad_ref
)
torch
.
testing
.
assert_close
(
grad_compiled
,
grad_ref
)
@
pytest
.
mark
.
parametrize
(
"device"
,
get_available_devices
(
no_cpu
=
True
))
@
pytest
.
mark
.
skipif
(
not
get_available_devices
(
no_cpu
=
True
),
reason
=
"No accelerator device"
)
def
test_linear8bitlt_device_movement
(
device
):
"""Test moving a Linear8bitLt layer between CPU and an accelerator device."""
# Create a Linear8bitLt layer on CPU
layer
=
bnb
.
nn
.
Linear8bitLt
(
32
,
128
,
bias
=
False
,
has_fp16_weights
=
False
)
torch
.
nn
.
init
.
xavier_uniform_
(
layer
.
weight
)
# Create a sample input.
x
=
torch
.
randn
(
4
,
32
,
dtype
=
torch
.
float16
,
device
=
"cpu"
)
# Move to the device. This should quantize the weights.
layer
=
layer
.
to
(
device
)
assert
layer
.
weight
.
data
.
dtype
==
torch
.
int8
# Call the layer on the accelerator device.
out_accelerator
=
layer
(
x
.
to
(
device
))
# Move back to CPU and call again.
layer
=
layer
.
to
(
"cpu"
)
out_cpu
=
layer
(
x
)
# Move back to the accelerator device and call again.
layer
=
layer
.
to
(
device
)
out_accelerator_2
=
layer
(
x
.
to
(
device
))
# Move back to the CPU and call one last time.
layer
=
layer
.
to
(
"cpu"
)
out_cpu_2
=
layer
(
x
)
# CPU outputs should match both times.
torch
.
testing
.
assert_close
(
out_cpu_2
,
out_cpu
,
rtol
=
1e-8
,
atol
=
1e-8
)
# Accelerator outputs should match both times.
torch
.
testing
.
assert_close
(
out_accelerator_2
,
out_accelerator
,
rtol
=
1e-8
,
atol
=
1e-8
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment