Unverified Commit 97073cdb authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Support LLM.int8() inference with torch.compile (#1594)

parent feaedbb0
...@@ -15,6 +15,32 @@ else: ...@@ -15,6 +15,32 @@ else:
register_fake = torch.library.impl_abstract register_fake = torch.library.impl_abstract
register_kernel = torch.library.impl register_kernel = torch.library.impl
# Int8 mixed precision matmul + dequant + bias
torch.library.define(
"bitsandbytes::int8_mixed_scaled_mm",
"(Tensor A, Tensor CA, Tensor CB, Tensor SCA, Tensor SCB, Tensor? outlier_cols=None, Tensor? bias=None) -> (Tensor, Tensor?)",
)
@register_fake("bitsandbytes::int8_mixed_scaled_mm")
def _(
A: torch.Tensor,
CA: torch.Tensor,
CB: torch.Tensor,
SCA: torch.Tensor,
SCB: torch.Tensor,
outlier_cols: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
shapeC = (*CA.shape[:-1], CB.shape[0])
out = torch.empty(shapeC, device=A.device, dtype=A.dtype)
outlier_cols = torch.library.get_ctx().new_dynamic_size()
subA = A.new_empty(outlier_cols, dtype=torch.int64)
return out, subA
# Higher level op: int8 matmul + dequant + bias # Higher level op: int8 matmul + dequant + bias
torch.library.define( torch.library.define(
......
...@@ -210,37 +210,28 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -210,37 +210,28 @@ class MatMul8bitLt(torch.autograd.Function):
# 2. Quantize B # 2. Quantize B
state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16)) state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16))
# Handle sparse decomposition. In some instances, we may have not found any # Handle sparse decomposition
# outlier columns at all. In that case, we'll skip this part completely. if state.threshold > 0.0:
if state.threshold > 0.0 and outlier_cols is not None and outlier_cols.numel():
state.idx = outlier_cols state.idx = outlier_cols
# Zero out the outliers in the transposed 8bit inputs. # Mixed Int8 Matmul + Dequant + Bias
if CAt is not None: output, subA = torch.ops.bitsandbytes.int8_mixed_scaled_mm(
CAt[:, state.idx] = 0 A,
CA,
# Extract the input outliers in original precision state.CB,
subA = A[:, state.idx].contiguous() SCA,
state.SCB,
outlier_cols,
bias,
)
# Extract the corresponding weights
if state.has_fp16_weights:
state.subB = B[:, state.idx].t()
else:
# To dequantize our weights associated with the input outliers,
# we want to divide by 127. It's however more performant to multiply
# by the reciprocal.
outliers = state.CB[:, state.idx]
state.subB = F.int8_vectorwise_dequant(outliers, state.SCB).to(A.dtype).t()
else: else:
# Int8 Matmul + Dequant + Bias
output = torch.ops.bitsandbytes.int8_scaled_mm.default(
CA, state.CB, SCA, state.SCB, bias=bias, dtype=A.dtype
)
subA = None subA = None
# 3. Int8 Matmul + Dequant + Bias
output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, state.CB, SCA, state.SCB, bias=bias, dtype=A.dtype)
# 4. Mixed-precision decomposition matmul
if subA is not None and state.subB is not None:
output = output.addmm(subA, state.subB)
# 5. Save state # 5. Save state
ctx.state = state ctx.state = state
......
...@@ -22,6 +22,45 @@ def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): ...@@ -22,6 +22,45 @@ def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
_int8_linear_matmul_impl(A, B, out) _int8_linear_matmul_impl(A, B, out)
@register_kernel("bitsandbytes::int8_mixed_scaled_mm", "cuda")
def _(
A: torch.Tensor,
CA: torch.Tensor,
CB: torch.Tensor,
SCA: torch.Tensor,
SCB: torch.Tensor,
outlier_cols: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
subB = None
if outlier_cols is not None and outlier_cols.numel():
# Extract the inputs with outliers in original precision
subA = A[:, outlier_cols].contiguous()
# Dequantize the corresponding weight columns
subB = (
torch.ops.bitsandbytes.int8_vectorwise_dequant.default(CB[:, outlier_cols].contiguous(), SCB)
.to(A.dtype)
.t()
)
# TODO: if state.has_fp16_weights: subB = B[:, outlier_cols].t()
else:
# Needed for torch.compile when there are no outliers.
subA = torch.empty(0, device=A.device, dtype=A.dtype)
# Int8 Matmul + Dequant + Bias
output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, CB, SCA, SCB, bias=bias, dtype=A.dtype)
if subB is not None:
# Add the outlier columns back to the output
output = output.addmm(subA, subB)
return output, subA
def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
A, B = B, A A, B = B, A
...@@ -143,6 +182,9 @@ def _(A: torch.Tensor, threshold=0.0): ...@@ -143,6 +182,9 @@ def _(A: torch.Tensor, threshold=0.0):
if outliers.any(): if outliers.any():
outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1)
else:
# Needed for torch.compile support.
outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64)
with _cuda_device_of(A): with _cuda_device_of(A):
lib.cint8_vector_quant( lib.cint8_vector_quant(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment