"docs/vscode:/vscode.git/clone" did not exist on "a0fc3db7df0849f05dfc1260da966bb7b6e24b52"
Commit 5b612bc6 authored by Tim Dettmers's avatar Tim Dettmers
Browse files

Added is_available_triton guard to Triton SwitchBackLinear.

parent c3d87e44
...@@ -3,6 +3,8 @@ import torch.nn as nn ...@@ -3,6 +3,8 @@ import torch.nn as nn
import time import time
from functools import partial from functools import partial
from bitsandbytes.triton.triton_utils import is_triton_available
from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise
from bitsandbytes.triton.quantize_rowwise import quantize_rowwise from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
...@@ -160,6 +162,10 @@ class SwitchBackLinear(nn.Linear): ...@@ -160,6 +162,10 @@ class SwitchBackLinear(nn.Linear):
): ):
super().__init__(in_features, out_features, bias, device, dtype) super().__init__(in_features, out_features, bias, device, dtype)
if not is_triton_available:
raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear.
Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''')
# By default, we use the global quantization. # By default, we use the global quantization.
self.vectorize = vectorize self.vectorize = vectorize
if self.vectorize: if self.vectorize:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment