Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
2938c739
Commit
2938c739
authored
Aug 02, 2025
by
ved1beta
Browse files
test_params4bit_torch_chunk_split
parent
1dbe6021
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
0 deletions
+35
-0
tests/test_linear4bit.py
tests/test_linear4bit.py
+35
-0
No files found.
tests/test_linear4bit.py
View file @
2938c739
...
@@ -212,6 +212,41 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics):
...
@@ -212,6 +212,41 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics):
assert
param
.
data
.
data_ptr
()
==
shallow_copy_param
.
data
.
data_ptr
()
assert
param
.
data
.
data_ptr
()
==
shallow_copy_param
.
data
.
data_ptr
()
@
pytest
.
mark
.
parametrize
(
"device"
,
get_available_devices
())
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
"nf4"
,
"fp4"
])
def
test_params4bit_torch_chunk_split
(
device
,
quant_type
):
"""Test that torch.chunk and torch.split preserve Params4bit subclass for FSDP2 compatibility."""
if
device
==
"hpu"
and
not
is_supported_on_hpu
(
quant_type
,
torch
.
float16
,
torch
.
uint8
):
pytest
.
skip
(
"This configuration is not supported on HPU."
)
if
device
==
"cpu"
:
pytest
.
skip
(
"CPU quantization causes segfault, skipping CPU test"
)
original_tensor
=
torch
.
randn
(
8
,
4
,
dtype
=
torch
.
float16
,
device
=
"cpu"
)
params4bit
=
bnb
.
nn
.
Params4bit
(
data
=
original_tensor
,
quant_type
=
quant_type
,
requires_grad
=
False
)
if
device
!=
"cpu"
:
params4bit
=
params4bit
.
to
(
device
)
chunks
=
torch
.
chunk
(
params4bit
,
2
,
dim
=
0
)
assert
isinstance
(
chunks
,
tuple
),
"torch.chunk should return tuple"
for
chunk
in
chunks
:
assert
isinstance
(
chunk
,
bnb
.
nn
.
Params4bit
),
"Chunk should preserve Params4bit subclass"
assert
hasattr
(
chunk
,
"quant_type"
),
"Should preserve metadata"
assert
chunk
.
quant_type
==
params4bit
.
quant_type
,
"Should preserve quant_type value"
splits
=
torch
.
split
(
params4bit
,
2
,
dim
=
0
)
assert
isinstance
(
splits
,
tuple
),
"torch.split should return tuple"
assert
len
(
splits
)
>
0
,
"Should have at least one split"
for
split
in
splits
:
assert
isinstance
(
split
,
bnb
.
nn
.
Params4bit
),
"Split should preserve Params4bit subclass"
assert
hasattr
(
split
,
"quant_type"
),
"Should preserve metadata"
assert
split
.
quant_type
==
params4bit
.
quant_type
,
"Should preserve quant_type value"
@
pytest
.
mark
.
parametrize
(
"device"
,
get_available_devices
())
@
pytest
.
mark
.
parametrize
(
"device"
,
get_available_devices
())
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
"nf4"
,
"fp4"
])
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
"nf4"
,
"fp4"
])
@
pytest
.
mark
.
parametrize
(
"blocksize"
,
[
64
,
128
]
if
not
HIP_ENVIRONMENT
else
[
128
])
@
pytest
.
mark
.
parametrize
(
"blocksize"
,
[
64
,
128
]
if
not
HIP_ENVIRONMENT
else
[
128
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment