Commit e9f3605f authored by Matthew Douglas's avatar Matthew Douglas
Browse files

Fix Linear4bit warnings/test for compute dtype

parent 812ef06a
......@@ -455,14 +455,14 @@ class Linear4bit(nn.Linear):
self.compute_dtype = x.dtype
elif x.dtype == torch.float16:
# we take the compoute dtype passed into the layer
if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]):
if self.compute_dtype in [None, torch.float32] and (x.numel() == x.shape[-1]):
# single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast
# warn the user about this
warnings.warn(
"Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.",
)
warnings.filterwarnings("ignore", message=".*inference.")
if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]):
if self.compute_dtype in [None, torch.float32] and (x.numel() != x.shape[-1]):
warnings.warn(
"Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.",
)
......
......@@ -440,31 +440,23 @@ def test_4bit_linear_warnings(device):
dim1 = 64
with pytest.warns(UserWarning, match=r"inference or training"):
net = nn.Sequential(
*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
)
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)])
net = net.to(device)
inp = torch.rand(10, dim1, device=device, dtype=torch.float16)
net(inp)
with pytest.warns(UserWarning, match=r"inference."):
net = nn.Sequential(
*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
)
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)])
net = net.to(device)
inp = torch.rand(1, dim1, device=device, dtype=torch.float16)
net(inp)
with pytest.warns(UserWarning) as record:
net = nn.Sequential(
*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
)
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)])
net = net.to(device)
inp = torch.rand(10, dim1, device=device, dtype=torch.float16)
net(inp)
net = nn.Sequential(
*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
)
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)])
net = net.to(device)
inp = torch.rand(1, dim1, device=device, dtype=torch.float16)
net(inp)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment