import torch import tilelang import tilelang.language as T pass_configs = { tilelang.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True, } def convert_to_uint16(x): hval = T.Cast("float16", x) bits_uint = T.reinterpret("uint16", hval) bits_uint = T.if_then_else(x < 0, ~bits_uint & (0xFFFF), bits_uint | (0x8000)) return bits_uint >> 8 def convert_to_uint32(x): bits_uint = T.reinterpret("uint32", x) bits_uint = T.if_then_else( x < 0, ~bits_uint & T.Cast("uint32", (0xFFFFFFFF)), bits_uint | T.Cast("uint32", (0x80000000)), ) return bits_uint @tilelang.jit(pass_configs=pass_configs) def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"): batch = T.dynamic("batch") seq_len = T.dynamic("seq_len") RADIX = 1 << 8 BLOCK_SIZE = 1024 SMEM_INPUT_SIZE = 4096 # assume the threshold bucket size after first pass is less than 4K @T.prim_func def tl_topk_kernel( input: T.Tensor[(batch, seq_len), in_dtype], index: T.Tensor[(batch, topk), out_dtype], starts: T.Tensor[(batch), out_dtype], ends: T.Tensor[(batch), out_dtype], ): with T.Kernel(batch, threads=BLOCK_SIZE) as (bx): tx = T.get_thread_binding() s_threshold_bin_id = T.alloc_shared([1], "int32") s_histogram = T.alloc_shared([RADIX + 1], "int32") s_num_input = T.alloc_shared([2], "int32") s_input_idx = T.alloc_shared([2, SMEM_INPUT_SIZE], "int32") l_threshold_bin_id = T.alloc_var("int32") l_new_topk = T.alloc_var("int32") l_num_input = T.alloc_var("int32") l_bin_id32 = T.alloc_var("int32") l_val = T.alloc_var("int32") l_start_pos = T.alloc_var("int32") l_start_idx = T.alloc_var("int32") l_end_idx = T.alloc_var("int32") l_out_pos = T.alloc_var("int32") l_new_topk = topk l_start_idx = starts[bx] l_end_idx = ends[bx] # stage 1: use 8bit to do quick topk T.fill(s_histogram, 0) T.fill(s_num_input[0], 0) T.sync_threads() for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)): input_idx = s * BLOCK_SIZE + tx if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len: inval_int16 = convert_to_uint16(input[bx, input_idx]) T.atomic_add(s_histogram[inval_int16], 1) T.sync_threads() # cumsum if tx < RADIX: for i in T.serial(8): offset = 1 << i T.sync_threads(3, RADIX) if tx < RADIX - offset: l_val = s_histogram[tx] + s_histogram[tx + offset] T.sync_threads(3, RADIX) if tx < RADIX - offset: s_histogram[tx] = l_val # find threshold bin id T.sync_threads(3, RADIX) if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk: s_threshold_bin_id[0] = tx T.sync_threads() l_threshold_bin_id = s_threshold_bin_id[0] l_new_topk = l_new_topk - s_histogram[l_threshold_bin_id + 1] T.sync_threads() # collect all elements with exponent ≥ threshold for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)): T.sync_threads() input_idx = s * BLOCK_SIZE + tx if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len: bin_id = convert_to_uint16(input[bx, input_idx]) l_bin_id32 = T.Cast("int32", bin_id) if l_bin_id32 > l_threshold_bin_id: # need a pos = T.atomic_add(s_histogram[bin_id32+1], 1) pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) index[bx, pos] = input_idx elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0: # pos = s_num_input[0] pos = T.atomic_add(s_num_input[0], 1, return_prev=True) s_input_idx[0, pos] = input_idx # stage 2: tail pass for round in T.serial(4): if l_new_topk <= 0: T.loop_break() r_idx = round % 2 l_start_pos = topk - l_new_topk T.sync_threads() T.fill(s_histogram, 0) if tx == 0: s_num_input[r_idx ^ 1] = 0 T.sync_threads() l_num_input = s_num_input[r_idx] for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): if s * BLOCK_SIZE + tx < l_num_input: l_bin_id32 = T.Cast("int32", (( convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF)) T.atomic_add(s_histogram[l_bin_id32], 1) T.sync_threads() # cumsum if tx < RADIX: for i in T.serial(8): offset = 1 << i T.sync_threads(3, RADIX) if tx < RADIX - offset: l_val = s_histogram[tx] + s_histogram[tx + offset] T.sync_threads(3, RADIX) if tx < RADIX - offset: s_histogram[tx] = l_val # find threshold bin id T.sync_threads(3, RADIX) if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk: s_threshold_bin_id[0] = tx T.sync_threads() l_threshold_bin_id = s_threshold_bin_id[0] l_new_topk = l_new_topk - s_histogram[l_threshold_bin_id + 1] T.sync_threads() for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): T.sync_threads() if s * BLOCK_SIZE + tx < l_num_input: l_bin_id32 = T.Cast("int32", (( convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF)) if l_bin_id32 > l_threshold_bin_id: pos = T.atomic_add( s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos index[bx, pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0: if round == 3: l_out_pos = T.atomic_add( s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos if l_out_pos < topk: index[bx, l_out_pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] else: pos = T.atomic_add(s_num_input[r_idx ^ 1], 1, return_prev=True) s_input_idx[r_idx ^ 1, pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] return tl_topk_kernel def tl_topk(input, starts, ends, topk): batch, seq_len = input.shape indexes = torch.zeros(batch, topk, dtype=torch.int32, device=input.device) kernel = tl_topk_impl(topk) kernel(input, indexes, starts, ends) return indexes def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048): batch = 64 seq_len = 32 * 1024 topk = 2048 torch.manual_seed(1) input = torch.randn(batch, seq_len, dtype=torch.float32).cuda() starts = torch.zeros(batch, dtype=torch.int32).cuda() ends = torch.ones(batch, dtype=torch.int32).cuda() * seq_len indexes = tl_topk(input, starts, ends, topk) print(indexes) indexes_ref = torch.topk(input, topk, dim=-1)[1] print(indexes_ref) # indexes_ref = fast_topk(input, topk) # print(indexes_ref) # Calculate intersection of out_ref and out_trt for i in range(batch): ref_np = indexes_ref[i].cpu().to(torch.int32).numpy() trt_np = indexes[i].cpu().to(torch.int32).numpy() set_ref = set(ref_np) set_trt = set(trt_np) intersection = set_ref & set_trt print("selected/all:", len(intersection), "/", len(set_ref), "=", len(intersection) / len(set_ref)) # Performance test with CUDA events torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) # Warmup for _ in range(5): _ = tl_topk(input, starts, ends, topk) torch.cuda.synchronize() n_iters = 20 start_event.record() for _ in range(n_iters): _ = tl_topk(input, starts, ends, topk) end_event.record() torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) print(f"Average tl_topk time: {elapsed_time_ms / n_iters:.3f} ms") # Torch topk time start_event.record() for _ in range(n_iters): _ = torch.topk(input, topk, dim=-1)[1] end_event.record() torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) print(f"Average torch.topk time: {elapsed_time_ms / n_iters:.3f} ms") if __name__ == "__main__": test_topk_selector()