Commit 3610ebfa authored by one's avatar one
Browse files

Add launch bounds helper for sparse index kernels

Centralize the 1024-thread launch bound annotation for sparse index
CUDA kernels and apply it consistently across index, hash table, mask,
and SubM indice helper kernels. This keeps generated kernel definitions
aligned with the launch configuration used by DTK runtime checks.
parent a2dd956c
......@@ -23,6 +23,14 @@ from typing import List
from cumm.conv.params import ConvProblem
import numpy as np
SPARSE_INDICES_LAUNCH_BOUNDS = "__launch_bounds__(1024)"
def _launch_bound_kernel_code():
code = pccm.FunctionCode()
code.add_pre_attr(SPARSE_INDICES_LAUNCH_BOUNDS)
return code
class CudaCommonKernel(pccm.ParameterizedClass):
# we need to use PClass instead of Class
......@@ -34,8 +42,7 @@ class CudaCommonKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def arange_kernel(self):
code = pccm.FunctionCode()
code.add_pre_attr("__launch_bounds__(1024)")
code = _launch_bound_kernel_code()
code.targ("T")
code.arg("data", f"T*")
code.arg("size", f"int")
......@@ -48,8 +55,7 @@ class CudaCommonKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def fill_kernel(self):
code = pccm.FunctionCode()
code.add_pre_attr("__launch_bounds__(1024)")
code = _launch_bound_kernel_code()
code.targ("T")
code.arg("data", f"T*")
code.arg("val", f"T")
......@@ -63,8 +69,7 @@ class CudaCommonKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def maximum_value_kernel(self):
code = pccm.FunctionCode()
code.add_pre_attr("__launch_bounds__(1024)")
code = _launch_bound_kernel_code()
code.targ("T")
code.arg("data", f"T*")
code.arg("val", f"T")
......@@ -293,7 +298,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def calc_conv_indices_stage1(self):
code = pccm.FunctionCode()
code = _launch_bound_kernel_code()
code.targ("TIndiceUniq")
code.targ("TConvLocIter")
code.arg("loc_iter", f"TConvLocIter") # [N, ndim + 1]
......@@ -338,7 +343,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def build_conv_hash_table(self):
code = pccm.FunctionCode()
code = _launch_bound_kernel_code()
code.targ("TTable")
code.targ("TLayoutNPQ")
......@@ -362,7 +367,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def arange_hash_table_and_assign_out(self):
code = pccm.FunctionCode()
code = _launch_bound_kernel_code()
code.targ("TTable")
code.targ("TLayoutNPQ")
code.arg("table", f"TTable") # [N, ndim + 1]
......@@ -393,7 +398,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def arange_hash_table(self):
code = pccm.FunctionCode()
code = _launch_bound_kernel_code()
code.targ("TTable")
code.arg("table", f"TTable") # [N, ndim + 1]
code.arg("out_indices_offset",
......@@ -419,7 +424,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def assign_out_indices(self):
code = pccm.FunctionCode()
code = _launch_bound_kernel_code()
code.targ("T")
code.targ("TLayoutNPQ")
code.arg("indices_out", f"int*") # [N, ndim + 1]
......@@ -435,7 +440,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def calc_conv_indices_stage2(self):
code = pccm.FunctionCode()
code = _launch_bound_kernel_code()
code.targ("TTable")
code.arg("table", f"TTable") # [N, ndim + 1]
code.arg("indice_pairs_uniq_before_sort",
......@@ -466,7 +471,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
here we will use indice_pairs_uniq as temp memory of
indice_pairs_in_part.
"""
code = pccm.FunctionCode()
code = _launch_bound_kernel_code()
code.targ("TTable")
code.arg("table", f"TTable") # [N, ndim + 1]
......@@ -503,7 +508,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def calc_conv_indices_stage1_mask(self):
code = pccm.FunctionCode()
code = _launch_bound_kernel_code()
code.targ("TIndiceUniq")
code.targ("TConvLocIter")
......@@ -550,7 +555,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def calc_conv_indices_stage1_mask_direct_table(self):
code = pccm.FunctionCode()
code = _launch_bound_kernel_code()
code.targ("TIndiceUniq")
code.targ("TTable")
code.targ("TConvLocIter")
......@@ -601,7 +606,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def calc_conv_indices_stage2_mask(self):
code = pccm.FunctionCode()
code = _launch_bound_kernel_code()
code.targ("TTable")
code.nontype_targ("CheckValueValid", "bool")
......@@ -654,7 +659,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def calc_conv_indices_stage2_mask_output(self):
code = pccm.FunctionCode()
code = _launch_bound_kernel_code()
code.arg("indice_pairs_bwd",
f"int*") # [kernelProd, MaxSize], out -> inp
code.arg("mask_bwd", f"uint32_t*") # [kernelProd]
......@@ -680,7 +685,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def calc_conv_indices_stage2_inference_mask(self):
code = pccm.FunctionCode()
code = _launch_bound_kernel_code()
code.targ("TTable")
code.nontype_targ("CheckValueValid", "bool")
......@@ -725,8 +730,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def build_subm_conv_hash_table(self):
code = pccm.FunctionCode()
code.add_pre_attr("__launch_bounds__(1024)")
code = _launch_bound_kernel_code()
code.targ("TTable")
code.targ("TLayoutNPQ")
......@@ -746,7 +750,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def clean_indices_uniq(self):
code = pccm.FunctionCode()
code = _launch_bound_kernel_code()
code.targ("T")
code.arg("indice_pairs_for_uniq", f"T*")
code.arg("size", f"size_t")
......@@ -759,7 +763,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def calc_subm_conv_indices(self):
code = pccm.FunctionCode()
code = _launch_bound_kernel_code()
code.targ("TTable")
code.targ("TConvLocIter")
code.arg("loc_iter", f"TConvLocIter") # [N, ndim + 1]
......@@ -809,8 +813,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def calc_subm_conv_indices_mask(self):
code = pccm.FunctionCode()
code.add_pre_attr("__launch_bounds__(1024)")
code = _launch_bound_kernel_code()
code.targ("TTable")
code.targ("TConvLocIter")
code.arg("loc_iter", f"TConvLocIter") # [N, ndim + 1]
......@@ -880,7 +883,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def calc_subm_conv_indices_split_mask(self):
code = pccm.FunctionCode()
code = _launch_bound_kernel_code()
code.targ("TTable")
code.targ("TConvLocIter")
code.arg("loc_iter", f"TConvLocIter") # [N, ndim + 1]
......@@ -1605,7 +1608,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def init_subm_multiple_mask_int_kernel(self):
code = pccm.FunctionCode()
code = _launch_bound_kernel_code()
code.targ("T")
code.arg("ptr", "T*")
code.arg("set_bit", "int")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment