Commit aa26c99e authored by yan.yan's avatar yan.yan
Browse files

working on quantization

parent ee8c9465
......@@ -207,7 +207,7 @@ class Net(nn.Module):
pool_algo = algo
# pool_algo = ConvAlgo.Native
self.net = spconv.SparseSequential(
spconv.SubMConv3d(3, 64, 3, bias=False, indice_key="c0",
spconv.SubMConv3d(16, 64, 3, bias=False, indice_key="c0",
algo=algo),
nn.BatchNorm1d(64),
nn.ReLU(),
......@@ -373,6 +373,11 @@ class Net(nn.Module):
x = spconv.SparseConvTensor(features, coors, self.shape, batch_size, voxel_num=vx_num)
return self.net(x)
def _set_enable_int8_test_inplace(simple_module: torch.fx.GraphModule, enable: bool):
for m in simple_module.modules():
if isinstance(m, SparseConvolution):
if m.in_channels % 32 == 0 and m.out_channels % 32 == 0:
m.enable_int8_test_mode = enable
class MyTracer(torch.fx.Tracer):
......@@ -387,6 +392,7 @@ def main():
torch.backends.cudnn.allow_tf32 = False
with open(Path(__file__).parent.parent / "test" / "data" / "test_spconv.pkl", "rb") as f:
(voxels, coors, spatial_shape) = pickle.load(f)
voxels = np.random.uniform(-1, 1, size=[voxels.shape[0], 16]).astype(np.float32)
np.random.seed(50051)
device = torch.device("cuda:0")
device_cpu = torch.device("cpu:0")
......@@ -408,6 +414,10 @@ def main():
out_fused = net_fused(voxels_th_cuda, coors_th_cuda, 1)
res = Fsp.sparse_add_hash_based(out_ref, out_fused.minus())
print(torch.linalg.norm(res.features))
_set_enable_int8_test_inplace(net_fused, True)
qvoxels_cuda = voxels_th_cuda.to(torch.int8)
out_int8 = net_fused(qvoxels_cuda, coors_th_cuda, 1)
if __name__ == "__main__":
main()
\ No newline at end of file
......@@ -426,7 +426,7 @@ int main(int argc, char **argv) {
{SPCONV_ALLOC_OUT_FEATURES, out_features}};
StaticAllocator alloc2(tensor_dict);
ConvTunerSimple tuner(ConvMain::get_all_conv_algo_desp());
auto conv_res = ConvGemmOps::implicit_gemm(
auto conv_run_status = ConvGemmOps::implicit_gemm(
alloc2, tuner, input_features_real, weights, pair_fwd_real,
pair_mask_splits, mask_argsort_splits, num_act_out_real,
mask_tensor, arch, false, is_subm,
......@@ -435,7 +435,7 @@ int main(int argc, char **argv) {
1.0 /*bias alpha, only used for leaky relu*/,
0.0 /*unused for now*/, tv::gemm::Activation::kReLU);
tv::ssprint("selected conv algo",
std::get<1>(conv_res).algo_desp.__repr__());
std::get<1>(conv_run_status).algo_desp.__repr__());
// FINISH!!!
}
// calc maximum number of output points.
......
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import torch
import spconv.pytorch as spconv
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import contextlib
import torch.cuda.amp
@contextlib.contextmanager
def identity_ctx():
yield
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.net = spconv.SparseSequential(
nn.BatchNorm1d(1),
spconv.SubMConv2d(1, 32, 3, 1),
nn.ReLU(),
spconv.SubMConv2d(32, 64, 3, 1),
nn.ReLU(),
spconv.SparseConv2d(64, 64, 2, 2),
spconv.ToDense(),
)
self.fc1 = nn.Linear(14 * 14 * 64, 128)
self.fc2 = nn.Linear(128, 10)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
def forward(self, x: torch.Tensor):
# x: [N, 28, 28, 1], must be NHWC tensor
x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
# create SparseConvTensor manually: see SparseConvTensor.from_dense
x = self.net(x_sp)
x = torch.flatten(x, 1)
x = self.dropout1(x)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
scaler = torch.cuda.amp.grad_scaler.GradScaler()
amp_ctx = contextlib.nullcontext()
if args.fp16:
amp_ctx = torch.cuda.amp.autocast()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
with amp_ctx:
output = model(data)
loss = F.nll_loss(output, target)
scale = 1.0
if args.fp16:
assert loss.dtype is torch.float32
scaler.scale(loss).backward()
# scaler.step() first unscales the gradients of the optimizer's assigned params.
# If these gradients do not contain infs or NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
# scaler.unscale_(optim)
# Since the gradients of optimizer's assigned params are now unscaled, clips as usual.
# You may use the same value for max_norm here as you would without gradient scaling.
# torch.nn.utils.clip_grad_norm_(models[0].net.parameters(), max_norm=0.1)
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
scale = scaler.get_scale()
else:
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(args, model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
amp_ctx = contextlib.nullcontext()
if args.fp16:
amp_ctx = torch.cuda.amp.autocast()
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
with amp_ctx:
output = model(data)
test_loss += F.nll_loss(
output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(
dim=1,
keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print(
'\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size',
type=int,
default=64,
metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size',
type=int,
default=1000,
metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs',
type=int,
default=14,
metavar='N',
help='number of epochs to train (default: 14)')
parser.add_argument('--lr',
type=float,
default=1.0,
metavar='LR',
help='learning rate (default: 1.0)')
parser.add_argument('--gamma',
type=float,
default=0.7,
metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda',
action='store_true',
default=False,
help='disables CUDA training')
parser.add_argument('--seed',
type=int,
default=1,
metavar='S',
help='random seed (default: 1)')
parser.add_argument(
'--log-interval',
type=int,
default=10,
metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save-model',
action='store_true',
default=False,
help='For Saving the current Model')
parser.add_argument('--fp16',
action='store_true',
default=False,
help='For mixed precision training')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
'../data',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
# here we remove norm to get sparse tensor with lots of zeros
# transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size,
shuffle=True,
**kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
'../data',
train=False,
transform=transforms.Compose([
transforms.ToTensor(),
# here we remove norm to get sparse tensor with lots of zeros
# transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size,
shuffle=True,
**kwargs)
model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader)
scheduler.step()
if args.save_model:
torch.save(model.state_dict(), "mnist_cnn.pt")
if __name__ == '__main__':
main()
[build-system]
requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm>=0.3.7"]
# requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm @ file:///io/dist/cumm_cu118-0.3.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"]
# requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm @ file:///io/dist/cumm_cu120-0.3.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"]
build-backend = "setuptools.build_meta"
......@@ -616,6 +616,7 @@ class SimpleConv:
algocore.get_conv_algo_desp_from_param(p)
for p in ALL_IMPGEMM_PARAMS
]
self.all_desps = all_desps
self.prebuilt_desps = prebuilt_desps
self.prebuilt_desp_names = {str(d) for d in prebuilt_desps}
......@@ -648,13 +649,13 @@ class SimpleConv:
tile_ms_list, tile_ns_list, tile_ks_list, tile_shape_to_algos)
self.kc_forward_cache: Dict[Tuple[int, int, int, int, int, int, int,
int],
int, bool],
BestConvAlgoByProfile] = {} # for forward
self.kc_dgrad_cache: Dict[Tuple[int, int, int, int, int, int, int,
int], BestConvAlgoByProfile] = {
int, bool], BestConvAlgoByProfile] = {
} # for backward weight
self.kc_wgrad_cache: Dict[Tuple[int, int, int, int, int, int, int,
int], BestConvAlgoByProfile] = {
int, bool], BestConvAlgoByProfile] = {
} # for backward weight
self._nvrtc_caches: Dict[Tuple[str, Tuple[int, int]], NVRTCParams] = {}
......@@ -679,11 +680,12 @@ class SimpleConv:
op_type: ConvOpType,
mask_width: int,
fp32_accum: Optional[bool] = None,
use_tf32: bool = True):
use_tf32: bool = True,
bias: tv.Tensor = tv.Tensor(),
scale: tv.Tensor = tv.Tensor()):
avail_algos = get_available_algo_str_from_arch(arch)
finally_algos: List[ConvAlgoDesp] = []
is_fp16 = inp.dtype == tv.float16 and weight.dtype == tv.float16 and out.dtype == tv.float16
is_fp16 = inp.dtype == tv.float16 and weight.dtype == tv.float16 # and out.dtype == tv.float16
use_f32_as_accum = False
kv = int(np.prod(weight.shape[1:-1]))
# for 3d conv, if reduce axis is too large, may cause nan during
......@@ -703,6 +705,10 @@ class SimpleConv:
layout_w.interleave, layout_o.interleave, inp.dtype,
weight.dtype, out.dtype, op_type.value)
desps = self.static_key_to_desps.get(static_key, None)
# for d in self.all_desps:
# print(d)
# print(len(desps))
# breakpoint()
if desps is None or len(desps) == 0:
return finally_algos
for desp in desps:
......@@ -726,11 +732,21 @@ class SimpleConv:
ldw = weight.dim(-1)
ldo = out.dim(-1)
mask_width_valid = True
if desp.op_type.value == ConvOpType.kBackwardWeight.value:
assert mask_width > 0
mask_width_valid = mask_width % desp.tile_shape[2] == 0
require_dynamic_mask = kv > 32
if desp.supported_ldx_conv(ldi, ldw, ldo) and mask_width_valid:
if not bias.empty() and not scale.empty():
# int8 inference, bias/scale dtype must equal to compute dtype in gemm
assert bias.dtype == scale.dtype
if desp.dcomp != bias.dtype:
continue
if not desp.is_int8_inference:
continue
else:
if desp.is_int8_inference:
continue
if desp.is_nvrtc:
if not CompileInfo.algo_can_be_nvrtc_compiled(desp.min_arch):
continue
......@@ -747,6 +763,12 @@ class SimpleConv:
continue
if SPCONV_DEBUG_NVRTC_KERNELS:
desp.is_nvrtc = True
if require_dynamic_mask:
if not desp.dynamic_mask:
continue
else:
if desp.dynamic_mask:
continue
finally_algos.append(desp)
return finally_algos
......@@ -758,11 +780,12 @@ class SimpleConv:
k: int,
c: int,
arch: Tuple[int, int],
mask_width: int = -1):
mask_width: int = -1,
need_dynamic_mask: bool = False):
if not op_type == ConvOpType.kBackwardWeight:
# fwd and dgrad don't need
mask_width = -1
key = (i_dtype, w_dtype, o_dtype, k, c, arch[0], arch[1], mask_width)
key = (i_dtype, w_dtype, o_dtype, k, c, arch[0], arch[1], mask_width, need_dynamic_mask)
if op_type == ConvOpType.kForward:
return self.kc_forward_cache.get(key, None)
elif op_type == ConvOpType.kBackwardInput:
......@@ -795,8 +818,9 @@ class SimpleConv:
cudadevrt = str(cudadevrt_p)
mod = CummNVRTCModule([kernel],
cudadevrt_path=cudadevrt,
verbose=False,
custom_names=custom_names)
verbose=True,
custom_names=custom_names,
verbose_path="/home/yy/Projects/spconv-release/spconv/build/dev_nvrtc_int8")
mod.load()
return mod, kernel
......@@ -824,7 +848,6 @@ class SimpleConv:
mask_argsort: tv.Tensor,
indices: tv.Tensor,
reverse_mask: bool,
mask_int_count: int = 1,
mask_filter: int = 0xffffffff,
mask_width: int = -1,
mask_output: tv.Tensor = tv.Tensor(),
......@@ -832,17 +855,20 @@ class SimpleConv:
beta: float = 0.0,
stream: int = 0,
fp32_accum: Optional[bool] = None,
use_tf32: bool = True):
use_tf32: bool = True,
bias: tv.Tensor = tv.Tensor(),
scale: tv.Tensor = tv.Tensor()):
avail = self.get_all_available(inp, weight, output, layout_i, layout_w,
layout_o, arch, op_type, mask_width,
fp32_accum, use_tf32)
fp32_accum, use_tf32, bias, scale)
inp = inp.clone()
weight = weight.clone()
output = output.clone()
print(len(avail), inp.dtype, weight.dtype, output.dtype, bias.dtype, scale.dtype, bias.empty(), scale.empty())
channel_k = output.dim(1)
channel_c = inp.dim(1)
weight = weight.view([channel_k, -1, channel_c])
need_dynamic_mask = weight.dim(1) > 32
times: List[float] = []
all_profile_res: List[BestConvAlgoByProfile] = []
group_by_algo = {}
......@@ -865,8 +891,9 @@ class SimpleConv:
params.indices = indices
params.mask = mask
params.mask_output = mask_output
params.mask_int_count = mask_int_count
if desp.is_int8_inference:
params.bias = bias
params.scale = scale
# if op_type == ConvOpType.kBackwardWeight:
# assert not mask_output.empty()
if op_type == ConvOpType.kBackwardInput:
......@@ -909,7 +936,7 @@ class SimpleConv:
# fwd and dgrad don't need
mask_width = -1
key = (inp.dtype, weight.dtype, output.dtype, channel_k, channel_c,
arch[0], arch[1], mask_width)
arch[0], arch[1], mask_width, need_dynamic_mask)
with self.lock:
if op_type == ConvOpType.kForward:
self.kc_forward_cache[key] = res
......@@ -945,7 +972,9 @@ class SimpleConv:
act_alpha: float = 0.0,
act_beta: float = 0.0,
act_type: tv.gemm.Activation = tv.gemm.Activation.None_,
mask_int_count: Union[int, None] = None):
scale: Optional[tv.Tensor] = None,
output_add: Optional[tv.Tensor] = None):
channel_k = output.dim(1)
channel_c = inp.dim(1)
# GemmMainUnitTest.stream_synchronize(stream)
......@@ -986,9 +1015,12 @@ class SimpleConv:
params.mask_filter = mask_filter
params.mask_output = mask_output
params.reverse_mask = reverse_mask
params.mask_int_count = mask_int_count
if bias is not None:
params.bias = bias
if output_add is not None and algo_desp.is_int8_inference:
params.output_add = output_add
if scale is not None and algo_desp.is_int8_inference:
params.scale = scale
if timer.enable:
assert timer._timer is not None
params.timer = timer._timer
......
......@@ -36,8 +36,8 @@ from cumm.gemm.algospec import TensorOp
def _assign_gemm_desp_props(desp: Union[ConvAlgoDesp, GemmAlgoDesp],
p: Union[GemmAlgoParams, ConvAlgoParams]):
desp.dtype_a = p.dtype_a.tv_dtype
desp.dtype_b = p.dtype_a.tv_dtype
desp.dtype_c = p.dtype_a.tv_dtype
desp.dtype_b = p.dtype_b.tv_dtype
desp.dtype_c = p.dtype_c.tv_dtype
desp.dacc = p.dtype_acc.tv_dtype
desp.dcomp = p.dtype_comp.tv_dtype
desp.trans_a = p.trans_a
......@@ -87,6 +87,9 @@ def get_conv_algo_desp_from_param(p: ConvAlgoParams):
desp.element_per_access_a = ker.input_spec.input_iter_a.element_per_acc
desp.element_per_access_b = ker.input_spec.input_iter_b.element_per_acc
desp.element_per_access_c = ker.output_spec.out_iter.element_per_acc
desp.is_int8_inference = ker.int8_inference
desp.dynamic_mask = ker.dynamic_mask
desp.min_arch = ker.min_arch()
return desp
......@@ -141,4 +144,6 @@ def get_conv_param_from_desp(desp: ConvAlgoDesp):
desp.interleave_o)
p.mask_sparse = desp.mask_sparse
p.increment_k_first = desp.increment_k_first
p.int8_inference = desp.is_int8_inference
p.dynamic_mask = desp.dynamic_mask
return p
......@@ -39,12 +39,12 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable(
from spconv.csrc.sparse.inference import InferenceOps
all_shuffle = SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_AMPERE_PARAMS
all_shuffle = list(filter(lambda x: not x.is_nvrtc, all_shuffle))
# all_shuffle = list(filter(lambda x: not x.is_nvrtc, all_shuffle))
cu = GemmMainUnitTest(all_shuffle)
cu.namespace = "cumm.gemm.main"
all_imp = (IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS +
IMPLGEMM_TURING_PARAMS + IMPLGEMM_AMPERE_PARAMS)
all_imp = list(filter(lambda x: not x.is_nvrtc, all_imp))
# all_imp = list(filter(lambda x: not x.is_nvrtc, all_imp))
convcu = ConvMainUnitTest(all_imp)
convcu.namespace = "cumm.conv.main"
gemmtuner = GemmTunerSimple(cu)
......
......@@ -619,14 +619,11 @@ IMPLGEMM_AMPERE_PARAMS = [
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 64), (64, 32, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
......@@ -635,13 +632,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True),
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 32), (64, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
......@@ -650,13 +648,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True),
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 64), (32, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
......@@ -665,13 +664,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True),
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
......@@ -680,13 +680,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True),
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16", "s8,s8,f32,s32,f32", "s8,s8,f32,s32,f16", "s8,s8,f16,s32,f32", "s8,s8,f16,s32,f16"],
NHWC,
NHWC,
NHWC,
......@@ -695,13 +696,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True),
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
......@@ -710,13 +712,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True),
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 64), (64, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
......@@ -725,13 +728,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True),
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 256, 64), (64, 128, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
......@@ -740,13 +744,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True),
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (256, 128, 64), (128, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
......@@ -755,13 +760,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True),
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 128), (64, 64, 128),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3],
"s8,s8,s8,s32,s32",
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
......@@ -770,7 +776,8 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True),
is_nvrtc=True,
int8_inference=True),
]
IMPLGEMM_TURING_PARAMS = [
......@@ -779,7 +786,7 @@ IMPLGEMM_TURING_PARAMS = [
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16", "s8,s8,f32,s32,f32", "s8,s8,f32,s32,f16", "s8,s8,f16,s32,f32", "s8,s8,f16,s32,f16"],
NHWC,
NHWC,
NHWC,
......@@ -788,13 +795,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True),
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
......@@ -803,13 +811,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True),
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 64), (32, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
......@@ -818,13 +827,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True),
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
......@@ -833,13 +843,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True),
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 64), (64, 32, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
......@@ -848,13 +859,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True),
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 32), (64, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
......@@ -863,13 +875,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True),
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 256, 64), (64, 128, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
......@@ -878,13 +891,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True),
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (256, 128, 64), (128, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
......@@ -893,13 +907,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True),
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 128), (64, 64, 128),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
......@@ -908,13 +923,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True),
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 64), (64, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
......@@ -923,7 +939,8 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True),
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (32, 16, 16), (16, 16, 16),
......
......@@ -144,7 +144,7 @@ class SpconvOps:
"""
...
@staticmethod
def generate_conv_inds_mask_stage2(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0, mask_int_count: int = 1) -> int:
def generate_conv_inds_mask_stage2(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
"""
Args:
indices:
......@@ -167,11 +167,10 @@ class SpconvOps:
dilation:
transposed:
stream_int:
mask_int_count:
"""
...
@staticmethod
def generate_conv_inds_stage2_mask_direct_table(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0, mask_int_count: int = 1) -> int:
def generate_conv_inds_stage2_mask_direct_table(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
"""
Args:
indices:
......@@ -194,11 +193,10 @@ class SpconvOps:
dilation:
transposed:
stream_int:
mask_int_count:
"""
...
@staticmethod
def generate_subm_conv_inds(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, batch_size: int, input_dims: List[int], ksize: List[int], dilation: List[int], indice_pair_mask: Tensor = Tensor(), backward: bool = False, stream_int: int = 0, mask_int_count: int = 1) -> int:
def generate_subm_conv_inds(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, batch_size: int, input_dims: List[int], ksize: List[int], dilation: List[int], indice_pair_mask: Tensor = Tensor(), backward: bool = False, stream_int: int = 0) -> int:
"""
Args:
indices:
......@@ -214,7 +212,6 @@ class SpconvOps:
indice_pair_mask:
backward:
stream_int:
mask_int_count:
"""
...
@staticmethod
......@@ -383,65 +380,25 @@ class SpconvOps:
"""
...
@staticmethod
def sort_1d_by_key_allocator_mask32(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0) -> Tensor:
def sort_1d_by_key_allocator(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0, mask_count: int = 1) -> Tensor:
"""
Args:
data:
alloc_func:
indices:
stream:
mask_count:
"""
...
@staticmethod
def sort_1d_by_key_allocator_mask32_v2(data: Tensor, allocator, indices: Tensor = Tensor(), stream: int = 0) -> Tensor:
def sort_1d_by_key_allocator_v2(data: Tensor, allocator, indices: Tensor = Tensor(), stream: int = 0, mask_count: int = 1) -> Tensor:
"""
Args:
data:
allocator:
indices:
stream:
"""
...
@staticmethod
def sort_1d_by_key_allocator_mask128(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0) -> Tensor:
"""
Args:
data:
alloc_func:
indices:
stream:
"""
...
@staticmethod
def sort_1d_by_key_allocator_mask128_v2(data: Tensor, allocator, indices: Tensor = Tensor(), stream: int = 0) -> Tensor:
"""
Args:
data:
allocator:
indices:
stream:
"""
...
@staticmethod
def sort_1d_by_key_allocator_mask_auto(data: Tensor, alloc_param, indices: Tensor = Tensor(), stream: int = 0, mask_int_count: int = 1) -> Tensor:
"""
Args:
data:
alloc_param:
indices:
stream:
mask_int_count:
"""
...
@staticmethod
def sort_1d_by_key_allocator_mask_auto_v2(data: Tensor, alloc_param, indices: Tensor = Tensor(), stream: int = 0, mask_int_count: int = 1) -> Tensor:
"""
Args:
data:
alloc_param:
indices:
stream:
mask_int_count:
mask_count:
"""
...
@staticmethod
......@@ -598,7 +555,7 @@ class SpconvOps:
"""
...
@staticmethod
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1, timer: CUDAKernelTimer = CUDAKernelTimer(False), direct_table: bool = False, preallocated: Dict[str, Tensor] = {}) -> Tuple[Tensor, int, int]:
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1, timer: CUDAKernelTimer = CUDAKernelTimer(False), direct_table: bool = False, preallocated: Dict[str, Tensor] = {}) -> Tuple[Tensor, int]:
"""
Args:
allocator:
......
......@@ -20,7 +20,7 @@ class ConvTunerSimple:
arch:
"""
...
def get_all_available(self, inp: Tensor, weight: Tensor, out: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], op_type: int, mask_width: int, auto_fp32_accum: bool, fp32_accum: bool, use_tf32: bool = True) -> List[ConvAlgoDesp]:
def get_all_available(self, inp: Tensor, weight: Tensor, out: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], op_type: int, mask_width: int, auto_fp32_accum: bool, fp32_accum: bool, use_tf32: bool = True, bias: Tensor = Tensor(), scale: Tensor = Tensor()) -> List[ConvAlgoDesp]:
"""
Args:
inp:
......@@ -38,6 +38,8 @@ class ConvTunerSimple:
auto_fp32_accum:
fp32_accum:
use_tf32:
bias:
scale:
"""
...
def cached_get_nvrtc_params(self, desp: ConvAlgoDesp, arch: Tuple[int, int], stream_int: int) -> NVRTCParams:
......@@ -48,7 +50,7 @@ class ConvTunerSimple:
stream_int:
"""
...
def tune_and_cache(self, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], mask: Tensor, mask_argsort: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, mask_output: Tensor = Tensor(), alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, mask_int_count: int = 1, auto_fp32_accum: bool = True, fp32_accum: bool = False, num_run: int = 5, use_tf32: bool = True) -> Tuple[ConvTuneResult, float]:
def tune_and_cache(self, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], mask: Tensor, mask_argsort: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, mask_output: Tensor = Tensor(), alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, auto_fp32_accum: bool = True, fp32_accum: bool = False, num_run: int = 5, use_tf32: bool = True, bias: Tensor = Tensor(), scale: Tensor = Tensor()) -> Tuple[ConvTuneResult, float]:
"""
Args:
op_type:
......@@ -72,14 +74,15 @@ class ConvTunerSimple:
alpha:
beta:
stream_int:
mask_int_count:
auto_fp32_accum:
fp32_accum:
num_run:
use_tf32:
bias:
scale:
"""
...
def get_tuned_algo(self, op_type: int, i_dtype: int, w_dtype: int, o_dtype: int, k: int, c: int, arch: Tuple[int, int], mask_width: int = -1) -> Tuple[Any, bool]:
def get_tuned_algo(self, op_type: int, i_dtype: int, w_dtype: int, o_dtype: int, k: int, c: int, arch: Tuple[int, int], mask_width: int = -1, need_dynamic_mask: bool = False) -> Tuple[Any, bool]:
"""
Args:
op_type:
......@@ -90,9 +93,10 @@ class ConvTunerSimple:
c:
arch:
mask_width:
need_dynamic_mask:
"""
...
def run_with_tuned_result(self, profile_res, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, mask: Tensor, mask_argsort: Tensor, mask_output: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, mask_int_count: int = 1, workspace: Tensor = Tensor(), verbose: bool = False, timer: CUDAKernelTimer = CUDAKernelTimer(false), force_nvrtc: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_) -> None:
def run_with_tuned_result(self, profile_res, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, mask: Tensor, mask_argsort: Tensor, mask_output: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, workspace: Tensor = Tensor(), verbose: bool = False, timer: CUDAKernelTimer = CUDAKernelTimer(false), force_nvrtc: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_, scale: Tensor = Tensor(), output_add: Tensor = Tensor()) -> None:
"""
Args:
profile_res:
......@@ -110,7 +114,6 @@ class ConvTunerSimple:
alpha:
beta:
stream_int:
mask_int_count:
workspace:
verbose:
timer:
......@@ -119,6 +122,8 @@ class ConvTunerSimple:
act_alpha:
act_beta:
act_type:
scale:
output_add:
"""
...
def query_workspace_size(self, desp: ConvAlgoDesp, splitk: int, op_type: int, N: int, C: int, K: int, kv: int) -> int:
......
......@@ -63,7 +63,7 @@ class ConvGemmOps:
"""
...
@staticmethod
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, mask_int_count: int, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_, use_tf32: bool = True) -> Tuple[int, Any]:
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_, use_tf32: bool = True, output_scale: float = 1.0, scale: Tensor = Tensor(), output_add: Tensor = Tensor(), output_add_scale: float = 1.0) -> Tuple[int, Any]:
"""
Args:
allocator:
......@@ -75,7 +75,6 @@ class ConvGemmOps:
mask_argsort_fwd_splits:
num_activate_out:
masks:
mask_int_count:
arch:
is_train:
is_subm:
......@@ -88,10 +87,14 @@ class ConvGemmOps:
act_beta:
act_type:
use_tf32:
output_scale:
scale:
output_add:
output_add_scale:
"""
...
@staticmethod
def implicit_gemm_backward(allocator, conv_tuner, features: Tensor, filters: Tensor, out_bp: Tensor, pair_fwd: Tensor, pair_bwd: Tensor, pair_mask_fwd_splits: List[Tensor], pair_mask_bwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], mask_argsort_bwd_splits: List[Tensor], mask_output_fwd: Tensor, masks: Tensor, mask_int_count: int, arch: Tuple[int, int], mask_width: int, is_subm: bool, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, use_tf32: bool = True) -> None:
def implicit_gemm_backward(allocator, conv_tuner, features: Tensor, filters: Tensor, out_bp: Tensor, pair_fwd: Tensor, pair_bwd: Tensor, pair_mask_fwd_splits: List[Tensor], pair_mask_bwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], mask_argsort_bwd_splits: List[Tensor], mask_output_fwd: Tensor, masks: Tensor, arch: Tuple[int, int], mask_width: int, is_subm: bool, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, use_tf32: bool = True) -> None:
"""
Args:
allocator:
......@@ -107,7 +110,6 @@ class ConvGemmOps:
mask_argsort_bwd_splits:
mask_output_fwd:
masks:
mask_int_count:
arch:
mask_width:
is_subm:
......
......@@ -30,7 +30,7 @@ from .alloc import ExternalAllocator, ThrustAllocator
from spconv.constants import SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE, AllocKeys
import re
import os
from cumm.gemm.codeops import dispatch
class CustomThrustLib(pccm.Class):
def __init__(self):
super().__init__()
......@@ -462,7 +462,6 @@ class SpconvOps(pccm.Class):
code.arg("ksize, stride, padding, dilation", f"std::vector<int>")
code.arg("transposed", f"bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.arg("mask_int_count", "int", "1")
code.raw(f"""
int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim &&
......@@ -489,7 +488,7 @@ class SpconvOps(pccm.Class):
indice_pairs_uniq, indice_pairs_uniq_before_sort,
out_inds, mask_fwd, mask_bwd,
num_out_act, batch_size, output_dims_, input_dims_,
ksize_, stride_, padding_, dilation_, transposed, stream_int, mask_int_count);
ksize_, stride_, padding_, dilation_, transposed, stream_int);
}}
""")
code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""")
......@@ -513,7 +512,6 @@ class SpconvOps(pccm.Class):
code.arg("ksize, stride, padding, dilation", f"std::vector<int>")
code.arg("transposed", f"bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.arg("mask_int_count", "int", "1")
code.raw(f"""
int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim &&
......@@ -540,7 +538,7 @@ class SpconvOps(pccm.Class):
indice_pairs_uniq, indice_pairs_uniq_before_sort,
out_inds, mask_fwd, mask_bwd,
num_out_act, batch_size, output_dims_, input_dims_,
ksize_, stride_, padding_, dilation_, transposed, stream_int, mask_int_count);
ksize_, stride_, padding_, dilation_, transposed, stream_int);
}}
""")
code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""")
......@@ -561,7 +559,6 @@ class SpconvOps(pccm.Class):
"cumm.tensorview.Tensor = Tensor()")
code.arg("backward", "bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int = 0")
code.arg("mask_int_count", "int", "1")
code.raw(f"""
int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(input_dims.size() == ndim &&
......@@ -582,7 +579,7 @@ class SpconvOps(pccm.Class):
indice_pairs, out_inds, indice_num_per_loc,
batch_size, input_dims_,
ksize_, dilation_, indice_pair_mask, backward,
stream_int, mask_int_count);
stream_int);
}}
""")
code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""")
......@@ -909,7 +906,7 @@ class SpconvOps(pccm.Class):
""")
return code
def sort_1d_by_key_allocator_template(self, use_allocator: bool, int_count: int = 1):
def sort_1d_by_key_allocator_template(self, use_allocator: bool):
code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
......@@ -924,18 +921,7 @@ class SpconvOps(pccm.Class):
"tv::Tensor()",
pyanno="cumm.tensorview.Tensor = Tensor()")
code.arg("stream", "std::uintptr_t", "0", pyanno="int")
code.code_after_include = f"""
template <typename T> struct SmallOrEqualTo {{
TV_HOST_DEVICE_INLINE T operator()(const T &x, const T &y) const {{
return x < y;
}}
}};
template <typename T> __global__ void mask_input(T* inp, T mask, int size){{
for (int i : tv::KernelLoopX<int>(size)){{
inp[i] &= mask;
}}
}}
"""
code.arg("mask_count", "int", "1", pyanno="int")
code.add_dependency(CustomThrustLib, TensorViewKernel)
code.add_param_class("cudakers", self.cuda_common_kernel)
if not use_allocator:
......@@ -945,20 +931,29 @@ class SpconvOps(pccm.Class):
code.raw(f"""
cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream);
if (indices.empty()){{
indices = tv::empty({{data.dim(0) / {int_count}}}, tv::int32, 0);
indices = tv::empty({{data.dim(0)}}, tv::int32, 0);
}}
tv::cuda::Launch launcher(data.dim(0), stream_cu);
launcher(cudakers::arange_kernel<int32_t>, indices.data_ptr<int32_t>(), indices.dim(0));
// auto timer = tv::CUDATimer();
tv::dispatch<int32_t, uint32_t, int64_t, uint64_t>(data.dtype(), [&](auto I){{
using T_ = TV_DECLTYPE(I);
using T = {"T_" if int_count == 1 else f"thrust::tuple<{', '.join(['T_'] * int_count) }>"};
thrust::device_ptr<T> ptr_tr(reinterpret_cast<T*>(data.data_ptr<T_>()));
thrust::device_ptr<int32_t> ptr_k(indices.data_ptr<int32_t>());
auto thrust_ctx = thrust::cuda::par.on(stream_cu);
auto ctx2 = thrust::cuda::par(allocator).on(stream_cu);
thrust::sort_by_key(ctx2, ptr_tr, ptr_tr + data.dim(0) / {int_count}, ptr_k);
}});
""")
# nested tv::dispatch may cause compiler bug in msvc.
for dtype in dispatch(code, [dtypes.int32, dtypes.int64, dtypes.uint32, dtypes.uint64], "data.dtype()"):
code.raw(f"""
using T_ = {dtype};
tv::dispatch_int<1, 2, 3, 4>(mask_count, [&](auto IV){{
constexpr int I = TV_DECLTYPE(IV)::value;
// we can't use thrust::tuple in mp_repeat_c directly because
// thrust tuple actually has fixed size template arguments.
using T = tv::mp_rename<tv::mp_repeat_c<tv::mp_list<T_>, I>, thrust::tuple>;
thrust::device_ptr<T> ptr_tr(reinterpret_cast<T*>(data.data_ptr<T_>()));
thrust::device_ptr<int32_t> ptr_k(indices.data_ptr<int32_t>());
auto thrust_ctx = thrust::cuda::par.on(stream_cu);
auto ctx2 = thrust::cuda::par(allocator).on(stream_cu);
thrust::sort_by_key(ctx2, ptr_tr, ptr_tr + data.dim(0), ptr_k);
}});
""")
code.raw(f"""
// tv::ssprint("SORT BY KEY TIME", data.dim(0), timer.report() / 1000.0);
return indices;
""")
......@@ -967,71 +962,12 @@ class SpconvOps(pccm.Class):
@pccm.pybind.mark
@_STATIC_FUNCTION
def sort_1d_by_key_allocator_mask32(self):
# for python
def sort_1d_by_key_allocator(self):
return self.sort_1d_by_key_allocator_template(False)
@pccm.pybind.mark
@_STATIC_FUNCTION
def sort_1d_by_key_allocator_mask32_v2(self):
# for python
return self.sort_1d_by_key_allocator_template(True)
@pccm.pybind.mark
@_STATIC_FUNCTION
def sort_1d_by_key_allocator_mask128(self):
# for python
return self.sort_1d_by_key_allocator_template(False, 4)
@pccm.pybind.mark
@_STATIC_FUNCTION
def sort_1d_by_key_allocator_mask128_v2(self):
# for python
return self.sort_1d_by_key_allocator_template(True, 4)
def sort_1d_by_key_allocator_mask_auto_template(self, use_allocator: bool):
code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
code.arg("data", "tv::Tensor")
if not use_allocator:
code.arg("alloc_param", "std::function<std::uintptr_t(std::size_t)>")
else:
code.arg("alloc_param", "ThrustAllocator&")
code.arg("indices",
"tv::Tensor",
"tv::Tensor()",
pyanno="cumm.tensorview.Tensor = Tensor()")
code.arg("stream", "std::uintptr_t", "0", pyanno="int")
code.arg("mask_int_count", "int", "1")
code.raw(f"""
switch (mask_int_count){{
case 1:
return sort_1d_by_key_allocator_mask32{"_v2" if use_allocator else ""}(data, alloc_param, indices, stream);
case 4:
return sort_1d_by_key_allocator_mask128{"_v2" if use_allocator else ""}(data, alloc_param, indices, stream);
default:
TV_ASSERT_RT_ERR(false, "Not implement for other mask_int_count");
return tv::Tensor();
}}
""")
return code.ret("tv::Tensor")
@pccm.pybind.mark
@pccm.static_function
def sort_1d_by_key_allocator_mask_auto(self):
return self.sort_1d_by_key_allocator_mask_auto_template(False)
@pccm.pybind.mark
@pccm.static_function
def sort_1d_by_key_allocator_mask_auto_v2(self):
return self.sort_1d_by_key_allocator_mask_auto_template(True)
@_STATIC_FUNCTION
def sort_1d_by_key_allocator_v2(self):
# for cpp only
return self.sort_1d_by_key_allocator_template(True)
@pccm.pybind.mark
......@@ -1622,7 +1558,7 @@ class SpconvOps(pccm.Class):
code.raw(f"""
int hash_size = 2 * num_act_out_bound;
if (direct_table){{
hash_size = int({SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE} * max_act_out_in_theory);
hash_size = tv::align_up(int({SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE} * max_act_out_in_theory), 2);
}}
size_t res = 0;
if (subm){{
......@@ -1655,7 +1591,7 @@ class SpconvOps(pccm.Class):
max_act_out_in_theory, subm, use_int64_hash_k, direct_table);
int hash_size = 2 * num_act_out_bound;
if (direct_table){{
hash_size = int({SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE} * max_act_out_in_theory);
hash_size = tv::align_up(int({SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE} * max_act_out_in_theory), 2);
}}
if (use_int64_hash_k){{
auto ten = tv::from_blob(workspace, {{int64_t(hash_size)}}, tv::int64, 0);
......@@ -1720,10 +1656,10 @@ class SpconvOps(pccm.Class):
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int));
auto conv_algo = static_cast<tv::gemm::SparseConvAlgo>(algo);
int kv = std::accumulate(ksize.begin(), ksize.end(), 1, std::multiplies<int>());
int mask_int_count = (kv + 31) / 32;
if (mask_int_count > 1 && mask_int_count < 4)
mask_int_count = 4;
TV_ASSERT_RT_ERR(mask_int_count == 1 || mask_int_count == 4, "Not Implement too large kernel");
int mask_int_count = tv::div_up(kv, 32);
// if (mask_int_count > 1 && mask_int_count < 4)
// mask_int_count = 4;
// TV_ASSERT_RT_ERR(mask_int_count == 1 || mask_int_count == 4, "Not Implement too large kernel");
// TV_ASSERT_RT_ERR(kv <= 32, "currently only support ksize < 32");
std::vector<int> out_shape;
if (!subm){{
......@@ -1845,14 +1781,14 @@ class SpconvOps(pccm.Class):
pair_mask = preallocated.at({pccm.literal(AllocKeys.PairMask)});
}}else{{
pair_mask = allocator.empty({pccm.literal(AllocKeys.PairMask)},
{{mask_split_count, num_act_in * mask_int_count}}, tv::uint32, 0, stream_int);
{{mask_split_count, num_act_in, mask_int_count}}, tv::uint32, 0, stream_int);
}}
generate_subm_conv_inds(indices, hash_k, hash_v, pair, out_inds, indice_num_per_loc,
batch_size, input_dims, ksize, dilation, pair_mask, is_train, stream_int, mask_int_count);
batch_size, input_dims, ksize, dilation, pair_mask, is_train, stream_int);
auto mask_argsort = allocator.empty({pccm.literal(AllocKeys.MaskArgSort)},
{{mask_split_count, num_act_in}}, tv::int32, 0, stream_int);
for (int j = 0; j < mask_split_count; ++j){{
sort_1d_by_key_allocator_mask_auto_v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int, mask_int_count);
sort_1d_by_key_allocator_v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int, mask_int_count);
}}
""")
with code.else_():
......@@ -1958,11 +1894,11 @@ Your Conv Params: )" << "\\n";
pair_fwd = allocator.full_int({pccm.literal(AllocKeys.PairFwd)},
{{kv, num_act_out}}, -1, indices.dtype(), indices.device(), stream_int);
pair_mask_fwd = allocator.zeros({pccm.literal(AllocKeys.PairMask)},
{{mask_split_count, num_act_out * mask_int_count}}, tv::uint32, 0, stream_int);
{{mask_split_count, num_act_out, mask_int_count}}, tv::uint32, 0, stream_int);
pair_mask_bwd = tv::Tensor();
if (is_train){{
pair_mask_bwd = allocator.zeros({pccm.literal(AllocKeys.PairMaskBwd)},
{{mask_split_count, indices.dim(0) * mask_int_count}}, tv::uint32, 0, stream_int);
{{mask_split_count, indices.dim(0), mask_int_count}}, tv::uint32, 0, stream_int);
}}
}}
if (!direct_table){{
......@@ -1994,13 +1930,13 @@ Your Conv Params: )" << "\\n";
indice_pairs_uniq, indice_pairs_uniq_bkp,
out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out,
batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
transposed, stream_int, mask_int_count);
transposed, stream_int);
}}else{{
generate_conv_inds_mask_stage2(indices, hash_k, hash_v, pair_fwd, pair_bwd,
indice_pairs_uniq, indice_pairs_uniq_bkp,
out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out,
batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
transposed, stream_int, mask_int_count);
transposed, stream_int);
}}
}}
""")
......@@ -2030,21 +1966,21 @@ Your Conv Params: )" << "\\n";
}}
}}else{{
if (!is_train){{
sort_1d_by_key_allocator_mask_auto_v2(pair_mask_fwd[0], thrustalloc,
sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc,
mask_argsort_fwd[0], stream_int, mask_int_count);
}}else{{
sort_1d_by_key_allocator_mask_auto_v2(pair_mask_fwd[0], thrustalloc,
sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc,
mask_argsort_fwd[0], stream_int, mask_int_count);
sort_1d_by_key_allocator_mask_auto_v2(pair_mask_bwd[0], thrustalloc,
sort_1d_by_key_allocator_v2(pair_mask_bwd[0], thrustalloc,
mask_argsort_bwd[0], stream_int, mask_int_count);
}}
}}
}}
""")
code.raw(f"""
return std::make_tuple(mask_tensor, num_act_out, mask_int_count);
return std::make_tuple(mask_tensor, num_act_out);
""")
return code.ret("std::tuple<tv::Tensor, int, int>")
return code.ret("std::tuple<tv::Tensor, int>")
@pccm.pybind.mark
@pccm.static_function
......
......@@ -936,7 +936,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
"int, int, int, int, int>"))
self.add_typedef(
"algo_cache_key_t", "std::tuple<int, int, int, int, "
"int, int, int, int>")
"int, int, int, int, bool>")
self.add_member("desps_", "std::vector<tv::gemm::ConvAlgoDesp>")
self.add_member(
......@@ -1009,7 +1009,10 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code.arg("auto_fp32_accum", "bool")
code.arg("fp32_accum", "bool")
code.arg("use_tf32", "bool", "true")
code.arg("bias", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()")
code.arg("scale", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()")
code.raw(f"""
tv::gemm::ConvOpType op_type_cpp = static_cast<tv::gemm::ConvOpType>(op_type);
......@@ -1077,7 +1080,22 @@ class ConvTunerSimple(pccm.ParameterizedClass):
TV_ASSERT_RT_ERR(mask_width > 0, "eroro");
mask_width_valid = mask_width % desp.tile_shape[2] == 0;
}}
bool require_dynamic_mask = kv > 32;
if (desp.supported_ldx_conv(ldi, ldw, ldo) && mask_width_valid){{
if (!bias.empty() && !scale.empty()){{
TV_ASSERT_RT_ERR(bias.dtype() == scale.dtype(), "bias/scale dtype must equal to compute dtype in gemm");
if (desp.dcomp != bias.dtype()){{
continue;
}}
if (!desp.is_int8_inference){{
continue;
}}
}}else{{
if (desp.is_int8_inference){{
continue;
}}
}}
auto desp2 = desp;
if (desp.is_nvrtc){{
if (!CompileInfo::algo_can_be_nvrtc_compiled(desp.min_arch)){{
......@@ -1093,6 +1111,15 @@ class ConvTunerSimple(pccm.ParameterizedClass):
}}
}}
}}
if (require_dynamic_mask){{
if (!desp.dynamic_mask){{
continue;
}}
}}else{{
if (desp.dynamic_mask){{
continue;
}}
}}
finally_algos.push_back(desp2);
}}
}}
......@@ -1138,11 +1165,14 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code.arg("beta", "float", "0.0")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.arg("mask_int_count", "int", "1")
code.arg("auto_fp32_accum", "bool", "true")
code.arg("fp32_accum", "bool", "false")
code.arg("num_run", "int", "5")
code.arg("use_tf32", "bool", "true")
code.arg("bias", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()")
code.arg("scale", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
......@@ -1157,12 +1187,15 @@ class ConvTunerSimple(pccm.ParameterizedClass):
auto avail = get_all_available(inp, weight, output, layout_i, layout_w,
layout_o, interleave_i, interleave_w, interleave_o,
arch, op_type, mask_width,
auto_fp32_accum, fp32_accum, use_tf32);
auto_fp32_accum, fp32_accum, use_tf32,
bias, scale);
inp = inp.clone();
weight = weight.clone();
bool need_dynamic_mask = weight.dim(1) > 32;
output = output.clone();
int channel_k = output.dim(1);
int channel_c = inp.dim(1);
weight = weight.view(channel_k, -1, channel_c);
std::vector<ConvTuneResult> all_profile_res;
std::unordered_set<int> splitk_tests;
......@@ -1187,7 +1220,10 @@ class ConvTunerSimple(pccm.ParameterizedClass):
params.indices = indices;
params.mask = mask;
params.mask_output = mask_output;
params.mask_int_count = mask_int_count;
if (desp.is_int8_inference){{
params.bias = bias;
params.scale = scale;
}}
// if (op_type_cpp == tv::gemm::ConvOpType::kBackwardWeight){{
// TV_ASSERT_RT_ERR(!mask_output.empty(), "error");
......@@ -1246,7 +1282,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
}}
algo_cache_key_t key;
key = std::make_tuple(int(inp.dtype()), int(weight.dtype()),
int(output.dtype()), channel_k, channel_c, std::get<0>(arch), std::get<1>(arch), mask_width);
int(output.dtype()), channel_k, channel_c, std::get<0>(arch), std::get<1>(arch), mask_width, need_dynamic_mask);
{{
std::lock_guard<std::mutex> guard(mutex_);
......@@ -1279,6 +1315,8 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code.arg("k, c", "int")
code.arg("arch", "std::tuple<int, int>")
code.arg("mask_width", "int", "-1")
code.arg("need_dynamic_mask", "bool", "false")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
return code.ret("std::tuple<ConvTuneResult, bool>")
......@@ -1290,7 +1328,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
}}
algo_cache_key_t key;
key = std::make_tuple(i_dtype, w_dtype, o_dtype, k, c,
std::get<0>(arch), std::get<1>(arch), mask_width);
std::get<0>(arch), std::get<1>(arch), mask_width, need_dynamic_mask);
ConvTuneResult res;
bool exists = false;
{{
......@@ -1338,7 +1376,6 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code.arg("beta", "float", "0.0")
code.arg("stream_int", f"std::uintptr_t", "0")
code.arg("mask_int_count", "int", "1")
code.arg("workspace", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()")
code.arg("verbose", f"bool", "false")
......@@ -1350,7 +1387,10 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code.arg("act_alpha", f"float", "0.0")
code.arg("act_beta", f"float", "0.0")
code.arg("act_type", f"tv::gemm::Activation", "tv::gemm::Activation::kNone", "cumm.tensorview.gemm.Activation = Activation.None_")
code.arg("scale", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()")
code.arg("output_add", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
return code
......@@ -1376,6 +1416,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
params.output = output;
params.verbose = verbose;
params.bias = bias;
params.scale = scale;
params.split_k_slices = split_k_slices;
params.alpha = alpha;
......@@ -1383,7 +1424,9 @@ class ConvTunerSimple(pccm.ParameterizedClass):
params.act_alpha = act_alpha;
params.act_beta = act_beta;
params.act_type = act_type;
if (!output_add.empty() && desp.is_int8_inference){{
params.output_add = output_add;
}}
params.stream = stream_int;
params.mask_argsort = mask_argsort;
params.indices = indices;
......@@ -1393,7 +1436,6 @@ class ConvTunerSimple(pccm.ParameterizedClass):
params.mask_width = mask_width;
params.mask_output = mask_output;
params.reverse_mask = reverse_mask;
params.mask_int_count = mask_int_count;
if (timer.enable()){{
params.timer = timer;
......@@ -2039,7 +2081,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
"std::vector<tv::Tensor>")
code.arg("num_activate_out", "int")
code.arg("masks", "tv::Tensor")
code.arg("mask_int_count", "int")
code.arg("arch", "std::tuple<int, int>")
code.arg("is_train, is_subm", "bool", "false")
......@@ -2055,7 +2096,13 @@ class ConvGemmOps(pccm.ParameterizedClass):
code.arg("act_beta", f"float", "0.0")
code.arg("act_type", f"tv::gemm::Activation", "tv::gemm::Activation::kNone", "cumm.tensorview.gemm.Activation = Activation.None_")
code.arg("use_tf32", "bool", "true")
code.arg("output_scale", "float", "1.0")
code.arg("scale", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()")
code.arg("output_add", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()")
code.arg("output_add_scale", "float", "1.0")
code.arg("output_dtype", "int", "-1")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
......@@ -2072,13 +2119,18 @@ class ConvGemmOps(pccm.ParameterizedClass):
int num_split = pair_mask_fwd_splits.size();
TV_ASSERT_RT_ERR(num_mask == num_split, "error");
filters = filters.view(out_channel, -1, in_channel);
int kv = filters.dim(1);
int mask_int_count = tv::div_up(kv, 32);
tv::Tensor out_features;
if (output_dtype < 0){{
output_dtype = int(features.dtype());
}}
if (is_subm){{
out_features = allocator.empty({pccm.literal(AllocKeys.OutFeatures)},
{{num_activate_out, out_channel}}, features.dtype(), features.device(), stream_int);
{{num_activate_out, out_channel}}, tv::DType(output_dtype), features.device(), stream_int);
}}else{{
out_features = allocator.zeros({pccm.literal(AllocKeys.OutFeatures)},
{{num_activate_out, out_channel}}, features.dtype(), features.device(), stream_int);
{{num_activate_out, out_channel}}, tv::DType(output_dtype), features.device(), stream_int);
}}
// auto start_ev = tv::CUDAEvent();
// start_ev.record(stream_int);
......@@ -2113,20 +2165,24 @@ class ConvGemmOps(pccm.ParameterizedClass):
tv::Tensor(), // mask_output
1.0, 0.0,
stream_int,
mask_int_count, // mask_int_count is after stream_int
auto_fp32_accum,
fp32_accum,
5, // num_run
use_tf32);
use_tf32,
bias,
scale);
tune_res = std::get<0>(tune_res_time);
}}
float alpha = 1.0;
if (tune_res.algo_desp.is_int8_inference){{
alpha = output_scale;
}}
int mask_width = tune_res.algo_desp.tile_shape[0];
tv::Tensor mask_output_fwd;
std::vector<tv::Tensor> mask_output_fwd_splits;
if (is_train){{
mask_output_fwd = allocator.empty({pccm.literal(AllocKeys.MaskOutputFwd)},
{{num_split, tv::div_up(num_activate_out, mask_width) * mask_int_count}},
{{num_split, tv::div_up(num_activate_out, mask_width), mask_int_count}},
tv::uint32, features.device(), stream_int);
for (int i = 0; i < num_split; ++i){{
mask_output_fwd_splits.push_back(mask_output_fwd[i]);
......@@ -2139,9 +2195,15 @@ class ConvGemmOps(pccm.ParameterizedClass):
for (int j = 0; j < num_split; ++j){{
float beta = j == 0 ? 0 : 1;
if (!bias.empty()){{
if (!bias.empty() && !tune_res.algo_desp.is_int8_inference){{
// use source as bias
beta = 1;
}}
if (!output_add.empty() && tune_res.algo_desp.is_int8_inference){{
// use source as bias
beta = output_add_scale;
}}
if (j > 0){{
bias = tv::Tensor();
}}
......@@ -2158,9 +2220,8 @@ class ConvGemmOps(pccm.ParameterizedClass):
false, // reverse_mask
mask_ptr[j],
-1, // mask_width
1.0, beta,
alpha, beta,
stream_int,
mask_int_count, // mask_int_count is after stream_int
tv::Tensor(), // workspace
false, // verbose
timer,
......@@ -2168,7 +2229,9 @@ class ConvGemmOps(pccm.ParameterizedClass):
bias,
act_alpha,
act_beta,
act_type);
act_type,
scale,
output_add);
}}
// auto end_ev = tv::CUDAEvent();
// end_ev.record(stream_int);
......@@ -2193,7 +2256,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
code.arg("mask_output_fwd", "tv::Tensor")
code.arg("masks", "tv::Tensor")
code.arg("mask_int_count", "int")
code.arg("arch", "std::tuple<int, int>")
code.arg("mask_width", "int")
......@@ -2286,7 +2348,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
tv::Tensor(), // mask_output
1.0, 0.0,
stream_int,
mask_int_count,
auto_fp32_accum,
fp32_accum,
5, // num_run
......@@ -2311,7 +2372,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
tv::Tensor(), // mask_output
1.0, 0.0,
stream_int,
mask_int_count,
auto_fp32_accum,
fp32_accum,
5, // num_run
......@@ -2354,7 +2414,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
-1, // mask_width
1.0, beta,
stream_int,
mask_int_count,
tv::Tensor(), // workspace
false, // verbose
timer);
......@@ -2372,7 +2431,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
mask_width,
1.0, 0.0,
stream_int,
mask_int_count,
workspace, // workspace
false, // verbose
timer);
......
......@@ -829,7 +829,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
uint32_t filter_mask_in = (1u << ((RS - 1 - filter_offset) % 32));
uint32_t filter_mask_in_offset = (RS - 1 - filter_offset) / 32;
// uint32_t filter_mask_center = (1u << (RS / 2));
loc_iter.set_filter_offset(filter_offset);
int indices_pair_size_mul_RS = indices_pair_size * RS;
int filter_offset_mul_indices_pair_size = filter_offset * indices_pair_size;
......@@ -1255,13 +1254,13 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
f"tv::array<int, {self.ndim}>")
code.arg("transposed", f"bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0")
code.arg("mask_int_count", "int", "1")
code.raw(f"""
auto custream = reinterpret_cast<cudaStream_t>(stream_int);
// TODO stream
// TODO handle num input == 0
int kv = ksize.op<tv::arrayops::prod>();
int mask_int_count = tv::div_up(kv, 32);
// indice_pairs_bwd: [kv, num_act_in] or empty
// indice_pairs_fwd: [kv, num_act_out]
auto ctx = tv::Context();
......@@ -1504,7 +1503,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
"cumm.tensorview.Tensor = Tensor()")
code.arg("is_train", "bool", "true")
code.arg("stream_int", f"std::uintptr_t", "0")
code.arg("mask_int_count", "int", "1")
code.raw(f"""
int num_act_in_real = indices.dim(0);
......@@ -1523,6 +1521,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
padding[i] = (ksize[i] / 2) * dilation[i];
}}
int kv = ksize.op<tv::arrayops::prod>();
int mask_int_count = tv::div_up(kv, 32);
TV_ASSERT_RT_ERR(kv == indice_pairs.dim(1), "error");
// indice_pairs: [1 or 2, kv, num_act_in] if mask else [2, kv, num_act_in]
// out_inds: [MaxSize, {self.ndim + 1}]
......@@ -1556,8 +1555,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
if (!indice_pair_mask.empty()){{
TV_ASSERT_RT_ERR(indice_pairs.ndim() == 3, "error");
TV_ASSERT_RT_ERR(indice_pairs.dim(0) == (is_train ? 2 : 1), "error");
TV_ASSERT_INVALID_ARG(indice_pair_mask.ndim() == 2, "error");
// indice_pair_mask: [mask_split_count, num_act_in]
TV_ASSERT_INVALID_ARG(indice_pair_mask.ndim() == 3, "error");
// indice_pair_mask: [mask_split_count, num_act_in, num_mask_per_point]
if (indice_pair_mask.dim(0) == 2){{
auto mask_0 = indice_pair_mask[0].slice_first_axis(0, num_act_in_real);
auto mask_1 = indice_pair_mask[1].slice_first_axis(0, num_act_in_real);
......@@ -1571,13 +1570,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
indices.dim(0), indice_pairs.dim(2), kv, is_train);
}}else{{
// indice_pair_mask: [1, num_act_in]
// indice_pair_mask: [1, num_act_in, num_mask_per_point]
tv::cuda::Launch lanucher_fill(num_act_in_real, custream);
if (mask_int_count == 1)
lanucher_fill(cudakers::fill_kernel<uint32_t>, indice_pair_mask.data_ptr<uint32_t>(), (1 << (kv / 2)), indices.dim(0));
else
lanucher_fill(init_subm_multiple_mask_int_kernel<uint32_t>,
indice_pair_mask.data_ptr<uint32_t>(), kv / 2, indices.dim(0), mask_int_count);
if (mask_int_count == 1){{
lanucher_fill(cudakers::fill_kernel<uint32_t>, indice_pair_mask.data_ptr<uint32_t>(), (1 << (kv / 2)), indices.dim(0));
}}
else{{
lanucher_fill(init_subm_multiple_mask_int_kernel<uint32_t>,
indice_pair_mask.data_ptr<uint32_t>(), kv / 2, indices.dim(0), mask_int_count);
}}
TV_ASSERT_RT_ERR(indice_pair_mask.dim(0) == 1, "error");
launcher_num_act_in(calc_subm_conv_indices_mask<table_t, {loc_type}>, loc_iter, hash,
indices.data_ptr<const int>(), indice_pairs.data_ptr<int>(),
......
......@@ -465,14 +465,14 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
table_launcher(kernel::assign_table<table_t>, hash, indices.data_ptr<int>(),
count.data_ptr<int>(),
layout, voxels.dim(0));
auto count_cpu = count.cpu();
int count_val = count_cpu.item<int32_t>();
count_val = count_val > voxels.dim(0) ? voxels.dim(0) : count_val;
launcher(kernel::generate_voxel<table_t>, hash, points.data_ptr<const {self.dtype}>(),
point_indice_data.data_ptr<const int64_t>(), voxels.data_ptr<{self.dtype}>(),
num_per_voxel.data_ptr<int>(), points_voxel_id.data_ptr<int64_t>(), points.dim(1), voxels.dim(1),
voxels.dim(0), vsize_tv, coors_range_tv,
grid_size_tv, grid_stride_tv, points.dim(0));
auto count_cpu = count.cpu();
int count_val = count_cpu.item<int32_t>();
count_val = count_val > voxels.dim(0) ? voxels.dim(0) : count_val;
auto voxel_launcher = tv::cuda::Launch(count_val, custream);
if (empty_mean){{
launcher(kernel::voxel_empty_fill_mean, voxels.data_ptr<{self.dtype}>(),
......
......@@ -37,10 +37,23 @@ from spconv.utils import nullcontext
from torch.nn.init import calculate_gain
from cumm import tensorview as tv
from torch.nn import functional as F
FILTER_HWIO = False
_MAX_NUM_VOXELS_DURING_TRAINING = "max_num_voxels_during_training"
def _apply_act(x: torch.Tensor, act_type: tv.gemm.Activation, act_alpha: float, act_beta: float):
if act_type == tv.gemm.Activation.None_:
return x
elif act_type == tv.gemm.Activation.ReLU:
return F.relu(x)
elif act_type == tv.gemm.Activation.Sigmoid:
return F.sigmoid(x)
elif act_type == tv.gemm.Activation.LeakyReLU:
return F.leaky_relu(x, act_alpha)
else:
raise NotImplementedError
class SparseConvolution(SparseModule):
__constants__ = [
......@@ -104,7 +117,7 @@ class SparseConvolution(SparseModule):
torch.zeros(1, dtype=torch.int32))
self.record_voxel_count = record_voxel_count
if algo is None:
if kv <= 32 and not CPU_ONLY_BUILD:
if kv <= 128 and not CPU_ONLY_BUILD:
if kv < 8:
algo = ConvAlgo.MaskImplicitGemm
else:
......@@ -139,6 +152,19 @@ class SparseConvolution(SparseModule):
self.act_type = act_type
self.act_alpha = act_alpha
self.act_beta = act_beta
self.enable_int8_test_mode: bool = False
self._int8_weight = torch.Tensor()
# calculated by max(abs(weight)) for each channel
self._int8_weight_scale = torch.Tensor()
# calculated by scale self.bias with _int8_input_scale
self._int8_bias = torch.Tensor()
# int8 inference must set _int8_input_scale
self._int8_input_scale: Optional[float] = None
# if _int8_output_scale unset, will execute s8 @ s8 => f16/f32 (weight dtype), i.e. dequantization
self._int8_output_scale: Optional[float] = None
if self.conv1x1:
assert act_type == tv.gemm.Activation.None_, "conv1x1 don't support fused act"
self.reset_parameters()
......@@ -151,11 +177,19 @@ class SparseConvolution(SparseModule):
return getattr(self, _MAX_NUM_VOXELS_DURING_TRAINING)
return None
def set_int8_test(self, enable: bool, input_scale: float, output_scale: Optional[float] = None, weight_scale: Optional[torch.Tensor] = None):
self._int8_input_scale = input_scale
self._int8_output_scale = output_scale
if weight_scale is not None:
self._int8_weight_scale = weight_scale
self.enable_int8_test_mode = enable
def _load_weight_different_layout(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys,
error_msgs):
if self.record_voxel_count and not self.subm and not self.inverse and _MAX_NUM_VOXELS_DURING_TRAINING not in state_dict:
state_dict[prefix + _MAX_NUM_VOXELS_DURING_TRAINING] = torch.zeros(
name = prefix + _MAX_NUM_VOXELS_DURING_TRAINING
if self.record_voxel_count and not self.subm and not self.inverse and name not in state_dict:
state_dict[name] = torch.zeros(
1, dtype=torch.int32)
if not SAVED_WEIGHT_LAYOUT:
return
......@@ -255,7 +289,10 @@ class SparseConvolution(SparseModule):
def is_inverseable(self):
return self.indice_key is not None and not self.subm
def forward(self, input: SparseConvTensor):
def forward(self, input: SparseConvTensor, add_input: Optional[SparseConvTensor] = None):
return self._conv_forward(input, self.weight, self.bias, add_input)
def _conv_forward(self, input: SparseConvTensor, weight: torch.Tensor, bias: Optional[torch.Tensor], add_input: Optional[SparseConvTensor] = None):
assert isinstance(input, SparseConvTensor)
assert input.features.shape[
1] == self.in_channels, "channel size mismatch"
......@@ -264,9 +301,34 @@ class SparseConvolution(SparseModule):
indices = input.indices
spatial_shape = input.spatial_shape
batch_size = input.batch_size
bias_for_training = self.bias if self.training else None
bias_for_infer = self.bias if not self.training else None
bias_for_training = bias if self.training else None
bias_for_infer = bias if not self.training else None
output_scale = None
output_add_scale = 1.0
if self.enable_int8_test_mode:
assert not self.training, "must in eval mode"
assert self.algo == ConvAlgo.MaskImplicitGemm, "int8 inference only support MaskImplicitGemm"
assert bias_for_infer is not None, "conv-bn-relu must be fused"
assert self._int8_input_scale is not None
if features.dtype != torch.int8:
# quantize
features = torch.clamp(torch.round(features / self._int8_input_scale), -128, 127).to(torch.int8)
output_scale = self._int8_output_scale
int8_out_scale = output_scale
if int8_out_scale is None:
int8_out_scale = 1
if add_input is not None:
assert add_input.int8_scale is not None, "only support int8 add"
output_add_scale = add_input.int8_scale
if self._int8_weight.numel() == 0:
with torch.no_grad():
assert ALL_WEIGHT_IS_KRSC
weight_scales = torch.abs(weight).view(self.out_channels, -1).max(1)[0]
num_1s = [1] * (self.ndim + 1)
self._int8_weight = (weight / weight_scales.view(self.out_channels, *num_1s) * 127).to(torch.int8)
if self._int8_weight_scale.numel() == 0:
self._int8_weight_scale = int8_out_scale / (self._int8_input_scale * weight_scales)
self._int8_bias = bias_for_infer * int8_out_scale
if self.training:
msg = "act don't support backward, only used in inference"
assert self.act_type == tv.gemm.Activation.None_, msg
......@@ -310,18 +372,19 @@ class SparseConvolution(SparseModule):
"out_channels": self.out_channels,
}
}
if self.conv1x1:
if self.conv1x1 and not self.enable_int8_test_mode:
# in int8 test mode, we don't implement conv1x1 via mm.
if FILTER_HWIO:
features = torch.mm(
input.features,
self.weight.view(self.out_channels, self.in_channels).T)
weight.view(self.out_channels, self.in_channels).T)
else:
features = torch.mm(
input.features,
self.weight.view(self.in_channels, self.out_channels))
weight.view(self.in_channels, self.out_channels))
if self.bias is not None:
features += self.bias
if bias is not None:
features += bias
out_tensor = out_tensor.replace_feature(features)
# padding may change spatial shape of conv 1x1.
out_tensor.spatial_shape = out_spatial_shape
......@@ -413,7 +476,7 @@ class SparseConvolution(SparseModule):
if self.subm:
out_features = Fsp.indice_subm_conv(
features,
self.weight,
weight,
indice_pairs_calc,
indice_pair_num,
outids.shape[0],
......@@ -427,7 +490,7 @@ class SparseConvolution(SparseModule):
if self.inverse:
out_features = Fsp.indice_inverse_conv(
features,
self.weight,
weight,
indice_pairs_calc,
indice_pair_num,
outids.shape[0],
......@@ -440,7 +503,7 @@ class SparseConvolution(SparseModule):
else:
out_features = Fsp.indice_conv(
features,
self.weight,
weight,
indice_pairs_calc,
indice_pair_num,
outids.shape[0],
......@@ -481,11 +544,13 @@ class SparseConvolution(SparseModule):
mask_argsort_fwd_splits = datas.mask_argsort_fwd_splits
mask_argsort_bwd_splits = datas.mask_argsort_bwd_splits
masks = datas.masks
mask_int_count = datas.mask_int_count
assert self.subm, "only support reuse subm indices"
self._check_subm_reuse_valid(input, spatial_shape,
datas)
else:
if input.benchmark:
torch.cuda.synchronize()
t = time.time()
with input._timer.namespace("gen_pairs"):
# we need to gen bwd indices for regular conv
# because it may be inversed.
......@@ -514,7 +579,11 @@ class SparseConvolution(SparseModule):
print(msg, file=sys.stderr)
spconv_save_debug_data(indices)
raise e
if input.benchmark:
torch.cuda.synchronize()
interval = time.time() - t
out_tensor.benchmark_record[
self.name]["indice_gen_time"].append(interval)
outids = res[0]
num_inds_per_loc = res[1]
pair_fwd = res[2]
......@@ -524,7 +593,6 @@ class SparseConvolution(SparseModule):
mask_argsort_fwd_splits = res[6]
mask_argsort_bwd_splits = res[7]
masks = res[8]
mask_int_count = res[9]
if self.indice_key is not None:
indice_data = ImplicitGemmIndiceData(
outids,
......@@ -543,8 +611,7 @@ class SparseConvolution(SparseModule):
ksize=self.kernel_size,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
mask_int_count=mask_int_count)
dilation=self.dilation)
msg = f"your indice key {self.indice_key} already exists in this sparse tensor."
assert self.indice_key not in indice_dict, msg
indice_dict[self.indice_key] = indice_data
......@@ -552,16 +619,43 @@ class SparseConvolution(SparseModule):
torch.cuda.synchronize()
t = time.time()
num_activate_out = outids.shape[0]
out_features = Fsp.implicit_gemm(
features, self.weight, pair_fwd, pair_bwd,
pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits,
num_activate_out, masks, mask_int_count, self.training, self.subm,
input._timer, self.fp32_accum,
bias_for_infer,
self.act_alpha,
self.act_beta,
self.act_type)
weight_cur = weight
bias_cur = bias_for_infer
if self.enable_int8_test_mode:
assert features.dtype == torch.int8, "in int8 test mode, feature must be int8"
weight_cur = self._int8_weight
bias_cur = self._int8_bias
if self.training:
out_features = Fsp.implicit_gemm(
features, weight_cur, pair_fwd, pair_bwd,
pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits,
num_activate_out, masks, self.training, self.subm,
input._timer, self.fp32_accum,
bias_cur,
self.act_alpha,
self.act_beta,
self.act_type)
else:
output_dtype = None
if self._int8_output_scale is None:
output_dtype = weight.dtype
out_features, _, _ = ops.implicit_gemm(
features, weight_cur, pair_fwd, pair_mask_fwd_splits,
mask_argsort_fwd_splits,
num_activate_out, masks, self.training, self.subm,
input._timer, self.fp32_accum,
bias_cur,
self.act_alpha,
self.act_beta,
self.act_type,
# TODO do we really need output scale to scale bias in kernel?
1.0, # output_scale
self._int8_weight_scale, # scale
output_add=add_input.features if add_input is not None else None,
output_add_scale=output_add_scale,
output_dtype=output_dtype)
if bias_for_training is not None:
out_features += bias_for_training
if input.benchmark:
......@@ -581,9 +675,10 @@ class SparseConvolution(SparseModule):
out_tensor.indices = outids
out_tensor.indice_dict = indice_dict
out_tensor.spatial_shape = out_spatial_shape
# print(outids.shape, spatial_shape, self.kernel_size, self.stride, self.padding,
# self.dilation, self.output_padding, out_spatial_shape)
if add_input is not None and not self.enable_int8_test_mode:
out_tensor = out_tensor.replace_feature(_apply_act(out_tensor.features + add_input.features, self.act_type, self.act_alpha, self.act_beta))
out_tensor.int8_scale = output_scale
return out_tensor
def _check_subm_reuse_valid(self, inp: SparseConvTensor,
......
......@@ -89,8 +89,7 @@ class ImplicitGemmIndiceData(object):
out_spatial_shape, is_subm: bool, algo: ConvAlgo,
ksize: List[int], stride: List[int], dilation: List[int], padding: List[int],
in_voxel_num: Optional[Any] = None,
out_voxel_num: Optional[Any] = None,
mask_int_count: int=1):
out_voxel_num: Optional[Any] = None):
self.out_indices = out_indices
self.indices = indices
self.pair_fwd = pair_fwd
......@@ -111,7 +110,6 @@ class ImplicitGemmIndiceData(object):
# in/out voxel_num is only used in tensorrt conversion.
self.in_voxel_num = in_voxel_num
self.out_voxel_num = out_voxel_num
self.mask_int_count = mask_int_count
def scatter_nd(indices, updates, shape):
......@@ -183,6 +181,8 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
self.thrust_allocator = ThrustSortAllocator(features.device)
self._timer = CUDAKernelTimer(enable_timer)
self.force_algo = force_algo
# for simple int8 torch inference
self.int8_scale: Optional[float] = None
def replace_feature(self, feature: torch.Tensor):
"""we need to replace x.features = F.relu(x.features) with x = x.replace_feature(F.relu(x.features))
......
......@@ -198,7 +198,6 @@ class SparseImplicitGemmFunction(Function):
mask_argsort_bwd_splits: List[torch.Tensor],
num_activate_out: int,
masks: List[np.ndarray],
mask_int_count: int,
is_train: bool,
is_subm: bool,
timer: CUDAKernelTimer = CUDAKernelTimer(False),
......@@ -210,7 +209,7 @@ class SparseImplicitGemmFunction(Function):
try:
out, mask_out, mask_width = ops.implicit_gemm(
features, filters, pair_fwd, pair_mask_fwd_splits,
mask_argsort_fwd_splits, num_activate_out, masks, mask_int_count, is_train,
mask_argsort_fwd_splits, num_activate_out, masks, is_train,
is_subm, timer, fp32_accum, bias, act_alpha, act_beta,
act_type)
except Exception as e:
......@@ -236,7 +235,6 @@ class SparseImplicitGemmFunction(Function):
ctx.masks = masks
ctx.is_subm = is_subm
ctx.fp32_accum = fp32_accum
ctx.mask_int_count = mask_int_count
return out
@staticmethod
......@@ -255,7 +253,6 @@ class SparseImplicitGemmFunction(Function):
is_subm = ctx.is_subm
timer = ctx.timer
fp32_accum = ctx.fp32_accum
mask_int_count = ctx.mask_int_count
try:
input_bp, filters_bp = ops.implicit_gemm_backward(
......@@ -270,7 +267,6 @@ class SparseImplicitGemmFunction(Function):
mask_argsort_bwd_splits,
mask_output_fwd=mask_out,
masks=masks,
mask_int_count=mask_int_count,
mask_width=mask_width,
is_subm=is_subm,
timer=timer,
......@@ -286,7 +282,7 @@ class SparseImplicitGemmFunction(Function):
mask_argsort_bwd_splits, masks))
raise e
None_9 = [None] * 17
None_9 = [None] * 16
return (input_bp, filters_bp, *None_9)
......
......@@ -23,7 +23,7 @@ import spconv
from spconv.core import AlgoHint, ConvAlgo
from typing import Dict, List, Optional, Union
from spconv.pytorch.core import ThrustSortAllocator
from spconv.pytorch.cppcore import TorchAllocator, torch_tensor_to_tv, get_current_stream, get_arch, TorchSpconvMatmul
from spconv.pytorch.cppcore import _TORCH_DTYPE_TO_TV, TorchAllocator, torch_tensor_to_tv, get_current_stream, get_arch, TorchSpconvMatmul
from spconv.core_cc.csrc.sparse.all import SpconvOps
from spconv.core_cc.csrc.sparse.alloc import ExternalAllocator
from spconv.constants import SPCONV_CPP_INDICE_PAIRS, SPCONV_CPP_INDICE_PAIRS_IGEMM, SPCONV_CPP_GEMM, SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE, SPCONV_ALLOW_TF32
......@@ -31,7 +31,7 @@ import spconv.core_cc as _ext
from spconv.core_cc.csrc.sparse.convops.spops import ConvGemmOps
from spconv.core_cc.csrc.sparse.inference import InferenceOps
from spconv.cppconstants import CPU_ONLY_BUILD
from cumm.gemm.codeops import div_up
from spconv.utils import nullcontext
if not CPU_ONLY_BUILD:
......@@ -365,7 +365,7 @@ def get_indice_pairs_implicit_gemm(
timer_cpp = tv.CUDAKernelTimer(False)
if timer._timer is not None:
timer_cpp = timer._timer
mask_tensor, num_act_out, mask_int_count = SpconvOps.get_indice_pairs_implicit_gemm(
mask_tensor, num_act_out = SpconvOps.get_indice_pairs_implicit_gemm(
thalloc,
torch_tensor_to_tv(indices),
batch_size,
......@@ -408,7 +408,7 @@ def get_indice_pairs_implicit_gemm(
assert pair.shape[0] == 2
pair_bwd = pair[1]
return (out_inds, indice_num_per_loc, pair[0], pair_bwd,
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks, mask_int_count)
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks)
else:
pair_bwd = thalloc.allocated.get(AllocKeys.PairBwd, torch.Tensor())
pair_fwd = thalloc.allocated[AllocKeys.PairFwd]
......@@ -437,16 +437,13 @@ def get_indice_pairs_implicit_gemm(
]
return (out_inds, indice_num_per_loc, pair_fwd, pair_bwd,
pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits, masks, mask_int_count)
mask_argsort_fwd_splits, mask_argsort_bwd_splits, masks)
assert indices.is_cuda, "implicit gemm only support cuda"
ndim = indices.shape[1] - 1
kv: int = functools.reduce(lambda x, y: x * y, ksize, 1)
# TODO in future we will support up to 128 kernel volume.
# assert kv <= 32, "currently only support kernel volume <= 32 to use implicit gemm"
mask_int_count = (kv + 31) // 32
if 1 < mask_int_count < 4:
mask_int_count = 4
assert mask_int_count in [1, 4]
mask_int_count = div_up(kv, 32)
if not subm:
if transpose:
......@@ -511,7 +508,7 @@ def get_indice_pairs_implicit_gemm(
hashdata = _HashData(out_inds.shape[0], use_int64_hash_k,
indices.device)
pair_mask = torch.empty((mask_split_count, indices.shape[0] * mask_int_count),
pair_mask = torch.empty((mask_split_count, indices.shape[0], mask_int_count),
dtype=torch.int32,
device=indices.device)
......@@ -531,8 +528,7 @@ def get_indice_pairs_implicit_gemm(
dilation=dilation,
indice_pair_mask=pair_mask_tv,
backward=is_train,
stream_int=stream,
mask_int_count=mask_int_count)
stream_int=stream)
# torch.cuda.synchronize()
# print("SUBM0", time.time() - t)
# CONV.stream_synchronize(stream)
......@@ -549,10 +545,10 @@ def get_indice_pairs_implicit_gemm(
# so I use this stupid hack to use torch allocator without touch
# pytorch binary (c++).
# f**k thrust
SpconvOps.sort_1d_by_key_allocator_mask_auto(pair_mask_tv[j],
alloc.alloc,
mask_argsort_tv[j], stream,
mask_int_count)
SpconvOps.sort_1d_by_key_allocator(pair_mask_tv[j],
alloc.alloc,
mask_argsort_tv[j], stream,
mask_int_count)
# CONV.stream_synchronize(stream)
pair_mask_in_splits = [pair_mask[i] for i in range(mask_split_count)]
mask_argsort_in_splits = [
......@@ -560,10 +556,10 @@ def get_indice_pairs_implicit_gemm(
]
if is_train:
return (out_inds, indice_num_per_loc, pair[0], pair[1],
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks, mask_int_count)
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks)
else:
return (out_inds, indice_num_per_loc, pair[0], torch.Tensor(),
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks, mask_int_count)
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks)
else:
max_num_act = SpconvOps.get_handcrafted_max_act_out(
indices.shape[0], ksize, stride, padding, dilation)
......@@ -621,7 +617,6 @@ def get_indice_pairs_implicit_gemm(
stream_int=stream)
uniq_out_indices_offset_tv = tv.Tensor()
with timer.record(f"unique_{indice_pairs_uniq.shape[0]}", stream):
if direct_table:
uniq_cnt = torch.zeros([1],
dtype=torch.int32,
......@@ -655,7 +650,7 @@ def get_indice_pairs_implicit_gemm(
-1,
dtype=indices.dtype,
device=indices.device)
pair_mask_fwd = torch.zeros((mask_split_count, num_act_out * mask_int_count),
pair_mask_fwd = torch.zeros((mask_split_count, num_act_out, mask_int_count),
dtype=torch.int32,
device=indices.device)
pair_fwd_tv = torch_tensor_to_tv(pair_fwd)
......@@ -665,7 +660,7 @@ def get_indice_pairs_implicit_gemm(
pair_mask_bwd_tv = tv.Tensor()
if is_train:
pair_mask_bwd = torch.zeros(
(mask_split_count, indices.shape[0] * mask_int_count),
(mask_split_count, indices.shape[0], mask_int_count),
dtype=torch.int32,
device=indices.device)
pair_mask_bwd_tv = torch_tensor_to_tv(pair_mask_bwd,
......@@ -713,8 +708,7 @@ def get_indice_pairs_implicit_gemm(
padding=padding,
dilation=dilation,
transposed=transpose,
stream_int=stream,
mask_int_count=mask_int_count)
stream_int=stream)
mask_argsort_fwd = torch.empty((mask_split_count, out_inds.shape[0]),
dtype=torch.int32,
device=indices.device)
......@@ -766,24 +760,24 @@ def get_indice_pairs_implicit_gemm(
else:
# if pair_mask_bwd_tv.dim(1) > pair_mask_fwd_tv.dim(1):
if not is_train:
SpconvOps.sort_1d_by_key_allocator_mask_auto(pair_mask_fwd_tv[0],
SpconvOps.sort_1d_by_key_allocator(pair_mask_fwd_tv[0],
alloc.alloc,
mask_argsort_fwd_tv[0],
stream,
mask_int_count)
else:
if pair_mask_bwd_tv.dim(1) > pair_mask_fwd_tv.dim(1):
SpconvOps.sort_1d_by_key_allocator_mask_auto(
SpconvOps.sort_1d_by_key_allocator(
pair_mask_bwd_tv[0], alloc.alloc,
mask_argsort_bwd_tv[0], stream, mask_int_count)
SpconvOps.sort_1d_by_key_allocator_mask_auto(
SpconvOps.sort_1d_by_key_allocator(
pair_mask_fwd_tv[0], alloc.alloc,
mask_argsort_fwd_tv[0], stream, mask_int_count)
else:
SpconvOps.sort_1d_by_key_allocator_mask_auto(
SpconvOps.sort_1d_by_key_allocator(
pair_mask_fwd_tv[0], alloc.alloc,
mask_argsort_fwd_tv[0], stream, mask_int_count)
SpconvOps.sort_1d_by_key_allocator_mask_auto(
SpconvOps.sort_1d_by_key_allocator(
pair_mask_bwd_tv[0], alloc.alloc,
mask_argsort_bwd_tv[0], stream, mask_int_count)
......@@ -808,7 +802,7 @@ def get_indice_pairs_implicit_gemm(
return (out_inds, indice_num_per_loc, pair_fwd, pair_bwd,
pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits, masks, mask_int_count)
mask_argsort_fwd_splits, mask_argsort_bwd_splits, masks)
def indice_conv(features: torch.Tensor,
......@@ -1457,7 +1451,6 @@ def implicit_gemm(features: torch.Tensor,
mask_argsort_fwd_splits: List[torch.Tensor],
num_activate_out: int,
masks: List[np.ndarray],
mask_int_count: int,
is_train: bool,
is_subm: bool,
timer: CUDAKernelTimer = CUDAKernelTimer(False),
......@@ -1465,16 +1458,31 @@ def implicit_gemm(features: torch.Tensor,
bias: Optional[torch.Tensor] = None,
act_alpha: float = 0.0,
act_beta: float = 0.0,
act_type: tv.gemm.Activation = tv.gemm.Activation.None_):
act_type: tv.gemm.Activation = tv.gemm.Activation.None_,
output_scale: float = 1.0,
scale: Optional[torch.Tensor] = None,
output_add: Optional[torch.Tensor] = None,
output_add_scale: float = 1.0,
output_dtype: Optional[torch.dtype] = None):
stream = get_current_stream()
bias_tv = tv.Tensor()
scale_tv = tv.Tensor()
output_add_tv = tv.Tensor()
if output_add is not None:
assert features.dtype == torch.int8, "fused residual add only support int8"
if bias is not None:
bias_tv = torch_tensor_to_tv(bias)
if scale is not None:
scale_tv = torch_tensor_to_tv(scale)
if output_add is not None:
output_add_tv = torch_tensor_to_tv(output_add)
if not features.is_contiguous():
features = features.contiguous()
assert features.is_contiguous()
assert filters.is_contiguous()
if output_dtype is None:
output_dtype = features.dtype
if SPCONV_CPP_GEMM and CONV_CPP is not None:
alloc = TorchAllocator(features.device)
......@@ -1497,13 +1505,15 @@ def implicit_gemm(features: torch.Tensor,
if fp32_accum is None:
fp32_accum = False
arch = get_arch()
output_dtype_tv = _TORCH_DTYPE_TO_TV[output_dtype]
mask_width, tune_res_cpp = ConvGemmOps.implicit_gemm(
alloc, CONV_CPP, features_tv, filters_tv, pair_fwd_tv,
pair_mask_fwd_splits_tv, mask_argsort_fwd_splits_tv,
num_activate_out, mask_tv, mask_int_count, arch, is_train, is_subm, stream,
num_activate_out, mask_tv, arch, is_train, is_subm, stream,
timer_cpp, auto_fp32_accum, fp32_accum, bias_tv, act_alpha, act_beta, act_type,
use_tf32=constants.SPCONV_ALLOW_TF32)
use_tf32=constants.SPCONV_ALLOW_TF32, output_scale=output_scale,
scale=scale_tv, output_add=output_add_tv, output_add_scale=output_add_scale,
output_dtype=output_dtype_tv)
out_features = alloc.allocated[AllocKeys.OutFeatures]
mask_output_fwd = alloc.allocated.get(AllocKeys.MaskOutputFwd, None)
if is_train:
......@@ -1515,8 +1525,8 @@ def implicit_gemm(features: torch.Tensor,
# t = time.time()
if features.dtype == torch.int8 or features.dtype == torch.qint8:
raise NotImplementedError("work in progress")
# if features.dtype == torch.int8 or features.dtype == torch.qint8:
# raise NotImplementedError("work in progress")
# here filters is KRSC
masks_ints = [m.item() for m in masks]
out_channel = filters.shape[0]
......@@ -1524,13 +1534,14 @@ def implicit_gemm(features: torch.Tensor,
num_split = len(pair_mask_fwd_splits)
filters = filters.reshape(out_channel, -1, filters.shape[-1])
kv = filters.shape[1]
mask_int_count = div_up(kv, 32)
if is_subm:
out_features = torch.empty((num_activate_out, out_channel),
dtype=features.dtype,
dtype=output_dtype,
device=features.device)
else:
out_features = torch.zeros((num_activate_out, out_channel),
dtype=features.dtype,
dtype=output_dtype,
device=features.device)
pair_fwd_tv = torch_tensor_to_tv(pair_fwd)
......@@ -1568,13 +1579,13 @@ def implicit_gemm(features: torch.Tensor,
stream=stream,
fp32_accum=fp32_accum,
use_tf32=constants.SPCONV_ALLOW_TF32,
mask_int_count=mask_int_count)
bias=bias_tv, scale=scale_tv)
mask_width = tune_res.algo_desp.tile_shape[0]
if is_train:
mask_output_fwd = torch.empty(
[num_split,
codeops.div_up(num_activate_out, mask_width) * mask_int_count],
codeops.div_up(num_activate_out, mask_width), mask_int_count],
dtype=torch.int32,
device=features.device)
# pytorch don't support uint32.
......@@ -1597,12 +1608,16 @@ def implicit_gemm(features: torch.Tensor,
bias_tv = tv.Tensor()
if bias is not None:
bias_tv = torch_tensor_to_tv(bias)
alpha = 1.0
if tune_res.algo_desp.is_int8_inference:
alpha = output_scale
with timer.record("implicit_gemm", stream):
for j in range(num_split):
beta = 0 if j == 0 else 1
if bias is not None:
if bias is not None and not tune_res.algo_desp.is_int8_inference:
beta = 1
if output_add is not None and tune_res.algo_desp.is_int8_inference:
beta = output_add_scale
CONV.run_with_tuned_result(
tune_res,
ConvOpType.kForward,
......@@ -1616,6 +1631,7 @@ def implicit_gemm(features: torch.Tensor,
reverse_mask=False,
mask_filter=masks_ints[j],
mask_width=-1,
alpha=alpha,
beta=beta,
stream=stream,
verbose=False,
......@@ -1623,91 +1639,8 @@ def implicit_gemm(features: torch.Tensor,
act_type=act_type,
act_alpha=act_alpha,
act_beta=act_beta,
mask_int_count=mask_int_count)
# INT8_TEST = True
# if INT8_TEST:
# if features.shape[1] % 32 != 0:
# return out_features, mask_output_fwd, mask_width
# features = features.to(torch.int8)
# filters = filters.to(torch.int8)
# out_features_i8 = out_features.to(torch.int8)
# features_tv = torch_tensor_to_tv(features)
# filters_tv = torch_tensor_to_tv(filters)
# out_features_i8_tv = torch_tensor_to_tv(out_features_i8)
# tune_res = CONV.get_tuned_algo(ConvOpType.kForward, features_tv.dtype,
# filters_tv.dtype, out_features_i8_tv.dtype,
# out_channel, in_channel, arch)
# if tune_res is None:
# tune_res, _ = CONV.tune_and_cache(
# ConvOpType.kForward,
# features_tv,
# filters_tv,
# out_features_i8_tv,
# NHWC,
# KRSC,
# NHWC,
# arch,
# mask=pair_mask_fwd_split_tvs[0],
# mask_argsort=mask_argsort_fwd_split_tvs[0],
# indices=pair_fwd_tv,
# reverse_mask=False,
# mask_filter=masks[0].item(),
# stream=stream,
# fp32_accum=fp32_accum)
# mask_width = tune_res.algo_desp.tile_shape[0]
# if is_train:
# mask_output_fwd = torch.empty(
# [num_split,
# codeops.div_up(num_activate_out, mask_width)],
# dtype=torch.int32,
# device=features.device)
# # pytorch don't support uint32.
# mask_output_fwd_tv = torch_tensor_to_tv(mask_output_fwd,
# dtype=tv.uint32)
# mask_output_fwd_tvs = [mask_output_fwd_tv[j] for j in range(num_split)]
# else:
# mask_output_fwd = None
# mask_output_fwd_tv = tv.Tensor()
# mask_output_fwd_tvs = [tv.Tensor() for _ in range(num_split)]
# # CONV.stream_synchronize(stream)
# # print("FPREPARE", time.time() - t)
# # # t = time.time()
# # CONV.stream_synchronize(stream)
# # t = time.time()
# # print(tune_res.algo_desp)
# with tv.measure_and_print(f"i8 time {features.shape[0]}-{in_channel}-{out_channel}"):
# with timer.record("implicit_gemm_i8", stream):
# for j in range(num_split):
# beta = 0 if j == 0 else 1
# CONV.run_with_tuned_result(
# tune_res,
# ConvOpType.kForward,
# features_tv,
# filters_tv,
# out_features_i8_tv,
# mask=pair_mask_fwd_split_tvs[j],
# mask_argsort=mask_argsort_fwd_split_tvs[j],
# mask_output=mask_output_fwd_tvs[j],
# indices=pair_fwd_tv,
# reverse_mask=False,
# mask_filter=masks_ints[j],
# mask_width=-1,
# beta=beta,
# stream=stream,
# verbose=False)
# torch.cuda.synchronize()
# if DEBUG:
# CONV.stream_synchronize(stream)
# dura = time.time() - t
# print("F", tune_res.algo_desp, dura)
# print(out_features.mean(), out_features.max(), out_features.min())
scale=scale_tv,
output_add=output_add)
return out_features, mask_output_fwd, mask_width
......@@ -1722,7 +1655,6 @@ def implicit_gemm_backward(features: torch.Tensor,
mask_argsort_bwd_splits: List[torch.Tensor],
mask_output_fwd: Optional[torch.Tensor],
masks: List[np.ndarray],
mask_int_count: int,
mask_width: int,
is_subm: bool,
timer: CUDAKernelTimer = CUDAKernelTimer(False),
......@@ -1782,7 +1714,7 @@ def implicit_gemm_backward(features: torch.Tensor,
alloc, CONV_CPP, features_tv, filters_tv, out_bp_tv, pair_fwd_tv,
pair_bwd_tv, pair_mask_fwd_splits_tv, pair_mask_bwd_splits_tv,
mask_argsort_fwd_splits_tv, mask_argsort_bwd_splits_tv,
mask_output_fwd_tv, mask_tv, mask_int_count, arch, mask_width, is_subm, stream,
mask_output_fwd_tv, mask_tv, arch, mask_width, is_subm, stream,
timer_cpp, auto_fp32_accum, fp32_accum,
use_tf32=constants.SPCONV_ALLOW_TF32)
din = alloc.allocated[AllocKeys.DIn]
......@@ -1802,6 +1734,8 @@ def implicit_gemm_backward(features: torch.Tensor,
filters = filters.reshape(out_channel, -1, filters.shape[-1])
kv = filters.shape[1]
need_dynamic_mask = kv > 32
mask_int_count = div_up(kv, 32)
pair_fwd_tv = torch_tensor_to_tv(pair_fwd)
pair_bwd_tv = torch_tensor_to_tv(pair_bwd)
......@@ -1831,11 +1765,12 @@ def implicit_gemm_backward(features: torch.Tensor,
dgrad_tune_res = CONV.get_tuned_algo(ConvOpType.kBackwardInput,
din_tv.dtype, filters_tv.dtype,
dout_tv.dtype, out_channel,
in_channel, arch)
in_channel, arch, need_dynamic_mask=need_dynamic_mask)
wgrad_tune_res = CONV.get_tuned_algo(ConvOpType.kBackwardWeight,
features_tv.dtype, dfilters_tv.dtype,
dout_tv.dtype, out_channel,
in_channel, arch, mask_width)
in_channel, arch, mask_width,
need_dynamic_mask=need_dynamic_mask)
if dgrad_tune_res is None:
# TODO split mask maybe completely invalid
......@@ -1861,8 +1796,7 @@ def implicit_gemm_backward(features: torch.Tensor,
mask_filter=masks[0].item(),
stream=stream,
fp32_accum=fp32_accum,
use_tf32=constants.SPCONV_ALLOW_TF32,
mask_int_count=mask_int_count)
use_tf32=constants.SPCONV_ALLOW_TF32)
if wgrad_tune_res is None:
wgrad_tune_res, _ = CONV.tune_and_cache(
ConvOpType.kBackwardWeight,
......@@ -1881,8 +1815,7 @@ def implicit_gemm_backward(features: torch.Tensor,
mask_output=tv.Tensor(),
mask_width=mask_width,
stream=stream,
use_tf32=constants.SPCONV_ALLOW_TF32,
mask_int_count=mask_int_count)
use_tf32=constants.SPCONV_ALLOW_TF32)
workspace_size = CONV.query_workspace_size(wgrad_tune_res.algo_desp,
wgrad_tune_res.splitk,
ConvOpType.kBackwardWeight,
......@@ -1919,8 +1852,7 @@ def implicit_gemm_backward(features: torch.Tensor,
mask_filter=masks[j].item(),
mask_width=-1,
beta=beta,
stream=stream,
mask_int_count=mask_int_count)
stream=stream)
# for backward weight, beta = 0 because each split
# handle different kernel locations.
# TODO remove D iterator in backward weight kernel
......@@ -1939,8 +1871,7 @@ def implicit_gemm_backward(features: torch.Tensor,
mask_width=mask_width,
beta=0,
workspace=workspace_tv,
stream=stream,
mask_int_count=mask_int_count)
stream=stream)
return (din, dfilters.reshape(filters_shape))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment