Commit aa26c99e authored by yan.yan's avatar yan.yan
Browse files

working on quantization

parent ee8c9465
...@@ -207,7 +207,7 @@ class Net(nn.Module): ...@@ -207,7 +207,7 @@ class Net(nn.Module):
pool_algo = algo pool_algo = algo
# pool_algo = ConvAlgo.Native # pool_algo = ConvAlgo.Native
self.net = spconv.SparseSequential( self.net = spconv.SparseSequential(
spconv.SubMConv3d(3, 64, 3, bias=False, indice_key="c0", spconv.SubMConv3d(16, 64, 3, bias=False, indice_key="c0",
algo=algo), algo=algo),
nn.BatchNorm1d(64), nn.BatchNorm1d(64),
nn.ReLU(), nn.ReLU(),
...@@ -373,6 +373,11 @@ class Net(nn.Module): ...@@ -373,6 +373,11 @@ class Net(nn.Module):
x = spconv.SparseConvTensor(features, coors, self.shape, batch_size, voxel_num=vx_num) x = spconv.SparseConvTensor(features, coors, self.shape, batch_size, voxel_num=vx_num)
return self.net(x) return self.net(x)
def _set_enable_int8_test_inplace(simple_module: torch.fx.GraphModule, enable: bool):
for m in simple_module.modules():
if isinstance(m, SparseConvolution):
if m.in_channels % 32 == 0 and m.out_channels % 32 == 0:
m.enable_int8_test_mode = enable
class MyTracer(torch.fx.Tracer): class MyTracer(torch.fx.Tracer):
...@@ -387,6 +392,7 @@ def main(): ...@@ -387,6 +392,7 @@ def main():
torch.backends.cudnn.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False
with open(Path(__file__).parent.parent / "test" / "data" / "test_spconv.pkl", "rb") as f: with open(Path(__file__).parent.parent / "test" / "data" / "test_spconv.pkl", "rb") as f:
(voxels, coors, spatial_shape) = pickle.load(f) (voxels, coors, spatial_shape) = pickle.load(f)
voxels = np.random.uniform(-1, 1, size=[voxels.shape[0], 16]).astype(np.float32)
np.random.seed(50051) np.random.seed(50051)
device = torch.device("cuda:0") device = torch.device("cuda:0")
device_cpu = torch.device("cpu:0") device_cpu = torch.device("cpu:0")
...@@ -408,6 +414,10 @@ def main(): ...@@ -408,6 +414,10 @@ def main():
out_fused = net_fused(voxels_th_cuda, coors_th_cuda, 1) out_fused = net_fused(voxels_th_cuda, coors_th_cuda, 1)
res = Fsp.sparse_add_hash_based(out_ref, out_fused.minus()) res = Fsp.sparse_add_hash_based(out_ref, out_fused.minus())
print(torch.linalg.norm(res.features)) print(torch.linalg.norm(res.features))
_set_enable_int8_test_inplace(net_fused, True)
qvoxels_cuda = voxels_th_cuda.to(torch.int8)
out_int8 = net_fused(qvoxels_cuda, coors_th_cuda, 1)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
\ No newline at end of file
...@@ -426,7 +426,7 @@ int main(int argc, char **argv) { ...@@ -426,7 +426,7 @@ int main(int argc, char **argv) {
{SPCONV_ALLOC_OUT_FEATURES, out_features}}; {SPCONV_ALLOC_OUT_FEATURES, out_features}};
StaticAllocator alloc2(tensor_dict); StaticAllocator alloc2(tensor_dict);
ConvTunerSimple tuner(ConvMain::get_all_conv_algo_desp()); ConvTunerSimple tuner(ConvMain::get_all_conv_algo_desp());
auto conv_res = ConvGemmOps::implicit_gemm( auto conv_run_status = ConvGemmOps::implicit_gemm(
alloc2, tuner, input_features_real, weights, pair_fwd_real, alloc2, tuner, input_features_real, weights, pair_fwd_real,
pair_mask_splits, mask_argsort_splits, num_act_out_real, pair_mask_splits, mask_argsort_splits, num_act_out_real,
mask_tensor, arch, false, is_subm, mask_tensor, arch, false, is_subm,
...@@ -435,7 +435,7 @@ int main(int argc, char **argv) { ...@@ -435,7 +435,7 @@ int main(int argc, char **argv) {
1.0 /*bias alpha, only used for leaky relu*/, 1.0 /*bias alpha, only used for leaky relu*/,
0.0 /*unused for now*/, tv::gemm::Activation::kReLU); 0.0 /*unused for now*/, tv::gemm::Activation::kReLU);
tv::ssprint("selected conv algo", tv::ssprint("selected conv algo",
std::get<1>(conv_res).algo_desp.__repr__()); std::get<1>(conv_run_status).algo_desp.__repr__());
// FINISH!!! // FINISH!!!
} }
// calc maximum number of output points. // calc maximum number of output points.
......
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import torch
import spconv.pytorch as spconv
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import contextlib
import torch.cuda.amp
@contextlib.contextmanager
def identity_ctx():
yield
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.net = spconv.SparseSequential(
nn.BatchNorm1d(1),
spconv.SubMConv2d(1, 32, 3, 1),
nn.ReLU(),
spconv.SubMConv2d(32, 64, 3, 1),
nn.ReLU(),
spconv.SparseConv2d(64, 64, 2, 2),
spconv.ToDense(),
)
self.fc1 = nn.Linear(14 * 14 * 64, 128)
self.fc2 = nn.Linear(128, 10)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
def forward(self, x: torch.Tensor):
# x: [N, 28, 28, 1], must be NHWC tensor
x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
# create SparseConvTensor manually: see SparseConvTensor.from_dense
x = self.net(x_sp)
x = torch.flatten(x, 1)
x = self.dropout1(x)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
scaler = torch.cuda.amp.grad_scaler.GradScaler()
amp_ctx = contextlib.nullcontext()
if args.fp16:
amp_ctx = torch.cuda.amp.autocast()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
with amp_ctx:
output = model(data)
loss = F.nll_loss(output, target)
scale = 1.0
if args.fp16:
assert loss.dtype is torch.float32
scaler.scale(loss).backward()
# scaler.step() first unscales the gradients of the optimizer's assigned params.
# If these gradients do not contain infs or NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
# scaler.unscale_(optim)
# Since the gradients of optimizer's assigned params are now unscaled, clips as usual.
# You may use the same value for max_norm here as you would without gradient scaling.
# torch.nn.utils.clip_grad_norm_(models[0].net.parameters(), max_norm=0.1)
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
scale = scaler.get_scale()
else:
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(args, model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
amp_ctx = contextlib.nullcontext()
if args.fp16:
amp_ctx = torch.cuda.amp.autocast()
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
with amp_ctx:
output = model(data)
test_loss += F.nll_loss(
output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(
dim=1,
keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print(
'\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size',
type=int,
default=64,
metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size',
type=int,
default=1000,
metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs',
type=int,
default=14,
metavar='N',
help='number of epochs to train (default: 14)')
parser.add_argument('--lr',
type=float,
default=1.0,
metavar='LR',
help='learning rate (default: 1.0)')
parser.add_argument('--gamma',
type=float,
default=0.7,
metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda',
action='store_true',
default=False,
help='disables CUDA training')
parser.add_argument('--seed',
type=int,
default=1,
metavar='S',
help='random seed (default: 1)')
parser.add_argument(
'--log-interval',
type=int,
default=10,
metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save-model',
action='store_true',
default=False,
help='For Saving the current Model')
parser.add_argument('--fp16',
action='store_true',
default=False,
help='For mixed precision training')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
'../data',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
# here we remove norm to get sparse tensor with lots of zeros
# transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size,
shuffle=True,
**kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
'../data',
train=False,
transform=transforms.Compose([
transforms.ToTensor(),
# here we remove norm to get sparse tensor with lots of zeros
# transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size,
shuffle=True,
**kwargs)
model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader)
scheduler.step()
if args.save_model:
torch.save(model.state_dict(), "mnist_cnn.pt")
if __name__ == '__main__':
main()
[build-system] [build-system]
requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm>=0.3.7"] requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm>=0.3.7"]
# requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm @ file:///io/dist/cumm_cu118-0.3.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"] # requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm @ file:///io/dist/cumm_cu120-0.3.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
...@@ -616,6 +616,7 @@ class SimpleConv: ...@@ -616,6 +616,7 @@ class SimpleConv:
algocore.get_conv_algo_desp_from_param(p) algocore.get_conv_algo_desp_from_param(p)
for p in ALL_IMPGEMM_PARAMS for p in ALL_IMPGEMM_PARAMS
] ]
self.all_desps = all_desps
self.prebuilt_desps = prebuilt_desps self.prebuilt_desps = prebuilt_desps
self.prebuilt_desp_names = {str(d) for d in prebuilt_desps} self.prebuilt_desp_names = {str(d) for d in prebuilt_desps}
...@@ -648,13 +649,13 @@ class SimpleConv: ...@@ -648,13 +649,13 @@ class SimpleConv:
tile_ms_list, tile_ns_list, tile_ks_list, tile_shape_to_algos) tile_ms_list, tile_ns_list, tile_ks_list, tile_shape_to_algos)
self.kc_forward_cache: Dict[Tuple[int, int, int, int, int, int, int, self.kc_forward_cache: Dict[Tuple[int, int, int, int, int, int, int,
int], int, bool],
BestConvAlgoByProfile] = {} # for forward BestConvAlgoByProfile] = {} # for forward
self.kc_dgrad_cache: Dict[Tuple[int, int, int, int, int, int, int, self.kc_dgrad_cache: Dict[Tuple[int, int, int, int, int, int, int,
int], BestConvAlgoByProfile] = { int, bool], BestConvAlgoByProfile] = {
} # for backward weight } # for backward weight
self.kc_wgrad_cache: Dict[Tuple[int, int, int, int, int, int, int, self.kc_wgrad_cache: Dict[Tuple[int, int, int, int, int, int, int,
int], BestConvAlgoByProfile] = { int, bool], BestConvAlgoByProfile] = {
} # for backward weight } # for backward weight
self._nvrtc_caches: Dict[Tuple[str, Tuple[int, int]], NVRTCParams] = {} self._nvrtc_caches: Dict[Tuple[str, Tuple[int, int]], NVRTCParams] = {}
...@@ -679,11 +680,12 @@ class SimpleConv: ...@@ -679,11 +680,12 @@ class SimpleConv:
op_type: ConvOpType, op_type: ConvOpType,
mask_width: int, mask_width: int,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
use_tf32: bool = True): use_tf32: bool = True,
bias: tv.Tensor = tv.Tensor(),
scale: tv.Tensor = tv.Tensor()):
avail_algos = get_available_algo_str_from_arch(arch) avail_algos = get_available_algo_str_from_arch(arch)
finally_algos: List[ConvAlgoDesp] = [] finally_algos: List[ConvAlgoDesp] = []
is_fp16 = inp.dtype == tv.float16 and weight.dtype == tv.float16 and out.dtype == tv.float16 is_fp16 = inp.dtype == tv.float16 and weight.dtype == tv.float16 # and out.dtype == tv.float16
use_f32_as_accum = False use_f32_as_accum = False
kv = int(np.prod(weight.shape[1:-1])) kv = int(np.prod(weight.shape[1:-1]))
# for 3d conv, if reduce axis is too large, may cause nan during # for 3d conv, if reduce axis is too large, may cause nan during
...@@ -703,6 +705,10 @@ class SimpleConv: ...@@ -703,6 +705,10 @@ class SimpleConv:
layout_w.interleave, layout_o.interleave, inp.dtype, layout_w.interleave, layout_o.interleave, inp.dtype,
weight.dtype, out.dtype, op_type.value) weight.dtype, out.dtype, op_type.value)
desps = self.static_key_to_desps.get(static_key, None) desps = self.static_key_to_desps.get(static_key, None)
# for d in self.all_desps:
# print(d)
# print(len(desps))
# breakpoint()
if desps is None or len(desps) == 0: if desps is None or len(desps) == 0:
return finally_algos return finally_algos
for desp in desps: for desp in desps:
...@@ -726,11 +732,21 @@ class SimpleConv: ...@@ -726,11 +732,21 @@ class SimpleConv:
ldw = weight.dim(-1) ldw = weight.dim(-1)
ldo = out.dim(-1) ldo = out.dim(-1)
mask_width_valid = True mask_width_valid = True
if desp.op_type.value == ConvOpType.kBackwardWeight.value: if desp.op_type.value == ConvOpType.kBackwardWeight.value:
assert mask_width > 0 assert mask_width > 0
mask_width_valid = mask_width % desp.tile_shape[2] == 0 mask_width_valid = mask_width % desp.tile_shape[2] == 0
require_dynamic_mask = kv > 32
if desp.supported_ldx_conv(ldi, ldw, ldo) and mask_width_valid: if desp.supported_ldx_conv(ldi, ldw, ldo) and mask_width_valid:
if not bias.empty() and not scale.empty():
# int8 inference, bias/scale dtype must equal to compute dtype in gemm
assert bias.dtype == scale.dtype
if desp.dcomp != bias.dtype:
continue
if not desp.is_int8_inference:
continue
else:
if desp.is_int8_inference:
continue
if desp.is_nvrtc: if desp.is_nvrtc:
if not CompileInfo.algo_can_be_nvrtc_compiled(desp.min_arch): if not CompileInfo.algo_can_be_nvrtc_compiled(desp.min_arch):
continue continue
...@@ -747,6 +763,12 @@ class SimpleConv: ...@@ -747,6 +763,12 @@ class SimpleConv:
continue continue
if SPCONV_DEBUG_NVRTC_KERNELS: if SPCONV_DEBUG_NVRTC_KERNELS:
desp.is_nvrtc = True desp.is_nvrtc = True
if require_dynamic_mask:
if not desp.dynamic_mask:
continue
else:
if desp.dynamic_mask:
continue
finally_algos.append(desp) finally_algos.append(desp)
return finally_algos return finally_algos
...@@ -758,11 +780,12 @@ class SimpleConv: ...@@ -758,11 +780,12 @@ class SimpleConv:
k: int, k: int,
c: int, c: int,
arch: Tuple[int, int], arch: Tuple[int, int],
mask_width: int = -1): mask_width: int = -1,
need_dynamic_mask: bool = False):
if not op_type == ConvOpType.kBackwardWeight: if not op_type == ConvOpType.kBackwardWeight:
# fwd and dgrad don't need # fwd and dgrad don't need
mask_width = -1 mask_width = -1
key = (i_dtype, w_dtype, o_dtype, k, c, arch[0], arch[1], mask_width) key = (i_dtype, w_dtype, o_dtype, k, c, arch[0], arch[1], mask_width, need_dynamic_mask)
if op_type == ConvOpType.kForward: if op_type == ConvOpType.kForward:
return self.kc_forward_cache.get(key, None) return self.kc_forward_cache.get(key, None)
elif op_type == ConvOpType.kBackwardInput: elif op_type == ConvOpType.kBackwardInput:
...@@ -795,8 +818,9 @@ class SimpleConv: ...@@ -795,8 +818,9 @@ class SimpleConv:
cudadevrt = str(cudadevrt_p) cudadevrt = str(cudadevrt_p)
mod = CummNVRTCModule([kernel], mod = CummNVRTCModule([kernel],
cudadevrt_path=cudadevrt, cudadevrt_path=cudadevrt,
verbose=False, verbose=True,
custom_names=custom_names) custom_names=custom_names,
verbose_path="/home/yy/Projects/spconv-release/spconv/build/dev_nvrtc_int8")
mod.load() mod.load()
return mod, kernel return mod, kernel
...@@ -824,7 +848,6 @@ class SimpleConv: ...@@ -824,7 +848,6 @@ class SimpleConv:
mask_argsort: tv.Tensor, mask_argsort: tv.Tensor,
indices: tv.Tensor, indices: tv.Tensor,
reverse_mask: bool, reverse_mask: bool,
mask_int_count: int = 1,
mask_filter: int = 0xffffffff, mask_filter: int = 0xffffffff,
mask_width: int = -1, mask_width: int = -1,
mask_output: tv.Tensor = tv.Tensor(), mask_output: tv.Tensor = tv.Tensor(),
...@@ -832,17 +855,20 @@ class SimpleConv: ...@@ -832,17 +855,20 @@ class SimpleConv:
beta: float = 0.0, beta: float = 0.0,
stream: int = 0, stream: int = 0,
fp32_accum: Optional[bool] = None, fp32_accum: Optional[bool] = None,
use_tf32: bool = True): use_tf32: bool = True,
bias: tv.Tensor = tv.Tensor(),
scale: tv.Tensor = tv.Tensor()):
avail = self.get_all_available(inp, weight, output, layout_i, layout_w, avail = self.get_all_available(inp, weight, output, layout_i, layout_w,
layout_o, arch, op_type, mask_width, layout_o, arch, op_type, mask_width,
fp32_accum, use_tf32) fp32_accum, use_tf32, bias, scale)
inp = inp.clone() inp = inp.clone()
weight = weight.clone() weight = weight.clone()
output = output.clone() output = output.clone()
print(len(avail), inp.dtype, weight.dtype, output.dtype, bias.dtype, scale.dtype, bias.empty(), scale.empty())
channel_k = output.dim(1) channel_k = output.dim(1)
channel_c = inp.dim(1) channel_c = inp.dim(1)
weight = weight.view([channel_k, -1, channel_c])
need_dynamic_mask = weight.dim(1) > 32
times: List[float] = [] times: List[float] = []
all_profile_res: List[BestConvAlgoByProfile] = [] all_profile_res: List[BestConvAlgoByProfile] = []
group_by_algo = {} group_by_algo = {}
...@@ -865,8 +891,9 @@ class SimpleConv: ...@@ -865,8 +891,9 @@ class SimpleConv:
params.indices = indices params.indices = indices
params.mask = mask params.mask = mask
params.mask_output = mask_output params.mask_output = mask_output
params.mask_int_count = mask_int_count if desp.is_int8_inference:
params.bias = bias
params.scale = scale
# if op_type == ConvOpType.kBackwardWeight: # if op_type == ConvOpType.kBackwardWeight:
# assert not mask_output.empty() # assert not mask_output.empty()
if op_type == ConvOpType.kBackwardInput: if op_type == ConvOpType.kBackwardInput:
...@@ -909,7 +936,7 @@ class SimpleConv: ...@@ -909,7 +936,7 @@ class SimpleConv:
# fwd and dgrad don't need # fwd and dgrad don't need
mask_width = -1 mask_width = -1
key = (inp.dtype, weight.dtype, output.dtype, channel_k, channel_c, key = (inp.dtype, weight.dtype, output.dtype, channel_k, channel_c,
arch[0], arch[1], mask_width) arch[0], arch[1], mask_width, need_dynamic_mask)
with self.lock: with self.lock:
if op_type == ConvOpType.kForward: if op_type == ConvOpType.kForward:
self.kc_forward_cache[key] = res self.kc_forward_cache[key] = res
...@@ -945,7 +972,9 @@ class SimpleConv: ...@@ -945,7 +972,9 @@ class SimpleConv:
act_alpha: float = 0.0, act_alpha: float = 0.0,
act_beta: float = 0.0, act_beta: float = 0.0,
act_type: tv.gemm.Activation = tv.gemm.Activation.None_, act_type: tv.gemm.Activation = tv.gemm.Activation.None_,
mask_int_count: Union[int, None] = None): scale: Optional[tv.Tensor] = None,
output_add: Optional[tv.Tensor] = None):
channel_k = output.dim(1) channel_k = output.dim(1)
channel_c = inp.dim(1) channel_c = inp.dim(1)
# GemmMainUnitTest.stream_synchronize(stream) # GemmMainUnitTest.stream_synchronize(stream)
...@@ -986,9 +1015,12 @@ class SimpleConv: ...@@ -986,9 +1015,12 @@ class SimpleConv:
params.mask_filter = mask_filter params.mask_filter = mask_filter
params.mask_output = mask_output params.mask_output = mask_output
params.reverse_mask = reverse_mask params.reverse_mask = reverse_mask
params.mask_int_count = mask_int_count
if bias is not None: if bias is not None:
params.bias = bias params.bias = bias
if output_add is not None and algo_desp.is_int8_inference:
params.output_add = output_add
if scale is not None and algo_desp.is_int8_inference:
params.scale = scale
if timer.enable: if timer.enable:
assert timer._timer is not None assert timer._timer is not None
params.timer = timer._timer params.timer = timer._timer
......
...@@ -36,8 +36,8 @@ from cumm.gemm.algospec import TensorOp ...@@ -36,8 +36,8 @@ from cumm.gemm.algospec import TensorOp
def _assign_gemm_desp_props(desp: Union[ConvAlgoDesp, GemmAlgoDesp], def _assign_gemm_desp_props(desp: Union[ConvAlgoDesp, GemmAlgoDesp],
p: Union[GemmAlgoParams, ConvAlgoParams]): p: Union[GemmAlgoParams, ConvAlgoParams]):
desp.dtype_a = p.dtype_a.tv_dtype desp.dtype_a = p.dtype_a.tv_dtype
desp.dtype_b = p.dtype_a.tv_dtype desp.dtype_b = p.dtype_b.tv_dtype
desp.dtype_c = p.dtype_a.tv_dtype desp.dtype_c = p.dtype_c.tv_dtype
desp.dacc = p.dtype_acc.tv_dtype desp.dacc = p.dtype_acc.tv_dtype
desp.dcomp = p.dtype_comp.tv_dtype desp.dcomp = p.dtype_comp.tv_dtype
desp.trans_a = p.trans_a desp.trans_a = p.trans_a
...@@ -87,6 +87,9 @@ def get_conv_algo_desp_from_param(p: ConvAlgoParams): ...@@ -87,6 +87,9 @@ def get_conv_algo_desp_from_param(p: ConvAlgoParams):
desp.element_per_access_a = ker.input_spec.input_iter_a.element_per_acc desp.element_per_access_a = ker.input_spec.input_iter_a.element_per_acc
desp.element_per_access_b = ker.input_spec.input_iter_b.element_per_acc desp.element_per_access_b = ker.input_spec.input_iter_b.element_per_acc
desp.element_per_access_c = ker.output_spec.out_iter.element_per_acc desp.element_per_access_c = ker.output_spec.out_iter.element_per_acc
desp.is_int8_inference = ker.int8_inference
desp.dynamic_mask = ker.dynamic_mask
desp.min_arch = ker.min_arch() desp.min_arch = ker.min_arch()
return desp return desp
...@@ -141,4 +144,6 @@ def get_conv_param_from_desp(desp: ConvAlgoDesp): ...@@ -141,4 +144,6 @@ def get_conv_param_from_desp(desp: ConvAlgoDesp):
desp.interleave_o) desp.interleave_o)
p.mask_sparse = desp.mask_sparse p.mask_sparse = desp.mask_sparse
p.increment_k_first = desp.increment_k_first p.increment_k_first = desp.increment_k_first
p.int8_inference = desp.is_int8_inference
p.dynamic_mask = desp.dynamic_mask
return p return p
...@@ -39,12 +39,12 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable( ...@@ -39,12 +39,12 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable(
from spconv.csrc.sparse.inference import InferenceOps from spconv.csrc.sparse.inference import InferenceOps
all_shuffle = SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_AMPERE_PARAMS all_shuffle = SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_AMPERE_PARAMS
all_shuffle = list(filter(lambda x: not x.is_nvrtc, all_shuffle)) # all_shuffle = list(filter(lambda x: not x.is_nvrtc, all_shuffle))
cu = GemmMainUnitTest(all_shuffle) cu = GemmMainUnitTest(all_shuffle)
cu.namespace = "cumm.gemm.main" cu.namespace = "cumm.gemm.main"
all_imp = (IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS + all_imp = (IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS +
IMPLGEMM_TURING_PARAMS + IMPLGEMM_AMPERE_PARAMS) IMPLGEMM_TURING_PARAMS + IMPLGEMM_AMPERE_PARAMS)
all_imp = list(filter(lambda x: not x.is_nvrtc, all_imp)) # all_imp = list(filter(lambda x: not x.is_nvrtc, all_imp))
convcu = ConvMainUnitTest(all_imp) convcu = ConvMainUnitTest(all_imp)
convcu.namespace = "cumm.conv.main" convcu.namespace = "cumm.conv.main"
gemmtuner = GemmTunerSimple(cu) gemmtuner = GemmTunerSimple(cu)
......
...@@ -619,14 +619,11 @@ IMPLGEMM_AMPERE_PARAMS = [ ...@@ -619,14 +619,11 @@ IMPLGEMM_AMPERE_PARAMS = [
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 64), (64, 32, 64), *gen_conv_params(ConvFwdAndBwdInput, (128, 64, 64), (64, 32, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -635,13 +632,14 @@ IMPLGEMM_AMPERE_PARAMS = [ ...@@ -635,13 +632,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 32), (64, 32, 32), *gen_conv_params(ConvFwdAndBwdInput, (128, 64, 32), (64, 32, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -650,13 +648,14 @@ IMPLGEMM_AMPERE_PARAMS = [ ...@@ -650,13 +648,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 64), (32, 64, 64), *gen_conv_params(ConvFwdAndBwdInput, (64, 128, 64), (32, 64, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -665,13 +664,14 @@ IMPLGEMM_AMPERE_PARAMS = [ ...@@ -665,13 +664,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32), *gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -680,13 +680,14 @@ IMPLGEMM_AMPERE_PARAMS = [ ...@@ -680,13 +680,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32), *gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16", "s8,s8,f32,s32,f32", "s8,s8,f32,s32,f16", "s8,s8,f16,s32,f32", "s8,s8,f16,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -695,13 +696,14 @@ IMPLGEMM_AMPERE_PARAMS = [ ...@@ -695,13 +696,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64), *gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -710,13 +712,14 @@ IMPLGEMM_AMPERE_PARAMS = [ ...@@ -710,13 +712,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 64), (64, 64, 64), *gen_conv_params(ConvFwdAndBwdInput, (128, 128, 64), (64, 64, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -725,13 +728,14 @@ IMPLGEMM_AMPERE_PARAMS = [ ...@@ -725,13 +728,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 256, 64), (64, 128, 64), *gen_conv_params(ConvFwdAndBwdInput, (128, 256, 64), (64, 128, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -740,13 +744,14 @@ IMPLGEMM_AMPERE_PARAMS = [ ...@@ -740,13 +744,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (256, 128, 64), (128, 64, 64), *gen_conv_params(ConvFwdAndBwdInput, (256, 128, 64), (128, 64, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -755,13 +760,14 @@ IMPLGEMM_AMPERE_PARAMS = [ ...@@ -755,13 +760,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 128), (64, 64, 128), *gen_conv_params(ConvFwdAndBwdInput, (128, 128, 128), (64, 64, 128),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3], [2, 3],
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -770,7 +776,8 @@ IMPLGEMM_AMPERE_PARAMS = [ ...@@ -770,7 +776,8 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
] ]
IMPLGEMM_TURING_PARAMS = [ IMPLGEMM_TURING_PARAMS = [
...@@ -779,7 +786,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -779,7 +786,7 @@ IMPLGEMM_TURING_PARAMS = [
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, 2,
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16", "s8,s8,f32,s32,f32", "s8,s8,f32,s32,f16", "s8,s8,f16,s32,f32", "s8,s8,f16,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -788,13 +795,14 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -788,13 +795,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64), *gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, 2,
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -803,13 +811,14 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -803,13 +811,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 64), (32, 64, 64), *gen_conv_params(ConvFwdAndBwdInput, (64, 128, 64), (32, 64, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, 2,
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -818,13 +827,14 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -818,13 +827,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32), *gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, 2,
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -833,13 +843,14 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -833,13 +843,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 64), (64, 32, 64), *gen_conv_params(ConvFwdAndBwdInput, (128, 64, 64), (64, 32, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, 2,
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -848,13 +859,14 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -848,13 +859,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 32), (64, 32, 32), *gen_conv_params(ConvFwdAndBwdInput, (128, 64, 32), (64, 32, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, 2,
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -863,13 +875,14 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -863,13 +875,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 256, 64), (64, 128, 64), *gen_conv_params(ConvFwdAndBwdInput, (128, 256, 64), (64, 128, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, 2,
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -878,13 +891,14 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -878,13 +891,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (256, 128, 64), (128, 64, 64), *gen_conv_params(ConvFwdAndBwdInput, (256, 128, 64), (128, 64, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, 2,
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -893,13 +907,14 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -893,13 +907,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 128), (64, 64, 128), *gen_conv_params(ConvFwdAndBwdInput, (128, 128, 128), (64, 64, 128),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, 2,
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -908,13 +923,14 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -908,13 +923,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 64), (64, 64, 64), *gen_conv_params(ConvFwdAndBwdInput, (128, 128, 64), (64, 64, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, 2,
"s8,s8,s8,s32,s32", ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -923,7 +939,8 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -923,7 +939,8 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True), is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (32, 16, 16), (16, 16, 16), *gen_conv_params(ConvFwdAndBwdInput, (32, 16, 16), (16, 16, 16),
......
...@@ -144,7 +144,7 @@ class SpconvOps: ...@@ -144,7 +144,7 @@ class SpconvOps:
""" """
... ...
@staticmethod @staticmethod
def generate_conv_inds_mask_stage2(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0, mask_int_count: int = 1) -> int: def generate_conv_inds_mask_stage2(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
""" """
Args: Args:
indices: indices:
...@@ -167,11 +167,10 @@ class SpconvOps: ...@@ -167,11 +167,10 @@ class SpconvOps:
dilation: dilation:
transposed: transposed:
stream_int: stream_int:
mask_int_count:
""" """
... ...
@staticmethod @staticmethod
def generate_conv_inds_stage2_mask_direct_table(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0, mask_int_count: int = 1) -> int: def generate_conv_inds_stage2_mask_direct_table(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
""" """
Args: Args:
indices: indices:
...@@ -194,11 +193,10 @@ class SpconvOps: ...@@ -194,11 +193,10 @@ class SpconvOps:
dilation: dilation:
transposed: transposed:
stream_int: stream_int:
mask_int_count:
""" """
... ...
@staticmethod @staticmethod
def generate_subm_conv_inds(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, batch_size: int, input_dims: List[int], ksize: List[int], dilation: List[int], indice_pair_mask: Tensor = Tensor(), backward: bool = False, stream_int: int = 0, mask_int_count: int = 1) -> int: def generate_subm_conv_inds(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, batch_size: int, input_dims: List[int], ksize: List[int], dilation: List[int], indice_pair_mask: Tensor = Tensor(), backward: bool = False, stream_int: int = 0) -> int:
""" """
Args: Args:
indices: indices:
...@@ -214,7 +212,6 @@ class SpconvOps: ...@@ -214,7 +212,6 @@ class SpconvOps:
indice_pair_mask: indice_pair_mask:
backward: backward:
stream_int: stream_int:
mask_int_count:
""" """
... ...
@staticmethod @staticmethod
...@@ -383,65 +380,25 @@ class SpconvOps: ...@@ -383,65 +380,25 @@ class SpconvOps:
""" """
... ...
@staticmethod @staticmethod
def sort_1d_by_key_allocator_mask32(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0) -> Tensor: def sort_1d_by_key_allocator(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0, mask_count: int = 1) -> Tensor:
""" """
Args: Args:
data: data:
alloc_func: alloc_func:
indices: indices:
stream: stream:
mask_count:
""" """
... ...
@staticmethod @staticmethod
def sort_1d_by_key_allocator_mask32_v2(data: Tensor, allocator, indices: Tensor = Tensor(), stream: int = 0) -> Tensor: def sort_1d_by_key_allocator_v2(data: Tensor, allocator, indices: Tensor = Tensor(), stream: int = 0, mask_count: int = 1) -> Tensor:
""" """
Args: Args:
data: data:
allocator: allocator:
indices: indices:
stream: stream:
""" mask_count:
...
@staticmethod
def sort_1d_by_key_allocator_mask128(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0) -> Tensor:
"""
Args:
data:
alloc_func:
indices:
stream:
"""
...
@staticmethod
def sort_1d_by_key_allocator_mask128_v2(data: Tensor, allocator, indices: Tensor = Tensor(), stream: int = 0) -> Tensor:
"""
Args:
data:
allocator:
indices:
stream:
"""
...
@staticmethod
def sort_1d_by_key_allocator_mask_auto(data: Tensor, alloc_param, indices: Tensor = Tensor(), stream: int = 0, mask_int_count: int = 1) -> Tensor:
"""
Args:
data:
alloc_param:
indices:
stream:
mask_int_count:
"""
...
@staticmethod
def sort_1d_by_key_allocator_mask_auto_v2(data: Tensor, alloc_param, indices: Tensor = Tensor(), stream: int = 0, mask_int_count: int = 1) -> Tensor:
"""
Args:
data:
alloc_param:
indices:
stream:
mask_int_count:
""" """
... ...
@staticmethod @staticmethod
...@@ -598,7 +555,7 @@ class SpconvOps: ...@@ -598,7 +555,7 @@ class SpconvOps:
""" """
... ...
@staticmethod @staticmethod
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1, timer: CUDAKernelTimer = CUDAKernelTimer(False), direct_table: bool = False, preallocated: Dict[str, Tensor] = {}) -> Tuple[Tensor, int, int]: def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1, timer: CUDAKernelTimer = CUDAKernelTimer(False), direct_table: bool = False, preallocated: Dict[str, Tensor] = {}) -> Tuple[Tensor, int]:
""" """
Args: Args:
allocator: allocator:
......
...@@ -20,7 +20,7 @@ class ConvTunerSimple: ...@@ -20,7 +20,7 @@ class ConvTunerSimple:
arch: arch:
""" """
... ...
def get_all_available(self, inp: Tensor, weight: Tensor, out: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], op_type: int, mask_width: int, auto_fp32_accum: bool, fp32_accum: bool, use_tf32: bool = True) -> List[ConvAlgoDesp]: def get_all_available(self, inp: Tensor, weight: Tensor, out: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], op_type: int, mask_width: int, auto_fp32_accum: bool, fp32_accum: bool, use_tf32: bool = True, bias: Tensor = Tensor(), scale: Tensor = Tensor()) -> List[ConvAlgoDesp]:
""" """
Args: Args:
inp: inp:
...@@ -38,6 +38,8 @@ class ConvTunerSimple: ...@@ -38,6 +38,8 @@ class ConvTunerSimple:
auto_fp32_accum: auto_fp32_accum:
fp32_accum: fp32_accum:
use_tf32: use_tf32:
bias:
scale:
""" """
... ...
def cached_get_nvrtc_params(self, desp: ConvAlgoDesp, arch: Tuple[int, int], stream_int: int) -> NVRTCParams: def cached_get_nvrtc_params(self, desp: ConvAlgoDesp, arch: Tuple[int, int], stream_int: int) -> NVRTCParams:
...@@ -48,7 +50,7 @@ class ConvTunerSimple: ...@@ -48,7 +50,7 @@ class ConvTunerSimple:
stream_int: stream_int:
""" """
... ...
def tune_and_cache(self, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], mask: Tensor, mask_argsort: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, mask_output: Tensor = Tensor(), alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, mask_int_count: int = 1, auto_fp32_accum: bool = True, fp32_accum: bool = False, num_run: int = 5, use_tf32: bool = True) -> Tuple[ConvTuneResult, float]: def tune_and_cache(self, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], mask: Tensor, mask_argsort: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, mask_output: Tensor = Tensor(), alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, auto_fp32_accum: bool = True, fp32_accum: bool = False, num_run: int = 5, use_tf32: bool = True, bias: Tensor = Tensor(), scale: Tensor = Tensor()) -> Tuple[ConvTuneResult, float]:
""" """
Args: Args:
op_type: op_type:
...@@ -72,14 +74,15 @@ class ConvTunerSimple: ...@@ -72,14 +74,15 @@ class ConvTunerSimple:
alpha: alpha:
beta: beta:
stream_int: stream_int:
mask_int_count:
auto_fp32_accum: auto_fp32_accum:
fp32_accum: fp32_accum:
num_run: num_run:
use_tf32: use_tf32:
bias:
scale:
""" """
... ...
def get_tuned_algo(self, op_type: int, i_dtype: int, w_dtype: int, o_dtype: int, k: int, c: int, arch: Tuple[int, int], mask_width: int = -1) -> Tuple[Any, bool]: def get_tuned_algo(self, op_type: int, i_dtype: int, w_dtype: int, o_dtype: int, k: int, c: int, arch: Tuple[int, int], mask_width: int = -1, need_dynamic_mask: bool = False) -> Tuple[Any, bool]:
""" """
Args: Args:
op_type: op_type:
...@@ -90,9 +93,10 @@ class ConvTunerSimple: ...@@ -90,9 +93,10 @@ class ConvTunerSimple:
c: c:
arch: arch:
mask_width: mask_width:
need_dynamic_mask:
""" """
... ...
def run_with_tuned_result(self, profile_res, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, mask: Tensor, mask_argsort: Tensor, mask_output: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, mask_int_count: int = 1, workspace: Tensor = Tensor(), verbose: bool = False, timer: CUDAKernelTimer = CUDAKernelTimer(false), force_nvrtc: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_) -> None: def run_with_tuned_result(self, profile_res, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, mask: Tensor, mask_argsort: Tensor, mask_output: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, workspace: Tensor = Tensor(), verbose: bool = False, timer: CUDAKernelTimer = CUDAKernelTimer(false), force_nvrtc: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_, scale: Tensor = Tensor(), output_add: Tensor = Tensor()) -> None:
""" """
Args: Args:
profile_res: profile_res:
...@@ -110,7 +114,6 @@ class ConvTunerSimple: ...@@ -110,7 +114,6 @@ class ConvTunerSimple:
alpha: alpha:
beta: beta:
stream_int: stream_int:
mask_int_count:
workspace: workspace:
verbose: verbose:
timer: timer:
...@@ -119,6 +122,8 @@ class ConvTunerSimple: ...@@ -119,6 +122,8 @@ class ConvTunerSimple:
act_alpha: act_alpha:
act_beta: act_beta:
act_type: act_type:
scale:
output_add:
""" """
... ...
def query_workspace_size(self, desp: ConvAlgoDesp, splitk: int, op_type: int, N: int, C: int, K: int, kv: int) -> int: def query_workspace_size(self, desp: ConvAlgoDesp, splitk: int, op_type: int, N: int, C: int, K: int, kv: int) -> int:
......
...@@ -63,7 +63,7 @@ class ConvGemmOps: ...@@ -63,7 +63,7 @@ class ConvGemmOps:
""" """
... ...
@staticmethod @staticmethod
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, mask_int_count: int, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_, use_tf32: bool = True) -> Tuple[int, Any]: def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_, use_tf32: bool = True, output_scale: float = 1.0, scale: Tensor = Tensor(), output_add: Tensor = Tensor(), output_add_scale: float = 1.0) -> Tuple[int, Any]:
""" """
Args: Args:
allocator: allocator:
...@@ -75,7 +75,6 @@ class ConvGemmOps: ...@@ -75,7 +75,6 @@ class ConvGemmOps:
mask_argsort_fwd_splits: mask_argsort_fwd_splits:
num_activate_out: num_activate_out:
masks: masks:
mask_int_count:
arch: arch:
is_train: is_train:
is_subm: is_subm:
...@@ -88,10 +87,14 @@ class ConvGemmOps: ...@@ -88,10 +87,14 @@ class ConvGemmOps:
act_beta: act_beta:
act_type: act_type:
use_tf32: use_tf32:
output_scale:
scale:
output_add:
output_add_scale:
""" """
... ...
@staticmethod @staticmethod
def implicit_gemm_backward(allocator, conv_tuner, features: Tensor, filters: Tensor, out_bp: Tensor, pair_fwd: Tensor, pair_bwd: Tensor, pair_mask_fwd_splits: List[Tensor], pair_mask_bwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], mask_argsort_bwd_splits: List[Tensor], mask_output_fwd: Tensor, masks: Tensor, mask_int_count: int, arch: Tuple[int, int], mask_width: int, is_subm: bool, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, use_tf32: bool = True) -> None: def implicit_gemm_backward(allocator, conv_tuner, features: Tensor, filters: Tensor, out_bp: Tensor, pair_fwd: Tensor, pair_bwd: Tensor, pair_mask_fwd_splits: List[Tensor], pair_mask_bwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], mask_argsort_bwd_splits: List[Tensor], mask_output_fwd: Tensor, masks: Tensor, arch: Tuple[int, int], mask_width: int, is_subm: bool, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, use_tf32: bool = True) -> None:
""" """
Args: Args:
allocator: allocator:
...@@ -107,7 +110,6 @@ class ConvGemmOps: ...@@ -107,7 +110,6 @@ class ConvGemmOps:
mask_argsort_bwd_splits: mask_argsort_bwd_splits:
mask_output_fwd: mask_output_fwd:
masks: masks:
mask_int_count:
arch: arch:
mask_width: mask_width:
is_subm: is_subm:
......
...@@ -30,7 +30,7 @@ from .alloc import ExternalAllocator, ThrustAllocator ...@@ -30,7 +30,7 @@ from .alloc import ExternalAllocator, ThrustAllocator
from spconv.constants import SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE, AllocKeys from spconv.constants import SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE, AllocKeys
import re import re
import os import os
from cumm.gemm.codeops import dispatch
class CustomThrustLib(pccm.Class): class CustomThrustLib(pccm.Class):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -462,7 +462,6 @@ class SpconvOps(pccm.Class): ...@@ -462,7 +462,6 @@ class SpconvOps(pccm.Class):
code.arg("ksize, stride, padding, dilation", f"std::vector<int>") code.arg("ksize, stride, padding, dilation", f"std::vector<int>")
code.arg("transposed", f"bool", "false") code.arg("transposed", f"bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int") code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.arg("mask_int_count", "int", "1")
code.raw(f""" code.raw(f"""
int ndim = indices.dim(1) - 1; int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim && TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim &&
...@@ -489,7 +488,7 @@ class SpconvOps(pccm.Class): ...@@ -489,7 +488,7 @@ class SpconvOps(pccm.Class):
indice_pairs_uniq, indice_pairs_uniq_before_sort, indice_pairs_uniq, indice_pairs_uniq_before_sort,
out_inds, mask_fwd, mask_bwd, out_inds, mask_fwd, mask_bwd,
num_out_act, batch_size, output_dims_, input_dims_, num_out_act, batch_size, output_dims_, input_dims_,
ksize_, stride_, padding_, dilation_, transposed, stream_int, mask_int_count); ksize_, stride_, padding_, dilation_, transposed, stream_int);
}} }}
""") """)
code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""") code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""")
...@@ -513,7 +512,6 @@ class SpconvOps(pccm.Class): ...@@ -513,7 +512,6 @@ class SpconvOps(pccm.Class):
code.arg("ksize, stride, padding, dilation", f"std::vector<int>") code.arg("ksize, stride, padding, dilation", f"std::vector<int>")
code.arg("transposed", f"bool", "false") code.arg("transposed", f"bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int") code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.arg("mask_int_count", "int", "1")
code.raw(f""" code.raw(f"""
int ndim = indices.dim(1) - 1; int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim && TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim &&
...@@ -540,7 +538,7 @@ class SpconvOps(pccm.Class): ...@@ -540,7 +538,7 @@ class SpconvOps(pccm.Class):
indice_pairs_uniq, indice_pairs_uniq_before_sort, indice_pairs_uniq, indice_pairs_uniq_before_sort,
out_inds, mask_fwd, mask_bwd, out_inds, mask_fwd, mask_bwd,
num_out_act, batch_size, output_dims_, input_dims_, num_out_act, batch_size, output_dims_, input_dims_,
ksize_, stride_, padding_, dilation_, transposed, stream_int, mask_int_count); ksize_, stride_, padding_, dilation_, transposed, stream_int);
}} }}
""") """)
code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""") code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""")
...@@ -561,7 +559,6 @@ class SpconvOps(pccm.Class): ...@@ -561,7 +559,6 @@ class SpconvOps(pccm.Class):
"cumm.tensorview.Tensor = Tensor()") "cumm.tensorview.Tensor = Tensor()")
code.arg("backward", "bool", "false") code.arg("backward", "bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int = 0") code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int = 0")
code.arg("mask_int_count", "int", "1")
code.raw(f""" code.raw(f"""
int ndim = indices.dim(1) - 1; int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(input_dims.size() == ndim && TV_ASSERT_RT_ERR(input_dims.size() == ndim &&
...@@ -582,7 +579,7 @@ class SpconvOps(pccm.Class): ...@@ -582,7 +579,7 @@ class SpconvOps(pccm.Class):
indice_pairs, out_inds, indice_num_per_loc, indice_pairs, out_inds, indice_num_per_loc,
batch_size, input_dims_, batch_size, input_dims_,
ksize_, dilation_, indice_pair_mask, backward, ksize_, dilation_, indice_pair_mask, backward,
stream_int, mask_int_count); stream_int);
}} }}
""") """)
code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""") code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""")
...@@ -909,7 +906,7 @@ class SpconvOps(pccm.Class): ...@@ -909,7 +906,7 @@ class SpconvOps(pccm.Class):
""") """)
return code return code
def sort_1d_by_key_allocator_template(self, use_allocator: bool, int_count: int = 1): def sort_1d_by_key_allocator_template(self, use_allocator: bool):
code = pccm.FunctionCode() code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD: if CUMM_CPU_ONLY_BUILD:
return code.make_invalid() return code.make_invalid()
...@@ -924,18 +921,7 @@ class SpconvOps(pccm.Class): ...@@ -924,18 +921,7 @@ class SpconvOps(pccm.Class):
"tv::Tensor()", "tv::Tensor()",
pyanno="cumm.tensorview.Tensor = Tensor()") pyanno="cumm.tensorview.Tensor = Tensor()")
code.arg("stream", "std::uintptr_t", "0", pyanno="int") code.arg("stream", "std::uintptr_t", "0", pyanno="int")
code.code_after_include = f""" code.arg("mask_count", "int", "1", pyanno="int")
template <typename T> struct SmallOrEqualTo {{
TV_HOST_DEVICE_INLINE T operator()(const T &x, const T &y) const {{
return x < y;
}}
}};
template <typename T> __global__ void mask_input(T* inp, T mask, int size){{
for (int i : tv::KernelLoopX<int>(size)){{
inp[i] &= mask;
}}
}}
"""
code.add_dependency(CustomThrustLib, TensorViewKernel) code.add_dependency(CustomThrustLib, TensorViewKernel)
code.add_param_class("cudakers", self.cuda_common_kernel) code.add_param_class("cudakers", self.cuda_common_kernel)
if not use_allocator: if not use_allocator:
...@@ -945,20 +931,29 @@ class SpconvOps(pccm.Class): ...@@ -945,20 +931,29 @@ class SpconvOps(pccm.Class):
code.raw(f""" code.raw(f"""
cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream); cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream);
if (indices.empty()){{ if (indices.empty()){{
indices = tv::empty({{data.dim(0) / {int_count}}}, tv::int32, 0); indices = tv::empty({{data.dim(0)}}, tv::int32, 0);
}} }}
tv::cuda::Launch launcher(data.dim(0), stream_cu); tv::cuda::Launch launcher(data.dim(0), stream_cu);
launcher(cudakers::arange_kernel<int32_t>, indices.data_ptr<int32_t>(), indices.dim(0)); launcher(cudakers::arange_kernel<int32_t>, indices.data_ptr<int32_t>(), indices.dim(0));
// auto timer = tv::CUDATimer(); // auto timer = tv::CUDATimer();
tv::dispatch<int32_t, uint32_t, int64_t, uint64_t>(data.dtype(), [&](auto I){{ """)
using T_ = TV_DECLTYPE(I); # nested tv::dispatch may cause compiler bug in msvc.
using T = {"T_" if int_count == 1 else f"thrust::tuple<{', '.join(['T_'] * int_count) }>"}; for dtype in dispatch(code, [dtypes.int32, dtypes.int64, dtypes.uint32, dtypes.uint64], "data.dtype()"):
code.raw(f"""
using T_ = {dtype};
tv::dispatch_int<1, 2, 3, 4>(mask_count, [&](auto IV){{
constexpr int I = TV_DECLTYPE(IV)::value;
// we can't use thrust::tuple in mp_repeat_c directly because
// thrust tuple actually has fixed size template arguments.
using T = tv::mp_rename<tv::mp_repeat_c<tv::mp_list<T_>, I>, thrust::tuple>;
thrust::device_ptr<T> ptr_tr(reinterpret_cast<T*>(data.data_ptr<T_>())); thrust::device_ptr<T> ptr_tr(reinterpret_cast<T*>(data.data_ptr<T_>()));
thrust::device_ptr<int32_t> ptr_k(indices.data_ptr<int32_t>()); thrust::device_ptr<int32_t> ptr_k(indices.data_ptr<int32_t>());
auto thrust_ctx = thrust::cuda::par.on(stream_cu); auto thrust_ctx = thrust::cuda::par.on(stream_cu);
auto ctx2 = thrust::cuda::par(allocator).on(stream_cu); auto ctx2 = thrust::cuda::par(allocator).on(stream_cu);
thrust::sort_by_key(ctx2, ptr_tr, ptr_tr + data.dim(0) / {int_count}, ptr_k); thrust::sort_by_key(ctx2, ptr_tr, ptr_tr + data.dim(0), ptr_k);
}}); }});
""")
code.raw(f"""
// tv::ssprint("SORT BY KEY TIME", data.dim(0), timer.report() / 1000.0); // tv::ssprint("SORT BY KEY TIME", data.dim(0), timer.report() / 1000.0);
return indices; return indices;
""") """)
...@@ -967,71 +962,12 @@ class SpconvOps(pccm.Class): ...@@ -967,71 +962,12 @@ class SpconvOps(pccm.Class):
@pccm.pybind.mark @pccm.pybind.mark
@_STATIC_FUNCTION @_STATIC_FUNCTION
def sort_1d_by_key_allocator_mask32(self): def sort_1d_by_key_allocator(self):
# for python
return self.sort_1d_by_key_allocator_template(False) return self.sort_1d_by_key_allocator_template(False)
@pccm.pybind.mark @pccm.pybind.mark
@_STATIC_FUNCTION
def sort_1d_by_key_allocator_mask32_v2(self):
# for python
return self.sort_1d_by_key_allocator_template(True)
@pccm.pybind.mark
@_STATIC_FUNCTION
def sort_1d_by_key_allocator_mask128(self):
# for python
return self.sort_1d_by_key_allocator_template(False, 4)
@pccm.pybind.mark
@_STATIC_FUNCTION
def sort_1d_by_key_allocator_mask128_v2(self):
# for python
return self.sort_1d_by_key_allocator_template(True, 4)
def sort_1d_by_key_allocator_mask_auto_template(self, use_allocator: bool):
code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
code.arg("data", "tv::Tensor")
if not use_allocator:
code.arg("alloc_param", "std::function<std::uintptr_t(std::size_t)>")
else:
code.arg("alloc_param", "ThrustAllocator&")
code.arg("indices",
"tv::Tensor",
"tv::Tensor()",
pyanno="cumm.tensorview.Tensor = Tensor()")
code.arg("stream", "std::uintptr_t", "0", pyanno="int")
code.arg("mask_int_count", "int", "1")
code.raw(f"""
switch (mask_int_count){{
case 1:
return sort_1d_by_key_allocator_mask32{"_v2" if use_allocator else ""}(data, alloc_param, indices, stream);
case 4:
return sort_1d_by_key_allocator_mask128{"_v2" if use_allocator else ""}(data, alloc_param, indices, stream);
default:
TV_ASSERT_RT_ERR(false, "Not implement for other mask_int_count");
return tv::Tensor();
}}
""")
return code.ret("tv::Tensor")
@pccm.pybind.mark
@pccm.static_function
def sort_1d_by_key_allocator_mask_auto(self):
return self.sort_1d_by_key_allocator_mask_auto_template(False)
@pccm.pybind.mark
@pccm.static_function
def sort_1d_by_key_allocator_mask_auto_v2(self):
return self.sort_1d_by_key_allocator_mask_auto_template(True)
@_STATIC_FUNCTION @_STATIC_FUNCTION
def sort_1d_by_key_allocator_v2(self): def sort_1d_by_key_allocator_v2(self):
# for cpp only
return self.sort_1d_by_key_allocator_template(True) return self.sort_1d_by_key_allocator_template(True)
@pccm.pybind.mark @pccm.pybind.mark
...@@ -1622,7 +1558,7 @@ class SpconvOps(pccm.Class): ...@@ -1622,7 +1558,7 @@ class SpconvOps(pccm.Class):
code.raw(f""" code.raw(f"""
int hash_size = 2 * num_act_out_bound; int hash_size = 2 * num_act_out_bound;
if (direct_table){{ if (direct_table){{
hash_size = int({SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE} * max_act_out_in_theory); hash_size = tv::align_up(int({SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE} * max_act_out_in_theory), 2);
}} }}
size_t res = 0; size_t res = 0;
if (subm){{ if (subm){{
...@@ -1655,7 +1591,7 @@ class SpconvOps(pccm.Class): ...@@ -1655,7 +1591,7 @@ class SpconvOps(pccm.Class):
max_act_out_in_theory, subm, use_int64_hash_k, direct_table); max_act_out_in_theory, subm, use_int64_hash_k, direct_table);
int hash_size = 2 * num_act_out_bound; int hash_size = 2 * num_act_out_bound;
if (direct_table){{ if (direct_table){{
hash_size = int({SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE} * max_act_out_in_theory); hash_size = tv::align_up(int({SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE} * max_act_out_in_theory), 2);
}} }}
if (use_int64_hash_k){{ if (use_int64_hash_k){{
auto ten = tv::from_blob(workspace, {{int64_t(hash_size)}}, tv::int64, 0); auto ten = tv::from_blob(workspace, {{int64_t(hash_size)}}, tv::int64, 0);
...@@ -1720,10 +1656,10 @@ class SpconvOps(pccm.Class): ...@@ -1720,10 +1656,10 @@ class SpconvOps(pccm.Class):
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int)); tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int));
auto conv_algo = static_cast<tv::gemm::SparseConvAlgo>(algo); auto conv_algo = static_cast<tv::gemm::SparseConvAlgo>(algo);
int kv = std::accumulate(ksize.begin(), ksize.end(), 1, std::multiplies<int>()); int kv = std::accumulate(ksize.begin(), ksize.end(), 1, std::multiplies<int>());
int mask_int_count = (kv + 31) / 32; int mask_int_count = tv::div_up(kv, 32);
if (mask_int_count > 1 && mask_int_count < 4) // if (mask_int_count > 1 && mask_int_count < 4)
mask_int_count = 4; // mask_int_count = 4;
TV_ASSERT_RT_ERR(mask_int_count == 1 || mask_int_count == 4, "Not Implement too large kernel"); // TV_ASSERT_RT_ERR(mask_int_count == 1 || mask_int_count == 4, "Not Implement too large kernel");
// TV_ASSERT_RT_ERR(kv <= 32, "currently only support ksize < 32"); // TV_ASSERT_RT_ERR(kv <= 32, "currently only support ksize < 32");
std::vector<int> out_shape; std::vector<int> out_shape;
if (!subm){{ if (!subm){{
...@@ -1845,14 +1781,14 @@ class SpconvOps(pccm.Class): ...@@ -1845,14 +1781,14 @@ class SpconvOps(pccm.Class):
pair_mask = preallocated.at({pccm.literal(AllocKeys.PairMask)}); pair_mask = preallocated.at({pccm.literal(AllocKeys.PairMask)});
}}else{{ }}else{{
pair_mask = allocator.empty({pccm.literal(AllocKeys.PairMask)}, pair_mask = allocator.empty({pccm.literal(AllocKeys.PairMask)},
{{mask_split_count, num_act_in * mask_int_count}}, tv::uint32, 0, stream_int); {{mask_split_count, num_act_in, mask_int_count}}, tv::uint32, 0, stream_int);
}} }}
generate_subm_conv_inds(indices, hash_k, hash_v, pair, out_inds, indice_num_per_loc, generate_subm_conv_inds(indices, hash_k, hash_v, pair, out_inds, indice_num_per_loc,
batch_size, input_dims, ksize, dilation, pair_mask, is_train, stream_int, mask_int_count); batch_size, input_dims, ksize, dilation, pair_mask, is_train, stream_int);
auto mask_argsort = allocator.empty({pccm.literal(AllocKeys.MaskArgSort)}, auto mask_argsort = allocator.empty({pccm.literal(AllocKeys.MaskArgSort)},
{{mask_split_count, num_act_in}}, tv::int32, 0, stream_int); {{mask_split_count, num_act_in}}, tv::int32, 0, stream_int);
for (int j = 0; j < mask_split_count; ++j){{ for (int j = 0; j < mask_split_count; ++j){{
sort_1d_by_key_allocator_mask_auto_v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int, mask_int_count); sort_1d_by_key_allocator_v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int, mask_int_count);
}} }}
""") """)
with code.else_(): with code.else_():
...@@ -1958,11 +1894,11 @@ Your Conv Params: )" << "\\n"; ...@@ -1958,11 +1894,11 @@ Your Conv Params: )" << "\\n";
pair_fwd = allocator.full_int({pccm.literal(AllocKeys.PairFwd)}, pair_fwd = allocator.full_int({pccm.literal(AllocKeys.PairFwd)},
{{kv, num_act_out}}, -1, indices.dtype(), indices.device(), stream_int); {{kv, num_act_out}}, -1, indices.dtype(), indices.device(), stream_int);
pair_mask_fwd = allocator.zeros({pccm.literal(AllocKeys.PairMask)}, pair_mask_fwd = allocator.zeros({pccm.literal(AllocKeys.PairMask)},
{{mask_split_count, num_act_out * mask_int_count}}, tv::uint32, 0, stream_int); {{mask_split_count, num_act_out, mask_int_count}}, tv::uint32, 0, stream_int);
pair_mask_bwd = tv::Tensor(); pair_mask_bwd = tv::Tensor();
if (is_train){{ if (is_train){{
pair_mask_bwd = allocator.zeros({pccm.literal(AllocKeys.PairMaskBwd)}, pair_mask_bwd = allocator.zeros({pccm.literal(AllocKeys.PairMaskBwd)},
{{mask_split_count, indices.dim(0) * mask_int_count}}, tv::uint32, 0, stream_int); {{mask_split_count, indices.dim(0), mask_int_count}}, tv::uint32, 0, stream_int);
}} }}
}} }}
if (!direct_table){{ if (!direct_table){{
...@@ -1994,13 +1930,13 @@ Your Conv Params: )" << "\\n"; ...@@ -1994,13 +1930,13 @@ Your Conv Params: )" << "\\n";
indice_pairs_uniq, indice_pairs_uniq_bkp, indice_pairs_uniq, indice_pairs_uniq_bkp,
out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out, out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out,
batch_size, out_shape, input_dims, ksize, stride, padding, dilation, batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
transposed, stream_int, mask_int_count); transposed, stream_int);
}}else{{ }}else{{
generate_conv_inds_mask_stage2(indices, hash_k, hash_v, pair_fwd, pair_bwd, generate_conv_inds_mask_stage2(indices, hash_k, hash_v, pair_fwd, pair_bwd,
indice_pairs_uniq, indice_pairs_uniq_bkp, indice_pairs_uniq, indice_pairs_uniq_bkp,
out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out, out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out,
batch_size, out_shape, input_dims, ksize, stride, padding, dilation, batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
transposed, stream_int, mask_int_count); transposed, stream_int);
}} }}
}} }}
""") """)
...@@ -2030,21 +1966,21 @@ Your Conv Params: )" << "\\n"; ...@@ -2030,21 +1966,21 @@ Your Conv Params: )" << "\\n";
}} }}
}}else{{ }}else{{
if (!is_train){{ if (!is_train){{
sort_1d_by_key_allocator_mask_auto_v2(pair_mask_fwd[0], thrustalloc, sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc,
mask_argsort_fwd[0], stream_int, mask_int_count); mask_argsort_fwd[0], stream_int, mask_int_count);
}}else{{ }}else{{
sort_1d_by_key_allocator_mask_auto_v2(pair_mask_fwd[0], thrustalloc, sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc,
mask_argsort_fwd[0], stream_int, mask_int_count); mask_argsort_fwd[0], stream_int, mask_int_count);
sort_1d_by_key_allocator_mask_auto_v2(pair_mask_bwd[0], thrustalloc, sort_1d_by_key_allocator_v2(pair_mask_bwd[0], thrustalloc,
mask_argsort_bwd[0], stream_int, mask_int_count); mask_argsort_bwd[0], stream_int, mask_int_count);
}} }}
}} }}
}} }}
""") """)
code.raw(f""" code.raw(f"""
return std::make_tuple(mask_tensor, num_act_out, mask_int_count); return std::make_tuple(mask_tensor, num_act_out);
""") """)
return code.ret("std::tuple<tv::Tensor, int, int>") return code.ret("std::tuple<tv::Tensor, int>")
@pccm.pybind.mark @pccm.pybind.mark
@pccm.static_function @pccm.static_function
......
...@@ -936,7 +936,7 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -936,7 +936,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
"int, int, int, int, int>")) "int, int, int, int, int>"))
self.add_typedef( self.add_typedef(
"algo_cache_key_t", "std::tuple<int, int, int, int, " "algo_cache_key_t", "std::tuple<int, int, int, int, "
"int, int, int, int>") "int, int, int, int, bool>")
self.add_member("desps_", "std::vector<tv::gemm::ConvAlgoDesp>") self.add_member("desps_", "std::vector<tv::gemm::ConvAlgoDesp>")
self.add_member( self.add_member(
...@@ -1009,7 +1009,10 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1009,7 +1009,10 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code.arg("auto_fp32_accum", "bool") code.arg("auto_fp32_accum", "bool")
code.arg("fp32_accum", "bool") code.arg("fp32_accum", "bool")
code.arg("use_tf32", "bool", "true") code.arg("use_tf32", "bool", "true")
code.arg("bias", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()")
code.arg("scale", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()")
code.raw(f""" code.raw(f"""
tv::gemm::ConvOpType op_type_cpp = static_cast<tv::gemm::ConvOpType>(op_type); tv::gemm::ConvOpType op_type_cpp = static_cast<tv::gemm::ConvOpType>(op_type);
...@@ -1077,7 +1080,22 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1077,7 +1080,22 @@ class ConvTunerSimple(pccm.ParameterizedClass):
TV_ASSERT_RT_ERR(mask_width > 0, "eroro"); TV_ASSERT_RT_ERR(mask_width > 0, "eroro");
mask_width_valid = mask_width % desp.tile_shape[2] == 0; mask_width_valid = mask_width % desp.tile_shape[2] == 0;
}} }}
bool require_dynamic_mask = kv > 32;
if (desp.supported_ldx_conv(ldi, ldw, ldo) && mask_width_valid){{ if (desp.supported_ldx_conv(ldi, ldw, ldo) && mask_width_valid){{
if (!bias.empty() && !scale.empty()){{
TV_ASSERT_RT_ERR(bias.dtype() == scale.dtype(), "bias/scale dtype must equal to compute dtype in gemm");
if (desp.dcomp != bias.dtype()){{
continue;
}}
if (!desp.is_int8_inference){{
continue;
}}
}}else{{
if (desp.is_int8_inference){{
continue;
}}
}}
auto desp2 = desp; auto desp2 = desp;
if (desp.is_nvrtc){{ if (desp.is_nvrtc){{
if (!CompileInfo::algo_can_be_nvrtc_compiled(desp.min_arch)){{ if (!CompileInfo::algo_can_be_nvrtc_compiled(desp.min_arch)){{
...@@ -1093,6 +1111,15 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1093,6 +1111,15 @@ class ConvTunerSimple(pccm.ParameterizedClass):
}} }}
}} }}
}} }}
if (require_dynamic_mask){{
if (!desp.dynamic_mask){{
continue;
}}
}}else{{
if (desp.dynamic_mask){{
continue;
}}
}}
finally_algos.push_back(desp2); finally_algos.push_back(desp2);
}} }}
}} }}
...@@ -1138,11 +1165,14 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1138,11 +1165,14 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code.arg("beta", "float", "0.0") code.arg("beta", "float", "0.0")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int") code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.arg("mask_int_count", "int", "1")
code.arg("auto_fp32_accum", "bool", "true") code.arg("auto_fp32_accum", "bool", "true")
code.arg("fp32_accum", "bool", "false") code.arg("fp32_accum", "bool", "false")
code.arg("num_run", "int", "5") code.arg("num_run", "int", "5")
code.arg("use_tf32", "bool", "true") code.arg("use_tf32", "bool", "true")
code.arg("bias", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()")
code.arg("scale", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()")
if CUMM_CPU_ONLY_BUILD: if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")") code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
...@@ -1157,12 +1187,15 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1157,12 +1187,15 @@ class ConvTunerSimple(pccm.ParameterizedClass):
auto avail = get_all_available(inp, weight, output, layout_i, layout_w, auto avail = get_all_available(inp, weight, output, layout_i, layout_w,
layout_o, interleave_i, interleave_w, interleave_o, layout_o, interleave_i, interleave_w, interleave_o,
arch, op_type, mask_width, arch, op_type, mask_width,
auto_fp32_accum, fp32_accum, use_tf32); auto_fp32_accum, fp32_accum, use_tf32,
bias, scale);
inp = inp.clone(); inp = inp.clone();
weight = weight.clone(); weight = weight.clone();
bool need_dynamic_mask = weight.dim(1) > 32;
output = output.clone(); output = output.clone();
int channel_k = output.dim(1); int channel_k = output.dim(1);
int channel_c = inp.dim(1); int channel_c = inp.dim(1);
weight = weight.view(channel_k, -1, channel_c);
std::vector<ConvTuneResult> all_profile_res; std::vector<ConvTuneResult> all_profile_res;
std::unordered_set<int> splitk_tests; std::unordered_set<int> splitk_tests;
...@@ -1187,7 +1220,10 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1187,7 +1220,10 @@ class ConvTunerSimple(pccm.ParameterizedClass):
params.indices = indices; params.indices = indices;
params.mask = mask; params.mask = mask;
params.mask_output = mask_output; params.mask_output = mask_output;
params.mask_int_count = mask_int_count; if (desp.is_int8_inference){{
params.bias = bias;
params.scale = scale;
}}
// if (op_type_cpp == tv::gemm::ConvOpType::kBackwardWeight){{ // if (op_type_cpp == tv::gemm::ConvOpType::kBackwardWeight){{
// TV_ASSERT_RT_ERR(!mask_output.empty(), "error"); // TV_ASSERT_RT_ERR(!mask_output.empty(), "error");
...@@ -1246,7 +1282,7 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1246,7 +1282,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
}} }}
algo_cache_key_t key; algo_cache_key_t key;
key = std::make_tuple(int(inp.dtype()), int(weight.dtype()), key = std::make_tuple(int(inp.dtype()), int(weight.dtype()),
int(output.dtype()), channel_k, channel_c, std::get<0>(arch), std::get<1>(arch), mask_width); int(output.dtype()), channel_k, channel_c, std::get<0>(arch), std::get<1>(arch), mask_width, need_dynamic_mask);
{{ {{
std::lock_guard<std::mutex> guard(mutex_); std::lock_guard<std::mutex> guard(mutex_);
...@@ -1279,6 +1315,8 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1279,6 +1315,8 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code.arg("k, c", "int") code.arg("k, c", "int")
code.arg("arch", "std::tuple<int, int>") code.arg("arch", "std::tuple<int, int>")
code.arg("mask_width", "int", "-1") code.arg("mask_width", "int", "-1")
code.arg("need_dynamic_mask", "bool", "false")
if CUMM_CPU_ONLY_BUILD: if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")") code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
return code.ret("std::tuple<ConvTuneResult, bool>") return code.ret("std::tuple<ConvTuneResult, bool>")
...@@ -1290,7 +1328,7 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1290,7 +1328,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
}} }}
algo_cache_key_t key; algo_cache_key_t key;
key = std::make_tuple(i_dtype, w_dtype, o_dtype, k, c, key = std::make_tuple(i_dtype, w_dtype, o_dtype, k, c,
std::get<0>(arch), std::get<1>(arch), mask_width); std::get<0>(arch), std::get<1>(arch), mask_width, need_dynamic_mask);
ConvTuneResult res; ConvTuneResult res;
bool exists = false; bool exists = false;
{{ {{
...@@ -1338,7 +1376,6 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1338,7 +1376,6 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code.arg("beta", "float", "0.0") code.arg("beta", "float", "0.0")
code.arg("stream_int", f"std::uintptr_t", "0") code.arg("stream_int", f"std::uintptr_t", "0")
code.arg("mask_int_count", "int", "1")
code.arg("workspace", "tv::Tensor", "tv::Tensor()", code.arg("workspace", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()") "cumm.tensorview.Tensor = Tensor()")
code.arg("verbose", f"bool", "false") code.arg("verbose", f"bool", "false")
...@@ -1350,7 +1387,10 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1350,7 +1387,10 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code.arg("act_alpha", f"float", "0.0") code.arg("act_alpha", f"float", "0.0")
code.arg("act_beta", f"float", "0.0") code.arg("act_beta", f"float", "0.0")
code.arg("act_type", f"tv::gemm::Activation", "tv::gemm::Activation::kNone", "cumm.tensorview.gemm.Activation = Activation.None_") code.arg("act_type", f"tv::gemm::Activation", "tv::gemm::Activation::kNone", "cumm.tensorview.gemm.Activation = Activation.None_")
code.arg("scale", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()")
code.arg("output_add", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()")
if CUMM_CPU_ONLY_BUILD: if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")") code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
return code return code
...@@ -1376,6 +1416,7 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1376,6 +1416,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
params.output = output; params.output = output;
params.verbose = verbose; params.verbose = verbose;
params.bias = bias; params.bias = bias;
params.scale = scale;
params.split_k_slices = split_k_slices; params.split_k_slices = split_k_slices;
params.alpha = alpha; params.alpha = alpha;
...@@ -1383,7 +1424,9 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1383,7 +1424,9 @@ class ConvTunerSimple(pccm.ParameterizedClass):
params.act_alpha = act_alpha; params.act_alpha = act_alpha;
params.act_beta = act_beta; params.act_beta = act_beta;
params.act_type = act_type; params.act_type = act_type;
if (!output_add.empty() && desp.is_int8_inference){{
params.output_add = output_add;
}}
params.stream = stream_int; params.stream = stream_int;
params.mask_argsort = mask_argsort; params.mask_argsort = mask_argsort;
params.indices = indices; params.indices = indices;
...@@ -1393,7 +1436,6 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1393,7 +1436,6 @@ class ConvTunerSimple(pccm.ParameterizedClass):
params.mask_width = mask_width; params.mask_width = mask_width;
params.mask_output = mask_output; params.mask_output = mask_output;
params.reverse_mask = reverse_mask; params.reverse_mask = reverse_mask;
params.mask_int_count = mask_int_count;
if (timer.enable()){{ if (timer.enable()){{
params.timer = timer; params.timer = timer;
...@@ -2039,7 +2081,6 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2039,7 +2081,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
"std::vector<tv::Tensor>") "std::vector<tv::Tensor>")
code.arg("num_activate_out", "int") code.arg("num_activate_out", "int")
code.arg("masks", "tv::Tensor") code.arg("masks", "tv::Tensor")
code.arg("mask_int_count", "int")
code.arg("arch", "std::tuple<int, int>") code.arg("arch", "std::tuple<int, int>")
code.arg("is_train, is_subm", "bool", "false") code.arg("is_train, is_subm", "bool", "false")
...@@ -2055,7 +2096,13 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2055,7 +2096,13 @@ class ConvGemmOps(pccm.ParameterizedClass):
code.arg("act_beta", f"float", "0.0") code.arg("act_beta", f"float", "0.0")
code.arg("act_type", f"tv::gemm::Activation", "tv::gemm::Activation::kNone", "cumm.tensorview.gemm.Activation = Activation.None_") code.arg("act_type", f"tv::gemm::Activation", "tv::gemm::Activation::kNone", "cumm.tensorview.gemm.Activation = Activation.None_")
code.arg("use_tf32", "bool", "true") code.arg("use_tf32", "bool", "true")
code.arg("output_scale", "float", "1.0")
code.arg("scale", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()")
code.arg("output_add", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()")
code.arg("output_add_scale", "float", "1.0")
code.arg("output_dtype", "int", "-1")
if CUMM_CPU_ONLY_BUILD: if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")") code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
...@@ -2072,13 +2119,18 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2072,13 +2119,18 @@ class ConvGemmOps(pccm.ParameterizedClass):
int num_split = pair_mask_fwd_splits.size(); int num_split = pair_mask_fwd_splits.size();
TV_ASSERT_RT_ERR(num_mask == num_split, "error"); TV_ASSERT_RT_ERR(num_mask == num_split, "error");
filters = filters.view(out_channel, -1, in_channel); filters = filters.view(out_channel, -1, in_channel);
int kv = filters.dim(1);
int mask_int_count = tv::div_up(kv, 32);
tv::Tensor out_features; tv::Tensor out_features;
if (output_dtype < 0){{
output_dtype = int(features.dtype());
}}
if (is_subm){{ if (is_subm){{
out_features = allocator.empty({pccm.literal(AllocKeys.OutFeatures)}, out_features = allocator.empty({pccm.literal(AllocKeys.OutFeatures)},
{{num_activate_out, out_channel}}, features.dtype(), features.device(), stream_int); {{num_activate_out, out_channel}}, tv::DType(output_dtype), features.device(), stream_int);
}}else{{ }}else{{
out_features = allocator.zeros({pccm.literal(AllocKeys.OutFeatures)}, out_features = allocator.zeros({pccm.literal(AllocKeys.OutFeatures)},
{{num_activate_out, out_channel}}, features.dtype(), features.device(), stream_int); {{num_activate_out, out_channel}}, tv::DType(output_dtype), features.device(), stream_int);
}} }}
// auto start_ev = tv::CUDAEvent(); // auto start_ev = tv::CUDAEvent();
// start_ev.record(stream_int); // start_ev.record(stream_int);
...@@ -2113,20 +2165,24 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2113,20 +2165,24 @@ class ConvGemmOps(pccm.ParameterizedClass):
tv::Tensor(), // mask_output tv::Tensor(), // mask_output
1.0, 0.0, 1.0, 0.0,
stream_int, stream_int,
mask_int_count, // mask_int_count is after stream_int
auto_fp32_accum, auto_fp32_accum,
fp32_accum, fp32_accum,
5, // num_run 5, // num_run
use_tf32); use_tf32,
bias,
scale);
tune_res = std::get<0>(tune_res_time); tune_res = std::get<0>(tune_res_time);
}} }}
float alpha = 1.0;
if (tune_res.algo_desp.is_int8_inference){{
alpha = output_scale;
}}
int mask_width = tune_res.algo_desp.tile_shape[0]; int mask_width = tune_res.algo_desp.tile_shape[0];
tv::Tensor mask_output_fwd; tv::Tensor mask_output_fwd;
std::vector<tv::Tensor> mask_output_fwd_splits; std::vector<tv::Tensor> mask_output_fwd_splits;
if (is_train){{ if (is_train){{
mask_output_fwd = allocator.empty({pccm.literal(AllocKeys.MaskOutputFwd)}, mask_output_fwd = allocator.empty({pccm.literal(AllocKeys.MaskOutputFwd)},
{{num_split, tv::div_up(num_activate_out, mask_width) * mask_int_count}}, {{num_split, tv::div_up(num_activate_out, mask_width), mask_int_count}},
tv::uint32, features.device(), stream_int); tv::uint32, features.device(), stream_int);
for (int i = 0; i < num_split; ++i){{ for (int i = 0; i < num_split; ++i){{
mask_output_fwd_splits.push_back(mask_output_fwd[i]); mask_output_fwd_splits.push_back(mask_output_fwd[i]);
...@@ -2139,9 +2195,15 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2139,9 +2195,15 @@ class ConvGemmOps(pccm.ParameterizedClass):
for (int j = 0; j < num_split; ++j){{ for (int j = 0; j < num_split; ++j){{
float beta = j == 0 ? 0 : 1; float beta = j == 0 ? 0 : 1;
if (!bias.empty()){{ if (!bias.empty() && !tune_res.algo_desp.is_int8_inference){{
// use source as bias
beta = 1; beta = 1;
}} }}
if (!output_add.empty() && tune_res.algo_desp.is_int8_inference){{
// use source as bias
beta = output_add_scale;
}}
if (j > 0){{ if (j > 0){{
bias = tv::Tensor(); bias = tv::Tensor();
}} }}
...@@ -2158,9 +2220,8 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2158,9 +2220,8 @@ class ConvGemmOps(pccm.ParameterizedClass):
false, // reverse_mask false, // reverse_mask
mask_ptr[j], mask_ptr[j],
-1, // mask_width -1, // mask_width
1.0, beta, alpha, beta,
stream_int, stream_int,
mask_int_count, // mask_int_count is after stream_int
tv::Tensor(), // workspace tv::Tensor(), // workspace
false, // verbose false, // verbose
timer, timer,
...@@ -2168,7 +2229,9 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2168,7 +2229,9 @@ class ConvGemmOps(pccm.ParameterizedClass):
bias, bias,
act_alpha, act_alpha,
act_beta, act_beta,
act_type); act_type,
scale,
output_add);
}} }}
// auto end_ev = tv::CUDAEvent(); // auto end_ev = tv::CUDAEvent();
// end_ev.record(stream_int); // end_ev.record(stream_int);
...@@ -2193,7 +2256,6 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2193,7 +2256,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
code.arg("mask_output_fwd", "tv::Tensor") code.arg("mask_output_fwd", "tv::Tensor")
code.arg("masks", "tv::Tensor") code.arg("masks", "tv::Tensor")
code.arg("mask_int_count", "int")
code.arg("arch", "std::tuple<int, int>") code.arg("arch", "std::tuple<int, int>")
code.arg("mask_width", "int") code.arg("mask_width", "int")
...@@ -2286,7 +2348,6 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2286,7 +2348,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
tv::Tensor(), // mask_output tv::Tensor(), // mask_output
1.0, 0.0, 1.0, 0.0,
stream_int, stream_int,
mask_int_count,
auto_fp32_accum, auto_fp32_accum,
fp32_accum, fp32_accum,
5, // num_run 5, // num_run
...@@ -2311,7 +2372,6 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2311,7 +2372,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
tv::Tensor(), // mask_output tv::Tensor(), // mask_output
1.0, 0.0, 1.0, 0.0,
stream_int, stream_int,
mask_int_count,
auto_fp32_accum, auto_fp32_accum,
fp32_accum, fp32_accum,
5, // num_run 5, // num_run
...@@ -2354,7 +2414,6 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2354,7 +2414,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
-1, // mask_width -1, // mask_width
1.0, beta, 1.0, beta,
stream_int, stream_int,
mask_int_count,
tv::Tensor(), // workspace tv::Tensor(), // workspace
false, // verbose false, // verbose
timer); timer);
...@@ -2372,7 +2431,6 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2372,7 +2431,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
mask_width, mask_width,
1.0, 0.0, 1.0, 0.0,
stream_int, stream_int,
mask_int_count,
workspace, // workspace workspace, // workspace
false, // verbose false, // verbose
timer); timer);
......
...@@ -829,7 +829,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -829,7 +829,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
uint32_t filter_mask_in = (1u << ((RS - 1 - filter_offset) % 32)); uint32_t filter_mask_in = (1u << ((RS - 1 - filter_offset) % 32));
uint32_t filter_mask_in_offset = (RS - 1 - filter_offset) / 32; uint32_t filter_mask_in_offset = (RS - 1 - filter_offset) / 32;
// uint32_t filter_mask_center = (1u << (RS / 2)); // uint32_t filter_mask_center = (1u << (RS / 2));
loc_iter.set_filter_offset(filter_offset); loc_iter.set_filter_offset(filter_offset);
int indices_pair_size_mul_RS = indices_pair_size * RS; int indices_pair_size_mul_RS = indices_pair_size * RS;
int filter_offset_mul_indices_pair_size = filter_offset * indices_pair_size; int filter_offset_mul_indices_pair_size = filter_offset * indices_pair_size;
...@@ -1255,13 +1254,13 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1255,13 +1254,13 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
f"tv::array<int, {self.ndim}>") f"tv::array<int, {self.ndim}>")
code.arg("transposed", f"bool", "false") code.arg("transposed", f"bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0") code.arg("stream_int", f"std::uintptr_t", "0")
code.arg("mask_int_count", "int", "1")
code.raw(f""" code.raw(f"""
auto custream = reinterpret_cast<cudaStream_t>(stream_int); auto custream = reinterpret_cast<cudaStream_t>(stream_int);
// TODO stream // TODO stream
// TODO handle num input == 0 // TODO handle num input == 0
int kv = ksize.op<tv::arrayops::prod>(); int kv = ksize.op<tv::arrayops::prod>();
int mask_int_count = tv::div_up(kv, 32);
// indice_pairs_bwd: [kv, num_act_in] or empty // indice_pairs_bwd: [kv, num_act_in] or empty
// indice_pairs_fwd: [kv, num_act_out] // indice_pairs_fwd: [kv, num_act_out]
auto ctx = tv::Context(); auto ctx = tv::Context();
...@@ -1504,7 +1503,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1504,7 +1503,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
"cumm.tensorview.Tensor = Tensor()") "cumm.tensorview.Tensor = Tensor()")
code.arg("is_train", "bool", "true") code.arg("is_train", "bool", "true")
code.arg("stream_int", f"std::uintptr_t", "0") code.arg("stream_int", f"std::uintptr_t", "0")
code.arg("mask_int_count", "int", "1")
code.raw(f""" code.raw(f"""
int num_act_in_real = indices.dim(0); int num_act_in_real = indices.dim(0);
...@@ -1523,6 +1521,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1523,6 +1521,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
padding[i] = (ksize[i] / 2) * dilation[i]; padding[i] = (ksize[i] / 2) * dilation[i];
}} }}
int kv = ksize.op<tv::arrayops::prod>(); int kv = ksize.op<tv::arrayops::prod>();
int mask_int_count = tv::div_up(kv, 32);
TV_ASSERT_RT_ERR(kv == indice_pairs.dim(1), "error"); TV_ASSERT_RT_ERR(kv == indice_pairs.dim(1), "error");
// indice_pairs: [1 or 2, kv, num_act_in] if mask else [2, kv, num_act_in] // indice_pairs: [1 or 2, kv, num_act_in] if mask else [2, kv, num_act_in]
// out_inds: [MaxSize, {self.ndim + 1}] // out_inds: [MaxSize, {self.ndim + 1}]
...@@ -1556,8 +1555,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1556,8 +1555,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
if (!indice_pair_mask.empty()){{ if (!indice_pair_mask.empty()){{
TV_ASSERT_RT_ERR(indice_pairs.ndim() == 3, "error"); TV_ASSERT_RT_ERR(indice_pairs.ndim() == 3, "error");
TV_ASSERT_RT_ERR(indice_pairs.dim(0) == (is_train ? 2 : 1), "error"); TV_ASSERT_RT_ERR(indice_pairs.dim(0) == (is_train ? 2 : 1), "error");
TV_ASSERT_INVALID_ARG(indice_pair_mask.ndim() == 2, "error"); TV_ASSERT_INVALID_ARG(indice_pair_mask.ndim() == 3, "error");
// indice_pair_mask: [mask_split_count, num_act_in] // indice_pair_mask: [mask_split_count, num_act_in, num_mask_per_point]
if (indice_pair_mask.dim(0) == 2){{ if (indice_pair_mask.dim(0) == 2){{
auto mask_0 = indice_pair_mask[0].slice_first_axis(0, num_act_in_real); auto mask_0 = indice_pair_mask[0].slice_first_axis(0, num_act_in_real);
auto mask_1 = indice_pair_mask[1].slice_first_axis(0, num_act_in_real); auto mask_1 = indice_pair_mask[1].slice_first_axis(0, num_act_in_real);
...@@ -1571,13 +1570,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1571,13 +1570,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
indices.dim(0), indice_pairs.dim(2), kv, is_train); indices.dim(0), indice_pairs.dim(2), kv, is_train);
}}else{{ }}else{{
// indice_pair_mask: [1, num_act_in] // indice_pair_mask: [1, num_act_in, num_mask_per_point]
tv::cuda::Launch lanucher_fill(num_act_in_real, custream); tv::cuda::Launch lanucher_fill(num_act_in_real, custream);
if (mask_int_count == 1) if (mask_int_count == 1){{
lanucher_fill(cudakers::fill_kernel<uint32_t>, indice_pair_mask.data_ptr<uint32_t>(), (1 << (kv / 2)), indices.dim(0)); lanucher_fill(cudakers::fill_kernel<uint32_t>, indice_pair_mask.data_ptr<uint32_t>(), (1 << (kv / 2)), indices.dim(0));
else }}
else{{
lanucher_fill(init_subm_multiple_mask_int_kernel<uint32_t>, lanucher_fill(init_subm_multiple_mask_int_kernel<uint32_t>,
indice_pair_mask.data_ptr<uint32_t>(), kv / 2, indices.dim(0), mask_int_count); indice_pair_mask.data_ptr<uint32_t>(), kv / 2, indices.dim(0), mask_int_count);
}}
TV_ASSERT_RT_ERR(indice_pair_mask.dim(0) == 1, "error"); TV_ASSERT_RT_ERR(indice_pair_mask.dim(0) == 1, "error");
launcher_num_act_in(calc_subm_conv_indices_mask<table_t, {loc_type}>, loc_iter, hash, launcher_num_act_in(calc_subm_conv_indices_mask<table_t, {loc_type}>, loc_iter, hash,
indices.data_ptr<const int>(), indice_pairs.data_ptr<int>(), indices.data_ptr<const int>(), indice_pairs.data_ptr<int>(),
......
...@@ -465,14 +465,14 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -465,14 +465,14 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
table_launcher(kernel::assign_table<table_t>, hash, indices.data_ptr<int>(), table_launcher(kernel::assign_table<table_t>, hash, indices.data_ptr<int>(),
count.data_ptr<int>(), count.data_ptr<int>(),
layout, voxels.dim(0)); layout, voxels.dim(0));
auto count_cpu = count.cpu();
int count_val = count_cpu.item<int32_t>();
count_val = count_val > voxels.dim(0) ? voxels.dim(0) : count_val;
launcher(kernel::generate_voxel<table_t>, hash, points.data_ptr<const {self.dtype}>(), launcher(kernel::generate_voxel<table_t>, hash, points.data_ptr<const {self.dtype}>(),
point_indice_data.data_ptr<const int64_t>(), voxels.data_ptr<{self.dtype}>(), point_indice_data.data_ptr<const int64_t>(), voxels.data_ptr<{self.dtype}>(),
num_per_voxel.data_ptr<int>(), points_voxel_id.data_ptr<int64_t>(), points.dim(1), voxels.dim(1), num_per_voxel.data_ptr<int>(), points_voxel_id.data_ptr<int64_t>(), points.dim(1), voxels.dim(1),
voxels.dim(0), vsize_tv, coors_range_tv, voxels.dim(0), vsize_tv, coors_range_tv,
grid_size_tv, grid_stride_tv, points.dim(0)); grid_size_tv, grid_stride_tv, points.dim(0));
auto count_cpu = count.cpu();
int count_val = count_cpu.item<int32_t>();
count_val = count_val > voxels.dim(0) ? voxels.dim(0) : count_val;
auto voxel_launcher = tv::cuda::Launch(count_val, custream); auto voxel_launcher = tv::cuda::Launch(count_val, custream);
if (empty_mean){{ if (empty_mean){{
launcher(kernel::voxel_empty_fill_mean, voxels.data_ptr<{self.dtype}>(), launcher(kernel::voxel_empty_fill_mean, voxels.data_ptr<{self.dtype}>(),
......
...@@ -37,10 +37,23 @@ from spconv.utils import nullcontext ...@@ -37,10 +37,23 @@ from spconv.utils import nullcontext
from torch.nn.init import calculate_gain from torch.nn.init import calculate_gain
from cumm import tensorview as tv from cumm import tensorview as tv
from torch.nn import functional as F
FILTER_HWIO = False FILTER_HWIO = False
_MAX_NUM_VOXELS_DURING_TRAINING = "max_num_voxels_during_training" _MAX_NUM_VOXELS_DURING_TRAINING = "max_num_voxels_during_training"
def _apply_act(x: torch.Tensor, act_type: tv.gemm.Activation, act_alpha: float, act_beta: float):
if act_type == tv.gemm.Activation.None_:
return x
elif act_type == tv.gemm.Activation.ReLU:
return F.relu(x)
elif act_type == tv.gemm.Activation.Sigmoid:
return F.sigmoid(x)
elif act_type == tv.gemm.Activation.LeakyReLU:
return F.leaky_relu(x, act_alpha)
else:
raise NotImplementedError
class SparseConvolution(SparseModule): class SparseConvolution(SparseModule):
__constants__ = [ __constants__ = [
...@@ -104,7 +117,7 @@ class SparseConvolution(SparseModule): ...@@ -104,7 +117,7 @@ class SparseConvolution(SparseModule):
torch.zeros(1, dtype=torch.int32)) torch.zeros(1, dtype=torch.int32))
self.record_voxel_count = record_voxel_count self.record_voxel_count = record_voxel_count
if algo is None: if algo is None:
if kv <= 32 and not CPU_ONLY_BUILD: if kv <= 128 and not CPU_ONLY_BUILD:
if kv < 8: if kv < 8:
algo = ConvAlgo.MaskImplicitGemm algo = ConvAlgo.MaskImplicitGemm
else: else:
...@@ -139,6 +152,19 @@ class SparseConvolution(SparseModule): ...@@ -139,6 +152,19 @@ class SparseConvolution(SparseModule):
self.act_type = act_type self.act_type = act_type
self.act_alpha = act_alpha self.act_alpha = act_alpha
self.act_beta = act_beta self.act_beta = act_beta
self.enable_int8_test_mode: bool = False
self._int8_weight = torch.Tensor()
# calculated by max(abs(weight)) for each channel
self._int8_weight_scale = torch.Tensor()
# calculated by scale self.bias with _int8_input_scale
self._int8_bias = torch.Tensor()
# int8 inference must set _int8_input_scale
self._int8_input_scale: Optional[float] = None
# if _int8_output_scale unset, will execute s8 @ s8 => f16/f32 (weight dtype), i.e. dequantization
self._int8_output_scale: Optional[float] = None
if self.conv1x1: if self.conv1x1:
assert act_type == tv.gemm.Activation.None_, "conv1x1 don't support fused act" assert act_type == tv.gemm.Activation.None_, "conv1x1 don't support fused act"
self.reset_parameters() self.reset_parameters()
...@@ -151,11 +177,19 @@ class SparseConvolution(SparseModule): ...@@ -151,11 +177,19 @@ class SparseConvolution(SparseModule):
return getattr(self, _MAX_NUM_VOXELS_DURING_TRAINING) return getattr(self, _MAX_NUM_VOXELS_DURING_TRAINING)
return None return None
def set_int8_test(self, enable: bool, input_scale: float, output_scale: Optional[float] = None, weight_scale: Optional[torch.Tensor] = None):
self._int8_input_scale = input_scale
self._int8_output_scale = output_scale
if weight_scale is not None:
self._int8_weight_scale = weight_scale
self.enable_int8_test_mode = enable
def _load_weight_different_layout(self, state_dict, prefix, local_metadata, def _load_weight_different_layout(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, strict, missing_keys, unexpected_keys,
error_msgs): error_msgs):
if self.record_voxel_count and not self.subm and not self.inverse and _MAX_NUM_VOXELS_DURING_TRAINING not in state_dict: name = prefix + _MAX_NUM_VOXELS_DURING_TRAINING
state_dict[prefix + _MAX_NUM_VOXELS_DURING_TRAINING] = torch.zeros( if self.record_voxel_count and not self.subm and not self.inverse and name not in state_dict:
state_dict[name] = torch.zeros(
1, dtype=torch.int32) 1, dtype=torch.int32)
if not SAVED_WEIGHT_LAYOUT: if not SAVED_WEIGHT_LAYOUT:
return return
...@@ -255,7 +289,10 @@ class SparseConvolution(SparseModule): ...@@ -255,7 +289,10 @@ class SparseConvolution(SparseModule):
def is_inverseable(self): def is_inverseable(self):
return self.indice_key is not None and not self.subm return self.indice_key is not None and not self.subm
def forward(self, input: SparseConvTensor): def forward(self, input: SparseConvTensor, add_input: Optional[SparseConvTensor] = None):
return self._conv_forward(input, self.weight, self.bias, add_input)
def _conv_forward(self, input: SparseConvTensor, weight: torch.Tensor, bias: Optional[torch.Tensor], add_input: Optional[SparseConvTensor] = None):
assert isinstance(input, SparseConvTensor) assert isinstance(input, SparseConvTensor)
assert input.features.shape[ assert input.features.shape[
1] == self.in_channels, "channel size mismatch" 1] == self.in_channels, "channel size mismatch"
...@@ -264,9 +301,34 @@ class SparseConvolution(SparseModule): ...@@ -264,9 +301,34 @@ class SparseConvolution(SparseModule):
indices = input.indices indices = input.indices
spatial_shape = input.spatial_shape spatial_shape = input.spatial_shape
batch_size = input.batch_size batch_size = input.batch_size
bias_for_training = self.bias if self.training else None bias_for_training = bias if self.training else None
bias_for_infer = self.bias if not self.training else None bias_for_infer = bias if not self.training else None
output_scale = None
output_add_scale = 1.0
if self.enable_int8_test_mode:
assert not self.training, "must in eval mode"
assert self.algo == ConvAlgo.MaskImplicitGemm, "int8 inference only support MaskImplicitGemm"
assert bias_for_infer is not None, "conv-bn-relu must be fused"
assert self._int8_input_scale is not None
if features.dtype != torch.int8:
# quantize
features = torch.clamp(torch.round(features / self._int8_input_scale), -128, 127).to(torch.int8)
output_scale = self._int8_output_scale
int8_out_scale = output_scale
if int8_out_scale is None:
int8_out_scale = 1
if add_input is not None:
assert add_input.int8_scale is not None, "only support int8 add"
output_add_scale = add_input.int8_scale
if self._int8_weight.numel() == 0:
with torch.no_grad():
assert ALL_WEIGHT_IS_KRSC
weight_scales = torch.abs(weight).view(self.out_channels, -1).max(1)[0]
num_1s = [1] * (self.ndim + 1)
self._int8_weight = (weight / weight_scales.view(self.out_channels, *num_1s) * 127).to(torch.int8)
if self._int8_weight_scale.numel() == 0:
self._int8_weight_scale = int8_out_scale / (self._int8_input_scale * weight_scales)
self._int8_bias = bias_for_infer * int8_out_scale
if self.training: if self.training:
msg = "act don't support backward, only used in inference" msg = "act don't support backward, only used in inference"
assert self.act_type == tv.gemm.Activation.None_, msg assert self.act_type == tv.gemm.Activation.None_, msg
...@@ -310,18 +372,19 @@ class SparseConvolution(SparseModule): ...@@ -310,18 +372,19 @@ class SparseConvolution(SparseModule):
"out_channels": self.out_channels, "out_channels": self.out_channels,
} }
} }
if self.conv1x1: if self.conv1x1 and not self.enable_int8_test_mode:
# in int8 test mode, we don't implement conv1x1 via mm.
if FILTER_HWIO: if FILTER_HWIO:
features = torch.mm( features = torch.mm(
input.features, input.features,
self.weight.view(self.out_channels, self.in_channels).T) weight.view(self.out_channels, self.in_channels).T)
else: else:
features = torch.mm( features = torch.mm(
input.features, input.features,
self.weight.view(self.in_channels, self.out_channels)) weight.view(self.in_channels, self.out_channels))
if self.bias is not None: if bias is not None:
features += self.bias features += bias
out_tensor = out_tensor.replace_feature(features) out_tensor = out_tensor.replace_feature(features)
# padding may change spatial shape of conv 1x1. # padding may change spatial shape of conv 1x1.
out_tensor.spatial_shape = out_spatial_shape out_tensor.spatial_shape = out_spatial_shape
...@@ -413,7 +476,7 @@ class SparseConvolution(SparseModule): ...@@ -413,7 +476,7 @@ class SparseConvolution(SparseModule):
if self.subm: if self.subm:
out_features = Fsp.indice_subm_conv( out_features = Fsp.indice_subm_conv(
features, features,
self.weight, weight,
indice_pairs_calc, indice_pairs_calc,
indice_pair_num, indice_pair_num,
outids.shape[0], outids.shape[0],
...@@ -427,7 +490,7 @@ class SparseConvolution(SparseModule): ...@@ -427,7 +490,7 @@ class SparseConvolution(SparseModule):
if self.inverse: if self.inverse:
out_features = Fsp.indice_inverse_conv( out_features = Fsp.indice_inverse_conv(
features, features,
self.weight, weight,
indice_pairs_calc, indice_pairs_calc,
indice_pair_num, indice_pair_num,
outids.shape[0], outids.shape[0],
...@@ -440,7 +503,7 @@ class SparseConvolution(SparseModule): ...@@ -440,7 +503,7 @@ class SparseConvolution(SparseModule):
else: else:
out_features = Fsp.indice_conv( out_features = Fsp.indice_conv(
features, features,
self.weight, weight,
indice_pairs_calc, indice_pairs_calc,
indice_pair_num, indice_pair_num,
outids.shape[0], outids.shape[0],
...@@ -481,11 +544,13 @@ class SparseConvolution(SparseModule): ...@@ -481,11 +544,13 @@ class SparseConvolution(SparseModule):
mask_argsort_fwd_splits = datas.mask_argsort_fwd_splits mask_argsort_fwd_splits = datas.mask_argsort_fwd_splits
mask_argsort_bwd_splits = datas.mask_argsort_bwd_splits mask_argsort_bwd_splits = datas.mask_argsort_bwd_splits
masks = datas.masks masks = datas.masks
mask_int_count = datas.mask_int_count
assert self.subm, "only support reuse subm indices" assert self.subm, "only support reuse subm indices"
self._check_subm_reuse_valid(input, spatial_shape, self._check_subm_reuse_valid(input, spatial_shape,
datas) datas)
else: else:
if input.benchmark:
torch.cuda.synchronize()
t = time.time()
with input._timer.namespace("gen_pairs"): with input._timer.namespace("gen_pairs"):
# we need to gen bwd indices for regular conv # we need to gen bwd indices for regular conv
# because it may be inversed. # because it may be inversed.
...@@ -514,7 +579,11 @@ class SparseConvolution(SparseModule): ...@@ -514,7 +579,11 @@ class SparseConvolution(SparseModule):
print(msg, file=sys.stderr) print(msg, file=sys.stderr)
spconv_save_debug_data(indices) spconv_save_debug_data(indices)
raise e raise e
if input.benchmark:
torch.cuda.synchronize()
interval = time.time() - t
out_tensor.benchmark_record[
self.name]["indice_gen_time"].append(interval)
outids = res[0] outids = res[0]
num_inds_per_loc = res[1] num_inds_per_loc = res[1]
pair_fwd = res[2] pair_fwd = res[2]
...@@ -524,7 +593,6 @@ class SparseConvolution(SparseModule): ...@@ -524,7 +593,6 @@ class SparseConvolution(SparseModule):
mask_argsort_fwd_splits = res[6] mask_argsort_fwd_splits = res[6]
mask_argsort_bwd_splits = res[7] mask_argsort_bwd_splits = res[7]
masks = res[8] masks = res[8]
mask_int_count = res[9]
if self.indice_key is not None: if self.indice_key is not None:
indice_data = ImplicitGemmIndiceData( indice_data = ImplicitGemmIndiceData(
outids, outids,
...@@ -543,8 +611,7 @@ class SparseConvolution(SparseModule): ...@@ -543,8 +611,7 @@ class SparseConvolution(SparseModule):
ksize=self.kernel_size, ksize=self.kernel_size,
stride=self.stride, stride=self.stride,
padding=self.padding, padding=self.padding,
dilation=self.dilation, dilation=self.dilation)
mask_int_count=mask_int_count)
msg = f"your indice key {self.indice_key} already exists in this sparse tensor." msg = f"your indice key {self.indice_key} already exists in this sparse tensor."
assert self.indice_key not in indice_dict, msg assert self.indice_key not in indice_dict, msg
indice_dict[self.indice_key] = indice_data indice_dict[self.indice_key] = indice_data
...@@ -552,16 +619,43 @@ class SparseConvolution(SparseModule): ...@@ -552,16 +619,43 @@ class SparseConvolution(SparseModule):
torch.cuda.synchronize() torch.cuda.synchronize()
t = time.time() t = time.time()
num_activate_out = outids.shape[0] num_activate_out = outids.shape[0]
weight_cur = weight
bias_cur = bias_for_infer
if self.enable_int8_test_mode:
assert features.dtype == torch.int8, "in int8 test mode, feature must be int8"
weight_cur = self._int8_weight
bias_cur = self._int8_bias
if self.training:
out_features = Fsp.implicit_gemm( out_features = Fsp.implicit_gemm(
features, self.weight, pair_fwd, pair_bwd, features, weight_cur, pair_fwd, pair_bwd,
pair_mask_fwd_splits, pair_mask_bwd_splits, pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits, mask_argsort_fwd_splits, mask_argsort_bwd_splits,
num_activate_out, masks, mask_int_count, self.training, self.subm, num_activate_out, masks, self.training, self.subm,
input._timer, self.fp32_accum, input._timer, self.fp32_accum,
bias_for_infer, bias_cur,
self.act_alpha, self.act_alpha,
self.act_beta, self.act_beta,
self.act_type) self.act_type)
else:
output_dtype = None
if self._int8_output_scale is None:
output_dtype = weight.dtype
out_features, _, _ = ops.implicit_gemm(
features, weight_cur, pair_fwd, pair_mask_fwd_splits,
mask_argsort_fwd_splits,
num_activate_out, masks, self.training, self.subm,
input._timer, self.fp32_accum,
bias_cur,
self.act_alpha,
self.act_beta,
self.act_type,
# TODO do we really need output scale to scale bias in kernel?
1.0, # output_scale
self._int8_weight_scale, # scale
output_add=add_input.features if add_input is not None else None,
output_add_scale=output_add_scale,
output_dtype=output_dtype)
if bias_for_training is not None: if bias_for_training is not None:
out_features += bias_for_training out_features += bias_for_training
if input.benchmark: if input.benchmark:
...@@ -581,8 +675,9 @@ class SparseConvolution(SparseModule): ...@@ -581,8 +675,9 @@ class SparseConvolution(SparseModule):
out_tensor.indices = outids out_tensor.indices = outids
out_tensor.indice_dict = indice_dict out_tensor.indice_dict = indice_dict
out_tensor.spatial_shape = out_spatial_shape out_tensor.spatial_shape = out_spatial_shape
# print(outids.shape, spatial_shape, self.kernel_size, self.stride, self.padding, if add_input is not None and not self.enable_int8_test_mode:
# self.dilation, self.output_padding, out_spatial_shape) out_tensor = out_tensor.replace_feature(_apply_act(out_tensor.features + add_input.features, self.act_type, self.act_alpha, self.act_beta))
out_tensor.int8_scale = output_scale
return out_tensor return out_tensor
......
...@@ -89,8 +89,7 @@ class ImplicitGemmIndiceData(object): ...@@ -89,8 +89,7 @@ class ImplicitGemmIndiceData(object):
out_spatial_shape, is_subm: bool, algo: ConvAlgo, out_spatial_shape, is_subm: bool, algo: ConvAlgo,
ksize: List[int], stride: List[int], dilation: List[int], padding: List[int], ksize: List[int], stride: List[int], dilation: List[int], padding: List[int],
in_voxel_num: Optional[Any] = None, in_voxel_num: Optional[Any] = None,
out_voxel_num: Optional[Any] = None, out_voxel_num: Optional[Any] = None):
mask_int_count: int=1):
self.out_indices = out_indices self.out_indices = out_indices
self.indices = indices self.indices = indices
self.pair_fwd = pair_fwd self.pair_fwd = pair_fwd
...@@ -111,7 +110,6 @@ class ImplicitGemmIndiceData(object): ...@@ -111,7 +110,6 @@ class ImplicitGemmIndiceData(object):
# in/out voxel_num is only used in tensorrt conversion. # in/out voxel_num is only used in tensorrt conversion.
self.in_voxel_num = in_voxel_num self.in_voxel_num = in_voxel_num
self.out_voxel_num = out_voxel_num self.out_voxel_num = out_voxel_num
self.mask_int_count = mask_int_count
def scatter_nd(indices, updates, shape): def scatter_nd(indices, updates, shape):
...@@ -183,6 +181,8 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -183,6 +181,8 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
self.thrust_allocator = ThrustSortAllocator(features.device) self.thrust_allocator = ThrustSortAllocator(features.device)
self._timer = CUDAKernelTimer(enable_timer) self._timer = CUDAKernelTimer(enable_timer)
self.force_algo = force_algo self.force_algo = force_algo
# for simple int8 torch inference
self.int8_scale: Optional[float] = None
def replace_feature(self, feature: torch.Tensor): def replace_feature(self, feature: torch.Tensor):
"""we need to replace x.features = F.relu(x.features) with x = x.replace_feature(F.relu(x.features)) """we need to replace x.features = F.relu(x.features) with x = x.replace_feature(F.relu(x.features))
......
...@@ -198,7 +198,6 @@ class SparseImplicitGemmFunction(Function): ...@@ -198,7 +198,6 @@ class SparseImplicitGemmFunction(Function):
mask_argsort_bwd_splits: List[torch.Tensor], mask_argsort_bwd_splits: List[torch.Tensor],
num_activate_out: int, num_activate_out: int,
masks: List[np.ndarray], masks: List[np.ndarray],
mask_int_count: int,
is_train: bool, is_train: bool,
is_subm: bool, is_subm: bool,
timer: CUDAKernelTimer = CUDAKernelTimer(False), timer: CUDAKernelTimer = CUDAKernelTimer(False),
...@@ -210,7 +209,7 @@ class SparseImplicitGemmFunction(Function): ...@@ -210,7 +209,7 @@ class SparseImplicitGemmFunction(Function):
try: try:
out, mask_out, mask_width = ops.implicit_gemm( out, mask_out, mask_width = ops.implicit_gemm(
features, filters, pair_fwd, pair_mask_fwd_splits, features, filters, pair_fwd, pair_mask_fwd_splits,
mask_argsort_fwd_splits, num_activate_out, masks, mask_int_count, is_train, mask_argsort_fwd_splits, num_activate_out, masks, is_train,
is_subm, timer, fp32_accum, bias, act_alpha, act_beta, is_subm, timer, fp32_accum, bias, act_alpha, act_beta,
act_type) act_type)
except Exception as e: except Exception as e:
...@@ -236,7 +235,6 @@ class SparseImplicitGemmFunction(Function): ...@@ -236,7 +235,6 @@ class SparseImplicitGemmFunction(Function):
ctx.masks = masks ctx.masks = masks
ctx.is_subm = is_subm ctx.is_subm = is_subm
ctx.fp32_accum = fp32_accum ctx.fp32_accum = fp32_accum
ctx.mask_int_count = mask_int_count
return out return out
@staticmethod @staticmethod
...@@ -255,7 +253,6 @@ class SparseImplicitGemmFunction(Function): ...@@ -255,7 +253,6 @@ class SparseImplicitGemmFunction(Function):
is_subm = ctx.is_subm is_subm = ctx.is_subm
timer = ctx.timer timer = ctx.timer
fp32_accum = ctx.fp32_accum fp32_accum = ctx.fp32_accum
mask_int_count = ctx.mask_int_count
try: try:
input_bp, filters_bp = ops.implicit_gemm_backward( input_bp, filters_bp = ops.implicit_gemm_backward(
...@@ -270,7 +267,6 @@ class SparseImplicitGemmFunction(Function): ...@@ -270,7 +267,6 @@ class SparseImplicitGemmFunction(Function):
mask_argsort_bwd_splits, mask_argsort_bwd_splits,
mask_output_fwd=mask_out, mask_output_fwd=mask_out,
masks=masks, masks=masks,
mask_int_count=mask_int_count,
mask_width=mask_width, mask_width=mask_width,
is_subm=is_subm, is_subm=is_subm,
timer=timer, timer=timer,
...@@ -286,7 +282,7 @@ class SparseImplicitGemmFunction(Function): ...@@ -286,7 +282,7 @@ class SparseImplicitGemmFunction(Function):
mask_argsort_bwd_splits, masks)) mask_argsort_bwd_splits, masks))
raise e raise e
None_9 = [None] * 17 None_9 = [None] * 16
return (input_bp, filters_bp, *None_9) return (input_bp, filters_bp, *None_9)
......
...@@ -23,7 +23,7 @@ import spconv ...@@ -23,7 +23,7 @@ import spconv
from spconv.core import AlgoHint, ConvAlgo from spconv.core import AlgoHint, ConvAlgo
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from spconv.pytorch.core import ThrustSortAllocator from spconv.pytorch.core import ThrustSortAllocator
from spconv.pytorch.cppcore import TorchAllocator, torch_tensor_to_tv, get_current_stream, get_arch, TorchSpconvMatmul from spconv.pytorch.cppcore import _TORCH_DTYPE_TO_TV, TorchAllocator, torch_tensor_to_tv, get_current_stream, get_arch, TorchSpconvMatmul
from spconv.core_cc.csrc.sparse.all import SpconvOps from spconv.core_cc.csrc.sparse.all import SpconvOps
from spconv.core_cc.csrc.sparse.alloc import ExternalAllocator from spconv.core_cc.csrc.sparse.alloc import ExternalAllocator
from spconv.constants import SPCONV_CPP_INDICE_PAIRS, SPCONV_CPP_INDICE_PAIRS_IGEMM, SPCONV_CPP_GEMM, SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE, SPCONV_ALLOW_TF32 from spconv.constants import SPCONV_CPP_INDICE_PAIRS, SPCONV_CPP_INDICE_PAIRS_IGEMM, SPCONV_CPP_GEMM, SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE, SPCONV_ALLOW_TF32
...@@ -31,7 +31,7 @@ import spconv.core_cc as _ext ...@@ -31,7 +31,7 @@ import spconv.core_cc as _ext
from spconv.core_cc.csrc.sparse.convops.spops import ConvGemmOps from spconv.core_cc.csrc.sparse.convops.spops import ConvGemmOps
from spconv.core_cc.csrc.sparse.inference import InferenceOps from spconv.core_cc.csrc.sparse.inference import InferenceOps
from spconv.cppconstants import CPU_ONLY_BUILD from spconv.cppconstants import CPU_ONLY_BUILD
from cumm.gemm.codeops import div_up
from spconv.utils import nullcontext from spconv.utils import nullcontext
if not CPU_ONLY_BUILD: if not CPU_ONLY_BUILD:
...@@ -365,7 +365,7 @@ def get_indice_pairs_implicit_gemm( ...@@ -365,7 +365,7 @@ def get_indice_pairs_implicit_gemm(
timer_cpp = tv.CUDAKernelTimer(False) timer_cpp = tv.CUDAKernelTimer(False)
if timer._timer is not None: if timer._timer is not None:
timer_cpp = timer._timer timer_cpp = timer._timer
mask_tensor, num_act_out, mask_int_count = SpconvOps.get_indice_pairs_implicit_gemm( mask_tensor, num_act_out = SpconvOps.get_indice_pairs_implicit_gemm(
thalloc, thalloc,
torch_tensor_to_tv(indices), torch_tensor_to_tv(indices),
batch_size, batch_size,
...@@ -408,7 +408,7 @@ def get_indice_pairs_implicit_gemm( ...@@ -408,7 +408,7 @@ def get_indice_pairs_implicit_gemm(
assert pair.shape[0] == 2 assert pair.shape[0] == 2
pair_bwd = pair[1] pair_bwd = pair[1]
return (out_inds, indice_num_per_loc, pair[0], pair_bwd, return (out_inds, indice_num_per_loc, pair[0], pair_bwd,
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks, mask_int_count) pair_mask_in_splits, [], mask_argsort_in_splits, [], masks)
else: else:
pair_bwd = thalloc.allocated.get(AllocKeys.PairBwd, torch.Tensor()) pair_bwd = thalloc.allocated.get(AllocKeys.PairBwd, torch.Tensor())
pair_fwd = thalloc.allocated[AllocKeys.PairFwd] pair_fwd = thalloc.allocated[AllocKeys.PairFwd]
...@@ -437,16 +437,13 @@ def get_indice_pairs_implicit_gemm( ...@@ -437,16 +437,13 @@ def get_indice_pairs_implicit_gemm(
] ]
return (out_inds, indice_num_per_loc, pair_fwd, pair_bwd, return (out_inds, indice_num_per_loc, pair_fwd, pair_bwd,
pair_mask_fwd_splits, pair_mask_bwd_splits, pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits, masks, mask_int_count) mask_argsort_fwd_splits, mask_argsort_bwd_splits, masks)
assert indices.is_cuda, "implicit gemm only support cuda" assert indices.is_cuda, "implicit gemm only support cuda"
ndim = indices.shape[1] - 1 ndim = indices.shape[1] - 1
kv: int = functools.reduce(lambda x, y: x * y, ksize, 1) kv: int = functools.reduce(lambda x, y: x * y, ksize, 1)
# TODO in future we will support up to 128 kernel volume. # TODO in future we will support up to 128 kernel volume.
# assert kv <= 32, "currently only support kernel volume <= 32 to use implicit gemm" # assert kv <= 32, "currently only support kernel volume <= 32 to use implicit gemm"
mask_int_count = (kv + 31) // 32 mask_int_count = div_up(kv, 32)
if 1 < mask_int_count < 4:
mask_int_count = 4
assert mask_int_count in [1, 4]
if not subm: if not subm:
if transpose: if transpose:
...@@ -511,7 +508,7 @@ def get_indice_pairs_implicit_gemm( ...@@ -511,7 +508,7 @@ def get_indice_pairs_implicit_gemm(
hashdata = _HashData(out_inds.shape[0], use_int64_hash_k, hashdata = _HashData(out_inds.shape[0], use_int64_hash_k,
indices.device) indices.device)
pair_mask = torch.empty((mask_split_count, indices.shape[0] * mask_int_count), pair_mask = torch.empty((mask_split_count, indices.shape[0], mask_int_count),
dtype=torch.int32, dtype=torch.int32,
device=indices.device) device=indices.device)
...@@ -531,8 +528,7 @@ def get_indice_pairs_implicit_gemm( ...@@ -531,8 +528,7 @@ def get_indice_pairs_implicit_gemm(
dilation=dilation, dilation=dilation,
indice_pair_mask=pair_mask_tv, indice_pair_mask=pair_mask_tv,
backward=is_train, backward=is_train,
stream_int=stream, stream_int=stream)
mask_int_count=mask_int_count)
# torch.cuda.synchronize() # torch.cuda.synchronize()
# print("SUBM0", time.time() - t) # print("SUBM0", time.time() - t)
# CONV.stream_synchronize(stream) # CONV.stream_synchronize(stream)
...@@ -549,7 +545,7 @@ def get_indice_pairs_implicit_gemm( ...@@ -549,7 +545,7 @@ def get_indice_pairs_implicit_gemm(
# so I use this stupid hack to use torch allocator without touch # so I use this stupid hack to use torch allocator without touch
# pytorch binary (c++). # pytorch binary (c++).
# f**k thrust # f**k thrust
SpconvOps.sort_1d_by_key_allocator_mask_auto(pair_mask_tv[j], SpconvOps.sort_1d_by_key_allocator(pair_mask_tv[j],
alloc.alloc, alloc.alloc,
mask_argsort_tv[j], stream, mask_argsort_tv[j], stream,
mask_int_count) mask_int_count)
...@@ -560,10 +556,10 @@ def get_indice_pairs_implicit_gemm( ...@@ -560,10 +556,10 @@ def get_indice_pairs_implicit_gemm(
] ]
if is_train: if is_train:
return (out_inds, indice_num_per_loc, pair[0], pair[1], return (out_inds, indice_num_per_loc, pair[0], pair[1],
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks, mask_int_count) pair_mask_in_splits, [], mask_argsort_in_splits, [], masks)
else: else:
return (out_inds, indice_num_per_loc, pair[0], torch.Tensor(), return (out_inds, indice_num_per_loc, pair[0], torch.Tensor(),
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks, mask_int_count) pair_mask_in_splits, [], mask_argsort_in_splits, [], masks)
else: else:
max_num_act = SpconvOps.get_handcrafted_max_act_out( max_num_act = SpconvOps.get_handcrafted_max_act_out(
indices.shape[0], ksize, stride, padding, dilation) indices.shape[0], ksize, stride, padding, dilation)
...@@ -621,7 +617,6 @@ def get_indice_pairs_implicit_gemm( ...@@ -621,7 +617,6 @@ def get_indice_pairs_implicit_gemm(
stream_int=stream) stream_int=stream)
uniq_out_indices_offset_tv = tv.Tensor() uniq_out_indices_offset_tv = tv.Tensor()
with timer.record(f"unique_{indice_pairs_uniq.shape[0]}", stream): with timer.record(f"unique_{indice_pairs_uniq.shape[0]}", stream):
if direct_table: if direct_table:
uniq_cnt = torch.zeros([1], uniq_cnt = torch.zeros([1],
dtype=torch.int32, dtype=torch.int32,
...@@ -655,7 +650,7 @@ def get_indice_pairs_implicit_gemm( ...@@ -655,7 +650,7 @@ def get_indice_pairs_implicit_gemm(
-1, -1,
dtype=indices.dtype, dtype=indices.dtype,
device=indices.device) device=indices.device)
pair_mask_fwd = torch.zeros((mask_split_count, num_act_out * mask_int_count), pair_mask_fwd = torch.zeros((mask_split_count, num_act_out, mask_int_count),
dtype=torch.int32, dtype=torch.int32,
device=indices.device) device=indices.device)
pair_fwd_tv = torch_tensor_to_tv(pair_fwd) pair_fwd_tv = torch_tensor_to_tv(pair_fwd)
...@@ -665,7 +660,7 @@ def get_indice_pairs_implicit_gemm( ...@@ -665,7 +660,7 @@ def get_indice_pairs_implicit_gemm(
pair_mask_bwd_tv = tv.Tensor() pair_mask_bwd_tv = tv.Tensor()
if is_train: if is_train:
pair_mask_bwd = torch.zeros( pair_mask_bwd = torch.zeros(
(mask_split_count, indices.shape[0] * mask_int_count), (mask_split_count, indices.shape[0], mask_int_count),
dtype=torch.int32, dtype=torch.int32,
device=indices.device) device=indices.device)
pair_mask_bwd_tv = torch_tensor_to_tv(pair_mask_bwd, pair_mask_bwd_tv = torch_tensor_to_tv(pair_mask_bwd,
...@@ -713,8 +708,7 @@ def get_indice_pairs_implicit_gemm( ...@@ -713,8 +708,7 @@ def get_indice_pairs_implicit_gemm(
padding=padding, padding=padding,
dilation=dilation, dilation=dilation,
transposed=transpose, transposed=transpose,
stream_int=stream, stream_int=stream)
mask_int_count=mask_int_count)
mask_argsort_fwd = torch.empty((mask_split_count, out_inds.shape[0]), mask_argsort_fwd = torch.empty((mask_split_count, out_inds.shape[0]),
dtype=torch.int32, dtype=torch.int32,
device=indices.device) device=indices.device)
...@@ -766,24 +760,24 @@ def get_indice_pairs_implicit_gemm( ...@@ -766,24 +760,24 @@ def get_indice_pairs_implicit_gemm(
else: else:
# if pair_mask_bwd_tv.dim(1) > pair_mask_fwd_tv.dim(1): # if pair_mask_bwd_tv.dim(1) > pair_mask_fwd_tv.dim(1):
if not is_train: if not is_train:
SpconvOps.sort_1d_by_key_allocator_mask_auto(pair_mask_fwd_tv[0], SpconvOps.sort_1d_by_key_allocator(pair_mask_fwd_tv[0],
alloc.alloc, alloc.alloc,
mask_argsort_fwd_tv[0], mask_argsort_fwd_tv[0],
stream, stream,
mask_int_count) mask_int_count)
else: else:
if pair_mask_bwd_tv.dim(1) > pair_mask_fwd_tv.dim(1): if pair_mask_bwd_tv.dim(1) > pair_mask_fwd_tv.dim(1):
SpconvOps.sort_1d_by_key_allocator_mask_auto( SpconvOps.sort_1d_by_key_allocator(
pair_mask_bwd_tv[0], alloc.alloc, pair_mask_bwd_tv[0], alloc.alloc,
mask_argsort_bwd_tv[0], stream, mask_int_count) mask_argsort_bwd_tv[0], stream, mask_int_count)
SpconvOps.sort_1d_by_key_allocator_mask_auto( SpconvOps.sort_1d_by_key_allocator(
pair_mask_fwd_tv[0], alloc.alloc, pair_mask_fwd_tv[0], alloc.alloc,
mask_argsort_fwd_tv[0], stream, mask_int_count) mask_argsort_fwd_tv[0], stream, mask_int_count)
else: else:
SpconvOps.sort_1d_by_key_allocator_mask_auto( SpconvOps.sort_1d_by_key_allocator(
pair_mask_fwd_tv[0], alloc.alloc, pair_mask_fwd_tv[0], alloc.alloc,
mask_argsort_fwd_tv[0], stream, mask_int_count) mask_argsort_fwd_tv[0], stream, mask_int_count)
SpconvOps.sort_1d_by_key_allocator_mask_auto( SpconvOps.sort_1d_by_key_allocator(
pair_mask_bwd_tv[0], alloc.alloc, pair_mask_bwd_tv[0], alloc.alloc,
mask_argsort_bwd_tv[0], stream, mask_int_count) mask_argsort_bwd_tv[0], stream, mask_int_count)
...@@ -808,7 +802,7 @@ def get_indice_pairs_implicit_gemm( ...@@ -808,7 +802,7 @@ def get_indice_pairs_implicit_gemm(
return (out_inds, indice_num_per_loc, pair_fwd, pair_bwd, return (out_inds, indice_num_per_loc, pair_fwd, pair_bwd,
pair_mask_fwd_splits, pair_mask_bwd_splits, pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits, masks, mask_int_count) mask_argsort_fwd_splits, mask_argsort_bwd_splits, masks)
def indice_conv(features: torch.Tensor, def indice_conv(features: torch.Tensor,
...@@ -1457,7 +1451,6 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1457,7 +1451,6 @@ def implicit_gemm(features: torch.Tensor,
mask_argsort_fwd_splits: List[torch.Tensor], mask_argsort_fwd_splits: List[torch.Tensor],
num_activate_out: int, num_activate_out: int,
masks: List[np.ndarray], masks: List[np.ndarray],
mask_int_count: int,
is_train: bool, is_train: bool,
is_subm: bool, is_subm: bool,
timer: CUDAKernelTimer = CUDAKernelTimer(False), timer: CUDAKernelTimer = CUDAKernelTimer(False),
...@@ -1465,16 +1458,31 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1465,16 +1458,31 @@ def implicit_gemm(features: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
act_alpha: float = 0.0, act_alpha: float = 0.0,
act_beta: float = 0.0, act_beta: float = 0.0,
act_type: tv.gemm.Activation = tv.gemm.Activation.None_): act_type: tv.gemm.Activation = tv.gemm.Activation.None_,
output_scale: float = 1.0,
scale: Optional[torch.Tensor] = None,
output_add: Optional[torch.Tensor] = None,
output_add_scale: float = 1.0,
output_dtype: Optional[torch.dtype] = None):
stream = get_current_stream() stream = get_current_stream()
bias_tv = tv.Tensor() bias_tv = tv.Tensor()
scale_tv = tv.Tensor()
output_add_tv = tv.Tensor()
if output_add is not None:
assert features.dtype == torch.int8, "fused residual add only support int8"
if bias is not None: if bias is not None:
bias_tv = torch_tensor_to_tv(bias) bias_tv = torch_tensor_to_tv(bias)
if scale is not None:
scale_tv = torch_tensor_to_tv(scale)
if output_add is not None:
output_add_tv = torch_tensor_to_tv(output_add)
if not features.is_contiguous(): if not features.is_contiguous():
features = features.contiguous() features = features.contiguous()
assert features.is_contiguous() assert features.is_contiguous()
assert filters.is_contiguous() assert filters.is_contiguous()
if output_dtype is None:
output_dtype = features.dtype
if SPCONV_CPP_GEMM and CONV_CPP is not None: if SPCONV_CPP_GEMM and CONV_CPP is not None:
alloc = TorchAllocator(features.device) alloc = TorchAllocator(features.device)
...@@ -1497,13 +1505,15 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1497,13 +1505,15 @@ def implicit_gemm(features: torch.Tensor,
if fp32_accum is None: if fp32_accum is None:
fp32_accum = False fp32_accum = False
arch = get_arch() arch = get_arch()
output_dtype_tv = _TORCH_DTYPE_TO_TV[output_dtype]
mask_width, tune_res_cpp = ConvGemmOps.implicit_gemm( mask_width, tune_res_cpp = ConvGemmOps.implicit_gemm(
alloc, CONV_CPP, features_tv, filters_tv, pair_fwd_tv, alloc, CONV_CPP, features_tv, filters_tv, pair_fwd_tv,
pair_mask_fwd_splits_tv, mask_argsort_fwd_splits_tv, pair_mask_fwd_splits_tv, mask_argsort_fwd_splits_tv,
num_activate_out, mask_tv, mask_int_count, arch, is_train, is_subm, stream, num_activate_out, mask_tv, arch, is_train, is_subm, stream,
timer_cpp, auto_fp32_accum, fp32_accum, bias_tv, act_alpha, act_beta, act_type, timer_cpp, auto_fp32_accum, fp32_accum, bias_tv, act_alpha, act_beta, act_type,
use_tf32=constants.SPCONV_ALLOW_TF32) use_tf32=constants.SPCONV_ALLOW_TF32, output_scale=output_scale,
scale=scale_tv, output_add=output_add_tv, output_add_scale=output_add_scale,
output_dtype=output_dtype_tv)
out_features = alloc.allocated[AllocKeys.OutFeatures] out_features = alloc.allocated[AllocKeys.OutFeatures]
mask_output_fwd = alloc.allocated.get(AllocKeys.MaskOutputFwd, None) mask_output_fwd = alloc.allocated.get(AllocKeys.MaskOutputFwd, None)
if is_train: if is_train:
...@@ -1515,8 +1525,8 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1515,8 +1525,8 @@ def implicit_gemm(features: torch.Tensor,
# t = time.time() # t = time.time()
if features.dtype == torch.int8 or features.dtype == torch.qint8: # if features.dtype == torch.int8 or features.dtype == torch.qint8:
raise NotImplementedError("work in progress") # raise NotImplementedError("work in progress")
# here filters is KRSC # here filters is KRSC
masks_ints = [m.item() for m in masks] masks_ints = [m.item() for m in masks]
out_channel = filters.shape[0] out_channel = filters.shape[0]
...@@ -1524,13 +1534,14 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1524,13 +1534,14 @@ def implicit_gemm(features: torch.Tensor,
num_split = len(pair_mask_fwd_splits) num_split = len(pair_mask_fwd_splits)
filters = filters.reshape(out_channel, -1, filters.shape[-1]) filters = filters.reshape(out_channel, -1, filters.shape[-1])
kv = filters.shape[1] kv = filters.shape[1]
mask_int_count = div_up(kv, 32)
if is_subm: if is_subm:
out_features = torch.empty((num_activate_out, out_channel), out_features = torch.empty((num_activate_out, out_channel),
dtype=features.dtype, dtype=output_dtype,
device=features.device) device=features.device)
else: else:
out_features = torch.zeros((num_activate_out, out_channel), out_features = torch.zeros((num_activate_out, out_channel),
dtype=features.dtype, dtype=output_dtype,
device=features.device) device=features.device)
pair_fwd_tv = torch_tensor_to_tv(pair_fwd) pair_fwd_tv = torch_tensor_to_tv(pair_fwd)
...@@ -1568,13 +1579,13 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1568,13 +1579,13 @@ def implicit_gemm(features: torch.Tensor,
stream=stream, stream=stream,
fp32_accum=fp32_accum, fp32_accum=fp32_accum,
use_tf32=constants.SPCONV_ALLOW_TF32, use_tf32=constants.SPCONV_ALLOW_TF32,
mask_int_count=mask_int_count) bias=bias_tv, scale=scale_tv)
mask_width = tune_res.algo_desp.tile_shape[0] mask_width = tune_res.algo_desp.tile_shape[0]
if is_train: if is_train:
mask_output_fwd = torch.empty( mask_output_fwd = torch.empty(
[num_split, [num_split,
codeops.div_up(num_activate_out, mask_width) * mask_int_count], codeops.div_up(num_activate_out, mask_width), mask_int_count],
dtype=torch.int32, dtype=torch.int32,
device=features.device) device=features.device)
# pytorch don't support uint32. # pytorch don't support uint32.
...@@ -1597,12 +1608,16 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1597,12 +1608,16 @@ def implicit_gemm(features: torch.Tensor,
bias_tv = tv.Tensor() bias_tv = tv.Tensor()
if bias is not None: if bias is not None:
bias_tv = torch_tensor_to_tv(bias) bias_tv = torch_tensor_to_tv(bias)
alpha = 1.0
if tune_res.algo_desp.is_int8_inference:
alpha = output_scale
with timer.record("implicit_gemm", stream): with timer.record("implicit_gemm", stream):
for j in range(num_split): for j in range(num_split):
beta = 0 if j == 0 else 1 beta = 0 if j == 0 else 1
if bias is not None: if bias is not None and not tune_res.algo_desp.is_int8_inference:
beta = 1 beta = 1
if output_add is not None and tune_res.algo_desp.is_int8_inference:
beta = output_add_scale
CONV.run_with_tuned_result( CONV.run_with_tuned_result(
tune_res, tune_res,
ConvOpType.kForward, ConvOpType.kForward,
...@@ -1616,6 +1631,7 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1616,6 +1631,7 @@ def implicit_gemm(features: torch.Tensor,
reverse_mask=False, reverse_mask=False,
mask_filter=masks_ints[j], mask_filter=masks_ints[j],
mask_width=-1, mask_width=-1,
alpha=alpha,
beta=beta, beta=beta,
stream=stream, stream=stream,
verbose=False, verbose=False,
...@@ -1623,91 +1639,8 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1623,91 +1639,8 @@ def implicit_gemm(features: torch.Tensor,
act_type=act_type, act_type=act_type,
act_alpha=act_alpha, act_alpha=act_alpha,
act_beta=act_beta, act_beta=act_beta,
mask_int_count=mask_int_count) scale=scale_tv,
# INT8_TEST = True output_add=output_add)
# if INT8_TEST:
# if features.shape[1] % 32 != 0:
# return out_features, mask_output_fwd, mask_width
# features = features.to(torch.int8)
# filters = filters.to(torch.int8)
# out_features_i8 = out_features.to(torch.int8)
# features_tv = torch_tensor_to_tv(features)
# filters_tv = torch_tensor_to_tv(filters)
# out_features_i8_tv = torch_tensor_to_tv(out_features_i8)
# tune_res = CONV.get_tuned_algo(ConvOpType.kForward, features_tv.dtype,
# filters_tv.dtype, out_features_i8_tv.dtype,
# out_channel, in_channel, arch)
# if tune_res is None:
# tune_res, _ = CONV.tune_and_cache(
# ConvOpType.kForward,
# features_tv,
# filters_tv,
# out_features_i8_tv,
# NHWC,
# KRSC,
# NHWC,
# arch,
# mask=pair_mask_fwd_split_tvs[0],
# mask_argsort=mask_argsort_fwd_split_tvs[0],
# indices=pair_fwd_tv,
# reverse_mask=False,
# mask_filter=masks[0].item(),
# stream=stream,
# fp32_accum=fp32_accum)
# mask_width = tune_res.algo_desp.tile_shape[0]
# if is_train:
# mask_output_fwd = torch.empty(
# [num_split,
# codeops.div_up(num_activate_out, mask_width)],
# dtype=torch.int32,
# device=features.device)
# # pytorch don't support uint32.
# mask_output_fwd_tv = torch_tensor_to_tv(mask_output_fwd,
# dtype=tv.uint32)
# mask_output_fwd_tvs = [mask_output_fwd_tv[j] for j in range(num_split)]
# else:
# mask_output_fwd = None
# mask_output_fwd_tv = tv.Tensor()
# mask_output_fwd_tvs = [tv.Tensor() for _ in range(num_split)]
# # CONV.stream_synchronize(stream)
# # print("FPREPARE", time.time() - t)
# # # t = time.time()
# # CONV.stream_synchronize(stream)
# # t = time.time()
# # print(tune_res.algo_desp)
# with tv.measure_and_print(f"i8 time {features.shape[0]}-{in_channel}-{out_channel}"):
# with timer.record("implicit_gemm_i8", stream):
# for j in range(num_split):
# beta = 0 if j == 0 else 1
# CONV.run_with_tuned_result(
# tune_res,
# ConvOpType.kForward,
# features_tv,
# filters_tv,
# out_features_i8_tv,
# mask=pair_mask_fwd_split_tvs[j],
# mask_argsort=mask_argsort_fwd_split_tvs[j],
# mask_output=mask_output_fwd_tvs[j],
# indices=pair_fwd_tv,
# reverse_mask=False,
# mask_filter=masks_ints[j],
# mask_width=-1,
# beta=beta,
# stream=stream,
# verbose=False)
# torch.cuda.synchronize()
# if DEBUG:
# CONV.stream_synchronize(stream)
# dura = time.time() - t
# print("F", tune_res.algo_desp, dura)
# print(out_features.mean(), out_features.max(), out_features.min())
return out_features, mask_output_fwd, mask_width return out_features, mask_output_fwd, mask_width
...@@ -1722,7 +1655,6 @@ def implicit_gemm_backward(features: torch.Tensor, ...@@ -1722,7 +1655,6 @@ def implicit_gemm_backward(features: torch.Tensor,
mask_argsort_bwd_splits: List[torch.Tensor], mask_argsort_bwd_splits: List[torch.Tensor],
mask_output_fwd: Optional[torch.Tensor], mask_output_fwd: Optional[torch.Tensor],
masks: List[np.ndarray], masks: List[np.ndarray],
mask_int_count: int,
mask_width: int, mask_width: int,
is_subm: bool, is_subm: bool,
timer: CUDAKernelTimer = CUDAKernelTimer(False), timer: CUDAKernelTimer = CUDAKernelTimer(False),
...@@ -1782,7 +1714,7 @@ def implicit_gemm_backward(features: torch.Tensor, ...@@ -1782,7 +1714,7 @@ def implicit_gemm_backward(features: torch.Tensor,
alloc, CONV_CPP, features_tv, filters_tv, out_bp_tv, pair_fwd_tv, alloc, CONV_CPP, features_tv, filters_tv, out_bp_tv, pair_fwd_tv,
pair_bwd_tv, pair_mask_fwd_splits_tv, pair_mask_bwd_splits_tv, pair_bwd_tv, pair_mask_fwd_splits_tv, pair_mask_bwd_splits_tv,
mask_argsort_fwd_splits_tv, mask_argsort_bwd_splits_tv, mask_argsort_fwd_splits_tv, mask_argsort_bwd_splits_tv,
mask_output_fwd_tv, mask_tv, mask_int_count, arch, mask_width, is_subm, stream, mask_output_fwd_tv, mask_tv, arch, mask_width, is_subm, stream,
timer_cpp, auto_fp32_accum, fp32_accum, timer_cpp, auto_fp32_accum, fp32_accum,
use_tf32=constants.SPCONV_ALLOW_TF32) use_tf32=constants.SPCONV_ALLOW_TF32)
din = alloc.allocated[AllocKeys.DIn] din = alloc.allocated[AllocKeys.DIn]
...@@ -1802,6 +1734,8 @@ def implicit_gemm_backward(features: torch.Tensor, ...@@ -1802,6 +1734,8 @@ def implicit_gemm_backward(features: torch.Tensor,
filters = filters.reshape(out_channel, -1, filters.shape[-1]) filters = filters.reshape(out_channel, -1, filters.shape[-1])
kv = filters.shape[1] kv = filters.shape[1]
need_dynamic_mask = kv > 32
mask_int_count = div_up(kv, 32)
pair_fwd_tv = torch_tensor_to_tv(pair_fwd) pair_fwd_tv = torch_tensor_to_tv(pair_fwd)
pair_bwd_tv = torch_tensor_to_tv(pair_bwd) pair_bwd_tv = torch_tensor_to_tv(pair_bwd)
...@@ -1831,11 +1765,12 @@ def implicit_gemm_backward(features: torch.Tensor, ...@@ -1831,11 +1765,12 @@ def implicit_gemm_backward(features: torch.Tensor,
dgrad_tune_res = CONV.get_tuned_algo(ConvOpType.kBackwardInput, dgrad_tune_res = CONV.get_tuned_algo(ConvOpType.kBackwardInput,
din_tv.dtype, filters_tv.dtype, din_tv.dtype, filters_tv.dtype,
dout_tv.dtype, out_channel, dout_tv.dtype, out_channel,
in_channel, arch) in_channel, arch, need_dynamic_mask=need_dynamic_mask)
wgrad_tune_res = CONV.get_tuned_algo(ConvOpType.kBackwardWeight, wgrad_tune_res = CONV.get_tuned_algo(ConvOpType.kBackwardWeight,
features_tv.dtype, dfilters_tv.dtype, features_tv.dtype, dfilters_tv.dtype,
dout_tv.dtype, out_channel, dout_tv.dtype, out_channel,
in_channel, arch, mask_width) in_channel, arch, mask_width,
need_dynamic_mask=need_dynamic_mask)
if dgrad_tune_res is None: if dgrad_tune_res is None:
# TODO split mask maybe completely invalid # TODO split mask maybe completely invalid
...@@ -1861,8 +1796,7 @@ def implicit_gemm_backward(features: torch.Tensor, ...@@ -1861,8 +1796,7 @@ def implicit_gemm_backward(features: torch.Tensor,
mask_filter=masks[0].item(), mask_filter=masks[0].item(),
stream=stream, stream=stream,
fp32_accum=fp32_accum, fp32_accum=fp32_accum,
use_tf32=constants.SPCONV_ALLOW_TF32, use_tf32=constants.SPCONV_ALLOW_TF32)
mask_int_count=mask_int_count)
if wgrad_tune_res is None: if wgrad_tune_res is None:
wgrad_tune_res, _ = CONV.tune_and_cache( wgrad_tune_res, _ = CONV.tune_and_cache(
ConvOpType.kBackwardWeight, ConvOpType.kBackwardWeight,
...@@ -1881,8 +1815,7 @@ def implicit_gemm_backward(features: torch.Tensor, ...@@ -1881,8 +1815,7 @@ def implicit_gemm_backward(features: torch.Tensor,
mask_output=tv.Tensor(), mask_output=tv.Tensor(),
mask_width=mask_width, mask_width=mask_width,
stream=stream, stream=stream,
use_tf32=constants.SPCONV_ALLOW_TF32, use_tf32=constants.SPCONV_ALLOW_TF32)
mask_int_count=mask_int_count)
workspace_size = CONV.query_workspace_size(wgrad_tune_res.algo_desp, workspace_size = CONV.query_workspace_size(wgrad_tune_res.algo_desp,
wgrad_tune_res.splitk, wgrad_tune_res.splitk,
ConvOpType.kBackwardWeight, ConvOpType.kBackwardWeight,
...@@ -1919,8 +1852,7 @@ def implicit_gemm_backward(features: torch.Tensor, ...@@ -1919,8 +1852,7 @@ def implicit_gemm_backward(features: torch.Tensor,
mask_filter=masks[j].item(), mask_filter=masks[j].item(),
mask_width=-1, mask_width=-1,
beta=beta, beta=beta,
stream=stream, stream=stream)
mask_int_count=mask_int_count)
# for backward weight, beta = 0 because each split # for backward weight, beta = 0 because each split
# handle different kernel locations. # handle different kernel locations.
# TODO remove D iterator in backward weight kernel # TODO remove D iterator in backward weight kernel
...@@ -1939,8 +1871,7 @@ def implicit_gemm_backward(features: torch.Tensor, ...@@ -1939,8 +1871,7 @@ def implicit_gemm_backward(features: torch.Tensor,
mask_width=mask_width, mask_width=mask_width,
beta=0, beta=0,
workspace=workspace_tv, workspace=workspace_tv,
stream=stream, stream=stream)
mask_int_count=mask_int_count)
return (din, dfilters.reshape(filters_shape)) return (din, dfilters.reshape(filters_shape))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment