Unverified Commit 947db7cf authored by Tim Dettmers's avatar Tim Dettmers Committed by GitHub
Browse files

Merge pull request #436 from akx/quanitze

Fix typo "quanitze"
parents 8c5c6689 6b26402b
...@@ -8,7 +8,7 @@ from bitsandbytes.triton.quantize_rowwise import quantize_rowwise ...@@ -8,7 +8,7 @@ from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze from bitsandbytes.triton.int8_matmul_mixed_dequantize import int8_matmul_mixed_dequantize
# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large. # KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.
...@@ -72,8 +72,8 @@ if __name__ == '__main__': ...@@ -72,8 +72,8 @@ if __name__ == '__main__':
get_time('standard_gx', lambda : g.matmul(w), info) get_time('standard_gx', lambda : g.matmul(w), info)
get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info) get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info)
get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info) get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info)
get_time('global_fwd', lambda : int8_matmul_mixed_dequanitze(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info) get_time('global_fwd', lambda : int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('global_bwd', lambda : int8_matmul_mixed_dequanitze(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info) get_time('global_bwd', lambda : int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info) get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info)
get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info) get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info)
get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info) get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info)
......
...@@ -10,7 +10,7 @@ from bitsandbytes.triton.quantize_rowwise import quantize_rowwise ...@@ -10,7 +10,7 @@ from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze from bitsandbytes.triton.int8_matmul_mixed_dequantize import int8_matmul_mixed_dequantize
class _switchback_global(torch.autograd.Function): class _switchback_global(torch.autograd.Function):
...@@ -29,7 +29,7 @@ class _switchback_global(torch.autograd.Function): ...@@ -29,7 +29,7 @@ class _switchback_global(torch.autograd.Function):
# matmult, fused dequant and add bias # matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized # call "mixed" because we are mixing rowwise quantized and global quantized
return int8_matmul_mixed_dequanitze( return int8_matmul_mixed_dequantize(
X_int8, W_int8.t(), state_X, state_W, bias X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D.size()[:-1], -1) ).view(*X_3D.size()[:-1], -1)
...@@ -47,7 +47,7 @@ class _switchback_global(torch.autograd.Function): ...@@ -47,7 +47,7 @@ class _switchback_global(torch.autograd.Function):
# so we transpose once then call .t() in the matmul # so we transpose once then call .t() in the matmul
G_int8, state_G = quantize_rowwise(G) G_int8, state_G = quantize_rowwise(G)
W_int8, state_W = quantize_global_transpose(W) W_int8, state_W = quantize_global_transpose(W)
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view( grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D.size()[:-1], -1 *G_3D.size()[:-1], -1
) )
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
...@@ -119,7 +119,7 @@ class _switchback_global_mem_efficient(torch.autograd.Function): ...@@ -119,7 +119,7 @@ class _switchback_global_mem_efficient(torch.autograd.Function):
# matmult, fused dequant and add bias # matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized # call "mixed" because we are mixing rowwise quantized and global quantized
return int8_matmul_mixed_dequanitze( return int8_matmul_mixed_dequantize(
X_int8, W_int8.t(), state_X, state_W, bias X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D_sz[:-1], -1) ).view(*X_3D_sz[:-1], -1)
...@@ -143,7 +143,7 @@ class _switchback_global_mem_efficient(torch.autograd.Function): ...@@ -143,7 +143,7 @@ class _switchback_global_mem_efficient(torch.autograd.Function):
G_int8, state_G = quantize_rowwise(G) G_int8, state_G = quantize_rowwise(G)
del G del G
W_int8 = W_int8.t().contiguous() W_int8 = W_int8.t().contiguous()
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view( grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D_sz[:-1], -1 *G_3D_sz[:-1], -1
) )
...@@ -215,7 +215,7 @@ class SwitchBackLinear(nn.Linear): ...@@ -215,7 +215,7 @@ class SwitchBackLinear(nn.Linear):
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
).view(*x.size()[:-1], -1) ).view(*x.size()[:-1], -1)
else: else:
return int8_matmul_mixed_dequanitze( return int8_matmul_mixed_dequantize(
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
).view(*x.size()[:-1], -1) ).view(*x.size()[:-1], -1)
......
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
from bitsandbytes.triton.triton_utils import is_triton_available from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available(): if not is_triton_available():
def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias): return None def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): return None
else: else:
import triton import triton
...@@ -136,7 +136,7 @@ else: ...@@ -136,7 +136,7 @@ else:
tl.atomic_add(C, acc, mask=mask) tl.atomic_add(C, acc, mask=mask)
def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias): def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias):
device = a.device device = a.device
divfactor = 1. / (127. * 127.) divfactor = 1. / (127. * 127.)
has_bias = 0 if bias is None else 1 has_bias = 0 if bias is None else 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment