Commit 5b612bc6 authored by Tim Dettmers's avatar Tim Dettmers
Browse files

Added is_available_triton guard to Triton SwitchBackLinear.

parent c3d87e44
......@@ -3,6 +3,8 @@ import torch.nn as nn
import time
from functools import partial
from bitsandbytes.triton.triton_utils import is_triton_available
from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise
from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
......@@ -160,6 +162,10 @@ class SwitchBackLinear(nn.Linear):
):
super().__init__(in_features, out_features, bias, device, dtype)
if not is_triton_available:
raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear.
Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''')
# By default, we use the global quantization.
self.vectorize = vectorize
if self.vectorize:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment