Unverified Commit 947db7cf authored by Tim Dettmers's avatar Tim Dettmers Committed by GitHub
Browse files

Merge pull request #436 from akx/quanitze

Fix typo "quanitze"
parents 8c5c6689 6b26402b
......@@ -8,7 +8,7 @@ from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze
from bitsandbytes.triton.int8_matmul_mixed_dequantize import int8_matmul_mixed_dequantize
# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.
......@@ -72,8 +72,8 @@ if __name__ == '__main__':
get_time('standard_gx', lambda : g.matmul(w), info)
get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info)
get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info)
get_time('global_fwd', lambda : int8_matmul_mixed_dequanitze(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('global_bwd', lambda : int8_matmul_mixed_dequanitze(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('global_fwd', lambda : int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('global_bwd', lambda : int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info)
get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info)
get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info)
......
......@@ -10,7 +10,7 @@ from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze
from bitsandbytes.triton.int8_matmul_mixed_dequantize import int8_matmul_mixed_dequantize
class _switchback_global(torch.autograd.Function):
......@@ -29,7 +29,7 @@ class _switchback_global(torch.autograd.Function):
# matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized
return int8_matmul_mixed_dequanitze(
return int8_matmul_mixed_dequantize(
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D.size()[:-1], -1)
......@@ -47,7 +47,7 @@ class _switchback_global(torch.autograd.Function):
# so we transpose once then call .t() in the matmul
G_int8, state_G = quantize_rowwise(G)
W_int8, state_W = quantize_global_transpose(W)
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view(
grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D.size()[:-1], -1
)
if ctx.needs_input_grad[1]:
......@@ -119,7 +119,7 @@ class _switchback_global_mem_efficient(torch.autograd.Function):
# matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized
return int8_matmul_mixed_dequanitze(
return int8_matmul_mixed_dequantize(
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D_sz[:-1], -1)
......@@ -143,7 +143,7 @@ class _switchback_global_mem_efficient(torch.autograd.Function):
G_int8, state_G = quantize_rowwise(G)
del G
W_int8 = W_int8.t().contiguous()
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view(
grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D_sz[:-1], -1
)
......@@ -215,7 +215,7 @@ class SwitchBackLinear(nn.Linear):
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
).view(*x.size()[:-1], -1)
else:
return int8_matmul_mixed_dequanitze(
return int8_matmul_mixed_dequantize(
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
).view(*x.size()[:-1], -1)
......
......@@ -2,7 +2,7 @@ import torch
from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available():
def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias): return None
def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): return None
else:
import triton
......@@ -136,7 +136,7 @@ else:
tl.atomic_add(C, acc, mask=mask)
def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias):
def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias):
device = a.device
divfactor = 1. / (127. * 127.)
has_bias = 0 if bias is None else 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment