Commit aa26c99e authored by yan.yan's avatar yan.yan
Browse files

working on quantization

parent ee8c9465
...@@ -207,7 +207,7 @@ class Net(nn.Module): ...@@ -207,7 +207,7 @@ class Net(nn.Module):
pool_algo = algo pool_algo = algo
# pool_algo = ConvAlgo.Native # pool_algo = ConvAlgo.Native
self.net = spconv.SparseSequential( self.net = spconv.SparseSequential(
spconv.SubMConv3d(3, 64, 3, bias=False, indice_key="c0", spconv.SubMConv3d(16, 64, 3, bias=False, indice_key="c0",
algo=algo), algo=algo),
nn.BatchNorm1d(64), nn.BatchNorm1d(64),
nn.ReLU(), nn.ReLU(),
...@@ -373,6 +373,11 @@ class Net(nn.Module): ...@@ -373,6 +373,11 @@ class Net(nn.Module):
x = spconv.SparseConvTensor(features, coors, self.shape, batch_size, voxel_num=vx_num) x = spconv.SparseConvTensor(features, coors, self.shape, batch_size, voxel_num=vx_num)
return self.net(x) return self.net(x)
def _set_enable_int8_test_inplace(simple_module: torch.fx.GraphModule, enable: bool):
for m in simple_module.modules():
if isinstance(m, SparseConvolution):
if m.in_channels % 32 == 0 and m.out_channels % 32 == 0:
m.enable_int8_test_mode = enable
class MyTracer(torch.fx.Tracer): class MyTracer(torch.fx.Tracer):
...@@ -387,6 +392,7 @@ def main(): ...@@ -387,6 +392,7 @@ def main():
torch.backends.cudnn.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False
with open(Path(__file__).parent.parent / "test" / "data" / "test_spconv.pkl", "rb") as f: with open(Path(__file__).parent.parent / "test" / "data" / "test_spconv.pkl", "rb") as f:
(voxels, coors, spatial_shape) = pickle.load(f) (voxels, coors, spatial_shape) = pickle.load(f)
voxels = np.random.uniform(-1, 1, size=[voxels.shape[0], 16]).astype(np.float32)
np.random.seed(50051) np.random.seed(50051)
device = torch.device("cuda:0") device = torch.device("cuda:0")
device_cpu = torch.device("cpu:0") device_cpu = torch.device("cpu:0")
...@@ -408,6 +414,10 @@ def main(): ...@@ -408,6 +414,10 @@ def main():
out_fused = net_fused(voxels_th_cuda, coors_th_cuda, 1) out_fused = net_fused(voxels_th_cuda, coors_th_cuda, 1)
res = Fsp.sparse_add_hash_based(out_ref, out_fused.minus()) res = Fsp.sparse_add_hash_based(out_ref, out_fused.minus())
print(torch.linalg.norm(res.features)) print(torch.linalg.norm(res.features))
_set_enable_int8_test_inplace(net_fused, True)
qvoxels_cuda = voxels_th_cuda.to(torch.int8)
out_int8 = net_fused(qvoxels_cuda, coors_th_cuda, 1)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
\ No newline at end of file
...@@ -426,7 +426,7 @@ int main(int argc, char **argv) { ...@@ -426,7 +426,7 @@ int main(int argc, char **argv) {
{SPCONV_ALLOC_OUT_FEATURES, out_features}}; {SPCONV_ALLOC_OUT_FEATURES, out_features}};
StaticAllocator alloc2(tensor_dict); StaticAllocator alloc2(tensor_dict);
ConvTunerSimple tuner(ConvMain::get_all_conv_algo_desp()); ConvTunerSimple tuner(ConvMain::get_all_conv_algo_desp());
auto conv_res = ConvGemmOps::implicit_gemm( auto conv_run_status = ConvGemmOps::implicit_gemm(
alloc2, tuner, input_features_real, weights, pair_fwd_real, alloc2, tuner, input_features_real, weights, pair_fwd_real,
pair_mask_splits, mask_argsort_splits, num_act_out_real, pair_mask_splits, mask_argsort_splits, num_act_out_real,
mask_tensor, arch, false, is_subm, mask_tensor, arch, false, is_subm,
...@@ -435,7 +435,7 @@ int main(int argc, char **argv) { ...@@ -435,7 +435,7 @@ int main(int argc, char **argv) {
1.0 /*bias alpha, only used for leaky relu*/, 1.0 /*bias alpha, only used for leaky relu*/,
0.0 /*unused for now*/, tv::gemm::Activation::kReLU); 0.0 /*unused for now*/, tv::gemm::Activation::kReLU);
tv::ssprint("selected conv algo", tv::ssprint("selected conv algo",
std::get<1>(conv_res).algo_desp.__repr__()); std::get<1>(conv_run_status).algo_desp.__repr__());
// FINISH!!! // FINISH!!!
} }
// calc maximum number of output points. // calc maximum number of output points.
......
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import torch
import spconv.pytorch as spconv
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import contextlib
import torch.cuda.amp
@contextlib.contextmanager
def identity_ctx():
yield
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.net = spconv.SparseSequential(
nn.BatchNorm1d(1),
spconv.SubMConv2d(1, 32, 3, 1),
nn.ReLU(),
spconv.SubMConv2d(32, 64, 3, 1),
nn.ReLU(),
spconv.SparseConv2d(64, 64, 2, 2),
spconv.ToDense(),
)
self.fc1 = nn.Linear(14 * 14 * 64, 128)
self.fc2 = nn.Linear(128, 10)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
def forward(self, x: torch.Tensor):
# x: [N, 28, 28, 1], must be NHWC tensor
x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
# create SparseConvTensor manually: see SparseConvTensor.from_dense
x = self.net(x_sp)
x = torch.flatten(x, 1)
x = self.dropout1(x)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
scaler = torch.cuda.amp.grad_scaler.GradScaler()
amp_ctx = contextlib.nullcontext()
if args.fp16:
amp_ctx = torch.cuda.amp.autocast()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
with amp_ctx:
output = model(data)
loss = F.nll_loss(output, target)
scale = 1.0
if args.fp16:
assert loss.dtype is torch.float32
scaler.scale(loss).backward()
# scaler.step() first unscales the gradients of the optimizer's assigned params.
# If these gradients do not contain infs or NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
# scaler.unscale_(optim)
# Since the gradients of optimizer's assigned params are now unscaled, clips as usual.
# You may use the same value for max_norm here as you would without gradient scaling.
# torch.nn.utils.clip_grad_norm_(models[0].net.parameters(), max_norm=0.1)
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
scale = scaler.get_scale()
else:
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(args, model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
amp_ctx = contextlib.nullcontext()
if args.fp16:
amp_ctx = torch.cuda.amp.autocast()
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
with amp_ctx:
output = model(data)
test_loss += F.nll_loss(
output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(
dim=1,
keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print(
'\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size',
type=int,
default=64,
metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size',
type=int,
default=1000,
metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs',
type=int,
default=14,
metavar='N',
help='number of epochs to train (default: 14)')
parser.add_argument('--lr',
type=float,
default=1.0,
metavar='LR',
help='learning rate (default: 1.0)')
parser.add_argument('--gamma',
type=float,
default=0.7,
metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda',
action='store_true',
default=False,
help='disables CUDA training')
parser.add_argument('--seed',
type=int,
default=1,
metavar='S',
help='random seed (default: 1)')
parser.add_argument(
'--log-interval',
type=int,
default=10,
metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save-model',
action='store_true',
default=False,
help='For Saving the current Model')
parser.add_argument('--fp16',
action='store_true',
default=False,
help='For mixed precision training')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
'../data',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
# here we remove norm to get sparse tensor with lots of zeros
# transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size,
shuffle=True,
**kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
'../data',
train=False,
transform=transforms.Compose([
transforms.ToTensor(),
# here we remove norm to get sparse tensor with lots of zeros
# transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size,
shuffle=True,
**kwargs)
model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader)
scheduler.step()
if args.save_model:
torch.save(model.state_dict(), "mnist_cnn.pt")
if __name__ == '__main__':
main()
[build-system] [build-system]
requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm>=0.3.7"] requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm>=0.3.7"]
# requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm @ file:///io/dist/cumm_cu118-0.3.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"] # requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm @ file:///io/dist/cumm_cu120-0.3.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
...@@ -616,6 +616,7 @@ class SimpleConv: ...@@ -616,6 +616,7 @@ class SimpleConv:
algocore.get_conv_algo_desp_from_param(p) algocore.get_conv_algo_desp_from_param(p)
for p in ALL_IMPGEMM_PARAMS for p in ALL_IMPGEMM_PARAMS
] ]
self.all_desps = all_desps
self.prebuilt_desps = prebuilt_desps self.prebuilt_desps = prebuilt_desps
self.prebuilt_desp_names = {str(d) for d in prebuilt_desps} self.prebuilt_desp_names = {str(d) for d in prebuilt_desps}
...@@ -648,13 +649,13 @@ class SimpleConv: ...@@ -648,13 +649,13 @@ class SimpleConv:
tile_ms_list, tile_ns_list, tile_ks_list, tile_shape_to_algos) tile_ms_list, tile_ns_list, tile_ks_list, tile_shape_to_algos)
self.kc_forward_cache: Dict[Tuple[int, int, int, int, int, int, int, self.kc_forward_cache: Dict[Tuple[int, int, int, int, int, int, int,
int], int, bool],
BestConvAlgoByProfile] = {} # for forward BestConvAlgoByProfile] = {} # for forward
self.kc_dgrad_cache: Dict[Tuple[int, int, int, int, int, int, int, self.kc_dgrad_cache: Dict[Tuple[int, int, int, int, int, int, int,
int], BestConvAlgoByProfile] = { int, bool], BestConvAlgoByProfile] = {
} # for backward weight } # for backward weight
self.kc_wgrad_cache: Dict[Tuple[int, int, int, int, int, int, int, self.kc_wgrad_cache: Dict[Tuple[int, int, int, int, int, int, int,
int], BestConvAlgoByProfile] = { int, bool], BestConvAlgoByProfile] = {
} # for backward weight } # for backward weight
self._nvrtc_caches: Dict[Tuple[str, Tuple[int, int]], NVRTCParams] = {} self._nvrtc_caches: Dict[Tuple[str, Tuple[int, int]], NVRTCParams] = {}
...@@ -679,11 +680,12 @@ class SimpleConv: ...@@ -679,11 +680,12 @@ class SimpleConv:
op_type: ConvOpType, op_type: ConvOpType,
mask_width: int, mask_width: int,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
use_tf32: bool = True): use_tf32: bool = True,
bias: tv.Tensor = tv.Tensor(),
scale: tv.Tensor = tv.Tensor()):
avail_algos = get_available_algo_str_from_arch(arch) avail_algos = get_available_algo_str_from_arch(arch)
finally_algos: List[ConvAlgoDesp] = [] finally_algos: List[ConvAlgoDesp] = []
is_fp16 = inp.dtype == tv.float16 and weight.dtype == tv.float16 and out.dtype == tv.float16 is_fp16 = inp.dtype == tv.float16 and weight.dtype == tv.float16 # and out.dtype == tv.float16
use_f32_as_accum = False use_f32_as_accum = False
kv = int(np.prod(weight.shape[1:-1])) kv = int(np.prod(weight.shape[1:-1]))
# for 3d conv, if reduce axis is too large, may cause nan during # for 3d conv, if reduce axis is too large, may cause nan during
...@@ -703,6 +705,10 @@ class SimpleConv: ...@@ -703,6 +705,10 @@ class SimpleConv:
layout_w.interleave, layout_o.interleave, inp.dtype, layout_w.interleave, layout_o.interleave, inp.dtype,
weight.dtype, out.dtype, op_type.value) weight.dtype, out.dtype, op_type.value)
desps = self.static_key_to_desps.get(static_key, None) desps = self.static_key_to_desps.get(static_key, None)
# for d in self.all_desps:
# print(d)
# print(len(desps))
# breakpoint()
if desps is None or len(desps) == 0: if desps is None or len(desps) == 0:
return finally_algos return finally_algos
for desp in desps: for desp in desps:
...@@ -726,11 +732,21 @@ class SimpleConv: ...@@ -726,11 +732,21 @@ class SimpleConv:
ldw = weight.dim(-1) ldw = weight.dim(-1)
ldo = out.dim(-1) ldo = out.dim(-1)
mask_width_valid = True mask_width_valid = True
if desp.op_type.value == ConvOpType.kBackwardWeight.value: if desp.op_type.value == ConvOpType.kBackwardWeight.value:
assert mask_width > 0 assert mask_width > 0
mask_width_valid = mask_width % desp.tile_shape[2] == 0 mask_width_valid = mask_width % desp.tile_shape[2] == 0
require_dynamic_mask = kv > 32
if desp.supported_ldx_conv(ldi, ldw, ldo) and mask_width_valid: if desp.supported_ldx_conv(ldi, ldw, ldo) and mask_width_valid:
if not bias.empty() and not scale.empty():
# int8 inference, bias/scale dtype must equal to compute dtype in gemm
assert bias.dtype == scale.dtype
if desp.dcomp != bias.dtype:
continue
if not desp.is_int8_inference:
continue
else:
if desp.is_int8_inference:
continue
if desp.is_nvrtc: if desp.is_nvrtc:
if not CompileInfo.algo_can_be_nvrtc_compiled(desp.min_arch): if not CompileInfo.algo_can_be_nvrtc_compiled(desp.min_arch):
continue continue
...@@ -747,6 +763,12 @@ class SimpleConv: ...@@ -747,6 +763,12 @@ class SimpleConv:
continue continue
if SPCONV_DEBUG_NVRTC_KERNELS: if SPCONV_DEBUG_NVRTC_KERNELS:
desp.is_nvrtc = True desp.is_nvrtc = True
if require_dynamic_mask:
if not desp.dynamic_mask:
continue
else:
if desp.dynamic_mask:
continue
finally_algos.append(desp) finally_algos.append(desp)
return finally_algos return finally_algos
...@@ -758,11 +780,12 @@ class SimpleConv: ...@@ -758,11 +780,12 @@ class SimpleConv:
k: int, k: int,
c: int, c: int,
arch: Tuple[int, int], arch: Tuple[int, int],
mask_width: int = -1): mask_width: int = -1,
need_dynamic_mask: bool = False):
if not op_type == ConvOpType.kBackwardWeight: if not op_type == ConvOpType.kBackwardWeight:
# fwd and dgrad don't need # fwd and dgrad don't need
mask_width = -1 mask_width = -1
key = (i_dtype, w_dtype, o_dtype, k, c, arch[0], arch[1], mask_width) key = (i_dtype, w_dtype, o_dtype, k, c, arch[0], arch[1], mask_width, need_dynamic_mask)
if op_type == ConvOpType.kForward: if op_type == ConvOpType.kForward:
return self.kc_forward_cache.get(key, None) return self.kc_forward_cache.get(key, None)
elif op_type == ConvOpType.kBackwardInput: elif op_type == ConvOpType.kBackwardInput:
...@@ -795,8 +818,9 @@ class SimpleConv: ...@@ -795,8 +818,9 @@ class SimpleConv:
cudadevrt = str(cudadevrt_p) cudadevrt = str(cudadevrt_p)
mod = CummNVRTCModule([kernel], mod = CummNVRTCModule([kernel],
cudadevrt_path=cudadevrt, cudadevrt_path=cudadevrt,
verbose=False, verbose=True,
custom_names=custom_names) custom_names=custom_names,
verbose_path="/home/yy/Projects/spconv-release/spconv/build/dev_nvrtc_int8")
mod.load() mod.load()
return mod, kernel return mod, kernel
...@@ -824,7 +848,6 @@ class SimpleConv: ...@@ -824,7 +848,6 @@ class SimpleConv:
mask_argsort: tv.Tensor, mask_argsort: tv.Tensor,
indices: tv.Tensor, indices: tv.Tensor,
reverse_mask: bool, reverse_mask: bool,
mask_int_count: int = 1,
mask_filter: int = 0xffffffff, mask_filter: int = 0xffffffff,
mask_width: int = -1, mask_width: int = -1,
mask_output: tv.Tensor = tv.Tensor(), mask_output: tv.Tensor = tv.Tensor(),
...@@ -832,17 +855,20 @@ class SimpleConv: ...@@ -832,17 +855,20 @@ class SimpleConv:
beta: float = 0.0, beta: float = 0.0,
stream: int = 0, stream: int = 0,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
use_tf32: bool = True): use_tf32: bool = True,
bias: tv.Tensor = tv.Tensor(),
scale: tv.Tensor = tv.Tensor()):
avail = self.get_all_available(inp, weight, output, layout_i, layout_w, avail = self.get_all_available(inp, weight, output, layout_i, layout_w,
layout_o, arch, op_type, mask_width, layout_o, arch, op_type, mask_width,
fp32_accum, use_tf32) fp32_accum, use_tf32, bias, scale)
inp = inp.clone() inp = inp.clone()
weight = weight.clone() weight = weight.clone()
output = output.clone() output = output.clone()
print(len(avail), inp.dtype, weight.dtype, output.dtype, bias.dtype, scale.dtype, bias.empty(), scale.empty())
channel_k = output.dim(1) channel_k = output.dim(1)
channel_c = inp.dim(1) channel_c = inp.dim(1)
weight = weight.view([channel_k, -1, channel_c])
need_dynamic_mask = weight.dim(1) > 32
times: List[float] = [] times: List[float] = []
all_profile_res: List[BestConvAlgoByProfile] = [] all_profile_res: List[BestConvAlgoByProfile] = []
group_by_algo = {} group_by_algo = {}
...@@ -865,8 +891,9 @@ class SimpleConv: ...@@ -865,8 +891,9 @@ class SimpleConv:
params.indices = indices params.indices = indices
params.mask = mask params.mask = mask
params.mask_output = mask_output params.mask_output = mask_output
params.mask_int_count = mask_int_count if desp.is_int8_inference:
params.bias = bias
params.scale = scale
# if op_type == ConvOpType.kBackwardWeight: # if op_type == ConvOpType.kBackwardWeight:
# assert not mask_output.empty() # assert not mask_output.empty()
if op_type == ConvOpType.kBackwardInput: if op_type == ConvOpType.kBackwardInput:
...@@ -909,7 +936,7 @@ class SimpleConv: ...@@ -909,7 +936,7 @@ class SimpleConv:
# fwd and dgrad don't need # fwd and dgrad don't need
mask_width = -1 mask_width = -1
key = (inp.dtype, weight.dtype, output.dtype, channel_k, channel_c, key = (inp.dtype, weight.dtype, output.dtype, channel_k, channel_c,
arch[0], arch[1], mask_width) arch[0], arch[1], mask_width, need_dynamic_mask)
with self.lock: with self.lock:
if op_type == ConvOpType.kForward: if op_type == ConvOpType.kForward:
self.kc_forward_cache[key] = res self.kc_forward_cache[key] = res
...@@ -945,7 +972,9 @@ class SimpleConv: ...@@ -945,7 +972,9 @@ class SimpleConv:
act_alpha: float = 0.0, act_alpha: float = 0.0,
act_beta: float = 0.0, act_beta: float = 0.0,
act_type: tv.gemm.Activation = tv.gemm.Activation.None_, act_type: tv.gemm.Activation = tv.gemm.Activation.None_,
mask_int_count: Union[int, None] = None): scale: Optional[tv.Tensor] = None,
output_add: Optional[tv.Tensor] = None):
channel_k = output.dim(1) channel_k = output.dim(1)
channel_c = inp.dim(1) channel_c = inp.dim(1)
# GemmMainUnitTest.stream_synchronize(stream) # GemmMainUnitTest.stream_synchronize(stream)
...@@ -986,9 +1015,12 @@ class SimpleConv: ...@@ -986,9 +1015,12 @@ class SimpleConv:
params.mask_filter = mask_filter params.mask_filter = mask_filter
params.mask_output = mask_output params.mask_output = mask_output
params.reverse_mask = reverse_mask params.reverse_mask = reverse_mask
params.mask_int_count = mask_int_count
if bias is not None: if bias is not None:
params.bias = bias params.bias = bias
if output_add is not None and algo_desp.is_int8_inference:
params.output_add = output_add
if scale is not None and algo_desp.is_int8_inference:
params.scale = scale
if timer.enable: if timer.enable:
assert timer._timer is not None assert timer._timer is not None
params.timer = timer._timer params.timer = timer._timer
......
...@@ -36,8 +36,8 @@ from cumm.gemm.algospec import TensorOp ...@@ -36,8 +36,8 @@ from cumm.gemm.algospec import TensorOp
def _assign_gemm_desp_props(desp: Union[ConvAlgoDesp, GemmAlgoDesp], def _assign_gemm_desp_props(desp: Union[ConvAlgoDesp, GemmAlgoDesp],
p: Union[GemmAlgoParams, ConvAlgoParams]): p: Union[GemmAlgoParams, ConvAlgoParams]):
desp.dtype_a = p.dtype_a.tv_dtype desp.dtype_a = p.dtype_a.tv_dtype
desp.dtype_b = p.dtype_a.tv_dtype desp.dtype_b = p.dtype_b.tv_dtype
desp.dtype_c = p.dtype_a.tv_dtype desp.dtype_c = p.dtype_c.tv_dtype
desp.dacc = p.dtype_acc.tv_dtype desp.dacc = p.dtype_acc.tv_dtype
desp.dcomp = p.dtype_comp.tv_dtype desp.dcomp = p.dtype_comp.tv_dtype
desp.trans_a = p.trans_a desp.trans_a = p.trans_a
...@@ -87,6 +87,9 @@ def get_conv_algo_desp_from_param(p: ConvAlgoParams): ...@@ -87,6 +87,9 @@ def get_conv_algo_desp_from_param(p: ConvAlgoParams):
desp.element_per_access_a = ker.input_spec.input_iter_a.element_per_acc desp.element_per_access_a = ker.input_spec.input_iter_a.element_per_acc
desp.element_per_access_b = ker.input_spec.input_iter_b.element_per_acc desp.element_per_access_b = ker.input_spec.input_iter_b.element_per_acc
desp.element_per_access_c = ker.output_spec.out_iter.element_per_acc desp.element_per_access_c = ker.output_spec.out_iter.element_per_acc
desp.is_int8_inference = ker.int8_inference
desp.dynamic_mask = ker.dynamic_mask
desp.min_arch = ker.min_arch() desp.min_arch = ker.min_arch()
return desp return desp
...@@ -141,4 +144,6 @@ def get_conv_param_from_desp(desp: ConvAlgoDesp): ...@@ -141,4 +144,6 @@ def get_conv_param_from_desp(desp: ConvAlgoDesp):
desp.interleave_o) desp.interleave_o)
p.mask_sparse = desp.mask_sparse p.mask_sparse = desp.mask_sparse
p.increment_k_first = desp.increment_k_first p.increment_k_first = desp.increment_k_first
p.int8_inference = desp.is_int8_inference
p.dynamic_mask = desp.dynamic_mask
return p return p
...@@ -39,12 +39,12 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable( ...@@ -39,12 +39,12 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable(
from spconv.csrc.sparse.inference import InferenceOps from spconv.csrc.sparse.inference import InferenceOps
all_shuffle = SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_AMPERE_PARAMS all_shuffle = SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_AMPERE_PARAMS
all_shuffle = list(filter(lambda x: not x.is_nvrtc, all_shuffle)) # all_shuffle = list(filter(lambda x: not x.is_nvrtc, all_shuffle))
cu = GemmMainUnitTest(all_shuffle) cu = GemmMainUnitTest(all_shuffle)
cu.namespace = "cumm.gemm.main" cu.namespace = "cumm.gemm.main"
all_imp = (IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS + all_imp = (IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS +
IMPLGEMM_TURING_PARAMS + IMPLGEMM_AMPERE_PARAMS) IMPLGEMM_TURING_PARAMS + IMPLGEMM_AMPERE_PARAMS)
all_imp = list(filter(lambda x: not x.is_nvrtc, all_imp)) # all_imp = list(filter(lambda x: not x.is_nvrtc, all_imp))
convcu = ConvMainUnitTest(all_imp) convcu = ConvMainUnitTest(all_imp)
convcu.namespace = "cumm.conv.main" convcu.namespace = "cumm.conv.main"
gemmtuner = GemmTunerSimple(cu) gemmtuner = GemmTunerSimple(cu)
......
...@@ -619,14 +619,11 @@ IMPLGEMM_AMPERE_PARAMS = [ ...@@ -619,14 +619,11 @@ IMPLGEMM_AMPERE_PARAMS = [
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 64), (64, 32, 64), *gen_conv_params(ConvFwdAndBwdInput, (128, 64, 64), (64, 32, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -635,13 +632,14 @@ IMPLGEMM_AMPERE_PARAMS = [ ...@@ -635,13 +632,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 32), (64, 32, 32), *gen_conv_params(ConvFwdAndBwdInput, (128, 64, 32), (64, 32, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -650,13 +648,14 @@ IMPLGEMM_AMPERE_PARAMS = [ ...@@ -650,13 +648,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 64), (32, 64, 64), *gen_conv_params(ConvFwdAndBwdInput, (64, 128, 64), (32, 64, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -665,13 +664,14 @@ IMPLGEMM_AMPERE_PARAMS = [ ...@@ -665,13 +664,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32), *gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -680,13 +680,14 @@ IMPLGEMM_AMPERE_PARAMS = [ ...@@ -680,13 +680,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32), *gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16", "s8,s8,f32,s32,f32", "s8,s8,f32,s32,f16", "s8,s8,f16,s32,f32", "s8,s8,f16,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -695,13 +696,14 @@ IMPLGEMM_AMPERE_PARAMS = [ ...@@ -695,13 +696,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64), *gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -710,13 +712,14 @@ IMPLGEMM_AMPERE_PARAMS = [ ...@@ -710,13 +712,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 64), (64, 64, 64), *gen_conv_params(ConvFwdAndBwdInput, (128, 128, 64), (64, 64, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -725,13 +728,14 @@ IMPLGEMM_AMPERE_PARAMS = [ ...@@ -725,13 +728,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 256, 64), (64, 128, 64), *gen_conv_params(ConvFwdAndBwdInput, (128, 256, 64), (64, 128, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -740,13 +744,14 @@ IMPLGEMM_AMPERE_PARAMS = [ ...@@ -740,13 +744,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (256, 128, 64), (128, 64, 64), *gen_conv_params(ConvFwdAndBwdInput, (256, 128, 64), (128, 64, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -755,13 +760,14 @@ IMPLGEMM_AMPERE_PARAMS = [ ...@@ -755,13 +760,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 128), (64, 64, 128), *gen_conv_params(ConvFwdAndBwdInput, (128, 128, 128), (64, 64, 128),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3], [2, 3],
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -770,7 +776,8 @@ IMPLGEMM_AMPERE_PARAMS = [ ...@@ -770,7 +776,8 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
] ]
IMPLGEMM_TURING_PARAMS = [ IMPLGEMM_TURING_PARAMS = [
...@@ -779,7 +786,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -779,7 +786,7 @@ IMPLGEMM_TURING_PARAMS = [
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, 2,
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16", "s8,s8,f32,s32,f32", "s8,s8,f32,s32,f16", "s8,s8,f16,s32,f32", "s8,s8,f16,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -788,13 +795,14 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -788,13 +795,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64), *gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, 2,
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -803,13 +811,14 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -803,13 +811,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 64), (32, 64, 64), *gen_conv_params(ConvFwdAndBwdInput, (64, 128, 64), (32, 64, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, 2,
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -818,13 +827,14 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -818,13 +827,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32), *gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, 2,
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -833,13 +843,14 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -833,13 +843,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 64), (64, 32, 64), *gen_conv_params(ConvFwdAndBwdInput, (128, 64, 64), (64, 32, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, 2,
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -848,13 +859,14 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -848,13 +859,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 32), (64, 32, 32), *gen_conv_params(ConvFwdAndBwdInput, (128, 64, 32), (64, 32, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, 2,
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -863,13 +875,14 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -863,13 +875,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 256, 64), (64, 128, 64), *gen_conv_params(ConvFwdAndBwdInput, (128, 256, 64), (64, 128, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, 2,
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -878,13 +891,14 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -878,13 +891,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (256, 128, 64), (128, 64, 64), *gen_conv_params(ConvFwdAndBwdInput, (256, 128, 64), (128, 64, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, 2,
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -893,13 +907,14 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -893,13 +907,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 128), (64, 64, 128), *gen_conv_params(ConvFwdAndBwdInput, (128, 128, 128), (64, 64, 128),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, 2,
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -908,13 +923,14 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -908,13 +923,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 64), (64, 64, 64), *gen_conv_params(ConvFwdAndBwdInput, (128, 128, 64), (64, 64, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, 2,
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -923,7 +939,8 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -923,7 +939,8 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (32, 16, 16), (16, 16, 16), *gen_conv_params(ConvFwdAndBwdInput, (32, 16, 16), (16, 16, 16),
......
...@@ -144,7 +144,7 @@ class SpconvOps: ...@@ -144,7 +144,7 @@ class SpconvOps:
""" """
... ...
@staticmethod @staticmethod
def generate_conv_inds_mask_stage2(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0, mask_int_count: int = 1) -> int: def generate_conv_inds_mask_stage2(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
""" """
Args: Args:
indices: indices:
...@@ -167,11 +167,10 @@ class SpconvOps: ...@@ -167,11 +167,10 @@ class SpconvOps:
dilation: dilation:
transposed: transposed:
stream_int: stream_int:
mask_int_count:
""" """
... ...
@staticmethod @staticmethod
def generate_conv_inds_stage2_mask_direct_table(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0, mask_int_count: int = 1) -> int: def generate_conv_inds_stage2_mask_direct_table(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
""" """
Args: Args:
indices: indices:
...@@ -194,11 +193,10 @@ class SpconvOps: ...@@ -194,11 +193,10 @@ class SpconvOps:
dilation: dilation:
transposed: transposed:
stream_int: stream_int:
mask_int_count:
""" """
... ...
@staticmethod @staticmethod
def generate_subm_conv_inds(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, batch_size: int, input_dims: List[int], ksize: List[int], dilation: List[int], indice_pair_mask: Tensor = Tensor(), backward: bool = False, stream_int: int = 0, mask_int_count: int = 1) -> int: def generate_subm_conv_inds(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, batch_size: int, input_dims: List[int], ksize: List[int], dilation: List[int], indice_pair_mask: Tensor = Tensor(), backward: bool = False, stream_int: int = 0) -> int:
""" """
Args: Args:
indices: indices:
...@@ -214,7 +212,6 @@ class SpconvOps: ...@@ -214,7 +212,6 @@ class SpconvOps:
indice_pair_mask: indice_pair_mask:
backward: backward:
stream_int: stream_int:
mask_int_count:
""" """
... ...
@staticmethod @staticmethod
...@@ -383,65 +380,25 @@ class SpconvOps: ...@@ -383,65 +380,25 @@ class SpconvOps:
""" """
... ...
@staticmethod @staticmethod
def sort_1d_by_key_allocator_mask32(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0) -> Tensor: def sort_1d_by_key_allocator(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0, mask_count: int = 1) -> Tensor:
""" """
Args: Args:
data: data:
alloc_func: alloc_func:
indices: indices:
stream: stream:
mask_count:
""" """
... ...
@staticmethod @staticmethod
def sort_1d_by_key_allocator_mask32_v2(data: Tensor, allocator, indices: Tensor = Tensor(), stream: int = 0) -> Tensor: def sort_1d_by_key_allocator_v2(data: Tensor, allocator, indices: Tensor = Tensor(), stream: int = 0, mask_count: int = 1) -> Tensor:
""" """
Args: Args:
data: data:
allocator: allocator:
indices: indices:
stream: stream:
""" mask_count:
...
@staticmethod
def sort_1d_by_key_allocator_mask128(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0) -> Tensor:
"""
Args:
data:
alloc_func:
indices:
stream:
"""
...
@staticmethod
def sort_1d_by_key_allocator_mask128_v2(data: Tensor, allocator, indices: Tensor = Tensor(), stream: int = 0) -> Tensor:
"""
Args:
data:
allocator:
indices:
stream:
"""
...
@staticmethod
def sort_1d_by_key_allocator_mask_auto(data: Tensor, alloc_param, indices: Tensor = Tensor(), stream: int = 0, mask_int_count: int = 1) -> Tensor:
"""
Args:
data:
alloc_param:
indices:
stream:
mask_int_count:
"""
...
@staticmethod
def sort_1d_by_key_allocator_mask_auto_v2(data: Tensor, alloc_param, indices: Tensor = Tensor(), stream: int = 0, mask_int_count: int = 1) -> Tensor:
"""
Args:
data:
alloc_param:
indices:
stream:
mask_int_count:
""" """
... ...
@staticmethod @staticmethod
...@@ -598,7 +555,7 @@ class SpconvOps: ...@@ -598,7 +555,7 @@ class SpconvOps:
""" """
... ...
@staticmethod @staticmethod
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1, timer: CUDAKernelTimer = CUDAKernelTimer(False), direct_table: bool = False, preallocated: Dict[str, Tensor] = {}) -> Tuple[Tensor, int, int]: def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1, timer: CUDAKernelTimer = CUDAKernelTimer(False), direct_table: bool = False, preallocated: Dict[str, Tensor] = {}) -> Tuple[Tensor, int]:
""" """
Args: Args:
allocator: allocator:
......
...@@ -20,7 +20,7 @@ class ConvTunerSimple: ...@@ -20,7 +20,7 @@ class ConvTunerSimple:
arch: arch:
""" """
... ...
def get_all_available(self, inp: Tensor, weight: Tensor, out: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], op_type: int, mask_width: int, auto_fp32_accum: bool, fp32_accum: bool, use_tf32: bool = True) -> List[ConvAlgoDesp]: def get_all_available(self, inp: Tensor, weight: Tensor, out: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], op_type: int, mask_width: int, auto_fp32_accum: bool, fp32_accum: bool, use_tf32: bool = True, bias: Tensor = Tensor(), scale: Tensor = Tensor()) -> List[ConvAlgoDesp]:
""" """
Args: Args:
inp: inp:
...@@ -38,6 +38,8 @@ class ConvTunerSimple: ...@@ -38,6 +38,8 @@ class ConvTunerSimple:
auto_fp32_accum: auto_fp32_accum:
fp32_accum: fp32_accum:
use_tf32: use_tf32:
bias:
scale:
""" """
... ...
def cached_get_nvrtc_params(self, desp: ConvAlgoDesp, arch: Tuple[int, int], stream_int: int) -> NVRTCParams: def cached_get_nvrtc_params(self, desp: ConvAlgoDesp, arch: Tuple[int, int], stream_int: int) -> NVRTCParams:
...@@ -48,7 +50,7 @@ class ConvTunerSimple: ...@@ -48,7 +50,7 @@ class ConvTunerSimple:
stream_int: stream_int:
""" """
... ...
def tune_and_cache(self, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], mask: Tensor, mask_argsort: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, mask_output: Tensor = Tensor(), alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, mask_int_count: int = 1, auto_fp32_accum: bool = True, fp32_accum: bool = False, num_run: int = 5, use_tf32: bool = True) -> Tuple[ConvTuneResult, float]: def tune_and_cache(self, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], mask: Tensor, mask_argsort: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, mask_output: Tensor = Tensor(), alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, auto_fp32_accum: bool = True, fp32_accum: bool = False, num_run: int = 5, use_tf32: bool = True, bias: Tensor = Tensor(), scale: Tensor = Tensor()) -> Tuple[ConvTuneResult, float]:
""" """
Args: Args:
op_type: op_type:
...@@ -72,14 +74,15 @@ class ConvTunerSimple: ...@@ -72,14 +74,15 @@ class ConvTunerSimple:
alpha: alpha:
beta: beta:
stream_int: stream_int:
mask_int_count:
auto_fp32_accum: auto_fp32_accum:
fp32_accum: fp32_accum:
num_run: num_run:
use_tf32: use_tf32:
bias:
scale:
""" """
... ...
def get_tuned_algo(self, op_type: int, i_dtype: int, w_dtype: int, o_dtype: int, k: int, c: int, arch: Tuple[int, int], mask_width: int = -1) -> Tuple[Any, bool]: def get_tuned_algo(self, op_type: int, i_dtype: int, w_dtype: int, o_dtype: int, k: int, c: int, arch: Tuple[int, int], mask_width: int = -1, need_dynamic_mask: bool = False) -> Tuple[Any, bool]:
""" """
Args: Args:
op_type: op_type:
...@@ -90,9 +93,10 @@ class ConvTunerSimple: ...@@ -90,9 +93,10 @@ class ConvTunerSimple:
c: c:
arch: arch:
mask_width: mask_width:
need_dynamic_mask:
""" """
... ...
def run_with_tuned_result(self, profile_res, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, mask: Tensor, mask_argsort: Tensor, mask_output: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, mask_int_count: int = 1, workspace: Tensor = Tensor(), verbose: bool = False, timer: CUDAKernelTimer = CUDAKernelTimer(false), force_nvrtc: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_) -> None: def run_with_tuned_result(self, profile_res, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, mask: Tensor, mask_argsort: Tensor, mask_output: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, workspace: Tensor = Tensor(), verbose: bool = False, timer: CUDAKernelTimer = CUDAKernelTimer(false), force_nvrtc: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_, scale: Tensor = Tensor(), output_add: Tensor = Tensor()) -> None:
""" """
Args: Args:
profile_res: profile_res:
...@@ -110,7 +114,6 @@ class ConvTunerSimple: ...@@ -110,7 +114,6 @@ class ConvTunerSimple:
alpha: alpha:
beta: beta:
stream_int: stream_int:
mask_int_count:
workspace: workspace:
verbose: verbose:
timer: timer:
...@@ -119,6 +122,8 @@ class ConvTunerSimple: ...@@ -119,6 +122,8 @@ class ConvTunerSimple:
act_alpha: act_alpha:
act_beta: act_beta:
act_type: act_type:
scale:
output_add:
""" """
... ...
def query_workspace_size(self, desp: ConvAlgoDesp, splitk: int, op_type: int, N: int, C: int, K: int, kv: int) -> int: def query_workspace_size(self, desp: ConvAlgoDesp, splitk: int, op_type: int, N: int, C: int, K: int, kv: int) -> int:
......
...@@ -63,7 +63,7 @@ class ConvGemmOps: ...@@ -63,7 +63,7 @@ class ConvGemmOps:
""" """
... ...
@staticmethod @staticmethod
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, mask_int_count: int, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_, use_tf32: bool = True) -> Tuple[int, Any]: def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_, use_tf32: bool = True, output_scale: float = 1.0, scale: Tensor = Tensor(), output_add: Tensor = Tensor(), output_add_scale: float = 1.0) -> Tuple[int, Any]:
""" """
Args: Args:
allocator: allocator:
...@@ -75,7 +75,6 @@ class ConvGemmOps: ...@@ -75,7 +75,6 @@ class ConvGemmOps:
mask_argsort_fwd_splits: mask_argsort_fwd_splits:
num_activate_out: num_activate_out:
masks: masks:
mask_int_count:
arch: arch:
is_train: is_train:
is_subm: is_subm:
...@@ -88,10 +87,14 @@ class ConvGemmOps: ...@@ -88,10 +87,14 @@ class ConvGemmOps:
act_beta: act_beta:
act_type: act_type:
use_tf32: use_tf32:
output_scale:
scale:
output_add:
output_add_scale:
""" """
... ...
@staticmethod @staticmethod
def implicit_gemm_backward(allocator, conv_tuner, features: Tensor, filters: Tensor, out_bp: Tensor, pair_fwd: Tensor, pair_bwd: Tensor, pair_mask_fwd_splits: List[Tensor], pair_mask_bwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], mask_argsort_bwd_splits: List[Tensor], mask_output_fwd: Tensor, masks: Tensor, mask_int_count: int, arch: Tuple[int, int], mask_width: int, is_subm: bool, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, use_tf32: bool = True) -> None: def implicit_gemm_backward(allocator, conv_tuner, features: Tensor, filters: Tensor, out_bp: Tensor, pair_fwd: Tensor, pair_bwd: Tensor, pair_mask_fwd_splits: List[Tensor], pair_mask_bwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], mask_argsort_bwd_splits: List[Tensor], mask_output_fwd: Tensor, masks: Tensor, arch: Tuple[int, int], mask_width: int, is_subm: bool, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, use_tf32: bool = True) -> None:
""" """
Args: Args:
allocator: allocator:
...@@ -107,7 +110,6 @@ class ConvGemmOps: ...@@ -107,7 +110,6 @@ class ConvGemmOps:
mask_argsort_bwd_splits: mask_argsort_bwd_splits:
mask_output_fwd: mask_output_fwd:
masks: masks:
mask_int_count:
arch: arch:
mask_width: mask_width:
is_subm: is_subm:
......
...@@ -30,7 +30,7 @@ from .alloc import ExternalAllocator, ThrustAllocator ...@@ -30,7 +30,7 @@ from .alloc import ExternalAllocator, ThrustAllocator
from spconv.constants import SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE, AllocKeys from spconv.constants import SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE, AllocKeys
import re import re
import os import os
from cumm.gemm.codeops import dispatch
class CustomThrustLib(pccm.Class): class CustomThrustLib(pccm.Class):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -462,7 +462,6 @@ class SpconvOps(pccm.Class): ...@@ -462,7 +462,6 @@ class SpconvOps(pccm.Class):
code.arg("ksize, stride, padding, dilation", f"std::vector<int>") code.arg("ksize, stride, padding, dilation", f"std::vector<int>")
code.arg("transposed", f"bool", "false") code.arg("transposed", f"bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int") code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.arg("mask_int_count", "int", "1")
code.raw(f""" code.raw(f"""
int ndim = indices.dim(1) - 1; int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim && TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim &&
...@@ -489,7 +488,7 @@ class SpconvOps(pccm.Class): ...@@ -489,7 +488,7 @@ class SpconvOps(pccm.Class):
indice_pairs_uniq, indice_pairs_uniq_before_sort, indice_pairs_uniq, indice_pairs_uniq_before_sort,
out_inds, mask_fwd, mask_bwd, out_inds, mask_fwd, mask_bwd,
num_out_act, batch_size, output_dims_, input_dims_, num_out_act, batch_size, output_dims_, input_dims_,
ksize_, stride_, padding_, dilation_, transposed, stream_int, mask_int_count); ksize_, stride_, padding_, dilation_, transposed, stream_int);
}} }}
""") """)
code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""") code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""")
...@@ -513,7 +512,6 @@ class SpconvOps(pccm.Class): ...@@ -513,7 +512,6 @@ class SpconvOps(pccm.Class):
code.arg("ksize, stride, padding, dilation", f"std::vector<int>") code.arg("ksize, stride, padding, dilation", f"std::vector<int>")
code.arg("transposed", f"bool", "false") code.arg("transposed", f"bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int") code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.arg("mask_int_count", "int", "1")
code.raw(f""" code.raw(f"""
int ndim = indices.dim(1) - 1; int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim && TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim &&
...@@ -540,7 +538,7 @@ class SpconvOps(pccm.Class): ...@@ -540,7 +538,7 @@ class SpconvOps(pccm.Class):
indice_pairs_uniq, indice_pairs_uniq_before_sort, indice_pairs_uniq, indice_pairs_uniq_before_sort,
out_inds, mask_fwd, mask_bwd, out_inds, mask_fwd, mask_bwd,
num_out_act, batch_size, output_dims_, input_dims_, num_out_act, batch_size, output_dims_, input_dims_,
ksize_, stride_, padding_, dilation_, transposed, stream_int, mask_int_count); ksize_, stride_, padding_, dilation_, transposed, stream_int);
}} }}
""") """)
code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""") code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""")
...@@ -561,7 +559,6 @@ class SpconvOps(pccm.Class): ...@@ -561,7 +559,6 @@ class SpconvOps(pccm.Class):
"cumm.tensorview.Tensor = Tensor()") "cumm.tensorview.Tensor = Tensor()")
code.arg("backward", "bool", "false") code.arg("backward", "bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int = 0") code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int = 0")
code.arg("mask_int_count", "int", "1")
code.raw(f""" code.raw(f"""
int ndim = indices.dim(1) - 1; int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(input_dims.size() == ndim && TV_ASSERT_RT_ERR(input_dims.size() == ndim &&
...@@ -582,7 +579,7 @@ class SpconvOps(pccm.Class): ...@@ -582,7 +579,7 @@ class SpconvOps(pccm.Class):
indice_pairs, out_inds, indice_num_per_loc, indice_pairs, out_inds, indice_num_per_loc,
batch_size, input_dims_, batch_size, input_dims_,
ksize_, dilation_, indice_pair_mask, backward, ksize_, dilation_, indice_pair_mask, backward,
stream_int, mask_int_count); stream_int);
}} }}
""") """)
code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""") code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""")
...@@ -909,7 +906,7 @@ class SpconvOps(pccm.Class): ...@@ -909,7 +906,7 @@ class SpconvOps(pccm.Class):
""") """)
return code return code
def sort_1d_by_key_allocator_template(self, use_allocator: bool, int_count: int = 1): def sort_1d_by_key_allocator_template(self, use_allocator: bool):
code = pccm.FunctionCode() code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD: if CUMM_CPU_ONLY_BUILD:
return code.make_invalid() return code.make_invalid()
...@@ -924,18 +921,7 @@ class SpconvOps(pccm.Class): ...@@ -924,18 +921,7 @@ class SpconvOps(pccm.Class):
"tv::Tensor()", "tv::Tensor()",
pyanno="cumm.tensorview.Tensor = Tensor()") pyanno="cumm.tensorview.Tensor = Tensor()")
code.arg("stream", "std::uintptr_t", "0", pyanno="int") code.arg("stream", "std::uintptr_t", "0", pyanno="int")
code.code_after_include = f""" code.arg("mask_count", "int", "1", pyanno="int")
template <typename T> struct SmallOrEqualTo {{
TV_HOST_DEVICE_INLINE T operator()(const T &x, const T &y) const {{
return x < y;
}}
}};
template <typename T> __global__ void mask_input(T* inp, T mask, int size){{
for (int i : tv::KernelLoopX<int>(size)){{
inp[i] &= mask;
}}
}}
"""
code.add_dependency(CustomThrustLib, TensorViewKernel) code.add_dependency(CustomThrustLib, TensorViewKernel)
code.add_param_class("cudakers", self.cuda_common_kernel) code.add_param_class("cudakers", self.cuda_common_kernel)
if not use_allocator: if not use_allocator:
...@@ -945,20 +931,29 @@ class SpconvOps(pccm.Class): ...@@ -945,20 +931,29 @@ class SpconvOps(pccm.Class):
code.raw(f""" code.raw(f"""
cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream); cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream);
if (indices.empty()){{ if (indices.empty()){{
indices = tv::empty({{data.dim(0) / {int_count}}}, tv::int32, 0); indices = tv::empty({{data.dim(0)}}, tv::int32, 0);
}} }}
tv::cuda::Launch launcher(data.dim(0), stream_cu); tv::cuda::Launch launcher(data.dim(0), stream_cu);
launcher(cudakers::arange_kernel<int32_t>, indices.data_ptr<int32_t>(), indices.dim(0)); launcher(cudakers::arange_kernel<int32_t>, indices.data_ptr<int32_t>(), indices.dim(0));
// auto timer = tv::CUDATimer(); // auto timer = tv::CUDATimer();
tv::dispatch<int32_t, uint32_t, int64_t, uint64_t>(data.dtype(), [&](auto I){{ """)
using T_ = TV_DECLTYPE(I); # nested tv::dispatch may cause compiler bug in msvc.
using T = {"T_" if int_count == 1 else f"thrust::tuple<{', '.join(['T_'] * int_count) }>"}; for dtype in dispatch(code, [dtypes.int32, dtypes.int64, dtypes.uint32, dtypes.uint64], "data.dtype()"):
thrust::device_ptr<T> ptr_tr(reinterpret_cast<T*>(data.data_ptr<T_>())); code.raw(f"""
thrust::device_ptr<int32_t> ptr_k(indices.data_ptr<int32_t>()); using T_ = {dtype};
auto thrust_ctx = thrust::cuda::par.on(stream_cu); tv::dispatch_int<1, 2, 3, 4>(mask_count, [&](auto IV){{
auto ctx2 = thrust::cuda::par(allocator).on(stream_cu); constexpr int I = TV_DECLTYPE(IV)::value;
thrust::sort_by_key(ctx2, ptr_tr, ptr_tr + data.dim(0) / {int_count}, ptr_k); // we can't use thrust::tuple in mp_repeat_c directly because
}}); // thrust tuple actually has fixed size template arguments.
using T = tv::mp_rename<tv::mp_repeat_c<tv::mp_list<T_>, I>, thrust::tuple>;
thrust::device_ptr<T> ptr_tr(reinterpret_cast<T*>(data.data_ptr<T_>()));
thrust::device_ptr<int32_t> ptr_k(indices.data_ptr<int32_t>());
auto thrust_ctx = thrust::cuda::par.on(stream_cu);
auto ctx2 = thrust::cuda::par(allocator).on(stream_cu);
thrust::sort_by_key(ctx2, ptr_tr, ptr_tr + data.dim(0), ptr_k);
}});
""")
code.raw(f"""
// tv::ssprint("SORT BY KEY TIME", data.dim(0), timer.report() / 1000.0); // tv::ssprint("SORT BY KEY TIME", data.dim(0), timer.report() / 1000.0);
return indices; return indices;
""") """)
...@@ -967,71 +962,12 @@ class SpconvOps(pccm.Class): ...@@ -967,71 +962,12 @@ class SpconvOps(pccm.Class):
@pccm.pybind.mark @pccm.pybind.mark
@_STATIC_FUNCTION @_STATIC_FUNCTION
def sort_1d_by_key_allocator_mask32(self): def sort_1d_by_key_allocator(self):
# for python
return self.sort_1d_by_key_allocator_template(False) return self.sort_1d_by_key_allocator_template(False)
@pccm.pybind.mark @pccm.pybind.mark
@_STATIC_FUNCTION
def sort_1d_by_key_allocator_mask32_v2(self):
# for python
return self.sort_1d_by_key_allocator_template(True)
@pccm.pybind.mark
@_STATIC_FUNCTION
def sort_1d_by_key_allocator_mask128(self):
# for python
return self.sort_1d_by_key_allocator_template(False, 4)
@pccm.pybind.mark
@_STATIC_FUNCTION
def sort_1d_by_key_allocator_mask128_v2(self):
# for python
return self.sort_1d_by_key_allocator_template(True, 4)
def sort_1d_by_key_allocator_mask_auto_template(self, use_allocator: bool):
code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
code.arg("data", "tv::Tensor")
if not use_allocator:
code.arg("alloc_param", "std::function<std::uintptr_t(std::size_t)>")
else:
code.arg("alloc_param", "ThrustAllocator&")
code.arg("indices",
"tv::Tensor",
"tv::Tensor()",
pyanno="cumm.tensorview.Tensor = Tensor()")
code.arg("stream", "std::uintptr_t", "0", pyanno="int")
code.arg("mask_int_count", "int", "1")
code.raw(f"""
switch (mask_int_count){{
case 1:
return sort_1d_by_key_allocator_mask32{"_v2" if use_allocator else ""}(data, alloc_param, indices, stream);
case 4:
return sort_1d_by_key_allocator_mask128{"_v2" if use_allocator else ""}(data, alloc_param, indices, stream);
default:
TV_ASSERT_RT_ERR(false, "Not implement for other mask_int_count");
return tv::Tensor();
}}
""")
return code.ret("tv::Tensor")
@pccm.pybind.mark
@pccm.static_function
def sort_1d_by_key_allocator_mask_auto(self):
return self.sort_1d_by_key_allocator_mask_auto_template(False)
@pccm.pybind.mark
@pccm.static_function
def sort_1d_by_key_allocator_mask_auto_v2(self):
return self.sort_1d_by_key_allocator_mask_auto_template(True)
@_STATIC_FUNCTION @_STATIC_FUNCTION
def sort_1d_by_key_allocator_v2(self): def sort_1d_by_key_allocator_v2(self):
# for cpp only
return self.sort_1d_by_key_allocator_template(True) return self.sort_1d_by_key_allocator_template(True)
@pccm.pybind.mark @pccm.pybind.mark
...@@ -1622,7 +1558,7 @@ class SpconvOps(pccm.Class): ...@@ -1622,7 +1558,7 @@ class SpconvOps(pccm.Class):
code.raw(f""" code.raw(f"""
int hash_size = 2 * num_act_out_bound; int hash_size = 2 * num_act_out_bound;
if (direct_table){{ if (direct_table){{
hash_size = int({SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE} * max_act_out_in_theory); hash_size = tv::align_up(int({SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE} * max_act_out_in_theory), 2);
}} }}
size_t res = 0; size_t res = 0;
if (subm){{ if (subm){{
...@@ -1655,7 +1591,7 @@ class SpconvOps(pccm.Class): ...@@ -1655,7 +1591,7 @@ class SpconvOps(pccm.Class):
max_act_out_in_theory, subm, use_int64_hash_k, direct_table); max_act_out_in_theory, subm, use_int64_hash_k, direct_table);
int hash_size = 2 * num_act_out_bound; int hash_size = 2 * num_act_out_bound;
if (direct_table){{ if (direct_table){{
hash_size = int({SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE} * max_act_out_in_theory); hash_size = tv::align_up(int({SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE} * max_act_out_in_theory), 2);
}} }}
if (use_int64_hash_k){{ if (use_int64_hash_k){{
auto ten = tv::from_blob(workspace, {{int64_t(hash_size)}}, tv::int64, 0); auto ten = tv::from_blob(workspace, {{int64_t(hash_size)}}, tv::int64, 0);
...@@ -1720,10 +1656,10 @@ class SpconvOps(pccm.Class): ...@@ -1720,10 +1656,10 @@ class SpconvOps(pccm.Class):
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int)); tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int));
auto conv_algo = static_cast<tv::gemm::SparseConvAlgo>(algo); auto conv_algo = static_cast<tv::gemm::SparseConvAlgo>(algo);
int kv = std::accumulate(ksize.begin(), ksize.end(), 1, std::multiplies<int>()); int kv = std::accumulate(ksize.begin(), ksize.end(), 1, std::multiplies<int>());
int mask_int_count = (kv + 31) / 32; int mask_int_count = tv::div_up(kv, 32);
if (mask_int_count > 1 && mask_int_count < 4) // if (mask_int_count > 1 && mask_int_count < 4)
mask_int_count = 4; // mask_int_count = 4;
TV_ASSERT_RT_ERR(mask_int_count == 1 || mask_int_count == 4, "Not Implement too large kernel"); // TV_ASSERT_RT_ERR(mask_int_count == 1 || mask_int_count == 4, "Not Implement too large kernel");
// TV_ASSERT_RT_ERR(kv <= 32, "currently only support ksize < 32"); // TV_ASSERT_RT_ERR(kv <= 32, "currently only support ksize < 32");
std::vector<int> out_shape; std::vector<int> out_shape;
if (!subm){{ if (!subm){{
...@@ -1845,14 +1781,14 @@ class SpconvOps(pccm.Class): ...@@ -1845,14 +1781,14 @@ class SpconvOps(pccm.Class):
pair_mask = preallocated.at({pccm.literal(AllocKeys.PairMask)}); pair_mask = preallocated.at({pccm.literal(AllocKeys.PairMask)});
}}else{{ }}else{{
pair_mask = allocator.empty({pccm.literal(AllocKeys.PairMask)}, pair_mask = allocator.empty({pccm.literal(AllocKeys.PairMask)},
{{mask_split_count, num_act_in * mask_int_count}}, tv::uint32, 0, stream_int); {{mask_split_count, num_act_in, mask_int_count}}, tv::uint32, 0, stream_int);
}} }}
generate_subm_conv_inds(indices, hash_k, hash_v, pair, out_inds, indice_num_per_loc, generate_subm_conv_inds(indices, hash_k, hash_v, pair, out_inds, indice_num_per_loc,
batch_size, input_dims, ksize, dilation, pair_mask, is_train, stream_int, mask_int_count); batch_size, input_dims, ksize, dilation, pair_mask, is_train, stream_int);
auto mask_argsort = allocator.empty({pccm.literal(AllocKeys.MaskArgSort)}, auto mask_argsort = allocator.empty({pccm.literal(AllocKeys.MaskArgSort)},
{{mask_split_count, num_act_in}}, tv::int32, 0, stream_int); {{mask_split_count, num_act_in}}, tv::int32, 0, stream_int);
for (int j = 0; j < mask_split_count; ++j){{ for (int j = 0; j < mask_split_count; ++j){{
sort_1d_by_key_allocator_mask_auto_v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int, mask_int_count); sort_1d_by_key_allocator_v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int, mask_int_count);
}} }}
""") """)
with code.else_(): with code.else_():
...@@ -1958,11 +1894,11 @@ Your Conv Params: )" << "\\n"; ...@@ -1958,11 +1894,11 @@ Your Conv Params: )" << "\\n";
pair_fwd = allocator.full_int({pccm.literal(AllocKeys.PairFwd)}, pair_fwd = allocator.full_int({pccm.literal(AllocKeys.PairFwd)},
{{kv, num_act_out}}, -1, indices.dtype(), indices.device(), stream_int); {{kv, num_act_out}}, -1, indices.dtype(), indices.device(), stream_int);
pair_mask_fwd = allocator.zeros({pccm.literal(AllocKeys.PairMask)}, pair_mask_fwd = allocator.zeros({pccm.literal(AllocKeys.PairMask)},
{{mask_split_count, num_act_out * mask_int_count}}, tv::uint32, 0, stream_int); {{mask_split_count, num_act_out, mask_int_count}}, tv::uint32, 0, stream_int);
pair_mask_bwd = tv::Tensor(); pair_mask_bwd = tv::Tensor();
if (is_train){{ if (is_train){{
pair_mask_bwd = allocator.zeros({pccm.literal(AllocKeys.PairMaskBwd)}, pair_mask_bwd = allocator.zeros({pccm.literal(AllocKeys.PairMaskBwd)},
{{mask_split_count, indices.dim(0) * mask_int_count}}, tv::uint32, 0, stream_int); {{mask_split_count, indices.dim(0), mask_int_count}}, tv::uint32, 0, stream_int);
}} }}
}} }}
if (!direct_table){{ if (!direct_table){{
...@@ -1994,13 +1930,13 @@ Your Conv Params: )" << "\\n"; ...@@ -1994,13 +1930,13 @@ Your Conv Params: )" << "\\n";
indice_pairs_uniq, indice_pairs_uniq_bkp, indice_pairs_uniq, indice_pairs_uniq_bkp,
out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out, out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out,
batch_size, out_shape, input_dims, ksize, stride, padding, dilation, batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
transposed, stream_int, mask_int_count); transposed, stream_int);
}}else{{ }}else{{
generate_conv_inds_mask_stage2(indices, hash_k, hash_v, pair_fwd, pair_bwd, generate_conv_inds_mask_stage2(indices, hash_k, hash_v, pair_fwd, pair_bwd,
indice_pairs_uniq, indice_pairs_uniq_bkp, indice_pairs_uniq, indice_pairs_uniq_bkp,
out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out, out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out,
batch_size, out_shape, input_dims, ksize, stride, padding, dilation, batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
transposed, stream_int, mask_int_count); transposed, stream_int);
}} }}
}} }}
""") """)
...@@ -2030,21 +1966,21 @@ Your Conv Params: )" << "\\n"; ...@@ -2030,21 +1966,21 @@ Your Conv Params: )" << "\\n";
}} }}
}}else{{ }}else{{
if (!is_train){{ if (!is_train){{
sort_1d_by_key_allocator_mask_auto_v2(pair_mask_fwd[0], thrustalloc, sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc,
mask_argsort_fwd[0], stream_int, mask_int_count); mask_argsort_fwd[0], stream_int, mask_int_count);
}}else{{ }}else{{
sort_1d_by_key_allocator_mask_auto_v2(pair_mask_fwd[0], thrustalloc, sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc,
mask_argsort_fwd[0], stream_int, mask_int_count); mask_argsort_fwd[0], stream_int, mask_int_count);
sort_1d_by_key_allocator_mask_auto_v2(pair_mask_bwd[0], thrustalloc, sort_1d_by_key_allocator_v2(pair_mask_bwd[0], thrustalloc,
mask_argsort_bwd[0], stream_int, mask_int_count); mask_argsort_bwd[0], stream_int, mask_int_count);
}} }}
}} }}
}} }}
""") """)
code.raw(f""" code.raw(f"""
return std::make_tuple(mask_tensor, num_act_out, mask_int_count); return std::make_tuple(mask_tensor, num_act_out);
""") """)
return code.ret("std::tuple<tv::Tensor, int, int>") return code.ret("std::tuple<tv::Tensor, int>")
@pccm.pybind.mark @pccm.pybind.mark
@pccm.static_function @pccm.static_function
......
...@@ -936,7 +936,7 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -936,7 +936,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
"int, int, int, int, int>")) "int, int, int, int, int>"))
self.add_typedef( self.add_typedef(
"algo_cache_key_t", "std::tuple<int, int, int, int, " "algo_cache_key_t", "std::tuple<int, int, int, int, "
"int, int, int, int>") "int, int, int, int, bool>")
self.add_member("desps_", "std::vector<tv::gemm::ConvAlgoDesp>") self.add_member("desps_", "std::vector<tv::gemm::ConvAlgoDesp>")
self.add_member( self.add_member(
...@@ -1009,7 +1009,10 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1009,7 +1009,10 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code.arg("auto_fp32_accum", "bool") code.arg("auto_fp32_accum", "bool")
code.arg("fp32_accum", "bool") code.arg("fp32_accum", "bool")
code.arg("use_tf32", "bool", "true") code.arg("use_tf32", "bool", "true")
code.arg("bias", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()")
code.arg("scale", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()")
code.raw(f""" code.raw(f"""
tv::gemm::ConvOpType op_type_cpp = static_cast<tv::gemm::ConvOpType>(op_type); tv::gemm::ConvOpType op_type_cpp = static_cast<tv::gemm::ConvOpType>(op_type);
...@@ -1077,7 +1080,22 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1077,7 +1080,22 @@ class ConvTunerSimple(pccm.ParameterizedClass):
TV_ASSERT_RT_ERR(mask_width > 0, "eroro"); TV_ASSERT_RT_ERR(mask_width > 0, "eroro");
mask_width_valid = mask_width % desp.tile_shape[2] == 0; mask_width_valid = mask_width % desp.tile_shape[2] == 0;
}} }}
bool require_dynamic_mask = kv > 32;
if (desp.supported_ldx_conv(ldi, ldw, ldo) && mask_width_valid){{ if (desp.supported_ldx_conv(ldi, ldw, ldo) && mask_width_valid){{
if (!bias.empty() && !scale.empty()){{
TV_ASSERT_RT_ERR(bias.dtype() == scale.dtype(), "bias/scale dtype must equal to compute dtype in gemm");
if (desp.dcomp != bias.dtype()){{
continue;
}}
if (!desp.is_int8_inference){{
continue;
}}
}}else{{
if (desp.is_int8_inference){{
continue;
}}
}}
auto desp2 = desp; auto desp2 = desp;
if (desp.is_nvrtc){{ if (desp.is_nvrtc){{
if (!CompileInfo::algo_can_be_nvrtc_compiled(desp.min_arch)){{ if (!CompileInfo::algo_can_be_nvrtc_compiled(desp.min_arch)){{
...@@ -1093,6 +1111,15 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1093,6 +1111,15 @@ class ConvTunerSimple(pccm.ParameterizedClass):
}} }}
}} }}
}} }}
if (require_dynamic_mask){{
if (!desp.dynamic_mask){{
continue;
}}
}}else{{
if (desp.dynamic_mask){{
continue;
}}
}}
finally_algos.push_back(desp2); finally_algos.push_back(desp2);
}} }}
}} }}
...@@ -1138,11 +1165,14 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1138,11 +1165,14 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code.arg("beta", "float", "0.0") code.arg("beta", "float", "0.0")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int") code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.arg("mask_int_count", "int", "1")
code.arg("auto_fp32_accum", "bool", "true") code.arg("auto_fp32_accum", "bool", "true")
code.arg("fp32_accum", "bool", "false") code.arg("fp32_accum", "bool", "false")
code.arg("num_run", "int", "5") code.arg("num_run", "int", "5")
code.arg("use_tf32", "bool", "true") code.arg("use_tf32", "bool", "true")
code.arg("bias", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()")
code.arg("scale", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()")
if CUMM_CPU_ONLY_BUILD: if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")") code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
...@@ -1157,12 +1187,15 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1157,12 +1187,15 @@ class ConvTunerSimple(pccm.ParameterizedClass):
auto avail = get_all_available(inp, weight, output, layout_i, layout_w, auto avail = get_all_available(inp, weight, output, layout_i, layout_w,
layout_o, interleave_i, interleave_w, interleave_o, layout_o, interleave_i, interleave_w, interleave_o,
arch, op_type, mask_width, arch, op_type, mask_width,
auto_fp32_accum, fp32_accum, use_tf32); auto_fp32_accum, fp32_accum, use_tf32,
bias, scale);
inp = inp.clone(); inp = inp.clone();
weight = weight.clone(); weight = weight.clone();
bool need_dynamic_mask = weight.dim(1) > 32;
output = output.clone(); output = output.clone();
int channel_k = output.dim(1); int channel_k = output.dim(1);
int channel_c = inp.dim(1); int channel_c = inp.dim(1);
weight = weight.view(channel_k, -1, channel_c);
std::vector<ConvTuneResult> all_profile_res; std::vector<ConvTuneResult> all_profile_res;
std::unordered_set<int> splitk_tests; std::unordered_set<int> splitk_tests;
...@@ -1187,7 +1220,10 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1187,7 +1220,10 @@ class ConvTunerSimple(pccm.ParameterizedClass):
params.indices = indices; params.indices = indices;
params.mask = mask; params.mask = mask;
params.mask_output = mask_output; params.mask_output = mask_output;
params.mask_int_count = mask_int_count; if (desp.is_int8_inference){{
params.bias = bias;
params.scale = scale;
}}
// if (op_type_cpp == tv::gemm::ConvOpType::kBackwardWeight){{ // if (op_type_cpp == tv::gemm::ConvOpType::kBackwardWeight){{
// TV_ASSERT_RT_ERR(!mask_output.empty(), "error"); // TV_ASSERT_RT_ERR(!mask_output.empty(), "error");
...@@ -1246,7 +1282,7 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1246,7 +1282,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
}} }}
algo_cache_key_t key; algo_cache_key_t key;
key = std::make_tuple(int(inp.dtype()), int(weight.dtype()), key = std::make_tuple(int(inp.dtype()), int(weight.dtype()),
int(output.dtype()), channel_k, channel_c, std::get<0>(arch), std::get<1>(arch), mask_width); int(output.dtype()), channel_k, channel_c, std::get<0>(arch), std::get<1>(arch), mask_width, need_dynamic_mask);
{{ {{
std::lock_guard<std::mutex> guard(mutex_); std::lock_guard<std::mutex> guard(mutex_);
...@@ -1279,6 +1315,8 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1279,6 +1315,8 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code.arg("k, c", "int") code.arg("k, c", "int")
code.arg("arch", "std::tuple<int, int>") code.arg("arch", "std::tuple<int, int>")
code.arg("mask_width", "int", "-1") code.arg("mask_width", "int", "-1")
code.arg("need_dynamic_mask", "bool", "false")
if CUMM_CPU_ONLY_BUILD: if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")") code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
return code.ret("std::tuple<ConvTuneResult, bool>") return code.ret("std::tuple<ConvTuneResult, bool>")
...@@ -1290,7 +1328,7 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1290,7 +1328,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
}} }}
algo_cache_key_t key; algo_cache_key_t key;
key = std::make_tuple(i_dtype, w_dtype, o_dtype, k, c, key = std::make_tuple(i_dtype, w_dtype, o_dtype, k, c,
std::get<0>(arch), std::get<1>(arch), mask_width); std::get<0>(arch), std::get<1>(arch), mask_width, need_dynamic_mask);
ConvTuneResult res; ConvTuneResult res;
bool exists = false; bool exists = false;
{{ {{
...@@ -1338,7 +1376,6 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1338,7 +1376,6 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code.arg("beta", "float", "0.0") code.arg("beta", "float", "0.0")
code.arg("stream_int", f"std::uintptr_t", "0") code.arg("stream_int", f"std::uintptr_t", "0")
code.arg("mask_int_count", "int", "1")
code.arg("workspace", "tv::Tensor", "tv::Tensor()", code.arg("workspace", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()") "cumm.tensorview.Tensor = Tensor()")
code.arg("verbose", f"bool", "false") code.arg("verbose", f"bool", "false")
...@@ -1350,7 +1387,10 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1350,7 +1387,10 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code.arg("act_alpha", f"float", "0.0") code.arg("act_alpha", f"float", "0.0")
code.arg("act_beta", f"float", "0.0") code.arg("act_beta", f"float", "0.0")
code.arg("act_type", f"tv::gemm::Activation", "tv::gemm::Activation::kNone", "cumm.tensorview.gemm.Activation = Activation.None_") code.arg("act_type", f"tv::gemm::Activation", "tv::gemm::Activation::kNone", "cumm.tensorview.gemm.Activation = Activation.None_")
code.arg("scale", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()")
code.arg("output_add", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()")
if CUMM_CPU_ONLY_BUILD: if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")") code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
return code return code
...@@ -1376,6 +1416,7 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1376,6 +1416,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
params.output = output; params.output = output;
params.verbose = verbose; params.verbose = verbose;
params.bias = bias; params.bias = bias;
params.scale = scale;
params.split_k_slices = split_k_slices; params.split_k_slices = split_k_slices;
params.alpha = alpha; params.alpha = alpha;
...@@ -1383,7 +1424,9 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1383,7 +1424,9 @@ class ConvTunerSimple(pccm.ParameterizedClass):
params.act_alpha = act_alpha; params.act_alpha = act_alpha;
params.act_beta = act_beta; params.act_beta = act_beta;
params.act_type = act_type; params.act_type = act_type;
if (!output_add.empty() && desp.is_int8_inference){{
params.output_add = output_add;
}}
params.stream = stream_int; params.stream = stream_int;
params.mask_argsort = mask_argsort; params.mask_argsort = mask_argsort;
params.indices = indices; params.indices = indices;
...@@ -1393,7 +1436,6 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1393,7 +1436,6 @@ class ConvTunerSimple(pccm.ParameterizedClass):
params.mask_width = mask_width; params.mask_width = mask_width;
params.mask_output = mask_output; params.mask_output = mask_output;
params.reverse_mask = reverse_mask; params.reverse_mask = reverse_mask;
params.mask_int_count = mask_int_count;
if (timer.enable()){{ if (timer.enable()){{
params.timer = timer; params.timer = timer;
...@@ -2039,7 +2081,6 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2039,7 +2081,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
"std::vector<tv::Tensor>") "std::vector<tv::Tensor>")
code.arg("num_activate_out", "int") code.arg("num_activate_out", "int")
code.arg("masks", "tv::Tensor") code.arg("masks", "tv::Tensor")
code.arg("mask_int_count", "int")
code.arg("arch", "std::tuple<int, int>") code.arg("arch", "std::tuple<int, int>")
code.arg("is_train, is_subm", "bool", "false") code.arg("is_train, is_subm", "bool", "false")
...@@ -2055,7 +2096,13 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2055,7 +2096,13 @@ class ConvGemmOps(pccm.ParameterizedClass):
code.arg("act_beta", f"float", "0.0") code.arg("act_beta", f"float", "0.0")
code.arg("act_type", f"tv::gemm::Activation", "tv::gemm::Activation::kNone", "cumm.tensorview.gemm.Activation = Activation.None_") code.arg("act_type", f"tv::gemm::Activation", "tv::gemm::Activation::kNone", "cumm.tensorview.gemm.Activation = Activation.None_")
code.arg("use_tf32", "bool", "true") code.arg("use_tf32", "bool", "true")
code.arg("output_scale", "float", "1.0")
code.arg("scale", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()")
code.arg("output_add", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()")
code.arg("output_add_scale", "float", "1.0")
code.arg("output_dtype", "int", "-1")
if CUMM_CPU_ONLY_BUILD: if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")") code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
...@@ -2072,13 +2119,18 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2072,13 +2119,18 @@ class ConvGemmOps(pccm.ParameterizedClass):
int num_split = pair_mask_fwd_splits.size(); int num_split = pair_mask_fwd_splits.size();
TV_ASSERT_RT_ERR(num_mask == num_split, "error"); TV_ASSERT_RT_ERR(num_mask == num_split, "error");
filters = filters.view(out_channel, -1, in_channel); filters = filters.view(out_channel, -1, in_channel);
int kv = filters.dim(1);
int mask_int_count = tv::div_up(kv, 32);
tv::Tensor out_features; tv::Tensor out_features;
if (output_dtype < 0){{
output_dtype = int(features.dtype());
}}
if (is_subm){{ if (is_subm){{
out_features = allocator.empty({pccm.literal(AllocKeys.OutFeatures)}, out_features = allocator.empty({pccm.literal(AllocKeys.OutFeatures)},
{{num_activate_out, out_channel}}, features.dtype(), features.device(), stream_int); {{num_activate_out, out_channel}}, tv::DType(output_dtype), features.device(), stream_int);
}}else{{ }}else{{
out_features = allocator.zeros({pccm.literal(AllocKeys.OutFeatures)}, out_features = allocator.zeros({pccm.literal(AllocKeys.OutFeatures)},
{{num_activate_out, out_channel}}, features.dtype(), features.device(), stream_int); {{num_activate_out, out_channel}}, tv::DType(output_dtype), features.device(), stream_int);
}} }}
// auto start_ev = tv::CUDAEvent(); // auto start_ev = tv::CUDAEvent();
// start_ev.record(stream_int); // start_ev.record(stream_int);
...@@ -2113,20 +2165,24 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2113,20 +2165,24 @@ class ConvGemmOps(pccm.ParameterizedClass):
tv::Tensor(), // mask_output tv::Tensor(), // mask_output
1.0, 0.0, 1.0, 0.0,
stream_int, stream_int,
mask_int_count, // mask_int_count is after stream_int
auto_fp32_accum, auto_fp32_accum,
fp32_accum, fp32_accum,
5, // num_run 5, // num_run
use_tf32); use_tf32,
bias,
scale);
tune_res = std::get<0>(tune_res_time); tune_res = std::get<0>(tune_res_time);
}} }}
float alpha = 1.0;
if (tune_res.algo_desp.is_int8_inference){{
alpha = output_scale;
}}
int mask_width = tune_res.algo_desp.tile_shape[0]; int mask_width = tune_res.algo_desp.tile_shape[0];
tv::Tensor mask_output_fwd; tv::Tensor mask_output_fwd;
std::vector<tv::Tensor> mask_output_fwd_splits; std::vector<tv::Tensor> mask_output_fwd_splits;
if (is_train){{ if (is_train){{
mask_output_fwd = allocator.empty({pccm.literal(AllocKeys.MaskOutputFwd)}, mask_output_fwd = allocator.empty({pccm.literal(AllocKeys.MaskOutputFwd)},
{{num_split, tv::div_up(num_activate_out, mask_width) * mask_int_count}}, {{num_split, tv::div_up(num_activate_out, mask_width), mask_int_count}},
tv::uint32, features.device(), stream_int); tv::uint32, features.device(), stream_int);
for (int i = 0; i < num_split; ++i){{ for (int i = 0; i < num_split; ++i){{
mask_output_fwd_splits.push_back(mask_output_fwd[i]); mask_output_fwd_splits.push_back(mask_output_fwd[i]);
...@@ -2139,9 +2195,15 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2139,9 +2195,15 @@ class ConvGemmOps(pccm.ParameterizedClass):
for (int j = 0; j < num_split; ++j){{ for (int j = 0; j < num_split; ++j){{
float beta = j == 0 ? 0 : 1; float beta = j == 0 ? 0 : 1;
if (!bias.empty()){{ if (!bias.empty() && !tune_res.algo_desp.is_int8_inference){{
// use source as bias
beta = 1; beta = 1;
}} }}
if (!output_add.empty() && tune_res.algo_desp.is_int8_inference){{
// use source as bias
beta = output_add_scale;
}}
if (j > 0){{ if (j > 0){{
bias = tv::Tensor(); bias = tv::Tensor();
}} }}
...@@ -2158,9 +2220,8 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2158,9 +2220,8 @@ class ConvGemmOps(pccm.ParameterizedClass):
false, // reverse_mask false, // reverse_mask
mask_ptr[j], mask_ptr[j],
-1, // mask_width -1, // mask_width
1.0, beta, alpha, beta,
stream_int, stream_int,
mask_int_count, // mask_int_count is after stream_int
tv::Tensor(), // workspace tv::Tensor(), // workspace
false, // verbose false, // verbose
timer, timer,
...@@ -2168,7 +2229,9 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2168,7 +2229,9 @@ class ConvGemmOps(pccm.ParameterizedClass):
bias, bias,
act_alpha, act_alpha,
act_beta, act_beta,
act_type); act_type,
scale,
output_add);
}} }}
// auto end_ev = tv::CUDAEvent(); // auto end_ev = tv::CUDAEvent();
// end_ev.record(stream_int); // end_ev.record(stream_int);
...@@ -2193,7 +2256,6 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2193,7 +2256,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
code.arg("mask_output_fwd", "tv::Tensor") code.arg("mask_output_fwd", "tv::Tensor")
code.arg("masks", "tv::Tensor") code.arg("masks", "tv::Tensor")
code.arg("mask_int_count", "int")
code.arg("arch", "std::tuple<int, int>") code.arg("arch", "std::tuple<int, int>")
code.arg("mask_width", "int") code.arg("mask_width", "int")
...@@ -2286,7 +2348,6 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2286,7 +2348,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
tv::Tensor(), // mask_output tv::Tensor(), // mask_output
1.0, 0.0, 1.0, 0.0,
stream_int, stream_int,
mask_int_count,
auto_fp32_accum, auto_fp32_accum,
fp32_accum, fp32_accum,
5, // num_run 5, // num_run
...@@ -2311,7 +2372,6 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2311,7 +2372,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
tv::Tensor(), // mask_output tv::Tensor(), // mask_output
1.0, 0.0, 1.0, 0.0,
stream_int, stream_int,
mask_int_count,
auto_fp32_accum, auto_fp32_accum,
fp32_accum, fp32_accum,
5, // num_run 5, // num_run
...@@ -2354,7 +2414,6 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2354,7 +2414,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
-1, // mask_width -1, // mask_width
1.0, beta, 1.0, beta,
stream_int, stream_int,
mask_int_count,
tv::Tensor(), // workspace tv::Tensor(), // workspace
false, // verbose false, // verbose
timer); timer);
...@@ -2372,7 +2431,6 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2372,7 +2431,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
mask_width, mask_width,
1.0, 0.0, 1.0, 0.0,
stream_int, stream_int,
mask_int_count,
workspace, // workspace workspace, // workspace
false, // verbose false, // verbose
timer); timer);
......
...@@ -829,7 +829,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -829,7 +829,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
uint32_t filter_mask_in = (1u << ((RS - 1 - filter_offset) % 32)); uint32_t filter_mask_in = (1u << ((RS - 1 - filter_offset) % 32));
uint32_t filter_mask_in_offset = (RS - 1 - filter_offset) / 32; uint32_t filter_mask_in_offset = (RS - 1 - filter_offset) / 32;
// uint32_t filter_mask_center = (1u << (RS / 2)); // uint32_t filter_mask_center = (1u << (RS / 2));
loc_iter.set_filter_offset(filter_offset); loc_iter.set_filter_offset(filter_offset);
int indices_pair_size_mul_RS = indices_pair_size * RS; int indices_pair_size_mul_RS = indices_pair_size * RS;
int filter_offset_mul_indices_pair_size = filter_offset * indices_pair_size; int filter_offset_mul_indices_pair_size = filter_offset * indices_pair_size;
...@@ -1255,13 +1254,13 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1255,13 +1254,13 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
f"tv::array<int, {self.ndim}>") f"tv::array<int, {self.ndim}>")
code.arg("transposed", f"bool", "false") code.arg("transposed", f"bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0") code.arg("stream_int", f"std::uintptr_t", "0")
code.arg("mask_int_count", "int", "1")
code.raw(f""" code.raw(f"""
auto custream = reinterpret_cast<cudaStream_t>(stream_int); auto custream = reinterpret_cast<cudaStream_t>(stream_int);
// TODO stream // TODO stream
// TODO handle num input == 0 // TODO handle num input == 0
int kv = ksize.op<tv::arrayops::prod>(); int kv = ksize.op<tv::arrayops::prod>();
int mask_int_count = tv::div_up(kv, 32);
// indice_pairs_bwd: [kv, num_act_in] or empty // indice_pairs_bwd: [kv, num_act_in] or empty
// indice_pairs_fwd: [kv, num_act_out] // indice_pairs_fwd: [kv, num_act_out]
auto ctx = tv::Context(); auto ctx = tv::Context();
...@@ -1504,7 +1503,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1504,7 +1503,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
"cumm.tensorview.Tensor = Tensor()") "cumm.tensorview.Tensor = Tensor()")
code.arg("is_train", "bool", "true") code.arg("is_train", "bool", "true")
code.arg("stream_int", f"std::uintptr_t", "0") code.arg("stream_int", f"std::uintptr_t", "0")
code.arg("mask_int_count", "int", "1")
code.raw(f""" code.raw(f"""
int num_act_in_real = indices.dim(0); int num_act_in_real = indices.dim(0);
...@@ -1523,6 +1521,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1523,6 +1521,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
padding[i] = (ksize[i] / 2) * dilation[i]; padding[i] = (ksize[i] / 2) * dilation[i];
}} }}
int kv = ksize.op<tv::arrayops::prod>(); int kv = ksize.op<tv::arrayops::prod>();
int mask_int_count = tv::div_up(kv, 32);
TV_ASSERT_RT_ERR(kv == indice_pairs.dim(1), "error"); TV_ASSERT_RT_ERR(kv == indice_pairs.dim(1), "error");
// indice_pairs: [1 or 2, kv, num_act_in] if mask else [2, kv, num_act_in] // indice_pairs: [1 or 2, kv, num_act_in] if mask else [2, kv, num_act_in]
// out_inds: [MaxSize, {self.ndim + 1}] // out_inds: [MaxSize, {self.ndim + 1}]
...@@ -1556,8 +1555,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1556,8 +1555,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
if (!indice_pair_mask.empty()){{ if (!indice_pair_mask.empty()){{
TV_ASSERT_RT_ERR(indice_pairs.ndim() == 3, "error"); TV_ASSERT_RT_ERR(indice_pairs.ndim() == 3, "error");
TV_ASSERT_RT_ERR(indice_pairs.dim(0) == (is_train ? 2 : 1), "error"); TV_ASSERT_RT_ERR(indice_pairs.dim(0) == (is_train ? 2 : 1), "error");
TV_ASSERT_INVALID_ARG(indice_pair_mask.ndim() == 2, "error"); TV_ASSERT_INVALID_ARG(indice_pair_mask.ndim() == 3, "error");
// indice_pair_mask: [mask_split_count, num_act_in] // indice_pair_mask: [mask_split_count, num_act_in, num_mask_per_point]
if (indice_pair_mask.dim(0) == 2){{ if (indice_pair_mask.dim(0) == 2){{
auto mask_0 = indice_pair_mask[0].slice_first_axis(0, num_act_in_real); auto mask_0 = indice_pair_mask[0].slice_first_axis(0, num_act_in_real);
auto mask_1 = indice_pair_mask[1].slice_first_axis(0, num_act_in_real); auto mask_1 = indice_pair_mask[1].slice_first_axis(0, num_act_in_real);
...@@ -1571,13 +1570,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1571,13 +1570,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
indices.dim(0), indice_pairs.dim(2), kv, is_train); indices.dim(0), indice_pairs.dim(2), kv, is_train);
}}else{{ }}else{{
// indice_pair_mask: [1, num_act_in] // indice_pair_mask: [1, num_act_in, num_mask_per_point]
tv::cuda::Launch lanucher_fill(num_act_in_real, custream); tv::cuda::Launch lanucher_fill(num_act_in_real, custream);
if (mask_int_count == 1) if (mask_int_count == 1){{
lanucher_fill(cudakers::fill_kernel<uint32_t>, indice_pair_mask.data_ptr<uint32_t>(), (1 << (kv / 2)), indices.dim(0)); lanucher_fill(cudakers::fill_kernel<uint32_t>, indice_pair_mask.data_ptr<uint32_t>(), (1 << (kv / 2)), indices.dim(0));
else }}
lanucher_fill(init_subm_multiple_mask_int_kernel<uint32_t>, else{{
indice_pair_mask.data_ptr<uint32_t>(), kv / 2, indices.dim(0), mask_int_count); lanucher_fill(init_subm_multiple_mask_int_kernel<uint32_t>,
indice_pair_mask.data_ptr<uint32_t>(), kv / 2, indices.dim(0), mask_int_count);
}}
TV_ASSERT_RT_ERR(indice_pair_mask.dim(0) == 1, "error"); TV_ASSERT_RT_ERR(indice_pair_mask.dim(0) == 1, "error");
launcher_num_act_in(calc_subm_conv_indices_mask<table_t, {loc_type}>, loc_iter, hash, launcher_num_act_in(calc_subm_conv_indices_mask<table_t, {loc_type}>, loc_iter, hash,
indices.data_ptr<const int>(), indice_pairs.data_ptr<int>(), indices.data_ptr<const int>(), indice_pairs.data_ptr<int>(),
......
...@@ -465,14 +465,14 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -465,14 +465,14 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
table_launcher(kernel::assign_table<table_t>, hash, indices.data_ptr<int>(), table_launcher(kernel::assign_table<table_t>, hash, indices.data_ptr<int>(),
count.data_ptr<int>(), count.data_ptr<int>(),
layout, voxels.dim(0)); layout, voxels.dim(0));
auto count_cpu = count.cpu();
int count_val = count_cpu.item<int32_t>();
count_val = count_val > voxels.dim(0) ? voxels.dim(0) : count_val;
launcher(kernel::generate_voxel<table_t>, hash, points.data_ptr<const {self.dtype}>(), launcher(kernel::generate_voxel<table_t>, hash, points.data_ptr<const {self.dtype}>(),
point_indice_data.data_ptr<const int64_t>(), voxels.data_ptr<{self.dtype}>(), point_indice_data.data_ptr<const int64_t>(), voxels.data_ptr<{self.dtype}>(),
num_per_voxel.data_ptr<int>(), points_voxel_id.data_ptr<int64_t>(), points.dim(1), voxels.dim(1), num_per_voxel.data_ptr<int>(), points_voxel_id.data_ptr<int64_t>(), points.dim(1), voxels.dim(1),
voxels.dim(0), vsize_tv, coors_range_tv, voxels.dim(0), vsize_tv, coors_range_tv,
grid_size_tv, grid_stride_tv, points.dim(0)); grid_size_tv, grid_stride_tv, points.dim(0));
auto count_cpu = count.cpu();
int count_val = count_cpu.item<int32_t>();
count_val = count_val > voxels.dim(0) ? voxels.dim(0) : count_val;
auto voxel_launcher = tv::cuda::Launch(count_val, custream); auto voxel_launcher = tv::cuda::Launch(count_val, custream);
if (empty_mean){{ if (empty_mean){{
launcher(kernel::voxel_empty_fill_mean, voxels.data_ptr<{self.dtype}>(), launcher(kernel::voxel_empty_fill_mean, voxels.data_ptr<{self.dtype}>(),
......
...@@ -37,10 +37,23 @@ from spconv.utils import nullcontext ...@@ -37,10 +37,23 @@ from spconv.utils import nullcontext
from torch.nn.init import calculate_gain from torch.nn.init import calculate_gain
from cumm import tensorview as tv from cumm import tensorview as tv
from torch.nn import functional as F
FILTER_HWIO = False FILTER_HWIO = False
_MAX_NUM_VOXELS_DURING_TRAINING = "max_num_voxels_during_training" _MAX_NUM_VOXELS_DURING_TRAINING = "max_num_voxels_during_training"
def _apply_act(x: torch.Tensor, act_type: tv.gemm.Activation, act_alpha: float, act_beta: float):
if act_type == tv.gemm.Activation.None_:
return x
elif act_type == tv.gemm.Activation.ReLU:
return F.relu(x)
elif act_type == tv.gemm.Activation.Sigmoid:
return F.sigmoid(x)
elif act_type == tv.gemm.Activation.LeakyReLU:
return F.leaky_relu(x, act_alpha)
else:
raise NotImplementedError
class SparseConvolution(SparseModule): class SparseConvolution(SparseModule):
__constants__ = [ __constants__ = [
...@@ -104,7 +117,7 @@ class SparseConvolution(SparseModule): ...@@ -104,7 +117,7 @@ class SparseConvolution(SparseModule):
torch.zeros(1, dtype=torch.int32)) torch.zeros(1, dtype=torch.int32))
self.record_voxel_count = record_voxel_count self.record_voxel_count = record_voxel_count
if algo is None: if algo is None:
if kv <= 32 and not CPU_ONLY_BUILD: if kv <= 128 and not CPU_ONLY_BUILD:
if kv < 8: if kv < 8:
algo = ConvAlgo.MaskImplicitGemm algo = ConvAlgo.MaskImplicitGemm
else: else:
...@@ -139,6 +152,19 @@ class SparseConvolution(SparseModule): ...@@ -139,6 +152,19 @@ class SparseConvolution(SparseModule):
self.act_type = act_type self.act_type = act_type
self.act_alpha = act_alpha self.act_alpha = act_alpha
self.act_beta = act_beta self.act_beta = act_beta
self.enable_int8_test_mode: bool = False
self._int8_weight = torch.Tensor()
# calculated by max(abs(weight)) for each channel
self._int8_weight_scale = torch.Tensor()
# calculated by scale self.bias with _int8_input_scale
self._int8_bias = torch.Tensor()
# int8 inference must set _int8_input_scale
self._int8_input_scale: Optional[float] = None
# if _int8_output_scale unset, will execute s8 @ s8 => f16/f32 (weight dtype), i.e. dequantization
self._int8_output_scale: Optional[float] = None
if self.conv1x1: if self.conv1x1:
assert act_type == tv.gemm.Activation.None_, "conv1x1 don't support fused act" assert act_type == tv.gemm.Activation.None_, "conv1x1 don't support fused act"
self.reset_parameters() self.reset_parameters()
...@@ -151,11 +177,19 @@ class SparseConvolution(SparseModule): ...@@ -151,11 +177,19 @@ class SparseConvolution(SparseModule):
return getattr(self, _MAX_NUM_VOXELS_DURING_TRAINING) return getattr(self, _MAX_NUM_VOXELS_DURING_TRAINING)
return None return None
def set_int8_test(self, enable: bool, input_scale: float, output_scale: Optional[float] = None, weight_scale: Optional[torch.Tensor] = None):
self._int8_input_scale = input_scale
self._int8_output_scale = output_scale
if weight_scale is not None:
self._int8_weight_scale = weight_scale
self.enable_int8_test_mode = enable
def _load_weight_different_layout(self, state_dict, prefix, local_metadata, def _load_weight_different_layout(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, strict, missing_keys, unexpected_keys,
error_msgs): error_msgs):
if self.record_voxel_count and not self.subm and not self.inverse and _MAX_NUM_VOXELS_DURING_TRAINING not in state_dict: name = prefix + _MAX_NUM_VOXELS_DURING_TRAINING
state_dict[prefix + _MAX_NUM_VOXELS_DURING_TRAINING] = torch.zeros( if self.record_voxel_count and not self.subm and not self.inverse and name not in state_dict:
state_dict[name] = torch.zeros(
1, dtype=torch.int32) 1, dtype=torch.int32)
if not SAVED_WEIGHT_LAYOUT: if not SAVED_WEIGHT_LAYOUT:
return return
...@@ -255,7 +289,10 @@ class SparseConvolution(SparseModule): ...@@ -255,7 +289,10 @@ class SparseConvolution(SparseModule):
def is_inverseable(self): def is_inverseable(self):
return self.indice_key is not None and not self.subm return self.indice_key is not None and not self.subm
def forward(self, input: SparseConvTensor): def forward(self, input: SparseConvTensor, add_input: Optional[SparseConvTensor] = None):
return self._conv_forward(input, self.weight, self.bias, add_input)
def _conv_forward(self, input: SparseConvTensor, weight: torch.Tensor, bias: Optional[torch.Tensor], add_input: Optional[SparseConvTensor] = None):
assert isinstance(input, SparseConvTensor) assert isinstance(input, SparseConvTensor)
assert input.features.shape[ assert input.features.shape[
1] == self.in_channels, "channel size mismatch" 1] == self.in_channels, "channel size mismatch"
...@@ -264,9 +301,34 @@ class SparseConvolution(SparseModule): ...@@ -264,9 +301,34 @@ class SparseConvolution(SparseModule):
indices = input.indices indices = input.indices
spatial_shape = input.spatial_shape spatial_shape = input.spatial_shape
batch_size = input.batch_size batch_size = input.batch_size
bias_for_training = self.bias if self.training else None bias_for_training = bias if self.training else None
bias_for_infer = self.bias if not self.training else None bias_for_infer = bias if not self.training else None
output_scale = None
output_add_scale = 1.0
if self.enable_int8_test_mode:
assert not self.training, "must in eval mode"
assert self.algo == ConvAlgo.MaskImplicitGemm, "int8 inference only support MaskImplicitGemm"
assert bias_for_infer is not None, "conv-bn-relu must be fused"
assert self._int8_input_scale is not None
if features.dtype != torch.int8:
# quantize
features = torch.clamp(torch.round(features / self._int8_input_scale), -128, 127).to(torch.int8)
output_scale = self._int8_output_scale
int8_out_scale = output_scale
if int8_out_scale is None:
int8_out_scale = 1
if add_input is not None:
assert add_input.int8_scale is not None, "only support int8 add"
output_add_scale = add_input.int8_scale
if self._int8_weight.numel() == 0:
with torch.no_grad():
assert ALL_WEIGHT_IS_KRSC
weight_scales = torch.abs(weight).view(self.out_channels, -1).max(1)[0]
num_1s = [1] * (self.ndim + 1)
self._int8_weight = (weight / weight_scales.view(self.out_channels, *num_1s) * 127).to(torch.int8)
if self._int8_weight_scale.numel() == 0:
self._int8_weight_scale = int8_out_scale / (self._int8_input_scale * weight_scales)
self._int8_bias = bias_for_infer * int8_out_scale
if self.training: if self.training:
msg = "act don't support backward, only used in inference" msg = "act don't support backward, only used in inference"
assert self.act_type == tv.gemm.Activation.None_, msg assert self.act_type == tv.gemm.Activation.None_, msg
...@@ -310,18 +372,19 @@ class SparseConvolution(SparseModule): ...@@ -310,18 +372,19 @@ class SparseConvolution(SparseModule):
"out_channels": self.out_channels, "out_channels": self.out_channels,
} }
} }
if self.conv1x1: if self.conv1x1 and not self.enable_int8_test_mode:
# in int8 test mode, we don't implement conv1x1 via mm.
if FILTER_HWIO: if FILTER_HWIO:
features = torch.mm( features = torch.mm(
input.features, input.features,
self.weight.view(self.out_channels, self.in_channels).T) weight.view(self.out_channels, self.in_channels).T)
else: else:
features = torch.mm( features = torch.mm(
input.features, input.features,
self.weight.view(self.in_channels, self.out_channels)) weight.view(self.in_channels, self.out_channels))
if self.bias is not None: if bias is not None:
features += self.bias features += bias
out_tensor = out_tensor.replace_feature(features) out_tensor = out_tensor.replace_feature(features)
# padding may change spatial shape of conv 1x1. # padding may change spatial shape of conv 1x1.
out_tensor.spatial_shape = out_spatial_shape out_tensor.spatial_shape = out_spatial_shape
...@@ -413,7 +476,7 @@ class SparseConvolution(SparseModule): ...@@ -413,7 +476,7 @@ class SparseConvolution(SparseModule):
if self.subm: if self.subm:
out_features = Fsp.indice_subm_conv( out_features = Fsp.indice_subm_conv(
features, features,
self.weight, weight,
indice_pairs_calc, indice_pairs_calc,
indice_pair_num, indice_pair_num,
outids.shape[0], outids.shape[0],
...@@ -427,7 +490,7 @@ class SparseConvolution(SparseModule): ...@@ -427,7 +490,7 @@ class SparseConvolution(SparseModule):
if self.inverse: if self.inverse:
out_features = Fsp.indice_inverse_conv( out_features = Fsp.indice_inverse_conv(
features, features,
self.weight, weight,
indice_pairs_calc, indice_pairs_calc,
indice_pair_num, indice_pair_num,
outids.shape[0], outids.shape[0],
...@@ -440,7 +503,7 @@ class SparseConvolution(SparseModule): ...@@ -440,7 +503,7 @@ class SparseConvolution(SparseModule):
else: else:
out_features = Fsp.indice_conv( out_features = Fsp.indice_conv(
features, features,
self.weight, weight,
indice_pairs_calc, indice_pairs_calc,
indice_pair_num, indice_pair_num,
outids.shape[0], outids.shape[0],
...@@ -481,11 +544,13 @@ class SparseConvolution(SparseModule): ...@@ -481,11 +544,13 @@ class SparseConvolution(SparseModule):
mask_argsort_fwd_splits = datas.mask_argsort_fwd_splits mask_argsort_fwd_splits = datas.mask_argsort_fwd_splits
mask_argsort_bwd_splits = datas.mask_argsort_bwd_splits mask_argsort_bwd_splits = datas.mask_argsort_bwd_splits
masks = datas.masks masks = datas.masks
mask_int_count = datas.mask_int_count
assert self.subm, "only support reuse subm indices" assert self.subm, "only support reuse subm indices"
self._check_subm_reuse_valid(input, spatial_shape, self._check_subm_reuse_valid(input, spatial_shape,
datas) datas)
else: else:
if input.benchmark:
torch.cuda.synchronize()
t = time.time()
with input._timer.namespace("gen_pairs"): with input._timer.namespace("gen_pairs"):
# we need to gen bwd indices for regular conv # we need to gen bwd indices for regular conv
# because it may be inversed. # because it may be inversed.
...@@ -514,7 +579,11 @@ class SparseConvolution(SparseModule): ...@@ -514,7 +579,11 @@ class SparseConvolution(SparseModule):
print(msg, file=sys.stderr) print(msg, file=sys.stderr)
spconv_save_debug_data(indices) spconv_save_debug_data(indices)
raise e raise e
if input.benchmark:
torch.cuda.synchronize()
interval = time.time() - t
out_tensor.benchmark_record[
self.name]["indice_gen_time"].append(interval)
outids = res[0] outids = res[0]
num_inds_per_loc = res[1] num_inds_per_loc = res[1]
pair_fwd = res[2] pair_fwd = res[2]
...@@ -524,7 +593,6 @@ class SparseConvolution(SparseModule): ...@@ -524,7 +593,6 @@ class SparseConvolution(SparseModule):
mask_argsort_fwd_splits = res[6] mask_argsort_fwd_splits = res[6]
mask_argsort_bwd_splits = res[7] mask_argsort_bwd_splits = res[7]
masks = res[8] masks = res[8]
mask_int_count = res[9]
if self.indice_key is not None: if self.indice_key is not None:
indice_data = ImplicitGemmIndiceData( indice_data = ImplicitGemmIndiceData(
outids, outids,
...@@ -543,8 +611,7 @@ class SparseConvolution(SparseModule): ...@@ -543,8 +611,7 @@ class SparseConvolution(SparseModule):
ksize=self.kernel_size, ksize=self.kernel_size,
stride=self.stride, stride=self.stride,
padding=self.padding, padding=self.padding,
dilation=self.dilation, dilation=self.dilation)
mask_int_count=mask_int_count)
msg = f"your indice key {self.indice_key} already exists in this sparse tensor." msg = f"your indice key {self.indice_key} already exists in this sparse tensor."
assert self.indice_key not in indice_dict, msg assert self.indice_key not in indice_dict, msg
indice_dict[self.indice_key] = indice_data indice_dict[self.indice_key] = indice_data
...@@ -552,16 +619,43 @@ class SparseConvolution(SparseModule): ...@@ -552,16 +619,43 @@ class SparseConvolution(SparseModule):
torch.cuda.synchronize() torch.cuda.synchronize()
t = time.time() t = time.time()
num_activate_out = outids.shape[0] num_activate_out = outids.shape[0]
out_features = Fsp.implicit_gemm( weight_cur = weight
features, self.weight, pair_fwd, pair_bwd, bias_cur = bias_for_infer
pair_mask_fwd_splits, pair_mask_bwd_splits, if self.enable_int8_test_mode:
mask_argsort_fwd_splits, mask_argsort_bwd_splits, assert features.dtype == torch.int8, "in int8 test mode, feature must be int8"
num_activate_out, masks, mask_int_count, self.training, self.subm, weight_cur = self._int8_weight
input._timer, self.fp32_accum, bias_cur = self._int8_bias
bias_for_infer, if self.training:
self.act_alpha, out_features = Fsp.implicit_gemm(
self.act_beta, features, weight_cur, pair_fwd, pair_bwd,
self.act_type) pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits,
num_activate_out, masks, self.training, self.subm,
input._timer, self.fp32_accum,
bias_cur,
self.act_alpha,
self.act_beta,
self.act_type)
else:
output_dtype = None
if self._int8_output_scale is None:
output_dtype = weight.dtype
out_features, _, _ = ops.implicit_gemm(
features, weight_cur, pair_fwd, pair_mask_fwd_splits,
mask_argsort_fwd_splits,
num_activate_out, masks, self.training, self.subm,
input._timer, self.fp32_accum,
bias_cur,
self.act_alpha,
self.act_beta,
self.act_type,
# TODO do we really need output scale to scale bias in kernel?
1.0, # output_scale
self._int8_weight_scale, # scale
output_add=add_input.features if add_input is not None else None,
output_add_scale=output_add_scale,
output_dtype=output_dtype)
if bias_for_training is not None: if bias_for_training is not None:
out_features += bias_for_training out_features += bias_for_training
if input.benchmark: if input.benchmark:
...@@ -581,9 +675,10 @@ class SparseConvolution(SparseModule): ...@@ -581,9 +675,10 @@ class SparseConvolution(SparseModule):
out_tensor.indices = outids out_tensor.indices = outids
out_tensor.indice_dict = indice_dict out_tensor.indice_dict = indice_dict
out_tensor.spatial_shape = out_spatial_shape out_tensor.spatial_shape = out_spatial_shape
# print(outids.shape, spatial_shape, self.kernel_size, self.stride, self.padding, if add_input is not None and not self.enable_int8_test_mode:
# self.dilation, self.output_padding, out_spatial_shape) out_tensor = out_tensor.replace_feature(_apply_act(out_tensor.features + add_input.features, self.act_type, self.act_alpha, self.act_beta))
out_tensor.int8_scale = output_scale
return out_tensor return out_tensor
def _check_subm_reuse_valid(self, inp: SparseConvTensor, def _check_subm_reuse_valid(self, inp: SparseConvTensor,
......
...@@ -89,8 +89,7 @@ class ImplicitGemmIndiceData(object): ...@@ -89,8 +89,7 @@ class ImplicitGemmIndiceData(object):
out_spatial_shape, is_subm: bool, algo: ConvAlgo, out_spatial_shape, is_subm: bool, algo: ConvAlgo,
ksize: List[int], stride: List[int], dilation: List[int], padding: List[int], ksize: List[int], stride: List[int], dilation: List[int], padding: List[int],
in_voxel_num: Optional[Any] = None, in_voxel_num: Optional[Any] = None,
out_voxel_num: Optional[Any] = None, out_voxel_num: Optional[Any] = None):
mask_int_count: int=1):
self.out_indices = out_indices self.out_indices = out_indices
self.indices = indices self.indices = indices
self.pair_fwd = pair_fwd self.pair_fwd = pair_fwd
...@@ -111,7 +110,6 @@ class ImplicitGemmIndiceData(object): ...@@ -111,7 +110,6 @@ class ImplicitGemmIndiceData(object):
# in/out voxel_num is only used in tensorrt conversion. # in/out voxel_num is only used in tensorrt conversion.
self.in_voxel_num = in_voxel_num self.in_voxel_num = in_voxel_num
self.out_voxel_num = out_voxel_num self.out_voxel_num = out_voxel_num
self.mask_int_count = mask_int_count
def scatter_nd(indices, updates, shape): def scatter_nd(indices, updates, shape):
...@@ -183,6 +181,8 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -183,6 +181,8 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
self.thrust_allocator = ThrustSortAllocator(features.device) self.thrust_allocator = ThrustSortAllocator(features.device)
self._timer = CUDAKernelTimer(enable_timer) self._timer = CUDAKernelTimer(enable_timer)
self.force_algo = force_algo self.force_algo = force_algo
# for simple int8 torch inference
self.int8_scale: Optional[float] = None
def replace_feature(self, feature: torch.Tensor): def replace_feature(self, feature: torch.Tensor):
"""we need to replace x.features = F.relu(x.features) with x = x.replace_feature(F.relu(x.features)) """we need to replace x.features = F.relu(x.features) with x = x.replace_feature(F.relu(x.features))
......
...@@ -198,7 +198,6 @@ class SparseImplicitGemmFunction(Function): ...@@ -198,7 +198,6 @@ class SparseImplicitGemmFunction(Function):
mask_argsort_bwd_splits: List[torch.Tensor], mask_argsort_bwd_splits: List[torch.Tensor],
num_activate_out: int, num_activate_out: int,
masks: List[np.ndarray], masks: List[np.ndarray],
mask_int_count: int,
is_train: bool, is_train: bool,
is_subm: bool, is_subm: bool,
timer: CUDAKernelTimer = CUDAKernelTimer(False), timer: CUDAKernelTimer = CUDAKernelTimer(False),
...@@ -210,7 +209,7 @@ class SparseImplicitGemmFunction(Function): ...@@ -210,7 +209,7 @@ class SparseImplicitGemmFunction(Function):
try: try:
out, mask_out, mask_width = ops.implicit_gemm( out, mask_out, mask_width = ops.implicit_gemm(
features, filters, pair_fwd, pair_mask_fwd_splits, features, filters, pair_fwd, pair_mask_fwd_splits,
mask_argsort_fwd_splits, num_activate_out, masks, mask_int_count, is_train, mask_argsort_fwd_splits, num_activate_out, masks, is_train,
is_subm, timer, fp32_accum, bias, act_alpha, act_beta, is_subm, timer, fp32_accum, bias, act_alpha, act_beta,
act_type) act_type)
except Exception as e: except Exception as e:
...@@ -236,7 +235,6 @@ class SparseImplicitGemmFunction(Function): ...@@ -236,7 +235,6 @@ class SparseImplicitGemmFunction(Function):
ctx.masks = masks ctx.masks = masks
ctx.is_subm = is_subm ctx.is_subm = is_subm
ctx.fp32_accum = fp32_accum ctx.fp32_accum = fp32_accum
ctx.mask_int_count = mask_int_count
return out return out
@staticmethod @staticmethod
...@@ -255,7 +253,6 @@ class SparseImplicitGemmFunction(Function): ...@@ -255,7 +253,6 @@ class SparseImplicitGemmFunction(Function):
is_subm = ctx.is_subm is_subm = ctx.is_subm
timer = ctx.timer timer = ctx.timer
fp32_accum = ctx.fp32_accum fp32_accum = ctx.fp32_accum
mask_int_count = ctx.mask_int_count
try: try:
input_bp, filters_bp = ops.implicit_gemm_backward( input_bp, filters_bp = ops.implicit_gemm_backward(
...@@ -270,7 +267,6 @@ class SparseImplicitGemmFunction(Function): ...@@ -270,7 +267,6 @@ class SparseImplicitGemmFunction(Function):
mask_argsort_bwd_splits, mask_argsort_bwd_splits,
mask_output_fwd=mask_out, mask_output_fwd=mask_out,
masks=masks, masks=masks,
mask_int_count=mask_int_count,
mask_width=mask_width, mask_width=mask_width,
is_subm=is_subm, is_subm=is_subm,
timer=timer, timer=timer,
...@@ -286,7 +282,7 @@ class SparseImplicitGemmFunction(Function): ...@@ -286,7 +282,7 @@ class SparseImplicitGemmFunction(Function):
mask_argsort_bwd_splits, masks)) mask_argsort_bwd_splits, masks))
raise e raise e
None_9 = [None] * 17 None_9 = [None] * 16
return (input_bp, filters_bp, *None_9) return (input_bp, filters_bp, *None_9)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment