Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
42653921
Unverified
Commit
42653921
authored
Aug 06, 2025
by
Matthew Douglas
Committed by
GitHub
Aug 06, 2025
Browse files
Merge pull request #1719 from ved1beta/fsdp_integration2
Fix Params4bit tensor subclass handling
parents
e54dc125
0ecb8fb4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
75 additions
and
0 deletions
+75
-0
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+40
-0
tests/test_linear4bit.py
tests/test_linear4bit.py
+35
-0
No files found.
bitsandbytes/nn/modules.py
View file @
42653921
...
@@ -356,6 +356,46 @@ class Params4bit(torch.nn.Parameter):
...
@@ -356,6 +356,46 @@ class Params4bit(torch.nn.Parameter):
return
new_param
return
new_param
@
classmethod
def
__torch_function__
(
cls
,
func
,
types
,
args
=
(),
kwargs
=
None
):
if
kwargs
is
None
:
kwargs
=
{}
if
func
in
[
torch
.
chunk
,
torch
.
split
]:
tensor
=
args
[
0
]
result
=
super
().
__torch_function__
(
func
,
types
,
args
,
kwargs
)
if
isinstance
(
result
,
tuple
):
return
tuple
(
cls
(
data
=
chunk
,
requires_grad
=
tensor
.
requires_grad
,
quant_state
=
tensor
.
quant_state
,
blocksize
=
tensor
.
blocksize
,
compress_statistics
=
tensor
.
compress_statistics
,
quant_type
=
tensor
.
quant_type
,
quant_storage
=
tensor
.
quant_storage
,
module
=
tensor
.
module
,
bnb_quantized
=
tensor
.
bnb_quantized
,
)
for
chunk
in
result
)
else
:
return
cls
(
data
=
result
,
requires_grad
=
tensor
.
requires_grad
,
quant_state
=
tensor
.
quant_state
,
blocksize
=
tensor
.
blocksize
,
compress_statistics
=
tensor
.
compress_statistics
,
quant_type
=
tensor
.
quant_type
,
quant_storage
=
tensor
.
quant_storage
,
module
=
tensor
.
module
,
bnb_quantized
=
tensor
.
bnb_quantized
,
)
return
super
().
__torch_function__
(
func
,
types
,
args
,
kwargs
)
def
fix_4bit_weight_quant_state_from_module
(
module
:
Union
[
"Embedding4bit"
,
"Linear4bit"
]):
def
fix_4bit_weight_quant_state_from_module
(
module
:
Union
[
"Embedding4bit"
,
"Linear4bit"
]):
if
getattr
(
module
.
weight
,
"quant_state"
,
None
)
is
not
None
:
if
getattr
(
module
.
weight
,
"quant_state"
,
None
)
is
not
None
:
...
...
tests/test_linear4bit.py
View file @
42653921
...
@@ -212,6 +212,41 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics):
...
@@ -212,6 +212,41 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics):
assert
param
.
data
.
data_ptr
()
==
shallow_copy_param
.
data
.
data_ptr
()
assert
param
.
data
.
data_ptr
()
==
shallow_copy_param
.
data
.
data_ptr
()
@
pytest
.
mark
.
parametrize
(
"device"
,
get_available_devices
())
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
"nf4"
,
"fp4"
])
def
test_params4bit_torch_chunk_split
(
device
,
quant_type
):
"""Test that torch.chunk and torch.split preserve Params4bit subclass for FSDP2 compatibility."""
if
device
==
"hpu"
and
not
is_supported_on_hpu
(
quant_type
,
torch
.
float16
,
torch
.
uint8
):
pytest
.
skip
(
"This configuration is not supported on HPU."
)
if
device
==
"cpu"
:
pytest
.
skip
(
"CPU quantization causes segfault, skipping CPU test"
)
original_tensor
=
torch
.
randn
(
8
,
4
,
dtype
=
torch
.
float16
,
device
=
"cpu"
)
params4bit
=
bnb
.
nn
.
Params4bit
(
data
=
original_tensor
,
quant_type
=
quant_type
,
requires_grad
=
False
)
if
device
!=
"cpu"
:
params4bit
=
params4bit
.
to
(
device
)
chunks
=
torch
.
chunk
(
params4bit
,
2
,
dim
=
0
)
assert
isinstance
(
chunks
,
tuple
),
"torch.chunk should return tuple"
for
chunk
in
chunks
:
assert
isinstance
(
chunk
,
bnb
.
nn
.
Params4bit
),
"Chunk should preserve Params4bit subclass"
assert
hasattr
(
chunk
,
"quant_type"
),
"Should preserve metadata"
assert
chunk
.
quant_type
==
params4bit
.
quant_type
,
"Should preserve quant_type value"
splits
=
torch
.
split
(
params4bit
,
2
,
dim
=
0
)
assert
isinstance
(
splits
,
tuple
),
"torch.split should return tuple"
assert
len
(
splits
)
>
0
,
"Should have at least one split"
for
split
in
splits
:
assert
isinstance
(
split
,
bnb
.
nn
.
Params4bit
),
"Split should preserve Params4bit subclass"
assert
hasattr
(
split
,
"quant_type"
),
"Should preserve metadata"
assert
split
.
quant_type
==
params4bit
.
quant_type
,
"Should preserve quant_type value"
@
pytest
.
mark
.
parametrize
(
"device"
,
get_available_devices
())
@
pytest
.
mark
.
parametrize
(
"device"
,
get_available_devices
())
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
"nf4"
,
"fp4"
])
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
"nf4"
,
"fp4"
])
@
pytest
.
mark
.
parametrize
(
"blocksize"
,
[
64
,
128
]
if
not
HIP_ENVIRONMENT
else
[
128
])
@
pytest
.
mark
.
parametrize
(
"blocksize"
,
[
64
,
128
]
if
not
HIP_ENVIRONMENT
else
[
128
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment