Commit 3610ebfa authored by one's avatar one
Browse files

Add launch bounds helper for sparse index kernels

Centralize the 1024-thread launch bound annotation for sparse index
CUDA kernels and apply it consistently across index, hash table, mask,
and SubM indice helper kernels. This keeps generated kernel definitions
aligned with the launch configuration used by DTK runtime checks.
parent a2dd956c
...@@ -23,6 +23,14 @@ from typing import List ...@@ -23,6 +23,14 @@ from typing import List
from cumm.conv.params import ConvProblem from cumm.conv.params import ConvProblem
import numpy as np import numpy as np
SPARSE_INDICES_LAUNCH_BOUNDS = "__launch_bounds__(1024)"
def _launch_bound_kernel_code():
code = pccm.FunctionCode()
code.add_pre_attr(SPARSE_INDICES_LAUNCH_BOUNDS)
return code
class CudaCommonKernel(pccm.ParameterizedClass): class CudaCommonKernel(pccm.ParameterizedClass):
# we need to use PClass instead of Class # we need to use PClass instead of Class
...@@ -34,8 +42,7 @@ class CudaCommonKernel(pccm.ParameterizedClass): ...@@ -34,8 +42,7 @@ class CudaCommonKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def arange_kernel(self): def arange_kernel(self):
code = pccm.FunctionCode() code = _launch_bound_kernel_code()
code.add_pre_attr("__launch_bounds__(1024)")
code.targ("T") code.targ("T")
code.arg("data", f"T*") code.arg("data", f"T*")
code.arg("size", f"int") code.arg("size", f"int")
...@@ -48,8 +55,7 @@ class CudaCommonKernel(pccm.ParameterizedClass): ...@@ -48,8 +55,7 @@ class CudaCommonKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def fill_kernel(self): def fill_kernel(self):
code = pccm.FunctionCode() code = _launch_bound_kernel_code()
code.add_pre_attr("__launch_bounds__(1024)")
code.targ("T") code.targ("T")
code.arg("data", f"T*") code.arg("data", f"T*")
code.arg("val", f"T") code.arg("val", f"T")
...@@ -63,8 +69,7 @@ class CudaCommonKernel(pccm.ParameterizedClass): ...@@ -63,8 +69,7 @@ class CudaCommonKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def maximum_value_kernel(self): def maximum_value_kernel(self):
code = pccm.FunctionCode() code = _launch_bound_kernel_code()
code.add_pre_attr("__launch_bounds__(1024)")
code.targ("T") code.targ("T")
code.arg("data", f"T*") code.arg("data", f"T*")
code.arg("val", f"T") code.arg("val", f"T")
...@@ -293,7 +298,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -293,7 +298,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def calc_conv_indices_stage1(self): def calc_conv_indices_stage1(self):
code = pccm.FunctionCode() code = _launch_bound_kernel_code()
code.targ("TIndiceUniq") code.targ("TIndiceUniq")
code.targ("TConvLocIter") code.targ("TConvLocIter")
code.arg("loc_iter", f"TConvLocIter") # [N, ndim + 1] code.arg("loc_iter", f"TConvLocIter") # [N, ndim + 1]
...@@ -338,7 +343,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -338,7 +343,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def build_conv_hash_table(self): def build_conv_hash_table(self):
code = pccm.FunctionCode() code = _launch_bound_kernel_code()
code.targ("TTable") code.targ("TTable")
code.targ("TLayoutNPQ") code.targ("TLayoutNPQ")
...@@ -362,7 +367,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -362,7 +367,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def arange_hash_table_and_assign_out(self): def arange_hash_table_and_assign_out(self):
code = pccm.FunctionCode() code = _launch_bound_kernel_code()
code.targ("TTable") code.targ("TTable")
code.targ("TLayoutNPQ") code.targ("TLayoutNPQ")
code.arg("table", f"TTable") # [N, ndim + 1] code.arg("table", f"TTable") # [N, ndim + 1]
...@@ -393,7 +398,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -393,7 +398,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def arange_hash_table(self): def arange_hash_table(self):
code = pccm.FunctionCode() code = _launch_bound_kernel_code()
code.targ("TTable") code.targ("TTable")
code.arg("table", f"TTable") # [N, ndim + 1] code.arg("table", f"TTable") # [N, ndim + 1]
code.arg("out_indices_offset", code.arg("out_indices_offset",
...@@ -419,7 +424,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -419,7 +424,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def assign_out_indices(self): def assign_out_indices(self):
code = pccm.FunctionCode() code = _launch_bound_kernel_code()
code.targ("T") code.targ("T")
code.targ("TLayoutNPQ") code.targ("TLayoutNPQ")
code.arg("indices_out", f"int*") # [N, ndim + 1] code.arg("indices_out", f"int*") # [N, ndim + 1]
...@@ -435,7 +440,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -435,7 +440,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def calc_conv_indices_stage2(self): def calc_conv_indices_stage2(self):
code = pccm.FunctionCode() code = _launch_bound_kernel_code()
code.targ("TTable") code.targ("TTable")
code.arg("table", f"TTable") # [N, ndim + 1] code.arg("table", f"TTable") # [N, ndim + 1]
code.arg("indice_pairs_uniq_before_sort", code.arg("indice_pairs_uniq_before_sort",
...@@ -466,7 +471,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -466,7 +471,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
here we will use indice_pairs_uniq as temp memory of here we will use indice_pairs_uniq as temp memory of
indice_pairs_in_part. indice_pairs_in_part.
""" """
code = pccm.FunctionCode() code = _launch_bound_kernel_code()
code.targ("TTable") code.targ("TTable")
code.arg("table", f"TTable") # [N, ndim + 1] code.arg("table", f"TTable") # [N, ndim + 1]
...@@ -503,7 +508,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -503,7 +508,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def calc_conv_indices_stage1_mask(self): def calc_conv_indices_stage1_mask(self):
code = pccm.FunctionCode() code = _launch_bound_kernel_code()
code.targ("TIndiceUniq") code.targ("TIndiceUniq")
code.targ("TConvLocIter") code.targ("TConvLocIter")
...@@ -550,7 +555,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -550,7 +555,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def calc_conv_indices_stage1_mask_direct_table(self): def calc_conv_indices_stage1_mask_direct_table(self):
code = pccm.FunctionCode() code = _launch_bound_kernel_code()
code.targ("TIndiceUniq") code.targ("TIndiceUniq")
code.targ("TTable") code.targ("TTable")
code.targ("TConvLocIter") code.targ("TConvLocIter")
...@@ -601,7 +606,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -601,7 +606,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def calc_conv_indices_stage2_mask(self): def calc_conv_indices_stage2_mask(self):
code = pccm.FunctionCode() code = _launch_bound_kernel_code()
code.targ("TTable") code.targ("TTable")
code.nontype_targ("CheckValueValid", "bool") code.nontype_targ("CheckValueValid", "bool")
...@@ -654,7 +659,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -654,7 +659,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def calc_conv_indices_stage2_mask_output(self): def calc_conv_indices_stage2_mask_output(self):
code = pccm.FunctionCode() code = _launch_bound_kernel_code()
code.arg("indice_pairs_bwd", code.arg("indice_pairs_bwd",
f"int*") # [kernelProd, MaxSize], out -> inp f"int*") # [kernelProd, MaxSize], out -> inp
code.arg("mask_bwd", f"uint32_t*") # [kernelProd] code.arg("mask_bwd", f"uint32_t*") # [kernelProd]
...@@ -680,7 +685,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -680,7 +685,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def calc_conv_indices_stage2_inference_mask(self): def calc_conv_indices_stage2_inference_mask(self):
code = pccm.FunctionCode() code = _launch_bound_kernel_code()
code.targ("TTable") code.targ("TTable")
code.nontype_targ("CheckValueValid", "bool") code.nontype_targ("CheckValueValid", "bool")
...@@ -725,8 +730,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -725,8 +730,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def build_subm_conv_hash_table(self): def build_subm_conv_hash_table(self):
code = pccm.FunctionCode() code = _launch_bound_kernel_code()
code.add_pre_attr("__launch_bounds__(1024)")
code.targ("TTable") code.targ("TTable")
code.targ("TLayoutNPQ") code.targ("TLayoutNPQ")
...@@ -746,7 +750,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -746,7 +750,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def clean_indices_uniq(self): def clean_indices_uniq(self):
code = pccm.FunctionCode() code = _launch_bound_kernel_code()
code.targ("T") code.targ("T")
code.arg("indice_pairs_for_uniq", f"T*") code.arg("indice_pairs_for_uniq", f"T*")
code.arg("size", f"size_t") code.arg("size", f"size_t")
...@@ -759,7 +763,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -759,7 +763,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def calc_subm_conv_indices(self): def calc_subm_conv_indices(self):
code = pccm.FunctionCode() code = _launch_bound_kernel_code()
code.targ("TTable") code.targ("TTable")
code.targ("TConvLocIter") code.targ("TConvLocIter")
code.arg("loc_iter", f"TConvLocIter") # [N, ndim + 1] code.arg("loc_iter", f"TConvLocIter") # [N, ndim + 1]
...@@ -809,8 +813,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -809,8 +813,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def calc_subm_conv_indices_mask(self): def calc_subm_conv_indices_mask(self):
code = pccm.FunctionCode() code = _launch_bound_kernel_code()
code.add_pre_attr("__launch_bounds__(1024)")
code.targ("TTable") code.targ("TTable")
code.targ("TConvLocIter") code.targ("TConvLocIter")
code.arg("loc_iter", f"TConvLocIter") # [N, ndim + 1] code.arg("loc_iter", f"TConvLocIter") # [N, ndim + 1]
...@@ -880,7 +883,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -880,7 +883,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def calc_subm_conv_indices_split_mask(self): def calc_subm_conv_indices_split_mask(self):
code = pccm.FunctionCode() code = _launch_bound_kernel_code()
code.targ("TTable") code.targ("TTable")
code.targ("TConvLocIter") code.targ("TConvLocIter")
code.arg("loc_iter", f"TConvLocIter") # [N, ndim + 1] code.arg("loc_iter", f"TConvLocIter") # [N, ndim + 1]
...@@ -1605,7 +1608,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -1605,7 +1608,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def init_subm_multiple_mask_int_kernel(self): def init_subm_multiple_mask_int_kernel(self):
code = pccm.FunctionCode() code = _launch_bound_kernel_code()
code.targ("T") code.targ("T")
code.arg("ptr", "T*") code.arg("ptr", "T*")
code.arg("set_bit", "int") code.arg("set_bit", "int")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment