Unverified Commit 9f858294 authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Add torch.compile tests (#1648)

* Add torch.compile tests

* Tests: WA aarch64 CPU regressions for torch 2.6.0; add Windows torch==2.7.0+cu118 test config

* Tests: skip torch.compile for cuda on windows
parent 503d243e
...@@ -137,6 +137,10 @@ jobs: ...@@ -137,6 +137,10 @@ jobs:
with: with:
python-version: 3.9 python-version: 3.9
- name: Setup MSVC
if: startsWith(matrix.os, 'windows')
uses: ilammy/msvc-dev-cmd@v1.13.0 # to use cl for torch.compile
- name: Install dependencies - name: Install dependencies
run: | run: |
pip install torch==${{ matrix.torch_version }} --index-url https://download.pytorch.org/whl/cpu pip install torch==${{ matrix.torch_version }} --index-url https://download.pytorch.org/whl/cpu
...@@ -201,18 +205,40 @@ jobs: ...@@ -201,18 +205,40 @@ jobs:
torch_version: "2.7.0" torch_version: "2.7.0"
pypi_index: "https://download.pytorch.org/whl/cu128" pypi_index: "https://download.pytorch.org/whl/cu128"
# L40S runners
# Linux L40S runners
- os: ubuntu-22.04 - os: ubuntu-22.04
gpu: L40S gpu: L40S
runner: bandb-aws-g6e-4xlarge-plus-use1-public-80 runner: bandb-aws-g6e-4xlarge-plus-use1-public-80
# T4 runners # Linux T4 runners
- os: ubuntu-22.04 - os: ubuntu-22.04
gpu: T4 gpu: T4
runner: bandb-aws-g4dn-4xlarge-plus-use1-public-80 runner: bandb-aws-g4dn-4xlarge-plus-use1-public-80
# Specific Windows runners using cu118
- os: windows-2025
arch: x86_64
gpu: T4
runner: CUDA-Windows-x64
cuda_version: "11.8.0"
torch_version: "2.2.0"
pypi_index: "https://download.pytorch.org/whl/cu118"
- os: windows-2025 - os: windows-2025
arch: x86_64
gpu: T4
runner: CUDA-Windows-x64
cuda_version: "11.8.0"
torch_version: "2.6.0"
pypi_index: "https://download.pytorch.org/whl/cu118"
- os: windows-2025
arch: x86_64
gpu: T4 gpu: T4
runner: CUDA-Windows-x64 runner: CUDA-Windows-x64
cuda_version: "11.8.0"
torch_version: "2.7.0"
pypi_index: "https://download.pytorch.org/whl/cu118"
exclude: exclude:
# Our current T4 Windows runner has a driver too old (471.11) # Our current T4 Windows runner has a driver too old (471.11)
# and cannot support CUDA 12+. Skip for now. # and cannot support CUDA 12+. Skip for now.
......
...@@ -771,14 +771,14 @@ def quantize_blockwise( ...@@ -771,14 +771,14 @@ def quantize_blockwise(
qabsmax, state2 = quantize_blockwise(_absmax, blocksize=blocksize, nested=False) qabsmax, state2 = quantize_blockwise(_absmax, blocksize=blocksize, nested=False)
quant_state = QuantState( quant_state = QuantState(
absmax=qabsmax, absmax=qabsmax,
code=code, code=code.to(A.device, copy=True),
blocksize=blocksize, blocksize=blocksize,
dtype=A.dtype, dtype=A.dtype,
offset=offset, offset=offset,
state2=state2, state2=state2,
) )
else: else:
quant_state = QuantState(absmax=_absmax, code=code.to(A.device), blocksize=blocksize, dtype=A.dtype) quant_state = QuantState(absmax=_absmax, code=code.to(A.device, copy=True), blocksize=blocksize, dtype=A.dtype)
# TODO(matthewdouglas): Deprecate out kwarg # TODO(matthewdouglas): Deprecate out kwarg
out = out.copy_(_out) if out is not None else _out out = out.copy_(_out) if out is not None else _out
......
...@@ -493,7 +493,7 @@ class Linear4bit(nn.Linear): ...@@ -493,7 +493,7 @@ class Linear4bit(nn.Linear):
bias = None if self.bias is None else self.bias.to(self.compute_dtype) bias = None if self.bias is None else self.bias.to(self.compute_dtype)
return bnb.matmul_4bit(x, self.weight.data.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype) return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
class LinearFP4(Linear4bit): class LinearFP4(Linear4bit):
......
import copy import copy
import os import os
import pickle import pickle
import platform
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
import pytest import pytest
import torch import torch
import bitsandbytes as bnb import bitsandbytes as bnb
from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, torch_load_from_buffer, torch_save_to_buffer from tests.helpers import (
TRUE_FALSE,
describe_dtype,
get_available_devices,
id_formatter,
torch_load_from_buffer,
torch_save_to_buffer,
)
storage = { storage = {
"uint8": torch.uint8, "uint8": torch.uint8,
...@@ -275,3 +283,85 @@ def test_params4bit_real_serialization(device, quant_type, blocksize, compress_s ...@@ -275,3 +283,85 @@ def test_params4bit_real_serialization(device, quant_type, blocksize, compress_s
# there was a bug where deepcopy would modify the original object # there was a bug where deepcopy would modify the original object
assert dict_keys_before == dict_keys_after assert dict_keys_before == dict_keys_after
assert dict_keys_before == dict_keys_deserialized assert dict_keys_before == dict_keys_deserialized
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("compute_dtype", [torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias"))
@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph"))
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_statistics, bias, fullgraph, mode):
if device == "cpu" and quant_type == "fp4":
pytest.skip("FP4 is not supported for CPU")
if fullgraph and torch.__version__ < (2, 8):
pytest.skip("fullgraph mode requires torch 2.8 or higher")
if device == "cuda" and platform.system() == "Windows":
pytest.skip("Triton is not officially supported on Windows")
# Has a strange regression on Linux aarch64 CPU in torch==2.6.0 when fullgraph=False.
if (
not fullgraph
and device == "cpu"
and platform.machine() == "aarch64"
and platform.system() == "Linux"
and ((2, 7) > torch.__version__ >= (2, 6))
):
pytest.xfail("Regression in torch==2.6.0 on Linux aarch64 CPU")
dim = 256
batch_size = 16
torch.compiler.reset()
# Create a small network with Linear4bit layers
net = torch.nn.Sequential(
*[
bnb.nn.Linear4bit(
dim,
dim,
bias=bias,
compute_dtype=compute_dtype,
compress_statistics=compress_statistics,
quant_type=quant_type,
)
for _ in range(4)
]
).to(device)
# Create input tensor
x = torch.randn(batch_size, dim, dtype=compute_dtype, device=device)
# Get reference output before compilation
with torch.no_grad():
ref_output = net(x)
# Compile the model
compiled_net = torch.compile(net, fullgraph=fullgraph, mode=mode)
# Get output from compiled model
with torch.no_grad():
compiled_output = compiled_net(x)
# Check outputs match
assert compiled_output.shape == ref_output.shape
assert compiled_output.device == ref_output.device
assert compiled_output.dtype == ref_output.dtype
torch.testing.assert_close(compiled_output, ref_output)
# Test with gradients
x.requires_grad_(True)
y1 = net(x).sum()
y1.backward()
grad_ref = x.grad.clone()
x.grad = None
y2 = compiled_net(x).sum()
y2.backward()
grad_compiled = x.grad.clone()
torch.testing.assert_close(grad_compiled, grad_ref)
...@@ -2,6 +2,7 @@ from contextlib import nullcontext ...@@ -2,6 +2,7 @@ from contextlib import nullcontext
import copy import copy
import os import os
import pickle import pickle
import platform
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
import pytest import pytest
...@@ -224,3 +225,68 @@ def test_linear8bit_serialization(linear8bit): ...@@ -224,3 +225,68 @@ def test_linear8bit_serialization(linear8bit):
# check for a bug where SCB and CB were not copied # check for a bug where SCB and CB were not copied
assert (linear8bit.weight.SCB == deserialized.weight.SCB).all() assert (linear8bit.weight.SCB == deserialized.weight.SCB).all()
assert (linear8bit.weight.CB == deserialized.weight.CB).all() assert (linear8bit.weight.CB == deserialized.weight.CB).all()
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("threshold", [0.0, 6.0], ids=id_formatter("threshold"))
@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias"))
@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph"))
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode):
if device == "cuda" and platform.system() == "Windows":
pytest.skip("Triton is not officially supported on Windows")
dim = 256
batch_size = 16
torch.compiler.reset()
# Create a small network with Linear8bitLt layers
net = torch.nn.Sequential(
*[bnb.nn.Linear8bitLt(dim, dim, bias=bias, has_fp16_weights=False, threshold=threshold) for _ in range(4)]
).to(device)
dynamic_output_shapes = fullgraph and threshold > 0
with torch._dynamo.config.patch("capture_dynamic_output_shape_ops", dynamic_output_shapes):
# Create input tensor
x = torch.randn(batch_size, dim, dtype=torch.float16, device=device)
# Get reference output before compilation
with torch.no_grad():
ref_output = net(x)
# Compile the model
compiled_net = torch.compile(net, fullgraph=fullgraph, mode=mode)
# Get output from compiled model
with torch.no_grad():
compiled_output = compiled_net(x)
# Check outputs match
assert compiled_output.shape == ref_output.shape
assert compiled_output.device == ref_output.device
assert compiled_output.dtype == ref_output.dtype
torch.testing.assert_close(compiled_output, ref_output)
# Test with gradients. Currently only works with threshold=0.
# Has a strange regression on Linux aarch64 CPU in torch==2.6.0.
is_broken_platform = (
device == "cpu"
and platform.machine() == "aarch64"
and platform.system() == "Linux"
and ((2, 7) > torch.__version__ >= (2, 6))
)
if threshold == 0 and not is_broken_platform:
x.requires_grad_(True)
y1 = net(x).sum()
y1.backward()
grad_ref = x.grad.clone()
x.grad = None
y2 = compiled_net(x).sum()
y2.backward()
grad_compiled = x.grad.clone()
torch.testing.assert_close(grad_compiled, grad_ref)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment