Commit 758c7175 authored by Tim Dettmers's avatar Tim Dettmers
Browse files

Merge branch 'debug' into cuda-bin-switch-and-cli

parents 96bc209b ab72a129
...@@ -58,7 +58,7 @@ CC_cublasLt111 += -gencode arch=compute_86,code=sm_86 ...@@ -58,7 +58,7 @@ CC_cublasLt111 += -gencode arch=compute_86,code=sm_86
all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
......
from dataclasses import dataclass from dataclasses import dataclass
import torch import torch
import math
import bitsandbytes as bnb import bitsandbytes as bnb
import bitsandbytes.functional as F import bitsandbytes.functional as F
...@@ -199,6 +199,17 @@ class MatmulLtState: ...@@ -199,6 +199,17 @@ class MatmulLtState:
class MatMul8bitLt(torch.autograd.Function): class MatMul8bitLt(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, A, B, out=None, state=MatmulLtState()): def forward(ctx, A, B, out=None, state=MatmulLtState()):
# default to pytorch behavior if inputs are empty
ctx.is_empty = False
if math.prod(A.shape) == 0:
ctx.is_empty = True
ctx.A = A
ctx.B = B
if A.shape[-1] == B.shape[0]:
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device)
else:
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)
# 1. Quantize A # 1. Quantize A
# 2. Quantize B # 2. Quantize B
# 3. Matmul # 3. Matmul
...@@ -339,6 +350,8 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -339,6 +350,8 @@ class MatMul8bitLt(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
if ctx.is_empty:
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None
req_gradA, req_gradB = ctx.req_grads req_gradA, req_gradB = ctx.req_grads
CAt, subA = ctx.tensors CAt, subA = ctx.tensors
SCAt, idx = ctx.tensor_states SCAt, idx = ctx.tensor_states
...@@ -375,7 +388,7 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -375,7 +388,7 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.grad_shape ctx.grad_shape
) )
return grad_A, grad_B, None, None, None, None, None return grad_A, grad_B, None, None
matmul = MatMul8bitLt.apply matmul = MatMul8bitLt.apply
......
...@@ -4,9 +4,10 @@ ...@@ -4,9 +4,10 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import ctypes as ct import ctypes as ct
import random import random
from typing import Tuple import math
import torch import torch
from typing import Tuple
from torch import Tensor from torch import Tensor
from .cextension import COMPILED_WITH_CUDA, lib from .cextension import COMPILED_WITH_CUDA, lib
...@@ -193,6 +194,14 @@ def get_special_format_str(): ...@@ -193,6 +194,14 @@ def get_special_format_str():
return "col_turing" return "col_turing"
def is_on_gpu(tensors):
on_gpu = True
for t in tensors:
if t is None: continue # NULL pointers are fine
on_gpu &= t.device.type == 'cuda'
return on_gpu
def get_ptr(A: Tensor) -> ct.c_void_p: def get_ptr(A: Tensor) -> ct.c_void_p:
""" """
Get the ctypes pointer from a PyTorch Tensor. Get the ctypes pointer from a PyTorch Tensor.
...@@ -336,7 +345,7 @@ def nvidia_transform( ...@@ -336,7 +345,7 @@ def nvidia_transform(
def estimate_quantiles( def estimate_quantiles(
A: Tensor, out: Tensor = None, offset: float = 1 / 512 A: Tensor, out: Tensor = None, offset: float = 1 / 512
) -> Tensor: ) -> Tensor:
""" '''
Estimates 256 equidistant quantiles on the input tensor eCDF. Estimates 256 equidistant quantiles on the input tensor eCDF.
Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles
...@@ -361,9 +370,9 @@ def estimate_quantiles( ...@@ -361,9 +370,9 @@ def estimate_quantiles(
------- -------
torch.Tensor: torch.Tensor:
The 256 quantiles in float32 datatype. The 256 quantiles in float32 datatype.
""" '''
if out is None: if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
out = torch.zeros((256,), dtype=torch.float32, device=A.device) is_on_gpu([A, out])
if A.dtype == torch.float32: if A.dtype == torch.float32:
lib.cestimate_quantiles_fp32( lib.cestimate_quantiles_fp32(
get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()) get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
...@@ -428,7 +437,8 @@ def quantize_blockwise( ...@@ -428,7 +437,8 @@ def quantize_blockwise(
if out is None: if out is None:
out = torch.zeros_like(A, dtype=torch.uint8) out = torch.zeros_like(A, dtype=torch.uint8)
if A.device.type != "cpu": if A.device.type != 'cpu':
is_on_gpu([code, A, absmax, out, rand])
if rand is not None: if rand is not None:
assert rand.numel() >= 1024 assert rand.numel() >= 1024
rand_offset = random.randint(0, 1023) rand_offset = random.randint(0, 1023)
...@@ -541,7 +551,8 @@ def dequantize_blockwise( ...@@ -541,7 +551,8 @@ def dequantize_blockwise(
f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]" f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]"
) )
if A.device.type != "cpu": if A.device.type != 'cpu':
is_on_gpu([A, out])
if out.dtype == torch.float32: if out.dtype == torch.float32:
lib.cdequantize_blockwise_fp32( lib.cdequantize_blockwise_fp32(
get_ptr(quant_state[1]), get_ptr(quant_state[1]),
...@@ -610,7 +621,7 @@ def dequantize( ...@@ -610,7 +621,7 @@ def dequantize(
def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
""" '''
Quantizes input tensor to 8-bit. Quantizes input tensor to 8-bit.
Quantizes the 32-bit input tensor `A` to the 8-bit output tensor Quantizes the 32-bit input tensor `A` to the 8-bit output tensor
...@@ -629,15 +640,15 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: ...@@ -629,15 +640,15 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
------- -------
torch.Tensor: torch.Tensor:
Quantized 8-bit tensor. Quantized 8-bit tensor.
""" '''
if out is None: if out is None: out = torch.zeros_like(A, dtype=torch.uint8)
out = torch.zeros_like(A, dtype=torch.uint8) is_on_gpu([A, out])
lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
return out return out
def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
""" '''
Dequantizes the 8-bit tensor to 32-bit. Dequantizes the 8-bit tensor to 32-bit.
Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via
...@@ -656,12 +667,10 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: ...@@ -656,12 +667,10 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
------- -------
torch.Tensor: torch.Tensor:
32-bit output tensor. 32-bit output tensor.
""" '''
if out is None: if out is None: out = torch.zeros_like(A, dtype=torch.float32)
out = torch.zeros_like(A, dtype=torch.float32) is_on_gpu([code, A, out])
lib.cdequantize( lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())
)
return out return out
...@@ -983,6 +992,7 @@ def percentile_clipping( ...@@ -983,6 +992,7 @@ def percentile_clipping(
The current optimiation steps (number of past gradient norms). The current optimiation steps (number of past gradient norms).
""" """
is_on_gpu([grad, gnorm_vec])
if grad.dtype == torch.float32: if grad.dtype == torch.float32:
lib.cpercentile_clipping_g32( lib.cpercentile_clipping_g32(
get_ptr(grad), get_ptr(grad),
...@@ -1027,21 +1037,11 @@ def histogram_scatter_add_2d( ...@@ -1027,21 +1037,11 @@ def histogram_scatter_add_2d(
maxdim1 = ct.c_int32(histogram.shape[0]) maxdim1 = ct.c_int32(histogram.shape[0])
n = ct.c_int32(index1.numel()) n = ct.c_int32(index1.numel())
lib.chistogram_scatter_add_2d( is_on_gpu([histogram, index1, index2d, source])
get_ptr(histogram), lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n)
get_ptr(index1),
get_ptr(index2),
get_ptr(source),
maxdim1,
n,
)
def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
def check_matmul( if not torch.cuda.is_initialized(): torch.cuda.init()
A, B, out, transposed_A, transposed_B, expected_type=torch.int8
):
if not torch.cuda.is_initialized():
torch.cuda.init()
if A.dtype != expected_type or B.dtype != expected_type: if A.dtype != expected_type or B.dtype != expected_type:
raise TypeError( raise TypeError(
f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}" f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}"
...@@ -1212,21 +1212,10 @@ def igemm( ...@@ -1212,21 +1212,10 @@ def igemm(
ptr = CUBLAS_Context.get_instance().get_context(A.device) ptr = CUBLAS_Context.get_instance().get_context(A.device)
# B^T @ A^T = C^T # B^T @ A^T = C^T
# [km, nk -> mn] # [km, nk -> mn]
lib.cigemm( is_on_gpu([B, A, out])
ptr, lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
ct.c_bool(transposed_B), get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc))
ct.c_bool(transposed_A),
ct.c_int32(m),
ct.c_int32(n),
ct.c_int32(k),
get_ptr(B),
get_ptr(A),
get_ptr(out),
ct.c_int32(lda),
ct.c_int32(ldb),
ct.c_int32(ldc),
)
return out return out
...@@ -1306,24 +1295,10 @@ def batched_igemm( ...@@ -1306,24 +1295,10 @@ def batched_igemm(
ptr = CUBLAS_Context.get_instance().get_context(A.device) ptr = CUBLAS_Context.get_instance().get_context(A.device)
lib.cbatched_igemm( is_on_gpu([B, A, out])
ptr, lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
ct.c_bool(transposed_B), get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc),
ct.c_bool(transposed_A), ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch))
ct.c_int32(m),
ct.c_int32(n),
ct.c_int32(k),
get_ptr(B),
get_ptr(A),
get_ptr(out),
ct.c_int32(lda),
ct.c_int32(ldb),
ct.c_int32(ldc),
ct.c_long(strideA),
ct.c_long(strideB),
ct.c_long(strideC),
ct.c_uint32(num_batch),
)
return out return out
...@@ -1332,15 +1307,20 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): ...@@ -1332,15 +1307,20 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
shapeB = SB[0] shapeB = SB[0]
dimsA = len(shapeA) dimsA = len(shapeA)
dimsB = len(shapeB) dimsB = len(shapeB)
assert dimsB == 2, 'Only two dimensional matrices are supported for argument B'
if dimsA == 2: if dimsA == 2:
m = shapeA[0] m = shapeA[0]
elif dimsA == 3: elif dimsA == 3:
m = shapeA[0] * shapeA[1] m = shapeA[0] * shapeA[1]
if dimsB == 2: rows = n = shapeB[0]
rows = n = shapeB[0] assert math.prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}'
elif dimsB == 3:
rows = n = shapeB[0] * shapeB[1] # if the tensor is empty, return a transformed empty tensor with the right dimensions
if shapeA[0] == 0 and dimsA == 2:
return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16)
elif shapeA[1] == 0 and dimsA == 3:
return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16)
if dimsA == 2 and out is None: if dimsA == 2 and out is None:
out, Sout = get_transform_buffer( out, Sout = get_transform_buffer(
...@@ -1390,7 +1370,8 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): ...@@ -1390,7 +1370,8 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
has_error = 0 has_error = 0
ptrRowScale = get_ptr(None) ptrRowScale = get_ptr(None)
if formatB == "col_turing": is_on_gpu([A, B, out])
if formatB == 'col_turing':
if dtype == torch.int32: if dtype == torch.int32:
has_error = lib.cigemmlt_turing_32( has_error = lib.cigemmlt_turing_32(
ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
...@@ -1410,7 +1391,8 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): ...@@ -1410,7 +1391,8 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
) )
if has_error == 1: if has_error == 1:
raise Exception("cublasLt ran into an error!") print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}')
raise Exception('cublasLt ran into an error!')
torch.cuda.set_device(prev_device) torch.cuda.set_device(prev_device)
...@@ -1457,16 +1439,8 @@ def mm_dequant( ...@@ -1457,16 +1439,8 @@ def mm_dequant(
numRows = ct.c_int32(out_shape[0]) numRows = ct.c_int32(out_shape[0])
numCols = ct.c_int32(out_shape[1]) numCols = ct.c_int32(out_shape[1])
lib.cdequant_mm_int32_fp16( is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats])
ptrA, lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, numRows, numCols)
ptrRowStats,
ptrColStats,
ptrOut,
ptrNewRowStats,
ptrNewColStats,
numRows,
numCols,
)
return out return out
...@@ -1507,15 +1481,8 @@ def get_colrow_absmax( ...@@ -1507,15 +1481,8 @@ def get_colrow_absmax(
cols = ct.c_int32(cols) cols = ct.c_int32(cols)
prev_device = pre_call(A.device) prev_device = pre_call(A.device)
lib.cget_col_row_stats( is_on_gpu([A, row_stats, col_stats, nnz_block_ptr])
ptrA, lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols)
ptrRowStats,
ptrColStats,
ptrNnzrows,
ct.c_float(threshold),
rows,
cols,
)
post_call(prev_device) post_call(prev_device)
if threshold > 0.0: if threshold > 0.0:
...@@ -1642,6 +1609,7 @@ def double_quant( ...@@ -1642,6 +1609,7 @@ def double_quant(
ptrOutCol = get_ptr(out_col) ptrOutCol = get_ptr(out_col)
ptrOutRow = get_ptr(out_row) ptrOutRow = get_ptr(out_row)
is_on_gpu([A, col_stats, row_stats, out_col, out_row])
if threshold > 0.0: if threshold > 0.0:
nnz = nnz_row_ptr[-1].item() nnz = nnz_row_ptr[-1].item()
if nnz > 0: if nnz > 0:
...@@ -1714,33 +1682,19 @@ def get_special_format_str(): ...@@ -1714,33 +1682,19 @@ def get_special_format_str():
) )
assert major >= 7 assert major >= 7
if major == 7: if major == 7: return 'col_turing'
return "col_turing" elif major == 8: return 'col_ampere'
elif major == 8: else: return 'col_turing'
return "col_ampere"
else:
return "col_turing"
def transform(
A,
to_order, def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
from_order="row", prev_device = pre_call(A.device)
out=None, if state is None: state = (A.shape, from_order)
transpose=False, else: from_order = state[1]
state=None, if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose)
ld=None, else: new_state = (state[0], to_order) # (shape, order)
):
if state is None:
state = (A.shape, from_order)
else:
from_order = state[1]
if out is None:
out, new_state = get_transform_buffer(
state[0], A.dtype, A.device, to_order, state[1], transpose
)
else:
new_state = (state[0], to_order) # (shape, order)
shape = state[0] shape = state[0]
if len(shape) == 2: if len(shape) == 2:
...@@ -1752,7 +1706,8 @@ def transform( ...@@ -1752,7 +1706,8 @@ def transform(
ptrA = get_ptr(A) ptrA = get_ptr(A)
ptrOut = get_ptr(out) ptrOut = get_ptr(out)
if to_order == "col32": is_on_gpu([A, out])
if to_order == 'col32':
if transpose: if transpose:
lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
else: else:
...@@ -1773,9 +1728,9 @@ def transform( ...@@ -1773,9 +1728,9 @@ def transform(
elif from_order == "col_ampere": elif from_order == "col_ampere":
lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2)
else: else:
raise NotImplementedError( raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}')
f"Transform function not implemented: From {from_order} to {to_order}"
) post_call(prev_device)
return out, new_state return out, new_state
...@@ -1810,21 +1765,8 @@ def spmm_coo(cooA, B, out=None): ...@@ -1810,21 +1765,8 @@ def spmm_coo(cooA, B, out=None):
cldb = ct.c_int32(ldb) cldb = ct.c_int32(ldb)
cldc = ct.c_int32(ldc) cldc = ct.c_int32(ldc)
lib.cspmm_coo( is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out])
ptr, lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B))
ptrRowidx,
ptrColidx,
ptrValues,
cnnz,
crowsA,
ccolsA,
ccolsB,
cldb,
ptrB,
cldc,
ptrC,
ct.c_bool(transposed_B),
)
return out return out
...@@ -1875,6 +1817,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): ...@@ -1875,6 +1817,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
# print(cooA.rowidx[:64]) # print(cooA.rowidx[:64])
# print(cooA.colidx[:64].sort()[0]) # print(cooA.colidx[:64].sort()[0])
is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats])
if B.dtype == torch.float16: if B.dtype == torch.float16:
lib.cspmm_coo_very_sparse_naive_fp16( lib.cspmm_coo_very_sparse_naive_fp16(
ptrMaxCount, ptrMaxCount,
...@@ -2061,9 +2004,11 @@ def extract_outliers(A, SA, idx): ...@@ -2061,9 +2004,11 @@ def extract_outliers(A, SA, idx):
ptrIdx = get_ptr(idx) ptrIdx = get_ptr(idx)
ptrOut = get_ptr(out) ptrOut = get_ptr(out)
if formatA == "col_turing": prev_device = pre_call(A.device)
if formatA == 'col_turing':
lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
elif formatA == "col_ampere": elif formatA == "col_ampere":
lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
post_call(prev_device)
return out return out
...@@ -19,53 +19,59 @@ using std::endl; ...@@ -19,53 +19,59 @@ using std::endl;
void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n) void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n)
{ {
int threads = 512; int threads = 512;
int blocks = n/threads; int num_blocks = n/threads;
blocks = n % threads == 0 ? blocks : blocks + 1; num_blocks = n % threads == 0 ? num_blocks : num_blocks + 1;
kHistogramScatterAdd2D<<<blocks, 512>>>(histogram, index1, index2, src, maxidx1, n); assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
kHistogramScatterAdd2D<<<num_blocks, 512>>>(histogram, index1, index2, src, maxidx1, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n) template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n)
{ {
int blocks = n/4096; int num_blocks = n/4096;
blocks = n % 4096 == 0 ? blocks : blocks + 1; num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
CUDA_CHECK_RETURN(cudaMemset(code, 0, 256*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(code, 0, 256*sizeof(float)));
kEstimateQuantiles<T><<<blocks, 512>>>(A, code, offset, std::numeric_limits<T>::max(), n); kEstimateQuantiles<T><<<num_blocks, 512>>>(A, code, offset, std::numeric_limits<T>::max(), n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
void quantize(float *code, float *A, unsigned char *out, int n) void quantize(float *code, float *A, unsigned char *out, int n)
{ {
int blocks = n/1024; int num_blocks = n/1024;
blocks = n % 1024 == 0 ? blocks : blocks + 1; num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
kQuantize<<<blocks, 1024>>>(code, A, out, n); assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
kQuantize<<<num_blocks, 1024>>>(code, A, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
void dequantize(float *code, unsigned char *A, float *out, int n) void dequantize(float *code, unsigned char *A, float *out, int n)
{ {
int blocks = n/1024; int num_blocks = n/1024;
blocks = n % 1024 == 0 ? blocks : blocks + 1; num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
kDequantize<<<blocks, 1024>>>(code, A, out, n); assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
kDequantize<<<num_blocks, 1024>>>(code, A, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n) template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n)
{ {
int blocks = n/4096; int num_blocks = n/4096;
blocks = n % 4096 == 0 ? blocks : blocks + 1; num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n); assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n)
{ {
int blocks = n/blocksize; int num_blocks = n/blocksize;
blocks = n % blocksize == 0 ? blocks : blocks + 1; num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
if(blocksize == 4096) if(blocksize == 4096)
kDequantizeBlockwise<T, 4096, 1024, 4><<<blocks, 4096/4>>>(code, A, absmax, out, n); kDequantizeBlockwise<T, 4096, 1024, 4><<<num_blocks, 4096/4>>>(code, A, absmax, out, n);
else if(blocksize == 2048) else if(blocksize == 2048)
kDequantizeBlockwise<T, 2048, 512, 4><<<blocks, 2048/4>>>(code, A, absmax, out, n); kDequantizeBlockwise<T, 2048, 512, 4><<<num_blocks, 2048/4>>>(code, A, absmax, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
...@@ -74,18 +80,19 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, ...@@ -74,18 +80,19 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
const float beta1, const float beta2, const float eps, const float weight_decay, const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n)
{ {
int blocks = n/4096; int num_blocks = n/4096;
blocks = n % 4096 == 0 ? blocks : blocks + 1; num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
switch(OPTIMIZER) switch(OPTIMIZER)
{ {
case ADAM: case ADAM:
if(max_unorm > 0.0f) if(max_unorm > 0.0f)
{ {
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
kOptimizer32bit2State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
case MOMENTUM: case MOMENTUM:
...@@ -95,11 +102,11 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, ...@@ -95,11 +102,11 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
if(max_unorm > 0.0f) if(max_unorm > 0.0f)
{ {
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n); kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
kOptimizer32bit1State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
} }
...@@ -115,8 +122,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, ...@@ -115,8 +122,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
float weight_decay, float weight_decay,
const float gnorm_scale, int n) const float gnorm_scale, int n)
{ {
int blocks = n/4096; int num_blocks = n/4096;
blocks = n % 4096 == 0 ? blocks : blocks + 1; num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); } if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); }
...@@ -125,9 +133,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, ...@@ -125,9 +133,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
case ADAM: case ADAM:
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1*sizeof(float)));
kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n); kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
kOptimizerStatic8bit2State<T, OPTIMIZER><<<blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, kOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
...@@ -135,9 +143,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, ...@@ -135,9 +143,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
case RMSPROP: case RMSPROP:
case ADAGRAD: case ADAGRAD:
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
kOptimizerStatic8bit1State<T, OPTIMIZER><<<blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr, kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr,
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
...@@ -156,22 +164,24 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g ...@@ -156,22 +164,24 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)
{ {
int blocks = 0; int num_blocks = 0;
switch(OPTIMIZER) switch(OPTIMIZER)
{ {
case ADAM: case ADAM:
blocks = n/BLOCKSIZE_2STATE; num_blocks = n/BLOCKSIZE_2STATE;
blocks = n % BLOCKSIZE_2STATE == 0 ? blocks : blocks + 1; num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1;
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr, assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<num_blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr,
quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
case MOMENTUM: case MOMENTUM:
case RMSPROP: case RMSPROP:
case ADAGRAD: case ADAGRAD:
blocks = n/BLOCKSIZE_1STATE; num_blocks = n/BLOCKSIZE_1STATE;
blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1; num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1;
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr, assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<num_blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n); quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
...@@ -182,10 +192,11 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g ...@@ -182,10 +192,11 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n) template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n)
{ {
int blocks = n/2048; int num_blocks = n/2048;
blocks = n % 2048 == 0 ? blocks : blocks + 1; num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float)));
kPercentileClipping<T, 2048, 4><<<blocks, 512>>>(g, gnorm_vec, step, n); kPercentileClipping<T, 2048, 4><<<num_blocks, 512>>>(g, gnorm_vec, step, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
...@@ -445,10 +456,9 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, ...@@ -445,10 +456,9 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out,
int num_blocks = numRows/subtile_rows; int num_blocks = numRows/subtile_rows;
num_blocks += (numRows % subtile_rows == 0) ? 0 : 1; num_blocks += (numRows % subtile_rows == 0) ? 0 : 1;
num_blocks = num_blocks*(tileCols/32); num_blocks = num_blocks*(tileCols/32);
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
assert(threads <= tilesize); assert(threads <= tilesize);
//cout << num_blocks << " blocks" << endl;
kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols, tileCols, n); kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols, tileCols, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
...@@ -461,7 +471,13 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r ...@@ -461,7 +471,13 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r
int tile_cols = STATS_THREADS*STATS_ITEMS; int tile_cols = STATS_THREADS*STATS_ITEMS;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS); int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS);
int num_blocks = (tiledCols/tile_cols) * (tiledRows/STATS_ROWS); int row_tiles = (tiledRows/STATS_ROWS);
int col_tiles = (tiledCols/tile_cols);
row_tiles = row_tiles > 0 ? row_tiles : 1;
col_tiles = col_tiles > 0 ? col_tiles : 1;
int num_blocks = row_tiles * col_tiles;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
if(nnz_threshold == 0.0) if(nnz_threshold == 0.0)
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 0><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 0><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
...@@ -479,12 +495,14 @@ void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col ...@@ -479,12 +495,14 @@ void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col
int tile_rows = 16; int tile_rows = 16;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows); int row_tiles = (tiledRows/tile_rows);
int col_tiles = (tiledCols/tile_cols);
row_tiles = row_tiles > 0 ? row_tiles : 1;
col_tiles = col_tiles > 0 ? col_tiles : 1;
int num_blocks = row_tiles * col_tiles;
//cout << cols << " " << tiledCols << " " << tiledRows << endl; assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
//cout << "num blocks " << num_blocks << endl;
//cout << A << " " << out_col_normed << endl;
if(threshold > 0.0f) if(threshold > 0.0f)
kDoubleRowColQuant<64, 4, 16, 64*4, 1><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols); kDoubleRowColQuant<64, 4, 16, 64*4, 1><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);
else else
...@@ -502,7 +520,13 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o ...@@ -502,7 +520,13 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
int tile_rows = 32; int tile_rows = 32;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows); int row_tiles = (tiledRows/tile_rows);
int col_tiles = (tiledCols/tile_cols);
row_tiles = row_tiles > 0 ? row_tiles : 1;
col_tiles = col_tiles > 0 ? col_tiles : 1;
int num_blocks = row_tiles * col_tiles;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
int outCols = fill_up_to_nearest_multiple(cols, 32); int outCols = fill_up_to_nearest_multiple(cols, 32);
int outRows = fill_up_to_nearest_multiple(rows, 32); int outRows = fill_up_to_nearest_multiple(rows, 32);
if(FORMAT == COL_TURING) if(FORMAT == COL_TURING)
...@@ -528,10 +552,6 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o ...@@ -528,10 +552,6 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
} }
} }
//cout << cols << " " << tiledCols << " " << tiledRows << " " << outCols << endl;
//cout << "num blocks " << num_blocks << endl;
//cout << A << " " << out_col_normed << endl;
kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<<num_blocks, threads>>>(A, out, rows, cols, tiledCols, outRows, outCols); kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<<num_blocks, threads>>>(A, out, rows, cols, tiledCols, outRows, outCols);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
......
...@@ -40,7 +40,8 @@ names = [ ...@@ -40,7 +40,8 @@ names = [
ids=names, ids=names,
) )
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
dim2 = dim2 - (dim2 % 16) if dim2 > 0:
dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16) dim3 = dim3 - (dim3 % 16)
dim4 = dim4 - (dim4 % 16) dim4 = dim4 - (dim4 % 16)
for i in range(k): for i in range(k):
...@@ -234,10 +235,7 @@ dim2 = torch.randint(32, 96, size=(n,)).tolist() ...@@ -234,10 +235,7 @@ dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist() dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist() dim4 = torch.randint(32, 96, size=(n,)).tolist()
# dim1 = (17,) dim2.append(0)
# dim2 = (7,)
# dim3 = (37,)
# dim4 = (23,)
decomp = [0.0, 6.0] decomp = [0.0, 6.0]
funcs = [(torch.matmul, bnb.matmul)] funcs = [(torch.matmul, bnb.matmul)]
...@@ -385,9 +383,14 @@ def test_matmullt( ...@@ -385,9 +383,14 @@ def test_matmullt(
) )
if req_grad[1]: if req_grad[1]:
n = gradB1.numel() n = gradB1.numel()
assert torch.abs(gradB1).sum() > 0.0 if dim2 > 0:
assert torch.abs(gradB2).sum() > 0.0 assert torch.abs(gradB1).sum() > 0.0
assert torch.abs(gradB2).sum() > 0.0
else:
assert torch.abs(gradB1).sum() == 0.0
assert torch.abs(gradB2).sum() == 0.0
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.1 assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02 assert (idx == 0).sum().item() < n * 0.02
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment