Unverified Commit 9f858294 authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Add torch.compile tests (#1648)

* Add torch.compile tests

* Tests: WA aarch64 CPU regressions for torch 2.6.0; add Windows torch==2.7.0+cu118 test config

* Tests: skip torch.compile for cuda on windows
parent 503d243e
......@@ -137,6 +137,10 @@ jobs:
with:
python-version: 3.9
- name: Setup MSVC
if: startsWith(matrix.os, 'windows')
uses: ilammy/msvc-dev-cmd@v1.13.0 # to use cl for torch.compile
- name: Install dependencies
run: |
pip install torch==${{ matrix.torch_version }} --index-url https://download.pytorch.org/whl/cpu
......@@ -201,18 +205,40 @@ jobs:
torch_version: "2.7.0"
pypi_index: "https://download.pytorch.org/whl/cu128"
# L40S runners
# Linux L40S runners
- os: ubuntu-22.04
gpu: L40S
runner: bandb-aws-g6e-4xlarge-plus-use1-public-80
# T4 runners
# Linux T4 runners
- os: ubuntu-22.04
gpu: T4
runner: bandb-aws-g4dn-4xlarge-plus-use1-public-80
# Specific Windows runners using cu118
- os: windows-2025
arch: x86_64
gpu: T4
runner: CUDA-Windows-x64
cuda_version: "11.8.0"
torch_version: "2.2.0"
pypi_index: "https://download.pytorch.org/whl/cu118"
- os: windows-2025
arch: x86_64
gpu: T4
runner: CUDA-Windows-x64
cuda_version: "11.8.0"
torch_version: "2.6.0"
pypi_index: "https://download.pytorch.org/whl/cu118"
- os: windows-2025
arch: x86_64
gpu: T4
runner: CUDA-Windows-x64
cuda_version: "11.8.0"
torch_version: "2.7.0"
pypi_index: "https://download.pytorch.org/whl/cu118"
exclude:
# Our current T4 Windows runner has a driver too old (471.11)
# and cannot support CUDA 12+. Skip for now.
......
......@@ -771,14 +771,14 @@ def quantize_blockwise(
qabsmax, state2 = quantize_blockwise(_absmax, blocksize=blocksize, nested=False)
quant_state = QuantState(
absmax=qabsmax,
code=code,
code=code.to(A.device, copy=True),
blocksize=blocksize,
dtype=A.dtype,
offset=offset,
state2=state2,
)
else:
quant_state = QuantState(absmax=_absmax, code=code.to(A.device), blocksize=blocksize, dtype=A.dtype)
quant_state = QuantState(absmax=_absmax, code=code.to(A.device, copy=True), blocksize=blocksize, dtype=A.dtype)
# TODO(matthewdouglas): Deprecate out kwarg
out = out.copy_(_out) if out is not None else _out
......
......@@ -493,7 +493,7 @@ class Linear4bit(nn.Linear):
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
return bnb.matmul_4bit(x, self.weight.data.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
class LinearFP4(Linear4bit):
......
import copy
import os
import pickle
import platform
from tempfile import TemporaryDirectory
import pytest
import torch
import bitsandbytes as bnb
from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, torch_load_from_buffer, torch_save_to_buffer
from tests.helpers import (
TRUE_FALSE,
describe_dtype,
get_available_devices,
id_formatter,
torch_load_from_buffer,
torch_save_to_buffer,
)
storage = {
"uint8": torch.uint8,
......@@ -275,3 +283,85 @@ def test_params4bit_real_serialization(device, quant_type, blocksize, compress_s
# there was a bug where deepcopy would modify the original object
assert dict_keys_before == dict_keys_after
assert dict_keys_before == dict_keys_deserialized
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("compute_dtype", [torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias"))
@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph"))
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_statistics, bias, fullgraph, mode):
if device == "cpu" and quant_type == "fp4":
pytest.skip("FP4 is not supported for CPU")
if fullgraph and torch.__version__ < (2, 8):
pytest.skip("fullgraph mode requires torch 2.8 or higher")
if device == "cuda" and platform.system() == "Windows":
pytest.skip("Triton is not officially supported on Windows")
# Has a strange regression on Linux aarch64 CPU in torch==2.6.0 when fullgraph=False.
if (
not fullgraph
and device == "cpu"
and platform.machine() == "aarch64"
and platform.system() == "Linux"
and ((2, 7) > torch.__version__ >= (2, 6))
):
pytest.xfail("Regression in torch==2.6.0 on Linux aarch64 CPU")
dim = 256
batch_size = 16
torch.compiler.reset()
# Create a small network with Linear4bit layers
net = torch.nn.Sequential(
*[
bnb.nn.Linear4bit(
dim,
dim,
bias=bias,
compute_dtype=compute_dtype,
compress_statistics=compress_statistics,
quant_type=quant_type,
)
for _ in range(4)
]
).to(device)
# Create input tensor
x = torch.randn(batch_size, dim, dtype=compute_dtype, device=device)
# Get reference output before compilation
with torch.no_grad():
ref_output = net(x)
# Compile the model
compiled_net = torch.compile(net, fullgraph=fullgraph, mode=mode)
# Get output from compiled model
with torch.no_grad():
compiled_output = compiled_net(x)
# Check outputs match
assert compiled_output.shape == ref_output.shape
assert compiled_output.device == ref_output.device
assert compiled_output.dtype == ref_output.dtype
torch.testing.assert_close(compiled_output, ref_output)
# Test with gradients
x.requires_grad_(True)
y1 = net(x).sum()
y1.backward()
grad_ref = x.grad.clone()
x.grad = None
y2 = compiled_net(x).sum()
y2.backward()
grad_compiled = x.grad.clone()
torch.testing.assert_close(grad_compiled, grad_ref)
......@@ -2,6 +2,7 @@ from contextlib import nullcontext
import copy
import os
import pickle
import platform
from tempfile import TemporaryDirectory
import pytest
......@@ -224,3 +225,68 @@ def test_linear8bit_serialization(linear8bit):
# check for a bug where SCB and CB were not copied
assert (linear8bit.weight.SCB == deserialized.weight.SCB).all()
assert (linear8bit.weight.CB == deserialized.weight.CB).all()
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("threshold", [0.0, 6.0], ids=id_formatter("threshold"))
@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias"))
@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph"))
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode):
if device == "cuda" and platform.system() == "Windows":
pytest.skip("Triton is not officially supported on Windows")
dim = 256
batch_size = 16
torch.compiler.reset()
# Create a small network with Linear8bitLt layers
net = torch.nn.Sequential(
*[bnb.nn.Linear8bitLt(dim, dim, bias=bias, has_fp16_weights=False, threshold=threshold) for _ in range(4)]
).to(device)
dynamic_output_shapes = fullgraph and threshold > 0
with torch._dynamo.config.patch("capture_dynamic_output_shape_ops", dynamic_output_shapes):
# Create input tensor
x = torch.randn(batch_size, dim, dtype=torch.float16, device=device)
# Get reference output before compilation
with torch.no_grad():
ref_output = net(x)
# Compile the model
compiled_net = torch.compile(net, fullgraph=fullgraph, mode=mode)
# Get output from compiled model
with torch.no_grad():
compiled_output = compiled_net(x)
# Check outputs match
assert compiled_output.shape == ref_output.shape
assert compiled_output.device == ref_output.device
assert compiled_output.dtype == ref_output.dtype
torch.testing.assert_close(compiled_output, ref_output)
# Test with gradients. Currently only works with threshold=0.
# Has a strange regression on Linux aarch64 CPU in torch==2.6.0.
is_broken_platform = (
device == "cpu"
and platform.machine() == "aarch64"
and platform.system() == "Linux"
and ((2, 7) > torch.__version__ >= (2, 6))
)
if threshold == 0 and not is_broken_platform:
x.requires_grad_(True)
y1 = net(x).sum()
y1.backward()
grad_ref = x.grad.clone()
x.grad = None
y2 = compiled_net(x).sum()
y2.backward()
grad_compiled = x.grad.clone()
torch.testing.assert_close(grad_compiled, grad_ref)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment