From 29051439dbed90583bfad1d16dfca88a95e78709 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Fri, 12 Dec 2025 17:22:18 +0800 Subject: [PATCH] [Lint] Phaseout Yapf format and embrace ruff format (#1417) --- .pre-commit-config.yaml | 14 +- .../benchmark_library_dense_fmha.py | 13 +- .../benchmark_tilelang_block_sparse_fmha.py | 72 +- .../benchmark_torch_block_sparse_fmha.py | 25 +- .../benchmark_triton_block_sparse_fmha.py | 40 +- .../mamba2/benchmark_mamba_chunk_scan.py | 207 +++-- benchmark/matmul/benchmark_matmul.py | 16 +- .../matmul/benchmark_matmul_intrinsic.py | 40 +- benchmark/matmul/benchmark_matmul_sp.py | 52 +- benchmark/matmul_fp8/benchmark_matmul.py | 15 +- docs/conf.py | 31 +- examples/amd/example_amd_flash_attn_bwd.py | 206 +++-- examples/amd/example_amd_flash_attn_fwd.py | 105 +-- examples/analyze/example_conv_analyze.py | 45 +- examples/analyze/example_gemm_analyze.py | 6 +- .../attention_sink/benchmark_gqa_sink_fwd.py | 48 +- .../attention_sink/benchmark_mha_sink_fwd.py | 64 +- .../example_gqa_sink_bwd_bhsd.py | 286 +++---- ...ample_gqa_sink_fwd_bhsd_wgmma_pipelined.py | 167 ++-- .../example_mha_sink_bwd_bhsd.py | 264 +++--- .../example_mha_sink_fwd_bhsd.py | 199 ++--- ...ample_mha_sink_fwd_bhsd_wgmma_pipelined.py | 210 +++-- examples/bitnet-1.58b/benchmark_generate.py | 35 +- .../benchmark_inference_latency.py | 9 +- examples/bitnet-1.58b/configuration_bitnet.py | 16 +- examples/bitnet-1.58b/eval_correctness.py | 22 +- examples/bitnet-1.58b/eval_gpu_memory.py | 13 +- examples/bitnet-1.58b/eval_ppl.py | 28 +- examples/bitnet-1.58b/eval_utils.py | 20 +- .../tilelang_bitnet_158_int8xint2_decode.py | 36 +- .../tilelang_bitnet_158_int8xint2_prefill.py | 77 +- .../kernel_benchmark/tl_int8xint8.py | 22 +- examples/bitnet-1.58b/load_from_quantized.py | 8 +- .../bitnet-1.58b/maint/create_bitblas_ckpt.py | 21 +- examples/bitnet-1.58b/modeling_bitnet.py | 308 +++---- examples/bitnet-1.58b/tokenization_bitnet.py | 60 +- examples/bitnet-1.58b/utils_quant.py | 24 +- .../bitnet-1.58b/vllm_workspace/conftest.py | 35 +- .../inference_with_compress_format.py | 15 +- .../inference_with_native_format.py | 14 +- examples/bitnet-1.58b/vllm_workspace/utils.py | 23 +- .../block_sparse_attn_triton.py | 71 +- .../example_tilelang_block_sparse_attn.py | 81 +- ...xample_tilelang_sparse_gqa_decode_paged.py | 240 +++--- ...ilelang_sparse_gqa_decode_varlen_indice.py | 240 +++--- ..._tilelang_sparse_gqa_decode_varlen_mask.py | 232 ++--- ..._triton_sparse_gqa_decode_varlen_indice.py | 153 ++-- ...le_triton_sparse_gqa_decode_varlen_mask.py | 147 ++-- examples/blocksparse_attention/heuristic.py | 3 +- .../test_example_blocksparse_attention.py | 20 +- .../example_blocksparse_gemm.py | 69 +- ...ample_group_per_split_token_cast_to_fp8.py | 51 +- .../cast/example_per_token_cast_to_fp8.py | 27 +- examples/cast/example_triton_cast_to_fp8.py | 4 +- examples/cast/test_example_cast.py | 3 +- examples/compile_flags/usecase.py | 10 +- examples/conftest.py | 7 +- examples/convolution/example_convolution.py | 64 +- .../example_convolution_autotune.py | 135 +-- .../example_deepgemm_fp8_2xAcc.py | 29 +- .../amd/benchmark_mla_decode_amd_tilelang.py | 186 ++-- .../amd/benchmark_mla_decode_amd_torch.py | 165 ++-- .../amd/benchmark_mla_decode_amd_triton.py | 165 ++-- examples/deepseek_mla/benchmark_mla.py | 198 ++--- examples/deepseek_mla/example_mla_decode.py | 181 ++-- .../deepseek_mla/example_mla_decode_paged.py | 226 +++-- .../example_mla_decode_persistent.py | 98 +-- .../deepseek_mla/example_mla_decode_ws.py | 245 +++--- .../experimental/example_mla_decode_kv_fp8.py | 85 +- examples/deepseek_mla/torch_refs.py | 29 +- .../benchmark/benchmark_nsa_fwd.py | 418 ++++----- .../deepseek_nsa/example_tilelang_nsa_bwd.py | 208 ++--- .../example_tilelang_nsa_decode.py | 50 +- .../deepseek_nsa/example_tilelang_nsa_fwd.py | 63 +- .../example_tilelang_nsa_fwd_varlen.py | 178 ++-- .../deepseek_nsa/example_triton_nsa_bwd.py | 354 +++++--- .../deepseek_nsa/example_triton_nsa_fwd.py | 124 +-- .../example_triton_nsa_fwd_varlen.py | 159 ++-- examples/deepseek_nsa/reference.py | 113 ++- examples/deepseek_v32/fp8_lighting_indexer.py | 96 +-- examples/deepseek_v32/sparse_mla_bwd.py | 199 ++--- examples/deepseek_v32/sparse_mla_fwd.py | 114 +-- .../deepseek_v32/sparse_mla_fwd_pipelined.py | 198 ++--- .../test_tilelang_example_deepseek_v32.py | 9 +- examples/deepseek_v32/topk_selector.py | 25 +- examples/deepseek_v32/utils.py | 144 ++-- examples/dequantize_gemm/dequantize_utils.py | 24 +- .../example_dequant_gemm_bf16_fp4_hopper.py | 219 ++--- .../example_dequant_gemm_bf16_mxfp4_hopper.py | 248 +++--- ...mple_dequant_gemm_bf16_mxfp4_hopper_tma.py | 250 +++--- .../example_dequant_gemm_fine_grained.py | 107 ++- .../example_dequant_gemm_fp4_hopper.py | 96 +-- .../example_dequant_gemm_w4a8.py | 40 +- .../example_dequant_gemv_fp16xint4.py | 66 +- ...e_dequant_groupedgemm_bf16_mxfp4_hopper.py | 190 ++--- examples/dsa_sparse_finetune/dsa.py | 95 +-- examples/dsa_sparse_finetune/index.py | 23 +- examples/dsa_sparse_finetune/indexer_bwd.py | 75 +- .../indexer_topk_reducesum.py | 56 +- .../dsa_sparse_finetune/sparse_mla_bwd.py | 228 ++--- .../dsa_sparse_finetune/sparse_mla_fwd.py | 126 ++- .../sparse_mla_topk_reducesum.py | 85 +- examples/dsa_sparse_finetune/utils.py | 6 +- examples/dynamic_shape/example_dynamic.py | 16 +- .../elementwise/example_elementwise_add.py | 16 +- examples/flash_attention/bert_padding.py | 16 +- examples/flash_attention/example_gqa_bwd.py | 322 +++---- .../example_gqa_bwd_tma_reduce.py | 357 ++++---- .../example_gqa_bwd_tma_reduce_varlen.py | 460 +++++----- .../example_gqa_bwd_wgmma_pipelined.py | 217 ++--- .../flash_attention/example_gqa_fwd_bshd.py | 133 ++- .../example_gqa_fwd_bshd_wgmma_pipelined.py | 106 +-- .../flash_attention/example_gqa_fwd_varlen.py | 150 ++-- .../flash_attention/example_mha_bwd_bhsd.py | 157 ++-- .../flash_attention/example_mha_bwd_bshd.py | 151 ++-- .../example_mha_bwd_bshd_wgmma_pipelined.py | 157 ++-- .../flash_attention/example_mha_fwd_bhsd.py | 98 +-- .../example_mha_fwd_bhsd_wgmma_pipelined.py | 109 +-- .../flash_attention/example_mha_fwd_bshd.py | 96 +-- .../example_mha_fwd_bshd_wgmma_pipelined.py | 107 +-- .../flash_attention/example_mha_fwd_varlen.py | 82 +- .../test_example_flash_attention.py | 6 +- examples/flash_attention/varlen_utils.py | 32 +- examples/flash_decoding/example_gqa_decode.py | 242 +++--- .../example_gqa_decode_varlen_logits.py | 297 +++---- .../example_gqa_decode_varlen_logits_paged.py | 292 +++---- .../flash_decoding/example_mha_inference.py | 145 ++-- .../fusedmoe/example_fusedmoe_tilelang.py | 349 ++++---- examples/fusedmoe/example_fusedmoe_torch.py | 91 +- examples/fusedmoe/test_example_fusedmoe.py | 9 +- examples/gdn/example_chunk_delta_bwd.py | 222 +++-- examples/gdn/example_chunk_delta_h.py | 138 +-- examples/gdn/example_chunk_o.py | 83 +- examples/gdn/example_chunk_o_bwd.py | 195 ++--- examples/gdn/example_chunk_scaled_dot_kkt.py | 46 +- examples/gdn/example_cumsum.py | 36 +- examples/gdn/example_wy_fast.py | 73 +- examples/gdn/example_wy_fast_bwd_split.py | 225 ++--- examples/gdn/test_example_gdn_compilation.py | 279 ++++-- examples/gdn/test_utils.py | 14 +- examples/gemm/example_gemm.py | 7 +- examples/gemm/example_gemm_autotune.py | 89 +- examples/gemm/example_gemm_intrinsics.py | 22 +- examples/gemm/example_gemm_persistent.py | 63 +- examples/gemm/example_gemm_schedule.py | 7 +- .../gemm_fp8/example_tilelang_gemm_amd.py | 77 +- .../gemm_fp8/example_tilelang_gemm_fp8.py | 15 +- .../example_tilelang_gemm_fp8_2xAcc.py | 16 +- .../example_tilelang_gemm_fp8_intrinsic.py | 22 +- .../example_tilelang_gemm_fp8_sm100.py | 10 +- examples/gemm_sm100/gemm_mma.py | 10 +- examples/gemm_sm100/gemm_tcgen5mma.py | 24 +- examples/gemm_sp/example_custom_compress.py | 188 ++-- examples/gemm_sp/example_gemm_sp.py | 123 ++- .../example_tilelang_gemm_splitk.py | 21 +- ...ilelang_gemm_splitk_vectorize_atomicadd.py | 21 +- .../example_tilelang_gemm_streamk.py | 10 +- examples/gemv/example_gemv.py | 81 +- .../grouped_gemm/example_grouped_gemm_bwd.py | 151 +--- .../grouped_gemm/example_grouped_gemm_fwd.py | 88 +- .../hadamard_transform/example_hadamard.py | 35 +- examples/lazy_jit/lazyjit.en.ipynb | 72 +- examples/lazy_jit/lazyjit.zh.ipynb | 72 +- .../example_linear_attn_bwd.py | 129 ++- .../example_linear_attn_fwd.py | 86 +- .../example_mamba_chunk_scan.py | 189 +++-- .../example_mamba_chunk_state.py | 119 ++- .../linear_attention/example_retention_fwd.py | 49 +- .../example_vertical_slash_sparse_attn.py | 223 ++--- examples/norm/rms_norm.py | 4 +- examples/norm/test_rms_norm.py | 4 +- examples/online_softmax/online_softmax.py | 12 +- examples/plot_layout/fragment_mfma_load_a.py | 20 +- examples/plot_layout/fragment_mma_load_a.py | 11 +- examples/quickstart.py | 7 +- .../block_sparse_attn_tilelang.py | 112 ++- .../block_sparse_attn_triton.py | 70 +- .../tilelang_example_sparse_tensorcore.py | 30 +- examples/topk/example_topk.py | 14 +- .../visual_layout_inference.py | 16 +- .../example_warp_specialize_flashmla.py | 144 ++-- ...warp_specialize_gemm_barrierpipe_stage2.py | 17 +- ...mple_warp_specialize_gemm_copy_0_gemm_1.py | 16 +- ...mple_warp_specialize_gemm_copy_1_gemm_0.py | 16 +- ...mple_warp_specialize_gemm_copy_gemm_0_1.py | 22 +- ...le_warp_specialize_gemm_softpipe_stage2.py | 1 - format.sh | 2 +- maint/gemm_v2/correctness_evaluation.py | 100 ++- maint/gemm_v2/correctness_evaluation_sm70.py | 25 +- .../gemm_v2/correctness_evaluation_tcgen05.py | 29 +- maint/gemm_v2/latency.py | 7 +- maint/gemm_v2/latency_gemm.py | 7 +- maint/gemm_v2/latency_mha_fwd_bhsd.py | 98 +-- maint/host_checks/01_num_args_mismatch.py | 1 + maint/host_checks/02_pointer_type_error.py | 1 + maint/host_checks/03_ndim_mismatch.py | 4 +- maint/host_checks/04_dtype_mismatch.py | 4 +- maint/host_checks/05_shape_mismatch.py | 4 +- maint/host_checks/06_strides_mismatch.py | 4 +- maint/host_checks/07_device_type_mismatch.py | 4 +- maint/host_checks/08_device_id_mismatch.py | 4 +- maint/host_checks/09_null_data_pointer.py | 1 + maint/host_checks/10_scalar_type_mismatch.py | 4 +- maint/host_checks/common.py | 17 +- maint/precision/compare_ops.py | 70 +- maint/scripts/ci_performance.py | 43 +- maint/scripts/performance.py | 40 +- pyproject.toml | 15 +- requirements-lint.txt | 1 - testing/conftest.py | 7 +- .../amd/test_tilelang_gemm_mfma_intrinsic.py | 54 +- .../amd/test_tilelang_gemm_mfma_preshuffle.py | 134 ++- testing/python/amd/test_tilelang_test_amd.py | 15 +- .../test_tilelang_fragment_loop_checker.py | 36 +- .../test_tilelang_nested_loop_checker.py | 120 ++- .../python/autotune/test_tilelang_autotune.py | 30 +- .../test_tilelang_autotune_with_inputs.py | 38 +- .../cache/test_tilelang_cache_matmul.py | 7 +- ..._tilelang_carver_cuda_driver_properties.py | 16 +- .../test_tilelang_carver_generate_hints.py | 25 +- .../test_tilelang_carver_recommend_hints.py | 17 +- .../test_storage_rewrite_detect_inplace.py | 3 +- ...ng_pass_config_disable_warp_specialized.py | 9 +- testing/python/cpu/test_tilelang_cpu_gemm.py | 14 +- testing/python/debug/test_device_assert.py | 2 - .../python/debug/test_tilelang_debug_print.py | 37 +- .../dynamic/test_tilelang_dynamic_symbolic.py | 81 +- .../test_tilelang_dynamic_symbolic_bench.py | 54 +- .../python/fastmath/test_mathops_fastmath.py | 81 +- .../python/issue/test_tilelang_issue_1001.py | 13 +- .../python/issue/test_tilelang_issue_1008.py | 22 +- .../python/issue/test_tilelang_issue_1115.py | 24 +- .../python/issue/test_tilelang_issue_1198.py | 14 +- .../python/issue/test_tilelang_issue_814.py | 5 +- .../python/issue/test_tilelang_issue_830.py | 2 - .../python/issue/test_tilelang_issue_96.py | 16 +- .../issue/test_tilelang_issue_merge_if.py | 1 - .../python/jit/test_tilelang_jit_callback.py | 13 +- testing/python/jit/test_tilelang_jit_gemm.py | 7 +- .../jit/test_tilelang_jit_gemm_cython.py | 157 ++-- .../python/jit/test_tilelang_jit_nullptr.py | 23 +- testing/python/jit/test_tilelang_jit_nvrtc.py | 173 +--- .../jit/test_tilelang_jit_parcompile.py | 12 +- .../python/jit/test_tilelang_jit_tvm_ffi.py | 181 ++-- .../test_tilelang_kernel_bf16_gemm_mma.py | 22 +- .../test_tilelang_kernel_element_wise_add.py | 8 +- .../kernel/test_tilelang_kernel_fp8_gemm.py | 10 +- .../test_tilelang_kernel_fp8_gemm_mma.py | 22 +- .../test_tilelang_kernel_fp8_gemv_simt.py | 36 +- .../kernel/test_tilelang_kernel_gemm.py | 22 +- ...test_tilelang_kernel_gemm_mma_intrinsic.py | 22 +- .../kernel/test_tilelang_kernel_gemm_simt.py | 27 +- .../test_tilelang_kernel_gemm_with_stride.py | 10 +- .../kernel/test_tilelang_kernel_gemv_simt.py | 36 +- .../test_tilelang_kernel_int4_gemm_mma.py | 51 +- .../python/language/test_tilelang_capture.py | 11 +- .../python/language/test_tilelang_intimm.py | 22 +- .../language/test_tilelang_language_alias.py | 7 +- .../language/test_tilelang_language_all_of.py | 42 +- .../language/test_tilelang_language_alloc.py | 16 +- .../language/test_tilelang_language_annot.py | 35 +- ...t_tilelang_language_annotate_safe_value.py | 14 +- .../language/test_tilelang_language_any_of.py | 42 +- .../language/test_tilelang_language_assume.py | 25 +- .../test_tilelang_language_atomic_add.py | 45 +- .../test_tilelang_language_ceildiv.py | 2 - .../test_tilelang_language_chain_equal.py | 10 +- .../language/test_tilelang_language_clamp.py | 8 +- .../language/test_tilelang_language_clear.py | 11 +- ...test_tilelang_language_composable_index.py | 8 +- .../language/test_tilelang_language_copy.py | 54 +- .../language/test_tilelang_language_cumsum.py | 31 +- .../test_tilelang_language_frontend_v2.py | 87 +- .../test_tilelang_language_get_warp_info.py | 5 - .../test_tilelang_language_if_range.py | 9 +- .../test_tilelang_language_infinity.py | 3 +- ...st_tilelang_language_intrinsics_codegen.py | 4 +- .../test_tilelang_language_lazy_jit.py | 131 ++- .../language/test_tilelang_language_let.py | 1 - .../test_tilelang_language_mask_op.py | 62 +- .../test_tilelang_language_negative_index.py | 3 +- .../test_tilelang_language_parallel.py | 12 +- .../test_tilelang_language_pipeline.py | 48 +- .../language/test_tilelang_language_ptr.py | 1 - .../language/test_tilelang_language_reduce.py | 20 +- .../test_tilelang_language_reshape.py | 54 +- .../test_tilelang_language_ternary.py | 12 +- .../language/test_tilelang_language_tma_1d.py | 6 +- .../language/test_tilelang_language_unroll.py | 2 - .../test_tilelang_language_var_init.py | 12 +- .../test_tilelang_language_vectorize.py | 16 +- .../test_tilelang_language_vectorized_cast.py | 11 +- .../language/test_tilelang_language_view.py | 12 +- .../test_tilelang_language_warp_reduce.py | 21 +- .../test_tilelang_layout_fused_replicate.py | 13 +- .../python/math/test_math_bitwise_reduce.py | 11 +- testing/python/math/test_math_fast_math.py | 81 +- testing/python/math/test_math_ieee_math.py | 59 +- testing/python/metal/test_metal_codegen.py | 29 +- .../test_tilelang_primitives_mma.py | 48 +- .../python/profiler/test_tilelang_profiler.py | 7 +- .../test_tilelang_tilelibrary_gemm.py | 62 +- .../test_tilelang_tilelibrary_gemm_sp.py | 91 +- .../test_tilelang_tilelibrary_gemm_sp_v2.py | 126 ++- ...lang_transform_Inject_software_pipeline.py | 12 +- ...est_tilelang_transform_cluster_planning.py | 7 +- ...ilelang_transform_config_index_bitwidth.py | 66 +- ...t_tilelang_transform_inject_fence_proxy.py | 56 +- ..._tilelang_transform_inject_set_max_nreg.py | 52 +- ...est_tilelang_transform_layout_inference.py | 94 +- ...g_transform_legalize_safe_memory_access.py | 54 +- ...lang_transform_legalize_vectorized_loop.py | 8 +- .../test_tilelang_transform_let_inline.py | 5 +- ..._tilelang_transform_lower_hopper_intrin.py | 17 +- .../test_tilelang_transform_lower_tile_op.py | 72 +- ...test_tilelang_transform_make_packed_api.py | 23 +- ...tilelang_transform_multi_version_buffer.py | 58 +- ...st_tilelang_transform_pipeline_planning.py | 22 +- .../test_tilelang_transform_simplify.py | 13 +- .../test_tilelang_transform_thread_sync.py | 30 +- ...est_tilelang_transform_warp_specialized.py | 72 +- testing/python/utils/test_compress_utils.py | 2 +- testing/python/webgpu/test_webgpu_codegen.py | 7 +- tilelang/__init__.py | 2 + tilelang/analysis/fragment_loop_checker.py | 18 +- tilelang/analysis/layout_visual.py | 6 +- tilelang/analysis/nested_loop_checker.py | 23 +- tilelang/autotuner/capture.py | 3 +- tilelang/autotuner/param.py | 73 +- tilelang/autotuner/tuner.py | 171 ++-- tilelang/cache/__init__.py | 15 +- tilelang/cache/kernel_cache.py | 54 +- tilelang/carver/__init__.py | 1 + tilelang/carver/analysis.py | 22 +- tilelang/carver/arch/__init__.py | 28 +- tilelang/carver/arch/arch_base.py | 8 +- tilelang/carver/arch/cdna.py | 5 +- tilelang/carver/arch/cpu.py | 5 +- tilelang/carver/arch/cuda.py | 17 +- tilelang/carver/arch/driver/cuda_driver.py | 3 +- tilelang/carver/arch/metal.py | 5 +- tilelang/carver/common_schedules.py | 1 + tilelang/carver/matmul_analysis.py | 79 +- tilelang/carver/roller/bestfit.py | 8 +- tilelang/carver/roller/hint.py | 5 +- tilelang/carver/roller/node.py | 65 +- tilelang/carver/roller/policy/default.py | 84 +- tilelang/carver/roller/policy/tensorcore.py | 55 +- tilelang/carver/roller/rasterization.py | 2 - .../carver/roller/shape_inference/common.py | 5 +- tilelang/carver/roller/shape_inference/tir.py | 41 +- tilelang/carver/template/base.py | 13 +- tilelang/carver/template/conv.py | 25 +- tilelang/carver/template/flashattention.py | 6 +- tilelang/carver/template/gemv.py | 9 +- tilelang/carver/template/general_reduce.py | 10 +- tilelang/carver/template/matmul.py | 9 +- tilelang/carver/utils.py | 27 +- tilelang/contrib/cc.py | 34 +- tilelang/contrib/dlpack.py | 9 +- tilelang/contrib/hipcc.py | 9 +- tilelang/contrib/nvcc.py | 37 +- tilelang/contrib/nvrtc.py | 21 +- tilelang/contrib/rocm.py | 7 +- tilelang/engine/lower.py | 39 +- tilelang/engine/param.py | 3 + tilelang/engine/phase.py | 18 +- tilelang/env.py | 72 +- tilelang/intrinsics/mfma_layout.py | 18 +- tilelang/intrinsics/mfma_macro_generator.py | 195 ++--- tilelang/intrinsics/mma_layout.py | 22 +- tilelang/intrinsics/mma_macro_generator.py | 143 ++-- tilelang/intrinsics/mma_sm70_layout.py | 8 +- .../intrinsics/mma_sm70_macro_generator.py | 75 +- tilelang/intrinsics/mma_sp_layout.py | 27 +- tilelang/intrinsics/mma_sp_macro_generator.py | 115 +-- .../intrinsics/tcgen05_macro_generator.py | 93 +- tilelang/intrinsics/utils.py | 2 +- tilelang/intrinsics/wgmma_macro_generator.py | 179 ++-- tilelang/ir.py | 48 +- tilelang/jit/__init__.py | 160 ++-- tilelang/jit/adapter/base.py | 10 +- tilelang/jit/adapter/ctypes/adapter.py | 63 +- tilelang/jit/adapter/cython/adapter.py | 79 +- tilelang/jit/adapter/libgen.py | 19 +- tilelang/jit/adapter/nvrtc/__init__.py | 14 +- tilelang/jit/adapter/nvrtc/adapter.py | 52 +- tilelang/jit/adapter/nvrtc/libgen.py | 20 +- tilelang/jit/adapter/nvrtc/wrapper.py | 146 ++-- tilelang/jit/adapter/torch/__init__.py | 2 +- tilelang/jit/adapter/torch/metal.py | 12 +- tilelang/jit/adapter/tvm_ffi.py | 80 +- tilelang/jit/adapter/utils.py | 112 +-- tilelang/jit/adapter/wrapper.py | 226 ++--- tilelang/jit/execution_backend.py | 7 +- tilelang/jit/kernel.py | 71 +- tilelang/language/__init__.py | 6 +- tilelang/language/allocate.py | 41 +- tilelang/language/annotations.py | 1 + tilelang/language/ast/__init__.py | 1 + tilelang/language/ast/_ffi_api.py | 1 + tilelang/language/ast/ir.py | 87 +- tilelang/language/atomic.py | 26 +- tilelang/language/builtin.py | 110 ++- tilelang/language/copy.py | 58 +- tilelang/language/customize.py | 12 +- tilelang/language/experimental/gemm_sp.py | 13 +- tilelang/language/fill.py | 7 +- tilelang/language/frame.py | 8 +- tilelang/language/gemm.py | 36 +- tilelang/language/kernel.py | 28 +- tilelang/language/logical.py | 7 +- tilelang/language/loop.py | 25 +- tilelang/language/math_intrinsics.py | 2 +- tilelang/language/overrides/parser.py | 25 +- tilelang/language/parser/entry.py | 8 +- tilelang/language/parser/operation.py | 12 +- tilelang/language/parser/parser.py | 12 +- tilelang/language/print.py | 40 +- tilelang/language/proxy.py | 92 +- tilelang/language/reduce.py | 11 +- tilelang/language/tir/entry.py | 7 +- tilelang/language/tir/ir.py | 22 +- tilelang/language/tir/ir.pyi | 110 ++- tilelang/language/tir/op.py | 28 +- tilelang/language/utils.py | 13 +- tilelang/language/v2/annot.py | 166 ++-- tilelang/language/v2/ast.py | 159 ++-- tilelang/language/v2/builder.py | 202 ++--- tilelang/language/v2/dtypes.py | 802 +++++++++--------- tilelang/language/v2/utils.py | 21 +- tilelang/language/warpgroup.py | 1 + tilelang/layout/fragment.py | 22 +- tilelang/layout/gemm_sp.py | 9 +- tilelang/layout/layout.py | 8 +- tilelang/layout/swizzle.py | 27 +- tilelang/libinfo.py | 3 +- tilelang/primitives/__init__.py | 2 +- tilelang/primitives/gemm/__init__.py | 12 +- tilelang/primitives/gemm/base.py | 19 +- tilelang/primitives/gemm/gemm_mma.py | 26 +- tilelang/profiler/__init__.py | 15 +- tilelang/profiler/bench.py | 10 +- tilelang/quantize/lop3.py | 7 +- tilelang/quantize/mxfp.py | 10 +- tilelang/quantize/utils.py | 9 +- tilelang/testing/__init__.py | 19 +- tilelang/tileop/gemm/__init__.py | 3 +- tilelang/tileop/gemm/gemm_mfma.py | 18 +- tilelang/tileop/gemm/gemm_mma.py | 17 +- tilelang/tileop/gemm/gemm_mma_sm70.py | 17 +- tilelang/tileop/gemm/gemm_tcgen05.py | 30 +- tilelang/tileop/gemm/gemm_wgmma.py | 37 +- tilelang/tileop/gemm_sp/__init__.py | 6 +- tilelang/tileop/gemm_sp/gemm_sp_mma.py | 14 +- tilelang/tools/Analyzer.py | 16 +- tilelang/tools/plot_layout.py | 80 +- tilelang/transform/__init__.py | 30 +- tilelang/transform/add_bufstore_wrapper.py | 9 +- tilelang/transform/pass_config.py | 1 + tilelang/transform/simplify.py | 1 - tilelang/utils/deprecated.py | 8 +- tilelang/utils/language.py | 19 +- tilelang/utils/sparse.py | 49 +- tilelang/utils/target.py | 5 +- tilelang/utils/tensor.py | 68 +- version_provider.py | 48 +- 467 files changed, 12931 insertions(+), 15919 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1e5bab4..d1bb4ce 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -39,19 +39,9 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.14.7 # sync with requirements-lint.txt hooks: + - id: ruff-format - id: ruff-check args: [--fix, --exit-non-zero-on-fix] - - repo: https://github.com/google/yapf - rev: v0.43.0 # sync with requirements-lint.txt - hooks: - - id: yapf - name: yapf-multiproc-bugfix - # yapf is not multiprocess safe, so we run a dummy yapf first. - args: [--in-place, docs/conf.py] - always_run: true - pass_filenames: false - - id: yapf - args: [--recursive, --in-place] - repo: https://github.com/codespell-project/codespell rev: v2.4.1 # sync with requirements-lint.txt hooks: @@ -62,4 +52,4 @@ repos: ^.+\.(cpp|hpp|cxx|cc|c|h|cu|cuh)$| ^.+\.svg$| ^.*\brequirements\b.*\.txt$ - ) \ No newline at end of file + ) diff --git a/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py b/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py index 6401276..3dd82aa 100644 --- a/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py +++ b/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py @@ -7,10 +7,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -28,15 +25,15 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F def benchmark_topk_sparse_attention(): from benchmark_configs import configs + torch.manual_seed(0) # Config for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: - # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) import flash_attn diff --git a/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py index 7c9edb5..fff65b4 100644 --- a/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py +++ b/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py @@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -39,7 +36,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) block_N = 64 num_stages = 2 threads = 128 - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, heads, seq_len, dim] block_mask_shape = [batch, heads, downsample_len, downsample_len] @@ -48,7 +45,6 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) block_mask_dtype = "bool" def kernel_func(block_M, block_N, num_stages, threads): - @T.macro def MMA0( K: T.Tensor(shape, dtype), @@ -60,11 +56,10 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) by: T.int32, bz: T.int32, ): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -79,18 +74,18 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) by: T.int32, bz: T.int32, ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -116,22 +111,21 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def main( - Q: T.Tensor(shape, dtype), - K: T.Tensor(shape, dtype), - V: T.Tensor(shape, dtype), - BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), - Output: T.Tensor(shape, dtype), + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), + Output: T.Tensor(shape, dtype), ): - with T.Kernel( - T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) @@ -146,7 +140,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) logsum = T.alloc_fragment([block_M], accum_dtype) block_mask = T.alloc_local([downsample_len], block_mask_dtype) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -155,20 +149,19 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) block_mask[vj] = BlockSparseMask[bz, by, bx, vj] loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): if block_mask[k]: MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, - scores_sum, logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main @@ -177,26 +170,23 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) def benchmark_topk_sparse_attention(): from benchmark_configs import configs + torch.manual_seed(0) # Config for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: - # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) # Create sparse mask (downsampled to block level) downsample_factor = BLOCK downsample_len = math.ceil(SEQ_LEN / downsample_factor) - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.bfloat16) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) - program = blocksparse_flashattn( - BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) + program = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) kernel = tilelang.compile(program, out_idx=4) def benchmark_fn(): diff --git a/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py index e4828ce..85d754a 100644 --- a/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py +++ b/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py @@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -31,39 +28,37 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F def benchmark_topk_sparse_attention(): from benchmark_configs import configs + torch.manual_seed(0) # Config for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: - # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) sm_scale = 1.0 / (D_HEAD**0.5) # Create sparse mask (downsampled to block level) downsample_factor = BLOCK downsample_len = math.ceil(SEQ_LEN / downsample_factor) - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.bfloat16) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) def benchmark_fn(): # Compute reference # Expand block mask to full attention matrix - full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")) full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal # PyTorch reference implementation - attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale - attn = attn.masked_fill(~full_mask, float('-inf')) + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float("-inf")) attn = F.softmax(attn, dim=-1) - ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) return ref_output ref_latency = do_bench( diff --git a/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py index 86ac894..7ebca93 100644 --- a/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py +++ b/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py @@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -56,7 +53,6 @@ def _fwd_kernel_inner( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): - mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n) if mask_val == True: @@ -72,8 +68,7 @@ def _fwd_kernel_inner( # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N if LAST_K_BLOCK: - qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, - float('-inf')) + qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf")) m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk -= m_ij[:, None] @@ -153,7 +148,7 @@ def _fwd_kernel( v_ptrs = V + off_v mask_ptrs = block_mask_ptr + start_m * stride_bmm - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) @@ -191,24 +186,12 @@ def _fwd_kernel( acc = acc * l_recip acc = acc.to(Out.dtype.element_ty) - off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[ - None, :] * stride_od + off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od out_ptrs = Out + off_o tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX) -def _forward(ctx, - q, - k, - v, - block_sparse_mask, - sm_scale, - BLOCK_M=64, - BLOCK_N=64, - num_warps=None, - num_stages=1, - out=None): - +def _forward(ctx, q, k, v, block_sparse_mask, sm_scale, BLOCK_M=64, BLOCK_N=64, num_warps=None, num_stages=1, out=None): assert q.shape[-1] == k.shape[-1] == v.shape[-1] assert k.shape[2] == v.shape[2] o = out if out is not None else torch.empty_like(q).contiguous() @@ -253,7 +236,6 @@ def _forward(ctx, class _sparse_attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, block_sparse_dense, sm_scale): # shape constraints @@ -271,24 +253,22 @@ block_sparse_triton_fn = _sparse_attention.apply def benchmark_topk_sparse_attention(): from benchmark_configs import configs + torch.manual_seed(0) # Config for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: - # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) sm_scale = 1.0 / (D_HEAD**0.5) # Create sparse mask (downsampled to block level) downsample_factor = BLOCK downsample_len = math.ceil(SEQ_LEN / downsample_factor) - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.bfloat16) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) diff --git a/benchmark/mamba2/benchmark_mamba_chunk_scan.py b/benchmark/mamba2/benchmark_mamba_chunk_scan.py index aff810f..a3ed72b 100644 --- a/benchmark/mamba2/benchmark_mamba_chunk_scan.py +++ b/benchmark/mamba2/benchmark_mamba_chunk_scan.py @@ -51,14 +51,15 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D): dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :] decay = torch.exp(dt_segment_sum) scores_decay = cb * rearrange(decay, "b h c l s -> b c h l s") - causal_mask = torch.tril( - torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0) + causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0) scores_decay = scores_decay.masked_fill(~causal_mask, 0) - out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype), - rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)) + out = torch.einsum( + "bchls,bhcs,bcshp->bclhp", scores_decay.to(x.dtype), dt.to(x.dtype), rearrange(x, "b (c s) h p -> b c s h p", c=nchunks) + ) state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1")) - out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange( - C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out + out_prev = ( + torch.einsum("bclhn,bchpn->bclhp", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out + ) out = out + out_prev out = rearrange(out, "b c l h p -> b (c l) h p") if D is not None: @@ -74,7 +75,6 @@ def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D): def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D): - @helion.kernel() def helion_mamba2_chunk_scan_kernel( cb: torch.Tensor, @@ -118,8 +118,7 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D): dtype = cb.dtype accum_dtype = torch.float32 - assert (x.dtype == dt.dtype == dA_cumsum.dtype == C.dtype == prev_states.dtype == D.dtype == - dtype) + assert x.dtype == dt.dtype == dA_cumsum.dtype == C.dtype == prev_states.dtype == D.dtype == dtype out = torch.empty_like(x) @@ -127,11 +126,10 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D): for tile_h, tile_m, tile_n, tile_b, tile_c in hl.tile( [nheads, chunk_size, headdim, batch, nchunks], - block_size=[1, block_m, block_n, 1, 1], + block_size=[1, block_m, block_n, 1, 1], ): acc_o = hl.zeros([tile_m, tile_n], dtype=accum_dtype) - dA_cumsum_local_m = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, - tile_m].to(torch.float32) + dA_cumsum_local_m = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, tile_m].to(torch.float32) scale_m_local = torch.exp2(dA_cumsum_local_m * p) C_local = C[ @@ -152,10 +150,8 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D): tile_m, tile_k, ] - dA_cumsum_local_k = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, - tile_k].to(torch.float32) - cb_local *= torch.exp2(dA_cumsum_local_m[:, None] * p - - dA_cumsum_local_k[None, :] * p) + dA_cumsum_local_k = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, tile_k].to(torch.float32) + cb_local *= torch.exp2(dA_cumsum_local_m[:, None] * p - dA_cumsum_local_k[None, :] * p) dt_local = dt[tile_b.begin, tile_h.begin, tile_c.begin, tile_k].to(torch.float32) cb_local = (cb_local * dt_local[None, :]).to(dtype) pred = (tile_m.index + 0)[:, None] >= (tile_k.index + 0)[None, :] @@ -169,11 +165,9 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D): acc_o = hl.dot(cb_local, x_local, acc=acc_o) D_local = D[tile_h.begin].to(torch.float32) - x_residual = x[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, - tile_n].to(torch.float32) + x_residual = x[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, tile_n].to(torch.float32) acc_o += x_residual * D_local - out[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, - tile_n] = acc_o.to(dtype=dtype) + out[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, tile_n] = acc_o.to(dtype=dtype) return out @@ -182,12 +176,7 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D): def get_configs(): - iter_params = dict( - block_M=[64, 128, 256], - block_N=[32, 64], - block_K=[64, 128, 256], - block_Dstate=[128], - num_stages=[1, 2, 3, 4, 5]) + iter_params = dict(block_M=[64, 128, 256], block_N=[32, 64], block_K=[64, 128, 256], block_Dstate=[128], num_stages=[1, 2, 3, 4, 5]) return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] @@ -198,19 +187,21 @@ def get_configs(): tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }, ) -def chunk_scan_fwd(batch, - seqlen, - chunk_size, - ngroups, - nheads, - headdim, - dstate, - block_M=64, - block_N=64, - block_K=64, - block_Dstate=128, - num_stages=2, - threads=128): +def chunk_scan_fwd( + batch, + seqlen, + chunk_size, + ngroups, + nheads, + headdim, + dstate, + block_M=64, + block_N=64, + block_K=64, + block_Dstate=128, + num_stages=2, + threads=128, +): dtype = "float16" accum_dtype = "float" nchunks = T.ceildiv(seqlen, chunk_size) @@ -218,20 +209,20 @@ def chunk_scan_fwd(batch, @T.prim_func def main( - cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore - x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore - dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore - dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore - C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore - prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore - D: T.Tensor((nheads), dtype), # type: ignore - Output: T.Tensor((batch, seqlen, nheads, headdim), dtype) # type: ignore + cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore + x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore + dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore + dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore + C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore + prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore + D: T.Tensor((nheads), dtype), # type: ignore + Output: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore ): - with T.Kernel( - nheads, - T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), - batch * nchunks, - threads=threads) as (bz, bx, by): + with T.Kernel(nheads, T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), batch * nchunks, threads=threads) as ( + bz, + bx, + by, + ): acc_o = T.alloc_fragment((block_M, block_N), accum_dtype) acc_o_shared = T.alloc_shared((block_M, block_N), dtype) cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn") @@ -257,27 +248,32 @@ def chunk_scan_fwd(batch, m_idx = bx // T.ceildiv(headdim, block_N) n_idx = bx % T.ceildiv(headdim, block_N) - T.annotate_layout({ - acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared), - cb_shared: tilelang.layout.make_swizzled_layout(cb_shared), - x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared) - }) + T.annotate_layout( + { + acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared), + cb_shared: tilelang.layout.make_swizzled_layout(cb_shared), + x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared), + } + ) T.no_set_max_nreg() - T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M:(m_idx + 1) * block_M], - dA_cs_m_shared) + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M : (m_idx + 1) * block_M], dA_cs_m_shared) T.copy(dA_cs_m_shared, dA_cs_m_local) T.clear(acc_o) for i in T.Parallel(block_M): scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p) T.copy( - C[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + - (m_idx + 1) * block_M, bz // (nheads // ngroups), 0:block_Dstate], C_shared) - T.copy( - prev_states[batch_idx, chunk_idx, bz, n_idx * block_N:(n_idx + 1) * block_N, - 0:block_Dstate], prev_state_shared) + C[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz // (nheads // ngroups), + 0:block_Dstate, + ], + C_shared, + ) + T.copy(prev_states[batch_idx, chunk_idx, bz, n_idx * block_N : (n_idx + 1) * block_N, 0:block_Dstate], prev_state_shared) T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True) for i, j in T.Parallel(block_M, block_N): acc_o[i, j] *= scale_m_local[i] @@ -286,34 +282,47 @@ def chunk_scan_fwd(batch, for k in T.Pipelined(loop_range, num_stages=num_stages): T.copy( - cb[batch_idx, chunk_idx, bz // (nheads // ngroups), - m_idx * block_M:(m_idx + 1) * block_M, k * block_K:(k + 1) * block_K], - cb_shared) + cb[ + batch_idx, + chunk_idx, + bz // (nheads // ngroups), + m_idx * block_M : (m_idx + 1) * block_M, + k * block_K : (k + 1) * block_K, + ], + cb_shared, + ) T.copy(cb_shared, cb_local) - T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], - dA_cs_k_shared) + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dA_cs_k_shared) T.copy(dA_cs_k_shared, dA_cs_k_local) for i, j in T.Parallel(block_M, block_K): - cb_local[i, - j] = cb_local[i, - j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p) - T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared) + cb_local[i, j] = cb_local[i, j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p) + T.copy(dt[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dt_shared) T.copy(dt_shared, dt_local) for i, j in T.Parallel(block_M, block_K): cb_local[i, j] *= dt_local[j] for i, j in T.Parallel(block_M, block_K): - cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, - cb_local[i, j], 0) + cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, cb_local[i, j], 0) T.copy( - x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size + - (k + 1) * block_K, bz, n_idx * block_N:(n_idx + 1) * block_N], x_shared) + x[ + batch_idx, + chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + x_shared, + ) T.gemm(cb_local, x_shared, acc_o) D_local[0] = D[bz] T.copy( - x[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + - (m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N], - x_residual_shared) + x[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + x_residual_shared, + ) T.copy(x_residual_shared, x_residual_local) for i, j in T.Parallel(block_M, block_N): acc_o[i, j] += x_residual_local[i, j] * D_local[0] @@ -321,24 +330,37 @@ def chunk_scan_fwd(batch, T.copy(acc_o, acc_o_shared) T.copy( acc_o_shared, - Output[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + - (m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N]) + Output[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + ) return main if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=80, help='heads') - parser.add_argument('--groups', type=int, default=1, help='groups') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--chunk_size', type=int, default=256, help='chunk size') - parser.add_argument('--dim', type=int, default=64, help='dim') - parser.add_argument('--dstate', type=int, default=128, help='dstate') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=80, help="heads") + parser.add_argument("--groups", type=int, default=1, help="groups") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--chunk_size", type=int, default=256, help="chunk size") + parser.add_argument("--dim", type=int, default=64, help="dim") + parser.add_argument("--dstate", type=int, default=128, help="dstate") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() - batch, heads, groups, seq_len, chunk_size, dim, dstate = args.batch, args.heads, args.groups, args.seq_len, args.chunk_size, args.dim, args.dstate + batch, heads, groups, seq_len, chunk_size, dim, dstate = ( + args.batch, + args.heads, + args.groups, + args.seq_len, + args.chunk_size, + args.dim, + args.dstate, + ) nchunks = math.ceil(seq_len / chunk_size) total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate @@ -360,8 +382,7 @@ if __name__ == "__main__": D = torch.randn(heads).half().cuda() print("Benchmarking Triton...") - triton_latency = do_bench( - lambda: chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D), _n_warmup=10, _n_repeat=10) + triton_latency = do_bench(lambda: chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D), _n_warmup=10, _n_repeat=10) print(f"Triton TFlops: {total_flops / triton_latency * 1e-9}") print("Benchmarking Helion...") diff --git a/benchmark/matmul/benchmark_matmul.py b/benchmark/matmul/benchmark_matmul.py index c64f4fa..6ca1402 100644 --- a/benchmark/matmul/benchmark_matmul.py +++ b/benchmark/matmul/benchmark_matmul.py @@ -6,6 +6,7 @@ import tilelang import tilelang.language as T from tilelang.autotuner import autotune from tilelang import jit + # Configure logger logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -101,9 +102,7 @@ def get_configs(args, kwargs): policy=[T.GemmWarpPolicy.Square], enable_rasteration=[True, False], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] return configs @@ -112,7 +111,9 @@ def get_configs(args, kwargs): warmup=3, rep=20, ) -@jit(out_idx=[2],) +@jit( + out_idx=[2], +) def matmul( M, N, @@ -159,9 +160,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): """ The compiled TVM function for block-level matrix multiplication. @@ -176,7 +177,6 @@ def matmul( # Bind x-dimension to block index in N, # y-dimension to block index in M. with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - # Allocate shared memory for A sub-block of shape (block_M, block_K) A_shared = T.alloc_shared((block_M, block_K), dtype) # Allocate shared memory for B sub-block of shape (block_N, block_K) diff --git a/benchmark/matmul/benchmark_matmul_intrinsic.py b/benchmark/matmul/benchmark_matmul_intrinsic.py index 94e36b3..010ce87 100644 --- a/benchmark/matmul/benchmark_matmul_intrinsic.py +++ b/benchmark/matmul/benchmark_matmul_intrinsic.py @@ -6,7 +6,8 @@ import tilelang as tl import tilelang.language as T from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitter,) + TensorCoreIntrinEmitter, +) from tilelang.transform import simplify_prim_func from tilelang.autotuner import autotune import itertools @@ -103,12 +104,11 @@ def tl_matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -116,10 +116,12 @@ def tl_matmul( B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10, enable=enable_rasteration) @@ -127,7 +129,6 @@ def tl_matmul( T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -137,7 +138,6 @@ def tl_matmul( B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a(A_local, A_shared, ki) @@ -223,7 +223,6 @@ def get_configs(args, kwargs): for config in configs: print(config) else: - iter_params = dict( block_row_warps=[1, 2, 4], block_col_warps=[1, 2, 4], @@ -233,9 +232,7 @@ def get_configs(args, kwargs): stage=[0, 2], enable_rasteration=[True, False], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] return configs @@ -247,7 +244,9 @@ def get_configs(args, kwargs): ref_prog=ref_program, skip_check=True, ) -@tl.jit(out_idx=[2],) +@tl.jit( + out_idx=[2], +) def matmul( M, N, @@ -291,13 +290,8 @@ if __name__ == "__main__": parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") - parser.add_argument( - "--with_roller", - type=bool, - default=False, - help="Whether to use roller to deduce search spaces") - parser.add_argument( - "--dtype", type=str, default="float16", choices=["float16", "int8"], help="Input data type") + parser.add_argument("--with_roller", type=bool, default=False, help="Whether to use roller to deduce search spaces") + parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "int8"], help="Input data type") args = parser.parse_args() M, N, K = args.m, args.n, args.k diff --git a/benchmark/matmul/benchmark_matmul_sp.py b/benchmark/matmul/benchmark_matmul_sp.py index 0ff3cd0..22b5d13 100644 --- a/benchmark/matmul/benchmark_matmul_sp.py +++ b/benchmark/matmul/benchmark_matmul_sp.py @@ -70,7 +70,8 @@ def get_configs(M, N, K): thread_num, policy, enable_rasterization, - )) + ) + ) configs = [ { @@ -81,7 +82,8 @@ def get_configs(M, N, K): "thread_num": c[4], "policy": c[5], "enable_rasterization": c[6], # keep param name for backward-compat - } for c in _configs + } + for c in _configs ] return configs @@ -126,7 +128,9 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype): warmup=3, rep=20, ) - @jit(out_idx=[2],) + @jit( + out_idx=[2], + ) def kernel( block_M=None, block_N=None, @@ -165,10 +169,10 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype): @T.prim_func def main( - A_sparse: T.Tensor((M, K // 2), in_dtype), - E: T.Tensor((M, K // e_factor), e_dtype), - B: T.Tensor((K, N), in_dtype), - C: T.Tensor((M, N), accum_dtype), + A_sparse: T.Tensor((M, K // 2), in_dtype), + E: T.Tensor((M, K // e_factor), e_dtype), + B: T.Tensor((K, N), in_dtype), + C: T.Tensor((M, N), accum_dtype), ): """ The compiled TVM function for block-level matrix multiplication. @@ -182,9 +186,7 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype): """ # Bind x-dimension to block index in N, # y-dimension to block index in M. - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): # Allocate shared memory for A sub-block of shape (block_M, block_K) A_shared = T.alloc_shared((block_M, block_K // 2), in_dtype) # Allocate shared memory for B sub-block of shape (block_N, block_K) @@ -201,12 +203,12 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype): T.disable_warp_group_reg_alloc() T.use_swizzle(panel_size=10, enable=enable_rasterization) - T.annotate_layout({ - E: - make_cutlass_metadata_layout(E, mma_dtype=in_dtype, block_k=block_K), - E_shared: - make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, block_k=block_K), - }) + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, block_k=block_K), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, block_k=block_K), + } + ) # Loop over sub-blocks in K dimension, pipelined by num_stages for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): # Load a sub-block of A from global memory into A_shared @@ -241,18 +243,13 @@ if __name__ == "__main__": parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") parser.add_argument("--disable_cache", action="store_true") - parser.add_argument( - "--accum_dtype", - type=str, - default="float", - choices=["float", "float16"], - help="Accumulation datatype") + parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype") parser.add_argument( "--bench_torch_sparse", type=str, - choices=['cutlass', 'cusparselt'], + choices=["cutlass", "cusparselt"], default=None, - help="Whether to benchmark against torch sparse implementation, note that at current time only sm80 is supported" + help="Whether to benchmark against torch sparse implementation, note that at current time only sm80 is supported", ) args = parser.parse_args() @@ -274,7 +271,8 @@ if __name__ == "__main__": if args.bench_torch_sparse is not None: from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor - if args.bench_torch_sparse == 'cutlass': + + if args.bench_torch_sparse == "cutlass": SparseSemiStructuredTensor._FORCE_CUTLASS = True A_sp = to_sparse_semi_structured(A, transposed=False) torch_sparse_latency = do_bench(lambda: A_sp @ B) @@ -285,8 +283,6 @@ if __name__ == "__main__": print(f"Best config: {best_config}") if args.bench_torch_sparse is not None: - print( - f"Torch sparse ({args.bench_torch_sparse}) TFlops: {total_flops / torch_sparse_latency * 1e-9:.3f}" - ) + print(f"Torch sparse ({args.bench_torch_sparse}) TFlops: {total_flops / torch_sparse_latency * 1e-9:.3f}") print(f"Reference Dense TFlops: {total_flops / ref_latency * 1e-9:.3f}") diff --git a/benchmark/matmul_fp8/benchmark_matmul.py b/benchmark/matmul_fp8/benchmark_matmul.py index 796f7b9..930e8a6 100644 --- a/benchmark/matmul_fp8/benchmark_matmul.py +++ b/benchmark/matmul_fp8/benchmark_matmul.py @@ -104,9 +104,7 @@ def get_configs(args, kwargs): policy=[T.GemmWarpPolicy.Square], enable_rasteration=[True, False], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] return configs @@ -116,7 +114,9 @@ def get_configs(args, kwargs): warmup=3, rep=20, ) -@jit(out_idx=[2],) +@jit( + out_idx=[2], +) def matmul( M, N, @@ -164,9 +164,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): """ The compiled TVM function for block-level matrix multiplication. @@ -181,7 +181,6 @@ def matmul( # Bind x-dimension to block index in N, # y-dimension to block index in M. with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - # Allocate shared memory for A sub-block of shape (block_M, block_K) A_shared = T.alloc_shared((block_M, block_K), dtype) # Allocate shared memory for B sub-block of shape (block_N, block_K) diff --git a/docs/conf.py b/docs/conf.py index 9d52415..877b558 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -20,33 +20,27 @@ extensions = [ "autoapi.extension", ] -autoapi_type = 'python' -autoapi_dirs = ['../tilelang'] +autoapi_type = "python" +autoapi_dirs = ["../tilelang"] autoapi_options = [ - 'members', - 'undoc-members', - 'show-inheritance', - 'show-module-summary', - 'special-members', + "members", + "undoc-members", + "show-inheritance", + "show-module-summary", + "special-members", ] autoapi_keep_files = False # Useful for debugging the generated rst files autoapi_generate_api_docs = True -autodoc_typehints = 'description' +autodoc_typehints = "description" autoapi_ignore = ["*language/ast*", "*version*", "*libinfo*", "*parser*"] -source_suffix = { - '.rst': 'restructuredtext', - '.md': 'markdown', -} +source_suffix = {".rst": "restructuredtext", ".md": "markdown"} -myst_enable_extensions = [ - "colon_fence", - "deflist", -] +myst_enable_extensions = ["colon_fence", "deflist"] redirects = {"get_started/try_out": "../index.html#getting-started"} @@ -66,10 +60,7 @@ html_css_files = ["custom.css"] footer_copyright = "© 2025-2026 TileLang" footer_note = " " -html_theme_options = { - "light_logo": "img/logo-v2.png", - "dark_logo": "img/logo-v2.png", -} +html_theme_options = {"light_logo": "img/logo-v2.png", "dark_logo": "img/logo-v2.png"} header_links = [ ("Home", "https://github.com/tile-ai/tilelang"), diff --git a/examples/amd/example_amd_flash_attn_bwd.py b/examples/amd/example_amd_flash_attn_bwd.py index d5c52f9..a546110 100644 --- a/examples/amd/example_amd_flash_attn_bwd.py +++ b/examples/amd/example_amd_flash_attn_bwd.py @@ -11,22 +11,20 @@ import time def ref_program(Q, K, V, is_causal, groups=1): - assert Q.size( - 2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}" - assert Q.size( - 2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}" dim = Q.size(-1) K_ref = K.repeat_interleave(groups, dim=2) V_ref = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K_ref) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K_ref) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V_ref) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V_ref) lse = torch.logsumexp(scores, dim=-1).float() return output, lse @@ -45,23 +43,23 @@ def get_fwd_configs(): valid_configs = [] - for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(block_M, block_N, num_split_q, - threads, num_stages, - enable_rasterization, k_pack, - panel_size, qk_coalesced_width, - v_coalesced_width): - valid_configs.append({ - "block_M": m, - "block_N": n, - "num_split_q": s, - "threads": t, - "num_stages": stages, - "enable_rasterization": r, - "k_pack": k, - "panel_size": p, - "qk_coalesced_width": qkw, - "v_coalesced_width": vw, - }) + for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product( + block_M, block_N, num_split_q, threads, num_stages, enable_rasterization, k_pack, panel_size, qk_coalesced_width, v_coalesced_width + ): + valid_configs.append( + { + "block_M": m, + "block_N": n, + "num_split_q": s, + "threads": t, + "num_stages": stages, + "enable_rasterization": r, + "k_pack": k, + "panel_size": p, + "qk_coalesced_width": qkw, + "v_coalesced_width": vw, + } + ) return valid_configs @@ -85,7 +83,7 @@ def fast_flashattn( qk_coalesced_width: int, v_coalesced_width: int, ): - scale = (1.0 / dim)**0.5 + scale = (1.0 / dim) ** 0.5 head_kv = heads // groups q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] @@ -97,11 +95,11 @@ def fast_flashattn( @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), - LSE: T.Tensor([batch, heads, seq_len], accum_dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + LSE: T.Tensor([batch, heads, seq_len], accum_dtype), ): with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined): T.use_swizzle(panel_size, enable=enable_rasterization) @@ -135,33 +133,21 @@ def fast_flashattn( m_prev = T.alloc_fragment([block_M], accum_dtype) scale_factor = T.alloc_fragment([block_M], accum_dtype) - T.copy( - Q[bz, q_block_offset:q_block_offset + block_M, by, :], - Q_shared, - coalesced_width=vec_size) + T.copy(Q[bz, q_block_offset : q_block_offset + block_M, by, :], Q_shared, coalesced_width=vec_size) - loop_end_k = ( - T.ceildiv(q_block_offset + - block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + loop_end_k = T.ceildiv(q_block_offset + block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) row_sum = T.alloc_fragment([block_M], accum_dtype) for k in T.Pipelined(loop_end_k, num_stages=num_stages): kv_idx = k * block_N - T.copy( - K[bz, kv_idx:kv_idx + block_N, by // groups, :], - K_shared, - coalesced_width=vec_size) - T.copy( - V[bz, kv_idx:kv_idx + block_N, by // groups, :], - V_shared, - coalesced_width=v_vec_size) + T.copy(K[bz, kv_idx : kv_idx + block_N, by // groups, :], K_shared, coalesced_width=vec_size) + T.copy(V[bz, kv_idx : kv_idx + block_N, by // groups, :], V_shared, coalesced_width=v_vec_size) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) T.gemm( @@ -216,8 +202,7 @@ def fast_flashattn( for i in T.Parallel(block_M): if q_block_offset + i < seq_len: - lse_val = T.if_then_else(l_i[i] > 0, - T.log(l_i[i]) + m_i[i], -T.infinity(accum_dtype)) + lse_val = T.if_then_else(l_i[i] > 0, T.log(l_i[i]) + m_i[i], -T.infinity(accum_dtype)) LSE[bz, by, q_block_offset + i] = lse_val bx_loop_var = current_bx + num_split_q @@ -234,16 +219,17 @@ def get_bwd_configs(): panel_size = [7, 8, 9, 10] configs = [] - for m, n, stages, t, r, p in itertools.product(block_M, block_N, num_stages, threads, - enable_rasterization, panel_size): - configs.append({ - "block_M": m, - "block_N": n, - "num_stages": stages, - "threads": t, - "enable_rasterization": r, - "panel_size": p, - }) + for m, n, stages, t, r, p in itertools.product(block_M, block_N, num_stages, threads, enable_rasterization, panel_size): + configs.append( + { + "block_M": m, + "block_N": n, + "num_stages": stages, + "threads": t, + "enable_rasterization": r, + "panel_size": p, + } + ) return configs @@ -256,8 +242,7 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim): blk = 32 @T.prim_func - def flash_bwd_prep(O: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype), - Delta: T.Tensor([batch, heads, seq_len], accum_dtype)): + def flash_bwd_prep(O: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype), Delta: T.Tensor([batch, heads, seq_len], accum_dtype)): with T.Kernel(batch, heads, T.ceildiv(seq_len, blk)) as (bz, bx, by): o = T.alloc_fragment([blk, blk], dtype) do = T.alloc_fragment([blk, blk], dtype) @@ -265,21 +250,33 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim): delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim, blk)): - T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep @tilelang.autotune(configs=get_bwd_configs(), cache_input_tensors=True) @tilelang.jit -def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, block_N: int, - num_stages: int, threads: int, enable_rasterization: bool, panel_size: int): - sm_scale = (1.0 / dim)**0.5 +def flashattn_bwd( + batch, + heads, + seq_len, + dim, + is_causal, + groups, + block_M: int, + block_N: int, + num_stages: int, + threads: int, + enable_rasterization: bool, + panel_size: int, +): + sm_scale = (1.0 / dim) ** 0.5 head_kv = heads // groups q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] @@ -287,14 +284,17 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b accum_dtype = "float" @T.prim_func - def flash_bwd_kernel(Q: T.Tensor(q_shape, - dtype), K: T.Tensor(kv_shape, - dtype), V: T.Tensor(kv_shape, dtype), - dO: T.Tensor(q_shape, dtype), lse: T.Tensor([batch, heads, seq_len], - accum_dtype), - Delta: T.Tensor([batch, heads, seq_len], - accum_dtype), dQ: T.Tensor(q_shape, accum_dtype), - dK: T.Tensor(kv_shape, accum_dtype), dV: T.Tensor(kv_shape, accum_dtype)): + def flash_bwd_kernel( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + dO: T.Tensor(q_shape, dtype), + lse: T.Tensor([batch, heads, seq_len], accum_dtype), + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), + dQ: T.Tensor(q_shape, accum_dtype), + dK: T.Tensor(kv_shape, accum_dtype), + dV: T.Tensor(kv_shape, accum_dtype), + ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): T.use_swizzle(panel_size, enable=enable_rasterization) @@ -315,8 +315,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b dk = T.alloc_fragment([block_M, dim], accum_dtype) dq = T.alloc_fragment([block_N, dim], accum_dtype) - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) @@ -324,22 +324,21 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q_shared) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q_shared) T.clear(qkT) T.gemm(K_shared, q_shared, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): P_acc[i, j] = T.exp(qkT[i, j] * sm_scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - P_acc[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, - P_acc[i, j], 0.0) + P_acc[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, P_acc[i, j], 0.0) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do_shared) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do_shared) T.clear(dP) T.gemm(V_shared, do_shared, dP, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -347,7 +346,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b T.copy(P_acc, p_cast) T.gemm(p_cast, do_shared, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta_shared) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta_shared) for i, j in T.Parallel(block_M, block_N): p_cast[i, j] = P_acc[i, j] * (dP[i, j] - delta_shared[j]) * sm_scale @@ -378,8 +377,8 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim): def flash_bwd_post(dQ_in: T.Tensor(shape, accum_dtype), dQ_out: T.Tensor(shape, dtype)): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): T.copy( - dQ_in[bz, bx * blk:(bx + 1) * blk, by, :], - dQ_out[bz, bx * blk:(bx + 1) * blk, by, :], + dQ_in[bz, bx * blk : (bx + 1) * blk, by, :], + dQ_out[bz, bx * blk : (bx + 1) * blk, by, :], ) return flash_bwd_post @@ -446,22 +445,14 @@ def benchmark_function(func, *args, warmup=10, repeat=100): return np.median(times) -def main(batch: int = 1, - heads: int = 8, - seq_len: int = 4096, - dim: int = 128, - is_causal: bool = False, - groups: int = 1): - +def main(batch: int = 1, heads: int = 8, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 1): device = "cuda" dtype = torch.float16 torch.manual_seed(42) torch.cuda.manual_seed(42) - print( - f"Test configuration: batch={batch}, heads={heads}, seq_len={seq_len}, dim={dim}, is_causal={is_causal}, groups={groups}" - ) + print(f"Test configuration: batch={batch}, heads={heads}, seq_len={seq_len}, dim={dim}, is_causal={is_causal}, groups={groups}") flops_per_gemm = 2.0 * batch * heads * seq_len * seq_len * dim total_flops = 5 * flops_per_gemm @@ -517,22 +508,19 @@ def main(batch: int = 1, o_ref.backward(dO) print("Verifying backward pass correctness...") - dq_close, dq_max_diff, dq_mean_diff = debug_tensor_comparison( - dQ_tl, q_ref.grad, "dQ", rtol=0.05, atol=0.05) + dq_close, dq_max_diff, dq_mean_diff = debug_tensor_comparison(dQ_tl, q_ref.grad, "dQ", rtol=0.05, atol=0.05) if dq_close: print("dQ is correct.") else: print("dQ mismatch detected.") - dk_close, dk_max_diff, dk_mean_diff = debug_tensor_comparison( - dK_tl.to(torch.float16), k_ref.grad, "dK", rtol=0.05, atol=0.05) + dk_close, dk_max_diff, dk_mean_diff = debug_tensor_comparison(dK_tl.to(torch.float16), k_ref.grad, "dK", rtol=0.05, atol=0.05) if dk_close: print("dK is correct.") else: print("dK mismatch detected.") - dv_close, dv_max_diff, dv_mean_diff = debug_tensor_comparison( - dV_tl.to(torch.float16), v_ref.grad, "dV", rtol=0.05, atol=0.05) + dv_close, dv_max_diff, dv_mean_diff = debug_tensor_comparison(dV_tl.to(torch.float16), v_ref.grad, "dV", rtol=0.05, atol=0.05) if dv_close: print("dV is correct.") else: @@ -553,9 +541,7 @@ def main(batch: int = 1, torch.cuda.synchronize() ref_latency = benchmark_function(run_reference_fwd_bwd, warmup=10, repeat=100) - print( - f"Reference PyTorch Forward+Backward: {ref_latency:.2f} ms | {total_flops / ref_latency * 1e-9:.2f} TFlops" - ) + print(f"Reference PyTorch Forward+Backward: {ref_latency:.2f} ms | {total_flops / ref_latency * 1e-9:.2f} TFlops") def run_complete_fwd_bwd(): o_tl_bench, lse_tl_bench = fwd_kernel(q, k, v) @@ -593,12 +579,12 @@ def main(batch: int = 1, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=8, help='heads') - parser.add_argument('--seq_len', type=int, default=1024, help='sequence length') - parser.add_argument('--dim', type=int, default=64, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--groups', type=int, default=1, help='groups') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=8, help="heads") + parser.add_argument("--seq_len", type=int, default=1024, help="sequence length") + parser.add_argument("--dim", type=int, default=64, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--groups", type=int, default=1, help="groups") args = parser.parse_args() main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups) diff --git a/examples/amd/example_amd_flash_attn_fwd.py b/examples/amd/example_amd_flash_attn_fwd.py index 9ffa7cb..e53299a 100644 --- a/examples/amd/example_amd_flash_attn_fwd.py +++ b/examples/amd/example_amd_flash_attn_fwd.py @@ -13,10 +13,10 @@ def supply_tensors_gpu(params): """Supply function that creates tensors on GPU for ROCm/HIP.""" tensors = [] for param in params: - if hasattr(param, 'shape') and hasattr(param, 'dtype'): + if hasattr(param, "shape") and hasattr(param, "dtype"): # Force creation on GPU device shape = [int(s) for s in param.shape] - tensor = torch.randn(shape, dtype=param.dtype, device='cuda') + tensor = torch.randn(shape, dtype=param.dtype, device="cuda") tensors.append(tensor) else: tensors.append(param) @@ -24,22 +24,20 @@ def supply_tensors_gpu(params): def ref_program(Q, K, V, is_causal, groups=1): - assert Q.size( - 2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}" - assert Q.size( - 2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}" dim = Q.size(-1) K = K.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output @@ -58,23 +56,23 @@ def get_configs(): valid_configs = [] - for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(block_M, block_N, num_split_q, - threads, num_stages, - enable_rasterization, k_pack, - panel_size, qk_coalesced_width, - v_coalesced_width): - valid_configs.append({ - "block_M": m, - "block_N": n, - "num_split_q": s, - "threads": t, - "num_stages": stages, - "enable_rasterization": r, - "k_pack": k, - "panel_size": p, - "qk_coalesced_width": qkw, - "v_coalesced_width": vw, - }) + for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product( + block_M, block_N, num_split_q, threads, num_stages, enable_rasterization, k_pack, panel_size, qk_coalesced_width, v_coalesced_width + ): + valid_configs.append( + { + "block_M": m, + "block_N": n, + "num_split_q": s, + "threads": t, + "num_stages": stages, + "enable_rasterization": r, + "k_pack": k, + "panel_size": p, + "qk_coalesced_width": qkw, + "v_coalesced_width": vw, + } + ) return valid_configs @@ -98,7 +96,7 @@ def fast_flashattn( qk_coalesced_width: int, v_coalesced_width: int, ): - scale = (1.0 / dim)**0.5 + scale = (1.0 / dim) ** 0.5 head_kv = heads // groups q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] @@ -110,10 +108,10 @@ def fast_flashattn( @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined): T.use_swizzle(panel_size, enable=enable_rasterization) @@ -147,32 +145,21 @@ def fast_flashattn( m_prev = T.alloc_fragment([block_M], accum_dtype) scale_factor = T.alloc_fragment([block_M], accum_dtype) - T.copy( - Q[bz, q_block_offset:q_block_offset + block_M, by, :], - Q_shared, - coalesced_width=vec_size) + T.copy(Q[bz, q_block_offset : q_block_offset + block_M, by, :], Q_shared, coalesced_width=vec_size) - loop_end_k = T.ceildiv(q_block_offset + block_M, - block_N) if is_causal else T.ceildiv(seq_len, block_N) + loop_end_k = T.ceildiv(q_block_offset + block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) row_sum = T.alloc_fragment([block_M], accum_dtype) for k in T.Pipelined(loop_end_k, num_stages=num_stages): kv_idx = k * block_N - T.copy( - K[bz, kv_idx:kv_idx + block_N, by // groups, :], - K_shared, - coalesced_width=vec_size) - T.copy( - V[bz, kv_idx:kv_idx + block_N, by // groups, :], - V_shared, - coalesced_width=v_vec_size) + T.copy(K[bz, kv_idx : kv_idx + block_N, by // groups, :], K_shared, coalesced_width=vec_size) + T.copy(V[bz, kv_idx : kv_idx + block_N, by // groups, :], V_shared, coalesced_width=v_vec_size) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) T.gemm( @@ -222,13 +209,7 @@ def fast_flashattn( return main -def main(batch: int = 1, - heads: int = 8, - seq_len: int = 4096, - dim: int = 128, - is_causal: bool = False, - groups: int = 1): - +def main(batch: int = 1, heads: int = 8, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 1): flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim total_flops = 2 * flops_per_matmul if is_causal: @@ -250,18 +231,16 @@ def main(batch: int = 1, print(f"Reference (PyTorch): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops") latency = profiler.do_bench(warmup=100) - print( - f"Fast Flash Attention V2 (Tile-lang): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops" - ) + print(f"Fast Flash Attention V2 (Tile-lang): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops") if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=8, help='heads') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--groups', type=int, default=1, help='groups') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=8, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--groups", type=int, default=1, help="groups") args = parser.parse_args() main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups) diff --git a/examples/analyze/example_conv_analyze.py b/examples/analyze/example_conv_analyze.py index 540fcf4..b90be14 100644 --- a/examples/analyze/example_conv_analyze.py +++ b/examples/analyze/example_conv_analyze.py @@ -25,22 +25,7 @@ def check_hopper(): return False -def kernel(N, - C, - H, - W, - F, - K, - S, - D, - P, - block_M, - block_N, - block_K, - num_stages, - threads, - dtype="float16", - accum_dtype="float"): +def kernel(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"): KH, KW = K, K OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 @@ -50,13 +35,11 @@ def kernel(N, @T.prim_func def conv( - data: T.Tensor((N, H, W, C), dtype), - kernel: T.Tensor((KH, KW, C, F), dtype), - out: T.Tensor((N, OH, OW, F), dtype), + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + out: T.Tensor((N, OH, OW, F), dtype), ): - with T.Kernel( - T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), - threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by): data_shared = T.alloc_shared((block_M, block_K), dtype) kernel_shared = T.alloc_shared((block_K, block_N), dtype) out_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -65,11 +48,13 @@ def kernel(N, kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) - T.annotate_layout({ - out_shared: make_swizzled_layout(out_shared), - data_shared: make_swizzled_layout(data_shared), - kernel_shared: make_swizzled_layout(kernel_shared), - }) + T.annotate_layout( + { + out_shared: make_swizzled_layout(out_shared), + data_shared: make_swizzled_layout(data_shared), + kernel_shared: make_swizzled_layout(kernel_shared), + } + ) T.clear(out_local) for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): @@ -81,10 +66,8 @@ def kernel(N, m = by * block_M + i access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P access_w = m % OW * S + k // C % KW * D - P - in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and - (access_w < W)) - data_shared[i, j] = T.if_then_else( - in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) + in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W) + data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) T.gemm(data_shared, kernel_shared, out_local) diff --git a/examples/analyze/example_gemm_analyze.py b/examples/analyze/example_gemm_analyze.py index bfd934f..e28440e 100644 --- a/examples/analyze/example_gemm_analyze.py +++ b/examples/analyze/example_gemm_analyze.py @@ -20,9 +20,9 @@ def kernel( @T.prim_func def matmul( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) diff --git a/examples/attention_sink/benchmark_gqa_sink_fwd.py b/examples/attention_sink/benchmark_gqa_sink_fwd.py index 1b7de6b..3538adc 100644 --- a/examples/attention_sink/benchmark_gqa_sink_fwd.py +++ b/examples/attention_sink/benchmark_gqa_sink_fwd.py @@ -51,8 +51,7 @@ def triton_kernel( q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM]) if BANDWIDTH: - lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M + lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M else: lo, hi = 0, start_q + (start_m + 1) * BLOCK_M @@ -120,7 +119,8 @@ def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.T BANDWIDTH=window_size, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, - start_q=seq_kv - seq_q) + start_q=seq_kv - seq_q, + ) return o @@ -137,12 +137,11 @@ def main( ): torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] if window_size is not None: - print('Using sliding window attention.') + print("Using sliding window attention.") assert window_size <= seq_q - flops_per_matmul = 2.0 * batch * heads * min( - window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation else: - print('Using full attention.') + print("Using full attention.") flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 total_flops = 2 * flops_per_matmul @@ -170,15 +169,14 @@ def main( block_N=block_N, num_stages=num_stages, threads=threads, - dtype=dtype) + dtype=dtype, + ) Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype) if torch.allclose( - triton_program(Q, K, V, sinks, window_size), - ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), - rtol=1e-2, - atol=1e-2): + triton_program(Q, K, V, sinks, window_size), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ): print("Checks for triton passed.✅") else: print("Checks for triton failed.❌") @@ -198,20 +196,14 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=64, help='heads') - parser.add_argument('--seq_q', type=int, default=2048, help='sequence length of query') - parser.add_argument('--seq_kv', type=int, default=2048, help='sequence length of key/value') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--groups', type=int, default=8, help='groups') - parser.add_argument( - '--window_size', - type=int, - default=None, - help='window size (default: None, which means full attention)') - parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_q", type=int, default=2048, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=2048, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--groups", type=int, default=8, help="groups") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() - main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, - args.dtype, args.tune) + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, args.dtype, args.tune) diff --git a/examples/attention_sink/benchmark_mha_sink_fwd.py b/examples/attention_sink/benchmark_mha_sink_fwd.py index f50b945..76997d8 100644 --- a/examples/attention_sink/benchmark_mha_sink_fwd.py +++ b/examples/attention_sink/benchmark_mha_sink_fwd.py @@ -50,8 +50,7 @@ def triton_kernel( q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM]) if BANDWIDTH: - lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M + lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M else: lo, hi = 0, start_q + (start_m + 1) * BLOCK_M @@ -117,26 +116,28 @@ def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.T BANDWIDTH=window_size, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, - start_q=seq_kv - seq_q) + start_q=seq_kv - seq_q, + ) return o -def main(batch: int = 1, - heads: int = 32, - seq_q: int = 256, - seq_kv: int = 256, - dim: int = 128, - window_size: Optional[int] = None, - dtype: str = "float16", - tune: bool = False): +def main( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + window_size: Optional[int] = None, + dtype: str = "float16", + tune: bool = False, +): torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] if window_size is not None: - print('Using sliding window attention.') + print("Using sliding window attention.") assert window_size <= seq_q - flops_per_matmul = 2.0 * batch * heads * min( - window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation else: - print('Using full attention.') + print("Using full attention.") flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 total_flops = 2 * flops_per_matmul @@ -163,15 +164,14 @@ def main(batch: int = 1, block_N=block_N, num_stages=num_stages, threads=threads, - dtype=dtype) + dtype=dtype, + ) Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) torch.testing.assert_close( - kernel(Q, K, V, sinks), - ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), - rtol=1e-2, - atol=1e-2) + kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ) print("All checks passed.✅") latency = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500) @@ -184,19 +184,13 @@ def main(batch: int = 1, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query') - parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument( - '--window_size', - type=int, - default=None, - help='window size (default: None, which means full attention)') - parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") - parser.add_argument('--tune', action='store_true', help='tune') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune") args = parser.parse_args() - main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, - args.tune) + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune) diff --git a/examples/attention_sink/example_gqa_sink_bwd_bhsd.py b/examples/attention_sink/example_gqa_sink_bwd_bhsd.py index b442505..5af787a 100644 --- a/examples/attention_sink/example_gqa_sink_bwd_bhsd.py +++ b/examples/attention_sink/example_gqa_sink_bwd_bhsd.py @@ -20,28 +20,30 @@ def get_bwd_configs(): @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd( - batch, - heads, - seq_len, - dim, - groups=1, - window_size=None, # None for full attention - sm_scale=None, - block_M=64, - block_N=64, - num_stages=1, - threads=128, - dtype: str = "float16"): - + batch, + heads, + seq_len, + dim, + groups=1, + window_size=None, # None for full attention + sm_scale=None, + block_M=64, + block_N=64, + num_stages=1, + threads=128, + dtype: str = "float16", +): if window_size is not None: assert window_size % block_N == 0, "window_size must be divisible by block_N" if sm_scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 scale = sm_scale * 1.44269504 # log2(e) head_kv = heads // groups @@ -51,12 +53,12 @@ def flashattn_fwd( @T.prim_func def flash_fwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(kv_shape, dtype), # type: ignore - V: T.Tensor(kv_shape, dtype), # type: ignore - Output: T.Tensor(q_shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Sinks: T.Tensor([heads], dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(kv_shape, dtype), # type: ignore + V: T.Tensor(kv_shape, dtype), # type: ignore + Output: T.Tensor(q_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Sinks: T.Tensor([heads], dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -73,7 +75,7 @@ def flashattn_fwd( sinks = T.alloc_fragment([heads], dtype) T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -81,22 +83,20 @@ def flashattn_fwd( sinks[i] = Sinks[by] end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) - start = T.max(0, - (bx * block_M - window_size) // block_N) if window_size is not None else 0 + start = T.max(0, (bx * block_M - window_size) // block_N) if window_size is not None else 0 for k in T.Pipelined(start, end, num_stages=num_stages): - T.copy(K[bz, by // groups, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by // groups, k * block_N : (k + 1) * block_N, :], K_shared) for i, j in T.Parallel(block_M, block_N): q_idx = bx * block_M + i k_idx = k * block_N + j if window_size is not None: - acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, - 0, -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) else: acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, by // groups, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by // groups, k * block_N : (k + 1) * block_N, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_M): @@ -106,8 +106,7 @@ def flashattn_fwd( # NOTE(wt): check_inf is necessary for sliding window attention. for i in T.Parallel(block_M): if window_size is not None: - scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, - scores_max[i]) + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim): @@ -124,22 +123,23 @@ def flashattn_fwd( logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i in T.Parallel(block_M): - logsum[i] += T.exp2(sinks[i] * 1.44269504 - - scores_max[i] * scale) # The only change for attention sink + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(acc_o, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"): accum_dtype = "float" shape = [batch, heads, seq_len, dim] @@ -147,9 +147,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16") @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -158,26 +158,27 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16") delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim, blk)): - T.copy(O[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], do) + T.copy(O[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep def make_dq_layout(dQ): # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment - return T.Layout(dQ.shape, - lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + return T.Layout(dQ.shape, lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) @tilelang.jit( - out_idx=[1], pass_configs={ + out_idx=[1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"): accum_dtype = "float" shape = [batch, heads, seq_len, dim] @@ -185,32 +186,27 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16" @T.prim_func def flash_bwd_post( - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dQ_out: T.Tensor(shape, dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): T.annotate_layout({dQ: make_dq_layout(dQ)}) T.copy( - dQ[bz, by, bx * blk:(bx + 1) * blk, :], - dQ_out[bz, by, bx * blk:(bx + 1) * blk, :], + dQ[bz, by, bx * blk : (bx + 1) * blk, :], + dQ_out[bz, by, bx * blk : (bx + 1) * blk, :], ) return flash_bwd_post -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd(batch, - heads, - seq_len, - dim, - groups, - window_size=None, - sm_scale=None, - dtype="float16"): # None for full attention +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None, sm_scale=None, dtype="float16"): # None for full attention if sm_scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 scale = sm_scale * 1.44269504 # log2(e) head_kv = heads // groups @@ -225,15 +221,15 @@ def flashattn_bwd(batch, @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(kv_shape, dtype), # type: ignore - V: T.Tensor(kv_shape, dtype), # type: ignore - dO: T.Tensor(q_shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(kv_shape, accum_dtype), # type: ignore - dV: T.Tensor(kv_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(kv_shape, dtype), # type: ignore + V: T.Tensor(kv_shape, dtype), # type: ignore + dO: T.Tensor(q_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(kv_shape, accum_dtype), # type: ignore + dV: T.Tensor(kv_shape, accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim], dtype) @@ -253,44 +249,47 @@ def flashattn_bwd(batch, dv_shared = T.alloc_shared([block_M, dim], accum_dtype) dk_shared = T.alloc_shared([block_M, dim], accum_dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) - T.copy(K[bz, bx // groups, by * block_M:(by + 1) * block_M, :], K_shared) - T.copy(V[bz, bx // groups, by * block_M:(by + 1) * block_M, :], V_shared) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) + T.copy(K[bz, bx // groups, by * block_M : (by + 1) * block_M, :], K_shared) + T.copy(V[bz, bx // groups, by * block_M : (by + 1) * block_M, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) - loop_ed = T.min( - T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv( - seq_len, block_N)) if window_size is not None else T.ceildiv(seq_len, block_N) + loop_ed = ( + T.min(T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(seq_len, block_N)) + if window_size is not None + else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q) + T.copy(Q[bz, bx, k * block_N : (k + 1) * block_N, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) for i, j in T.Parallel(block_M, block_N): if window_size is not None: qkT[i, j] = T.if_then_else( - by * block_M + i <= k * block_N + j and - by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0) + by * block_M + i <= k * block_N + j and by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0 + ) else: - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) - T.copy(dO[bz, bx, k * block_N:(k + 1) * block_N, :], dst=do) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, bx, k * block_N : (k + 1) * block_N, :], dst=do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -299,12 +298,12 @@ def flashattn_bwd(batch, T.copy(dsT_cast, dsT_shared) T.clear(dq) T.gemm(dsT_shared, K_shared, dq, transpose_A=True) - T.atomic_add(dQ[bz, bx, k * block_N:(k + 1) * block_N, :], dq) + T.atomic_add(dQ[bz, bx, k * block_N : (k + 1) * block_N, :], dq) T.copy(dv, dv_shared) - T.atomic_add(dV[bz, bx // groups, by * block_M:(by + 1) * block_M, :], dv_shared) + T.atomic_add(dV[bz, bx // groups, by * block_M : (by + 1) * block_M, :], dv_shared) T.copy(dk, dk_shared) - T.atomic_add(dK[bz, bx // groups, by * block_M:(by + 1) * block_M, :], dk_shared) + T.atomic_add(dK[bz, bx // groups, by * block_M : (by + 1) * block_M, :], dk_shared) return flash_bwd @@ -316,10 +315,10 @@ def flashattn_bwd_dsink(batch, heads, seq_len, block=256, dtype: str = "float16" @T.prim_func def flash_bwd_dsink( - Sinks: T.Tensor([heads], dtype), # type: ignore - Delta: T.Tensor(shape, accum_dtype), # type: ignore - lse: T.Tensor(shape, accum_dtype), # type: ignore - dsinks: T.Tensor(shape, dtype), # type: ignore + Sinks: T.Tensor([heads], dtype), # type: ignore + Delta: T.Tensor(shape, accum_dtype), # type: ignore + lse: T.Tensor(shape, accum_dtype), # type: ignore + dsinks: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=256) as (bx, by, bz): sink = T.alloc_local([1], dtype) @@ -328,21 +327,18 @@ def flashattn_bwd_dsink(batch, heads, seq_len, block=256, dtype: str = "float16" dsink_fragment = T.alloc_fragment([block], dtype) sink[0] = Sinks[bx] - T.copy(lse[bz, bx, by * block:(by + 1) * block], lse_fragment) - T.copy(Delta[bz, bx, by * block:(by + 1) * block], delta_fragment) + T.copy(lse[bz, bx, by * block : (by + 1) * block], lse_fragment) + T.copy(Delta[bz, bx, by * block : (by + 1) * block], delta_fragment) for i in T.Parallel(block): - dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 - - lse_fragment[i]) * delta_fragment[i] - T.copy(dsink_fragment, dsinks[bz, bx, by * block:(by + 1) * block]) + dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 - lse_fragment[i]) * delta_fragment[i] + T.copy(dsink_fragment, dsinks[bz, bx, by * block : (by + 1) * block]) return flash_bwd_dsink class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, sinks, window_size, groups): - def maybe_contiguous(x): if x.stride(-1) != 1: return x.contiguous() @@ -388,13 +384,14 @@ attention = _attention.apply # Adapted and optimized from # https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py -def ref_program(query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - sinks: torch.Tensor, - sliding_window: Optional[int] = None, - dtype: torch.dtype = torch.float16) -> torch.Tensor: - +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() batch_size, num_keys, num_key_value_heads, head_dim = key.shape @@ -430,32 +427,31 @@ def ref_program(query: torch.Tensor, output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) - output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, - head_dim).to(dtype) + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) return output.transpose(1, 2).contiguous() -def main(BATCH: int = 1, - H: int = 8, - N_CTX: int = 512, - D_HEAD: int = 64, - groups: int = 2, - window_size: Optional[int] = None, - dtype: str = "float16"): +def main( + BATCH: int = 1, + H: int = 8, + N_CTX: int = 512, + D_HEAD: int = 64, + groups: int = 2, + window_size: Optional[int] = None, + dtype: str = "float16", +): torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] if window_size is not None: - print('Using sliding window attention.') + print("Using sliding window attention.") assert window_size <= N_CTX - flops_per_matmul = 2.0 * BATCH * H * min( - window_size, N_CTX // 2) * N_CTX * D_HEAD # just a rough estimation + flops_per_matmul = 2.0 * BATCH * H * min(window_size, N_CTX // 2) * N_CTX * D_HEAD # just a rough estimation else: - print('Using full attention.') + print("Using full attention.") flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5 total_flops = 5 * flops_per_matmul - Q = (torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_()) - K = torch.randn( - BATCH, H // groups, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_() + Q = torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_() + K = torch.randn(BATCH, H // groups, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_() V = torch.randn_like(K).requires_grad_() sinks = torch.randn(H, dtype=torch_dtype, device="cuda").requires_grad_() dO = torch.randn_like(Q) @@ -479,16 +475,11 @@ def main(BATCH: int = 1, "float16": (1e-2, 1e-2), "bfloat16": (2e-2, 2e-2), }[dtype] - assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f'O max err: {(O-O_ref).abs().max()}' - assert torch.allclose( - dV, dV_ref, rtol=rtol, atol=atol), f'dV max err: {(dV-dV_ref).abs().max()}' - assert torch.allclose( - dK, dK_ref, rtol=rtol, atol=atol), f'dK max err: {(dK-dK_ref).abs().max()}' - assert torch.allclose( - dQ, dQ_ref, rtol=rtol, atol=atol), f'dq max err: {(dQ-dQ_ref).abs().max()}' - assert torch.allclose( - dsinks, dsinks_ref, rtol=rtol, - atol=atol), f'dsinks max err: {(dsinks-dsinks_ref).abs().max()}' + assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f"O max err: {(O - O_ref).abs().max()}" + assert torch.allclose(dV, dV_ref, rtol=rtol, atol=atol), f"dV max err: {(dV - dV_ref).abs().max()}" + assert torch.allclose(dK, dK_ref, rtol=rtol, atol=atol), f"dK max err: {(dK - dK_ref).abs().max()}" + assert torch.allclose(dQ, dQ_ref, rtol=rtol, atol=atol), f"dq max err: {(dQ - dQ_ref).abs().max()}" + assert torch.allclose(dsinks, dsinks_ref, rtol=rtol, atol=atol), f"dsinks max err: {(dsinks - dsinks_ref).abs().max()}" print("All checks passed for tilelang kernels.✅") @@ -509,17 +500,12 @@ def main(BATCH: int = 1, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='Batch size') - parser.add_argument('--h', type=int, default=64, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=4096, help='Context size') - parser.add_argument('--d_head', type=int, default=128, help='Head dimension') - parser.add_argument('--groups', type=int, default=8, help='Groups') - parser.add_argument( - '--window_size', - type=int, - default=None, - help='window size (default: None, which means full attention)') - parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--batch", type=int, default=1, help="Batch size") + parser.add_argument("--h", type=int, default=64, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=4096, help="Context size") + parser.add_argument("--d_head", type=int, default=128, help="Head dimension") + parser.add_argument("--groups", type=int, default=8, help="Groups") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") args = parser.parse_args() main(args.batch, args.h, args.n_ctx, args.d_head, args.groups, args.window_size, args.dtype) diff --git a/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py b/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py index 8d18172..feb5844 100644 --- a/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py +++ b/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py @@ -23,9 +23,11 @@ def get_configs(): rep=100, ) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn( batch, heads, @@ -41,12 +43,11 @@ def flashattn( threads=256, dtype: str = "float16", ): - if window_size is not None: assert window_size % block_N == 0, "window_size must be divisible by block_N" if sm_scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 scale = sm_scale * 1.44269504 # log2(e) head_kv = heads // groups @@ -68,13 +69,12 @@ def flashattn( by: T.int32, bz: T.int32, ): - T.copy(K[bz, by // groups, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by // groups, k * block_N : (k + 1) * block_N, :], K_shared) for i, j in T.Parallel(block_M, block_N): q_idx = bx * block_M + i + past_len k_idx = k * block_N + j if window_size is not None: - acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) else: acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -89,18 +89,18 @@ def flashattn( by: T.int32, bz: T.int32, ): - T.copy(V[bz, by // groups, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by // groups, k * block_N : (k + 1) * block_N, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -112,8 +112,7 @@ def flashattn( # NOTE(wt): check_inf is necessary for sliding window attention. for i in T.Parallel(block_M): if window_size is not None: - scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, - scores_max[i]) + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, block_N): @@ -128,19 +127,19 @@ def flashattn( @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), - Sinks: T.Tensor([heads], dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + Sinks: T.Tensor([heads], dtype), ): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -157,58 +156,58 @@ def flashattn( logsum = T.alloc_fragment([block_M], accum_dtype) sinks = T.alloc_fragment([block_M], dtype) - T.annotate_layout({ - Q_shared: make_swizzled_layout(Q_shared), - K_shared: make_swizzled_layout(K_shared), - V_shared: make_swizzled_layout(V_shared), - O_shared: make_swizzled_layout(O_shared), - }) + T.annotate_layout( + { + Q_shared: make_swizzled_layout(Q_shared), + K_shared: make_swizzled_layout(K_shared), + V_shared: make_swizzled_layout(V_shared), + O_shared: make_swizzled_layout(O_shared), + } + ) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) for i in T.Parallel(block_M): sinks[i] = Sinks[by] - end = T.min( - T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + end = T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) - start = T.max(0, (bx * block_M + past_len - window_size) // - block_N) if window_size is not None else 0 + start = T.max(0, (bx * block_M + past_len - window_size) // block_N) if window_size is not None else 0 for k in T.Pipelined( - start, - end, - num_stages=num_stages, - order=[-1, 0, 3, 1, -1, 2], - stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]): + start, + end, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i in T.Parallel(block_M): - logsum[i] += T.exp2(sinks[i] * 1.44269504 - - scores_max[i] * scale) # The only change for attention sink + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main # Following functions are adapted and optimized from # https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py -def ref_program(query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - sinks: torch.Tensor, - sliding_window: Optional[int] = None, - dtype: torch.dtype = torch.float16) -> torch.Tensor: - +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() batch_size, num_keys, num_key_value_heads, head_dim = key.shape @@ -244,23 +243,15 @@ def ref_program(query: torch.Tensor, output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) - output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, - head_dim).to(dtype) + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) return output.transpose(1, 2).contiguous() -def gen_inputs( - B, - H, - Sq, - Skv, - D, - groups, - dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - query = torch.randn([B, H, Sq, D], dtype=dtype, device='cuda') - key = torch.randn([B, H // groups, Skv, D], dtype=dtype, device='cuda') - value = torch.randn([B, H // groups, Skv, D], dtype=dtype, device='cuda') - sinks = torch.randn([H], dtype=dtype, device='cuda') +def gen_inputs(B, H, Sq, Skv, D, groups, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + query = torch.randn([B, H, Sq, D], dtype=dtype, device="cuda") + key = torch.randn([B, H // groups, Skv, D], dtype=dtype, device="cuda") + value = torch.randn([B, H // groups, Skv, D], dtype=dtype, device="cuda") + sinks = torch.randn([H], dtype=dtype, device="cuda") return query, key, value, sinks @@ -277,12 +268,11 @@ def main( ): torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] if window_size is not None: - print('Using sliding window attention.') + print("Using sliding window attention.") assert window_size <= seq_q - flops_per_matmul = 2.0 * batch * heads * min( - window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation else: - print('Using full attention.') + print("Using full attention.") flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 total_flops = 2 * flops_per_matmul @@ -310,15 +300,14 @@ def main( block_N=block_N, num_stages=num_stages, threads=threads, - dtype=dtype) + dtype=dtype, + ) Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype) torch.testing.assert_close( - kernel(Q, K, V, sinks), - ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), - rtol=1e-2, - atol=1e-2) + kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ) print("All checks passed.✅") # Benchmark tilelang @@ -329,20 +318,14 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=64, help='heads') - parser.add_argument('--seq_q', type=int, default=2048, help='sequence length of query') - parser.add_argument('--seq_kv', type=int, default=2048, help='sequence length of key/value') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--groups', type=int, default=8, help='groups') - parser.add_argument( - '--window_size', - type=int, - default=None, - help='window size (default: None, which means full attention)') - parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_q", type=int, default=2048, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=2048, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--groups", type=int, default=8, help="groups") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() - main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, - args.dtype, args.tune) + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, args.dtype, args.tune) diff --git a/examples/attention_sink/example_mha_sink_bwd_bhsd.py b/examples/attention_sink/example_mha_sink_bwd_bhsd.py index b9fa0fd..155c488 100644 --- a/examples/attention_sink/example_mha_sink_bwd_bhsd.py +++ b/examples/attention_sink/example_mha_sink_bwd_bhsd.py @@ -20,27 +20,29 @@ def get_bwd_configs(): @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd( - batch, - heads, - seq_len, - dim, - window_size=None, # None for full attention, - sm_scale=None, - block_M=64, - block_N=64, - num_stages=1, - threads=128, - dtype: str = "float16"): - + batch, + heads, + seq_len, + dim, + window_size=None, # None for full attention, + sm_scale=None, + block_M=64, + block_N=64, + num_stages=1, + threads=128, + dtype: str = "float16", +): if window_size is not None: assert window_size % block_N == 0, "window_size must be divisible by block_N" if sm_scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 scale = sm_scale * 1.44269504 # log2(e) shape = [batch, heads, seq_len, dim] @@ -48,12 +50,12 @@ def flashattn_fwd( @T.prim_func def flash_fwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - Output: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Sinks: T.Tensor([heads], dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + Output: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Sinks: T.Tensor([heads], dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -70,7 +72,7 @@ def flashattn_fwd( sinks = T.alloc_fragment([heads], dtype) T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -78,22 +80,20 @@ def flashattn_fwd( sinks[i] = Sinks[by] end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) - start = T.max(0, - (bx * block_M - window_size) // block_N) if window_size is not None else 0 + start = T.max(0, (bx * block_M - window_size) // block_N) if window_size is not None else 0 for k in T.Pipelined(start, end, num_stages=num_stages): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) for i, j in T.Parallel(block_M, block_N): q_idx = bx * block_M + i k_idx = k * block_N + j if window_size is not None: - acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, - 0, -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) else: acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_M): @@ -103,8 +103,7 @@ def flashattn_fwd( # NOTE(wt): check_inf is necessary for sliding window attention. for i in T.Parallel(block_M): if window_size is not None: - scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, - scores_max[i]) + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim): @@ -121,22 +120,23 @@ def flashattn_fwd( logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i in T.Parallel(block_M): - logsum[i] += T.exp2(sinks[i] * 1.44269504 - - scores_max[i] * scale) # The only change for attention sink + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(acc_o, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"): accum_dtype = "float" shape = [batch, heads, seq_len, dim] @@ -144,9 +144,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16") @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -155,26 +155,27 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16") delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim, blk)): - T.copy(O[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], do) + T.copy(O[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep def make_dq_layout(dQ): # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment - return T.Layout(dQ.shape, - lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + return T.Layout(dQ.shape, lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) @tilelang.jit( - out_idx=[1], pass_configs={ + out_idx=[1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"): accum_dtype = "float" shape = [batch, heads, seq_len, dim] @@ -182,22 +183,24 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16" @T.prim_func def flash_bwd_post( - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dQ_out: T.Tensor(shape, dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): T.annotate_layout({dQ: make_dq_layout(dQ)}) T.copy( - dQ[bz, by, bx * blk:(bx + 1) * blk, :], - dQ_out[bz, by, bx * blk:(bx + 1) * blk, :], + dQ[bz, by, bx * blk : (bx + 1) * blk, :], + dQ_out[bz, by, bx * blk : (bx + 1) * blk, :], ) return flash_bwd_post -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) def flashattn_bwd( batch, heads, @@ -207,11 +210,10 @@ def flashattn_bwd( sm_scale=None, dtype: str = "float16", ): - block_M, block_N, num_stages, threads = get_bwd_configs() if sm_scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 scale = sm_scale * 1.44269504 # log2(e) shape = [batch, heads, seq_len, dim] @@ -222,15 +224,15 @@ def flashattn_bwd( @T.prim_func def flash_bwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dK: T.Tensor(shape, dtype), # type: ignore - dV: T.Tensor(shape, dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dK: T.Tensor(shape, dtype), # type: ignore + dV: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim], dtype) @@ -254,43 +256,46 @@ def flashattn_bwd( dv_shared = T.alloc_shared([block_M, dim], dtype) dk_shared = T.alloc_shared([block_M, dim], dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) - T.copy(K[bz, bx, by * block_M:(by + 1) * block_M, :], K_shared) - T.copy(V[bz, bx, by * block_M:(by + 1) * block_M, :], V_shared) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) + T.copy(K[bz, bx, by * block_M : (by + 1) * block_M, :], K_shared) + T.copy(V[bz, bx, by * block_M : (by + 1) * block_M, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) - loop_ed = T.min( - T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv( - seq_len, block_N)) if window_size is not None else T.ceildiv(seq_len, block_N) + loop_ed = ( + T.min(T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(seq_len, block_N)) + if window_size is not None + else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q) + T.copy(Q[bz, bx, k * block_N : (k + 1) * block_N, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) for i, j in T.Parallel(block_M, block_N): if window_size is not None: qkT[i, j] = T.if_then_else( - by * block_M + i <= k * block_N + j and - by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0) + by * block_M + i <= k * block_N + j and by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0 + ) else: - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) - T.copy(dO[bz, bx, k * block_N:(k + 1) * block_N, :], dst=do) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, bx, k * block_N : (k + 1) * block_N, :], dst=do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, B=do, C=dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -299,12 +304,12 @@ def flashattn_bwd( T.copy(dsT_cast, dsT_shared) T.clear(dq) T.gemm(dsT_shared, K_shared, dq, transpose_A=True) - T.atomic_add(dQ[bz, bx, k * block_N:(k + 1) * block_N, :], dq) + T.atomic_add(dQ[bz, bx, k * block_N : (k + 1) * block_N, :], dq) T.copy(dv, dv_shared) T.copy(dk, dk_shared) - T.copy(dv_shared, dV[bz, bx, by * block_M:(by + 1) * block_M, :]) - T.copy(dk_shared, dK[bz, bx, by * block_M:(by + 1) * block_M, :]) + T.copy(dv_shared, dV[bz, bx, by * block_M : (by + 1) * block_M, :]) + T.copy(dk_shared, dK[bz, bx, by * block_M : (by + 1) * block_M, :]) return flash_bwd @@ -316,10 +321,10 @@ def flashattn_bwd_dsink(batch, heads, seq_len, block=128, dtype: str = "float16" @T.prim_func def flash_bwd_dsink( - Sinks: T.Tensor([heads], dtype), # type: ignore - Delta: T.Tensor(shape, accum_dtype), # type: ignore - lse: T.Tensor(shape, accum_dtype), # type: ignore - dsinks: T.Tensor(shape, accum_dtype), # type: ignore + Sinks: T.Tensor([heads], dtype), # type: ignore + Delta: T.Tensor(shape, accum_dtype), # type: ignore + lse: T.Tensor(shape, accum_dtype), # type: ignore + dsinks: T.Tensor(shape, accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=128) as (bx, by, bz): sink = T.alloc_local([1], dtype) @@ -328,18 +333,16 @@ def flashattn_bwd_dsink(batch, heads, seq_len, block=128, dtype: str = "float16" dsink_fragment = T.alloc_fragment([block], accum_dtype) sink[0] = Sinks[bx] - T.copy(lse[bz, bx, by * block:(by + 1) * block], lse_fragment) - T.copy(Delta[bz, bx, by * block:(by + 1) * block], delta_fragment) + T.copy(lse[bz, bx, by * block : (by + 1) * block], lse_fragment) + T.copy(Delta[bz, bx, by * block : (by + 1) * block], delta_fragment) for i in T.Parallel(block): - dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 - - lse_fragment[i]) * delta_fragment[i] - T.copy(dsink_fragment, dsinks[bz, bx, by * block:(by + 1) * block]) + dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 - lse_fragment[i]) * delta_fragment[i] + T.copy(dsink_fragment, dsinks[bz, bx, by * block : (by + 1) * block]) return flash_bwd_dsink class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, sinks, window_size): BATCH, H, N_CTX, D_HEAD = q.shape @@ -383,15 +386,15 @@ attention = _attention.apply # Adapted and optimized from # https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py -def ref_program(query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - sinks: torch.Tensor, - sliding_window: Optional[int] = None, - dtype: torch.dtype = torch.float16) -> torch.Tensor: - - query = query.transpose(1, 2).contiguous().unsqueeze( - 3) # align with the original function's interface +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + query = query.transpose(1, 2).contiguous().unsqueeze(3) # align with the original function's interface key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() @@ -426,29 +429,22 @@ def ref_program(query: torch.Tensor, output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) - output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, - head_dim).to(dtype) + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) return output.transpose(1, 2).contiguous() -def main(BATCH: int = 1, - H: int = 1, - N_CTX: int = 512, - D_HEAD: int = 128, - window_size: Optional[int] = None, - dtype: str = "float16"): +def main(BATCH: int = 1, H: int = 1, N_CTX: int = 512, D_HEAD: int = 128, window_size: Optional[int] = None, dtype: str = "float16"): torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] if window_size is not None: - print('Using sliding window attention.') + print("Using sliding window attention.") assert window_size <= N_CTX - flops_per_matmul = 2.0 * BATCH * H * min( - window_size, N_CTX // 2) * N_CTX * D_HEAD # just a rough estimation + flops_per_matmul = 2.0 * BATCH * H * min(window_size, N_CTX // 2) * N_CTX * D_HEAD # just a rough estimation else: - print('Using full attention.') + print("Using full attention.") flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5 total_flops = 5 * flops_per_matmul - Q = (torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_()) + Q = torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_() K = torch.randn_like(Q).requires_grad_() V = torch.randn_like(Q).requires_grad_() sinks = torch.randn(H, dtype=torch_dtype, device=Q.device).requires_grad_() @@ -473,16 +469,11 @@ def main(BATCH: int = 1, "float16": (1e-2, 1e-2), "bfloat16": (2e-2, 2e-2), }[dtype] - assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f'O max err: {(O-O_ref).abs().max()}' - assert torch.allclose( - dV, dV_ref, rtol=rtol, atol=atol), f'dV max err: {(dV-dV_ref).abs().max()}' - assert torch.allclose( - dK, dK_ref, rtol=rtol, atol=atol), f'dK max err: {(dK-dK_ref).abs().max()}' - assert torch.allclose( - dQ, dQ_ref, rtol=rtol, atol=atol), f'dq max err: {(dQ-dQ_ref).abs().max()}' - assert torch.allclose( - dsinks, dsinks_ref, rtol=rtol, - atol=atol), f'dsinks max err: {(dsinks-dsinks_ref).abs().max()}' + assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f"O max err: {(O - O_ref).abs().max()}" + assert torch.allclose(dV, dV_ref, rtol=rtol, atol=atol), f"dV max err: {(dV - dV_ref).abs().max()}" + assert torch.allclose(dK, dK_ref, rtol=rtol, atol=atol), f"dK max err: {(dK - dK_ref).abs().max()}" + assert torch.allclose(dQ, dQ_ref, rtol=rtol, atol=atol), f"dq max err: {(dQ - dQ_ref).abs().max()}" + assert torch.allclose(dsinks, dsinks_ref, rtol=rtol, atol=atol), f"dsinks max err: {(dsinks - dsinks_ref).abs().max()}" print("All checks passed for tilelang kernels.✅") @@ -503,16 +494,11 @@ def main(BATCH: int = 1, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='Batch size') - parser.add_argument('--h', type=int, default=64, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=4096, help='Context size') - parser.add_argument('--d_head', type=int, default=128, help='Head dimension') - parser.add_argument( - '--window_size', - type=int, - default=None, - help='window size (default: None, which means full attention)') - parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--batch", type=int, default=1, help="Batch size") + parser.add_argument("--h", type=int, default=64, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=4096, help="Context size") + parser.add_argument("--d_head", type=int, default=128, help="Head dimension") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") args = parser.parse_args() main(args.batch, args.h, args.n_ctx, args.d_head, args.window_size, args.dtype) diff --git a/examples/attention_sink/example_mha_sink_fwd_bhsd.py b/examples/attention_sink/example_mha_sink_fwd_bhsd.py index 0ccb695..78ac443 100644 --- a/examples/attention_sink/example_mha_sink_fwd_bhsd.py +++ b/examples/attention_sink/example_mha_sink_fwd_bhsd.py @@ -18,27 +18,30 @@ def get_configs(): @autotune(configs=get_configs(), warmup=500, rep=100) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn( - batch, - heads, - seq_q, - seq_kv, - dim, - window_size=None, # None for full attention - sm_scale=None, - block_M=64, - block_N=64, - num_stages=1, - threads=128, - dtype: str = "float16"): + batch, + heads, + seq_q, + seq_kv, + dim, + window_size=None, # None for full attention + sm_scale=None, + block_M=64, + block_N=64, + num_stages=1, + threads=128, + dtype: str = "float16", +): if window_size is not None: assert window_size % block_N == 0, "window_size must be divisible by block_N" if sm_scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 scale = sm_scale * 1.44269504 # log2(e) q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, heads, seq_kv, dim] @@ -58,13 +61,12 @@ def flashattn( by: T.int32, bz: T.int32, ): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) for i, j in T.Parallel(block_M, block_N): q_idx = bx * block_M + i + past_len k_idx = k * block_N + j if window_size is not None: - acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) else: acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -79,18 +81,18 @@ def flashattn( by: T.int32, bz: T.int32, ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -102,8 +104,7 @@ def flashattn( # NOTE(wt): check_inf is necessary for sliding window attention. for i in T.Parallel(block_M): if window_size is not None: - scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, - scores_max[i]) + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, block_N): @@ -118,19 +119,19 @@ def flashattn( @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), - Sinks: T.Tensor([heads], dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + Sinks: T.Tensor([heads], dtype), ): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -147,53 +148,51 @@ def flashattn( logsum = T.alloc_fragment([block_M], accum_dtype) sinks = T.alloc_fragment([block_M], dtype) - T.annotate_layout({ - Q_shared: make_swizzled_layout(Q_shared), - K_shared: make_swizzled_layout(K_shared), - V_shared: make_swizzled_layout(V_shared), - O_shared: make_swizzled_layout(O_shared), - }) + T.annotate_layout( + { + Q_shared: make_swizzled_layout(Q_shared), + K_shared: make_swizzled_layout(K_shared), + V_shared: make_swizzled_layout(V_shared), + O_shared: make_swizzled_layout(O_shared), + } + ) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) for i in T.Parallel(block_M): sinks[i] = Sinks[by] - end = T.min( - T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + end = T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) - start = T.max(0, (bx * block_M + past_len - window_size) // - block_N) if window_size is not None else 0 + start = T.max(0, (bx * block_M + past_len - window_size) // block_N) if window_size is not None else 0 for k in T.Pipelined(start, end, num_stages=num_stages): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i in T.Parallel(block_M): - logsum[i] += T.exp2(sinks[i] * 1.44269504 - - scores_max[i] * scale) # The only change for attention sink + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main # Modified from https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py -def ref_program(query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - sinks: torch.Tensor, - sliding_window: Optional[int] = None, - dtype: torch.dtype = torch.float16) -> torch.Tensor: - - query = query.transpose(1, 2).contiguous().unsqueeze( - 3) # align with the original function's interface +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + query = query.transpose(1, 2).contiguous().unsqueeze(3) # align with the original function's interface key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() @@ -228,41 +227,35 @@ def ref_program(query: torch.Tensor, output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) - output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, - head_dim).to(dtype) + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) return output.transpose(1, 2).contiguous() -def gen_inputs( - B, - H, - Sq, - Skv, - D, - dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - query = torch.randn([B, H, Sq, D], dtype=dtype, device='cuda') - key = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda') - value = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda') - sinks = torch.randn([H], dtype=dtype, device='cuda') +def gen_inputs(B, H, Sq, Skv, D, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + query = torch.randn([B, H, Sq, D], dtype=dtype, device="cuda") + key = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda") + value = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda") + sinks = torch.randn([H], dtype=dtype, device="cuda") return query, key, value, sinks -def main(batch: int = 1, - heads: int = 1, - seq_q: int = 256, - seq_kv: int = 256, - dim: int = 128, - window_size: Optional[int] = None, - dtype: str = "float16", - tune: bool = False): +def main( + batch: int = 1, + heads: int = 1, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + window_size: Optional[int] = None, + dtype: str = "float16", + tune: bool = False, +): torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] if window_size is not None: - print('Using sliding window attention.') + print("Using sliding window attention.") assert window_size <= seq_q - flops_per_matmul = 2.0 * batch * heads * min( - window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation else: - print('Using full attention.') + print("Using full attention.") flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 total_flops = 2 * flops_per_matmul @@ -289,19 +282,17 @@ def main(batch: int = 1, block_N=block_N, num_stages=num_stages, threads=threads, - dtype=dtype) + dtype=dtype, + ) Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) torch.testing.assert_close( - kernel(Q, K, V, sinks), - ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), - rtol=1e-2, - atol=1e-2) + kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ) print("All checks passed.✅") - latency = do_bench( - lambda: ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), warmup=500) + latency = do_bench(lambda: ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), warmup=500) print("Ref: {:.2f} ms".format(latency)) print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500) @@ -311,19 +302,13 @@ def main(batch: int = 1, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query') - parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument( - '--window_size', - type=int, - default=None, - help='window size (default: None, which means full attention)') - parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") - parser.add_argument('--tune', action='store_true', help='tune') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune") args = parser.parse_args() - main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, - args.tune) + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune) diff --git a/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py b/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py index 64d6ec6..decdc8f 100644 --- a/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py +++ b/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py @@ -19,28 +19,30 @@ def get_configs(): @autotune(configs=get_configs(), warmup=500, rep=100) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn( - batch, - heads, - seq_q, - seq_kv, - dim, - window_size=None, # None for full attention - sm_scale=None, - block_M=128, - block_N=128, - num_stages=2, - threads=256, - dtype: str = "float16"): - + batch, + heads, + seq_q, + seq_kv, + dim, + window_size=None, # None for full attention + sm_scale=None, + block_M=128, + block_N=128, + num_stages=2, + threads=256, + dtype: str = "float16", +): if window_size is not None: assert window_size % block_N == 0, "window_size must be divisible by block_N" if sm_scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 scale = sm_scale * 1.44269504 # log2(e) q_shape = [batch, heads, seq_q, dim] @@ -61,13 +63,12 @@ def flashattn( by: T.int32, bz: T.int32, ): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) for i, j in T.Parallel(block_M, block_N): q_idx = bx * block_M + i + past_len k_idx = k * block_N + j if window_size is not None: - acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) else: acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -82,18 +83,18 @@ def flashattn( by: T.int32, bz: T.int32, ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -105,8 +106,7 @@ def flashattn( # NOTE(wt): check_inf is necessary for sliding window attention. for i in T.Parallel(block_M): if window_size is not None: - scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, - scores_max[i]) + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, block_N): @@ -121,19 +121,19 @@ def flashattn( @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), - Sinks: T.Tensor([heads], dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + Sinks: T.Tensor([heads], dtype), ): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -150,60 +150,59 @@ def flashattn( logsum = T.alloc_fragment([block_M], accum_dtype) sinks = T.alloc_fragment([block_M], dtype) - T.annotate_layout({ - Q_shared: make_swizzled_layout(Q_shared), - K_shared: make_swizzled_layout(K_shared), - V_shared: make_swizzled_layout(V_shared), - O_shared: make_swizzled_layout(O_shared), - }) + T.annotate_layout( + { + Q_shared: make_swizzled_layout(Q_shared), + K_shared: make_swizzled_layout(K_shared), + V_shared: make_swizzled_layout(V_shared), + O_shared: make_swizzled_layout(O_shared), + } + ) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) for i in T.Parallel(block_M): sinks[i] = Sinks[by] - end = T.min( - T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + end = T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) - start = T.max(0, (bx * block_M + past_len - window_size) // - block_N) if window_size is not None else 0 + start = T.max(0, (bx * block_M + past_len - window_size) // block_N) if window_size is not None else 0 for k in T.Pipelined( - start, - end, - num_stages=num_stages, - order=[-1, 0, 3, 1, -1, 2], - stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]): + start, + end, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i in T.Parallel(block_M): - logsum[i] += T.exp2(sinks[i] * 1.44269504 - - scores_max[i] * scale) # The only change for attention sink + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main # Following functions are adapted and optimized from # https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py -def ref_program(query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - sinks: torch.Tensor, - sliding_window: Optional[int] = None, - dtype: torch.dtype = torch.float16) -> torch.Tensor: - - query = query.transpose(1, 2).contiguous().unsqueeze( - 3) # align with the original function'sinterface +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + query = query.transpose(1, 2).contiguous().unsqueeze(3) # align with the original function'sinterface key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() @@ -238,41 +237,35 @@ def ref_program(query: torch.Tensor, output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) - output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, - head_dim).to(dtype) + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) return output.transpose(1, 2).contiguous() -def gen_inputs( - B, - H, - Sq, - Skv, - D, - dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - query = torch.randn([B, H, Sq, D], dtype=dtype, device='cuda') - key = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda') - value = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda') - sinks = torch.randn([H], dtype=dtype, device='cuda') +def gen_inputs(B, H, Sq, Skv, D, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + query = torch.randn([B, H, Sq, D], dtype=dtype, device="cuda") + key = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda") + value = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda") + sinks = torch.randn([H], dtype=dtype, device="cuda") return query, key, value, sinks -def main(batch: int = 1, - heads: int = 32, - seq_q: int = 256, - seq_kv: int = 256, - dim: int = 128, - window_size: Optional[int] = None, - dtype: str = "float16", - tune: bool = False): +def main( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + window_size: Optional[int] = None, + dtype: str = "float16", + tune: bool = False, +): torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] if window_size is not None: - print('Using sliding window attention.') + print("Using sliding window attention.") assert window_size <= seq_q - flops_per_matmul = 2.0 * batch * heads * min( - window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation else: - print('Using full attention.') + print("Using full attention.") flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 total_flops = 2 * flops_per_matmul @@ -299,15 +292,14 @@ def main(batch: int = 1, block_N=block_N, num_stages=num_stages, threads=threads, - dtype=dtype) + dtype=dtype, + ) Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) torch.testing.assert_close( - kernel(Q, K, V, sinks), - ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), - rtol=1e-2, - atol=1e-2) + kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ) print("All checks passed.✅") latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500) @@ -317,19 +309,13 @@ def main(batch: int = 1, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query') - parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument( - '--window_size', - type=int, - default=None, - help='window size (default: None, which means full attention)') - parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") - parser.add_argument('--tune', action='store_true', help='tune') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune") args = parser.parse_args() - main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, - args.tune) + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune) diff --git a/examples/bitnet-1.58b/benchmark_generate.py b/examples/bitnet-1.58b/benchmark_generate.py index d6f21ed..d678b91 100644 --- a/examples/bitnet-1.58b/benchmark_generate.py +++ b/examples/bitnet-1.58b/benchmark_generate.py @@ -12,8 +12,7 @@ bitblas.set_log_level("INFO") def generate_text_batch(model, tokenizer, prompts, max_length=100): # Encode the input prompts as a batch - input_ids = tokenizer( - prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device) + input_ids = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device) # Generate cos and sin values (commented out as not used in generation) seq_length = input_ids.size(1) @@ -37,9 +36,7 @@ def generate_text_batch(model, tokenizer, prompts, max_length=100): end_time = time.time() # Decode the output ids to text - generated_texts = [ - tokenizer.decode(output_id, skip_special_tokens=True) for output_id in output_ids - ] + generated_texts = [tokenizer.decode(output_id, skip_special_tokens=True) for output_id in output_ids] generation_time = end_time - start_time num_tokens = sum(len(output_id) for output_id in output_ids) @@ -52,8 +49,8 @@ def generate_text_batch(model, tokenizer, prompts, max_length=100): def profile(model, input_data): - import numpy as np + model = model.cuda() model.eval() @@ -74,25 +71,29 @@ def profile(model, input_data): return np.mean(times) -model_path = '1bitLLM/bitnet_b1_58-3B' +model_path = "1bitLLM/bitnet_b1_58-3B" def main(): parser = argparse.ArgumentParser() - parser.add_argument('--bs', default=16, type=int) - parser.add_argument('--in_seq_len', default=32, type=int) - parser.add_argument('--out_seq_len', default=128, type=int) - parser.add_argument('--bitblas', action='store_true') + parser.add_argument("--bs", default=16, type=int) + parser.add_argument("--in_seq_len", default=32, type=int) + parser.add_argument("--out_seq_len", default=128, type=int) + parser.add_argument("--bitblas", action="store_true") args = parser.parse_args() bs = args.bs in_seq_len = args.in_seq_len out_seq_len = args.out_seq_len is_bitblas = args.bitblas - model = BitnetForCausalLM.from_pretrained( - model_path, - use_flash_attention_2=True, - torch_dtype=torch.float16, - ).cuda().half() + model = ( + BitnetForCausalLM.from_pretrained( + model_path, + use_flash_attention_2=True, + torch_dtype=torch.float16, + ) + .cuda() + .half() + ) if is_bitblas: with torch.no_grad(): model.quantize() @@ -109,5 +110,5 @@ def main(): print(generate_text_batch(model, tokenizer, prompts, max_length=max_length)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/bitnet-1.58b/benchmark_inference_latency.py b/examples/bitnet-1.58b/benchmark_inference_latency.py index 9ce7a38..788fc55 100644 --- a/examples/bitnet-1.58b/benchmark_inference_latency.py +++ b/examples/bitnet-1.58b/benchmark_inference_latency.py @@ -6,13 +6,14 @@ from modeling_bitnet import BitnetForCausalLM torch.set_grad_enabled(False) parser = argparse.ArgumentParser() -parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str) +parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str) def profile(model, input_data): import time import numpy as np + model = model.cuda() model.eval() @@ -35,8 +36,8 @@ def profile(model, input_data): def main(): model = BitnetForCausalLM.from_pretrained( - '1bitLLM/bitnet_b1_58-3B', - device_map='auto', + "1bitLLM/bitnet_b1_58-3B", + device_map="auto", low_cpu_mem_usage=True, use_flash_attention_2=True, torch_dtype=torch.float16, @@ -52,5 +53,5 @@ def main(): print(f"Batch size: {batch_size}, Seq len: {seq_len}, Latency: {latency}") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/bitnet-1.58b/configuration_bitnet.py b/examples/bitnet-1.58b/configuration_bitnet.py index 5f4937b..63c499d 100644 --- a/examples/bitnet-1.58b/configuration_bitnet.py +++ b/examples/bitnet-1.58b/configuration_bitnet.py @@ -17,7 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" LLaMA model configuration""" +"""LLaMA model configuration""" from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging @@ -180,16 +180,10 @@ class BitnetConfig(PretrainedConfig): return if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: - raise ValueError( - "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " - f"got {self.rope_scaling}") + raise ValueError(f"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, got {self.rope_scaling}") rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: - raise ValueError( - f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" - ) - if rope_scaling_factor is None or not isinstance(rope_scaling_factor, - float) or rope_scaling_factor <= 1.0: - raise ValueError( - f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") + raise ValueError(f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}") + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") diff --git a/examples/bitnet-1.58b/eval_correctness.py b/examples/bitnet-1.58b/eval_correctness.py index ac1e340..11d4700 100644 --- a/examples/bitnet-1.58b/eval_correctness.py +++ b/examples/bitnet-1.58b/eval_correctness.py @@ -47,8 +47,8 @@ def generate_text(model, tokenizer, prompt, max_length=100): def profile(model, input_data): - import numpy as np + model = model.cuda() model.eval() @@ -69,18 +69,22 @@ def profile(model, input_data): return np.mean(times) -model_path = '1bitLLM/bitnet_b1_58-3B' +model_path = "1bitLLM/bitnet_b1_58-3B" def main(): - model = BitnetForCausalLM.from_pretrained( - model_path, - use_flash_attention_2=False, - torch_dtype=torch.float16, - ).cuda().half() + model = ( + BitnetForCausalLM.from_pretrained( + model_path, + use_flash_attention_2=False, + torch_dtype=torch.float16, + ) + .cuda() + .half() + ) tokenizer = BitnetTokenizer.from_pretrained(model_path, use_fast=False) - input_id = tokenizer("Hello")['input_ids'] + input_id = tokenizer("Hello")["input_ids"] input_id = torch.tensor(input_id).unsqueeze(0).cuda() print("original model generated text:") @@ -91,5 +95,5 @@ def main(): print(generate_text(model, tokenizer, "Hello", max_length=100)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/bitnet-1.58b/eval_gpu_memory.py b/examples/bitnet-1.58b/eval_gpu_memory.py index 597cbbf..00c914c 100644 --- a/examples/bitnet-1.58b/eval_gpu_memory.py +++ b/examples/bitnet-1.58b/eval_gpu_memory.py @@ -6,13 +6,14 @@ from modeling_bitnet import BitnetForCausalLM torch.set_grad_enabled(False) parser = argparse.ArgumentParser() -parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str) +parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str) def profile(model, input_data): import time import numpy as np + model = model.cuda() model.eval() @@ -35,17 +36,17 @@ def profile(model, input_data): def main(): model = BitnetForCausalLM.from_pretrained( - '1bitLLM/bitnet_b1_58-3B', - device_map='auto', + "1bitLLM/bitnet_b1_58-3B", + device_map="auto", low_cpu_mem_usage=True, use_flash_attention_2=True, torch_dtype=torch.float16, ).half() - print(f"gpu memory: {torch.cuda.memory_allocated() / 1024 ** 3} GB") + print(f"gpu memory: {torch.cuda.memory_allocated() / 1024**3} GB") with torch.no_grad(): model._post_process_weights() - print(f"gpu memory BitBLAS: {torch.cuda.memory_allocated() / 1024 ** 3} GB") + print(f"gpu memory BitBLAS: {torch.cuda.memory_allocated() / 1024**3} GB") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/bitnet-1.58b/eval_ppl.py b/examples/bitnet-1.58b/eval_ppl.py index 61c8488..97db2d0 100644 --- a/examples/bitnet-1.58b/eval_ppl.py +++ b/examples/bitnet-1.58b/eval_ppl.py @@ -15,9 +15,9 @@ from tqdm import tqdm torch.set_grad_enabled(False) parser = argparse.ArgumentParser() -parser.add_argument('--seed', default=0, type=int) -parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str) -parser.add_argument('--seqlen', default=2048, type=int) +parser.add_argument("--seed", default=0, type=int) +parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str) +parser.add_argument("--seqlen", default=2048, type=int) def calulate_loss(model, input, loss_fct): @@ -29,12 +29,16 @@ def calulate_loss(model, input, loss_fct): def main(args): - datasets = ['c4', 'wikitext2'] - model = BitnetForCausalLM.from_pretrained( - args.hf_path, - use_flash_attention_2=True, - torch_dtype=torch.float16, - ).cuda().half() + datasets = ["c4", "wikitext2"] + model = ( + BitnetForCausalLM.from_pretrained( + args.hf_path, + use_flash_attention_2=True, + torch_dtype=torch.float16, + ) + .cuda() + .half() + ) with torch.no_grad(): model._post_process_weights() tokenizer = BitnetTokenizer.from_pretrained(args.hf_path, use_fast=False) @@ -48,9 +52,9 @@ def main(args): for ii in progress: input = torch.Tensor(testdata[ii]).long().cuda().view(1, -1) loss = calulate_loss(model, input, loss_fct) - count += (input.size(-1) - 1) + count += input.size(-1) - 1 acc_loss += loss.item() - progress.set_description(f"avg_loss = {acc_loss/ count / math.log(2)}") + progress.set_description(f"avg_loss = {acc_loss / count / math.log(2)}") avg_loss = acc_loss / count / math.log(2) ppl.append(2**avg_loss) @@ -60,7 +64,7 @@ def main(args): print("Avg PPL:", sum(ppl) / len(ppl)) -if __name__ == '__main__': +if __name__ == "__main__": torch.set_grad_enabled(False) args = parser.parse_args() random.seed(args.seed) diff --git a/examples/bitnet-1.58b/eval_utils.py b/examples/bitnet-1.58b/eval_utils.py index 46241ee..72480c3 100644 --- a/examples/bitnet-1.58b/eval_utils.py +++ b/examples/bitnet-1.58b/eval_utils.py @@ -15,21 +15,17 @@ def set_seed(seed): def get_test_dataset(dataset_name, tokenizer, seqlen=2048): if dataset_name == "wikitext2": - testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') - testdata = "".join(testdata['text']).split('\n') + testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + testdata = "".join(testdata["text"]).split("\n") elif dataset_name == "c4": - testdata = load_dataset( - 'allenai/c4', - data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, - split='validation')['text'] + testdata = load_dataset("allenai/c4", data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, split="validation")[ + "text" + ] else: raise NotImplementedError testdata = [item for item in testdata if item != ""] - tokenized_text = [ - tokenizer(item, add_special_tokens=False)['input_ids'] + [tokenizer.eos_token_id] - for item in testdata - ] + tokenized_text = [tokenizer(item, add_special_tokens=False)["input_ids"] + [tokenizer.eos_token_id] for item in testdata] data, doc = [], [tokenizer.bos_token_id] for sen in tokenized_text: @@ -45,7 +41,6 @@ def get_test_dataset(dataset_name, tokenizer, seqlen=2048): class LMEvalAdaptor(BaseLM): - def __init__(self, model_name, model, tokenizer, batch_size=1, max_length=-1): super().__init__() @@ -137,5 +132,4 @@ class LMEvalAdaptor(BaseLM): return out def _model_generate(self, context, max_length, eos_token_id): - return self.model.generate( - context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False) + return self.model.generate(context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False) diff --git a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py index e5af16c..35a044e 100644 --- a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py +++ b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py @@ -102,17 +102,17 @@ def bitnet_158_int8xint2_decode( @T.prim_func def kernel( - A: T.Buffer(A_shape, in_dtype), - B: T.Buffer(B_shape, storage_dtype), - C: T.Buffer(C_shape, out_dtype), + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer(C_shape, out_dtype), ): with T.Kernel( - T.ceildiv(N, n_partition), - M, - threads=(reduce_thread, n_partition), + T.ceildiv(N, n_partition), + M, + threads=(reduce_thread, n_partition), ) as ( - bx, - by, + bx, + by, ): A_local = T.alloc_local((micro_size_k,), in_dtype) B_quant_local = T.alloc_local([micro_size_k_compressed], storage_dtype) @@ -133,8 +133,7 @@ def bitnet_158_int8xint2_decode( for v in T.vectorized(micro_size_k_compressed): B_quant_local[v] = B[ bx * n_partition + ni, - ko * (reduce_thread * micro_size_k_compressed) + - kr * micro_size_k_compressed + v, + ko * (reduce_thread * micro_size_k_compressed) + kr * micro_size_k_compressed + v, ] T.call_extern( @@ -156,9 +155,9 @@ def bitnet_158_int8xint2_decode( accum_res[0] += A_local[ki] * B_dequantize_local[ki] with T.attr( - T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), - "reduce_scope", - T.reinterpret(T.uint64(0), dtype="handle"), + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), ): T.evaluate( T.tvm_thread_allreduce( @@ -168,7 +167,8 @@ def bitnet_158_int8xint2_decode( reduced_accum_res[0], kr, dtype="handle", - )) + ) + ) if kr == 0: C[by, bx * n_partition + ni] = reduced_accum_res[0] @@ -234,13 +234,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): return new_qweight.view(np.int8) -def assert_bitnet_158_int8xint2_decode_correctness(M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - fast_decoding=True): +def assert_bitnet_158_int8xint2_decode_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding=True): program = bitnet_158_int8xint2_decode(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding) print(program) kernel = tilelang.compile(program) diff --git a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py index d8b1f62..d68a012 100644 --- a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py +++ b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py @@ -8,11 +8,13 @@ import tilelang.language as T from tilelang import tvm as tvm from tvm import DataType from tilelang.intrinsics.mma_layout import ( - make_mma_swizzle_layout as make_swizzle_layout,) + make_mma_swizzle_layout as make_swizzle_layout, +) import numpy as np from tilelang.intrinsics.mma_macro_generator import ( - INT4TensorCoreIntrinEmitter,) + INT4TensorCoreIntrinEmitter, +) from tilelang.transform import simplify_prim_func torch.manual_seed(42) @@ -181,38 +183,36 @@ def bitnet_158_int8xint2_prefill( @T.prim_func def main( - A: T.Buffer(A_shape, in_dtype), - B: T.Buffer(B_shape, storage_dtype), - C: T.Buffer((M, N), out_dtype), + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer((M, N), out_dtype), ): """ - GPU kernel entry that performs a blocked, pipelined matrix multiplication A @ B.T writing into C. + GPU kernel entry that performs a blocked, pipelined matrix multiplication A @ B.T writing into C. - This kernel: - - Loads tiles of A and a compressed/interleaved representation of B from global memory into shared memory. - - Decodes B's packed low-precision format (storage_dtype, e.g., 2-bit packed) into element values of `in_dtype` in shared memory via an external decode routine. - - Uses Warp/MMA tiled fragments and an INT4/INT2-capable MMA emitter to compute accumulation across K in a pipelined fashion with configurable stages. - - Writes accumulated tile results from shared memory back to global C with the expected block/micro-tile indexing. + This kernel: + - Loads tiles of A and a compressed/interleaved representation of B from global memory into shared memory. + - Decodes B's packed low-precision format (storage_dtype, e.g., 2-bit packed) into element values of `in_dtype` in shared memory via an external decode routine. + - Uses Warp/MMA tiled fragments and an INT4/INT2-capable MMA emitter to compute accumulation across K in a pipelined fashion with configurable stages. + - Writes accumulated tile results from shared memory back to global C with the expected block/micro-tile indexing. - Parameters: - A: Input matrix buffer of shape A_shape and element type `in_dtype`. Represents the MxK activations. - B: Compressed/interleaved weight buffer of shape B_shape and storage type `storage_dtype`. Must contain B in the packed low-precision layout expected by the decode routine used by this kernel. - C: Output buffer of shape (M, N) and type `out_dtype`; receives the resulting matrix (accumulated values are produced in `accum_dtype` and stored into C). + Parameters: + A: Input matrix buffer of shape A_shape and element type `in_dtype`. Represents the MxK activations. + B: Compressed/interleaved weight buffer of shape B_shape and storage type `storage_dtype`. Must contain B in the packed low-precision layout expected by the decode routine used by this kernel. + C: Output buffer of shape (M, N) and type `out_dtype`; receives the resulting matrix (accumulated values are produced in `accum_dtype` and stored into C). - Side effects: - Writes results into C. Calls external device decode functions to expand B from its packed representation into shared memory before computation. + Side effects: + Writes results into C. Calls external device decode functions to expand B from its packed representation into shared memory before computation. """ with T.Kernel( - T.ceildiv(N, block_N), - T.ceildiv(M, block_M), - threads=threads, - prelude=decode_i2s_to_i8s, + T.ceildiv(N, block_N), + T.ceildiv(M, block_M), + threads=threads, + prelude=decode_i2s_to_i8s, ) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) - B_dequantize_shared = T.alloc_shared( - B_dequantize_shared_shape, in_dtype, scope=shared_scope) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) A_frag = T.alloc_local((warp_rows * fragement_size_a), in_dtype) B_frag = T.alloc_local((warp_cols * fragement_size_b), in_dtype) @@ -223,10 +223,12 @@ def bitnet_158_int8xint2_prefill( thread_bindings = T.thread_binding(0, threads, "threadIdx.x") - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_dequantize_shared: make_swizzle_layout(B_dequantize_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_dequantize_shared: make_swizzle_layout(B_dequantize_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -234,7 +236,6 @@ def bitnet_158_int8xint2_prefill( T.clear(C_frag) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -243,12 +244,9 @@ def bitnet_158_int8xint2_prefill( for j, k in T.Parallel(block_N, block_K // num_elems_per_byte): B_shared[j, k] = B[bx * block_N + j, ko * (block_K // num_elems_per_byte) + k] - for i in T.serial(block_N * block_K // num_elems_per_byte // - (threads * local_size_compressed)): + for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * local_size_compressed)): for v in T.vectorized(0, local_size_compressed): - index = ( - i * threads * local_size_compressed + - thread_bindings * local_size_compressed + v) + index = i * threads * local_size_compressed + thread_bindings * local_size_compressed + v vi, vj = T.index_to_coordinates(index, B_shared_shape) B_local[v] = B_shared[vi, vj] @@ -260,12 +258,11 @@ def bitnet_158_int8xint2_prefill( ) for v in T.vectorized(0, local_size): - index = (i * threads * local_size + thread_bindings * local_size + v) + index = i * threads * local_size + thread_bindings * local_size + v vi, vj = T.index_to_coordinates(index, B_dequantize_shared_shape) B_dequantize_shared[vi, vj] = B_dequantize_local[v] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a( A_frag, @@ -360,13 +357,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): return new_qweight.view(np.int8) -def assert_bitnet_158_int8xint2_prefill_correctness(M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - fast_decoding=True): +def assert_bitnet_158_int8xint2_prefill_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding=True): program = bitnet_158_int8xint2_prefill(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding) print(program) kernel = tilelang.compile(program) diff --git a/examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py b/examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py index 9864635..f2a0e2e 100644 --- a/examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py +++ b/examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py @@ -6,7 +6,8 @@ from tvm import tl as TL import tvm.tl.language as T from bitblas.tl.utils import get_swizzle_layout from bitblas.tl.mma_macro_generator import ( - TensorCoreIntrinEmitter,) + TensorCoreIntrinEmitter, +) from bitblas.base import simplify_prim_func torch.manual_seed(0) @@ -101,12 +102,11 @@ def tl_matmul( @T.prim_func def main( - A: T.Buffer(A_shape, in_dtype), - B: T.Buffer(B_shape, in_dtype), - C: T.Buffer((M, N), out_dtype), + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -116,10 +116,12 @@ def tl_matmul( thread_bindings = T.thread_binding(0, threads, "threadIdx.x") - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -127,7 +129,6 @@ def tl_matmul( T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -137,7 +138,6 @@ def tl_matmul( B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a( A_local, diff --git a/examples/bitnet-1.58b/load_from_quantized.py b/examples/bitnet-1.58b/load_from_quantized.py index 26a32f9..8c775aa 100644 --- a/examples/bitnet-1.58b/load_from_quantized.py +++ b/examples/bitnet-1.58b/load_from_quantized.py @@ -49,7 +49,13 @@ def generate_text(model, tokenizer, prompt, max_length=100): def main(): # load quantized model - qmodel = BitnetForCausalLM.from_quantized(saved_model_path,).cuda().half() + qmodel = ( + BitnetForCausalLM.from_quantized( + saved_model_path, + ) + .cuda() + .half() + ) tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False) # print("original model generated text:") # print(generate_text(model, tokenizer, "Hi, ", max_length=100)) diff --git a/examples/bitnet-1.58b/maint/create_bitblas_ckpt.py b/examples/bitnet-1.58b/maint/create_bitblas_ckpt.py index 1e29a55..2604ef3 100644 --- a/examples/bitnet-1.58b/maint/create_bitblas_ckpt.py +++ b/examples/bitnet-1.58b/maint/create_bitblas_ckpt.py @@ -25,9 +25,9 @@ parser.add_argument("--saved_model_path", type=str, default=None) args = parser.parse_args() model_name_or_path = args.model_name_or_path -saved_model_path = os.path.join( - dirpath, "models", - f"{model_name_or_path}_bitblas") if args.saved_model_path is None else args.saved_model_path +saved_model_path = ( + os.path.join(dirpath, "models", f"{model_name_or_path}_bitblas") if args.saved_model_path is None else args.saved_model_path +) def generate_text(model, tokenizer, prompt, max_length=100): @@ -67,7 +67,10 @@ def main(): model_name_or_path, use_flash_attention_2=False, torch_dtype=torch.float16, - ).cuda().half()) + ) + .cuda() + .half() + ) tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False) # print("original model generated text:") @@ -112,10 +115,16 @@ def main(): file_path = cached_file(model_name_or_path, file) os.system(f"cp {file_path} {saved_model_path}") # load quantized model - qmodel = BitnetForCausalLM.from_quantized(saved_model_path,).cuda().half() + qmodel = ( + BitnetForCausalLM.from_quantized( + saved_model_path, + ) + .cuda() + .half() + ) print("quantized model generated text:") print(generate_text(qmodel, tokenizer, "Hi, ", max_length=100)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/bitnet-1.58b/modeling_bitnet.py b/examples/bitnet-1.58b/modeling_bitnet.py index 6e3c42b..1830995 100644 --- a/examples/bitnet-1.58b/modeling_bitnet.py +++ b/examples/bitnet-1.58b/modeling_bitnet.py @@ -64,8 +64,7 @@ def find_layers(module, layers=None, name=""): return {name: module} res = {} for name1, child in module.named_children(): - res.update( - find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1)) + res.update(find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1)) return res @@ -87,7 +86,6 @@ def _get_unpad_data(attention_mask): class BitnetRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): """ BitnetRMSNorm is equivalent to T5LayerNorm @@ -108,34 +106,23 @@ ALL_LAYERNORM_LAYERS.append(BitnetRMSNorm) class BitnetRotaryEmbedding(nn.Module): - - def __init__(self, - dim, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): super().__init__() self.scaling_factor = scaling_factor self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / ( - self.base - **(torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq) # For BC we register cos and sin cached self.max_seq_len_cached = max_position_embeddings - t = torch.arange( - self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) t = t / self.scaling_factor freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer( - "_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) - self.register_buffer( - "_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) @property def sin_cached(self): @@ -156,14 +143,12 @@ class BitnetRotaryEmbedding(nn.Module): @torch.no_grad() def forward(self, x, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] - inv_freq_expanded = self.inv_freq[None, :, - None].float().expand(position_ids.shape[0], -1, 1) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type - device_type = device_type if isinstance(device_type, - str) and device_type != "mps" else "cpu" + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) @@ -174,8 +159,8 @@ class BitnetRotaryEmbedding(nn.Module): def rotate_half(x): """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) @@ -207,7 +192,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class BitnetMLP(nn.Module): - def __init__(self, config): super().__init__() self.config = config @@ -245,7 +229,6 @@ class BitnetMLP(nn.Module): class BitnetMLPFuseGateUp(nn.Module): - def __init__(self, config): super().__init__() self.config = config @@ -272,8 +255,7 @@ class BitnetMLPFuseGateUp(nn.Module): def from_bit_mlp(cls, bit_mlp: BitnetMLP): module = cls(bit_mlp.config) # assign the weights - module.gate_up_proj.weight = nn.Parameter( - torch.cat([bit_mlp.gate_proj.weight, bit_mlp.up_proj.weight], dim=0)) + module.gate_up_proj.weight = nn.Parameter(torch.cat([bit_mlp.gate_proj.weight, bit_mlp.up_proj.weight], dim=0)) module.down_proj = bit_mlp.down_proj module.ffn_layernorm = bit_mlp.ffn_layernorm return module @@ -295,8 +277,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, - head_dim) + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) @@ -311,7 +292,8 @@ class BitnetAttention(nn.Module): logger.warning_once( f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class.") + "when creating this class." + ) self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size @@ -325,8 +307,8 @@ class BitnetAttention(nn.Module): if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads}).") + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`: {self.num_heads})." + ) self.q_proj = BitLinear( self.hidden_size, @@ -387,10 +369,8 @@ class BitnetAttention(nn.Module): value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) past_key_value = getattr(self, "past_key_value", past_key_value) cos, sin = self.rotary_emb(value_states, position_ids) @@ -399,30 +379,24 @@ class BitnetAttention(nn.Module): if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( - self.head_dim) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, :key_states.shape[-2]] + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout( - attn_weights, p=self.attention_dropout, training=self.training) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}") + raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is {attn_output.size()}") attn_output = attn_output.transpose(1, 2).contiguous() @@ -448,7 +422,8 @@ class BitnetAttentionQKVFused(nn.Module): logger.warning_once( f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class.") + "when creating this class." + ) self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size @@ -462,8 +437,8 @@ class BitnetAttentionQKVFused(nn.Module): if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads}).") + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`: {self.num_heads})." + ) self.qkv_proj = BitLinear( self.hidden_size, @@ -497,17 +472,12 @@ class BitnetAttentionQKVFused(nn.Module): module = cls(bit_attention.config, bit_attention.layer_idx) # assign the weights module.qkv_proj.weight = nn.Parameter( - torch.cat([ - bit_attention.q_proj.weight, bit_attention.k_proj.weight, - bit_attention.v_proj.weight - ], - dim=0)) + torch.cat([bit_attention.q_proj.weight, bit_attention.k_proj.weight, bit_attention.v_proj.weight], dim=0) + ) if bit_attention.q_proj.bias is not None and bit_attention.k_proj.bias is not None and bit_attention.v_proj.bias is not None: module.qkv_proj.bias = nn.Parameter( - torch.cat([ - bit_attention.q_proj.bias, bit_attention.k_proj.bias, bit_attention.v_proj.bias - ], - dim=0)) + torch.cat([bit_attention.q_proj.bias, bit_attention.k_proj.bias, bit_attention.v_proj.bias], dim=0) + ) module.o_proj = bit_attention.o_proj module.inner_attn_ln = bit_attention.inner_attn_ln if bit_attention.config.rope_scaling is None: @@ -528,16 +498,13 @@ class BitnetAttentionQKVFused(nn.Module): bsz, q_len, _ = hidden_states.size() qkv_states = self.qkv_proj(hidden_states) query_states, key_states, value_states = torch.split( - qkv_states, [ - self.num_heads * self.head_dim, self.num_key_value_heads * self.head_dim, - self.num_key_value_heads * self.head_dim - ], - dim=-1) + qkv_states, + [self.num_heads * self.head_dim, self.num_key_value_heads * self.head_dim, self.num_key_value_heads * self.head_dim], + dim=-1, + ) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) past_key_value = getattr(self, "past_key_value", past_key_value) cos, sin = self.rotary_emb(value_states, position_ids) @@ -546,30 +513,24 @@ class BitnetAttentionQKVFused(nn.Module): if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( - self.head_dim) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, :key_states.shape[-2]] + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout( - attn_weights, p=self.attention_dropout, training=self.training) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}") + raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is {attn_output.size()}") attn_output = attn_output.transpose(1, 2).contiguous() @@ -622,10 +583,8 @@ class BitnetFlashAttention2(BitnetAttention): # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -635,8 +594,7 @@ class BitnetFlashAttention2(BitnetAttention): if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. @@ -665,14 +623,14 @@ class BitnetFlashAttention2(BitnetAttention): logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}.") + f" {target_dtype}." + ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate) + attn_output = self._flash_attention_forward(query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.inner_attn_ln(attn_output) @@ -683,14 +641,9 @@ class BitnetFlashAttention2(BitnetAttention): return attn_output, attn_weights, past_key_value - def _flash_attention_forward(self, - query_states, - key_states, - value_states, - attention_mask, - query_length, - dropout=0.0, - softmax_scale=None): + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token first unpad the input, then computes the attention scores and pad the final attention scores. @@ -720,7 +673,8 @@ class BitnetFlashAttention2(BitnetAttention): if attention_mask is not None: batch_size = query_states.shape[0] query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length) + query_states, key_states, value_states, attention_mask, query_length + ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens @@ -740,13 +694,7 @@ class BitnetFlashAttention2(BitnetAttention): attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: - attn_output = flash_attn_func( - query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal) + attn_output = flash_attn_func(query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal) return attn_output @@ -754,28 +702,24 @@ class BitnetFlashAttention2(BitnetAttention): indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) - value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k) + query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k elif query_length == 1: max_seqlen_in_batch_q = 1 cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, - device=query_layer.device) # There is a memcpy here, that is very bad. + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. indices_q = cu_seqlens_q[:-1] query_layer = query_layer.squeeze(1) else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( - query_layer, attention_mask) + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) return ( query_layer, @@ -794,13 +738,11 @@ LLAMA_ATTENTION_CLASSES = { class BitnetDecoderLayer(nn.Module): - def __init__(self, config: BitnetConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation]( - config=config, layer_idx=layer_idx) + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) self.mlp = BitnetMLP(config) self.input_layernorm = BitnetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -834,7 +776,8 @@ class BitnetDecoderLayer(nn.Module): if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`", - stacklevel=2) + stacklevel=2, + ) residual = hidden_states @@ -925,8 +868,7 @@ class BitnetPreTrainedModel(PreTrainedModel): dtype = self.config._pre_quantization_dtype else: dtype = layer.self_attn.o_proj.weight.dtype - layer.self_attn.past_key_value = cache_cls( - self.config, max_batch_size, max_cache_len, device=device, dtype=dtype) + layer.self_attn.past_key_value = cache_cls(self.config, max_batch_size, max_cache_len, device=device, dtype=dtype) def _reset_cache(self): for layer in self.model.layers: @@ -1025,9 +967,7 @@ class BitnetModel(BitnetPreTrainedModel): self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([ - BitnetDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) - ]) + self.layers = nn.ModuleList([BitnetDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) self.norm = BitnetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -1055,21 +995,15 @@ class BitnetModel(BitnetPreTrainedModel): cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" - ) + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one") if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.") use_cache = False if inputs_embeds is None: @@ -1083,10 +1017,7 @@ class BitnetModel(BitnetPreTrainedModel): if cache_position is None: if isinstance(past_key_values, StaticCache): raise ValueError("cache_position is a required argument when using StaticCache.") - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + inputs_embeds.shape[1], - device=inputs_embeds.device) + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device) if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -1143,12 +1074,9 @@ class BitnetModel(BitnetPreTrainedModel): next_cache = None if use_cache: - next_cache = ( - next_decoder_cache.to_legacy_cache() - if isinstance(next_decoder_cache, Cache) else next_decoder_cache) + next_cache = next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None) + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -1172,14 +1100,9 @@ class BitnetModel(BitnetPreTrainedModel): if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache target_length = self.config.max_position_embeddings else: # dynamic cache - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1) - - causal_mask = torch.full((sequence_length, target_length), - fill_value=min_dtype, - dtype=dtype, - device=device) + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1 + + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) @@ -1188,10 +1111,8 @@ class BitnetModel(BitnetPreTrainedModel): causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit if attention_mask.dim() == 2: mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[..., :mask_length].eq( - 0.0) * attention_mask[:, None, None, :].eq(0.0) - causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill( - padding_mask, min_dtype) + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) elif attention_mask.dim() == 4: # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # cache. In that case, the 4D attention mask attends to the newest tokens only. @@ -1201,8 +1122,7 @@ class BitnetModel(BitnetPreTrainedModel): offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[:mask_shape[0], :mask_shape[1], - offset:mask_shape[2] + offset, :mask_shape[3]] = mask_slice + causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = mask_slice return causal_mask @@ -1279,9 +1199,7 @@ class BitnetForCausalLM(BitnetPreTrainedModel): "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) @@ -1327,13 +1245,9 @@ class BitnetForCausalLM(BitnetPreTrainedModel): attentions=outputs.attentions, ) - def prepare_inputs_for_generation(self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - cache_position=None, - **kwargs): + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs + ): # With static cache, the `past_key_values` is None # TODO joao: standardize interface for the different Cache classes and remove of this if has_static_cache = False @@ -1344,13 +1258,13 @@ class BitnetForCausalLM(BitnetPreTrainedModel): past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): - past_length = cache_position[ - 0] if cache_position is not None else past_key_values.get_seq_length() + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() max_cache_length = ( torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None else None) - cache_length = past_length if max_cache_length is None else torch.min( - max_cache_length, past_length) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects else: cache_length = past_length = past_key_values[0][0].shape[2] @@ -1361,7 +1275,7 @@ class BitnetForCausalLM(BitnetPreTrainedModel): # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):] + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: @@ -1369,8 +1283,7 @@ class BitnetForCausalLM(BitnetPreTrainedModel): # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if (max_cache_length is not None and attention_mask is not None and - cache_length + input_ids.shape[1] > max_cache_length): + if max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length: attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids") @@ -1379,7 +1292,7 @@ class BitnetForCausalLM(BitnetPreTrainedModel): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1]:] + position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: @@ -1392,39 +1305,38 @@ class BitnetForCausalLM(BitnetPreTrainedModel): input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] if cache_position is None: - cache_position = torch.arange( - past_length, past_length + input_length, device=input_ids.device) + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) else: cache_position = cache_position[-input_length:] if has_static_cache: past_key_values = None - model_inputs.update({ - "position_ids": position_ids, - "cache_position": cache_position, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - }) + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) return model_inputs @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple( - past_state.index_select(0, beam_idx.to(past_state.device)) - for past_state in layer_past),) + reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),) return reordered_past @staticmethod def recursive_set(model, name, attr): - ''' - set layers.25.mlp.up_proj to attr - ''' + """ + set layers.25.mlp.up_proj to attr + """ - names = name.split('.') + names = name.split(".") obj = model for n in names[:-1]: obj = getattr(obj, n) @@ -1521,6 +1433,7 @@ class BitnetForCausalLM(BitnetPreTrainedModel): fuse_gateup = quant_config.get("fuse_gateup", True) import accelerate + if checkpoint_format == "bitblas": model = cls(config) for name, module in model.named_modules(): @@ -1567,7 +1480,6 @@ class BitnetForCausalLM(BitnetPreTrainedModel): LLAMA_START_DOCSTRING, ) class BitnetForSequenceClassification(BitnetPreTrainedModel): - def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels @@ -1631,8 +1543,7 @@ class BitnetForSequenceClassification(BitnetPreTrainedModel): else: if input_ids is not None: # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility - sequence_lengths = torch.eq(input_ids, - self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 sequence_lengths = sequence_lengths % input_ids.shape[-1] sequence_lengths = sequence_lengths.to(logits.device) else: @@ -1646,8 +1557,7 @@ class BitnetForSequenceClassification(BitnetPreTrainedModel): if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or - labels.dtype == torch.int): + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" diff --git a/examples/bitnet-1.58b/tokenization_bitnet.py b/examples/bitnet-1.58b/tokenization_bitnet.py index 6fea325..2adfd6d 100644 --- a/examples/bitnet-1.58b/tokenization_bitnet.py +++ b/examples/bitnet-1.58b/tokenization_bitnet.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tokenization classes for LLaMA.""" + import os from shutil import copyfile from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple @@ -37,12 +38,10 @@ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { - "hf-internal-testing/llama-tokenizer": - "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", }, "tokenizer_file": { - "hf-internal-testing/llama-tokenizer": - "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", }, } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { @@ -159,14 +158,10 @@ class BitnetTokenizer(PreTrainedTokenizer): **kwargs, ): self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs - bos_token = AddedToken( - bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token - eos_token = AddedToken( - eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token - unk_token = AddedToken( - unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token - pad_token = AddedToken( - pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token + bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token if legacy is None: logger.warning_once( @@ -174,7 +169,8 @@ class BitnetTokenizer(PreTrainedTokenizer): " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you." " If you want to use the new behavior, set `legacy=False`. This should only be set if you understand what it" " means, and thoroughly read the reason why this was added as explained in" - " https://github.com/huggingface/transformers/pull/24565") + " https://github.com/huggingface/transformers/pull/24565" + ) legacy = True self.legacy = legacy @@ -214,8 +210,7 @@ class BitnetTokenizer(PreTrainedTokenizer): with open(self.vocab_file, "rb") as f: sp_model = f.read() - model_pb2 = import_protobuf( - f"The new behavior of {self.__class__.__name__} (with `self.legacy = False`)") + model_pb2 = import_protobuf(f"The new behavior of {self.__class__.__name__} (with `self.legacy = False`)") model = model_pb2.ModelProto.FromString(sp_model) normalizer_spec = model_pb2.NormalizerSpec() normalizer_spec.add_dummy_prefix = False @@ -261,8 +256,7 @@ class BitnetTokenizer(PreTrainedTokenizer): tokens = super().tokenize(text, **kwargs) - if len(tokens - ) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: + if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: tokens = tokens[1:] return tokens @@ -284,7 +278,7 @@ class BitnetTokenizer(PreTrainedTokenizer): # 1. Encode string + prefix ex: " Hey" tokens = self.sp_model.encode(self.unk_token + text, out_type=str) # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] - return tokens[self.unk_token_length:] if len(tokens) >= self.unk_token_length else tokens + return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens def _convert_token_to_id(self, token): """Converts a token (str) in an id using the vocab.""" @@ -332,12 +326,9 @@ class BitnetTokenizer(PreTrainedTokenizer): if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") return - out_vocab_file = os.path.join(save_directory, - (filename_prefix + "-" if filename_prefix else "") + - VOCAB_FILES_NAMES["vocab_file"]) + out_vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]) - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile( - self.vocab_file): + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): copyfile(self.vocab_file, out_vocab_file) elif not os.path.isfile(self.vocab_file): with open(out_vocab_file, "wb") as fi: @@ -357,10 +348,9 @@ class BitnetTokenizer(PreTrainedTokenizer): return output - def get_special_tokens_mask(self, - token_ids_0: List[int], - token_ids_1: Optional[List[int]] = None, - already_has_special_tokens: bool = False) -> List[int]: + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: """ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer `prepare_for_model` method. @@ -377,20 +367,16 @@ class BitnetTokenizer(PreTrainedTokenizer): `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. """ if already_has_special_tokens: - return super().get_special_tokens_mask( - token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True) + return super().get_special_tokens_mask(token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True) bos_token_id = [1] if self.add_bos_token else [] eos_token_id = [1] if self.add_eos_token else [] if token_ids_1 is None: return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id - return (bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + bos_token_id + - ([0] * len(token_ids_1)) + eos_token_id) + return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + bos_token_id + ([0] * len(token_ids_1)) + eos_token_id - def create_token_type_ids_from_sequences(self, - token_ids_0: List[int], - token_ids_1: Optional[List[int]] = None) -> List[int]: + def create_token_type_ids_from_sequences(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]: """ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT sequence pair mask has the following format: @@ -473,9 +459,9 @@ class BitnetTokenizer(PreTrainedTokenizer): "{% elif message['role'] == 'assistant' %}" "{{ ' ' + content.strip() + ' ' + eos_token }}" "{% endif %}" - "{% endfor %}") - template = template.replace("USE_DEFAULT_PROMPT", - "true" if self.use_default_system_prompt else "false") + "{% endfor %}" + ) + template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false") default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'") template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message) diff --git a/examples/bitnet-1.58b/utils_quant.py b/examples/bitnet-1.58b/utils_quant.py index 5f5db5d..5a50edb 100644 --- a/examples/bitnet-1.58b/utils_quant.py +++ b/examples/bitnet-1.58b/utils_quant.py @@ -24,15 +24,14 @@ def weight_quant(weight, num_bits=1): def activation_quant(x, num_bits=8): dtype = x.dtype x = x.float() - Qn = -(2**(num_bits - 1)) - Qp = 2**(num_bits - 1) - 1 + Qn = -(2 ** (num_bits - 1)) + Qp = 2 ** (num_bits - 1) - 1 s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) result = (x * s).round().clamp(Qn, Qp) / s return result.type(dtype) class BitLinearBitBLAS(nn.Module): - def __init__( self, in_features: int, @@ -68,7 +67,7 @@ class BitLinearBitBLAS(nn.Module): self.bitblas_matmul = self._get_or_create_bitblas_operator(matmul_config, ENABLE_TUNING) self.format = "bitnet" - self.Qp = 2**(self.input_bits - 1) - 1 + self.Qp = 2 ** (self.input_bits - 1) - 1 def _get_or_create_bitblas_operator(self, config, enable_tuning): if global_operator_cache.size() == 0: @@ -99,8 +98,7 @@ class BitLinearBitBLAS(nn.Module): @classmethod def from_bit_linear(cls, bitlinear, weight_group=1): - bitblas_linear = cls( - bitlinear.in_features, bitlinear.out_features, weight_bits=1, input_bits=8) + bitblas_linear = cls(bitlinear.in_features, bitlinear.out_features, weight_bits=1, input_bits=8) sw, qweight = bitblas_linear.create_bitblas_weights(bitlinear.weight, weight_group) bitblas_linear.register_buffer("qweight", qweight) bitblas_linear.register_buffer("sw", sw) @@ -158,8 +156,8 @@ class BitLinearBitBLAS(nn.Module): @torch.compile def activation_quant(self, x, num_bits=8): x = x.float() - Qn = -(2**(num_bits - 1)) - Qp = 2**(num_bits - 1) - 1 + Qn = -(2 ** (num_bits - 1)) + Qp = 2 ** (num_bits - 1) - 1 s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) result = (x * s).round().clamp(Qn, Qp) return result.type(torch.int8), s @@ -173,9 +171,8 @@ class BitLinearBitBLAS(nn.Module): # for the correctness evaluation. def native_forward(self, input): - quant_input = (input + (activation_quant(input, self.input_bits) - input).detach()) - quant_weight = ( - self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach()) + quant_input = input + (activation_quant(input, self.input_bits) - input).detach() + quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach() out = nn.functional.linear(quant_input, quant_weight) if self.bias is not None: @@ -214,7 +211,6 @@ class BitLinearBitBLAS(nn.Module): # Naive BitLinear from HuggingFace class BitLinear(nn.Linear): - def __init__(self, *kargs, weight_bits=1, input_bits=8, **kwargs): super(BitLinear, self).__init__(*kargs, **kwargs) """ @@ -224,10 +220,8 @@ class BitLinear(nn.Linear): self.input_bits = input_bits def forward(self, input): - quant_input = input + (activation_quant(input, self.input_bits) - input).detach() - quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - - self.weight).detach() + quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach() out = nn.functional.linear(quant_input, quant_weight) if self.bias is not None: diff --git a/examples/bitnet-1.58b/vllm_workspace/conftest.py b/examples/bitnet-1.58b/vllm_workspace/conftest.py index 951f389..e9e2997 100644 --- a/examples/bitnet-1.58b/vllm_workspace/conftest.py +++ b/examples/bitnet-1.58b/vllm_workspace/conftest.py @@ -20,7 +20,7 @@ from transformers import ( from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset from vllm.config import TokenizerPoolConfig -from vllm.distributed import (destroy_distributed_environment, destroy_model_parallel) +from vllm.distributed import destroy_distributed_environment, destroy_model_parallel from vllm.inputs import TextPrompt from vllm.logger import init_logger from vllm.sequence import SampleLogprobs @@ -56,12 +56,13 @@ else: class _ImageAssets(_ImageAssetsBase): - def __init__(self) -> None: - super().__init__([ - ImageAsset("stop_sign"), - ImageAsset("cherry_blossom"), - ]) + super().__init__( + [ + ImageAsset("stop_sign"), + ImageAsset("cherry_blossom"), + ] + ) def prompts(self, prompts: _ImageAssetPrompts) -> List[str]: """ @@ -136,7 +137,6 @@ _T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding) class HfRunner: - def wrap_device(self, input: _T) -> _T: if not is_cpu(): return input.to("cuda") @@ -166,7 +166,8 @@ class HfRunner: SentenceTransformer( model_name, device="cpu", - ).to(dtype=torch_dtype)) + ).to(dtype=torch_dtype) + ) else: if is_vision_model: auto_cls = AutoModelForVision2Seq @@ -184,7 +185,8 @@ class HfRunner: torch_dtype=torch_dtype, trust_remote_code=True, **model_kwargs, - )) + ) + ) self.tokenizer = AutoTokenizer.from_pretrained( model_name, @@ -204,8 +206,7 @@ class HfRunner: ) except Exception: logger.warning( - "Unable to auto-load processor from HuggingFace for " - "model %s. Using tokenizer instead.", + "Unable to auto-load processor from HuggingFace for model %s. Using tokenizer instead.", model_name, ) self.processor = self.tokenizer @@ -362,7 +363,7 @@ class HfRunner: last_hidden_states, self.model.get_output_embeddings().weight.t(), ) - if (getattr(self.model.get_output_embeddings(), "bias", None) is not None): + if getattr(self.model.get_output_embeddings(), "bias", None) is not None: logits += self.model.get_output_embeddings().bias.unsqueeze(0) logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) seq_logprobs.append(logprobs) @@ -389,8 +390,7 @@ class HfRunner: all_output_strs.append(self.tokenizer.decode(output_ids)) outputs = zip(all_output_ids, all_output_strs, all_logprobs) - return [(output_ids, output_str, output_logprobs) - for output_ids, output_str, output_logprobs in outputs] + return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: return self.model.encode(prompts) @@ -409,7 +409,6 @@ def hf_runner(): class VllmRunner: - def __init__( self, model_name: str, @@ -514,12 +513,10 @@ class VllmRunner: num_logprobs: int, images: Optional[List[Image.Image]] = None, ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: - greedy_logprobs_params = SamplingParams( - temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs) + greedy_logprobs_params = SamplingParams(temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs) outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params, images=images) - return [(output_ids, output_str, output_logprobs) - for output_ids, output_str, output_logprobs in outputs] + return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] def generate_beam_search( self, diff --git a/examples/bitnet-1.58b/vllm_workspace/inference_with_compress_format.py b/examples/bitnet-1.58b/vllm_workspace/inference_with_compress_format.py index 55a2454..ea18239 100644 --- a/examples/bitnet-1.58b/vllm_workspace/inference_with_compress_format.py +++ b/examples/bitnet-1.58b/vllm_workspace/inference_with_compress_format.py @@ -32,15 +32,14 @@ args = parser.parse_args() ckpt_path = args.ckpt_path with VllmRunner( - ckpt_path, - dtype="half", - quantization="bitblas", - # set enforce_eager = False to enable cuda graph - # set enforce_eager = True to disable cuda graph - enforce_eager=False, + ckpt_path, + dtype="half", + quantization="bitblas", + # set enforce_eager = False to enable cuda graph + # set enforce_eager = True to disable cuda graph + enforce_eager=False, ) as bitnet_model: - bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"], - max_tokens=1024) + bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"], max_tokens=1024) print("bitnet inference:") print(bitbnet_outputs[0][0]) print(bitbnet_outputs[0][1]) diff --git a/examples/bitnet-1.58b/vllm_workspace/inference_with_native_format.py b/examples/bitnet-1.58b/vllm_workspace/inference_with_native_format.py index 4f5f87f..f631fb3 100644 --- a/examples/bitnet-1.58b/vllm_workspace/inference_with_native_format.py +++ b/examples/bitnet-1.58b/vllm_workspace/inference_with_native_format.py @@ -33,13 +33,13 @@ args = parser.parse_args() ckpt_path = args.ckpt_path with VllmRunner( - ckpt_path, - dtype="half", - quantization="bitnet_bitblas", - gpu_memory_utilization=0.5, - # set enforce_eager = False to enable cuda graph - # set enforce_eager = True to disable cuda graph - enforce_eager=False, + ckpt_path, + dtype="half", + quantization="bitnet_bitblas", + gpu_memory_utilization=0.5, + # set enforce_eager = False to enable cuda graph + # set enforce_eager = True to disable cuda graph + enforce_eager=False, ) as bitnet_model: bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"], max_tokens=128) print("bitnet inference output:") diff --git a/examples/bitnet-1.58b/vllm_workspace/utils.py b/examples/bitnet-1.58b/vllm_workspace/utils.py index daa9d8f..e96b19e 100644 --- a/examples/bitnet-1.58b/vllm_workspace/utils.py +++ b/examples/bitnet-1.58b/vllm_workspace/utils.py @@ -3,8 +3,7 @@ from typing import Dict, List, Tuple TokensText = Tuple[List[int], str] -def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[TokensText], - name_0: str, name_1: str): +def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[TokensText], name_0: str, name_1: str): """ Compare the two sequences generated by different models, which should be equal. @@ -15,19 +14,14 @@ def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[Tok output_ids_0, output_str_0 = outputs_0 output_ids_1, output_str_1 = outputs_1 - assert output_str_0 == output_str_1, (f"Test{prompt_idx}:" - f"\n{name_0}:\t{output_str_0!r}" - f"\n{name_1}:\t{output_str_1!r}") - assert output_ids_0 == output_ids_1, (f"Test{prompt_idx}:" - f"\n{name_0}:\t{output_str_0!r}" - f"\n{name_1}:\t{output_str_1!r}") + assert output_str_0 == output_str_1, f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}" + assert output_ids_0 == output_ids_1, f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}" TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]] -def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], - outputs_1_lst: List[TokensTextLogprobs], name_0: str, name_1: str): +def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], outputs_1_lst: List[TokensTextLogprobs], name_0: str, name_1: str): """ Compare the logprobs of two sequences generated by different models, which should be similar but not necessarily equal. @@ -41,16 +35,11 @@ def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], # Loop through generated tokens. for idx, (output_id_0, output_id_1) in enumerate(zip(output_ids_0, output_ids_1)): - # If generated tokens don't match, then if output_id_0 != output_id_1: # Each predicted token must be in top N logprobs of the other - assert output_id_0 in logprobs_1[idx], (f"Test{prompt_idx}:" - f"\n{name_0}:\t{output_str_0!r}" - f"\n{name_1}:\t{output_str_1!r}") - assert output_id_1 in logprobs_0[idx], (f"Test{prompt_idx}:" - f"\n{name_0}:\t{output_str_0!r}" - f"\n{name_1}:\t{output_str_1!r}") + assert output_id_0 in logprobs_1[idx], f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}" + assert output_id_1 in logprobs_0[idx], f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}" # Break out since sequences will now diverge. break diff --git a/examples/blocksparse_attention/block_sparse_attn_triton.py b/examples/blocksparse_attention/block_sparse_attn_triton.py index 014f0c5..1794836 100644 --- a/examples/blocksparse_attention/block_sparse_attn_triton.py +++ b/examples/blocksparse_attention/block_sparse_attn_triton.py @@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -56,7 +53,6 @@ def _fwd_kernel_inner( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): - mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n) # print @@ -73,8 +69,7 @@ def _fwd_kernel_inner( # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N if LAST_K_BLOCK: - qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, - float('-inf')) + qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf")) m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk -= m_ij[:, None] @@ -154,7 +149,7 @@ def _fwd_kernel( v_ptrs = V + off_v mask_ptrs = block_mask_ptr + start_m * stride_bmm - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) @@ -192,24 +187,12 @@ def _fwd_kernel( acc = acc * l_recip acc = acc.to(Out.dtype.element_ty) - off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[ - None, :] * stride_od + off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od out_ptrs = Out + off_o tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX) -def _forward(ctx, - q, - k, - v, - block_sparse_mask, - sm_scale, - BLOCK_M=64, - BLOCK_N=64, - num_warps=None, - num_stages=1, - out=None): - +def _forward(ctx, q, k, v, block_sparse_mask, sm_scale, BLOCK_M=64, BLOCK_N=64, num_warps=None, num_stages=1, out=None): assert q.shape[-1] == k.shape[-1] == v.shape[-1] assert k.shape[2] == v.shape[2] o = out if out is not None else torch.empty_like(q).contiguous() @@ -254,7 +237,6 @@ def _forward(ctx, class _sparse_attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, block_sparse_dense, sm_scale): # shape constraints @@ -278,9 +260,9 @@ def test_topk_sparse_attention(): torch.manual_seed(0) # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) sm_scale = 1.0 / (D_HEAD**0.5) # Create sparse mask (downsampled to block level) @@ -288,9 +270,7 @@ def test_topk_sparse_attention(): downsample_len = math.ceil(SEQ_LEN / downsample_factor) print("downsample_len", downsample_len) - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.bfloat16) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) x_ds[:, :, :, 0] = 100 print("x_ds.shape", x_ds.shape) block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) @@ -302,22 +282,21 @@ def test_topk_sparse_attention(): # Compute reference # Expand block mask to full attention matrix - full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")) full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal # PyTorch reference implementation - attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale - attn = attn.masked_fill(~full_mask, float('-inf')) + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float("-inf")) attn = F.softmax(attn, dim=-1) - ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) # print("ref_output", ref_output) # print("triton_output", triton_output) # Verify accuracy - assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ - "Triton output doesn't match reference" + assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference" print("Pass topk sparse attention test with qlen == klen") @@ -329,9 +308,9 @@ def test_topk_sparse_attention_qlt_kl(): torch.manual_seed(0) # Create inputs. - q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) + q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) # softmax scale sm_scale = 1.0 / (D_HEAD**0.5) @@ -339,8 +318,7 @@ def test_topk_sparse_attention_qlt_kl(): print("downsample_factor", downsample_factor) downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension print("downsample_len", downsample_len) - x_ds = torch.randn( - BATCH, N_HEADS, downsample_len, downsample_len, device='cuda', dtype=torch.bfloat16) + x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="cuda", dtype=torch.bfloat16) # Force the first column to be high so that the first block is always selected. x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) @@ -351,26 +329,25 @@ def test_topk_sparse_attention_qlt_kl(): past_len = K_LEN - Q_LEN - attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale - full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')).bool() + full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")).bool() full_mask_full = full_mask_full[..., :K_LEN, :K_LEN] effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN) i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1) j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN) - causal_mask = (j_global <= i_global) # shape: (Q_LEN, K_LEN) + causal_mask = j_global <= i_global # shape: (Q_LEN, K_LEN) final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN) - attn = attn.masked_fill(~final_mask, float('-inf')) + attn = attn.masked_fill(~final_mask, float("-inf")) attn = F.softmax(attn, dim=-1) - ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) # Verify accuracy. - assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ - "Triton output doesn't match reference when qlen < klen" + assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference when qlen < klen" print("Pass topk sparse attention test with qlen < klen") diff --git a/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py b/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py index 7e90db7..afb4cc8 100644 --- a/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py +++ b/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py @@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -30,15 +27,17 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F @tilelang.jit( - out_idx=[4], pass_configs={ + out_idx=[4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal): block_M = 64 block_N = 64 num_stages = 1 threads = 128 - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, heads, seq_len, dim] block_mask_shape = [batch, heads, downsample_len, downsample_len] @@ -47,7 +46,6 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) block_mask_dtype = "bool" def kernel_func(block_M, block_N, num_stages, threads): - @T.macro def MMA0( K: T.Tensor(shape, dtype), @@ -59,11 +57,10 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) by: T.int32, bz: T.int32, ): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -78,18 +75,18 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) by: T.int32, bz: T.int32, ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -113,22 +110,21 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def blocksparse_flashattn( - Q: T.Tensor(shape, dtype), - K: T.Tensor(shape, dtype), - V: T.Tensor(shape, dtype), - BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), - Output: T.Tensor(shape, dtype), + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), + Output: T.Tensor(shape, dtype), ): - with T.Kernel( - T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) @@ -143,7 +139,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) logsum = T.alloc_fragment([block_M], accum_dtype) block_mask = T.alloc_local([downsample_len], block_mask_dtype) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -152,20 +148,19 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) block_mask[vj] = BlockSparseMask[bz, by, bx, vj] loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): if block_mask[k] != 0: MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, - scores_sum, logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return blocksparse_flashattn @@ -180,18 +175,16 @@ def test_topk_sparse_attention(): torch.manual_seed(0) # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) sm_scale = 1.0 / (D_HEAD**0.5) # Create sparse mask (downsampled to block level) downsample_factor = BLOCK downsample_len = math.ceil(SEQ_LEN / downsample_factor) - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.bfloat16) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) @@ -202,15 +195,15 @@ def test_topk_sparse_attention(): # Compute reference # Expand block mask to full attention matrix - full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")) full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal # PyTorch reference implementation - attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale - attn = attn.masked_fill(~full_mask, float('-inf')) + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float("-inf")) attn = F.softmax(attn, dim=-1) - ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) print("ref_output", ref_output) print("tilelang_output", tilelang_output) diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py index 1c4b847..99418d5 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py @@ -13,17 +13,20 @@ from heuristic import num_splits_heuristic def flashattn(batch, heads, heads_kv, dim, dim_v): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) dtype = "float16" accum_dtype = "float" kv_group_num = heads // heads_kv @tilelang.jit( - out_idx=[-1], pass_configs={ + out_idx=[-1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) - def kernel_func(block_N, block_H, page_block_size, num_split, num_stages, threads, num_pages, - max_num_blocks_per_seq, max_selected_blocks): + }, + ) + def kernel_func( + block_N, block_H, page_block_size, num_split, num_stages, threads, num_pages, max_num_blocks_per_seq, max_selected_blocks + ): shape_q = [batch, heads, dim] shape_k = [num_pages, page_block_size, heads_kv, dim] shape_v = [num_pages, page_block_size, heads_kv, dim_v] @@ -37,17 +40,16 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): @T.macro def flash_attn_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_indices: T.Tensor(shape_indices, "int32"), - cache_seqlens: T.Tensor([batch], "int32"), - block_table: T.Tensor(shape_block_table, "int32"), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_indices: T.Tensor(shape_indices, "int32"), + cache_seqlens: T.Tensor([batch], "int32"), + block_table: T.Tensor(shape_block_table, "int32"), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), ): - with T.Kernel( - batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim_v], dtype) @@ -67,7 +69,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): sid = bz cur_kv_head = hid // (kv_group_num // valid_block_H) - T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -75,7 +77,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): num_blocks = max_selected_blocks blocks_per_split = T.floordiv(num_blocks, num_split) remaining_blocks = T.floormod(num_blocks, num_split) - loop_range = (blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0)) + loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) start = blocks_per_split * sid + T.min(sid, remaining_blocks) has_valid_block = False for k in T.Pipelined(loop_range, num_stages=num_stages): @@ -85,29 +87,20 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): block_table_idx = T.floordiv(logical_block_idx, block_ratio) block_tile_idx = T.floormod(logical_block_idx, block_ratio) physical_block_idx = block_table[bid, block_table_idx] - T.copy( - K[physical_block_idx, - block_tile_idx * block_N:(block_tile_idx + 1) * block_N, - cur_kv_head, :], K_shared) + T.copy(K[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], K_shared) T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.if_then_else( - logical_block_idx * block_N + j >= cache_seqlens[bid], - -T.infinity(accum_dtype), acc_s[i, j]) + logical_block_idx * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j] + ) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - - scores_max[i] * scale) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) @@ -116,10 +109,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): T.copy(acc_s, acc_s_cast) for i, j in T.Parallel(block_H, dim_v): acc_o[i, j] *= scores_scale[i] - T.copy( - V[physical_block_idx, - block_tile_idx * block_N:(block_tile_idx + 1) * block_N, - cur_kv_head, :], V_shared) + T.copy(V[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) if has_valid_block: for i, j in T.Parallel(block_H, dim_v): @@ -138,9 +128,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): @T.macro def combine( - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), ): with T.Kernel(heads, batch, threads=128) as (by, bz): po_local = T.alloc_fragment([dim_v], accum_dtype) @@ -151,17 +141,18 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): scale_local = T.alloc_local([1], accum_dtype) max_split = T.alloc_local([1], "int32") - T.annotate_layout({ - lse_logsum_local: - T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) T.clear(lse_logsum_local) T.clear(o_accum_local) lse_max_local[0] = -T.infinity(accum_dtype) for k in T.serial(num_split): lse_local_split[0] = glse[bz, by, k] - if (lse_local_split[0] != 0): + if lse_local_split[0] != 0: max_split[0] = k lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) @@ -183,18 +174,17 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): @T.prim_func def main( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_indices: T.Tensor(shape_indices, "int32"), - cache_seqlens: T.Tensor([batch], "int32"), - block_table: T.Tensor(shape_block_table, "int32"), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_indices: T.Tensor(shape_indices, "int32"), + cache_seqlens: T.Tensor([batch], "int32"), + block_table: T.Tensor(shape_block_table, "int32"), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), ): - flash_attn_split(Q, K, V, block_indices, cache_seqlens, block_table, glse, - Output_partial) + flash_attn_split(Q, K, V, block_indices, cache_seqlens, block_table, glse, Output_partial) combine(glse, Output_partial, Output) return main @@ -203,7 +193,6 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): class SparseFlashAttn(torch.nn.Module): - def __init__(self, batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, num_pages): super(SparseFlashAttn, self).__init__() self.batch = batch @@ -249,18 +238,11 @@ class SparseFlashAttn(torch.nn.Module): num_sm = self.num_sm num_split = num_splits_heuristic( - total_mblocks, - num_sm, - num_n_blocks, - num_m_blocks, - size_one_kv_head, - is_causal_or_local=True, - max_splits=128) - - glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda') - output_partial = torch.empty((batch, heads, num_split, dim_v), - dtype=torch.float32, - device='cuda') + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") output = self.kernel( query, @@ -275,14 +257,13 @@ class SparseFlashAttn(torch.nn.Module): return output -def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_seqlens, - block_table, page_block_size, block_size): +def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_seqlens, block_table, page_block_size, block_size): """ Paged version of sparse attention reference implementation. - + Args: query: [batch, heads, dim] - key_cache: [num_pages, page_block_size, heads_kv, dim] + key_cache: [num_pages, page_block_size, heads_kv, dim] value_cache: [num_pages, page_block_size, heads_kv, dim] block_indices: [batch, heads_kv, max_selected_blocks] - logical block indices cache_seqlens: [batch] - actual sequence lengths @@ -298,12 +279,8 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_ # Reconstruct the full key and value tensors from paged cache max_cache_seqlen = max(cache_seqlens).item() - key_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim), - dtype=key_cache.dtype, - device=key_cache.device) - value_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim_v), - dtype=value_cache.dtype, - device=value_cache.device) + key_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim), dtype=key_cache.dtype, device=key_cache.device) + value_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim_v), dtype=value_cache.dtype, device=value_cache.device) # Reconstruct full tensors from paged cache using block_table for b in range(batch): @@ -319,20 +296,14 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_ actual_block_size = end_token - start_token # Copy from paged cache to full tensors - key_full[b, :, start_token:end_token, :] = key_cache[ - physical_block_idx, :actual_block_size, :, :].transpose(0, 1) - value_full[b, :, start_token:end_token, :] = value_cache[ - physical_block_idx, :actual_block_size, :, :].transpose(0, 1) + key_full[b, :, start_token:end_token, :] = key_cache[physical_block_idx, :actual_block_size, :, :].transpose(0, 1) + value_full[b, :, start_token:end_token, :] = value_cache[physical_block_idx, :actual_block_size, :, :].transpose(0, 1) # Reshape query for grouped attention - query = rearrange( - query, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] # Compute attention scores - scores = einsum( - query, key_full, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = einsum(query, key_full, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] # Create sparse mask based on block_indices sparse_mask = torch.zeros_like(scores) @@ -348,24 +319,23 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_ sparse_mask[b, :, h, start_pos:end_pos] = 1 # Apply sparse mask - scores = scores.masked_fill(sparse_mask == 0, float('-inf')) + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) # Apply causal mask based on actual sequence lengths range_len = torch.arange(scores.shape[-1], device=scores.device).unsqueeze(0) cache_seqlens_expanded = cache_seqlens.unsqueeze(1) pad_mask = range_len >= cache_seqlens_expanded pad_mask = pad_mask[:, None, None, :] - scores = scores.masked_fill(pad_mask, float('-inf')) + scores = scores.masked_fill(pad_mask, float("-inf")) # Compute attention weights attention = F.softmax(scores / scale, dim=-1) # Apply attention to values - out = einsum(attention, value_full, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim] + out = einsum(attention, value_full, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] # Reshape output back to original format - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out @@ -373,17 +343,23 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_ def ref_program_fa(query, kcache, vcache, cache_seqlens, block_table): # latency reference # from flash_attn_interface import flash_attn_with_kvcache # fa3 - from flash_attn import flash_attn_with_kvcache #fa2 + from flash_attn import flash_attn_with_kvcache # fa2 + query = query.unsqueeze(1) - output = flash_attn_with_kvcache( - query, kcache, vcache, cache_seqlens=cache_seqlens, block_table=block_table) + output = flash_attn_with_kvcache(query, kcache, vcache, cache_seqlens=cache_seqlens, block_table=block_table) output = output.squeeze(1) return output def main(args): - - batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = ( + args.batch, + args.heads, + args.heads_kv, + args.max_cache_seqlen, + args.dim, + args.dim_v, + ) sparse_ratio = args.sparse_ratio block_N = args.block_N page_block_size = args.page_block_size @@ -395,35 +371,30 @@ def main(args): dtype = torch.float16 # Generate random inputs - Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') - cache_seqlens = torch.randint( - max_cache_seqlen // 2, max_cache_seqlen + 1, (batch,), dtype=torch.int32, device='cuda') + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(max_cache_seqlen // 2, max_cache_seqlen + 1, (batch,), dtype=torch.int32, device="cuda") print("cache_seqlens: ", cache_seqlens) - K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') - V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") # Create paged KV cache - K_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim), dtype=dtype, device='cuda') - V_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim_v), - dtype=dtype, - device='cuda') + K_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim), dtype=dtype, device="cuda") + V_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim_v), dtype=dtype, device="cuda") # Create block table and block indices for dense case (all blocks selected) max_num_blocks_per_seq = int(math.ceil(max_cache_seqlen / page_block_size)) print("max_num_blocks_per_seq: ", max_num_blocks_per_seq) - block_table = torch.zeros((batch, max_num_blocks_per_seq), dtype=torch.int32, device='cuda') - block_indices = torch.zeros((batch, heads_kv, max_selected_blocks), - dtype=torch.int32, - device='cuda') + block_table = torch.zeros((batch, max_num_blocks_per_seq), dtype=torch.int32, device="cuda") + block_indices = torch.zeros((batch, heads_kv, max_selected_blocks), dtype=torch.int32, device="cuda") # Fill block table and block indices and cache # Create a pool of available physical blocks - total_blocks_needed = sum( - int(math.ceil(cache_seqlens[seq_idx].item() / page_block_size)) for seq_idx in range(batch)) + total_blocks_needed = sum(int(math.ceil(cache_seqlens[seq_idx].item() / page_block_size)) for seq_idx in range(batch)) available_blocks = list(range(total_blocks_needed)) import random + random.seed(42) # For reproducibility random.shuffle(available_blocks) @@ -458,10 +429,8 @@ def main(args): actual_block_size = end_token - start_token # Copy K and V data to the paged cache - K_cache[physical_block_idx, :actual_block_size, :, :] = K[seq_idx, - start_token:end_token, :, :] - V_cache[physical_block_idx, :actual_block_size, :, :] = V[seq_idx, - start_token:end_token, :, :] + K_cache[physical_block_idx, :actual_block_size, :, :] = K[seq_idx, start_token:end_token, :, :] + V_cache[physical_block_idx, :actual_block_size, :, :] = V[seq_idx, start_token:end_token, :, :] # Fill block_indices for sparse attention # For dense case (verification), we select all blocks in reverse order @@ -496,10 +465,9 @@ def main(args): remaining_blocks = [b for b in all_blocks if b not in selected_blocks] if remaining_blocks: import random + random.seed(42) # For reproducibility - additional_blocks = random.sample( - remaining_blocks, - min(num_selected - recent_blocks, len(remaining_blocks))) + additional_blocks = random.sample(remaining_blocks, min(num_selected - recent_blocks, len(remaining_blocks))) selected_blocks.extend(additional_blocks) # Sort selected blocks in reverse order (most recent first) @@ -512,25 +480,20 @@ def main(args): block_indices[seq_idx, head_idx, i] = -1 # Initialize sparse attention module - sparse_attn = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, - num_blocks) - output_sparse = sparse_attn.forward(Q, K_cache, V_cache, block_indices, cache_seqlens, - block_table) + sparse_attn = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, num_blocks) + output_sparse = sparse_attn.forward(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table) import flash_attn # noqa: F401 - output_ref_torch = ref_program_torch_paged(Q, K_cache, V_cache, block_indices, cache_seqlens, - block_table, page_block_size, block_N) + output_ref_torch = ref_program_torch_paged(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table, page_block_size, block_N) output_ref_fa = ref_program_fa(Q, K_cache, V_cache, cache_seqlens, block_table) # Check correctness if sparse_ratio == 0.0: max_diff = torch.max(torch.abs(output_sparse - output_ref_fa)).item() mean_diff = torch.mean(torch.abs(output_sparse - output_ref_fa)).item() - assert torch.allclose( - output_ref_fa, output_ref_torch, atol=1e-2), "Reference outputs do not match!" + assert torch.allclose(output_ref_fa, output_ref_torch, atol=1e-2), "Reference outputs do not match!" else: - max_diff = torch.max(torch.abs(output_sparse - output_ref_torch)).item() mean_diff = torch.mean(torch.abs(output_sparse - output_ref_torch)).item() @@ -574,16 +537,15 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') - parser.add_argument( - '--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--dim_v', type=int, default=128, help='dim_v') - parser.add_argument('--sparse_ratio', type=float, default=0.0, help='sparse ratio') - parser.add_argument('--block_N', type=int, default=64, help='block_N') - parser.add_argument('--page_block_size', type=int, default=256, help='block size of pages') - parser.add_argument('--num_pages', type=int, default=1024, help='total number of pages') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.0, help="sparse ratio") + parser.add_argument("--block_N", type=int, default=64, help="block_N") + parser.add_argument("--page_block_size", type=int, default=256, help="block size of pages") + parser.add_argument("--num_pages", type=int, default=1024, help="total number of pages") args = parser.parse_args() main(args) diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py index b308752..8b5cde3 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py @@ -10,17 +10,18 @@ from heuristic import num_splits_heuristic def flashattn(batch, heads, heads_kv, dim, dim_v): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) dtype = "float16" accum_dtype = "float" kv_group_num = heads // heads_kv @tilelang.jit( - out_idx=[-1], pass_configs={ + out_idx=[-1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) - def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, - max_selected_blocks): + }, + ) + def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, max_selected_blocks): shape_q = [batch, heads, dim] shape_k = [batch, max_cache_seqlen, heads_kv, dim] shape_v = [batch, max_cache_seqlen, heads_kv, dim_v] @@ -31,17 +32,16 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): @T.macro def flash_attn_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_indices: T.Tensor(shape_indices, "int32"), - cache_seqlens: T.Tensor([batch], "int32"), - # actual_num_blocks: T.Tensor([batch], "int32"), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_indices: T.Tensor(shape_indices, "int32"), + cache_seqlens: T.Tensor([batch], "int32"), + # actual_num_blocks: T.Tensor([batch], "int32"), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), ): - with T.Kernel( - batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim_v], dtype) @@ -62,7 +62,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): sid = bz cur_kv_head = hid // (kv_group_num // valid_block_H) - T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -70,7 +70,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): num_blocks = max_selected_blocks blocks_per_split = T.floordiv(num_blocks, num_split) remaining_blocks = T.floormod(num_blocks, num_split) - loop_range = (blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0)) + loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) start = blocks_per_split * sid + T.min(sid, remaining_blocks) has_valid_block = False @@ -78,26 +78,18 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): i_s = block_indices[bid, cur_kv_head, start + k] if i_s >= 0: has_valid_block = True - T.copy(K[bid, i_s * block_N:(i_s + 1) * block_N, cur_kv_head, :], K_shared) + T.copy(K[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], K_shared) T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition for i, j in T.Parallel(block_H, block_N): - acc_s[i, - j] = T.if_then_else(i_s * block_N + j >= cache_seqlens[bid], - -T.infinity(accum_dtype), acc_s[i, j]) + acc_s[i, j] = T.if_then_else(i_s * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j]) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - - scores_max[i] * scale) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) @@ -106,7 +98,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): T.copy(acc_s, acc_s_cast) for i, j in T.Parallel(block_H, dim_v): acc_o[i, j] *= scores_scale[i] - T.copy(V[bid, i_s * block_N:(i_s + 1) * block_N, cur_kv_head, :], V_shared) + T.copy(V[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) if has_valid_block: for i, j in T.Parallel(block_H, dim_v): @@ -125,9 +117,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): @T.macro def combine( - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), ): with T.Kernel(heads, batch, threads=128) as (by, bz): po_local = T.alloc_fragment([dim_v], accum_dtype) @@ -138,17 +130,18 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): scale_local = T.alloc_local([1], accum_dtype) max_split = T.alloc_local([1], "int32") - T.annotate_layout({ - lse_logsum_local: - T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) T.clear(lse_logsum_local) T.clear(o_accum_local) lse_max_local[0] = -T.infinity(accum_dtype) for k in T.serial(num_split): lse_local_split[0] = glse[bz, by, k] - if (lse_local_split[0] != 0): + if lse_local_split[0] != 0: max_split[0] = k lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) @@ -170,15 +163,15 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): @T.prim_func def main( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_indices: T.Tensor(shape_indices, "int32"), - cache_seqlens: T.Tensor([batch], "int32"), - # actual_num_blocks: T.Tensor([batch], "int32"), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_indices: T.Tensor(shape_indices, "int32"), + cache_seqlens: T.Tensor([batch], "int32"), + # actual_num_blocks: T.Tensor([batch], "int32"), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), ): # flash_attn_split(Q, K, V, block_indices, cache_seqlens, actual_num_blocks, glse, Output_partial) flash_attn_split(Q, K, V, block_indices, cache_seqlens, glse, Output_partial) @@ -190,7 +183,6 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): class SparseFlashAttn(torch.nn.Module): - def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size): super(SparseFlashAttn, self).__init__() self.batch = batch @@ -209,7 +201,8 @@ class SparseFlashAttn(torch.nn.Module): num_stages=2, threads=128, max_cache_seqlen=T.dynamic("max_cache_seqlen"), - max_selected_blocks=T.dynamic("max_selected_blocks")) + max_selected_blocks=T.dynamic("max_selected_blocks"), + ) props = torch.cuda.get_device_properties(torch.device("cuda:0")) self.num_sm = props.multi_processor_count @@ -232,25 +225,17 @@ class SparseFlashAttn(torch.nn.Module): num_sm = self.num_sm num_split = num_splits_heuristic( - total_mblocks, - num_sm, - num_n_blocks, - num_m_blocks, - size_one_kv_head, - is_causal_or_local=True, - max_splits=128) - - glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda') - output_partial = torch.empty((batch, heads, num_split, dim_v), - dtype=torch.float32, - device='cuda') + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") output = self.kernel(query, key, value, block_indices, cache_seqlens, glse, output_partial) return output -def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seqlens, - max_cache_seqlen, block_size): +def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, block_size): """ Args: query: [batch, heads, dim] @@ -272,31 +257,24 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql block_H = 64 actual_num_blocks = torch.sum(block_indices != -1, dim=-1).to(torch.int32) - actual_num_blocks = actual_num_blocks[:, - 0] #[batch], number of valid blocks, assume all groups in the same batch have the same number of blocks + actual_num_blocks = actual_num_blocks[ + :, 0 + ] # [batch], number of valid blocks, assume all groups in the same batch have the same number of blocks # get num_split num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H - num_n_blocks = max_selected_blocks #(kv_seqlen + block_size - 1 ) // block_size + num_n_blocks = max_selected_blocks # (kv_seqlen + block_size - 1 ) // block_size # num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks - size_one_kv_head = max_selected_blocks * block_size * ( - dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2 + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 total_mblocks = batch * heads_kv * num_m_blocks num_sm = 132 num_split = num_splits_heuristic( - total_mblocks, - num_sm, - num_n_blocks, - num_m_blocks, - size_one_kv_head, - is_causal_or_local=True, - max_splits=128) - - glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda') - Output_partial = torch.empty((batch, heads, num_split, dim_v), - dtype=torch.float32, - device='cuda') + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( block_N=block_size, block_H=block_H, @@ -304,29 +282,24 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql num_stages=2, threads=128, max_cache_seqlen=T.dynamic("max_cache_seqlen"), - max_selected_blocks=T.dynamic("max_selected_blocks")) + max_selected_blocks=T.dynamic("max_selected_blocks"), + ) output = kernel(query, key, value, block_indices, cache_seqlens, glse, Output_partial) return output -def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, - block_size): - +def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size): batch, heads, dim = query.shape heads_kv = key.shape[2] num_head_groups = query.shape[1] // key.shape[2] scale = dim**0.5 - key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] - value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] + key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] - query = rearrange( - query, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] sparse_mask = torch.zeros_like(scores) # Assign mask values based on block_indices @@ -335,28 +308,26 @@ def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache valid_indices = block_indices[b, h] # Extract indices for this batch and head for idx in valid_indices: if idx >= 0: - sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1 - scores = scores.masked_fill(sparse_mask == 0, float('-inf')) + sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1 + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) - range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0) + range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0) cache_seqlens_expanded = cache_seqlens.unsqueeze(1) pad_mask = range_len >= cache_seqlens_expanded pad_mask = pad_mask[:, None, None, :] - scores = scores.masked_fill(pad_mask, float('-inf')) - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = scores.masked_fill(pad_mask, float("-inf")) + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] - out = einsum(attention, value, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out -def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, - block_size): +def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size): # latency reference # from flash_attn_interface import flash_attn_with_kvcache # fa3 - from flash_attn import flash_attn_with_kvcache #fa2 + from flash_attn import flash_attn_with_kvcache # fa2 + query = query.unsqueeze(1) output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) output = output.squeeze(1) @@ -368,23 +339,13 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3): print(name + " all_close={}".format(all_close)) if not all_close: diff = (expect - actual).abs() - print("all_close={}, max={}, min={}, mean={}".format(all_close, - diff.max().item(), - diff.min().item(), - diff.mean().item())) + print("all_close={}, max={}, min={}, mean={}".format(all_close, diff.max().item(), diff.min().item(), diff.mean().item())) max_indices = torch.nonzero(diff == diff.max().item()) first_index = tuple(max_indices[0].tolist()) print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}") -def main(batch=8, - heads=32, - heads_kv=8, - max_cache_seqlen=8192, - dim=128, - dim_v=128, - sparse_ratio=0.8, - block_size=32): +def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v sparse_ratio = sparse_ratio block_size = block_size @@ -392,10 +353,10 @@ def main(batch=8, print("max_selected_blocks: ", max_selected_blocks) dtype = torch.float16 - Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') - K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') - V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') - cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda') + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") # cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda') # # Ensure at least one element equals cache_seqlen # random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index @@ -406,10 +367,7 @@ def main(batch=8, max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() print("max_valid_num_blocks: ", max_valid_num_blocks) # Initialize block_indices with -1 (for padding blocks) - block_indices = torch.full((batch, heads_kv, max_selected_blocks), - -1, - dtype=torch.int32, - device='cuda') + block_indices = torch.full((batch, heads_kv, max_selected_blocks), -1, dtype=torch.int32, device="cuda") # max_num_blocks = int((max_cache_seqlen + block_size - 1)/ block_size) # block_indices = torch.full((batch, heads_kv, max_num_blocks), -1, dtype=torch.int32, device='cuda') @@ -418,10 +376,9 @@ def main(batch=8, max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch if max_valid_block > 0: # Ensure there's at least one valid block for h in range(heads_kv): - valid_indices = torch.randperm( - max_valid_block, device='cuda', dtype=torch.int32)[:max_selected_blocks] + valid_indices = torch.randperm(max_valid_block, device="cuda", dtype=torch.int32)[:max_selected_blocks] # valid_indices = torch.randperm(max_valid_block, device='cuda', dtype=torch.int32)[:max_num_blocks] - block_indices[b, h, :len(valid_indices)] = valid_indices + block_indices[b, h, : len(valid_indices)] = valid_indices # Sort indices within each batch-group for consistency block_indices, _ = block_indices.sort(dim=-1, descending=True) @@ -434,8 +391,7 @@ def main(batch=8, print("max_num_blocks: ", max_num_blocks) # parity reference - ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, - block_size) + ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) out = sparse_kernel(Q, K, V, block_indices, cache_seqlens) @@ -445,13 +401,11 @@ def main(batch=8, ## latency reference for _ in range(10): - ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, - max_num_blocks, block_size) + ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) torch.cuda.synchronize() start = time.time() for _ in range(100): - ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, - max_num_blocks, block_size) + ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) torch.cuda.synchronize() print("dense time: ", (time.time() - start) / 100 * 1000) @@ -469,15 +423,13 @@ def main(batch=8, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') - parser.add_argument( - '--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--dim_v', type=int, default=128, help='dim_v') - parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') - parser.add_argument('--block_size', type=int, default=32, help='block_size') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio") + parser.add_argument("--block_size", type=int, default=32, help="block_size") args = parser.parse_args() - main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, - args.sparse_ratio, args.block_size) + main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size) diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py index 3417bd7..0d75921 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py @@ -12,15 +12,17 @@ from heuristic import num_splits_heuristic def flashattn(batch, heads, heads_kv, dim, dim_v): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) dtype = "float16" accum_dtype = "float" kv_group_num = heads // heads_kv @tilelang.jit( - out_idx=[-1], pass_configs={ + out_idx=[-1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, + ) def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, num_blocks): shape_q = [batch, heads, dim] shape_k = [batch, max_cache_seqlen, heads_kv, dim] @@ -32,16 +34,15 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): @T.macro def flash_attn_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_mask: T.Tensor(shape_mask, "bool"), - cache_seqlens: T.Tensor([batch], "int32"), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_mask: T.Tensor(shape_mask, "bool"), + cache_seqlens: T.Tensor([batch], "int32"), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), ): - with T.Kernel( - batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim_v], dtype) @@ -62,39 +63,31 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): sid = bz cur_kv_head = hid // (kv_group_num // valid_block_H) - T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) blocks_per_split = T.floordiv(num_blocks, num_split) remaining_blocks = T.floormod(num_blocks, num_split) - loop_range = (blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0)) + loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) start = blocks_per_split * sid + T.min(sid, remaining_blocks) has_valid_block = False for k in T.Pipelined(loop_range, num_stages=num_stages): if block_mask[bid, hid, start + k]: has_valid_block = True - T.copy( - K[bid, (start + k) * block_N:(start + k + 1) * block_N, cur_kv_head, :], - K_shared) + T.copy(K[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], K_shared) T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else((start + k) * block_N + j - >= cache_seqlens[bx], - -T.infinity(accum_dtype), acc_s[i, j]) + acc_s[i, j] = T.if_then_else( + (start + k) * block_N + j >= cache_seqlens[bx], -T.infinity(accum_dtype), acc_s[i, j] + ) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - - scores_max[i] * scale) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) @@ -103,9 +96,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): T.copy(acc_s, acc_s_cast) for i, j in T.Parallel(block_H, dim_v): acc_o[i, j] *= scores_scale[i] - T.copy( - V[bid, (start + k) * block_N:(start + k + 1) * block_N, cur_kv_head, :], - V_shared) + T.copy(V[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) if has_valid_block: for i, j in T.Parallel(block_H, dim_v): @@ -123,9 +114,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): @T.macro def combine( - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), ): with T.Kernel(heads, batch, threads=128) as (by, bz): po_local = T.alloc_fragment([dim_v], accum_dtype) @@ -135,10 +126,11 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): lse_max_local = T.alloc_local([1], accum_dtype) scale_local = T.alloc_local([1], accum_dtype) - T.annotate_layout({ - lse_logsum_local: - T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) T.clear(lse_logsum_local) T.clear(o_accum_local) @@ -161,14 +153,14 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): @T.prim_func def main( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_mask: T.Tensor(shape_mask, "bool"), - cache_seqlens: T.Tensor([batch], "int32"), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_mask: T.Tensor(shape_mask, "bool"), + cache_seqlens: T.Tensor([batch], "int32"), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), ): flash_attn_split(Q, K, V, block_mask, cache_seqlens, glse, Output_partial) combine(glse, Output_partial, Output) @@ -179,7 +171,6 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): class SparseFlashAttn(torch.nn.Module): - def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size): super(SparseFlashAttn, self).__init__() self.batch = batch @@ -198,7 +189,8 @@ class SparseFlashAttn(torch.nn.Module): num_stages=2, threads=128, max_cache_seqlen=T.dynamic("max_cache_seqlen"), - num_blocks=T.dynamic("num_blocks")) + num_blocks=T.dynamic("num_blocks"), + ) props = torch.cuda.get_device_properties(torch.device("cuda:0")) self.num_sm = props.multi_processor_count @@ -217,24 +209,16 @@ class SparseFlashAttn(torch.nn.Module): num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H num_n_blocks = max_selected_blocks - size_one_kv_head = max_selected_blocks * block_size * ( - dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2 + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 total_mblocks = batch * heads_kv * num_m_blocks # num_sm = 132 num_sm = self.num_sm num_split = num_splits_heuristic( - total_mblocks, - num_sm, - num_n_blocks, - num_m_blocks, - size_one_kv_head, - is_causal_or_local=True, - max_splits=128) + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) # print("num_split: ", num_split) - glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda') - Output_partial = torch.empty((batch, heads, num_split, dim_v), - dtype=torch.float32, - device='cuda') + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") output = self.kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial) return output @@ -259,26 +243,21 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens, block_H = 64 actual_num_blocks = torch.sum(block_mask, dim=-1).to(torch.int32) - actual_num_blocks = actual_num_blocks[:, - 0] #[batch], number of valid blocks, assume all groups in the same batch have the same number of blocks + actual_num_blocks = actual_num_blocks[ + :, 0 + ] # [batch], number of valid blocks, assume all groups in the same batch have the same number of blocks max_selected_blocks = actual_num_blocks.max().item() # get num_split num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H - num_n_blocks = max_selected_blocks #(kv_seqlen + block_size - 1 ) // block_size + num_n_blocks = max_selected_blocks # (kv_seqlen + block_size - 1 ) // block_size # num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks - size_one_kv_head = max_selected_blocks * block_size * ( - dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2 + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 total_mblocks = batch * heads_kv * num_m_blocks num_sm = 132 num_split = num_splits_heuristic( - total_mblocks, - num_sm, - num_n_blocks, - num_m_blocks, - size_one_kv_head, - is_causal_or_local=True, - max_splits=128) + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( block_N=block_size, @@ -287,11 +266,10 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens, num_stages=2, threads=128, max_cache_seqlen=T.dynamic("max_cache_seqlen"), - num_blocks=T.dynamic("num_blocks")) - glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda') - Output_partial = torch.empty((batch, heads, num_split, dim_v), - dtype=torch.float32, - device='cuda') + num_blocks=T.dynamic("num_blocks"), + ) + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") # print(kernel.get_kernel_source()) output = kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial) @@ -299,24 +277,18 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens, return output -def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, - block_size): - +def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size): batch, heads, dim = query.shape heads_kv = key.shape[2] num_head_groups = query.shape[1] // key.shape[2] scale = dim**0.5 - key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] - value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] + key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] - query = rearrange( - query, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] sparse_mask = torch.zeros_like(scores) # Assign mask values @@ -324,29 +296,27 @@ def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_se for h in range(heads_kv): for idx in range(num_blocks): if block_mask[b, h, idx]: - sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1 + sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1 - scores = scores.masked_fill(sparse_mask == 0, float('-inf')) + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) - range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0) + range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0) cache_seqlens_expanded = cache_seqlens.unsqueeze(1) pad_mask = range_len >= cache_seqlens_expanded pad_mask = pad_mask[:, None, None, :] - scores = scores.masked_fill(pad_mask, float('-inf')) - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = scores.masked_fill(pad_mask, float("-inf")) + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] - out = einsum(attention, value, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out -def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, - block_size): +def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size): # latency reference # from flash_attn_interface import flash_attn_with_kvcache # fa3 - from flash_attn import flash_attn_with_kvcache #fa2 + from flash_attn import flash_attn_with_kvcache # fa2 + query = query.unsqueeze(1) output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) output = output.squeeze(1) @@ -360,23 +330,13 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3): # print(expect[3, 28]) # print(actual[3, 28]) diff = (expect - actual).abs() - print("all_close={}, max={}, min={}, mean={}".format(all_close, - diff.max().item(), - diff.min().item(), - diff.mean().item())) + print("all_close={}, max={}, min={}, mean={}".format(all_close, diff.max().item(), diff.min().item(), diff.mean().item())) max_indices = torch.nonzero(diff == diff.max().item()) first_index = tuple(max_indices[0].tolist()) print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}") -def main(batch=8, - heads=32, - heads_kv=8, - max_cache_seqlen=8192, - dim=128, - dim_v=128, - sparse_ratio=0.8, - block_size=32): +def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v sparse_ratio = sparse_ratio block_size = block_size @@ -384,14 +344,13 @@ def main(batch=8, print("max_selected_blocks: ", max_selected_blocks) dtype = torch.float16 - Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') - K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') - V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') - cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda') + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") # Ensure at least one element equals cache_seqlen - random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index - cache_seqlens[ - random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence + random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index + cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence # cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda') print("cache_seqlens: ", cache_seqlens) @@ -403,7 +362,7 @@ def main(batch=8, max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() print("max_valid_num_blocks: ", max_valid_num_blocks) # Initialize block_mask with false (for padding blocks) - block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device='cuda') + block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device="cuda") # Assign valid indices while ensuring no duplicates within each batch-group for b in range(batch): @@ -411,13 +370,12 @@ def main(batch=8, valid_num_block = valid_num_blocks[b].item() # Valid blocks for this batch if valid_num_block > 0: # Ensure there's at least one valid block for h in range(heads_kv): - perm = torch.randperm(max_valid_block, device='cuda')[:valid_num_block] + perm = torch.randperm(max_valid_block, device="cuda")[:valid_num_block] block_mask[b, h, perm] = True # print("block_mask: ", block_mask) # parity reference - ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, - block_size) + ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size) # out = sparse_gqa_decode_varlen_mask(Q, K, V, block_mask, cache_seqlens, block_size) model = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) out = model(Q, K, V, block_mask, cache_seqlens) @@ -427,13 +385,11 @@ def main(batch=8, ## latency reference for _ in range(10): - ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, - block_size) + ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size) torch.cuda.synchronize() start = time.time() for _ in range(100): - ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, - block_size) + ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size) torch.cuda.synchronize() print("dense time: ", (time.time() - start) / 100 * 1000) @@ -452,15 +408,13 @@ def main(batch=8, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') - parser.add_argument( - '--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--dim_v', type=int, default=128, help='dim_v') - parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') - parser.add_argument('--block_size', type=int, default=32, help='block_size') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio") + parser.add_argument("--block_size", type=int, default=32, help="block_size") args = parser.parse_args() - main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, - args.sparse_ratio, args.block_size) + main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size) diff --git a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py index 85b72b7..b61d52f 100644 --- a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py +++ b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py @@ -12,12 +12,8 @@ from heuristic import num_splits_heuristic @triton.autotune( - configs=[ - triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [1, 2, 4]\ - for num_stages in [1, 2, 3, 4, 7] - ], - key=['BLOCK_H', 'BLOCK_N', 'BLOCK_D'], + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]], + key=["BLOCK_H", "BLOCK_N", "BLOCK_D"], ) @triton.jit def _split_kernel( @@ -79,16 +75,11 @@ def _split_kernel( loop_range = blocks_per_split q_ptr += batch_idx * stride_q_b + head_idx_q * stride_q_h - k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[ - None, :] * stride_k_s + offs_d[:, None] * stride_k_d - v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:, - None] * stride_v_s + offs_d[ - None, :] * stride_v_d + k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[None, :] * stride_k_s + offs_d[:, None] * stride_k_d + v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:, None] * stride_v_s + offs_d[None, :] * stride_v_d mask_ptr += batch_idx * stride_mask_b + head_idx_kv * stride_mask_h - q = tl.load( - q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d, - mask=offs_h[:, None] < gqa_group_size) + q = tl.load(q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d, mask=offs_h[:, None] < gqa_group_size) start = blocks_per_split * split_idx + tl.minimum(split_idx, remaining_blocks) for i in range(loop_range): block_idx = tl.load(mask_ptr + (start + i) * stride_mask_s) @@ -119,23 +110,18 @@ def _split_kernel( acc = acc * l_recip acc = acc.to(o_partial_ptr.dtype.element_ty) - lse_partial_ptr += batch_idx * stride_lse_b + ( - head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split + lse_partial_ptr += batch_idx * stride_lse_b + (head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split tl.store(lse_partial_ptr, m_i, mask=offs_h < gqa_group_size) - o_partial_ptr += batch_idx * stride_o_b + ( - head_idx_q + - offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d + o_partial_ptr += ( + batch_idx * stride_o_b + (head_idx_q + offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d + ) tl.store(o_partial_ptr, acc, mask=offs_h[:, None] < gqa_group_size) @triton.autotune( - configs=[ - triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [1, 2, 4]\ - for num_stages in [1, 2, 3, 4, 7] - ], - key=['BLOCK_D'], + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]], + key=["BLOCK_D"], ) @triton.jit def _merge_kernel( @@ -163,18 +149,15 @@ def _merge_kernel( offs_d = tl.arange(0, BLOCK_D) lse_offsets = lse_partial_ptr + batch_idx * lse_partial_stride_b + head_idx * lse_partial_stride_h - lse = tl.load( - lse_offsets + offs_splits * lse_partial_stride_split, - mask=offs_splits < num_splits, - other=float("-inf")) + lse = tl.load(lse_offsets + offs_splits * lse_partial_stride_split, mask=offs_splits < num_splits, other=float("-inf")) lse_max = tl.max(lse) o_offsets = o_partial_ptr + batch_idx * o_partial_stride_b + head_idx * o_partial_stride_h o_partial = tl.load( - o_offsets + offs_splits[:, None] * o_partial_stride_split + - offs_d[None, :] * o_partial_stride_d, - mask=offs_splits[:, None] < num_splits) + o_offsets + offs_splits[:, None] * o_partial_stride_split + offs_d[None, :] * o_partial_stride_d, + mask=offs_splits[:, None] < num_splits, + ) sumexp_normalized_splitk = tl.exp(lse - lse_max) sumexp_normalized = tl.sum(sumexp_normalized_splitk, axis=0) numerator_normalized = tl.sum(o_partial * sumexp_normalized_splitk[:, None], axis=0) @@ -209,19 +192,13 @@ def block_sparse_flash_decode_gqa_indice_triton( num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H num_n_blocks = max_selected_blocks - size_one_kv_head = max_selected_blocks * block_size * ( - dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2 + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 total_mblocks = batch * heads_kv * num_m_blocks num_sm = 64 # num_sm = self.num_sm num_splits = num_splits_heuristic( - total_mblocks, - num_sm, - num_n_blocks, - num_m_blocks, - size_one_kv_head, - is_causal_or_local=True, - max_splits=128) + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) # print("num_splits:", num_splits, "num_blocks:", num_n_blocks) @@ -295,24 +272,18 @@ def block_sparse_flash_decode_gqa_indice_triton( return output -def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, - block_size): - +def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size): batch, heads, dim = query.shape heads_kv = key.shape[2] dim_v = value.shape[-1] num_head_groups = query.shape[1] // key.shape[2] scale = dim**0.5 - key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] - value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] + key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] - query = rearrange( - query, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] sparse_mask = torch.zeros_like(scores) # Assign mask values based on block_indices @@ -321,42 +292,33 @@ def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache valid_indices = block_indices[b, h] # Extract indices for this batch and head for idx in valid_indices: if idx >= 0: - sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1 - scores = scores.masked_fill(sparse_mask == 0, float('-inf')) + sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1 + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) - range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0) + range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0) cache_seqlens_expanded = cache_seqlens.unsqueeze(1) pad_mask = range_len >= cache_seqlens_expanded pad_mask = pad_mask[:, None, None, :] - scores = scores.masked_fill(pad_mask, float('-inf')) - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = scores.masked_fill(pad_mask, float("-inf")) + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] - out = einsum(attention, value, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out def ref_program_fa(query, key, value, cache_seqlens): # latency reference # from flash_attn_interface import flash_attn_with_kvcache # fa3 - from flash_attn import flash_attn_with_kvcache #fa2 + from flash_attn import flash_attn_with_kvcache # fa2 + query = query.unsqueeze(1) output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) output = output.squeeze(1) return output -def main(batch=64, - heads=32, - heads_kv=8, - max_cache_seqlen=8192, - dim=128, - dim_v=128, - sparse_ratio=0.8, - block_size=32): - +def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v sparse_ratio = sparse_ratio block_size = block_size @@ -369,34 +331,29 @@ def main(batch=64, dtype = torch.float16 block_H = 64 - Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') - K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') - V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') - cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda') + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") # cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda') # Ensure at least one element equals cache_seqlen - random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index - cache_seqlens[ - random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence + random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index + cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence print("cache_seqlens: ", cache_seqlens) max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() print("max_valid_num_blocks: ", max_valid_num_blocks) # Initialize block_indices with -1 (for padding blocks) - block_indices = torch.full((batch, heads_kv, max_selected_blocks), - -1, - dtype=torch.int32, - device='cuda') + block_indices = torch.full((batch, heads_kv, max_selected_blocks), -1, dtype=torch.int32, device="cuda") # Assign valid indices while ensuring no duplicates within each batch-group for b in range(batch): max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch if max_valid_block > 0: # Ensure there's at least one valid block for h in range(heads_kv): - valid_indices = torch.randperm( - max_valid_block, device='cuda', dtype=torch.int32)[:max_selected_blocks] - block_indices[b, h, :len(valid_indices)] = valid_indices + valid_indices = torch.randperm(max_valid_block, device="cuda", dtype=torch.int32)[:max_selected_blocks] + block_indices[b, h, : len(valid_indices)] = valid_indices # Sort indices within each batch-group for consistency block_indices, _ = block_indices.sort(dim=-1, descending=True) @@ -408,8 +365,7 @@ def main(batch=64, max_num_blocks = torch.max(max_valid_num_blocks).item() print("max_num_blocks: ", max_num_blocks) - ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, - block_size) + ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) triton_out = block_sparse_flash_decode_gqa_indice_triton( Q, @@ -423,8 +379,7 @@ def main(batch=64, ) print("max difference: ", torch.max(torch.abs(ref - triton_out))) - assert torch.allclose( - ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation" + assert torch.allclose(ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation" print("Passed the ref test!") # Measure performance @@ -466,15 +421,13 @@ def main(batch=64, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=64, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') - parser.add_argument( - '--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--dim_v', type=int, default=128, help='dim_v') - parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') - parser.add_argument('--block_size', type=int, default=32, help='block_size') + parser.add_argument("--batch", type=int, default=64, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio") + parser.add_argument("--block_size", type=int, default=32, help="block_size") args = parser.parse_args() - main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, - args.sparse_ratio, args.block_size) + main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size) diff --git a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py index 3485725..c05b377 100644 --- a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py +++ b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py @@ -11,12 +11,8 @@ from heuristic import num_splits_heuristic @triton.autotune( - configs=[ - triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [1, 2, 4]\ - for num_stages in [1, 2, 3, 4, 7] - ], - key=['BLOCK_H', 'BLOCK_N', 'BLOCK_D'], + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]], + key=["BLOCK_H", "BLOCK_N", "BLOCK_D"], ) @triton.jit def _split_kernel( @@ -77,16 +73,11 @@ def _split_kernel( loop_range = blocks_per_split q_ptr += batch_idx * stride_q_b + head_idx_q * stride_q_h - k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[ - None, :] * stride_k_s + offs_d[:, None] * stride_k_d - v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:, - None] * stride_v_s + offs_d[ - None, :] * stride_v_d + k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[None, :] * stride_k_s + offs_d[:, None] * stride_k_d + v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:, None] * stride_v_s + offs_d[None, :] * stride_v_d mask_ptr += batch_idx * stride_mask_b + head_idx_kv * stride_mask_h - q = tl.load( - q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d, - mask=offs_h[:, None] < gqa_group_size) + q = tl.load(q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d, mask=offs_h[:, None] < gqa_group_size) start = blocks_per_split * split_idx + tl.minimum(split_idx, remaining_blocks) for block_idx in range(loop_range): start_n = (start + block_idx) * BLOCK_N @@ -117,23 +108,18 @@ def _split_kernel( acc = acc * l_recip acc = acc.to(o_partial_ptr.dtype.element_ty) - lse_partial_ptr += batch_idx * stride_lse_b + ( - head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split + lse_partial_ptr += batch_idx * stride_lse_b + (head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split tl.store(lse_partial_ptr, m_i, mask=offs_h < gqa_group_size) - o_partial_ptr += batch_idx * stride_o_b + ( - head_idx_q + - offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d + o_partial_ptr += ( + batch_idx * stride_o_b + (head_idx_q + offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d + ) tl.store(o_partial_ptr, acc, mask=offs_h[:, None] < gqa_group_size) @triton.autotune( - configs=[ - triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [1, 2, 4]\ - for num_stages in [1, 2, 3, 4, 7] - ], - key=['BLOCK_D'], + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]], + key=["BLOCK_D"], ) @triton.jit def _merge_kernel( @@ -161,18 +147,15 @@ def _merge_kernel( offs_d = tl.arange(0, BLOCK_D) lse_offsets = lse_partial_ptr + batch_idx * lse_partial_stride_b + head_idx * lse_partial_stride_h - lse = tl.load( - lse_offsets + offs_splits * lse_partial_stride_split, - mask=offs_splits < num_splits, - other=float("-inf")) + lse = tl.load(lse_offsets + offs_splits * lse_partial_stride_split, mask=offs_splits < num_splits, other=float("-inf")) lse_max = tl.max(lse) o_offsets = o_partial_ptr + batch_idx * o_partial_stride_b + head_idx * o_partial_stride_h o_partial = tl.load( - o_offsets + offs_splits[:, None] * o_partial_stride_split + - offs_d[None, :] * o_partial_stride_d, - mask=offs_splits[:, None] < num_splits) + o_offsets + offs_splits[:, None] * o_partial_stride_split + offs_d[None, :] * o_partial_stride_d, + mask=offs_splits[:, None] < num_splits, + ) sumexp_normalized_splitk = tl.exp(lse - lse_max) sumexp_normalized = tl.sum(sumexp_normalized_splitk, axis=0) numerator_normalized = tl.sum(o_partial * sumexp_normalized_splitk[:, None], axis=0) @@ -207,19 +190,13 @@ def block_sparse_flash_decode_gqa_mask_triton( num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H num_n_blocks = max_selected_blocks - size_one_kv_head = max_selected_blocks * block_size * ( - dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2 + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 total_mblocks = batch * heads_kv * num_m_blocks num_sm = 64 # num_sm = self.num_sm num_splits = num_splits_heuristic( - total_mblocks, - num_sm, - num_n_blocks, - num_m_blocks, - size_one_kv_head, - is_causal_or_local=True, - max_splits=128) + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) # print("num_splits:", num_splits, "num_blocks:", num_n_blocks) @@ -292,24 +269,18 @@ def block_sparse_flash_decode_gqa_mask_triton( return output -def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, - block_size): - +def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size): batch, heads, dim = query.shape heads_kv = key.shape[2] num_head_groups = query.shape[1] // key.shape[2] scale = dim**0.5 - key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] - value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] + key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] - query = rearrange( - query, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] sparse_mask = torch.zeros_like(scores) # Assign mask values @@ -317,43 +288,34 @@ def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_se for h in range(heads_kv): for idx in range(num_blocks): if block_mask[b, h, idx]: - sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1 + sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1 - scores = scores.masked_fill(sparse_mask == 0, float('-inf')) + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) - range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0) + range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0) cache_seqlens_expanded = cache_seqlens.unsqueeze(1) pad_mask = range_len >= cache_seqlens_expanded pad_mask = pad_mask[:, None, None, :] - scores = scores.masked_fill(pad_mask, float('-inf')) - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = scores.masked_fill(pad_mask, float("-inf")) + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] - out = einsum(attention, value, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out def ref_program_fa(query, key, value, cache_seqlens): # latency reference # from flash_attn_interface import flash_attn_with_kvcache # fa3 - from flash_attn import flash_attn_with_kvcache #fa2 + from flash_attn import flash_attn_with_kvcache # fa2 + query = query.unsqueeze(1) output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) output = output.squeeze(1) return output -def main(batch=64, - heads=32, - heads_kv=8, - max_cache_seqlen=8192, - dim=128, - dim_v=128, - sparse_ratio=0.8, - block_size=32): - +def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v block_size = block_size sparse_ratio = sparse_ratio @@ -363,14 +325,13 @@ def main(batch=64, dtype = torch.float16 - Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') - K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') - V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') - cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda') + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") # Ensure at least one element equals cache_seqlen - random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index - cache_seqlens[ - random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence + random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index + cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence num_blocks = (max_cache_seqlen + block_size - 1) // block_size @@ -379,7 +340,7 @@ def main(batch=64, max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() print("max_valid_num_blocks: ", max_valid_num_blocks) # Initialize block_mask with false (for padding blocks) - block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device='cuda') + block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device="cuda") # Assign valid indices while ensuring no duplicates within each batch-group for b in range(batch): @@ -387,11 +348,10 @@ def main(batch=64, valid_num_block = valid_num_blocks[b].item() # Valid blocks for this batch if valid_num_block > 0: # Ensure there's at least one valid block for h in range(heads_kv): - perm = torch.randperm(max_valid_block, device='cuda')[:valid_num_block] + perm = torch.randperm(max_valid_block, device="cuda")[:valid_num_block] block_mask[b, h, perm] = True - ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, - block_size) + ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size) triton_out = block_sparse_flash_decode_gqa_mask_triton( Q, @@ -404,8 +364,7 @@ def main(batch=64, ) # print("max difference: ", torch.max(torch.abs(ref - triton_out))) - assert torch.allclose( - ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation" + assert torch.allclose(ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation" print("Passed the ref test!") # Measure performance @@ -448,15 +407,13 @@ def main(batch=64, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=64, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') - parser.add_argument( - '--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--dim_v', type=int, default=128, help='dim_v') - parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') - parser.add_argument('--block_size', type=int, default=32, help='block_size') + parser.add_argument("--batch", type=int, default=64, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio") + parser.add_argument("--block_size", type=int, default=32, help="block_size") args = parser.parse_args() - main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, - args.sparse_ratio, args.block_size) + main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size) diff --git a/examples/blocksparse_attention/heuristic.py b/examples/blocksparse_attention/heuristic.py index b60a81d..0e6fc52 100644 --- a/examples/blocksparse_attention/heuristic.py +++ b/examples/blocksparse_attention/heuristic.py @@ -1,8 +1,7 @@ import math -def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, num_m_blocks, size_one_kv_head, - is_causal_or_local, max_splits): +def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local, max_splits): """ Determines the optimal number of splits for maximizing GPU occupancy while balancing memory efficiency. diff --git a/examples/blocksparse_attention/test_example_blocksparse_attention.py b/examples/blocksparse_attention/test_example_blocksparse_attention.py index adda1f0..dd33f46 100644 --- a/examples/blocksparse_attention/test_example_blocksparse_attention.py +++ b/examples/blocksparse_attention/test_example_blocksparse_attention.py @@ -25,26 +25,14 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask(): def test_example_triton_sparse_gqa_decode_varlen_indice(): example_triton_sparse_gqa_decode_varlen_indice.main( - batch=8, - heads=8, - heads_kv=4, - max_cache_seqlen=2048, - dim=128, - dim_v=128, - sparse_ratio=0.8, - block_size=32) + batch=8, heads=8, heads_kv=4, max_cache_seqlen=2048, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32 + ) def test_example_triton_sparse_gqa_decode_varlen_mask(): example_triton_sparse_gqa_decode_varlen_mask.main( - batch=16, - heads=16, - heads_kv=8, - max_cache_seqlen=1024, - dim=128, - dim_v=128, - sparse_ratio=0.8, - block_size=32) + batch=16, heads=16, heads_kv=8, max_cache_seqlen=1024, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32 + ) if __name__ == "__main__": diff --git a/examples/blocksparse_gemm/example_blocksparse_gemm.py b/examples/blocksparse_gemm/example_blocksparse_gemm.py index 8cd3a82..0cbef5e 100644 --- a/examples/blocksparse_gemm/example_blocksparse_gemm.py +++ b/examples/blocksparse_gemm/example_blocksparse_gemm.py @@ -19,8 +19,7 @@ parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M") parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N") parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K") parser.add_argument("--sparsity", type=float, default=0.5, help="Sparsity ratio (0-1)") -parser.add_argument( - "--use_autotune", action="store_true", default=False, help="Whether to use autotune") +parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune") args, _ = parser.parse_known_args() M, N, K = args.m, args.n, args.k @@ -41,17 +40,19 @@ def get_configs(): thread_num = [128, 256] enable_rasterization = [True, False] - _configs = list( - itertools.product(block_M, block_N, block_K, num_stages, thread_num, enable_rasterization)) + _configs = list(itertools.product(block_M, block_N, block_K, num_stages, thread_num, enable_rasterization)) - return [{ - "block_M": c[0], - "block_N": c[1], - "block_K": c[2], - "num_stages": c[3], - "thread_num": c[4], - "enable_rasteration": c[5], - } for c in _configs] + return [ + { + "block_M": c[0], + "block_N": c[1], + "block_K": c[2], + "num_stages": c[3], + "thread_num": c[4], + "enable_rasteration": c[5], + } + for c in _configs + ] def ref_program(A, B, BlockMask, block_M, block_N, block_K): @@ -61,12 +62,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K): accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device) for k in range(K // block_K): if BlockMask[i, j, k]: - accu += ( - A[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to( - torch.float32) @ B[k * block_K:(k + 1) * block_K, - j * block_N:(j + 1) * block_N].to(torch.float32)) - ref_c[i * block_M:(i + 1) * block_M, - j * block_N:(j + 1) * block_N] = accu.to(torch.float16) + accu += A[i * block_M : (i + 1) * block_M, k * block_K : (k + 1) * block_K].to(torch.float32) @ B[ + k * block_K : (k + 1) * block_K, j * block_N : (j + 1) * block_N + ].to(torch.float32) + ref_c[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = accu.to(torch.float16) return ref_c @@ -89,28 +88,21 @@ def supply_program(params: List[KernelParam]): return input_tensors -@tilelang.autotune(configs=get_configs(),) +@tilelang.autotune( + configs=get_configs(), +) @tilelang.jit(out_idx=[-1]) -def blocksparse_matmul(M, - N, - K, - block_M, - block_N, - block_K, - num_stages, - thread_num, - enable_rasteration, - dtype="float16", - accum_dtype="float"): - +def blocksparse_matmul( + M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype="float16", accum_dtype="float" +): block_mask_shape = (M // block_M, N // block_N, K // block_K) @T.prim_func def block_sparse_matmul( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - BlockMask: T.Tensor(block_mask_shape, "bool"), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + BlockMask: T.Tensor(block_mask_shape, "bool"), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -134,7 +126,6 @@ def blocksparse_matmul(M, def main(): - # Initialize input matrices A and B on the GPU with half precision a = torch.randn(M, K).cuda().half() b = torch.randn(K, N).cuda().half() @@ -147,8 +138,7 @@ def main(): best_config = kernel.config best_latency = kernel.latency - block_M, block_N, block_K = best_config["block_M"], best_config["block_N"], best_config[ - "block_K"] + block_M, block_N, block_K = best_config["block_M"], best_config["block_N"], best_config["block_K"] print(f"Best Config: {best_config}") print(f"Sparsity Ratio: {sparsity}") @@ -163,7 +153,8 @@ def main(): block_K=DEFAULT_BLOCK_K, num_stages=DEFAULT_NUM_STAGES, thread_num=DEFAULT_THREAD_NUM, - enable_rasteration=DEFAULT_ENABLE_RASTERIZATION) + enable_rasteration=DEFAULT_ENABLE_RASTERIZATION, + ) block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})") # Create block mask with desired sparsity diff --git a/examples/cast/example_group_per_split_token_cast_to_fp8.py b/examples/cast/example_group_per_split_token_cast_to_fp8.py index 102ac20..ec15b29 100644 --- a/examples/cast/example_group_per_split_token_cast_to_fp8.py +++ b/examples/cast/example_group_per_split_token_cast_to_fp8.py @@ -16,11 +16,13 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m): fp8_max = 448.0 @T.prim_func - def group_per_split_token_cast(X: T.Tensor((M, N), dtype), batch_sizes: T.Tensor( - (BG,), "int32"), X_fp8: T.Tensor((BG, M_max, N), "float8_e4m3"), X_amax: T.Tensor( - (BG, M_max, T.ceildiv(N, group_size)), accum_dtype)): - with T.Kernel( - T.ceildiv(M_max, blk_m), T.ceildiv(N, group_size), BG, threads=128) as (bx, by, bz): + def group_per_split_token_cast( + X: T.Tensor((M, N), dtype), + batch_sizes: T.Tensor((BG,), "int32"), + X_fp8: T.Tensor((BG, M_max, N), "float8_e4m3"), + X_amax: T.Tensor((BG, M_max, T.ceildiv(N, group_size)), accum_dtype), + ): + with T.Kernel(T.ceildiv(M_max, blk_m), T.ceildiv(N, group_size), BG, threads=128) as (bx, by, bz): row = bx row_g_id = by bg = bz @@ -31,36 +33,32 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m): y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3") row_offset = T.alloc_fragment((1,), "int32") - T.annotate_layout({ - y_local: - T.Fragment( - y_local.shape, - forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32), - }) + T.annotate_layout( + { + y_local: T.Fragment(y_local.shape, forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32), + } + ) row_offset[0] = 0 for i in T.serial(bg): row_offset[0] += batch_sizes[i] T.copy( - X[row_offset[0] + row * blk_m:row_offset[0] + (row + 1) * blk_m, - row_g_id * group_size:(row_g_id + 1) * group_size], y_local) + X[row_offset[0] + row * blk_m : row_offset[0] + (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size], + y_local, + ) T.reduce_absmax(y_local, y_amax_local, dim=1) for i in T.Parallel(blk_m): y_amax_local[i] = T.max(y_amax_local[i], 1e-4) - y_s_local[i] = T.if_then_else(row * blk_m + i < batch_sizes[bg], - y_amax_local[i] / fp8_max, 0) + y_s_local[i] = T.if_then_else(row * blk_m + i < batch_sizes[bg], y_amax_local[i] / fp8_max, 0) for i, j in T.Parallel(blk_m, group_size): y_q_local[i, j] = T.clamp(y_local[i, j] / y_s_local[i], fp8_min, fp8_max) T.copy(y_q_local, y_q_local_fp8) for i, j in T.Parallel(blk_m, group_size): - y_q_local_fp8[i, j] = T.if_then_else(row * blk_m + i < batch_sizes[bg], - y_q_local[i, j], 0) + y_q_local_fp8[i, j] = T.if_then_else(row * blk_m + i < batch_sizes[bg], y_q_local[i, j], 0) for i in T.Parallel(blk_m): X_amax[bg, row * blk_m + i, row_g_id] = y_s_local[i] - T.copy( - y_q_local_fp8, X_fp8[bg, row * blk_m:(row + 1) * blk_m, - row_g_id * group_size:(row_g_id + 1) * group_size]) + T.copy(y_q_local_fp8, X_fp8[bg, row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size]) return group_per_split_token_cast @@ -127,8 +125,7 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: return x.squeeze(0) if remove_dim else x # Normal layout requires transposing - aligned_x = torch.transpose( - torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) + aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) aligned_x[:, :m, :] = x aligned_x = aligned_x[:, :m, :] return aligned_x.squeeze(0) if remove_dim else aligned_x @@ -146,15 +143,17 @@ def ref_per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens x_fp8 = x_fp8.view(m, -1)[:, :n].contiguous() return x_fp8, (x_amax / 448.0).view(m, -1) -def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \ - Tuple[torch.Tensor, torch.Tensor]: + +def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # assert x.shape[0] == batch_sizes.sum() M_max = ceil_div(batch_sizes.max(), 128) * 128 split_x = torch.split(x, batch_sizes.tolist(), dim=0) padded_x = [torch.nn.functional.pad(t, (0, 0, 0, M_max - t.shape[0])) for t in split_x] num_groups, m, n = batch_sizes.shape[0], M_max, x.shape[1] - x_fp8 = (torch.empty((num_groups, m, n), device='cuda', dtype=torch.float8_e4m3fn), - torch.empty((num_groups, m, n // 128), device='cuda', dtype=torch.float)) + x_fp8 = ( + torch.empty((num_groups, m, n), device="cuda", dtype=torch.float8_e4m3fn), + torch.empty((num_groups, m, n // 128), device="cuda", dtype=torch.float), + ) for i in range(num_groups): x_fp8[0][i], x_fp8[1][i] = ref_per_token_cast_to_fp8(padded_x[i]) x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) diff --git a/examples/cast/example_per_token_cast_to_fp8.py b/examples/cast/example_per_token_cast_to_fp8.py index 484a092..45281ab 100644 --- a/examples/cast/example_per_token_cast_to_fp8.py +++ b/examples/cast/example_per_token_cast_to_fp8.py @@ -13,8 +13,9 @@ def per_token_cast_to_fp8(M, N, blk_m): fp8_max = 448.0 @T.prim_func - def per_token_cast(X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), "float8_e4m3"), - X_amax: T.Tensor((M, T.ceildiv(N, group_size)), dtype)): + def per_token_cast( + X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), "float8_e4m3"), X_amax: T.Tensor((M, T.ceildiv(N, group_size)), dtype) + ): with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (bx, by): row = bx row_g_id = by @@ -24,16 +25,13 @@ def per_token_cast_to_fp8(M, N, blk_m): y_q_local = T.alloc_fragment((blk_m, group_size), dtype) y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3") - T.annotate_layout({ - y_local: - T.Fragment( - y_local.shape, - forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32), - }) + T.annotate_layout( + { + y_local: T.Fragment(y_local.shape, forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32), + } + ) - T.copy( - X[row * blk_m:(row + 1) * blk_m, row_g_id * group_size:(row_g_id + 1) * group_size], - y_local) + T.copy(X[row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size], y_local) T.reduce_absmax(y_local, y_amax_local, dim=1) for i in T.Parallel(blk_m): y_amax_local[i] = T.max(y_amax_local[i], 1e-4) @@ -43,9 +41,7 @@ def per_token_cast_to_fp8(M, N, blk_m): T.copy(y_q_local, y_q_local_fp8) for i in T.Parallel(blk_m): X_amax[row * blk_m + i, row_g_id] = y_s_local[i] - T.copy( - y_q_local_fp8, X_fp8[row * blk_m:(row + 1) * blk_m, - row_g_id * group_size:(row_g_id + 1) * group_size]) + T.copy(y_q_local_fp8, X_fp8[row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size]) return per_token_cast @@ -105,8 +101,7 @@ def main(M=8192, N=8192, blk_m=8): from example_triton_cast_to_fp8 import per_token_group_quant_fp8 def run_triton(): - x_fp8_triton_, x_amax_triton_ = per_token_group_quant_fp8( - x, 128, 1e-4, dtype=torch.float8_e4m3fn, column_major_scales=False) + x_fp8_triton_, x_amax_triton_ = per_token_group_quant_fp8(x, 128, 1e-4, dtype=torch.float8_e4m3fn, column_major_scales=False) return x_fp8_triton_, x_amax_triton_ x_fp8_triton, x_amax_triton = run_triton() diff --git a/examples/cast/example_triton_cast_to_fp8.py b/examples/cast/example_triton_cast_to_fp8.py index cc56def..1859433 100644 --- a/examples/cast/example_triton_cast_to_fp8.py +++ b/examples/cast/example_triton_cast_to_fp8.py @@ -128,9 +128,7 @@ def per_token_group_quant_fp8( Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. """ - assert (x.shape[-1] % - group_size == 0), (f"the last dimension of `x` {x.shape[-1]} must be divisible " - f"by `group_size` {group_size}") + assert x.shape[-1] % group_size == 0, f"the last dimension of `x` {x.shape[-1]} must be divisible by `group_size` {group_size}" assert x.stride(-1) == 1, "`x` groups must be contiguous" finfo = torch.finfo(dtype) diff --git a/examples/cast/test_example_cast.py b/examples/cast/test_example_cast.py index 1ca000e..e8b10a7 100644 --- a/examples/cast/test_example_cast.py +++ b/examples/cast/test_example_cast.py @@ -4,8 +4,7 @@ import example_per_token_cast_to_fp8 def test_example_group_per_split_token_cast_to_fp8(): - example_group_per_split_token_cast_to_fp8.main( - M=1024, N=1024, BG=2, blk_m=4, batch_sizes=[128, 896]) + example_group_per_split_token_cast_to_fp8.main(M=1024, N=1024, BG=2, blk_m=4, batch_sizes=[128, 896]) def test_example_per_token_cast_to_fp8(): diff --git a/examples/compile_flags/usecase.py b/examples/compile_flags/usecase.py index 8451b04..80e2b78 100644 --- a/examples/compile_flags/usecase.py +++ b/examples/compile_flags/usecase.py @@ -4,12 +4,11 @@ import tilelang.language as T # @tilelang.jit(compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"]) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): @@ -36,8 +35,7 @@ block_K = 32 func = matmul(M, N, K, block_M, block_N, block_K) -jit_kernel = tilelang.compile( - func, out_idx=[2], target="cuda", compile_flags="-O3 --use_fast_math --expt-relaxed-constexpr") +jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags="-O3 --use_fast_math --expt-relaxed-constexpr") # or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"]) # or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3 --use_fast_math --expt-relaxed-constexpr"]) diff --git a/examples/conftest.py b/examples/conftest.py index 9f49d40..4010e0d 100644 --- a/examples/conftest.py +++ b/examples/conftest.py @@ -33,12 +33,9 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config): "warnings", "error", } - if (sum( - len(terminalreporter.stats.get(k, [])) - for k in known_types.difference({"skipped", "deselected"})) == 0): + if sum(len(terminalreporter.stats.get(k, [])) for k in known_types.difference({"skipped", "deselected"})) == 0: terminalreporter.write_sep( "!", - (f"Error: No tests were collected. " - f"{dict(sorted((k, len(v)) for k, v in terminalreporter.stats.items()))}"), + (f"Error: No tests were collected. {dict(sorted((k, len(v)) for k, v in terminalreporter.stats.items()))}"), ) pytest.exit("No tests were collected.", returncode=5) diff --git a/examples/convolution/example_convolution.py b/examples/convolution/example_convolution.py index b2696ba..a84e587 100644 --- a/examples/convolution/example_convolution.py +++ b/examples/convolution/example_convolution.py @@ -14,7 +14,6 @@ def check_hopper(): def ref_program(stride, padding, dilation): - def main(A, B): A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W @@ -26,22 +25,7 @@ def ref_program(stride, padding, dilation): @tilelang.jit(out_idx=[2]) -def convolution(N, - C, - H, - W, - F, - K, - S, - D, - P, - block_M, - block_N, - block_K, - num_stages, - threads, - dtype="float16", - accum_dtype="float"): +def convolution(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"): KH, KW = K, K OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 @@ -51,13 +35,11 @@ def convolution(N, @T.prim_func def main( - data: T.Tensor((N, H, W, C), dtype), - kernel: T.Tensor((KH, KW, C, F), dtype), - out: T.Tensor((N, OH, OW, F), dtype), + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + out: T.Tensor((N, OH, OW, F), dtype), ): - with T.Kernel( - T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), - threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by): data_shared = T.alloc_shared((block_M, block_K), dtype) kernel_shared = T.alloc_shared((block_K, block_N), dtype) out_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -66,11 +48,13 @@ def convolution(N, kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) - T.annotate_layout({ - out_shared: tilelang.layout.make_swizzled_layout(out_shared), - data_shared: tilelang.layout.make_swizzled_layout(data_shared), - kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), - }) + T.annotate_layout( + { + out_shared: tilelang.layout.make_swizzled_layout(out_shared), + data_shared: tilelang.layout.make_swizzled_layout(data_shared), + kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), + } + ) T.clear(out_local) for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): @@ -82,10 +66,8 @@ def convolution(N, m = by * block_M + i access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P access_w = m % OW * S + k // C % KW * D - P - in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and - (access_w < W)) - data_shared[i, j] = T.if_then_else( - in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) + in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W) + data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) T.gemm(data_shared, kernel_shared, out_local) @@ -97,15 +79,15 @@ def convolution(N, def main(argv=None): parser = argparse.ArgumentParser() - parser.add_argument('--n', type=int, default=128, help='n') - parser.add_argument('--c', type=int, default=128, help='c') - parser.add_argument('--h', type=int, default=64, help='h') - parser.add_argument('--w', type=int, default=64, help='w') - parser.add_argument('--f', type=int, default=128, help='f') - parser.add_argument('--k', type=int, default=3, help='k') - parser.add_argument('--s', type=int, default=1, help='s') - parser.add_argument('--d', type=int, default=1, help='d') - parser.add_argument('--p', type=int, default=1, help='p') + parser.add_argument("--n", type=int, default=128, help="n") + parser.add_argument("--c", type=int, default=128, help="c") + parser.add_argument("--h", type=int, default=64, help="h") + parser.add_argument("--w", type=int, default=64, help="w") + parser.add_argument("--f", type=int, default=128, help="f") + parser.add_argument("--k", type=int, default=3, help="k") + parser.add_argument("--s", type=int, default=1, help="s") + parser.add_argument("--d", type=int, default=1, help="d") + parser.add_argument("--p", type=int, default=1, help="p") args = parser.parse_args(argv) N, C, H, W, F, K, S, D, P = args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p diff --git a/examples/convolution/example_convolution_autotune.py b/examples/convolution/example_convolution_autotune.py index 3936774..600b608 100644 --- a/examples/convolution/example_convolution_autotune.py +++ b/examples/convolution/example_convolution_autotune.py @@ -14,7 +14,6 @@ def check_hopper(): def ref_program(stride, padding, dilation): - def main(A, B): A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W @@ -40,7 +39,8 @@ def get_configs(): num_stages, thread_num, enable_rasterization, - )) + ) + ) configs = [ { @@ -50,7 +50,8 @@ def get_configs(): "num_stages": c[3], "thread_num": c[4], "enable_rasteration": c[5], # keep param name for backward-compat - } for c in _configs + } + for c in _configs ] return configs @@ -64,53 +65,18 @@ def get_heuristic_config() -> dict: sm_version = sm_major * 10 + sm_minor print(f"CUDA device capability: {sm_version}") if sm_version in {80}: - return { - "block_M": 128, - "block_N": 256, - "block_K": 32, - "num_stages": 2, - "thread_num": 128, - "enable_rasteration": True - } + return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 2, "thread_num": 128, "enable_rasteration": True} elif sm_version in {90}: - return { - "block_M": 128, - "block_N": 256, - "block_K": 64, - "num_stages": 3, - "thread_num": 256, - "enable_rasteration": True - } + return {"block_M": 128, "block_N": 256, "block_K": 64, "num_stages": 3, "thread_num": 256, "enable_rasteration": True} else: - return { - "block_M": 128, - "block_N": 256, - "block_K": 32, - "num_stages": 0, - "thread_num": 128, - "enable_rasteration": True - } + return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 0, "thread_num": 128, "enable_rasteration": True} @tilelang.autotune(configs=get_configs()) @tilelang.jit(out_idx=[2]) -def convolution(N, - C, - H, - W, - F, - K, - S, - D, - P, - block_M, - block_N, - block_K, - num_stages, - thread_num, - enable_rasteration, - dtype="float16", - accum_dtype="float"): +def convolution( + N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype="float16", accum_dtype="float" +): KH, KW = K, K OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 @@ -120,13 +86,11 @@ def convolution(N, @T.prim_func def main( - data: T.Tensor((N, H, W, C), dtype), - kernel: T.Tensor((KH, KW, C, F), dtype), - out: T.Tensor((N, OH, OW, F), dtype), + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + out: T.Tensor((N, OH, OW, F), dtype), ): - with T.Kernel( - T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), - threads=thread_num) as (bx, by): + with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=thread_num) as (bx, by): data_shared = T.alloc_shared((block_M, block_K), dtype) kernel_shared = T.alloc_shared((block_K, block_N), dtype) out_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -136,9 +100,11 @@ def convolution(N, out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) if is_hopper: - T.annotate_layout({ - out_shared: tilelang.layout.make_swizzled_layout(out_shared), - }) + T.annotate_layout( + { + out_shared: tilelang.layout.make_swizzled_layout(out_shared), + } + ) T.clear(out_local) for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): @@ -150,10 +116,8 @@ def convolution(N, m = by * block_M + i access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P access_w = m % OW * S + k // C % KW * D - P - in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and - (access_w < W)) - data_shared[i, j] = T.if_then_else( - in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) + in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W) + data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) T.gemm(data_shared, kernel_shared, out_local) @@ -166,17 +130,19 @@ def convolution(N, return main -def main(n: int = 128, - c: int = 128, - h: int = 64, - w: int = 64, - f: int = 128, - k: int = 3, - s: int = 1, - d: int = 1, - p: int = 1, - use_autotune: bool = False, - with_roller: bool = True): +def main( + n: int = 128, + c: int = 128, + h: int = 64, + w: int = 64, + f: int = 128, + k: int = 3, + s: int = 1, + d: int = 1, + p: int = 1, + use_autotune: bool = False, + with_roller: bool = True, +): N, C, H, W, F, K, S, D, P = n, c, h, w, f, k, s, d, p ref_prog = ref_program(S, P, D) @@ -196,25 +162,16 @@ def main(n: int = 128, if __name__ == "__main__": parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") - parser.add_argument('--n', type=int, default=128, help='n') - parser.add_argument('--c', type=int, default=128, help='c') - parser.add_argument('--h', type=int, default=64, help='h') - parser.add_argument('--w', type=int, default=64, help='w') - parser.add_argument('--f', type=int, default=128, help='f') - parser.add_argument('--k', type=int, default=3, help='k') - parser.add_argument('--s', type=int, default=1, help='s') - parser.add_argument('--d', type=int, default=1, help='d') - parser.add_argument('--p', type=int, default=1, help='p') - parser.add_argument( - "--use_autotune", - action="store_true", - default=False, - help="Whether to use autotune for matmul configs") - parser.add_argument( - "--with_roller", - action="store_true", - default=True, - help="Whether to enable BitBLAS roller for search space") + parser.add_argument("--n", type=int, default=128, help="n") + parser.add_argument("--c", type=int, default=128, help="c") + parser.add_argument("--h", type=int, default=64, help="h") + parser.add_argument("--w", type=int, default=64, help="w") + parser.add_argument("--f", type=int, default=128, help="f") + parser.add_argument("--k", type=int, default=3, help="k") + parser.add_argument("--s", type=int, default=1, help="s") + parser.add_argument("--d", type=int, default=1, help="d") + parser.add_argument("--p", type=int, default=1, help="p") + parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune for matmul configs") + parser.add_argument("--with_roller", action="store_true", default=True, help="Whether to enable BitBLAS roller for search space") args = parser.parse_args() - main(args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p, args.use_autotune, - args.with_roller) + main(args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p, args.use_autotune, args.with_roller) diff --git a/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py b/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py index 715f09a..8aba914 100644 --- a/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py +++ b/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py @@ -41,14 +41,13 @@ def tl_gemm( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - scales_a: T.Tensor(Scales_A_shape, "float32"), - scales_b: T.Tensor(Scales_B_shape, "float32"), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + scales_a: T.Tensor(Scales_A_shape, "float32"), + scales_b: T.Tensor(Scales_B_shape, "float32"), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) C_shared = T.alloc_shared(C_shared_shape, out_dtype) @@ -93,21 +92,18 @@ def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: m, n = x.shape x_view = x.view(m, -1, 128) x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view( - m, n), (x_amax / 448.0).view(m, -1) + return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape - x_padded = torch.zeros( - ceildiv(m, 128) * 128, ceildiv(n, 128) * 128, dtype=x.dtype, device=x.device) + x_padded = torch.zeros(ceildiv(m, 128) * 128, ceildiv(n, 128) * 128, dtype=x.dtype, device=x.device) x_padded[:m, :n] = x x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) - return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view( - x_view.size(0), x_view.size(2)) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype): @@ -127,13 +123,14 @@ def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype): c_acc.zero_() for k in range(ceildiv(K, 128)): c = torch._scaled_mm( - A_fp8[i * 128:(i + 1) * 128, k * 128:(k + 1) * 128], - B_fp8[j * 128:(j + 1) * 128, k * 128:(k + 1) * 128].T, + A_fp8[i * 128 : (i + 1) * 128, k * 128 : (k + 1) * 128], + B_fp8[j * 128 : (j + 1) * 128, k * 128 : (k + 1) * 128].T, scale_a=A_scales[i, k].view(128, 1).contiguous(), scale_b=B_scales[j, k].view(1, 128).contiguous(), - out_dtype=torch.bfloat16) + out_dtype=torch.bfloat16, + ) c_acc += c.to(torch.float32) - C[i * 128:(i + 1) * 128, j * 128:(j + 1) * 128] = c_acc.to(out_dtype) + C[i * 128 : (i + 1) * 128, j * 128 : (j + 1) * 128] = c_acc.to(out_dtype) return C diff --git a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py index 61c3b63..4995837 100644 --- a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py +++ b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py @@ -8,6 +8,7 @@ import argparse def get_configs(): import itertools + BLOCK_N = [16, 32, 64, 128] BLOCK_H = [16, 32, 64, 128] num_split = [1, 2, 4, 8, 16, 32] @@ -15,30 +16,26 @@ def get_configs(): _configs = list(itertools.product(BLOCK_N, BLOCK_H, num_split, threads)) - return [{ - "block_N": c[0], - "block_H": c[1], - "num_split": c[2], - "threads": c[3], - } for c in _configs] + return [ + { + "block_N": c[0], + "block_H": c[1], + "num_split": c[2], + "threads": c[3], + } + for c in _configs + ] @tilelang.autotune(configs=get_configs()) @tilelang.jit( - out_idx=[6], pass_configs={ + out_idx=[6], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashmla_decode(batch, - heads, - kv_head_num, - seqlen_kv, - dim, - pe_dim, - block_N, - block_H, - num_split, - threads=128): - scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) + }, +) +def flashmla_decode(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, threads=128): + scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) dtype = "float16" accum_dtype = "float" kv_group_num = heads // kv_head_num @@ -47,11 +44,11 @@ def flashmla_decode(batch, @T.macro def flash_attn( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=threads) as (bx, by): Q_local = T.alloc_fragment([block_H, dim], dtype) @@ -70,24 +67,19 @@ def flashmla_decode(batch, cur_kv_head = by // (kv_group_num // block_H) T.use_swizzle(10) - T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_local) - T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_local) + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_local) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_local) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = T.ceildiv(seqlen_kv, block_N) for k in T.Pipelined(loop_range, num_stages=0): - T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared) - T.copy(K_pe[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared) + T.copy(KV[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], KV_shared) + T.copy(K_pe[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_pe_shared) T.clear(acc_s) T.gemm(Q_local, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.gemm( - Q_pe_local, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_pe_local, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) @@ -107,20 +99,18 @@ def flashmla_decode(batch, T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, dim): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :]) + T.copy(acc_o, Output[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :]) @T.macro def flash_attn_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), ): - with T.Kernel( - batch, heads // min(block_H, kv_group_num), num_split, - threads=threads) as (bx, by, bz): + with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=threads) as (bx, by, bz): Q_local = T.alloc_fragment([block_H, dim], dtype) Q_pe_local = T.alloc_fragment([block_H, pe_dim], dtype) KV_shared = T.alloc_shared([block_N, dim], dtype) @@ -136,8 +126,8 @@ def flashmla_decode(batch, cur_kv_head = by // (kv_group_num // block_H) T.use_swizzle(10) - T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_local) - T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_local) + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_local) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_local) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -150,12 +140,7 @@ def flashmla_decode(batch, T.copy(K_pe[bx, kv_start:kv_end, cur_kv_head, :], K_pe_shared) T.clear(acc_s) T.gemm(Q_local, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.gemm( - Q_pe_local, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_pe_local, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) @@ -176,14 +161,14 @@ def flashmla_decode(batch, acc_o[i, j] /= logsum[i] for i in T.Parallel(block_H): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, glse[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz]) - T.copy(acc_o, Output_partial[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz, :]) + T.copy(logsum, glse[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz]) + T.copy(acc_o, Output_partial[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz, :]) @T.macro def combine( - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): with T.Kernel(heads, batch, threads=128) as (by, bz): po_local = T.alloc_fragment([dim], dtype) @@ -193,9 +178,11 @@ def flashmla_decode(batch, lse_max_local = T.alloc_local([1], accum_dtype) scale_local = T.alloc_local([1], accum_dtype) - T.annotate_layout({ - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) T.clear(lse_logsum_local) T.clear(o_accum_local) @@ -218,26 +205,26 @@ def flashmla_decode(batch, @T.prim_func def main_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial) combine(glse, Output_partial, Output) @T.prim_func def main_no_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): flash_attn(Q, Q_pe, KV, K_pe, Output) @@ -262,43 +249,36 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): dim = q.shape[-1] pe_dim = q_pe.shape[-1] num_head_groups = q.shape[1] // kv.shape[2] - scale = (dim + pe_dim)**0.5 - q = rearrange( - q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] - q_pe = rearrange( - q_pe, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] - kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] - k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim] + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] query = torch.concat([q, q_pe], dim=-1) key = torch.concat([kv, k_pe], dim=-1) - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] - out = einsum(attention, kv, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=128, help='batch size') - parser.add_argument('--heads', type=int, default=128, help='q heads number') - parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') - parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') - parser.add_argument('--dim', type=int, default=512, help='head dim') - parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') - parser.add_argument('--autotune', action='store_true', help='auto tune') + parser.add_argument("--batch", type=int, default=128, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") + parser.add_argument("--autotune", action="store_true", help="auto tune") args = parser.parse_args() batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim enable_autotune = args.autotune @@ -314,17 +294,7 @@ if __name__ == "__main__": if enable_autotune: kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim) else: - kernel = flashmla_decode( - batch, - heads, - kv_heads, - kv_ctx, - dim, - pe_dim, - BLOCK_N, - BLOCK_H, - num_split, - threads=threads) + kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, threads=threads) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) input_tensors = profiler._get_inputs() tilelang_output = kernel(*input_tensors) diff --git a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py index 0006d94..18c0a5f 100644 --- a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py +++ b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py @@ -32,8 +32,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): @torch.inference_mode() -def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, - h_kv, d, dv, causal, dtype): +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): blocked_v = blocked_k[..., :dv] def ref_mla(): @@ -94,8 +93,7 @@ def _mla_attn_kernel( offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) - offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[ - None, :] + offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :] q_nope = tl.load(Q_nope + offs_q_nope) offs_d_kpe = tl.arange(0, HEAD_DIM_KPE) @@ -141,9 +139,7 @@ def _mla_attn_kernel( e_sum = e_sum * re_scale + tl.sum(p, 1) e_max = n_e_max - offs_o = cur_batch * stride_o_b + cur_head[:, - None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[ - None, :] + offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :] tl.store(O + offs_o, acc / e_sum[:, None]) offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV tl.store(O + offs_o_1, e_max + tl.log(e_sum)) @@ -309,24 +305,30 @@ def mla_decode_triton( @torch.inference_mode() -def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, - cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - +def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): blocked_v = blocked_k[..., :dv] assert d > dv, "mla with rope dim should be larger than no rope dim" q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() - blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., - dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() def flash_mla_triton(): num_kv_splits = 32 o = torch.empty([b * s_q, h_q, dv]) attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) mla_decode_triton( - q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope.view(-1, dv), - blocked_k_pe.view(-1, d - dv), o, block_table, cache_seqlens, attn_logits, - num_kv_splits, 1 / math.sqrt(d), block_size) + q_nope.view(-1, h_q, dv), + q_pe.view(-1, h_q, d - dv), + blocked_k_nope.view(-1, dv), + blocked_k_pe.view(-1, d - dv), + o, + block_table, + cache_seqlens, + attn_logits, + num_kv_splits, + 1 / math.sqrt(d), + block_size, + ) return o.view([b, s_q, h_q, dv]) out_flash = flash_mla_triton() @@ -362,14 +364,15 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal q = torch.randn(b, s_q, h_q, d) block_size = 64 - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) - out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_a, lse_a, perf_a = baseline_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" if target not in ["flash_mla_triton"]: @@ -377,21 +380,14 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(dtype).bits // 8) - print( - f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s" - ) - print( - f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s" - ) + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.0f} TFLOPS, {bytes / 10**6 / perf_a:.0f} GB/s") + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - print( - f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}" - ) + print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") torch.set_default_dtype(dtype) device = torch.device("cuda:0") torch.set_default_device(device) @@ -408,19 +404,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): q = torch.randn(b, s_q, h_q, d) block_size = 64 - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(dtype).bits // 8) - print( - f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s" - ) + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") return bytes / 10**6 / perf_b @@ -429,26 +422,22 @@ available_targets = [ "flash_mla_triton", ] -shape_configs = [{ - "b": - batch, - "s_q": - 1, - "cache_seqlens": - torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), - "h_q": - head, - "h_kv": - 1, - "d": - 512 + 64, - "dv": - 512, - "causal": - True, - "dtype": - torch.float16 -} for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 16384] for head in [128]] +shape_configs = [ + { + "b": batch, + "s_q": 1, + "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), + "h_q": head, + "h_kv": 1, + "d": 512 + 64, + "dv": 512, + "causal": True, + "dtype": torch.float16, + } + for batch in [128] + for seqlen in [1024, 2048, 4096, 8192, 16384] + for head in [128] +] def get_args(): @@ -470,26 +459,54 @@ if __name__ == "__main__": for shape in shape_configs: if args.all: for target in available_targets: - perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], - shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], - shape["causal"], shape["dtype"]) + perf = compare_a( + target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' + f"{target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" ) elif args.compare: - perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], - shape["cache_seqlens"], shape["h_q"], shape["h_kv"], - shape["d"], shape["dv"], shape["causal"], shape["dtype"]) + perfa, prefb = compare_ab( + args.baseline, + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n' + f"{args.baseline},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perfa:.0f}\n" ) fout.write( - f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n' + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{prefb:.0f}\n" ) elif args.one: - perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], - shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], - shape["causal"], shape["dtype"]) + perf = compare_a( + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" ) diff --git a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py index 644f97d..861e841 100644 --- a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py +++ b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py @@ -29,8 +29,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): @torch.inference_mode() -def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, - h_kv, d, dv, causal, dtype): +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): blocked_v = blocked_k[..., :dv] def ref_mla(): @@ -91,8 +90,7 @@ def _mla_attn_kernel( offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) - offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[ - None, :] + offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :] q_nope = tl.load(Q_nope + offs_q_nope) offs_d_kpe = tl.arange(0, HEAD_DIM_KPE) @@ -138,9 +136,7 @@ def _mla_attn_kernel( e_sum = e_sum * re_scale + tl.sum(p, 1) e_max = n_e_max - offs_o = cur_batch * stride_o_b + cur_head[:, - None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[ - None, :] + offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :] tl.store(O + offs_o, acc / e_sum[:, None]) offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV tl.store(O + offs_o_1, e_max + tl.log(e_sum)) @@ -306,24 +302,30 @@ def mla_decode_triton( @torch.inference_mode() -def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, - cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - +def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): blocked_v = blocked_k[..., :dv] assert d > dv, "mla with rope dim should be larger than no rope dim" q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() - blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., - dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() def flash_mla_triton(): num_kv_splits = 32 o = torch.empty([b * s_q, h_q, dv]) attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) mla_decode_triton( - q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope.view(-1, dv), - blocked_k_pe.view(-1, d - dv), o, block_table, cache_seqlens, attn_logits, - num_kv_splits, 1 / math.sqrt(d), block_size) + q_nope.view(-1, h_q, dv), + q_pe.view(-1, h_q, d - dv), + blocked_k_nope.view(-1, dv), + blocked_k_pe.view(-1, d - dv), + o, + block_table, + cache_seqlens, + attn_logits, + num_kv_splits, + 1 / math.sqrt(d), + block_size, + ) return o.view([b, s_q, h_q, dv]) out_flash = flash_mla_triton() @@ -359,14 +361,15 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal q = torch.randn(b, s_q, h_q, d) block_size = 64 - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) - out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_a, lse_a, perf_a = baseline_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" if target not in ["flash_mla_triton"]: @@ -374,21 +377,14 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(dtype).bits // 8) - print( - f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s" - ) - print( - f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s" - ) + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.0f} TFLOPS, {bytes / 10**6 / perf_a:.0f} GB/s") + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - print( - f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}" - ) + print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") torch.set_default_dtype(dtype) device = torch.device("cuda:0") torch.set_default_device(device) @@ -405,19 +401,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): q = torch.randn(b, s_q, h_q, d) block_size = 64 - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(dtype).bits // 8) - print( - f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s" - ) + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") return bytes / 10**6 / perf_b @@ -426,26 +419,22 @@ available_targets = [ "flash_mla_triton", ] -shape_configs = [{ - "b": - batch, - "s_q": - 1, - "cache_seqlens": - torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), - "h_q": - head, - "h_kv": - 1, - "d": - 512 + 64, - "dv": - 512, - "causal": - True, - "dtype": - torch.float16 -} for batch in [64, 128] for seqlen in [1024, 2048, 4096, 8192, 16384] for head in [128]] +shape_configs = [ + { + "b": batch, + "s_q": 1, + "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), + "h_q": head, + "h_kv": 1, + "d": 512 + 64, + "dv": 512, + "causal": True, + "dtype": torch.float16, + } + for batch in [64, 128] + for seqlen in [1024, 2048, 4096, 8192, 16384] + for head in [128] +] def get_args(): @@ -467,26 +456,54 @@ if __name__ == "__main__": for shape in shape_configs: if args.all: for target in available_targets: - perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], - shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], - shape["causal"], shape["dtype"]) + perf = compare_a( + target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' + f"{target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" ) elif args.compare: - perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], - shape["cache_seqlens"], shape["h_q"], shape["h_kv"], - shape["d"], shape["dv"], shape["causal"], shape["dtype"]) + perfa, prefb = compare_ab( + args.baseline, + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n' + f"{args.baseline},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perfa:.0f}\n" ) fout.write( - f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n' + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{prefb:.0f}\n" ) elif args.one: - perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], - shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], - shape["causal"], shape["dtype"]) + perf = compare_a( + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" ) diff --git a/examples/deepseek_mla/benchmark_mla.py b/examples/deepseek_mla/benchmark_mla.py index a542ff6..544b5e1 100644 --- a/examples/deepseek_mla/benchmark_mla.py +++ b/examples/deepseek_mla/benchmark_mla.py @@ -33,8 +33,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): @torch.inference_mode() -def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, - h_kv, d, dv, causal, dtype): +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): blocked_v = blocked_k[..., :dv] def ref_mla(): @@ -61,8 +60,7 @@ def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, @torch.inference_mode() -def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, - h_kv, d, dv, causal, dtype): +def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): from flash_mla import flash_mla_with_kvcache, get_mla_metadata blocked_v = blocked_k[..., :dv] @@ -87,14 +85,13 @@ def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, @torch.inference_mode() -def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, - h_q, h_kv, d, dv, causal, dtype): +def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): # pip install flashinfer-python import flashinfer + assert d > dv, "mla with rope dim should be larger than no rope dim" q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() - blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., - dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() kv_indptr = [0] kv_indices = [] @@ -111,8 +108,7 @@ def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32) - mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( - torch.empty(128 * 1024 * 1024, dtype=torch.int8), backend="fa3") + mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(torch.empty(128 * 1024 * 1024, dtype=torch.int8), backend="fa3") mla_wrapper.plan( q_indptr, kv_indptr, @@ -129,12 +125,7 @@ def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q ) def flashinfer(): - output, lse = mla_wrapper.run( - q_nope.view(-1, h_q, dv), - q_pe.view(-1, h_q, d - dv), - blocked_k_nope, - blocked_k_pe, - return_lse=True) + output, lse = mla_wrapper.run(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope, blocked_k_pe, return_lse=True) return output.view(b, -1, h_q, dv), lse.view(b, h_q, 1) out_flash, lse_flash = flashinfer() @@ -177,8 +168,7 @@ def _mla_attn_kernel( offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) - offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[ - None, :] + offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :] q_nope = tl.load(Q_nope + offs_q_nope) offs_d_kpe = tl.arange(0, HEAD_DIM_KPE) @@ -224,9 +214,7 @@ def _mla_attn_kernel( e_sum = e_sum * re_scale + tl.sum(p, 1) e_max = n_e_max - offs_o = cur_batch * stride_o_b + cur_head[:, - None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[ - None, :] + offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :] tl.store(O + offs_o, acc / e_sum[:, None]) offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV tl.store(O + offs_o_1, e_max + tl.log(e_sum)) @@ -393,24 +381,30 @@ def mla_decode_triton( @torch.inference_mode() -def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, - cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - +def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): blocked_v = blocked_k[..., :dv] assert d > dv, "mla with rope dim should be larger than no rope dim" q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() - blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., - dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() def flash_mla_triton(): num_kv_splits = 32 o = torch.empty([b * s_q, h_q, dv]) attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) mla_decode_triton( - q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope.view(-1, dv), - blocked_k_pe.view(-1, d - dv), o, block_table, cache_seqlens, attn_logits, - num_kv_splits, 1 / math.sqrt(d), block_size) + q_nope.view(-1, h_q, dv), + q_pe.view(-1, h_q, d - dv), + blocked_k_nope.view(-1, dv), + blocked_k_pe.view(-1, d - dv), + o, + block_table, + cache_seqlens, + attn_logits, + num_kv_splits, + 1 / math.sqrt(d), + block_size, + ) return o.view([b, s_q, h_q, dv]) out_flash = flash_mla_triton() @@ -419,13 +413,10 @@ def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, @torch.inference_mode() -def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, - cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - +def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): assert d > dv, "mla with rope dim should be larger than no rope dim" q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() - blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., - dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() dpe = d - dv num_kv_splits = 1 @@ -434,8 +425,7 @@ def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) - kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, - num_kv_splits, block_size) + kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, num_kv_splits, block_size) def flash_mla_tilelang(): out = kernel( @@ -486,38 +476,31 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal q = torch.randn(b, s_q, h_q, d) block_size = 64 - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) - out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_a, lse_a, perf_a = baseline_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" - if target not in ["flashinfer", "flash_mla_triton", "tilelang" - ] and baseline not in ["flashinfer", "flash_mla_triton", "tilelang"]: + if target not in ["flashinfer", "flash_mla_triton", "tilelang"] and baseline not in ["flashinfer", "flash_mla_triton", "tilelang"]: # flashinfer has a different lse return value # flash_mla_triton and flash_mla_tilelang doesn't return lse torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(dtype).bits // 8) - print( - f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s" - ) - print( - f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s" - ) + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.0f} TFLOPS, {bytes / 10**6 / perf_a:.0f} GB/s") + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - print( - f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}" - ) + print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") torch.set_default_dtype(dtype) device = torch.device("cuda:0") torch.set_default_device(device) @@ -534,19 +517,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): q = torch.randn(b, s_q, h_q, d) block_size = 64 - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(dtype).bits // 8) - print( - f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s" - ) + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") return bytes / 10**6 / perf_b @@ -558,26 +538,22 @@ available_targets = [ "flash_mla_triton", ] -shape_configs = [{ - "b": - batch, - "s_q": - 1, - "cache_seqlens": - torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), - "h_q": - head, - "h_kv": - 1, - "d": - 512 + 64, - "dv": - 512, - "causal": - True, - "dtype": - torch.float16 -} for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 16384, 32768] for head in [128]] +shape_configs = [ + { + "b": batch, + "s_q": 1, + "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), + "h_q": head, + "h_kv": 1, + "d": 512 + 64, + "dv": 512, + "causal": True, + "dtype": torch.float16, + } + for batch in [128] + for seqlen in [1024, 2048, 4096, 8192, 16384, 32768] + for head in [128] +] def get_args(): @@ -599,26 +575,54 @@ if __name__ == "__main__": for shape in shape_configs: if args.all: for target in available_targets: - perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], - shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], - shape["causal"], shape["dtype"]) + perf = compare_a( + target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' + f"{target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" ) elif args.compare: - perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], - shape["cache_seqlens"], shape["h_q"], shape["h_kv"], - shape["d"], shape["dv"], shape["causal"], shape["dtype"]) + perfa, prefb = compare_ab( + args.baseline, + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n' + f"{args.baseline},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perfa:.0f}\n" ) fout.write( - f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n' + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{prefb:.0f}\n" ) elif args.one: - perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], - shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], - shape["causal"], shape["dtype"]) + perf = compare_a( + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" ) diff --git a/examples/deepseek_mla/example_mla_decode.py b/examples/deepseek_mla/example_mla_decode.py index 3932d11..733ae3c 100644 --- a/examples/deepseek_mla/example_mla_decode.py +++ b/examples/deepseek_mla/example_mla_decode.py @@ -8,11 +8,12 @@ import argparse @tilelang.jit( - out_idx=[6], pass_configs={ + out_idx=[6], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, - softmax_scale): + }, +) +def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, softmax_scale): scale = float(softmax_scale * 1.44269504) # log2(e) dtype = "float16" accum_dtype = "float" @@ -22,11 +23,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ @T.macro def flash_attn( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid): Q_shared = T.alloc_shared([block_H, dim], dtype) @@ -44,33 +45,24 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ logsum = T.alloc_fragment([block_H], accum_dtype) cur_kv_head = hid // (kv_group_num // block_H) - T.annotate_layout({ - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - }) + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + } + ) - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared) - T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = T.ceildiv(seqlen_kv, block_N) for k in T.Pipelined(loop_range, num_stages=2): - T.copy(KV[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared) - T.copy(K_pe[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared) - T.gemm( - Q_shared, - KV_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol, - clear_accum=True) - T.gemm( - Q_pe_shared, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) + T.copy(KV[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], KV_shared) + T.copy(K_pe[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_pe_shared) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) @@ -90,20 +82,18 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ for i, j in T.Parallel(block_H, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :]) + T.copy(O_shared, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :]) @T.macro def flash_attn_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), ): - with T.Kernel( - batch, heads // min(block_H, kv_group_num), num_split, - threads=256) as (bid, hid, bz): + with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=256) as (bid, hid, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) S_shared = T.alloc_shared([block_H, block_N], dtype) Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype) @@ -121,13 +111,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ cur_kv_head = hid // (kv_group_num // block_H) T.use_swizzle(10) - T.annotate_layout({ - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - S_shared: tilelang.layout.make_swizzled_layout(S_shared), - }) + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + S_shared: tilelang.layout.make_swizzled_layout(S_shared), + } + ) - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared) - T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -139,14 +131,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared) T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared) T.clear(acc_s) - T.gemm( - Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) - T.gemm( - Q_pe_shared, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) @@ -168,16 +154,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ acc_o[i, j] /= logsum[i] for i in T.Parallel(block_H): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, glse[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, bz]) + T.copy(logsum, glse[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz]) T.copy(acc_o, O_shared) - T.copy(O_shared, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, - bz, :]) + T.copy(O_shared, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz, :]) @T.macro def combine( - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): with T.Kernel(heads, batch, threads=128) as (hid, bz): po_local = T.alloc_fragment([dim], dtype) @@ -187,9 +172,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ lse_max_local = T.alloc_local([1], accum_dtype) scale_local = T.alloc_local([1], accum_dtype) - T.annotate_layout({ - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) T.clear(lse_logsum_local) T.clear(o_accum_local) @@ -212,26 +199,26 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ @T.prim_func def main_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial) combine(glse, Output_partial, Output) @T.prim_func def main_no_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): flash_attn(Q, Q_pe, KV, K_pe, Output) @@ -256,31 +243,24 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): dim = q.shape[-1] pe_dim = q_pe.shape[-1] num_head_groups = q.shape[1] // kv.shape[2] - scale = (dim + pe_dim)**0.5 - q = rearrange( - q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] - q_pe = rearrange( - q_pe, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] - kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] - k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim] + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] query = torch.concat([q, q_pe], dim=-1) key = torch.concat([kv, k_pe], dim=-1) - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] - out = einsum(attention, kv, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out @@ -298,10 +278,9 @@ def main( BLOCK_N = 64 BLOCK_H = min(64, heads // kv_heads) num_split = 1 - softmax_scale = (dim + pe_dim)**-0.5 + softmax_scale = (dim + pe_dim) ** -0.5 - kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, - softmax_scale) + kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, softmax_scale) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) profiler.assert_allclose(ref_program, rtol=1e-4, atol=1e-4) latency = profiler.do_bench(warmup=500) @@ -311,12 +290,12 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=132, help='batch size') - parser.add_argument('--heads', type=int, default=128, help='q heads number') - parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') - parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') - parser.add_argument('--dim', type=int, default=512, help='head dim') - parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') + parser.add_argument("--batch", type=int, default=132, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") args = parser.parse_args() batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim main(batch, heads, kv_heads, kv_ctx, dim, pe_dim) diff --git a/examples/deepseek_mla/example_mla_decode_paged.py b/examples/deepseek_mla/example_mla_decode_paged.py index d23ff00..dee05c1 100644 --- a/examples/deepseek_mla/example_mla_decode_paged.py +++ b/examples/deepseek_mla/example_mla_decode_paged.py @@ -8,22 +8,14 @@ import math @tilelang.jit( - out_idx=[8], pass_configs={ + out_idx=[8], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def mla_decode_tilelang(batch, - h_q, - h_kv, - max_seqlen_pad, - dv, - dpe, - block_N, - block_H, - num_split, - block_size, - softmax_scale=None): + }, +) +def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, block_H, num_split, block_size, softmax_scale=None): if softmax_scale is None: - softmax_scale = (dv + dpe)**-0.5 + softmax_scale = (dv + dpe) ** -0.5 scale = float(softmax_scale * 1.44269504) # log2(e) dtype = "float16" accum_dtype = "float" @@ -34,13 +26,13 @@ def mla_decode_tilelang(batch, @T.macro def flash_mla_kernel( - Q: T.Tensor([batch, h_q, dv], dtype), - Q_pe: T.Tensor([batch, h_q, dpe], dtype), - KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), - K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), - BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), - CACHE_SEQLENS: T.Tensor([batch], "int32"), - Output: T.Tensor([batch, h_q, dv], dtype), + Q: T.Tensor([batch, h_q, dv], dtype), + Q_pe: T.Tensor([batch, h_q, dpe], dtype), + KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), + K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), + BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), + CACHE_SEQLENS: T.Tensor([batch], "int32"), + Output: T.Tensor([batch, h_q, dv], dtype), ): with T.Kernel(batch, h_q // min(block_H, kv_group_num), threads=256) as (bx, by): Q_shared = T.alloc_shared([block_H, dv], dtype) @@ -59,13 +51,15 @@ def mla_decode_tilelang(batch, cur_kv_head = by // (kv_group_num // block_H) T.use_swizzle(10) - T.annotate_layout({ - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - S_shared: tilelang.layout.make_swizzled_layout(S_shared), - }) + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + S_shared: tilelang.layout.make_swizzled_layout(S_shared), + } + ) - T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared) - T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -73,25 +67,17 @@ def mla_decode_tilelang(batch, loop_range = T.ceildiv(CACHE_SEQLENS[bx], block_N) for kr in T.Pipelined(loop_range, num_stages=2): k = loop_range - 1 - kr - kv_start = BLOCK_TABLE[bx, (k * block_N) // - block_size] * block_size + (k * block_N) % block_size - T.copy(KV[kv_start:kv_start + block_N, cur_kv_head, :], KV_shared) - T.copy(K_pe[kv_start:kv_start + block_N, cur_kv_head, :], K_pe_shared) + kv_start = BLOCK_TABLE[bx, (k * block_N) // block_size] * block_size + (k * block_N) % block_size + T.copy(KV[kv_start : kv_start + block_N, cur_kv_head, :], KV_shared) + T.copy(K_pe[kv_start : kv_start + block_N, cur_kv_head, :], K_pe_shared) T.clear(acc_s) - T.gemm( - Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) - T.gemm( - Q_pe_shared, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) if kr == 0: for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else(k * block_N + j >= CACHE_SEQLENS[bx], - -T.infinity(accum_dtype), acc_s[i, j]) + acc_s[i, j] = T.if_then_else(k * block_N + j >= CACHE_SEQLENS[bx], -T.infinity(accum_dtype), acc_s[i, j]) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) @@ -109,21 +95,20 @@ def mla_decode_tilelang(batch, for i, j in T.Parallel(block_H, dv): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :]) + T.copy(O_shared, Output[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :]) @T.macro def flash_mla_split_kv_kernel( - Q: T.Tensor([batch, h_q, dv], dtype), - Q_pe: T.Tensor([batch, h_q, dpe], dtype), - KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), - K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), - BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), - CACHE_SEQLENS: T.Tensor([batch], "int32"), - glse: T.Tensor([batch, h_q, num_split], dtype), - Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), + Q: T.Tensor([batch, h_q, dv], dtype), + Q_pe: T.Tensor([batch, h_q, dpe], dtype), + KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), + K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), + BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), + CACHE_SEQLENS: T.Tensor([batch], "int32"), + glse: T.Tensor([batch, h_q, num_split], dtype), + Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), ): - with T.Kernel( - batch, h_q // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz): + with T.Kernel(batch, h_q // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz): Q_shared = T.alloc_shared([block_H, dv], dtype) S_shared = T.alloc_shared([block_H, block_N], dtype) Q_pe_shared = T.alloc_shared([block_H, dpe], dtype) @@ -141,13 +126,15 @@ def mla_decode_tilelang(batch, cur_kv_head = by // (kv_group_num // block_H) T.use_swizzle(10) - T.annotate_layout({ - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - S_shared: tilelang.layout.make_swizzled_layout(S_shared), - }) + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + S_shared: tilelang.layout.make_swizzled_layout(S_shared), + } + ) - T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared) - T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -155,28 +142,20 @@ def mla_decode_tilelang(batch, total_blocks = T.ceildiv(CACHE_SEQLENS[bx], block_N) blocks_per_split = T.floordiv(total_blocks, num_split) remaining_blocks = T.floormod(total_blocks, num_split) - loop_range = (blocks_per_split + T.if_then_else(bz < remaining_blocks, 1, 0)) + loop_range = blocks_per_split + T.if_then_else(bz < remaining_blocks, 1, 0) start = (blocks_per_split * bz + T.min(bz, remaining_blocks)) * block_N for k in T.Pipelined(loop_range, num_stages=2): - kv_start = BLOCK_TABLE[bx, (start + k * block_N) // - block_size] * block_size + (k * block_N) % block_size - T.copy(KV[kv_start:kv_start + block_N, cur_kv_head, :], KV_shared) - T.copy(K_pe[kv_start:kv_start + block_N, cur_kv_head, :], K_pe_shared) + kv_start = BLOCK_TABLE[bx, (start + k * block_N) // block_size] * block_size + (k * block_N) % block_size + T.copy(KV[kv_start : kv_start + block_N, cur_kv_head, :], KV_shared) + T.copy(K_pe[kv_start : kv_start + block_N, cur_kv_head, :], K_pe_shared) T.clear(acc_s) - T.gemm( - Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) - T.gemm( - Q_pe_shared, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else(start + k * block_N + j >= CACHE_SEQLENS[bx], - -T.infinity(accum_dtype), acc_s[i, j]) + acc_s[i, j] = T.if_then_else(start + k * block_N + j >= CACHE_SEQLENS[bx], -T.infinity(accum_dtype), acc_s[i, j]) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) @@ -196,15 +175,15 @@ def mla_decode_tilelang(batch, acc_o[i, j] /= logsum[i] for i in T.Parallel(block_H): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, glse[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz]) + T.copy(logsum, glse[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz]) T.copy(acc_o, O_shared) - T.copy(O_shared, Output_partial[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz, :]) + T.copy(O_shared, Output_partial[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz, :]) @T.macro def combine( - glse: T.Tensor([batch, h_q, num_split], dtype), - Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), - Output: T.Tensor([batch, h_q, dv], dtype), + glse: T.Tensor([batch, h_q, num_split], dtype), + Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), + Output: T.Tensor([batch, h_q, dv], dtype), ): with T.Kernel(h_q, batch, threads=128) as (by, bz): po_local = T.alloc_fragment([dv], dtype) @@ -214,9 +193,11 @@ def mla_decode_tilelang(batch, lse_max_local = T.alloc_local([1], accum_dtype) scale_local = T.alloc_local([1], accum_dtype) - T.annotate_layout({ - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) T.clear(lse_logsum_local) T.clear(o_accum_local) @@ -239,31 +220,30 @@ def mla_decode_tilelang(batch, @T.prim_func def main_split( - Q: T.Tensor([batch, h_q, dv], dtype), - Q_pe: T.Tensor([batch, h_q, dpe], dtype), - KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), - K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), - block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), - cache_seqlens: T.Tensor([batch], "int32"), - glse: T.Tensor([batch, h_q, num_split], dtype), - Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), - Output: T.Tensor([batch, h_q, dv], dtype), + Q: T.Tensor([batch, h_q, dv], dtype), + Q_pe: T.Tensor([batch, h_q, dpe], dtype), + KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), + K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), + block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), + cache_seqlens: T.Tensor([batch], "int32"), + glse: T.Tensor([batch, h_q, num_split], dtype), + Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), + Output: T.Tensor([batch, h_q, dv], dtype), ): - flash_mla_split_kv_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, glse, - Output_partial) + flash_mla_split_kv_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, glse, Output_partial) combine(glse, Output_partial, Output) @T.prim_func def main_no_split( - Q: T.Tensor([batch, h_q, dv], dtype), - Q_pe: T.Tensor([batch, h_q, dpe], dtype), - KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), - K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), - block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), - cache_seqlens: T.Tensor([batch], "int32"), - glse: T.Tensor([batch, h_q, num_split], dtype), - Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), - Output: T.Tensor([batch, h_q, dv], dtype), + Q: T.Tensor([batch, h_q, dv], dtype), + Q_pe: T.Tensor([batch, h_q, dpe], dtype), + KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), + K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), + block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), + cache_seqlens: T.Tensor([batch], "int32"), + glse: T.Tensor([batch, h_q, num_split], dtype), + Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), + Output: T.Tensor([batch, h_q, dv], dtype), ): flash_mla_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, Output) @@ -284,8 +264,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): s_q = query.shape[-2] s_k = key.shape[-2] attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype, device=query.device) - temp_mask = torch.ones( - s_q, s_k, dtype=torch.bool, device=query.device).tril(diagonal=s_k - s_q) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device=query.device).tril(diagonal=s_k - s_q) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) attn_weight += attn_bias @@ -295,8 +274,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): @torch.inference_mode() -def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, - h_kv, d, dv, causal, dtype): +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): # q: [b, s_q, h_q, d] # block_table: [b, max_seqlen_pad // block_size] # blocked_k: [b * max_seqlen_pad // block_size, block_size, h_kv, d] @@ -325,13 +303,10 @@ def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, return out_torch -def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, - h_q, h_kv, d, dv, causal, dtype): - +def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): assert d > dv, "mla with rope dim should be larger than no rope dim" q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() - blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., - dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() dpe = d - dv num_kv_splits = 1 @@ -341,8 +316,7 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) - kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, - num_kv_splits, block_size, softmax_scale) + kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, num_kv_splits, block_size, softmax_scale) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) def flash_mla_tilelang(): @@ -360,8 +334,7 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s out_flash = flash_mla_tilelang() t = do_bench(flash_mla_tilelang) - out_ref = run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, - cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_ref = run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) torch.testing.assert_close(out_flash, out_ref, rtol=0.01, atol=0.01) print("All close") return out_flash, t @@ -369,12 +342,12 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=128, help='batch size') - parser.add_argument('--h_q', type=int, default=128, help='q heads number') - parser.add_argument('--h_kv', type=int, default=1, help='kv heads number') - parser.add_argument('--cache_seqlen', type=int, default=8192, help='kv cache context length') - parser.add_argument('--d', type=int, default=576, help='query/key head dim, d = dv + dpe') - parser.add_argument('--dv', type=int, default=512, help='value head dim') + parser.add_argument("--batch", type=int, default=128, help="batch size") + parser.add_argument("--h_q", type=int, default=128, help="q heads number") + parser.add_argument("--h_kv", type=int, default=1, help="kv heads number") + parser.add_argument("--cache_seqlen", type=int, default=8192, help="kv cache context length") + parser.add_argument("--d", type=int, default=576, help="query/key head dim, d = dv + dpe") + parser.add_argument("--dv", type=int, default=512, help="value head dim") args = parser.parse_args() b, h_q, h_kv, cache_seqlen, d, dv = args.batch, args.h_q, args.h_kv, args.cache_seqlen, args.d, args.dv @@ -383,9 +356,7 @@ if __name__ == "__main__": s_q = 1 # for decode, s_q = 1 block_size = 64 - cache_seqlens = torch.tensor([cache_seqlen + 2 * i for i in range(b)], - dtype=torch.int32, - device=device) + cache_seqlens = torch.tensor([cache_seqlen + 2 * i for i in range(b)], dtype=torch.int32, device=device) dpe = d - dv causal = True @@ -397,12 +368,11 @@ if __name__ == "__main__": total_flops = s_q * total_seqlens * h_q * d * 2 q = torch.randn(b, s_q, h_q, d, dtype=dtype, device=device) - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32, - device=device).view(b, max_seqlen_pad // block_size) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32, device=device).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d, dtype=dtype, device=device) - out_flash, latency = run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_flash, latency = run_tilelang_mla( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) diff --git a/examples/deepseek_mla/example_mla_decode_persistent.py b/examples/deepseek_mla/example_mla_decode_persistent.py index 2f896f2..305fd30 100644 --- a/examples/deepseek_mla/example_mla_decode_persistent.py +++ b/examples/deepseek_mla/example_mla_decode_persistent.py @@ -9,11 +9,13 @@ import argparse @tilelang.jit( - out_idx=[6], pass_configs={ + out_idx=[6], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split): - scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) + scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) dtype = "float16" accum_dtype = "float" kv_group_num = heads // kv_head_num @@ -23,13 +25,13 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ @T.prim_func def main_split_persistent( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): with T.Kernel(sm_num, threads=256) as (block_id): Q_shared = T.alloc_shared([block_H, dim], dtype) @@ -53,11 +55,13 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ lse_max_local = T.alloc_local([1], accum_dtype) scale_local = T.alloc_local([1], accum_dtype) - T.annotate_layout({ - # O_shared: tilelang.layout.make_swizzled_layout(O_shared), - S_shared: tilelang.layout.make_swizzled_layout(S_shared), - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) + T.annotate_layout( + { + # O_shared: tilelang.layout.make_swizzled_layout(O_shared), + S_shared: tilelang.layout.make_swizzled_layout(S_shared), + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) T.use_swizzle(10) total_tiles = batch * (heads // min(block_H, kv_group_num)) * num_split @@ -70,8 +74,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ cur_kv_head = hid // (kv_group_num // block_H) if bid < batch and hid * VALID_BLOCK_H < heads and sid < num_split: - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared) - T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -83,26 +87,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared) T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared) T.clear(acc_s) - T.gemm( - Q_shared, - KV_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) - T.gemm( - Q_pe_shared, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - - scores_max[i] * scale) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) @@ -117,11 +110,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ acc_o[i, j] /= logsum[i] for i in T.Parallel(block_H): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, glse[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, sid]) + T.copy(logsum, glse[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, sid]) # T.copy(acc_o, O_shared) - T.copy( - acc_o, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, - sid, :]) + T.copy(acc_o, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, sid, :]) T.sync_grid() waves = T.ceildiv(heads * batch, sm_num) @@ -167,42 +158,35 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): dim = q.shape[-1] pe_dim = q_pe.shape[-1] num_head_groups = q.shape[1] // kv.shape[2] - scale = (dim + pe_dim)**0.5 - q = rearrange( - q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] - q_pe = rearrange( - q_pe, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] - kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] - k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim] + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] query = torch.concat([q, q_pe], dim=-1) key = torch.concat([kv, k_pe], dim=-1) - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] - out = einsum(attention, kv, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out def main(): parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=128, help='batch size') - parser.add_argument('--heads', type=int, default=128, help='q heads number') - parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') - parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') - parser.add_argument('--dim', type=int, default=512, help='head dim') - parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') + parser.add_argument("--batch", type=int, default=128, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") args = parser.parse_args() batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim) diff --git a/examples/deepseek_mla/example_mla_decode_ws.py b/examples/deepseek_mla/example_mla_decode_ws.py index fcd427e..3fb90a5 100644 --- a/examples/deepseek_mla/example_mla_decode_ws.py +++ b/examples/deepseek_mla/example_mla_decode_ws.py @@ -13,14 +13,19 @@ import argparse tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }, compile_flags=[ - "-O3", "-Wno-deprecated-declarations", "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "--expt-relaxed-constexpr", "--expt-extended-lambda", - "--ptxas-options=-v,--register-usage-level=10", "-DNDEBUG" + "-O3", + "-Wno-deprecated-declarations", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--ptxas-options=-v,--register-usage-level=10", + "-DNDEBUG", ], ) -def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, - softmax_scale): +def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, softmax_scale): sm_scale = float(softmax_scale * 1.44269504) # log2(e) dtype = "float16" accum_dtype = "float" @@ -30,11 +35,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ @T.macro def flash_attn( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=384) as (hid, bid): Q_shared_l = T.alloc_shared([block_H, dim // 2], dtype) @@ -75,16 +80,16 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ tx = T.get_thread_binding() - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, 0:dim // 2], Q_shared_l) - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, dim // 2:dim], Q_shared_r) - T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_tail_shared) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, 0 : dim // 2], Q_shared_l) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, dim // 2 : dim], Q_shared_r) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_tail_shared) T.barrier_arrive(bar_q) if tx < 128: T.set_max_nreg(240, 1) T.fill(sumexp, 0) - T.fill(m_i, -2**30) # avoid -inf - inf to cause nan + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan T.fill(acc_o_l, 0) T.barrier_wait(bar_q, 0) @@ -166,8 +171,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ for h_i in T.Parallel(block_H): sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale T.copy(acc_o_l, O_shared_l) - T.copy(O_shared_l, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, - 0:dim // 2]) + T.copy(O_shared_l, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, 0 : dim // 2]) elif tx >= 128 and tx < 256: T.set_max_nreg(168, 1) @@ -197,8 +201,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ acc_o_r[h_i, d_i] /= sum_exp_shared[h_i] T.copy(acc_o_r, O_shared_r) - T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, - dim // 2:dim]) + T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, dim // 2 : dim]) elif tx >= 256: # producer @@ -211,19 +214,17 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ with T.attr("default", "async_scope", 1): for u in T.serial(4): for v in T.vectorized(8): - KV_shared_0_l[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, - 64 * u + (tx - 256) % 8 * 8 + v] - KV_shared_0_r[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, dim // 2 + - 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] with T.attr("default", "async_scope", 1): for v in T.vectorized(8): - K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + - v] = K_pe[bid, kv_indices, cur_kv_head, - (tx - 256) % 8 * 8 + v] + K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[ + bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v + ] T.cp_async_barrier_noinc(bar_k_0_ready[0]) # Buffer 1 @@ -233,33 +234,29 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ with T.attr("default", "async_scope", 1): for u in T.serial(4): for v in T.vectorized(8): - KV_shared_1_l[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, - 64 * u + (tx - 256) % 8 * 8 + v] - KV_shared_1_r[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, dim // 2 + - 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] with T.attr("default", "async_scope", 1): for v in T.vectorized(8): - K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + - v] = K_pe[bid, kv_indices, cur_kv_head, - (tx - 256) % 8 * 8 + v] + K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[ + bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v + ] T.cp_async_barrier_noinc(bar_k_1_ready[0]) @T.macro def flash_attn_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), ): - with T.Kernel( - batch, heads // min(block_H, kv_group_num), num_split, - threads=384) as (bid, hid, bz): + with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=384) as (bid, hid, bz): Q_shared_l = T.alloc_shared([block_H, dim // 2], dtype) Q_shared_r = T.alloc_shared([block_H, dim // 2], dtype) Q_tail_shared = T.alloc_shared([block_H, pe_dim], dtype) @@ -298,16 +295,16 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ tx = T.get_thread_binding() - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, 0:dim // 2], Q_shared_l) - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, dim // 2:dim], Q_shared_r) - T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_tail_shared) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, 0 : dim // 2], Q_shared_l) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, dim // 2 : dim], Q_shared_r) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_tail_shared) T.barrier_arrive(bar_q) if tx < 128: T.set_max_nreg(240, 1) T.fill(sumexp, 0) - T.fill(m_i, -2**30) # avoid -inf - inf to cause nan + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan T.fill(acc_o_l, 0) T.barrier_wait(bar_q, 0) @@ -389,10 +386,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ for h_i in T.Parallel(block_H): sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale T.copy(acc_o_l, O_shared_l) - T.copy( - O_shared_l, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, - bz, 0:dim // 2]) - T.copy(sumexp, glse[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, bz]) + T.copy(O_shared_l, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz, 0 : dim // 2]) + T.copy(sumexp, glse[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz]) elif tx >= 128 and tx < 256: T.set_max_nreg(168, 1) @@ -422,9 +417,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ acc_o_r[h_i, d_i] /= sum_exp_shared[h_i] T.copy(acc_o_r, O_shared_r) - T.copy( - O_shared_r, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, - bz, dim // 2:dim]) + T.copy(O_shared_r, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz, dim // 2 : dim]) elif tx >= 256: # producer @@ -433,54 +426,48 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ # Buffer 0 T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): - kv_indices = (seqlen_kv // num_split) * bz + ( - i_i * 2) * block_N + r * 16 + (tx - 256) // 8 + kv_indices = (seqlen_kv // num_split) * bz + (i_i * 2) * block_N + r * 16 + (tx - 256) // 8 with T.attr("default", "async_scope", 1): for u in T.serial(4): for v in T.vectorized(8): - KV_shared_0_l[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, - 64 * u + (tx - 256) % 8 * 8 + v] - KV_shared_0_r[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, dim // 2 + - 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] with T.attr("default", "async_scope", 1): for v in T.vectorized(8): - K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + - v] = K_pe[bid, kv_indices, cur_kv_head, - (tx - 256) % 8 * 8 + v] + K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[ + bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v + ] T.cp_async_barrier_noinc(bar_k_0_ready[0]) # Buffer 1 T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): - kv_indices = (seqlen_kv // num_split) * bz + ( - i_i * 2 + 1) * block_N + r * 16 + (tx - 256) // 8 + kv_indices = (seqlen_kv // num_split) * bz + (i_i * 2 + 1) * block_N + r * 16 + (tx - 256) // 8 with T.attr("default", "async_scope", 1): for u in T.serial(4): for v in T.vectorized(8): - KV_shared_1_l[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, - 64 * u + (tx - 256) % 8 * 8 + v] - KV_shared_1_r[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, dim // 2 + - 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] with T.attr("default", "async_scope", 1): for v in T.vectorized(8): - K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + - v] = K_pe[bid, kv_indices, cur_kv_head, - (tx - 256) % 8 * 8 + v] + K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[ + bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v + ] T.cp_async_barrier_noinc(bar_k_1_ready[0]) @T.macro def combine( - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): with T.Kernel(heads, batch, threads=128) as (hid, bz): po_local = T.alloc_fragment([dim], dtype) @@ -490,9 +477,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ lse_max_local = T.alloc_local([1], accum_dtype) scale_local = T.alloc_local([1], accum_dtype) - T.annotate_layout({ - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) T.clear(lse_logsum_local) T.clear(o_accum_local) @@ -515,26 +504,26 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ @T.prim_func def main_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial) combine(glse, Output_partial, Output) @T.prim_func def main_no_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): flash_attn(Q, Q_pe, KV, K_pe, Output) @@ -559,31 +548,24 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): dim = q.shape[-1] pe_dim = q_pe.shape[-1] num_head_groups = q.shape[1] // kv.shape[2] - scale = (dim + pe_dim)**0.5 - q = rearrange( - q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] - q_pe = rearrange( - q_pe, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] - kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] - k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim] + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] query = torch.concat([q, q_pe], dim=-1) key = torch.concat([kv, k_pe], dim=-1) - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] - out = einsum(attention, kv, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out @@ -601,10 +583,9 @@ def main( BLOCK_N = 64 BLOCK_H = min(64, heads // kv_heads) num_split = 1 - softmax_scale = (dim + pe_dim)**-0.5 + softmax_scale = (dim + pe_dim) ** -0.5 - kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, - softmax_scale) + kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, softmax_scale) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) profiler.assert_allclose(ref_program, rtol=1e-4, atol=1e-4) latency = profiler.do_bench(warmup=500) @@ -614,12 +595,12 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=132, help='batch size') - parser.add_argument('--heads', type=int, default=128, help='q heads number') - parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') - parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') - parser.add_argument('--dim', type=int, default=512, help='head dim') - parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') + parser.add_argument("--batch", type=int, default=132, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") args = parser.parse_args() batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim main(batch, heads, kv_heads, kv_ctx, dim, pe_dim) diff --git a/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py b/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py index b141822..4a1a84c 100644 --- a/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py +++ b/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py @@ -8,11 +8,13 @@ import argparse @tilelang.jit( - out_idx=[-1], pass_configs={ + out_idx=[-1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H): - scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) + scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) dtype = "float16" q_dtype = "float8_e4m3" accum_dtype = "float" @@ -22,11 +24,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ @T.prim_func def main_no_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], q_dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], q_dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=256) as (bx, by): Q_shared = T.alloc_shared([block_H, dim], dtype) @@ -46,31 +48,27 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ cur_kv_head = by // (kv_group_num // block_H) T.use_swizzle(10) - T.annotate_layout({ - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - }) - - T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared) - T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + } + ) + + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) T.disable_warp_group_reg_alloc() loop_range = T.ceildiv(seqlen_kv, block_N) for k in T.Pipelined(loop_range, num_stages=2): - T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], qKV_shared) - T.copy(K_pe[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared) + T.copy(KV[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], qKV_shared) + T.copy(K_pe[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_pe_shared) T.copy(qKV_shared, KV_shared) T.clear(acc_s) - T.gemm( - Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) - T.gemm( - Q_pe_shared, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) @@ -90,7 +88,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ for i, j in T.Parallel(block_H, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :]) + T.copy(O_shared, Output[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :]) return main_no_split @@ -108,42 +106,35 @@ def ref_program(q, q_pe, kv, k_pe): dim = q.shape[-1] pe_dim = q_pe.shape[-1] num_head_groups = q.shape[1] // kv.shape[2] - scale = (dim + pe_dim)**0.5 - q = rearrange( - q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] - q_pe = rearrange( - q_pe, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] - kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] - k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim] + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] query = torch.concat([q, q_pe], dim=-1) key = torch.concat([kv, k_pe], dim=-1) - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] - out = einsum(attention, kv, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=128, help='batch size') - parser.add_argument('--heads', type=int, default=128, help='q heads number') - parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') - parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') - parser.add_argument('--dim', type=int, default=512, help='head dim') - parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') + parser.add_argument("--batch", type=int, default=128, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") args = parser.parse_args() batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim) diff --git a/examples/deepseek_mla/torch_refs.py b/examples/deepseek_mla/torch_refs.py index 4b4c888..aae6c7c 100644 --- a/examples/deepseek_mla/torch_refs.py +++ b/examples/deepseek_mla/torch_refs.py @@ -11,7 +11,7 @@ def flash_split_ref(Q, Q_pe, KV, K_pe): block_N = 64 seqlen_kv = KV.size(1) - scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) + scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) acc_s = torch.empty((batch, nheads, block_N), device="cuda", dtype=torch.float) acc_s_cast = torch.empty((batch, nheads, block_N), device="cuda", dtype=torch.float16) acc_o = torch.empty((batch, nheads, dim), device="cuda", dtype=torch.float) @@ -31,18 +31,20 @@ def flash_split_ref(Q, Q_pe, KV, K_pe): for ks in range(num_split): acc_o.fill_(0) logsum.fill_(0) - scores_max.fill_(float('-inf')) - scores_max_prev.fill_(float('-inf')) + scores_max.fill_(float("-inf")) + scores_max_prev.fill_(float("-inf")) for i in range(int((seqlen_kv // num_split) / block_N)): acc_s.fill_(0) - acc_s = torch.einsum('bhd,bkhd->bhk', Q_, - KV_[:, (seqlen_kv // num_split) * ks + - i * block_N:(seqlen_kv // num_split) * ks + - (i + 1) * block_N, :, :]) # [batch, nheads, block_N] + acc_s = torch.einsum( + "bhd,bkhd->bhk", + Q_, + KV_[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) # [batch, nheads, block_N] acc_s += torch.einsum( - 'bhd,bkhd->bhk', Q_pe_, - K_pe_[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks + - (i + 1) * block_N, :, :]) + "bhd,bkhd->bhk", + Q_pe_, + K_pe_[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) scores_max_prev = scores_max scores_max = acc_s.max(dim=-1, keepdim=False).values # [batch, nheads] scores_scale = torch.exp2(scores_max_prev - scores_max) # [batch, nheads] @@ -50,9 +52,10 @@ def flash_split_ref(Q, Q_pe, KV, K_pe): acc_s = torch.exp2(acc_s - scores_max[:, :, None]) acc_s_cast = acc_s.to(torch.float16) # [batch, nheads, block_N] acc_o += torch.einsum( - 'bhk,bkhd->bhd', acc_s_cast, - KV_[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks + - (i + 1) * block_N, :, :]) + "bhk,bkhd->bhd", + acc_s_cast, + KV_[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) scores_sum = acc_s.sum(dim=-1, keepdim=False) logsum = logsum * scores_scale + scores_sum acc_o /= logsum[:, :, None] diff --git a/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py b/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py index daee398..ea3f72c 100644 --- a/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py +++ b/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py @@ -14,21 +14,44 @@ from fla.ops.utils import prepare_token_indices from fla.utils import autocast_custom_fwd, contiguous -@triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, - 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor), -}) +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1]], - key=['BS', 'BK', 'BV'], + key=["BS", "BK", "BV"], ) @triton.jit -def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices, - block_counts, offsets, token_indices, T, H: tl.constexpr, - HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr, - S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr, - BV: tl.constexpr, USE_OFFSETS: tl.constexpr, - USE_BLOCK_COUNTS: tl.constexpr): +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H @@ -40,20 +63,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc NS = S - p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), - (1, 0)) + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # the Q block is kept in the shared memory throughout the whole kernel # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) - p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), - (G, BV), (1, 0)) + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_o_slc = tl.zeros([G, BV], dtype=tl.float32) - b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32) + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) b_acc_slc = tl.zeros([G], dtype=tl.float32) for i in range(NS): i_s = tl.load(block_indices + i).to(tl.int32) * BS @@ -66,7 +87,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) # [G, BS] b_s_slc = tl.dot(b_q, b_k_slc) - b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf')) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) # [G] b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc @@ -87,7 +108,6 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -100,8 +120,7 @@ class ParallelNSAFunction(torch.autograd.Function): # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] token_indices = prepare_token_indices(offsets) if offsets is not None else None - o, lse = parallel_nsa_fwd( - q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) + o, lse = parallel_nsa_fwd(q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) ctx.save_for_backward(q, k, v, o, lse) ctx.block_indices = block_indices ctx.block_size = block_size @@ -172,7 +191,6 @@ def parallel_nsa_fwd( @torch.compile class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -195,7 +213,8 @@ class ParallelNSAFunction(torch.autograd.Function): window_size=window_size, scale=scale, offsets=offsets, - token_indices=token_indices) + token_indices=token_indices, + ) ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa) ctx.block_indices = block_indices ctx.block_counts = block_counts @@ -207,18 +226,20 @@ class ParallelNSAFunction(torch.autograd.Function): return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa -def parallel_nsa(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_indices: torch.LongTensor, - block_counts: Optional[Union[torch.LongTensor, int]] = None, - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False) -> torch.Tensor: +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: r""" Args: q (torch.Tensor): @@ -258,44 +279,44 @@ def parallel_nsa(q: torch.Tensor, Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), - (q, k, v, block_indices)) - g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') + block_counts = rearrange(block_counts, "b h t -> b t h") assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" if isinstance(block_counts, int): block_indices = block_indices[:, :, :, :block_counts] block_counts = None - o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, - window_size, scale, cu_seqlens) + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) if window_size > 0: o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) else: o = o_slc * g_slc.unsqueeze(-1) if head_first: - o = rearrange(o, 'b t h d -> b h t d') + o = rearrange(o, "b t h d -> b h t d") return o -def naive_nsa(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_indices: torch.LongTensor, - block_counts: Optional[Union[torch.LongTensor, int]] = None, - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False) -> torch.Tensor: +def naive_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: r""" Args: q (torch.Tensor): @@ -335,26 +356,24 @@ def naive_nsa(q: torch.Tensor, Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - raise RuntimeError( - "Sequences with variable lengths are not supported for head-first mode") + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), - (q, k, v, block_indices)) - g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') + block_counts = rearrange(block_counts, "b h t -> b t h") dtype = q.dtype G = q.shape[2] // k.shape[2] BS = block_size S = block_indices.shape[-1] - k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices)) + k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices)) if isinstance(block_counts, torch.Tensor): - block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G) + block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G) c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device) q, k, v = map(lambda x: x.float(), (q, k, v)) @@ -364,14 +383,11 @@ def naive_nsa(q: torch.Tensor, if cu_seqlens is None: varlen = False B, T = q.shape[:2] - cu_seqlens = torch.cat( - [block_indices.new_tensor(range(0, B * T, T)), - block_indices.new_tensor([B * T])]) + cu_seqlens = torch.cat([block_indices.new_tensor(range(0, B * T, T)), block_indices.new_tensor([B * T])]) for i in range(len(cu_seqlens) - 1): if not varlen: - q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[ - i], block_indices[i] + q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[i], block_indices[i] if isinstance(block_counts, torch.Tensor): s_b = block_counts[i] else: @@ -379,10 +395,10 @@ def naive_nsa(q: torch.Tensor, else: T = cu_seqlens[i + 1] - cu_seqlens[i] q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = map( - lambda x: x[0][cu_seqlens[i]:cu_seqlens[i + 1]], - (q, k, v, g_slc, g_swa, block_indices)) + lambda x: x[0][cu_seqlens[i] : cu_seqlens[i + 1]], (q, k, v, g_slc, g_swa, block_indices) + ) if isinstance(block_counts, torch.Tensor): - s_b = block_counts[0][cu_seqlens[i]:cu_seqlens[i + 1]] + s_b = block_counts[0][cu_seqlens[i] : cu_seqlens[i + 1]] else: s_b = block_counts @@ -404,71 +420,58 @@ def naive_nsa(q: torch.Tensor, else: s_i = s_b # [S*BS, HQ, -1] - k_i_slc, v_i_slc = map( - lambda x: x.gather( - 0, - i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b)) + k_i_slc, v_i_slc = map(lambda x: x.gather(0, i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b)) # [S*BS, HQ] - attn_slc = torch.einsum('h d, n h d -> n h', q_i, k_i_slc).masked_fill( - torch.logical_or(i_i < 0, i_i > i_q) | - (c >= s_i if block_counts is not None else False), float('-inf')).softmax(0) + attn_slc = ( + torch.einsum("h d, n h d -> n h", q_i, k_i_slc) + .masked_fill(torch.logical_or(i_i < 0, i_i > i_q) | (c >= s_i if block_counts is not None else False), float("-inf")) + .softmax(0) + ) if not varlen: - o_slc[i, i_q] = torch.einsum('n h, n h v -> h v', attn_slc, - v_i_slc) * g_slc_i.unsqueeze(-1) + o_slc[i, i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1) else: - o_slc[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_slc, - v_i_slc) * g_slc_i.unsqueeze(-1) + o_slc[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1) if window_size > 0: - k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1):i_q + 1], - (k_b, v_b)) - attn_swa = torch.einsum('h d, n h d -> n h', q_i, k_i_swa).softmax(0) + k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1) : i_q + 1], (k_b, v_b)) + attn_swa = torch.einsum("h d, n h d -> n h", q_i, k_i_swa).softmax(0) if not varlen: - o_swa[i, i_q] = torch.einsum('n h, n h v -> h v', attn_swa, - v_i_swa) * g_swa_i.unsqueeze(-1) + o_swa[i, i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1) else: - o_swa[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_swa, - v_i_swa) * g_swa_i.unsqueeze(-1) + o_swa[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1) if head_first: - o_slc = rearrange(o_slc, 'b t h d -> b h t d') - o_swa = rearrange(o_swa, 'b t h d -> b h t d') + o_slc = rearrange(o_slc, "b t h d -> b h t d") + o_swa = rearrange(o_swa, "b t h d -> b h t d") return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype) def get_configs(): import itertools + iter_params = dict( block_T=[128, 256, 512], num_stages=[0, 1, 2, 4, 5], threads=[32, 64, 128, 256, 512], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] -@tilelang.autotune(configs=get_configs(),) +@tilelang.autotune( + configs=get_configs(), +) @tilelang.jit( pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) -def tilelang_sparse_attention(batch, - heads, - seq_len, - dim, - is_causal, - scale=None, - block_size=64, - groups=1, - selected_blocks=16, - block_T=128, - num_stages=2, - threads=32): + } +) +def tilelang_sparse_attention( + batch, heads, seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16, block_T=128, num_stages=2, threads=32 +): if scale is None: - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) else: scale = scale * 1.44269504 # log2(e) @@ -493,11 +496,11 @@ def tilelang_sparse_attention(batch, @T.prim_func def tilelang_sparse_attention( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([G, BK], dtype) @@ -520,7 +523,7 @@ def tilelang_sparse_attention(batch, i_b, i_h = i_bh // head_kv, i_bh % head_kv NS = S - T.copy(Q[i_b, i_t, i_h * G:(i_h + 1) * G, :], Q_shared) + T.copy(Q[i_b, i_t, i_h * G : (i_h + 1) * G, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) @@ -530,21 +533,15 @@ def tilelang_sparse_attention(batch, i_s = BlockIndices[i_b, i_t, i_h, i] * BS if i_s <= i_t and i_s >= 0: # [BS, BK] - T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared) + T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared) if is_causal: for i, j in T.Parallel(G, BS): - acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) # Softmax T.copy(scores_max, scores_max_prev) @@ -564,45 +561,33 @@ def tilelang_sparse_attention(batch, acc_o[i, j] *= scores_scale[i] # V * softmax(Q * K) - T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared) + T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(G, BV): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[i_b, i_t, i_h * G:(i_h + 1) * G, i_v * BV:(i_v + 1) * BV]) + T.copy(O_shared, Output[i_b, i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV]) return tilelang_sparse_attention def generate_block_indices(batch, seq_len, heads, selected_blocks, block_size): """Generate random block indices for the benchmark.""" - block_indices = torch.full((batch, seq_len, heads, selected_blocks), - seq_len, - dtype=torch.long, - device='cuda') + block_indices = torch.full((batch, seq_len, heads, selected_blocks), seq_len, dtype=torch.long, device="cuda") for b in range(batch): for t in range(seq_len): for h in range(heads): i_i = torch.randperm(max(1, (t // block_size)))[:selected_blocks] - block_indices[b, t, h, :len(i_i)] = i_i + block_indices[b, t, h, : len(i_i)] = i_i return block_indices.sort(-1)[0] -def benchmark_nsa(batch_size, - seq_len, - heads, - head_query, - dim, - selected_blocks, - block_size, - dtype, - scale, - warmup=10, - iterations=100, - validate=False): +def benchmark_nsa( + batch_size, seq_len, heads, head_query, dim, selected_blocks, block_size, dtype, scale, warmup=10, iterations=100, validate=False +): """Benchmark the TileLang Sparse Attention implementation.""" # Set random seed for reproducibility @@ -628,14 +613,13 @@ def benchmark_nsa(batch_size, print(f"Profiler latency: {profiler_latency} ms") # Create input tensors - Q = torch.randn((batch_size, seq_len, head_query, dim), dtype=dtype, device='cuda') - K = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device='cuda') - V = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device='cuda') - out = torch.empty((batch_size, seq_len, head_query, dim), dtype=dtype, device='cuda') + Q = torch.randn((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda") + K = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda") + V = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda") + out = torch.empty((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda") # Generate block indices - block_indices = generate_block_indices(batch_size, seq_len, heads, selected_blocks, - block_size).to(torch.int32) + block_indices = generate_block_indices(batch_size, seq_len, heads, selected_blocks, block_size).to(torch.int32) # Warmup for _ in range(warmup): @@ -666,10 +650,9 @@ def benchmark_nsa(batch_size, # Validate result against reference if requested if validate: - g_slc = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device='cuda') - g_swa = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device='cuda') - block_counts = torch.randint( - 1, selected_blocks + 1, (batch_size, seq_len, heads), device='cuda') + g_slc = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda") + g_swa = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda") + block_counts = torch.randint(1, selected_blocks + 1, (batch_size, seq_len, heads), device="cuda") ref = naive_nsa( q=Q, @@ -700,22 +683,13 @@ def benchmark_nsa(batch_size, "head_query": head_query, "dim": dim, "selected_blocks": selected_blocks, - "block_size": block_size + "block_size": block_size, } -def benchmark_triton_nsa(batch_size, - seq_len, - heads, - head_query, - dim, - selected_blocks, - block_size, - dtype, - scale, - warmup=10, - iterations=100, - validate=False): +def benchmark_triton_nsa( + batch_size, seq_len, heads, head_query, dim, selected_blocks, block_size, dtype, scale, warmup=10, iterations=100, validate=False +): """Benchmark the Triton-based TileLang Sparse Attention implementation.""" # Set random seed for reproducibility @@ -723,18 +697,17 @@ def benchmark_triton_nsa(batch_size, torch.random.manual_seed(0) # Create input tensors - Q = torch.randn((batch_size, seq_len, head_query, dim), dtype=dtype, device='cuda') - K = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device='cuda') - V = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device='cuda') - g_slc = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device='cuda') - g_swa = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device='cuda') + Q = torch.randn((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda") + K = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda") + V = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda") + g_slc = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda") + g_swa = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda") # Generate block indices block_indices = generate_block_indices(batch_size, seq_len, heads, selected_blocks, block_size) - block_counts = torch.randint( - 1, selected_blocks + 1, (batch_size, seq_len, heads), device='cuda') - o_slc = torch.empty((batch_size, seq_len, head_query, dim), dtype=dtype, device='cuda') - lse_slc = torch.empty((batch_size, seq_len, head_query), dtype=torch.float, device='cuda') + block_counts = torch.randint(1, selected_blocks + 1, (batch_size, seq_len, heads), device="cuda") + o_slc = torch.empty((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda") + lse_slc = torch.empty((batch_size, seq_len, head_query), dtype=torch.float, device="cuda") # Warmup for _ in range(warmup): @@ -750,7 +723,8 @@ def benchmark_triton_nsa(batch_size, block_counts=block_counts, block_size=block_size, window_size=0, - scale=scale) + scale=scale, + ) # Synchronize before timing torch.cuda.synchronize() @@ -770,7 +744,8 @@ def benchmark_triton_nsa(batch_size, block_counts=block_counts, block_size=block_size, window_size=0, - scale=scale) + scale=scale, + ) torch.cuda.synchronize() end_time = time.time() @@ -815,54 +790,28 @@ def benchmark_triton_nsa(batch_size, "head_query": head_query, "dim": dim, "selected_blocks": selected_blocks, - "block_size": block_size + "block_size": block_size, } -def run_benchmark_suite(impl='all'): +def run_benchmark_suite(impl="all"): """Run a suite of benchmarks with different configurations.""" # Define configurations to benchmark configs = [ # Small model config - Note: head_query must be a multiple of heads*16 for Triton - { - "batch_size": 2, - "seq_len": 1024, - "heads": 8, - "head_query": 8 * 16, - "dim": 64, - "selected_blocks": 8, - "block_size": 32 - }, - + {"batch_size": 2, "seq_len": 1024, "heads": 8, "head_query": 8 * 16, "dim": 64, "selected_blocks": 8, "block_size": 32}, # Medium model config - { - "batch_size": 2, - "seq_len": 2048, - "heads": 16, - "head_query": 16 * 16, - "dim": 64, - "selected_blocks": 16, - "block_size": 64 - }, - + {"batch_size": 2, "seq_len": 2048, "heads": 16, "head_query": 16 * 16, "dim": 64, "selected_blocks": 16, "block_size": 64}, # Large model config - { - "batch_size": 1, - "seq_len": 4096, - "heads": 32, - "head_query": 32 * 16, - "dim": 128, - "selected_blocks": 32, - "block_size": 128 - }, + {"batch_size": 1, "seq_len": 4096, "heads": 32, "head_query": 32 * 16, "dim": 128, "selected_blocks": 32, "block_size": 128}, ] results = [] for config in configs: print(f"Running benchmark with config: {config}") - if impl in ['all', 'tilelang']: + if impl in ["all", "tilelang"]: print("Benchmarking TileLang implementation:") result = benchmark_nsa( batch_size=config["batch_size"], @@ -874,12 +823,13 @@ def run_benchmark_suite(impl='all'): block_size=config["block_size"], dtype=torch.float16, scale=0.1, - validate=False) + validate=False, + ) results.append({"impl": "tilelang", **result}) print(f"Average time: {result['avg_time_ms']:.2f} ms") print(f"Performance: {result['tflops']:.2f} TFLOPs") - if impl in ['all', 'triton']: + if impl in ["all", "triton"]: print("Benchmarking Triton implementation:") result = benchmark_triton_nsa( batch_size=config["batch_size"], @@ -891,19 +841,24 @@ def run_benchmark_suite(impl='all'): block_size=config["block_size"], dtype=torch.float16, scale=0.1, - validate=False) + validate=False, + ) results.append({"impl": "triton", **result}) print(f"Average time: {result['avg_time_ms']:.2f} ms") print(f"Performance: {result['tflops']:.2f} TFLOPs") - if impl in ['all']: + if impl in ["all"]: # Print comparison if both implementations were run tilelang_result = next( - r for r in results if r["impl"] == "tilelang" and - r["batch_size"] == config["batch_size"] and r["seq_len"] == config["seq_len"]) + r + for r in results + if r["impl"] == "tilelang" and r["batch_size"] == config["batch_size"] and r["seq_len"] == config["seq_len"] + ) triton_result = next( - r for r in results if r["impl"] == "triton" and - r["batch_size"] == config["batch_size"] and r["seq_len"] == config["seq_len"]) + r + for r in results + if r["impl"] == "triton" and r["batch_size"] == config["batch_size"] and r["seq_len"] == config["seq_len"] + ) speedup = tilelang_result["avg_time_ms"] / triton_result["avg_time_ms"] print(f"Speedup (Triton vs TileLang): {speedup:.2f}x") @@ -921,8 +876,7 @@ if __name__ == "__main__": parser.add_argument("--dim", type=int, default=128, help="Head dimension") parser.add_argument("--selected_blocks", type=int, default=16, help="Number of selected blocks") parser.add_argument("--block_size", type=int, default=32, help="Block size") - parser.add_argument( - "--dtype", type=str, default="float16", help="Data type (float16 or float32)") + parser.add_argument("--dtype", type=str, default="float16", help="Data type (float16 or float32)") parser.add_argument("--scale", type=float, default=0.1, help="Attention scale factor") parser.add_argument("--iterations", type=int, default=100, help="Number of iterations") parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations") @@ -933,7 +887,8 @@ if __name__ == "__main__": type=str, default="all", choices=["tilelang", "triton", "all"], - help="Implementation to benchmark (tilelang, triton, or all)") + help="Implementation to benchmark (tilelang, triton, or all)", + ) args = parser.parse_args() @@ -941,8 +896,7 @@ if __name__ == "__main__": if args.impl in ["triton", "all"] and args.head_query % (args.heads * 16) != 0: # Adjust head_query to nearest valid value args.head_query = ((args.head_query // (args.heads * 16)) + 1) * (args.heads * 16) - print( - f"Adjusted head_query to {args.head_query} to be compatible with Triton implementation") + print(f"Adjusted head_query to {args.head_query} to be compatible with Triton implementation") if args.suite: run_benchmark_suite(impl=args.impl) @@ -963,12 +917,14 @@ if __name__ == "__main__": scale=args.scale, warmup=args.warmup, iterations=args.iterations, - validate=args.validate) + validate=args.validate, + ) print("\nBenchmark Results (TileLang):") print( - f"Configuration: batch={args.batch}, seq_len={args.seq_len}, heads={args.heads}, " + - f"head_query={args.head_query}, dim={args.dim}, blocks={args.selected_blocks}, " + - f"block_size={args.block_size}") + f"Configuration: batch={args.batch}, seq_len={args.seq_len}, heads={args.heads}, " + + f"head_query={args.head_query}, dim={args.dim}, blocks={args.selected_blocks}, " + + f"block_size={args.block_size}" + ) print(f"Average time: {result['avg_time_ms']:.2f} ms") print(f"Performance: {result['tflops']:.2f} TFLOPs") @@ -986,11 +942,13 @@ if __name__ == "__main__": scale=args.scale, warmup=args.warmup, iterations=args.iterations, - validate=args.validate) + validate=args.validate, + ) print("\nBenchmark Results (Triton):") print( - f"Configuration: batch={args.batch}, seq_len={args.seq_len}, heads={args.heads}, " + - f"head_query={args.head_query}, dim={args.dim}, blocks={args.selected_blocks}, " + - f"block_size={args.block_size}") + f"Configuration: batch={args.batch}, seq_len={args.seq_len}, heads={args.heads}, " + + f"head_query={args.head_query}, dim={args.dim}, blocks={args.selected_blocks}, " + + f"block_size={args.block_size}" + ) print(f"Average time: {result['avg_time_ms']:.2f} ms") print(f"Performance: {result['tflops']:.2f} TFLOPs") diff --git a/examples/deepseek_nsa/example_tilelang_nsa_bwd.py b/examples/deepseek_nsa/example_tilelang_nsa_bwd.py index 1d1b5ea..56e98a9 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_bwd.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_bwd.py @@ -7,6 +7,7 @@ import torch import triton import fla + if parse(fla.__version__) < parse("0.2.1"): from fla.ops.common.utils import prepare_token_indices else: @@ -22,7 +23,8 @@ import tilelang tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + } +) def tilelang_kernel_fwd( batch, heads, @@ -34,11 +36,10 @@ def tilelang_kernel_fwd( groups=1, selected_blocks=16, ): - from tilelang import language as T if scale is None: - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) else: scale = scale * 1.44269504 # log2(e) @@ -67,12 +68,12 @@ def tilelang_kernel_fwd( @T.prim_func def native_sparse_attention( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), - O_slc: T.Tensor(o_slc_shape, dtype), - LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), + O_slc: T.Tensor(o_slc_shape, dtype), + LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), ): with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([G, BK], dtype) @@ -93,7 +94,7 @@ def tilelang_kernel_fwd( i_b, i_h = i_bh // head_kv, i_bh % head_kv NS = S - T.copy(Q[i_b, i_t, i_h * G:(i_h + 1) * G, :], Q_shared) + T.copy(Q[i_b, i_t, i_h * G : (i_h + 1) * G, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) @@ -103,12 +104,11 @@ def tilelang_kernel_fwd( i_s = BlockIndices[i_b, i_t, i_h, i] * BS if i_s <= i_t and i_s >= 0: # [BS, BK] - T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared) + T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared) if is_causal: for k, j in T.Parallel(G, BS): - acc_s[k, j] = T.if_then_else(i_t >= (i_s + j), 0, - -T.infinity(acc_s.dtype)) + acc_s[k, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) @@ -138,7 +138,7 @@ def tilelang_kernel_fwd( acc_o[k, j] *= scores_scale[k] # V * softmax(Q * K) - T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared) + T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(G, BV): @@ -146,18 +146,20 @@ def tilelang_kernel_fwd( T.copy(acc_o, O_shared) T.copy( O_shared, - O_slc[i_b, i_t, i_h * G:(i_h + 1) * G, i_v * BV:(i_v + 1) * BV], + O_slc[i_b, i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV], ) for i in T.Parallel(G): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, LSE_slc[i_b, i_t, i_h * G:(i_h + 1) * G]) + T.copy(logsum, LSE_slc[i_b, i_t, i_h * G : (i_h + 1) * G]) return native_sparse_attention -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) def tilelang_kernel_bwd_dkv( batch, heads, @@ -172,7 +174,7 @@ def tilelang_kernel_bwd_dkv( accum_dtype="float", ): if scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 else: sm_scale = scale @@ -207,15 +209,15 @@ def tilelang_kernel_bwd_dkv( @T.prim_func def flash_bwd_dkv( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(k_shape, dtype), - V: T.Tensor(v_shape, dtype), - LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), - Delta_slc: T.Tensor(delta_slc_shape, accum_dtype), - DO_slc: T.Tensor(do_slc_shape, dtype), - DK: T.Tensor(dk_shape, dtype), - DV: T.Tensor(dv_shape, dtype), - BlockMask: T.Tensor(block_mask_shape, "int32"), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(k_shape, dtype), + V: T.Tensor(v_shape, dtype), + LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), + Delta_slc: T.Tensor(delta_slc_shape, accum_dtype), + DO_slc: T.Tensor(do_slc_shape, dtype), + DK: T.Tensor(dk_shape, dtype), + DV: T.Tensor(dv_shape, dtype), + BlockMask: T.Tensor(block_mask_shape, "int32"), ): with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh): K_shared = T.alloc_shared([BS, BK], dtype) @@ -238,31 +240,33 @@ def tilelang_kernel_bwd_dkv( i_b, i_h = i_bh // H, i_bh % H - T.copy(K[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BK], K_shared) - T.copy(V[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BV], V_shared) + T.copy(K[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK], K_shared) + T.copy(V[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV], V_shared) # [BS, BK] T.clear(dk) # [BS, BV] T.clear(dv) - T.annotate_layout({ - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) + T.annotate_layout( + { + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) loop_st = i_s * BS loop_ed = seq_len for i in T.Pipelined( - start=loop_st, - stop=loop_ed, - num_stages=0, + start=loop_st, + stop=loop_ed, + num_stages=0, ): b_m_slc = BlockMask[i_b, i, i_h, i_s] if b_m_slc != 0: # [G, BK] - T.copy(Q[i_b, i, i_h * G:(i_h + 1) * G, :BK], Q_shared) + T.copy(Q[i_b, i, i_h * G : (i_h + 1) * G, :BK], Q_shared) T.clear(qkT) # [BS, BK] @ [G, BK] -> [BS, G] T.gemm( @@ -273,7 +277,7 @@ def tilelang_kernel_bwd_dkv( policy=T.GemmWarpPolicy.FullRow, ) # [G] - T.copy(LSE_slc[i_b, i, i_h * G:(i_h + 1) * G], lse_shared) + T.copy(LSE_slc[i_b, i, i_h * G : (i_h + 1) * G], lse_shared) for _i, _j in T.Parallel(BS, G): qkT[_i, _j] = T.exp2(qkT[_i, _j] * scale - lse_shared[_j]) @@ -282,7 +286,7 @@ def tilelang_kernel_bwd_dkv( qkT[_i, _j] = T.if_then_else(i >= (i_s * BS + _i), qkT[_i, _j], 0) # [G, BV] - T.copy(DO_slc[i_b, i, i_h * G:(i_h + 1) * G, :BV], do) + T.copy(DO_slc[i_b, i, i_h * G : (i_h + 1) * G, :BV], do) T.clear(dsT) # [BS, BV] @ [G, BV] -> [BS, G] T.gemm( @@ -296,7 +300,7 @@ def tilelang_kernel_bwd_dkv( # [BS, G] @ [G, BV] -> [BS, BV] T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) # [G] - T.copy(Delta_slc[i_b, i, i_h * G:(i_h + 1) * G], delta) + T.copy(Delta_slc[i_b, i, i_h * G : (i_h + 1) * G], delta) for i, j in T.Parallel(BS, G): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -305,8 +309,8 @@ def tilelang_kernel_bwd_dkv( T.copy(dv, dv_shared) T.copy(dk, dk_shared) - T.copy(dv_shared, DV[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BV]) - T.copy(dk_shared, DK[i_v, i_b, i_s * BS:(i_s + 1) * BS, i_h, :BK]) + T.copy(dv_shared, DV[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV]) + T.copy(dk_shared, DK[i_v, i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK]) return flash_bwd_dkv @@ -321,9 +325,11 @@ def make_dq_layout(dQ): ) -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) def tilelang_kernel_bwd_dqkv( batch, heads, @@ -338,7 +344,7 @@ def tilelang_kernel_bwd_dqkv( accum_dtype="float", ): if scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 else: sm_scale = scale @@ -373,16 +379,16 @@ def tilelang_kernel_bwd_dqkv( @T.prim_func def flash_bwd_dqkv( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(k_shape, dtype), - V: T.Tensor(v_shape, dtype), - LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), - Delta_slc: T.Tensor(delta_slc_shape, accum_dtype), - DO_slc: T.Tensor(do_slc_shape, dtype), - DQ: T.Tensor(dq_shape, dtype), - DK: T.Tensor(dk_shape, dtype), - DV: T.Tensor(dv_shape, dtype), - BlockMask: T.Tensor(block_mask_shape, "int32"), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(k_shape, dtype), + V: T.Tensor(v_shape, dtype), + LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), + Delta_slc: T.Tensor(delta_slc_shape, accum_dtype), + DO_slc: T.Tensor(do_slc_shape, dtype), + DQ: T.Tensor(dq_shape, dtype), + DK: T.Tensor(dk_shape, dtype), + DV: T.Tensor(dv_shape, dtype), + BlockMask: T.Tensor(block_mask_shape, "int32"), ): with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh): K_shared = T.alloc_shared([BS, BK], dtype) @@ -406,31 +412,33 @@ def tilelang_kernel_bwd_dqkv( i_b, i_h = i_bh // H, i_bh % H - T.copy(K[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BK], K_shared) - T.copy(V[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BV], V_shared) + T.copy(K[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK], K_shared) + T.copy(V[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV], V_shared) # [BS, BK] T.clear(dk) # [BS, BV] T.clear(dv) - T.annotate_layout({ - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) + T.annotate_layout( + { + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) loop_st = i_s * BS loop_ed = seq_len for i in T.Pipelined( - start=loop_st, - stop=loop_ed, - num_stages=0, + start=loop_st, + stop=loop_ed, + num_stages=0, ): b_m_slc = BlockMask[i_b, i, i_h, i_s] if b_m_slc != 0: # [G, BK] - T.copy(Q[i_b, i, i_h * G:(i_h + 1) * G, :BK], Q_shared) + T.copy(Q[i_b, i, i_h * G : (i_h + 1) * G, :BK], Q_shared) T.clear(qkT) # [BS, BK] @ [G, BK] -> [BS, G] T.gemm( @@ -441,7 +449,7 @@ def tilelang_kernel_bwd_dqkv( policy=T.GemmWarpPolicy.FullRow, ) # [G] - T.copy(LSE_slc[i_b, i, i_h * G:(i_h + 1) * G], lse_shared) + T.copy(LSE_slc[i_b, i, i_h * G : (i_h + 1) * G], lse_shared) for _i, _j in T.Parallel(BS, G): qkT[_i, _j] = T.exp2(qkT[_i, _j] * scale - lse_shared[_j]) @@ -450,7 +458,7 @@ def tilelang_kernel_bwd_dqkv( qkT[_i, _j] = T.if_then_else(i >= (i_s * BS + _i), qkT[_i, _j], 0) # [G, BV] - T.copy(DO_slc[i_b, i, i_h * G:(i_h + 1) * G, :BV], do) + T.copy(DO_slc[i_b, i, i_h * G : (i_h + 1) * G, :BV], do) T.clear(dsT) # [BS, BV] @ [G, BV] -> [BS, G] T.gemm( @@ -464,7 +472,7 @@ def tilelang_kernel_bwd_dqkv( # [BS, G] @ [G, BV] -> [BS, BV] T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) # [G] - T.copy(Delta_slc[i_b, i, i_h * G:(i_h + 1) * G], delta) + T.copy(Delta_slc[i_b, i, i_h * G : (i_h + 1) * G], delta) for _i, _j in T.Parallel(BS, G): dsT_cast[_i, _j] = qkT[_i, _j] * (dsT[_i, _j] - delta[_j]) * sm_scale @@ -480,16 +488,18 @@ def tilelang_kernel_bwd_dqkv( T.copy(dv, dv_shared) T.copy(dk, dk_shared) - T.copy(dv_shared, DV[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BV]) - T.copy(dk_shared, DK[i_v, i_b, i_s * BS:(i_s + 1) * BS, i_h, :BK]) + T.copy(dv_shared, DV[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV]) + T.copy(dk_shared, DK[i_v, i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK]) return flash_bwd_dqkv @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def tilelang_kernel_preprocess( batch, heads, @@ -505,9 +515,9 @@ def tilelang_kernel_preprocess( @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, seq_len, heads], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, seq_len, heads], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -516,20 +526,22 @@ def tilelang_kernel_preprocess( delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim, blk)): - T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, by * blk:(by + 1) * blk, bx]) + T.copy(delta, Delta[bz, by * blk : (by + 1) * blk, bx]) return flash_bwd_prep @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def tilelang_kernel_block_mask( batch, heads, @@ -551,9 +563,9 @@ def tilelang_kernel_block_mask( @T.prim_func def flash_bwd_block_mask( - BlockIndices: T.Tensor(block_indices_shape, dtype), # type: ignore - BlockCounts: T.Tensor(block_counts_shape, dtype), # type: ignore - BlockMask: T.Tensor(block_mask_shape, dtype), # type: ignore + BlockIndices: T.Tensor(block_indices_shape, dtype), # type: ignore + BlockCounts: T.Tensor(block_counts_shape, dtype), # type: ignore + BlockMask: T.Tensor(block_mask_shape, dtype), # type: ignore ): with T.Kernel(seq_len, batch, heads * S) as (bx, by, bz): i_t, i_b, i_hs = bx, by, bz @@ -603,9 +615,7 @@ def parallel_nsa_bwd( dk = torch.empty(NV, *k.shape, dtype=k.dtype, device=q.device) dv = torch.empty(v.shape, dtype=v.dtype, device=q.device) - block_mask = tilelang_kernel_block_mask(B, H, T, S, - BS)(block_indices.to(torch.int32), - block_counts.to(torch.int32)).to(torch.bool) + block_mask = tilelang_kernel_block_mask(B, H, T, S, BS)(block_indices.to(torch.int32), block_counts.to(torch.int32)).to(torch.bool) fused_qkv_bwd_kernel = tilelang_kernel_bwd_dqkv( batch=B, @@ -618,8 +628,7 @@ def parallel_nsa_bwd( selected_blocks=S, scale=scale, ) - fused_qkv_bwd_kernel(q, k, v, lse_slc, delta_slc, do_slc, dq, dk, dv, - block_mask.to(torch.int32)) + fused_qkv_bwd_kernel(q, k, v, lse_slc, delta_slc, do_slc, dq, dk, dv, block_mask.to(torch.int32)) dq = dq.sum(0) dk = dk.sum(0) @@ -628,7 +637,6 @@ def parallel_nsa_bwd( @torch.compile class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -773,23 +781,21 @@ def parallel_nsa( Outputs of shape `[B, SEQLEN, HQ, V]` if `head_first=False` else `[B, HQ, SEQLEN, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), - (q, k, v, block_indices)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): block_counts = rearrange(block_counts, "b h t -> b t h") - assert (q.shape[2] % (k.shape[2] * 16) == 0), "Group size must be a multiple of 16 in NSA" + assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" if isinstance(block_counts, int): block_indices = block_indices[:, :, :, :block_counts] block_counts = None - o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, - window_size, scale, cu_seqlens) + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) if window_size > 0: o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) else: @@ -814,7 +820,7 @@ if __name__ == "__main__": for t in range(T): for h in range(H): i_i = torch.randperm(max(1, (t // block_size)))[:S] - block_indices[b, t, h, :len(i_i)] = i_i + block_indices[b, t, h, : len(i_i)] = i_i block_indices = block_indices.sort(-1)[0] block_counts = torch.randint(1, S + 1, (B, T, H), device="cuda") diff --git a/examples/deepseek_nsa/example_tilelang_nsa_decode.py b/examples/deepseek_nsa/example_tilelang_nsa_decode.py index 58f4355..38fc51a 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_decode.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_decode.py @@ -16,7 +16,8 @@ tilelang.testing.set_random_seed(42) tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def native_sparse_attention( batch, heads, @@ -25,10 +26,10 @@ def native_sparse_attention( scale=None, block_size=64, # Tile size for attention computation groups=1, # Grouped query attention (GQA) groups - selected_blocks=16 # Number of blocks to select per attention head + selected_blocks=16, # Number of blocks to select per attention head ): if scale is None: - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups # Modified shapes for inference (q has seq_len=1)a q_shape = [batch, 1, heads, dim] # Changed seq_len to 1 @@ -53,12 +54,11 @@ def native_sparse_attention( @T.prim_func def native_sparse_attention( - Q: T.Tensor(q_shape, dtype), # [batch, 1, heads, dim] - K: T.Tensor(kv_shape, dtype), # [batch, seq_len, head_kv, dim] - V: T.Tensor(kv_shape, dtype), # Same shape as K - BlockIndices: T.Tensor(block_indices_shape, - block_indices_dtype), # Selected block indices - Output: T.Tensor(q_shape, dtype), # Output attention tensor + Q: T.Tensor(q_shape, dtype), # [batch, 1, heads, dim] + K: T.Tensor(kv_shape, dtype), # [batch, seq_len, head_kv, dim] + V: T.Tensor(kv_shape, dtype), # Same shape as K + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), # Selected block indices + Output: T.Tensor(q_shape, dtype), # Output attention tensor ): with T.Kernel(1, NV, batch * head_kv, threads=threads) as (bx, by, bz): # Shared memory allocations for tile storage @@ -82,7 +82,7 @@ def native_sparse_attention( NS = S # Copy Q for the single position - T.copy(Q[i_b, 0, i_h * G:(i_h + 1) * G, :], Q_shared) # Changed i_t to 0 + T.copy(Q[i_b, 0, i_h * G : (i_h + 1) * G, :], Q_shared) # Changed i_t to 0 T.fill(acc_o, 0) T.fill(logsum, 0) @@ -93,16 +93,11 @@ def native_sparse_attention( i_s = BlockIndices[i_b, 0, i_h, i] * BS # Get block offset if i_s >= 0: # Skip invalid/padding blocks # Load current key block to shared memory - T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared) + T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared) # Compute QK^T attention scores T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) # Online softmax with numerical stability # 1. Compute max for scaling @@ -122,15 +117,14 @@ def native_sparse_attention( T.copy(acc_s, acc_s_cast) # Accumulate attention-weighted values - T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared) + T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) # Final normalization and output for i, j in T.Parallel(G, BV): acc_o[i, j] /= logsum[i] # Normalize by logsum T.copy(acc_o, O_shared) - T.copy(O_shared, Output[i_b, 0, i_h * G:(i_h + 1) * G, - i_v * BV:(i_v + 1) * BV]) # Changed i_t to 0 + T.copy(O_shared, Output[i_b, 0, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV]) # Changed i_t to 0 return native_sparse_attention @@ -149,21 +143,21 @@ def main(): selected_blocks=S, ) - Q = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device='cuda').requires_grad_(True) - K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) - V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) + Q = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) - mask = torch.randint(0, 2, (B, SEQ_LEN, groups), device='cuda') - DO = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device='cuda') + mask = torch.randint(0, 2, (B, SEQ_LEN, groups), device="cuda") + DO = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device="cuda") - block_indices = torch.full((B, SEQ_LEN_Q, H, S), SEQ_LEN, dtype=torch.long, device='cuda') + block_indices = torch.full((B, SEQ_LEN_Q, H, S), SEQ_LEN, dtype=torch.long, device="cuda") for b in range(B): for t in range(SEQ_LEN_Q): for h in range(H): i_i = torch.randperm(max(1, (t // block_size)))[:S] - block_indices[b, t, h, :len(i_i)] = i_i + block_indices[b, t, h, : len(i_i)] = i_i block_indices = block_indices.sort(-1)[0] - block_counts = torch.randint(1, S + 1, (B, SEQ_LEN_Q, H), device='cuda') + block_counts = torch.randint(1, S + 1, (B, SEQ_LEN_Q, H), device="cuda") out = kernel(Q, K, V, block_indices.to(torch.int32)) diff --git a/examples/deepseek_nsa/example_tilelang_nsa_fwd.py b/examples/deepseek_nsa/example_tilelang_nsa_fwd.py index 0b71779..a8dd26b 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_fwd.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_fwd.py @@ -14,18 +14,11 @@ tilelang.testing.set_random_seed(0) tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) -def native_sparse_attention(batch, - heads, - seq_len, - dim, - is_causal, - scale=None, - block_size=64, - groups=1, - selected_blocks=16): + }, +) +def native_sparse_attention(batch, heads, seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16): if scale is None: - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) else: scale = scale * 1.44269504 # log2(e) @@ -52,11 +45,11 @@ def native_sparse_attention(batch, @T.prim_func def native_sparse_attention( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([G, BK], dtype) @@ -77,7 +70,7 @@ def native_sparse_attention(batch, i_b, i_h = i_bh // head_kv, i_bh % head_kv NS = S - T.copy(Q[i_b, i_t, i_h * G:(i_h + 1) * G, :], Q_shared) + T.copy(Q[i_b, i_t, i_h * G : (i_h + 1) * G, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) @@ -87,21 +80,15 @@ def native_sparse_attention(batch, i_s = BlockIndices[i_b, i_t, i_h, i] * BS if i_s <= i_t and i_s >= 0: # [BS, BK] - T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared) + T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared) if is_causal: for i, j in T.Parallel(G, BS): - acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) # Softmax T.copy(scores_max, scores_max_prev) @@ -121,13 +108,13 @@ def native_sparse_attention(batch, acc_o[i, j] *= scores_scale[i] # V * softmax(Q * K) - T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared) + T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(G, BV): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[i_b, i_t, i_h * G:(i_h + 1) * G, i_v * BV:(i_v + 1) * BV]) + T.copy(O_shared, Output[i_b, i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV]) return native_sparse_attention @@ -148,20 +135,20 @@ def main(): ) print(kernel.get_kernel_source()) torch.random.manual_seed(0) - Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda').requires_grad_(True) - K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) - V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) - g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device='cuda').requires_grad_(True) - g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device='cuda').requires_grad_(True) - DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda') - - block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device='cuda') - block_counts = torch.zeros((B, SEQ_LEN, H), dtype=torch.long, device='cuda') + Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda") + + block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device="cuda") + block_counts = torch.zeros((B, SEQ_LEN, H), dtype=torch.long, device="cuda") for b in range(B): for t in range(SEQ_LEN): for h in range(H): i_i = torch.randperm(max(1, (t // block_size)))[:S] - block_indices[b, t, h, :len(i_i)] = i_i + block_indices[b, t, h, : len(i_i)] = i_i block_counts[b, t, h] = (block_indices[b, t, h] != SEQ_LEN).sum().item() block_indices = block_indices.sort(-1)[0] diff --git a/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py b/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py index d365e7a..af87db8 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py @@ -8,6 +8,7 @@ from tilelang import language as T import tilelang.testing import fla + if parse(fla.__version__) < parse("0.2.1"): from fla.ops.common.utils import prepare_token_indices else: @@ -21,18 +22,11 @@ from einops import rearrange tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) -def native_sparse_attention_varlen(batch, - heads, - c_seq_len, - dim, - is_causal, - scale=None, - block_size=64, - groups=1, - selected_blocks=16): + } +) +def native_sparse_attention_varlen(batch, heads, c_seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16): if scale is None: - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [c_seq_len, heads, dim] kv_shape = [c_seq_len, head_kv, dim] @@ -66,14 +60,14 @@ def native_sparse_attention_varlen(batch, @T.prim_func def native_sparse_attention_varlen( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - O_slc: T.Tensor(o_slc_shape, dtype), - BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), - BlockCounts: T.Tensor(block_counts_shape, block_counts_dtype), - Offsets: T.Tensor(offsets_shape, offsets_dtype), - TokenIndices: T.Tensor(token_indices_shape, token_indices_dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + O_slc: T.Tensor(o_slc_shape, dtype), + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), + BlockCounts: T.Tensor(block_counts_shape, block_counts_dtype), + Offsets: T.Tensor(offsets_shape, offsets_dtype), + TokenIndices: T.Tensor(token_indices_shape, token_indices_dtype), ): with T.Kernel(c_seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([G, BK], dtype) @@ -100,7 +94,7 @@ def native_sparse_attention_varlen(batch, current_seq_len = eos - bos NS = BlockCounts[i_t, i_h] - T.copy(Q[bos + i_t, i_h * G:(i_h + 1) * G, :BK], Q_shared) + T.copy(Q[bos + i_t, i_h * G : (i_h + 1) * G, :BK], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) @@ -112,21 +106,15 @@ def native_sparse_attention_varlen(batch, # [BS, BK] # Lei: may have some padding issues # we should learn from mha varlen templates to handle this - T.copy(K[bos + i_s:bos + i_s + BS, i_h, :BK], K_shared) + T.copy(K[bos + i_s : bos + i_s + BS, i_h, :BK], K_shared) if is_causal: for i, j in T.Parallel(G, BS): - acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) # Softmax T.copy(scores_max, scores_max_prev) @@ -146,13 +134,13 @@ def native_sparse_attention_varlen(batch, acc_o[i, j] *= scores_scale[i] # V * softmax(Q * K) - T.copy(V[bos + i_s:bos + i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared) + T.copy(V[bos + i_s : bos + i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(G, BV): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, O_slc[bos + i_t, i_h * G:(i_h + 1) * G, i_v * BV:(i_v + 1) * BV]) + T.copy(O_shared, O_slc[bos + i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV]) return native_sparse_attention_varlen @@ -190,17 +178,20 @@ def parallel_nsa_fwd( o_slc = torch.empty(B, C_SEQ_LEN, HQ, V, dtype=v.dtype, device=q.device) kernel( - q.view(C_SEQ_LEN, HQ, D), k.view(C_SEQ_LEN, H, D), v.view(C_SEQ_LEN, H, D), + q.view(C_SEQ_LEN, HQ, D), + k.view(C_SEQ_LEN, H, D), + v.view(C_SEQ_LEN, H, D), o_slc.view(C_SEQ_LEN, HQ, V), block_indices.to(torch.int32).view(C_SEQ_LEN, H, S), - block_counts.to(torch.int32).view(C_SEQ_LEN, H), offsets.to(torch.int32), - token_indices.to(torch.int32)) + block_counts.to(torch.int32).view(C_SEQ_LEN, H), + offsets.to(torch.int32), + token_indices.to(torch.int32), + ) return o_slc @torch.compile class ParallelNSAFunction(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, scale, offsets): ctx.dtype = q.dtype @@ -221,22 +212,25 @@ class ParallelNSAFunction(torch.autograd.Function): window_size=window_size, scale=scale, offsets=offsets, - token_indices=token_indices) + token_indices=token_indices, + ) return o_slc.to(q.dtype) -def parallel_nsa(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_indices: torch.LongTensor, - block_counts: Optional[Union[torch.LongTensor, int]] = None, - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False) -> torch.Tensor: +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: r""" Args: q (torch.Tensor): @@ -276,29 +270,27 @@ def parallel_nsa(q: torch.Tensor, Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), - (q, k, v, block_indices)) - g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') + block_counts = rearrange(block_counts, "b h t -> b t h") assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" if isinstance(block_counts, int): block_indices = block_indices[:, :, :, :block_counts] block_counts = None - o_slc = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, - scale, cu_seqlens) + o_slc = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) if window_size > 0: assert False, "Window size is not supported yet" else: o = o_slc * g_slc.unsqueeze(-1) if head_first: - o = rearrange(o, 'b t h d -> b h t d') + o = rearrange(o, "b t h d -> b h t d") return o @@ -306,41 +298,57 @@ if __name__ == "__main__": N, C_SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 64, 1, 32, torch.float16 torch.manual_seed(42) # randomly split the sequence into N segments - offsets = torch.cat([ - torch.tensor([0], dtype=torch.long), - torch.arange(16, C_SEQ_LEN)[torch.randperm(C_SEQ_LEN - 1)[:N - 1]], - torch.tensor([C_SEQ_LEN], dtype=torch.long) - ], 0).cuda().sort()[0] + offsets = ( + torch.cat( + [ + torch.tensor([0], dtype=torch.long), + torch.arange(16, C_SEQ_LEN)[torch.randperm(C_SEQ_LEN - 1)[: N - 1]], + torch.tensor([C_SEQ_LEN], dtype=torch.long), + ], + 0, + ) + .cuda() + .sort()[0] + ) # seq-first required for inputs with variable lengths - perm_q = torch.randperm(C_SEQ_LEN, device='cuda') - perm_k = torch.randperm(C_SEQ_LEN, device='cuda') - perm_v = torch.randperm(C_SEQ_LEN, device='cuda') - q = torch.linspace( - 0, 1, steps=C_SEQ_LEN, dtype=dtype, - device='cuda')[perm_q].view(1, C_SEQ_LEN, 1, 1).expand(1, C_SEQ_LEN, HQ, - D).clone().requires_grad_(True) - k = torch.linspace( - 0, 1, steps=C_SEQ_LEN, dtype=dtype, - device='cuda')[perm_k].view(1, C_SEQ_LEN, 1, 1).expand(1, C_SEQ_LEN, H, - D).clone().requires_grad_(True) - v = torch.linspace( - 0, 1, steps=C_SEQ_LEN, dtype=dtype, - device='cuda')[perm_v].view(1, C_SEQ_LEN, 1, 1).expand(1, C_SEQ_LEN, H, - D).clone().requires_grad_(True) - g_slc = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device='cuda').requires_grad_(True) - g_swa = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device='cuda').requires_grad_(True) - do = torch.randn((1, C_SEQ_LEN, HQ, D), dtype=dtype, device='cuda') + perm_q = torch.randperm(C_SEQ_LEN, device="cuda") + perm_k = torch.randperm(C_SEQ_LEN, device="cuda") + perm_v = torch.randperm(C_SEQ_LEN, device="cuda") + q = ( + torch.linspace(0, 1, steps=C_SEQ_LEN, dtype=dtype, device="cuda")[perm_q] + .view(1, C_SEQ_LEN, 1, 1) + .expand(1, C_SEQ_LEN, HQ, D) + .clone() + .requires_grad_(True) + ) + k = ( + torch.linspace(0, 1, steps=C_SEQ_LEN, dtype=dtype, device="cuda")[perm_k] + .view(1, C_SEQ_LEN, 1, 1) + .expand(1, C_SEQ_LEN, H, D) + .clone() + .requires_grad_(True) + ) + v = ( + torch.linspace(0, 1, steps=C_SEQ_LEN, dtype=dtype, device="cuda")[perm_v] + .view(1, C_SEQ_LEN, 1, 1) + .expand(1, C_SEQ_LEN, H, D) + .clone() + .requires_grad_(True) + ) + g_slc = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((1, C_SEQ_LEN, HQ, D), dtype=dtype, device="cuda") token_indices = prepare_token_indices(offsets).tolist() - block_indices = torch.full((1, C_SEQ_LEN, H, S), C_SEQ_LEN, dtype=torch.long, device='cuda') + block_indices = torch.full((1, C_SEQ_LEN, H, S), C_SEQ_LEN, dtype=torch.long, device="cuda") for i in range(C_SEQ_LEN): _, t = token_indices[i] for h in range(H): i_i = torch.randperm(max(1, tilelang.cdiv(t, block_size)))[:S] - block_indices[0, i, h, :len(i_i)] = i_i + block_indices[0, i, h, : len(i_i)] = i_i block_indices = block_indices.sort(-1)[0] - block_counts = torch.randint(1, S + 1, (1, C_SEQ_LEN, H), device='cuda') + block_counts = torch.randint(1, S + 1, (1, C_SEQ_LEN, H), device="cuda") ref = naive_nsa( q=q, @@ -351,7 +359,8 @@ if __name__ == "__main__": block_indices=block_indices, block_counts=block_counts, block_size=block_size, - cu_seqlens=offsets) + cu_seqlens=offsets, + ) tri = parallel_nsa( q=q, @@ -362,7 +371,8 @@ if __name__ == "__main__": block_indices=block_indices, block_counts=block_counts, block_size=block_size, - cu_seqlens=offsets) + cu_seqlens=offsets, + ) print("tri", tri) print("ref", ref) diff --git a/examples/deepseek_nsa/example_triton_nsa_bwd.py b/examples/deepseek_nsa/example_triton_nsa_bwd.py index e912794..af05bfa 100644 --- a/examples/deepseek_nsa/example_triton_nsa_bwd.py +++ b/examples/deepseek_nsa/example_triton_nsa_bwd.py @@ -8,6 +8,7 @@ import triton import triton.language as tl import fla + if parse(fla.__version__) < parse("0.2.1"): from fla.ops.common.utils import prepare_token_indices else: @@ -17,21 +18,44 @@ from reference import naive_nsa from einops import rearrange -@triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, - 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor), -}) +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1]], - key=['BS', 'BK', 'BV'], + key=["BS", "BK", "BV"], ) @triton.jit -def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices, - block_counts, offsets, token_indices, T, H: tl.constexpr, - HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr, - S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr, - BV: tl.constexpr, USE_OFFSETS: tl.constexpr, - USE_BLOCK_COUNTS: tl.constexpr): +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H @@ -46,20 +70,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc # else: NS = S - p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), - (1, 0)) + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # the Q block is kept in the shared memory throughout the whole kernel # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) - p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), - (G, BV), (1, 0)) + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_o_slc = tl.zeros([G, BV], dtype=tl.float32) - b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32) + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) b_acc_slc = tl.zeros([G], dtype=tl.float32) for i in range(NS): i_s = tl.load(block_indices + i).to(tl.int32) * BS @@ -72,7 +94,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) # [G, BS] b_s_slc = tl.dot(b_q, b_k_slc) - b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf')) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) # [G] b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc @@ -92,7 +114,6 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -105,8 +126,7 @@ class ParallelNSAFunction(torch.autograd.Function): # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] token_indices = prepare_token_indices(offsets) if offsets is not None else None - o, lse = parallel_nsa_fwd( - q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) + o, lse = parallel_nsa_fwd(q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) ctx.save_for_backward(q, k, v, o, lse) ctx.block_indices = block_indices ctx.block_size = block_size @@ -134,7 +154,8 @@ class ParallelNSAFunction(torch.autograd.Function): window_size=ctx.window_size, scale=ctx.scale, offsets=ctx.offsets, - token_indices=ctx.token_indices) + token_indices=ctx.token_indices, + ) return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None @@ -199,37 +220,56 @@ def parallel_nsa_fwd( return o_slc, lse_slc, o_swa, lse_swa -@triton.heuristics({'USE_OFFSETS': lambda args: args['offsets'] is not None}) +@triton.heuristics({"USE_OFFSETS": lambda args: args["offsets"] is not None}) @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], - key=['BS', 'BK', 'BV'], + key=["BS", "BK", "BV"], ) -@triton.jit(do_not_specialize=['T']) -def parallel_nsa_bwd_kernel_dkv(q, k, v, lse_slc, lse_swa, delta_slc, delta_swa, do_slc, do_swa, dk, - dv, block_mask, offsets, chunk_indices, scale, T, B: tl.constexpr, - H: tl.constexpr, HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, - V: tl.constexpr, M: tl.constexpr, BS: tl.constexpr, - WS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr): +@triton.jit(do_not_specialize=["T"]) +def parallel_nsa_bwd_kernel_dkv( + q, + k, + v, + lse_slc, + lse_swa, + delta_slc, + delta_swa, + do_slc, + do_swa, + dk, + dv, + block_mask, + offsets, + chunk_indices, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + M: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, +): i_v, i_s, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: - i_n, i_s = tl.load(chunk_indices + i_s * 2).to(tl.int32), tl.load(chunk_indices + i_s * 2 + - 1).to(tl.int32) + i_n, i_s = tl.load(chunk_indices + i_s * 2).to(tl.int32), tl.load(chunk_indices + i_s * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: bos, eos = i_b * T, i_b * T + T - p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_s * BS, 0), (BS, BK), - (1, 0)) - p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_s * BS, i_v * BV), - (BS, BV), (1, 0)) - p_dk = tl.make_block_ptr(dk + (i_v * B * T * H + bos * H + i_h) * K, (T, K), (H * K, 1), - (i_s * BS, 0), (BS, BK), (1, 0)) - p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_s * BS, i_v * BV), - (BS, BV), (1, 0)) + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_s * BS, 0), (BS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_s * BS, i_v * BV), (BS, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_v * B * T * H + bos * H + i_h) * K, (T, K), (H * K, 1), (i_s * BS, 0), (BS, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_s * BS, i_v * BV), (BS, BV), (1, 0)) # [BS, BK] b_k = tl.load(p_k, boundary_check=(0, 1)) @@ -241,14 +281,12 @@ def parallel_nsa_bwd_kernel_dkv(q, k, v, lse_slc, lse_swa, delta_slc, delta_swa, for i in range(i_s * BS, T): b_m_slc = tl.load(block_mask + (bos + i) * H * M + i_h * M + i_s) if b_m_slc: - p_q = tl.make_block_ptr(q + (bos + i) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), - (1, 0)) + p_q = tl.make_block_ptr(q + (bos + i) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) - p_do_slc = tl.make_block_ptr(do_slc + (bos + i) * HQ * V, (HQ, V), (V, 1), - (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_do_slc = tl.make_block_ptr(do_slc + (bos + i) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_slc = lse_slc + (bos + i) * HQ + i_h * G + tl.arange(0, G) p_delta_slc = delta_slc + (bos + i) * HQ + i_h * G + tl.arange(0, G) # [G, BV] @@ -272,14 +310,12 @@ def parallel_nsa_bwd_kernel_dkv(q, k, v, lse_slc, lse_swa, delta_slc, delta_swa, if WS > 0: o_s = i_s * BS + tl.arange(0, BS) if max(i_s * BS, i - WS + 1) < min((i_s + 1) * BS, i + 1): - p_q = tl.make_block_ptr(q + (bos + i) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), - (G, BK), (1, 0)) + p_q = tl.make_block_ptr(q + (bos + i) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) - p_do_swa = tl.make_block_ptr(do_swa + (bos + i) * HQ * V, (HQ, V), (V, 1), - (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_do_swa = tl.make_block_ptr(do_swa + (bos + i) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_swa = lse_swa + (bos + i) * HQ + i_h * G + tl.arange(0, G) p_delta_swa = delta_swa + (bos + i) * HQ + i_h * G + tl.arange(0, G) # [G, BV] @@ -304,12 +340,19 @@ def parallel_nsa_bwd_kernel_dkv(q, k, v, lse_slc, lse_swa, delta_slc, delta_swa, tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) -@triton.heuristics( - {'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor)}) +@triton.heuristics({"USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor)}) @triton.jit -def parallel_nsa_kernel_mask(block_indices, block_counts, block_mask, T: tl.constexpr, - H: tl.constexpr, S: tl.constexpr, BS: tl.constexpr, NS: tl.constexpr, - USE_BLOCK_COUNTS: tl.constexpr): +def parallel_nsa_kernel_mask( + block_indices, + block_counts, + block_mask, + T: tl.constexpr, + H: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + NS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): i_t, i_b, i_hs = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_h, i_s = i_hs // S, i_hs % S @@ -320,31 +363,56 @@ def parallel_nsa_kernel_mask(block_indices, block_counts, block_mask, T: tl.cons b_m = b_i * BS <= i_t if b_i < NS and b_i >= 0: - tl.store(block_mask + i_b * T * H * NS + i_t * H * NS + i_h * NS + b_i, - b_m.to(block_mask.dtype.element_ty)) + tl.store(block_mask + i_b * T * H * NS + i_t * H * NS + i_h * NS + b_i, b_m.to(block_mask.dtype.element_ty)) -@triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, - 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor) -}) +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], - key=['BS', 'BK', 'BV'], + key=["BS", "BK", "BV"], ) -@triton.jit(do_not_specialize=['T']) -def parallel_nsa_bwd_kernel_dq(q, k, v, lse_slc, delta_slc, do_slc, lse_swa, delta_swa, do_swa, dq, - scale, block_indices, block_counts, offsets, token_indices, T, - B: tl.constexpr, H: tl.constexpr, HQ: tl.constexpr, G: tl.constexpr, - K: tl.constexpr, V: tl.constexpr, S: tl.constexpr, BS: tl.constexpr, - WS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr, USE_BLOCK_COUNTS: tl.constexpr): +@triton.jit(do_not_specialize=["T"]) +def parallel_nsa_bwd_kernel_dq( + q, + k, + v, + lse_slc, + delta_slc, + do_slc, + lse_swa, + delta_swa, + do_swa, + dq, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: - i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + - 1).to(tl.int32) + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: @@ -449,27 +517,49 @@ def parallel_nsa_bwd_kernel_dq(q, k, v, lse_slc, delta_slc, do_slc, lse_swa, del tl.store(p_dq, (b_dq_slc + b_dq_swa).to(p_dq.dtype.element_ty), boundary_check=(0, 1)) -@triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, - 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor), -}) +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], - key=['BS', 'BK', 'BV'], + key=["BS", "BK", "BV"], ) @triton.jit -def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices, - block_counts, offsets, token_indices, T, H: tl.constexpr, - HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr, - S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr, - BV: tl.constexpr, USE_OFFSETS: tl.constexpr, - USE_BLOCK_COUNTS: tl.constexpr): +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: - i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + - 1).to(tl.int32) + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: @@ -484,20 +574,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc else: NS = S - p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), - (1, 0)) + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # the Q block is kept in the shared memory throughout the whole kernel # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) - p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), - (G, BV), (1, 0)) + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_o_slc = tl.zeros([G, BV], dtype=tl.float32) - b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32) + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) b_acc_slc = tl.zeros([G], dtype=tl.float32) for i in range(NS): i_s = tl.load(block_indices + i).to(tl.int32) * BS @@ -510,7 +598,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) # [G, BS] b_s_slc = tl.dot(b_q, b_k_slc) - b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf')) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) # [G] b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc @@ -529,13 +617,12 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty)) if WS > 0: - p_o_swa = tl.make_block_ptr(o_swa + (bos + i_t) * HQ * V, (HQ, V), (V, 1), - (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_o_swa = tl.make_block_ptr(o_swa + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_swa = lse_swa + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_o_swa = tl.zeros([G, BV], dtype=tl.float32) - b_m_swa = tl.full([G], float('-inf'), dtype=tl.float32) + b_m_swa = tl.full([G], float("-inf"), dtype=tl.float32) b_acc_swa = tl.zeros([G], dtype=tl.float32) for i_s in range(max(0, i_t - WS + 1), i_t + 1, BS): p_k_swa = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1)) @@ -546,7 +633,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc b_v_swa = tl.load(p_v_swa, boundary_check=(0, 1)) # [G, BS] b_s_swa = tl.dot(b_q, b_k_swa) - b_s_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_swa, float('-inf')) + b_s_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_swa, float("-inf")) # [G] b_m_swa, b_mp_swa = tl.maximum(b_m_swa, tl.max(b_s_swa, 1)), b_m_swa @@ -593,14 +680,8 @@ def parallel_nsa_block_mask( block_mask = torch.zeros(B, T, H, NS, dtype=torch.bool, device=block_indices.device) parallel_nsa_kernel_mask[(T, B, H * S)]( - block_indices=block_indices, - block_counts=block_counts, - block_mask=block_mask, - T=T, - H=H, - S=S, - BS=BS, - NS=NS) + block_indices=block_indices, block_counts=block_counts, block_mask=block_mask, T=T, H=H, S=S, BS=BS, NS=NS + ) return block_mask @@ -676,7 +757,8 @@ def parallel_nsa_bwd( BS=BS, WS=WS, BK=BK, - BV=BV) + BV=BV, + ) dq = dq.sum(0) if offsets is not None: @@ -719,14 +801,14 @@ def parallel_nsa_bwd( BS=BS, WS=WS, BK=BK, - BV=BV) + BV=BV, + ) dk = dk.sum(0) return dq, dk, dv @torch.compile class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -749,7 +831,8 @@ class ParallelNSAFunction(torch.autograd.Function): window_size=window_size, scale=scale, offsets=offsets, - token_indices=token_indices) + token_indices=token_indices, + ) ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa) ctx.block_indices = block_indices ctx.block_counts = block_counts @@ -781,22 +864,25 @@ class ParallelNSAFunction(torch.autograd.Function): window_size=ctx.window_size, scale=ctx.scale, offsets=ctx.offsets, - token_indices=ctx.token_indices) + token_indices=ctx.token_indices, + ) return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None -def parallel_nsa(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_indices: torch.LongTensor, - block_counts: Optional[Union[torch.LongTensor, int]] = None, - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False) -> torch.Tensor: +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: r""" Args: q (torch.Tensor): @@ -836,51 +922,49 @@ def parallel_nsa(q: torch.Tensor, Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), - (q, k, v, block_indices)) - g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') + block_counts = rearrange(block_counts, "b h t -> b t h") assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" if isinstance(block_counts, int): block_indices = block_indices[:, :, :, :block_counts] block_counts = None - o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, - window_size, scale, cu_seqlens) + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) if window_size > 0: o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) else: o = o_slc * g_slc.unsqueeze(-1) if head_first: - o = rearrange(o, 'b t h d -> b h t d') + o = rearrange(o, "b t h d -> b h t d") return o if __name__ == "__main__": B, T, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 32, 1, 32, torch.float16 torch.random.manual_seed(0) - q = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda').requires_grad_(True) - k = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True) - v = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True) - g_slc = torch.ones((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True) - g_swa = torch.ones((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True) - do = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda') - - block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device='cuda') + q = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + k = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True) + v = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True) + g_slc = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda") + + block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device="cuda") for b in range(B): for t in range(T): for h in range(H): i_i = torch.randperm(max(1, (t // block_size)))[:S] - block_indices[b, t, h, :len(i_i)] = i_i + block_indices[b, t, h, : len(i_i)] = i_i block_indices = block_indices.sort(-1)[0] - block_counts = torch.randint(1, S + 1, (B, T, H), device='cuda') + block_counts = torch.randint(1, S + 1, (B, T, H), device="cuda") ref = naive_nsa( q=q, diff --git a/examples/deepseek_nsa/example_triton_nsa_fwd.py b/examples/deepseek_nsa/example_triton_nsa_fwd.py index 2c74001..c9ab28d 100644 --- a/examples/deepseek_nsa/example_triton_nsa_fwd.py +++ b/examples/deepseek_nsa/example_triton_nsa_fwd.py @@ -8,6 +8,7 @@ import triton import triton.language as tl import fla + if parse(fla.__version__) < parse("0.2.1"): from fla.ops.common.utils import prepare_token_indices else: @@ -17,21 +18,44 @@ from reference import naive_nsa from einops import rearrange -@triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, - 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor), -}) +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1]], - key=['BS', 'BK', 'BV'], + key=["BS", "BK", "BV"], ) @triton.jit -def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices, - block_counts, offsets, token_indices, T, H: tl.constexpr, - HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr, - S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr, - BV: tl.constexpr, USE_OFFSETS: tl.constexpr, - USE_BLOCK_COUNTS: tl.constexpr): +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H @@ -46,20 +70,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc # else: NS = S - p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), - (1, 0)) + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # the Q block is kept in the shared memory throughout the whole kernel # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) - p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), - (G, BV), (1, 0)) + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_o_slc = tl.zeros([G, BV], dtype=tl.float32) - b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32) + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) b_acc_slc = tl.zeros([G], dtype=tl.float32) for i in range(NS): i_s = tl.load(block_indices + i).to(tl.int32) * BS @@ -72,7 +94,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) # [G, BS] b_s_slc = tl.dot(b_q, b_k_slc) - b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf')) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) # [G] b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc @@ -92,7 +114,6 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -105,8 +126,7 @@ class ParallelNSAFunction(torch.autograd.Function): # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] token_indices = prepare_token_indices(offsets) if offsets is not None else None - o, lse = parallel_nsa_fwd( - q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) + o, lse = parallel_nsa_fwd(q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) ctx.save_for_backward(q, k, v, o, lse) ctx.block_indices = block_indices ctx.block_size = block_size @@ -177,7 +197,6 @@ def parallel_nsa_fwd( @torch.compile class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -200,7 +219,8 @@ class ParallelNSAFunction(torch.autograd.Function): window_size=window_size, scale=scale, offsets=offsets, - token_indices=token_indices) + token_indices=token_indices, + ) ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa) ctx.block_indices = block_indices ctx.block_counts = block_counts @@ -212,18 +232,20 @@ class ParallelNSAFunction(torch.autograd.Function): return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa -def parallel_nsa(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_indices: torch.LongTensor, - block_counts: Optional[Union[torch.LongTensor, int]] = None, - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False) -> torch.Tensor: +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: r""" Args: q (torch.Tensor): @@ -263,51 +285,49 @@ def parallel_nsa(q: torch.Tensor, Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), - (q, k, v, block_indices)) - g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') + block_counts = rearrange(block_counts, "b h t -> b t h") assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" if isinstance(block_counts, int): block_indices = block_indices[:, :, :, :block_counts] block_counts = None - o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, - window_size, scale, cu_seqlens) + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) if window_size > 0: o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) else: o = o_slc * g_slc.unsqueeze(-1) if head_first: - o = rearrange(o, 'b t h d -> b h t d') + o = rearrange(o, "b t h d -> b h t d") return o if __name__ == "__main__": B, T, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 32, 1, 32, torch.float16 torch.random.manual_seed(0) - q = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda').requires_grad_(True) - k = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True) - v = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True) - g_slc = torch.ones((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True) - g_swa = torch.ones((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True) - do = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda') - - block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device='cuda') + q = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + k = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True) + v = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True) + g_slc = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda") + + block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device="cuda") for b in range(B): for t in range(T): for h in range(H): i_i = torch.randperm(max(1, (t // block_size)))[:S] - block_indices[b, t, h, :len(i_i)] = i_i + block_indices[b, t, h, : len(i_i)] = i_i block_indices = block_indices.sort(-1)[0] - block_counts = torch.randint(1, S + 1, (B, T, H), device='cuda') + block_counts = torch.randint(1, S + 1, (B, T, H), device="cuda") ref = naive_nsa( q=q, diff --git a/examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py b/examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py index 9ccbff6..cb4eb6d 100644 --- a/examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py +++ b/examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py @@ -8,6 +8,7 @@ import triton import triton.language as tl import fla + if parse(fla.__version__) < parse("0.2.1"): from fla.ops.common.utils import prepare_token_indices else: @@ -17,27 +18,49 @@ from reference import naive_nsa from einops import rearrange -@triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, - 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor), -}) +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], - key=['BS', 'BK', 'BV'], + key=["BS", "BK", "BV"], ) @triton.jit -def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices, - block_counts, offsets, token_indices, T, H: tl.constexpr, - HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr, - S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr, - BV: tl.constexpr, USE_OFFSETS: tl.constexpr, - USE_BLOCK_COUNTS: tl.constexpr): +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: - i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + - 1).to(tl.int32) + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: @@ -52,20 +75,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc else: NS = S - p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), - (1, 0)) + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # the Q block is kept in the shared memory throughout the whole kernel # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) - p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), - (G, BV), (1, 0)) + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_o_slc = tl.zeros([G, BV], dtype=tl.float32) - b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32) + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) b_acc_slc = tl.zeros([G], dtype=tl.float32) for i in range(NS): i_s = tl.load(block_indices + i).to(tl.int32) * BS @@ -78,7 +99,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) # [G, BS] b_s_slc = tl.dot(b_q, b_k_slc) - b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf')) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) # [G] b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc @@ -97,13 +118,12 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty)) if WS > 0: - p_o_swa = tl.make_block_ptr(o_swa + (bos + i_t) * HQ * V, (HQ, V), (V, 1), - (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_o_swa = tl.make_block_ptr(o_swa + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_swa = lse_swa + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_o_swa = tl.zeros([G, BV], dtype=tl.float32) - b_m_swa = tl.full([G], float('-inf'), dtype=tl.float32) + b_m_swa = tl.full([G], float("-inf"), dtype=tl.float32) b_acc_swa = tl.zeros([G], dtype=tl.float32) for i_s in range(max(0, i_t - WS + 1), i_t + 1, BS): p_k_swa = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1)) @@ -114,7 +134,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc b_v_swa = tl.load(p_v_swa, boundary_check=(0, 1)) # [G, BS] b_s_swa = tl.dot(b_q, b_k_swa) - b_s_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_swa, float('-inf')) + b_s_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_swa, float("-inf")) # [G] b_m_swa, b_mp_swa = tl.maximum(b_m_swa, tl.max(b_s_swa, 1)), b_m_swa @@ -196,7 +216,6 @@ def parallel_nsa_fwd( @torch.compile class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -219,7 +238,8 @@ class ParallelNSAFunction(torch.autograd.Function): window_size=window_size, scale=scale, offsets=offsets, - token_indices=token_indices) + token_indices=token_indices, + ) ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa) ctx.block_indices = block_indices ctx.block_counts = block_counts @@ -231,18 +251,20 @@ class ParallelNSAFunction(torch.autograd.Function): return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa -def parallel_nsa(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_indices: torch.LongTensor, - block_counts: Optional[Union[torch.LongTensor, int]] = None, - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False) -> torch.Tensor: +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: r""" Args: q (torch.Tensor): @@ -282,29 +304,27 @@ def parallel_nsa(q: torch.Tensor, Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), - (q, k, v, block_indices)) - g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') + block_counts = rearrange(block_counts, "b h t -> b t h") assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" if isinstance(block_counts, int): block_indices = block_indices[:, :, :, :block_counts] block_counts = None - o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, - window_size, scale, cu_seqlens) + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) if window_size > 0: o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) else: o = o_slc * g_slc.unsqueeze(-1) if head_first: - o = rearrange(o, 'b t h d -> b h t d') + o = rearrange(o, "b t h d -> b h t d") return o @@ -312,38 +332,35 @@ if __name__ == "__main__": N, T, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 64, 1, 32, torch.float16 torch.manual_seed(42) # randomly split the sequence into N segments - offsets = torch.cat([ - torch.tensor([0], dtype=torch.long), - torch.arange(16, T)[torch.randperm(T - 1)[:N - 1]], - torch.tensor([T], dtype=torch.long) - ], 0).cuda().sort()[0] + offsets = ( + torch.cat( + [torch.tensor([0], dtype=torch.long), torch.arange(16, T)[torch.randperm(T - 1)[: N - 1]], torch.tensor([T], dtype=torch.long)], + 0, + ) + .cuda() + .sort()[0] + ) # offsets.shape is [N+1] # seq-first required for inputs with variable lengths - perm_q = torch.randperm(T, device='cuda') - perm_k = torch.randperm(T, device='cuda') - perm_v = torch.randperm(T, device='cuda') - q = torch.linspace( - 0, 1, steps=T, dtype=dtype, - device='cuda')[perm_q].view(1, T, 1, 1).expand(1, T, HQ, D).clone().requires_grad_(True) - k = torch.linspace( - 0, 1, steps=T, dtype=dtype, - device='cuda')[perm_k].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True) - v = torch.linspace( - 0, 1, steps=T, dtype=dtype, - device='cuda')[perm_v].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True) - g_slc = torch.rand((1, T, HQ), dtype=dtype, device='cuda').requires_grad_(True) - g_swa = torch.rand((1, T, HQ), dtype=dtype, device='cuda').requires_grad_(True) - do = torch.randn((1, T, HQ, D), dtype=dtype, device='cuda') + perm_q = torch.randperm(T, device="cuda") + perm_k = torch.randperm(T, device="cuda") + perm_v = torch.randperm(T, device="cuda") + q = torch.linspace(0, 1, steps=T, dtype=dtype, device="cuda")[perm_q].view(1, T, 1, 1).expand(1, T, HQ, D).clone().requires_grad_(True) + k = torch.linspace(0, 1, steps=T, dtype=dtype, device="cuda")[perm_k].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True) + v = torch.linspace(0, 1, steps=T, dtype=dtype, device="cuda")[perm_v].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True) + g_slc = torch.rand((1, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.rand((1, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((1, T, HQ, D), dtype=dtype, device="cuda") token_indices = prepare_token_indices(offsets).tolist() - block_indices = torch.full((1, T, H, S), T, dtype=torch.long, device='cuda') + block_indices = torch.full((1, T, H, S), T, dtype=torch.long, device="cuda") for i in range(T): _, t = token_indices[i] for h in range(H): i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S] - block_indices[0, i, h, :len(i_i)] = i_i + block_indices[0, i, h, : len(i_i)] = i_i block_indices = block_indices.sort(-1)[0] - block_counts = torch.randint(1, S + 1, (1, T, H), device='cuda') + block_counts = torch.randint(1, S + 1, (1, T, H), device="cuda") ref = naive_nsa( q=q, @@ -354,7 +371,8 @@ if __name__ == "__main__": block_indices=block_indices, block_counts=block_counts, block_size=block_size, - cu_seqlens=offsets) + cu_seqlens=offsets, + ) tri = parallel_nsa( q=q, @@ -365,7 +383,8 @@ if __name__ == "__main__": block_indices=block_indices, block_counts=block_counts, block_size=block_size, - cu_seqlens=offsets) + cu_seqlens=offsets, + ) print("tri", tri) print("ref", ref) diff --git a/examples/deepseek_nsa/reference.py b/examples/deepseek_nsa/reference.py index 958d0c1..5808310 100644 --- a/examples/deepseek_nsa/reference.py +++ b/examples/deepseek_nsa/reference.py @@ -6,18 +6,20 @@ from typing import Union from einops import rearrange, repeat -def naive_nsa(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_indices: torch.LongTensor, - block_counts: Optional[Union[torch.LongTensor, int]] = None, - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False) -> torch.Tensor: +def naive_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: r""" Args: q (torch.Tensor): @@ -57,26 +59,24 @@ def naive_nsa(q: torch.Tensor, Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - raise RuntimeError( - "Sequences with variable lengths are not supported for head-first mode") + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), - (q, k, v, block_indices)) - g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') + block_counts = rearrange(block_counts, "b h t -> b t h") dtype = q.dtype G = q.shape[2] // k.shape[2] BS = block_size S = block_indices.shape[-1] - k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices)) + k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices)) if isinstance(block_counts, torch.Tensor): - block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G) + block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G) c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device) q, k, v = map(lambda x: x.float(), (q, k, v)) @@ -86,14 +86,11 @@ def naive_nsa(q: torch.Tensor, if cu_seqlens is None: varlen = False B, T = q.shape[:2] - cu_seqlens = torch.cat( - [block_indices.new_tensor(range(0, B * T, T)), - block_indices.new_tensor([B * T])]) + cu_seqlens = torch.cat([block_indices.new_tensor(range(0, B * T, T)), block_indices.new_tensor([B * T])]) for i in range(len(cu_seqlens) - 1): if not varlen: - q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[ - i], block_indices[i] + q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[i], block_indices[i] if isinstance(block_counts, torch.Tensor): s_b = block_counts[i] else: @@ -101,10 +98,10 @@ def naive_nsa(q: torch.Tensor, else: T = cu_seqlens[i + 1] - cu_seqlens[i] q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = map( - lambda x: x[0][cu_seqlens[i]:cu_seqlens[i + 1]], - (q, k, v, g_slc, g_swa, block_indices)) + lambda x: x[0][cu_seqlens[i] : cu_seqlens[i + 1]], (q, k, v, g_slc, g_swa, block_indices) + ) if isinstance(block_counts, torch.Tensor): - s_b = block_counts[0][cu_seqlens[i]:cu_seqlens[i + 1]] + s_b = block_counts[0][cu_seqlens[i] : cu_seqlens[i + 1]] else: s_b = block_counts @@ -126,34 +123,28 @@ def naive_nsa(q: torch.Tensor, else: s_i = s_b # [S*BS, HQ, -1] - k_i_slc, v_i_slc = map( - lambda x: x.gather( - 0, - i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b)) + k_i_slc, v_i_slc = map(lambda x: x.gather(0, i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b)) # [S*BS, HQ] - attn_slc = torch.einsum('h d, n h d -> n h', q_i, k_i_slc).masked_fill( - torch.logical_or(i_i < 0, i_i > i_q) | - (c >= s_i if block_counts is not None else False), float('-inf')).softmax(0) + attn_slc = ( + torch.einsum("h d, n h d -> n h", q_i, k_i_slc) + .masked_fill(torch.logical_or(i_i < 0, i_i > i_q) | (c >= s_i if block_counts is not None else False), float("-inf")) + .softmax(0) + ) if not varlen: - o_slc[i, i_q] = torch.einsum('n h, n h v -> h v', attn_slc, - v_i_slc) * g_slc_i.unsqueeze(-1) + o_slc[i, i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1) else: - o_slc[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_slc, - v_i_slc) * g_slc_i.unsqueeze(-1) + o_slc[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1) if window_size > 0: - k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1):i_q + 1], - (k_b, v_b)) - attn_swa = torch.einsum('h d, n h d -> n h', q_i, k_i_swa).softmax(0) + k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1) : i_q + 1], (k_b, v_b)) + attn_swa = torch.einsum("h d, n h d -> n h", q_i, k_i_swa).softmax(0) if not varlen: - o_swa[i, i_q] = torch.einsum('n h, n h v -> h v', attn_swa, - v_i_swa) * g_swa_i.unsqueeze(-1) + o_swa[i, i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1) else: - o_swa[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_swa, - v_i_swa) * g_swa_i.unsqueeze(-1) + o_swa[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1) if head_first: - o_slc = rearrange(o_slc, 'b t h d -> b h t d') - o_swa = rearrange(o_swa, 'b t h d -> b h t d') + o_slc = rearrange(o_slc, "b t h d -> b h t d") + o_swa = rearrange(o_swa, "b t h d -> b h t d") return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype) @@ -187,7 +178,7 @@ def naive_nsa_simple( o (torch.Tensor): Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 dtype = q.dtype HQ = q.shape[2] @@ -197,8 +188,8 @@ def naive_nsa_simple( BS = block_size S = block_indices.shape[-1] SELECTED_BLOCKS_SIZE = S * BS - k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices)) - block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G) + k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices)) + block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G) c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device) q, k, v = map(lambda x: x.float(), (q, k, v)) o = torch.zeros_like(v) @@ -228,10 +219,10 @@ def naive_nsa_simple( v_i[t, h] = v_b[selected_block_index, h, :] # [S*BS, HQ] - attn = torch.einsum('h d, n h d -> n h', q_i, k_i) - attn = attn.masked_fill((i_i > i_q) | (c >= s_i), float('-inf')) + attn = torch.einsum("h d, n h d -> n h", q_i, k_i) + attn = attn.masked_fill((i_i > i_q) | (c >= s_i), float("-inf")) attn = torch.softmax(attn, dim=0) - o[i, i_q] = torch.einsum('n h, n h v -> h v', attn, v_i) + o[i, i_q] = torch.einsum("n h, n h v -> h v", attn, v_i) return o.to(dtype) @@ -265,7 +256,7 @@ def naive_nsa_simple_inference( o (torch.Tensor): Outputs of shape `[B, 1, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 dtype = q.dtype HQ = q.shape[2] @@ -275,8 +266,8 @@ def naive_nsa_simple_inference( BS = block_size S = block_indices.shape[-1] SELECTED_BLOCKS_SIZE = S * BS - k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices)) - block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G) + k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices)) + block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G) c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device) q, k, v = map(lambda x: x.float(), (q, k, v)) o = torch.zeros_like(q) @@ -306,9 +297,9 @@ def naive_nsa_simple_inference( v_i[t, h] = v_b[selected_block_index, h, :] # [S*BS, HQ] - attn = torch.einsum('h d, n h d -> n h', q_i, k_i) - attn = attn.masked_fill((c >= s_i), float('-inf')) + attn = torch.einsum("h d, n h d -> n h", q_i, k_i) + attn = attn.masked_fill((c >= s_i), float("-inf")) attn = torch.softmax(attn, dim=0) - o[i, 0] = torch.einsum('n h, n h v -> h v', attn, v_i) + o[i, 0] = torch.einsum("n h, n h v -> h v", attn, v_i) return o.to(dtype) diff --git a/examples/deepseek_v32/fp8_lighting_indexer.py b/examples/deepseek_v32/fp8_lighting_indexer.py index dd94064..305e2af 100644 --- a/examples/deepseek_v32/fp8_lighting_indexer.py +++ b/examples/deepseek_v32/fp8_lighting_indexer.py @@ -28,11 +28,11 @@ def validate_tensor_match(a, b, tolerance=1e-8, tensor_name="tensor", should_rai if should_raise: assert False if not torch.isclose( - a.masked_fill(a_finite, 0), - b.masked_fill(b_finite, 0), - rtol=0, - atol=0, - equal_nan=True, + a.masked_fill(a_finite, 0), + b.masked_fill(b_finite, 0), + rtol=0, + atol=0, + equal_nan=True, ).all(): display_error_message(f"{tensor_name} Error: nonfinite value mismatch") if should_raise: @@ -55,13 +55,10 @@ def get_configs(): threads=[128, 256], block_Q=[1, 2, 4], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] class SupplyProg: - def __init__(self): self.tensors_dict = {} @@ -88,7 +85,8 @@ supply_prog = SupplyProg() @tilelang.jit( pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - },) + }, +) def mqa_attn_return_logits( heads, index_dim, @@ -113,16 +111,15 @@ def mqa_attn_return_logits( @T.prim_func def mqa_attn_return_logits_kernel( - IndexQ: T.Tensor(index_q_shape, dtype), # type: ignore - IndexK: T.Tensor(index_k_shape, dtype), # type: ignore - IndexKScale: T.Tensor(index_k_scale_shape, accum_dtype), # type: ignore - Logits: T.Tensor(logits_shape, accum_dtype), # type: ignore - Weights: T.Tensor([seq_len, heads], accum_dtype), # type: ignore - CuSeqLenKS: T.Tensor([seq_len], index_dtype), # type: ignore - CuSeqLenKE: T.Tensor([seq_len], index_dtype), # type: ignore + IndexQ: T.Tensor(index_q_shape, dtype), # type: ignore + IndexK: T.Tensor(index_k_shape, dtype), # type: ignore + IndexKScale: T.Tensor(index_k_scale_shape, accum_dtype), # type: ignore + Logits: T.Tensor(logits_shape, accum_dtype), # type: ignore + Weights: T.Tensor([seq_len, heads], accum_dtype), # type: ignore + CuSeqLenKS: T.Tensor([seq_len], index_dtype), # type: ignore + CuSeqLenKE: T.Tensor([seq_len], index_dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_Q), threads=threads) as bx: - index_q_shared = T.alloc_shared([block_Q * heads, index_dim], dtype) index_k_shared = T.alloc_shared([block_N, index_dim], dtype) index_k_scale_fragment = T.alloc_fragment([block_N], accum_dtype) @@ -140,17 +137,14 @@ def mqa_attn_return_logits( cu_k_e_max[0] = -2147483648 for bq_i in T.serial(block_Q): - cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i], - seq_len_kv)) + cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i], seq_len_kv)) for bq_i in T.serial(block_Q): - cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i], - seq_len_kv)) + cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i], seq_len_kv)) T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared) T.copy(Weights[seq_len_i, 0], weights) - for nbn_i in T.Pipelined( - T.ceildiv(cu_k_e_max[0] - cu_k_s_min[0], block_N), num_stages=num_stages): + for nbn_i in T.Pipelined(T.ceildiv(cu_k_e_max[0] - cu_k_s_min[0], block_N), num_stages=num_stages): T.copy(IndexK[cu_k_s_min[0] + nbn_i * block_N, 0], index_k_shared) T.copy(IndexKScale[cu_k_s_min[0] + nbn_i * block_N], index_k_scale_fragment) @@ -164,15 +158,14 @@ def mqa_attn_return_logits( ) for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads): - s_reshaped[bn_i, bq_i, - h_i] = (T.max(s_reshaped[bn_i, bq_i, h_i], 0) * - weights[bq_i, h_i]) * index_k_scale_fragment[bn_i] + s_reshaped[bn_i, bq_i, h_i] = (T.max(s_reshaped[bn_i, bq_i, h_i], 0) * weights[bq_i, h_i]) * index_k_scale_fragment[ + bn_i + ] T.reduce_sum(s_reshaped, logits, dim=-1, clear=True) for bq_i, bn_i in T.Parallel(block_Q, block_N): - Logits[seq_len_i + bq_i, cu_k_s_min[0] + nbn_i * block_N + bn_i] = ( - logits[bn_i, bq_i]) + Logits[seq_len_i + bq_i, cu_k_s_min[0] + nbn_i * block_N + bn_i] = logits[bn_i, bq_i] return mqa_attn_return_logits_kernel @@ -190,9 +183,9 @@ def clean_logits_( @T.prim_func def clean_logits_kernel( - Logits: T.Tensor([seq_len, seq_len_kv], dtype), # type: ignore - CuSeqLenKS: T.Tensor([seq_len], indices_dtype), # type: ignore - CuSeqLenKE: T.Tensor([seq_len], indices_dtype), # type: ignore + Logits: T.Tensor([seq_len, seq_len_kv], dtype), # type: ignore + CuSeqLenKS: T.Tensor([seq_len], indices_dtype), # type: ignore + CuSeqLenKE: T.Tensor([seq_len], indices_dtype), # type: ignore ): with T.Kernel(seq_len, threads=threads) as bx: tx = T.thread_binding(0, threads, thread="threadIdx.x") @@ -210,13 +203,7 @@ def clean_logits_( return clean_logits_kernel -def mqa_attn_return_logits_interface(q, - kv, - kv_scales, - weights, - cu_seqlen_ks, - cu_seqlen_ke, - clean_logits=True): +def mqa_attn_return_logits_interface(q, kv, kv_scales, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits=True): seq_len, heads, index_dim = q.shape seq_len_kv = kv.shape[0] @@ -238,20 +225,19 @@ def mqa_attn_return_logits_interface(q, return logits -def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, - cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor): +def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor): k = kv q = q.float() k = k.float() seq_len_kv = kv.shape[0] - mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= cu_seqlen_ks[:, None] - mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < cu_seqlen_ke[:, None] + mask_lo = torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None] + mask_hi = torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] mask = mask_lo & mask_hi - score = torch.einsum('mhd,nd->hmn', q, k) + score = torch.einsum("mhd,nd->hmn", q, k) logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) - logits = logits.masked_fill(~mask, float('-inf')) + logits = logits.masked_fill(~mask, float("-inf")) cost = mask.sum() return logits, cost @@ -265,32 +251,22 @@ def test_fp8_lighting_indexer(S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1): weights = torch.randn(S, H, device="cuda", dtype=torch.float32) p = (torch.randn(S, SKV, device="cuda", dtype=torch.float32) * 4).softmax(dim=-1) - ks, ke = generate_random_cu_seqlens( - per_cp_seqlen=S, cp_size=4, cp_rank=3, kv_stride=kv_stride, average_q_len=2048) + ks, ke = generate_random_cu_seqlens(per_cp_seqlen=S, cp_size=4, cp_rank=3, kv_stride=kv_stride, average_q_len=2048) - logits_ref, cost_ref = ref_fp8_mqa_logits( - q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) + logits_ref, cost_ref = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) q_fp8 = q.to(torch.float8_e4m3fn) kv_fp8, kv_scales = per_custom_dims_cast_to_fp8(kv, (0,), False) - logits_tl = mqa_attn_return_logits_interface( - q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) - diff = validate_tensor_match( - logits_ref, logits_tl, tolerance=1e-14, tensor_name="logits", should_raise=False) + logits_tl = mqa_attn_return_logits_interface(q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) + diff = validate_tensor_match(logits_ref, logits_tl, tolerance=1e-14, tensor_name="logits", should_raise=False) print(f"diff: {diff}") from tilelang.profiler import do_bench def logits_fn(): - return mqa_attn_return_logits_interface( - q=q_fp8, - kv=kv_fp8, - kv_scales=kv_scales, - weights=weights, - cu_seqlen_ks=ks, - cu_seqlen_ke=ke) + return mqa_attn_return_logits_interface(q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: logits_fn() diff --git a/examples/deepseek_v32/sparse_mla_bwd.py b/examples/deepseek_v32/sparse_mla_bwd.py index 4ff3b81..1266e70 100644 --- a/examples/deepseek_v32/sparse_mla_bwd.py +++ b/examples/deepseek_v32/sparse_mla_bwd.py @@ -22,9 +22,9 @@ def preprocess( @T.prim_func def preprocess_kernel( - O: T.Tensor(shape, dtype), - dO: T.Tensor(shape, dtype), - Delta: T.Tensor([B, S, H], accum_dtype), + O: T.Tensor(shape, dtype), + dO: T.Tensor(shape, dtype), + Delta: T.Tensor([B, S, H], accum_dtype), ): with T.Kernel(H, T.ceildiv(S, block_ND), B) as (bx, by, bz): o = T.alloc_fragment([block_ND, block_ND], accum_dtype) @@ -33,16 +33,12 @@ def preprocess( acc = T.alloc_fragment([block_ND, block_ND], accum_dtype) T.clear(acc) for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages): - T.copy( - O[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], - o) - T.copy( - dO[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], - do) + T.copy(O[bz, by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], o) + T.copy(dO[bz, by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], do) for i, j in T.Parallel(block_ND, block_ND): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, by * block_ND:(by + 1) * block_ND, bx]) + T.copy(delta, Delta[bz, by * block_ND : (by + 1) * block_ND, bx]) return preprocess_kernel @@ -65,13 +61,13 @@ def postprocess( @T.prim_func def postprocess_kernel( - dKV: T.Tensor(dkv_shape, accum_dtype), - dKV_out: T.Tensor(dkv_shape, dtype), + dKV: T.Tensor(dkv_shape, accum_dtype), + dKV_out: T.Tensor(dkv_shape, dtype), ): with T.Kernel(T.ceildiv(S_kv, block_N), kv_group, B, threads=threads) as (bx, by, bz): T.copy( - dKV[bz, bx * block_N:(bx + 1) * block_N, by, :], - dKV_out[bz, bx * block_N:(bx + 1) * block_N, by, :], + dKV[bz, bx * block_N : (bx + 1) * block_N, by, :], + dKV_out[bz, bx * block_N : (bx + 1) * block_N, by, :], ) return postprocess_kernel @@ -83,7 +79,8 @@ def postprocess( tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True, - }) + }, +) def bwd( B, S, @@ -102,14 +99,14 @@ def bwd( dtype="bfloat16", accum_dtype="float", ): - assert is_causal == True, 'non-casual is not supported now' - assert topk % block_size == 0, 'otherwise will load some index=0 thus causing wrong kv to be loaded' + assert is_causal == True, "non-casual is not supported now" + assert topk % block_size == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" assert dtype == "bfloat16" assert accum_dtype == "float" assert indices_dtype == "int32" if sm_scale is None: - sm_scale = (D + D_tail)**(-0.5) + sm_scale = (D + D_tail) ** (-0.5) sm_scale_mul_reciprocal_log2 = sm_scale * 1.44269504 # log2(e) H_kv = H // kv_group @@ -132,14 +129,14 @@ def bwd( @T.prim_func def sparse_mla_bwd_kernel( - Q: T.Tensor(q_shape, dtype), - KV: T.Tensor(k_shape, dtype), - dO: T.Tensor(o_shape, dtype), - Indices: T.Tensor(indices_shape, indices_dtype), - Lse: T.Tensor(lse_shape, accum_dtype), - Delta: T.Tensor(delta_shape, accum_dtype), - dQ: T.Tensor(q_shape, dtype), - dKV: T.Tensor(k_shape, accum_dtype), + Q: T.Tensor(q_shape, dtype), + KV: T.Tensor(k_shape, dtype), + dO: T.Tensor(o_shape, dtype), + Indices: T.Tensor(indices_shape, indices_dtype), + Lse: T.Tensor(lse_shape, accum_dtype), + Delta: T.Tensor(delta_shape, accum_dtype), + dQ: T.Tensor(q_shape, dtype), + dKV: T.Tensor(k_shape, accum_dtype), ): with T.Kernel(S, B, kv_group, threads=threads) as (s_i, by, bz): Q_shared = T.alloc_shared([padded_H, D], dtype) @@ -165,17 +162,19 @@ def bwd( max_kv_i = s_i - T.copy(Q[by, s_i, bz * padded_H:(bz + 1) * padded_H, :D], Q_shared) - T.copy(Q[by, s_i, bz * padded_H:(bz + 1) * padded_H, D:], Q_tail_shared) - T.copy(dO[by, s_i, bz * padded_H:(bz + 1) * padded_H, :D], dO_shared) + T.copy(Q[by, s_i, bz * padded_H : (bz + 1) * padded_H, :D], Q_shared) + T.copy(Q[by, s_i, bz * padded_H : (bz + 1) * padded_H, D:], Q_tail_shared) + T.copy(dO[by, s_i, bz * padded_H : (bz + 1) * padded_H, :D], dO_shared) T.clear(acc_dq) T.clear(acc_dq_tail) - T.annotate_layout({ - dQ_shared: tilelang.layout.make_swizzled_layout(dQ_shared), - dQ_tail_shared: tilelang.layout.make_swizzled_layout(dQ_tail_shared), - }) + T.annotate_layout( + { + dQ_shared: tilelang.layout.make_swizzled_layout(dQ_shared), + dQ_tail_shared: tilelang.layout.make_swizzled_layout(dQ_tail_shared), + } + ) # Process each block of indices for i_i in T.Pipelined(NS, num_stages=num_stages): @@ -191,62 +190,31 @@ def bwd( for bi_i, d_i in T.Parallel(BS, D): KV_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BS + bi_i], bz, d_i] - T.gemm( - Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) for bi_i, d_i in T.Parallel(BS, D_tail): - KV_tail_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BS + bi_i], bz, - D + d_i] - T.gemm( - Q_tail_shared, - KV_tail_shared, - acc_p, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) + KV_tail_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BS + bi_i], bz, D + d_i] + T.gemm(Q_tail_shared, KV_tail_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) for h_i, bi_i in T.Parallel(padded_H, BS): - acc_p[h_i, bi_i] = T.exp2(acc_p[h_i, bi_i] * sm_scale_mul_reciprocal_log2 - - Lse[by, s_i, bz * padded_H + h_i]) + acc_p[h_i, bi_i] = T.exp2(acc_p[h_i, bi_i] * sm_scale_mul_reciprocal_log2 - Lse[by, s_i, bz * padded_H + h_i]) T.copy(acc_p, P_shared_cast) - T.gemm( - dO_shared, - KV_shared, - acc_dp, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol, - clear_accum=True) + T.gemm(dO_shared, KV_shared, acc_dp, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) for h_i, bi_i in T.Parallel(padded_H, BS): - acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * ( - acc_dp[h_i, bi_i] - Delta[by, s_i, bz * padded_H + h_i]) * sm_scale + acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[by, s_i, bz * padded_H + h_i]) * sm_scale T.copy(acc_dp, dP_shared_cast) T.gemm(dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol) T.gemm(dP_shared_cast, KV_tail_shared, acc_dq_tail, policy=T.GemmWarpPolicy.FullCol) - T.gemm( - dP_shared_cast, - Q_shared, - acc_dkv, - transpose_A=True, - policy=T.GemmWarpPolicy.FullCol, - clear_accum=True) - T.gemm( - P_shared_cast, - dO_shared, - acc_dkv, - transpose_A=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(dP_shared_cast, Q_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) + T.gemm(P_shared_cast, dO_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol) T.clear(acc_dkv_tail) - T.gemm( - dP_shared_cast, - Q_tail_shared, - acc_dkv_tail, - transpose_A=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(dP_shared_cast, Q_tail_shared, acc_dkv_tail, transpose_A=True, policy=T.GemmWarpPolicy.FullCol) for s in range(split_store): for bi_i, d_i in T.Parallel(BS, D): @@ -255,41 +223,32 @@ def bwd( for bi_i, d_i in T.Parallel(BS, D_tail): if bi_i < BS // split_store: - acc_dkv_tail_shared[bi_i, - d_i] = acc_dkv_tail[bi_i + s * (BS // split_store), - d_i] + acc_dkv_tail_shared[bi_i, d_i] = acc_dkv_tail[bi_i + s * (BS // split_store), d_i] for bi_i, d_i in T.Parallel(BS // split_store, D // 4): T.atomic_addx4( - dKV[by, Indices[by, s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], - bz, d_i * 4], acc_dkv_shared[bi_i, d_i * 4]) + dKV[by, Indices[by, s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, d_i * 4], + acc_dkv_shared[bi_i, d_i * 4], + ) # Atomically update dKV, dKV_tail tensors for bi_i, d_i in T.Parallel(BS // split_store, D_tail // 4): T.atomic_addx4( - dKV[by, Indices[by, s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], - bz, D + d_i * 4], acc_dkv_tail_shared[bi_i, d_i * 4]) + dKV[by, Indices[by, s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, D + d_i * 4], + acc_dkv_tail_shared[bi_i, d_i * 4], + ) # Store the accumulated dQ T.copy(acc_dq, dQ_shared) T.copy(acc_dq_tail, dQ_tail_shared) - T.copy(dQ_shared, dQ[by, s_i, bz * padded_H:(bz + 1) * padded_H, :D]) - T.copy(dQ_tail_shared, dQ[by, s_i, bz * padded_H:(bz + 1) * padded_H, D:]) + T.copy(dQ_shared, dQ[by, s_i, bz * padded_H : (bz + 1) * padded_H, :D]) + T.copy(dQ_tail_shared, dQ[by, s_i, bz * padded_H : (bz + 1) * padded_H, D:]) return sparse_mla_bwd_kernel -def sparse_mla_bwd(q, - kv, - o, - do, - indices, - lse, - sm_scale=None, - is_casual=True, - return_kernel=False, - delta=None): +def sparse_mla_bwd(q, kv, o, do, indices, lse, sm_scale=None, is_casual=True, return_kernel=False, delta=None): assert q.is_contiguous() assert kv.is_contiguous() assert indices.is_contiguous() @@ -322,6 +281,7 @@ def sparse_mla_bwd(q, def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, sm_scale=None, is_casual=True): from sparse_mla_fwd import ref_sparse_mla_fwd_interface + q = q.detach().clone() kv = kv.detach().clone() q.requires_grad = True @@ -331,30 +291,22 @@ def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, sm_scale=None, is_c return q.grad, kv.grad -def test_sparse_mla_bwd(B=1, - S=4096, - SKV=8192, - H=64, - HKV=1, - DQKV=576, - DV=512, - topk=2048, - dtype=torch.bfloat16, - check_correctness=True): +def test_sparse_mla_bwd(B=1, S=4096, SKV=8192, H=64, HKV=1, DQKV=576, DV=512, topk=2048, dtype=torch.bfloat16, check_correctness=True): # Prepare data - q = torch.randn((B, S, H, DQKV), dtype=dtype, device='cuda').requires_grad_(True) - kv = torch.randn((B, SKV, HKV, DQKV), dtype=dtype, device='cuda').requires_grad_(True) - do = torch.randn((B, S, H, DV), dtype=dtype, device='cuda') + q = torch.randn((B, S, H, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + kv = torch.randn((B, SKV, HKV, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((B, S, H, DV), dtype=dtype, device="cuda") - indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device='cuda') + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") for b in range(B): for t in range(S): for h in range(HKV): i_i = torch.randperm(max(1, t))[:topk] - indices[b, t, h, :len(i_i)] = i_i + indices[b, t, h, : len(i_i)] = i_i # Forward from sparse_mla_fwd import sparse_mla_fwd_interface + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices) tl_dq, tl_dkv = sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse) @@ -365,13 +317,15 @@ def test_sparse_mla_bwd(B=1, assert_tensors_similar(tl_dkv, ref_dkv, eps=1e-4, name="dkv") print("assert_tensors_similar passed") - per_token_flop = 2 * sum([ - H * DV * topk, - H * DQKV * topk, - H * DQKV * topk, - H * DQKV * topk, - H * DV * topk, - ]) + per_token_flop = 2 * sum( + [ + H * DV * topk, + H * DQKV * topk, + H * DQKV * topk, + H * DQKV * topk, + H * DV * topk, + ] + ) from tilelang.profiler import do_bench def fn(): @@ -379,20 +333,9 @@ def test_sparse_mla_bwd(B=1, ms = do_bench(fn, rep=100, warmup=250) print(f"Average time: {ms:.3f} ms") - print(f'bwd io bandwidth = ', - (B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12) - print(f'bwd tflops = ', per_token_flop * S / (ms * 1e-3) / 1e12) + print(f"bwd io bandwidth = ", (B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12) + print(f"bwd tflops = ", per_token_flop * S / (ms * 1e-3) / 1e12) if __name__ == "__main__": - test_sparse_mla_bwd( - B=1, - S=4096, - SKV=8192, - H=64, - HKV=1, - DQKV=576, - DV=512, - topk=2048, - dtype=torch.bfloat16, - check_correctness=True) + test_sparse_mla_bwd(B=1, S=4096, SKV=8192, H=64, HKV=1, DQKV=576, DV=512, topk=2048, dtype=torch.bfloat16, check_correctness=True) diff --git a/examples/deepseek_v32/sparse_mla_fwd.py b/examples/deepseek_v32/sparse_mla_fwd.py index e65b890..3b963c7 100644 --- a/examples/deepseek_v32/sparse_mla_fwd.py +++ b/examples/deepseek_v32/sparse_mla_fwd.py @@ -25,15 +25,12 @@ def sparse_mla_fwd( num_stages=2, threads=256, ): - assert dim == tilelang.math.next_power_of_2( - dim), f"haven't check padding correctness yet, dim={dim}" - assert tail_dim == tilelang.math.next_power_of_2( - tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" assert is_causal == True, "non-casual is not supported" - assert (topk % - block_I == 0), "otherwise will load some index=0 thus causing wrong kv to be loaded" + assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" if sm_scale is None: - sm_scale = (1.0 / (dim + tail_dim))**0.5 * 1.44269504 # log2(e) + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e) else: sm_scale = sm_scale * 1.44269504 # log2(e) @@ -55,9 +52,9 @@ def sparse_mla_fwd( H = head_kv padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) if padded_H != H: - assert ( - kv_group == 1 - ), "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + assert kv_group == 1, ( + "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + ) BI = block_I NI = tilelang.cdiv(topk, block_I) D = dim @@ -73,18 +70,17 @@ def sparse_mla_fwd( @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), # type: ignore - KV: T.Tensor(kv_shape, dtype), # type: ignore - Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore - Output: T.Tensor(o_shape, dtype), # type: ignore - Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + Output: T.Tensor(o_shape, dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore ): - with T.Kernel( - seq_len * REPLICATE_H, batch, kv_group, threads=threads) as ( - bx, - by, - bz, - ): + with T.Kernel(seq_len * REPLICATE_H, batch, kv_group, threads=threads) as ( + bx, + by, + bz, + ): Q_shared = T.alloc_shared([H_per_block, D], dtype) Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) KV_shared = T.alloc_shared([BI, D], dtype) @@ -118,16 +114,13 @@ def sparse_mla_fwd( T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared) for i_i in T.Pipelined(NI, num_stages=num_stages): - for bi_i in T.Parallel(BI): mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] <= max_kv_i for bi_i, d_i in T.Parallel(BI, D): - KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, - d_i] + KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i] for bi_i, d_i in T.Parallel(BI, D_tail): - K_tail_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, - D + d_i] + K_tail_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, D + d_i] for h_i, bi_i in T.Parallel(H_per_block, BI): acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype)) @@ -176,15 +169,7 @@ def sparse_mla_fwd( return main -def sparse_mla_fwd_interface(q, - kv, - indices, - sm_scale=None, - return_p_sum: bool = False, - d_v=512, - block_I=64, - num_stages=2, - threads=256): +def sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, return_p_sum: bool = False, d_v=512, block_I=64, num_stages=2, threads=256): is_casual = True assert return_p_sum == False, "This kernel file is for fwd only" assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() @@ -201,16 +186,8 @@ def sparse_mla_fwd_interface(q, assert indices.shape == (batch, seq_len, kv_group, topk) kernel = sparse_mla_fwd( - heads, - dim, - tail_dim, - topk, - kv_group, - sm_scale, - is_casual, - block_I=block_I, - num_stages=num_stages, - threads=threads) + heads, dim, tail_dim, topk, kv_group, sm_scale, is_casual, block_I=block_I, num_stages=num_stages, threads=threads + ) out, lse = kernel(q, kv, indices) return out, lse @@ -230,14 +207,14 @@ def ref_sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, is_casual=True): b, _, _, dim_v = v.shape g_index = g h_index = h // g - compressed_casual_mask = torch.arange( - 0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange( - 1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda").view(1, -1) + compressed_casual_mask = torch.arange(0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange( + 1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda" + ).view(1, -1) mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) mask = mask[..., :-1] mask = mask & compressed_casual_mask.view(1, 1, sq, sk) - mask[:, :, :1 - 1, 0] = True + mask[:, :, : 1 - 1, 0] = True mask = mask.view(b, g_index, 1, sq, sk) q = q.view(b, sq, g, -1, dim_q) @@ -252,19 +229,21 @@ def ref_sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, is_casual=True): return o.to(torch.bfloat16) -def test_sparse_mla_fwd(B=1, - S=4096, - SKV=8192, - H=128, - HKV=1, - DQK=576, - DV=512, - topk=2048, - dtype=torch.bfloat16, - check_correctness=True, - block_I=64, - num_stages=2, - threads=256): +def test_sparse_mla_fwd( + B=1, + S=4096, + SKV=8192, + H=128, + HKV=1, + DQK=576, + DV=512, + topk=2048, + dtype=torch.bfloat16, + check_correctness=True, + block_I=64, + num_stages=2, + threads=256, +): torch.random.manual_seed(0) q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) @@ -274,10 +253,9 @@ def test_sparse_mla_fwd(B=1, for t in range(S): for h in range(HKV): i_i = torch.randperm(max(1, t))[:topk] - indices[b, t, h, :len(i_i)] = i_i + indices[b, t, h, : len(i_i)] = i_i - tl_out, tl_lse = sparse_mla_fwd_interface( - q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads) + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads) if check_correctness: # otherwise may cause out of memory @@ -286,8 +264,7 @@ def test_sparse_mla_fwd(B=1, print("assert_tensors_similar passed") def fn(): - return sparse_mla_fwd_interface( - q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads) + return sparse_mla_fwd_interface(q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads) from tilelang.profiler import do_bench @@ -315,4 +292,5 @@ if __name__ == "__main__": check_correctness=True, block_I=64, num_stages=2, - threads=256) + threads=256, + ) diff --git a/examples/deepseek_v32/sparse_mla_fwd_pipelined.py b/examples/deepseek_v32/sparse_mla_fwd_pipelined.py index 1621d85..972160c 100644 --- a/examples/deepseek_v32/sparse_mla_fwd_pipelined.py +++ b/examples/deepseek_v32/sparse_mla_fwd_pipelined.py @@ -9,10 +9,16 @@ import argparse @tilelang.jit( out_idx=[-2, -1], compile_flags=[ - "-O3", "-Wno-deprecated-declarations", "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "--expt-relaxed-constexpr", "--expt-extended-lambda", - "--ptxas-options=-v,--register-usage-level=10", "-DNDEBUG" + "-O3", + "-Wno-deprecated-declarations", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--ptxas-options=-v,--register-usage-level=10", + "-DNDEBUG", ], ) def sparse_mla_fwd( @@ -32,14 +38,12 @@ def sparse_mla_fwd( num_stages=0, threads=384, ): - assert dim == tilelang.math.next_power_of_2( - dim), f"haven't check padding correctness yet, dim={dim}" - assert tail_dim == tilelang.math.next_power_of_2( - tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" - assert is_causal == True, 'non-casual is not supported' - assert topk % block_I == 0, 'otherwise will load some index=0 thus causing wrong kv to be loaded' + assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert is_causal == True, "non-casual is not supported" + assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" if sm_scale is None: - sm_scale = (1.0 / (dim + tail_dim))**0.5 * 1.44269504 # log2(e) + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e) else: sm_scale = sm_scale * 1.44269504 # log2(e) @@ -57,15 +61,17 @@ def sparse_mla_fwd( H = head_kv padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) if padded_H != H: - assert kv_group == 1, 'here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)' + assert kv_group == 1, ( + "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + ) BI = block_I NI = tilelang.cdiv(topk, block_I) - assert NI % 2 == 0, 'NI should be a multiple of 2' + assert NI % 2 == 0, "NI should be a multiple of 2" D = dim D_tail = tail_dim KV_stride = kv_stride if head_kv > 64: - assert head_kv % 64 == 0, 'head_kv should be a multiple of 64' + assert head_kv % 64 == 0, "head_kv should be a multiple of 64" REPLICATE_H = head_kv // 64 else: REPLICATE_H = 1 @@ -74,18 +80,14 @@ def sparse_mla_fwd( @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), # type: ignore - KV: T.Tensor(kv_shape, dtype), # type: ignore - Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore - q_start_index_s: T.Tensor(1, indices_dtype), - Output: T.Tensor(o_shape, dtype), # type: ignore - Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + q_start_index_s: T.Tensor(1, indices_dtype), + Output: T.Tensor(o_shape, dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore ): - with T.Kernel( - (seq_len - kv_stride + 1 if CP0 else seq_len) * REPLICATE_H, - batch, - kv_group, - threads=threads) as (bx, by, bz): + with T.Kernel((seq_len - kv_stride + 1 if CP0 else seq_len) * REPLICATE_H, batch, kv_group, threads=threads) as (bx, by, bz): Q_shared_l = T.alloc_shared([H_per_block, D // 2], dtype) Q_shared_r = T.alloc_shared([H_per_block, D // 2], dtype) Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) @@ -122,8 +124,7 @@ def sparse_mla_fwd( bar_sScale_and_sS_free = T.alloc_barrier(arrive_count=256) b_i, g_i = by, bz - s_i = (bx + (KV_stride - 1 if CP0 else 0)) if REPLICATE_H == 1 else ( - bx // REPLICATE_H + (KV_stride - 1 if CP0 else 0)) + s_i = (bx + (KV_stride - 1 if CP0 else 0)) if REPLICATE_H == 1 else (bx // REPLICATE_H + (KV_stride - 1 if CP0 else 0)) q_i = q_start_index_s[0] + s_i max_kv_i = (q_i + 1 - KV_stride) // KV_stride @@ -132,26 +133,24 @@ def sparse_mla_fwd( tx = T.get_thread_binding() - T.copy(Q[b_i, s_i, H0:H1, 0:D // 2], Q_shared_l) - T.copy(Q[b_i, s_i, H0:H1, D // 2:D], Q_shared_r) + T.copy(Q[b_i, s_i, H0:H1, 0 : D // 2], Q_shared_l) + T.copy(Q[b_i, s_i, H0:H1, D // 2 : D], Q_shared_r) T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared) T.barrier_arrive(bar_q) if tx < 128: T.set_max_nreg(240, 1) T.fill(sumexp, 0) - T.fill(m_i, -2**30) # avoid -inf - inf to cause nan + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan T.fill(acc_o_l, 0) T.barrier_wait(bar_q, 0) for i_i in T.serial(T.ceildiv(NI, 2)): - # Buffer 0 T.barrier_wait(bar_k_0_ready[0], (i_i & 1)) for h_i, bi_i in T.Parallel(H_per_block, BI): - acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, - -T.infinity(acc_s.dtype)) + acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, -T.infinity(acc_s.dtype)) T.gemm(Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1) T.gemm(Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1) T.gemm(Q_tail_shared, K_tail_shared_0, acc_s, transpose_B=True, wg_wait=-1) @@ -187,8 +186,7 @@ def sparse_mla_fwd( T.barrier_wait(bar_k_1_ready[0], (i_i & 1)) for h_i, bi_i in T.Parallel(H_per_block, BI): - acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, - -T.infinity(acc_s.dtype)) + acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, -T.infinity(acc_s.dtype)) T.gemm(Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1) T.gemm(Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1) T.gemm(Q_tail_shared, K_tail_shared_1, acc_s, transpose_B=True, wg_wait=-1) @@ -227,7 +225,7 @@ def sparse_mla_fwd( for h_i in T.Parallel(H_per_block): sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale T.copy(acc_o_l, O_shared_l) - T.copy(O_shared_l, Output[b_i, s_i, H0:H1, 0:D // 2]) + T.copy(O_shared_l, Output[b_i, s_i, H0:H1, 0 : D // 2]) elif tx >= 128 and tx < 256: T.set_max_nreg(168, 1) @@ -257,7 +255,7 @@ def sparse_mla_fwd( acc_o_r[h_i, d_i] /= sum_exp_shared[h_i] T.copy(acc_o_r, O_shared_r) - T.copy(O_shared_r, Output[b_i, s_i, H0:H1, D // 2:D]) + T.copy(O_shared_r, Output[b_i, s_i, H0:H1, D // 2 : D]) elif tx >= 256: # producer T.set_max_nreg(80, 0) @@ -265,70 +263,58 @@ def sparse_mla_fwd( # Buffer 0 T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): - indices_local[0] = Indices[b_i, s_i, g_i, - (i_i * 2) * BI + r * 16 + (tx - 256) // 8] + indices_local[0] = Indices[b_i, s_i, g_i, (i_i * 2) * BI + r * 16 + (tx - 256) // 8] is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i if is_kv_valid[r * 16 + (tx - 256) // 8]: with T.attr("default", "async_scope", 1): for u in T.serial(4): for v in T.vectorized(8): - KV_shared_0_l[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, - 64 * u + (tx - 256) % 8 * 8 + v] - KV_shared_0_r[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, D // 2 + - 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local[0], g_i, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local[0], g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] with T.attr("default", "async_scope", 1): for v in T.vectorized(8): - K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, - D + (tx - 256) % 8 * 8 + v] + K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local[0], g_i, D + (tx - 256) % 8 * 8 + v + ] T.cp_async_barrier_noinc(bar_k_0_ready[0]) # Buffer 1 T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): - indices_local[0] = Indices[b_i, s_i, g_i, - (i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8] + indices_local[0] = Indices[b_i, s_i, g_i, (i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8] is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i if is_kv_valid[r * 16 + (tx - 256) // 8]: with T.attr("default", "async_scope", 1): for u in T.serial(4): for v in T.vectorized(8): - KV_shared_1_l[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, - 64 * u + (tx - 256) % 8 * 8 + v] - KV_shared_1_r[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, D // 2 + - 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local[0], g_i, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local[0], g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] with T.attr("default", "async_scope", 1): for v in T.vectorized(8): - K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, - D + (tx - 256) % 8 * 8 + v] + K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local[0], g_i, D + (tx - 256) % 8 * 8 + v + ] T.cp_async_barrier_noinc(bar_k_1_ready[0]) return main -def sparse_mla_fwd_interface(q, - kv, - indices, - q_start_index_s, - kv_stride, - sm_scale=None, - is_casual=True, - return_kernel=False, - print_kernel=False): +def sparse_mla_fwd_interface( + q, kv, indices, q_start_index_s, kv_stride, sm_scale=None, is_casual=True, return_kernel=False, print_kernel=False +): assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() batch, seq_len, heads, dim_plus_tail_dim = q.shape _, seq_len_kv, kv_group, _ = kv.shape - assert dim_plus_tail_dim == 576, 'you should assign dim otherwise' + assert dim_plus_tail_dim == 576, "you should assign dim otherwise" dim = 512 assert kv.shape[-1] == dim_plus_tail_dim @@ -338,29 +324,23 @@ def sparse_mla_fwd_interface(q, assert indices.shape == (batch, seq_len, kv_group, topk) if q_start_index_s != 0: - assert q_start_index_s > kv_stride, "If it is because each cp has too short length, you should fix the logic involving CP0 (cp_rank == 0), to make sure q with pos < KV_Stride - 1 is masked (or you may just ignore how this is handled if nan in these q's Out would not effect others, which is reported to be likely to happen by wangding)" + assert q_start_index_s > kv_stride, ( + "If it is because each cp has too short length, you should fix the logic involving CP0 (cp_rank == 0), to make sure q with pos < KV_Stride - 1 is masked (or you may just ignore how this is handled if nan in these q's Out would not effect others, which is reported to be likely to happen by wangding)" + ) CP0 = q_start_index_s == 0 - kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, - kv_group, sm_scale, is_casual, CP0) + kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, kv_group, sm_scale, is_casual, CP0) if print_kernel: print(kernel.get_kernel_source()) - out, lse = kernel(q, kv, indices, - torch.tensor([q_start_index_s], dtype=torch.int32, device="cuda")) + out, lse = kernel(q, kv, indices, torch.tensor([q_start_index_s], dtype=torch.int32, device="cuda")) if return_kernel: return kernel if q_start_index_s == 0 and kv_stride > 1: - out[:, :kv_stride - 1, :, :] = 0 + out[:, : kv_stride - 1, :, :] = 0 return out, lse -def ref_sparse_mla_fwd_interface(q, - kv, - indices, - q_start_index_s, - kv_stride=4, - sm_scale=None, - is_casual=True): +def ref_sparse_mla_fwd_interface(q, kv, indices, q_start_index_s, kv_stride=4, sm_scale=None, is_casual=True): q = q.float() kv = kv.float() indices = indices.transpose(1, 2) @@ -369,7 +349,7 @@ def ref_sparse_mla_fwd_interface(q, if q_start_index_s is None: q_start_index_s = sk * kv_stride - sq - assert kv.shape[-1] == 576, 'you should assign dim otherwise' + assert kv.shape[-1] == 576, "you should assign dim otherwise" dim = 512 k = kv v = kv[..., :dim] @@ -378,15 +358,14 @@ def ref_sparse_mla_fwd_interface(q, num_kv_per_index = 1 g_index = g h_index = h // g - compressed_casual_mask = torch.arange( - q_start_index_s, sq + q_start_index_s, dtype=torch.int32, - device="cuda").view(-1, 1) >= torch.arange( - kv_stride - 1, sk * kv_stride, kv_stride, dtype=torch.int32, device="cuda").view(1, -1) + compressed_casual_mask = torch.arange(q_start_index_s, sq + q_start_index_s, dtype=torch.int32, device="cuda").view( + -1, 1 + ) >= torch.arange(kv_stride - 1, sk * kv_stride, kv_stride, dtype=torch.int32, device="cuda").view(1, -1) mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) mask = mask[..., :-1] mask = mask & compressed_casual_mask.view(1, 1, sq, sk) - mask[:, :, :kv_stride - 1, 0] = True + mask[:, :, : kv_stride - 1, 0] = True mask = mask.view(b, g_index, 1, sq, sk) q = q.view(b, sq, g, -1, dim_q) @@ -401,41 +380,32 @@ def ref_sparse_mla_fwd_interface(q, return o.to(torch.bfloat16) -def test_sparse_mla_fwd_pipelined(B=1, - S=4096, - SKV=8192, - H=128, - HKV=1, - DQK=576, - DV=512, - topk=2048, - dtype=torch.bfloat16, - q_start_s_index=1024, - check_correctness=True): +def test_sparse_mla_fwd_pipelined( + B=1, S=4096, SKV=8192, H=128, HKV=1, DQK=576, DV=512, topk=2048, dtype=torch.bfloat16, q_start_s_index=1024, check_correctness=True +): KV_stride = 1 torch.random.manual_seed(0) - q = torch.randn((B, S, H, DQK), dtype=dtype, device='cuda').requires_grad_(True) / 10 - kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device='cuda').requires_grad_(True) / 10 + q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) / 10 + kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) / 10 q_start_s_index_t = torch.tensor([q_start_s_index], dtype=torch.int32, device="cuda") q.clamp_(-10, 10) kv.clamp_(-10, 10) - indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device='cuda') + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") for b in range(B): for t in range(S): for h in range(HKV): i_i = torch.randperm(min(max(1, ((t + q_start_s_index) // KV_stride)), SKV))[:topk] - indices[b, t, h, :len(i_i)] = i_i + indices[b, t, h, : len(i_i)] = i_i - kernel = sparse_mla_fwd_interface( - q, kv, indices, q_start_s_index, KV_stride, return_kernel=True, print_kernel=True) + kernel = sparse_mla_fwd_interface(q, kv, indices, q_start_s_index, KV_stride, return_kernel=True, print_kernel=True) def fn(): out, lse = kernel(q, kv, indices, q_start_s_index_t) if q_start_s_index == 0 and KV_stride > 1: - out[:, :KV_stride - 1, :, :] = 0 + out[:, : KV_stride - 1, :, :] = 0 return out, lse tl_out, tl_lse = fn() @@ -446,14 +416,15 @@ def test_sparse_mla_fwd_pipelined(B=1, torch.testing.assert_close(tl_out, ref_out, rtol=1e-3, atol=1e-3) from tilelang.profiler import do_bench + ms = do_bench( fn, rep=10, warmup=10, ) print(f"Average time: {ms:.3f} ms") - print(f'fwd io bandwidth = ', (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12) - print(f'fwd tflops = ', (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) + print(f"fwd io bandwidth = ", (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12) + print(f"fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) if __name__ == "__main__": @@ -464,5 +435,4 @@ if __name__ == "__main__": B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 1024, 8192, 128, 1, 576, 512, 2048, torch.bfloat16 else: B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 4096, 8192, 128, 1, 576, 512, 2048, torch.bfloat16 - test_sparse_mla_fwd_pipelined( - B, S, SKV, H, HKV, DQK, DV, topk, dtype, check_correctness=args.test_correctness) + test_sparse_mla_fwd_pipelined(B, S, SKV, H, HKV, DQK, DV, topk, dtype, check_correctness=args.test_correctness) diff --git a/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py b/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py index 2dd2704..6b7e879 100644 --- a/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py +++ b/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py @@ -21,23 +21,20 @@ def test_example_fp8_lighting_indexer(): @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_sparse_mla_fwd(): # small shapes for testing - sparse_mla_fwd.test_sparse_mla_fwd( - S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) + sparse_mla_fwd.test_sparse_mla_fwd(S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_sparse_mla_fwd_pipelined(): # small shapes for testing - sparse_mla_fwd_pipelined.test_sparse_mla_fwd_pipelined( - S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) + sparse_mla_fwd_pipelined.test_sparse_mla_fwd_pipelined(S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_sparse_mla_bwd(): - sparse_mla_bwd.test_sparse_mla_bwd( - S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False) + sparse_mla_bwd.test_sparse_mla_bwd(S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False) if __name__ == "__main__": diff --git a/examples/deepseek_v32/topk_selector.py b/examples/deepseek_v32/topk_selector.py index 4a4b432..cf87f52 100644 --- a/examples/deepseek_v32/topk_selector.py +++ b/examples/deepseek_v32/topk_selector.py @@ -127,9 +127,9 @@ def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"): l_num_input = s_num_input[r_idx] for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): if s * BLOCK_SIZE + tx < l_num_input: - l_bin_id32 = T.Cast("int32", (( - convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> - (24 - round * 8)) & 0xFF)) + l_bin_id32 = T.Cast( + "int32", ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF) + ) T.atomic_add(s_histogram[l_bin_id32], 1) T.sync_threads() # cumsum @@ -156,23 +156,20 @@ def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"): for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): T.sync_threads() if s * BLOCK_SIZE + tx < l_num_input: - l_bin_id32 = T.Cast("int32", (( - convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> - (24 - round * 8)) & 0xFF)) + l_bin_id32 = T.Cast( + "int32", ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF) + ) if l_bin_id32 > l_threshold_bin_id: - pos = T.atomic_add( - s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos + pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos index[bx, pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0: if round == 3: - l_out_pos = T.atomic_add( - s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos + l_out_pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos if l_out_pos < topk: index[bx, l_out_pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] else: pos = T.atomic_add(s_num_input[r_idx ^ 1], 1, return_prev=True) - s_input_idx[r_idx ^ 1, pos] = s_input_idx[r_idx, - s * BLOCK_SIZE + tx] + s_input_idx[r_idx ^ 1, pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] return tl_topk_kernel @@ -186,7 +183,6 @@ def tl_topk(input, starts, ends, topk): def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048): - batch = 64 seq_len = 32 * 1024 topk = 2048 @@ -212,8 +208,7 @@ def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048): set_ref = set(ref_np) set_trt = set(trt_np) intersection = set_ref & set_trt - print("selected/all:", len(intersection), "/", len(set_ref), "=", - len(intersection) / len(set_ref)) + print("selected/all:", len(intersection), "/", len(set_ref), "=", len(intersection) / len(set_ref)) # Performance test with CUDA events diff --git a/examples/deepseek_v32/utils.py b/examples/deepseek_v32/utils.py index 2ea34b1..d7252e1 100644 --- a/examples/deepseek_v32/utils.py +++ b/examples/deepseek_v32/utils.py @@ -23,8 +23,7 @@ def _is_equal(a, b): if isinstance(a, torch.Tensor): return a is b # Whitelist of types that are safe to compare by value for caching. - if isinstance(a, (int, float, str, bool, type(None))) and isinstance( - b, (int, float, str, bool, type(None))): + if isinstance(a, (int, float, str, bool, type(None))) and isinstance(b, (int, float, str, bool, type(None))): return a == b # For other types, we cannot guarantee a cheap and safe comparison, so we fail the cache check. return False @@ -58,9 +57,11 @@ def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor] if len(args) == len(last_args) and len(kwargs) == len(last_kwargs): # For Tensors, check for object identity. For other types, check for equality. # Python caches small integers, so `is` works for them but not for large integers like 4096. - if all(_is_equal(a, b) for a, b in zip(args, last_args)) and \ - set(kwargs.keys()) == set(last_kwargs.keys()) and \ - all(_is_equal(v, last_kwargs[k]) for k, v in kwargs.items()): + if ( + all(_is_equal(a, b) for a, b in zip(args, last_args)) + and set(kwargs.keys()) == set(last_kwargs.keys()) + and all(_is_equal(v, last_kwargs[k]) for k, v in kwargs.items()) + ): return last_result result = fn(*args, **kwargs) @@ -79,73 +80,68 @@ def cal_seq_idx_from_cu_seqlens(cu_seqlens: torch.LongTensor, seq_len: int): @tensor_cache -def cal_seq_idx_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, - seq_len: int) -> torch.IntTensor: - seq_idx_for_q = torch.full((seq_len,), - len(cu_seqlens_qs), - dtype=torch.int32, - device=cu_seqlens_qs.device) +def cal_seq_idx_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, seq_len: int) -> torch.IntTensor: + seq_idx_for_q = torch.full((seq_len,), len(cu_seqlens_qs), dtype=torch.int32, device=cu_seqlens_qs.device) for i in range(len(cu_seqlens_qs)): - seq_idx_for_q[cu_seqlens_qs[i]:cu_seqlens_qe[i]] = i + seq_idx_for_q[cu_seqlens_qs[i] : cu_seqlens_qe[i]] = i return seq_idx_for_q @tensor_cache -def cal_cu_seqlen_ks_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, - cu_seqlens_ks: torch.LongTensor, seq_len: int) -> torch.IntTensor: +def cal_cu_seqlen_ks_for_q( + cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, cu_seqlens_ks: torch.LongTensor, seq_len: int +) -> torch.IntTensor: cu_seqlen_ks_for_each_q = torch.gather( - input=torch.cat([ - cu_seqlens_ks, - torch.full((1,), - torch.iinfo(torch.int32).max, - dtype=torch.int32, - device=cu_seqlens_qs.device) - ]), + input=torch.cat([cu_seqlens_ks, torch.full((1,), torch.iinfo(torch.int32).max, dtype=torch.int32, device=cu_seqlens_qs.device)]), dim=0, - index=cal_seq_idx_for_q( - cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long()) + index=cal_seq_idx_for_q(cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long(), + ) return cu_seqlen_ks_for_each_q.int() @tensor_cache -def cal_cu_seqlen_ke_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, - cu_seqlens_ks: torch.LongTensor, cu_seqlens_ke: torch.LongTensor, - q_start_idxs: torch.LongTensor, seq_len: int, - kv_stride: int) -> torch.IntTensor: +def cal_cu_seqlen_ke_for_q( + cu_seqlens_qs: torch.LongTensor, + cu_seqlens_qe: torch.LongTensor, + cu_seqlens_ks: torch.LongTensor, + cu_seqlens_ke: torch.LongTensor, + q_start_idxs: torch.LongTensor, + seq_len: int, + kv_stride: int, +) -> torch.IntTensor: cu_seqlen_ke_for_each_q = torch.gather( - input=torch.cat( - [cu_seqlens_ke, - torch.zeros(1, dtype=torch.int32, device=cu_seqlens_qs.device)]), + input=torch.cat([cu_seqlens_ke, torch.zeros(1, dtype=torch.int32, device=cu_seqlens_qs.device)]), dim=0, - index=cal_seq_idx_for_q( - cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long()) - casual_cu_seqlen_ke_for_each_q = torch.zeros((seq_len,), - dtype=torch.int32, - device=cu_seqlens_qs.device) + index=cal_seq_idx_for_q(cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long(), + ) + casual_cu_seqlen_ke_for_each_q = torch.zeros((seq_len,), dtype=torch.int32, device=cu_seqlens_qs.device) for i in range(len(cu_seqlens_qs)): - casual_cu_seqlen_ke_for_each_q[cu_seqlens_qs[i]:cu_seqlens_qe[i]] = (torch.arange( - q_start_idxs[i], - q_start_idxs[i] + cu_seqlens_qe[i] - cu_seqlens_qs[i], - dtype=torch.int32, - device=cu_seqlens_qs.device) + 1) // kv_stride + cu_seqlens_ks[i] + casual_cu_seqlen_ke_for_each_q[cu_seqlens_qs[i] : cu_seqlens_qe[i]] = ( + torch.arange( + q_start_idxs[i], q_start_idxs[i] + cu_seqlens_qe[i] - cu_seqlens_qs[i], dtype=torch.int32, device=cu_seqlens_qs.device + ) + + 1 + ) // kv_stride + cu_seqlens_ks[i] cu_seqlen_ke_for_each_q = torch.minimum(casual_cu_seqlen_ke_for_each_q, cu_seqlen_ke_for_each_q) return cu_seqlen_ke_for_each_q.int() @tensor_cache -def cal_ks_ke_from_cu_seqlen_qk(cu_seqlens_q: torch.LongTensor, - cu_seqlens_k: torch.LongTensor = None, - offs_q: torch.LongTensor = None, - *, - seq_len: int, - kv_stride: int = 1, - cp_rank: int = 0, - cp_size: int = 1, - balanced_cp=False): - ''' +def cal_ks_ke_from_cu_seqlen_qk( + cu_seqlens_q: torch.LongTensor, + cu_seqlens_k: torch.LongTensor = None, + offs_q: torch.LongTensor = None, + *, + seq_len: int, + kv_stride: int = 1, + cp_rank: int = 0, + cp_size: int = 1, + balanced_cp=False, +): + """ seq_len: seq len per cp rank balanced cp slice assignment: 0 1 2 3 3 2 1 0 - ''' + """ n_seq = len(cu_seqlens_q) - 1 assert n_seq > 0 assert cu_seqlens_q.shape == (n_seq + 1,) @@ -170,10 +166,12 @@ def cal_ks_ke_from_cu_seqlen_qk(cu_seqlens_q: torch.LongTensor, def f(x: torch.Tensor): chunks = x.chunk(cp_size * 2) - return torch.cat([ - chunks[cp_rank], - chunks[cp_size - cp_rank - 1], - ]) + return torch.cat( + [ + chunks[cp_rank], + chunks[cp_size - cp_rank - 1], + ] + ) ks = f(ks) ke = f(ke) @@ -189,8 +187,7 @@ def ceil_to_ue8m0(x: torch.Tensor): return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) -def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple[int], - use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: +def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple[int], use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) sf = x_amax / 448.0 @@ -239,14 +236,18 @@ def generate_random_cu_seqlens(per_cp_seqlen, cp_size=4, cp_rank=3, kv_stride=1, total_seqlen - (cp_rank + 1) * per_chunk_seqlen, total_seqlen - cp_rank * per_chunk_seqlen, ) - ks = torch.cat([ - cu_seqlens_ks_for_each_q[slice_short], - cu_seqlens_ks_for_each_q[slice_long], - ]) - ke = torch.cat([ - cu_seqlens_ke_for_each_q[slice_short], - cu_seqlens_ke_for_each_q[slice_long], - ]) + ks = torch.cat( + [ + cu_seqlens_ks_for_each_q[slice_short], + cu_seqlens_ks_for_each_q[slice_long], + ] + ) + ke = torch.cat( + [ + cu_seqlens_ke_for_each_q[slice_short], + cu_seqlens_ke_for_each_q[slice_long], + ] + ) assert len(ks) == len(ke) == per_cp_seqlen return ks, ke @@ -302,11 +303,9 @@ def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True): raise_assert: Whether to raise assertion error on failure """ sim = calculate_tensor_similarity(x, y, name) - diff = 1. - sim + diff = 1.0 - sim if not (0 <= diff <= eps): - print( - f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m" - ) + print(f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m") if raise_assert: assert False # noqa: B011 @@ -316,11 +315,8 @@ if __name__ == "__main__": cu_seqlens = torch.randint(128, 4096, (1000,), dtype=torch.int32, device="cuda") last_idx = torch.where(cu_seqlens.cumsum(dim=0) >= seq_len)[0][0] cu_seqlens_cumsum = cu_seqlens[:last_idx].cumsum(dim=0) - cu_seqlens_qs = torch.cat( - [torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device), cu_seqlens_cumsum]) - cu_seqlens_qe = torch.cat( - [cu_seqlens_cumsum, - torch.ones(1, dtype=torch.int32, device=cu_seqlens.device) * seq_len]) + cu_seqlens_qs = torch.cat([torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device), cu_seqlens_cumsum]) + cu_seqlens_qe = torch.cat([cu_seqlens_cumsum, torch.ones(1, dtype=torch.int32, device=cu_seqlens.device) * seq_len]) from tilelang.profiler import do_bench diff --git a/examples/dequantize_gemm/dequantize_utils.py b/examples/dequantize_gemm/dequantize_utils.py index b14c0ae..90a6265 100644 --- a/examples/dequantize_gemm/dequantize_utils.py +++ b/examples/dequantize_gemm/dequantize_utils.py @@ -39,12 +39,10 @@ def torch_convert_bit_twiddling(tensor): res0 = val_concat_expanded & mask res1 = (val_concat_expanded << 3) & mask res2 = (val_concat_expanded << 6) & mask - res3 = ((val_concat_expanded << 1) & mask1) | ((val_concat_expanded >> 3) & mask2) | ( - (val_concat_expanded >> 7) & mask3) + res3 = ((val_concat_expanded << 1) & mask1) | ((val_concat_expanded >> 3) & mask2) | ((val_concat_expanded >> 7) & mask3) # Select the correct result based on position - bf16 = torch.where(pos == 0, res0, torch.where(pos == 1, res1, - torch.where(pos == 2, res2, res3))) + bf16 = torch.where(pos == 0, res0, torch.where(pos == 1, res1, torch.where(pos == 2, res2, res3))) # Convert to uint16 for .view(torch.bfloat16) bf16_uint16 = (bf16 & 0xFFFF).to(torch.uint16) @@ -110,7 +108,7 @@ def print_bit(name, val): val (torch.Tensor): A scalar PyTorch tensor (numeric) whose 32-bit binary representation will be shown. """ val_cpu = val.cpu().item() - binary_repr = f'{val_cpu:032b}' + binary_repr = f"{val_cpu:032b}" print(name, binary_repr) @@ -122,7 +120,7 @@ def calc_sim(x, y, name="tensor"): x, y = x.data.double(), y.data.double() denominator = (x * x + y * y).sum() if denominator == 0: - print_red_warning(f'{name} all zero') + print_red_warning(f"{name} all zero") return 1 sim = 2 * (x * y).sum() / denominator return sim @@ -132,21 +130,19 @@ def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True): x_mask = torch.isfinite(x) y_mask = torch.isfinite(y) if not torch.all(x_mask == y_mask): - print_red_warning(f'{name} Error: isfinite mask mismatch') + print_red_warning(f"{name} Error: isfinite mask mismatch") if raise_assert: raise AssertionError - if not torch.isclose( - x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, - equal_nan=True).all(): - print_red_warning(f'{name} Error: nonfinite value mismatch') + if not torch.isclose(x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, equal_nan=True).all(): + print_red_warning(f"{name} Error: nonfinite value mismatch") if raise_assert: raise AssertionError x = x.masked_fill(~x_mask, 0) y = y.masked_fill(~y_mask, 0) sim = calc_sim(x, y, name) - diff = (1. - sim).item() - print(f'{diff=}') + diff = (1.0 - sim).item() + print(f"{diff=}") if not (0 <= diff <= eps): - print_red_warning(f'{name} Error: {diff=}') + print_red_warning(f"{name} Error: {diff=}") if raise_assert: raise AssertionError diff --git a/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py index e30845b..ba3e0b4 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py @@ -24,6 +24,7 @@ def get_configs(): the parameter name to its chosen value. """ import itertools + iter_params = dict( block_M=[64, 128, 256], block_N=[64, 128, 256], @@ -32,63 +33,62 @@ def get_configs(): threads=[128, 256, 512], split=[1, 2], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] -@tilelang.autotune(configs=get_configs(),) +@tilelang.autotune( + configs=get_configs(), +) @tilelang.jit( out_idx=[-1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) -def matmul(M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - source_format='uint', - num_bits=4, - fast_dequant=True, - block_M=256, - block_N=128, - block_K=128, - num_stages=2, - threads=256, - split=1): +def matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + source_format="uint", + num_bits=4, + fast_dequant=True, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, +): + """ + Builds a parameterized TileLang/TIR matrix-multiplication kernel that dequantizes 4-bit FP inputs to BF16 on-the-fly and computes C = A @ B^T. + + This function returns a tiled, autotunable prim_func implementing a block-wise GEMM with shared-memory buffering and a pipelined K-loop. The kernel accepts: + - A: dense input of shape (M, K) with dtype `in_dtype`. + - B: packed quantized input of shape (N, QK) where QK = K / (8 / num_bits) stored as `uint8`. + - C: output of shape (M, N) with dtype `out_dtype`. + + The generated kernel supports two dequantization paths: + - fast_dequant (fast_dequant=True): calls an external mxfp dequantization intrinsic (twiddling-based) loaded from a C source returned by get_mxfp_intrin_group. + - simple dequant (fast_dequant=False): performs a pure-TIR FP4 -> BF16 conversion per element. + + Important behavior and requirements: + - num_bits (default 4) is the bit-width of the quantized elements; storage_dtype is uint8 and num_elems_per_byte = 8 // num_bits. + - QK = K // num_elems_per_byte and Block_QK = block_K // num_elems_per_byte determine B and shared-buffer shapes. + - Asserts that K % (block_K * split) == 0; K must be divisible by block_K * split for the tiling to be valid. + - When fast_dequant is True, a valid mxfp intrinsic group (C source and function name) must be available via tilelang.quantize.get_mxfp_intrin_group. + - The kernel launches a 2D grid over ceildiv(N, block_N) and ceildiv(M, block_M) and uses `threads` threads per block with `num_stages` pipeline stages. + + Parameters that alter kernel layout/behavior (brief): + - block_M, block_N, block_K: tile sizes for M, N, and K dimensions. + - num_stages: number of software pipeline stages for the K-loop. + - threads: number of threads used per kernel block. + - split: extra K-splitting factor; K must be divisible by block_K * split. + - source_format, num_bits: describe the quantized data layout passed to the mxfp intrinsics. + + Returns: + A TileLang/TIR prim_func (the compiled `main`) implementing the described dequantize-then-GEMM kernel. """ - Builds a parameterized TileLang/TIR matrix-multiplication kernel that dequantizes 4-bit FP inputs to BF16 on-the-fly and computes C = A @ B^T. - - This function returns a tiled, autotunable prim_func implementing a block-wise GEMM with shared-memory buffering and a pipelined K-loop. The kernel accepts: - - A: dense input of shape (M, K) with dtype `in_dtype`. - - B: packed quantized input of shape (N, QK) where QK = K / (8 / num_bits) stored as `uint8`. - - C: output of shape (M, N) with dtype `out_dtype`. - - The generated kernel supports two dequantization paths: - - fast_dequant (fast_dequant=True): calls an external mxfp dequantization intrinsic (twiddling-based) loaded from a C source returned by get_mxfp_intrin_group. - - simple dequant (fast_dequant=False): performs a pure-TIR FP4 -> BF16 conversion per element. - - Important behavior and requirements: - - num_bits (default 4) is the bit-width of the quantized elements; storage_dtype is uint8 and num_elems_per_byte = 8 // num_bits. - - QK = K // num_elems_per_byte and Block_QK = block_K // num_elems_per_byte determine B and shared-buffer shapes. - - Asserts that K % (block_K * split) == 0; K must be divisible by block_K * split for the tiling to be valid. - - When fast_dequant is True, a valid mxfp intrinsic group (C source and function name) must be available via tilelang.quantize.get_mxfp_intrin_group. - - The kernel launches a 2D grid over ceildiv(N, block_N) and ceildiv(M, block_M) and uses `threads` threads per block with `num_stages` pipeline stages. - - Parameters that alter kernel layout/behavior (brief): - - block_M, block_N, block_K: tile sizes for M, N, and K dimensions. - - num_stages: number of software pipeline stages for the K-loop. - - threads: number of threads used per kernel block. - - split: extra K-splitting factor; K must be divisible by block_K * split. - - source_format, num_bits: describe the quantized data layout passed to the mxfp intrinsics. - - Returns: - A TileLang/TIR prim_func (the compiled `main`) implementing the described dequantize-then-GEMM kernel. - """ num_elems_per_byte = 8 // num_bits storage_dtype = "uint8" @@ -189,8 +189,7 @@ def matmul(M, # Finally, store the dequantized data to shared memory. for v in T.vectorized(0, local_size): index = i * threads * local_size + tx * local_size + v - B_dequantize_shared[index // block_K, - index % block_K] = B_dequantize_local_thread[v] + B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v] return fast_dequant_bf16_fp4_twiddling @@ -215,30 +214,29 @@ def matmul(M, assert in_dtype in ["fp4"] assert out_dtype in ["bfloat16"] - def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, - scale: tir.PrimExpr, dtype: str): + def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str): """ - Convert a 4-bit FP4 value packed in a uint8 byte into a bfloat16 value. - - This helper extracts the 4-bit field located at the bit position `pos` within the - byte `val`, interprets it as an FP4 (sign, exponent, mantissa) value, applies an - exponent `scale` offset to align it with bfloat16 exponent bias, clamps the - resulting exponent to 8 bits, and returns the assembled bfloat16 bit pattern. - - Parameters: - nbit (int): Number of bits in the packed element; must be 4. - val (tir.PrimExpr): A uint8 value containing packed FP4 elements. - pos (tir.PrimExpr): Index (0-based) of which FP4 nibble inside `val` to extract. - scale (tir.PrimExpr): Exponent offset applied when converting FP4 exponent to bfloat16. - dtype (str): Target dtype string; must be "bfloat16". - - Returns: - tir.PrimExpr: A bfloat16-typed PrimExpr containing the converted value. - - Notes: - - The function asserts `nbit == 4`, `dtype == "bfloat16"`, and that `val.dtype` is "uint8". - - The conversion uses a fixed mapping from FP4 exponent/mantissa layout into bfloat16 - bit fields and clamps the computed exponent to fit into 8 bits. + Convert a 4-bit FP4 value packed in a uint8 byte into a bfloat16 value. + + This helper extracts the 4-bit field located at the bit position `pos` within the + byte `val`, interprets it as an FP4 (sign, exponent, mantissa) value, applies an + exponent `scale` offset to align it with bfloat16 exponent bias, clamps the + resulting exponent to 8 bits, and returns the assembled bfloat16 bit pattern. + + Parameters: + nbit (int): Number of bits in the packed element; must be 4. + val (tir.PrimExpr): A uint8 value containing packed FP4 elements. + pos (tir.PrimExpr): Index (0-based) of which FP4 nibble inside `val` to extract. + scale (tir.PrimExpr): Exponent offset applied when converting FP4 exponent to bfloat16. + dtype (str): Target dtype string; must be "bfloat16". + + Returns: + tir.PrimExpr: A bfloat16-typed PrimExpr containing the converted value. + + Notes: + - The function asserts `nbit == 4`, `dtype == "bfloat16"`, and that `val.dtype` is "uint8". + - The conversion uses a fixed mapping from FP4 exponent/mantissa layout into bfloat16 + bit fields and clamps the computed exponent to fit into 8 bits. """ assert nbit == 4 assert dtype == "bfloat16" @@ -254,8 +252,9 @@ def matmul(M, e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) m_f4 = f4 & tir.const(1, "uint16") val_bf16 = tir.reinterpret( - "bfloat16", ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) - | (m_f4 << tir.const(6, "uint16"))).astype("uint16")) + "bfloat16", + ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) | (m_f4 << tir.const(6, "uint16"))).astype("uint16"), + ) return val_bf16 @T.macro @@ -292,32 +291,32 @@ def matmul(M, @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((M, N), out_dtype), ): """ - Kernel entry for the tiled, pipelined matmul used by the generated prim_func. - - This function implements a block-wise GEMM over a 2D grid (grid dims: ceildiv(N, block_N) x ceildiv(M, block_M)) with a thread block of `threads`. For each output block it: - - Allocates shared buffers for A, the packed/quantized B, and a dequantized B tile. - - Allocates a fragment accumulator (C_local) and a shared output tile (C_shared) with a swizzled layout. - - Pipelines over K in chunks of `block_K` for `num_stages` stages: - - Loads A and packed B tiles into shared memory. - - Dequantizes B into B_dequantize_shared using either the fast (twiddling/external) or the simple (pure-TIR) dequantization routine. - - Performs a GEMM accumulating into C_local with B transposed. - - Stores the accumulated block from C_local back to the global output C via C_shared. - - Parameters: - - A: input tile of shape (M, K) with dtype `in_dtype`. - - B: packed/quantized input of shape (N, QK) with storage dtype `storage_dtype` (quantized FP4 packing). - - C: output tensor of shape (M, N) with dtype `out_dtype`. - - Side effects: - - Writes the computed output block into the global tensor `C`. - - Uses and updates shared memory buffers and per-thread accumulators. - - No value is returned. + Kernel entry for the tiled, pipelined matmul used by the generated prim_func. + + This function implements a block-wise GEMM over a 2D grid (grid dims: ceildiv(N, block_N) x ceildiv(M, block_M)) with a thread block of `threads`. For each output block it: + - Allocates shared buffers for A, the packed/quantized B, and a dequantized B tile. + - Allocates a fragment accumulator (C_local) and a shared output tile (C_shared) with a swizzled layout. + - Pipelines over K in chunks of `block_K` for `num_stages` stages: + - Loads A and packed B tiles into shared memory. + - Dequantizes B into B_dequantize_shared using either the fast (twiddling/external) or the simple (pure-TIR) dequantization routine. + - Performs a GEMM accumulating into C_local with B transposed. + - Stores the accumulated block from C_local back to the global output C via C_shared. + + Parameters: + - A: input tile of shape (M, K) with dtype `in_dtype`. + - B: packed/quantized input of shape (N, QK) with storage dtype `storage_dtype` (quantized FP4 packing). + - C: output tensor of shape (M, N) with dtype `out_dtype`. + + Side effects: + - Writes the computed output block into the global tensor `C`. + - Uses and updates shared memory buffers and per-thread accumulators. + + No value is returned. """ with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -327,9 +326,11 @@ def matmul(M, C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_shared = T.alloc_shared((block_M, block_N), out_dtype) - T.annotate_layout({ - C_shared: tilelang.layout.make_swizzled_layout(C_shared), - }) + T.annotate_layout( + { + C_shared: tilelang.layout.make_swizzled_layout(C_shared), + } + ) T.clear(C_local) for k in T.Pipelined(K // block_K, num_stages=num_stages): @@ -344,7 +345,7 @@ def matmul(M, T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) T.copy(C_local, C_shared) - T.copy(C_shared, C[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N]) + T.copy(C_shared, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N]) return main @@ -409,8 +410,7 @@ def main(m=256, n=256, k=256, fast_dequant=True, tune=False): """ total_flops = 2 * m * n * k if tune: - kernel = matmul( - m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, fast_dequant=fast_dequant) + kernel = matmul(m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, fast_dequant=fast_dequant) else: kernel = matmul( m, @@ -426,7 +426,8 @@ def main(m=256, n=256, k=256, fast_dequant=True, tune=False): block_K=128, num_stages=2, threads=256, - split=1) + split=1, + ) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) if fast_dequant: profiler.assert_allclose(ref_program_twiddling, rtol=0.01, atol=0.01) diff --git a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py index ac1417a..1091306 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py @@ -7,29 +7,28 @@ import torch from dequantize_utils import torch_convert_bit_twiddling, torch_convert -def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, - dtype: str): +def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str): """ - Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale. + Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale. - This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its - bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by - `scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation. + This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its + bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by + `scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation. - Parameters: - nbit (int): Number of bits in the packed field (must be 4). - val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields. - pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field). - scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like). - dtype (str): Destination dtype string (must be "bfloat16"). + Parameters: + nbit (int): Number of bits in the packed field (must be 4). + val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields. + pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field). + scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like). + dtype (str): Destination dtype string (must be "bfloat16"). - Returns: - tir.PrimExpr: The resulting value reinterpreted as `bfloat16`. + Returns: + tir.PrimExpr: The resulting value reinterpreted as `bfloat16`. - Notes: - - Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8". - - The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern. - """ + Notes: + - Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8". + - The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern. + """ assert nbit == 4 assert dtype == "bfloat16" assert val.dtype == "uint8" @@ -43,9 +42,10 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale # To handle the overflow, we may use the min function to limit the exponential part to 8 bits # e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) m_f4 = f4 & tir.const(1, "uint16") - val_bf16 = tir.reinterpret("bfloat16", - ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) - | (m_f4 << tir.const(6, "uint16"))).astype("uint16")) + val_bf16 = tir.reinterpret( + "bfloat16", + ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) | (m_f4 << tir.const(6, "uint16"))).astype("uint16"), + ) return val_bf16 @@ -65,6 +65,7 @@ def get_configs(): List[dict]: A list of configuration dictionaries covering all combinations. """ import itertools + iter_params = dict( block_M=[64, 128, 256], block_N=[64, 128, 256], @@ -73,67 +74,71 @@ def get_configs(): threads=[128, 256, 512], split=[1, 2], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] - - -@tilelang.autotune(configs=get_configs(),) -@tilelang.jit(out_idx=[-1],) -def matmul(M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - source_format='uint', - num_bits=4, - scale_size=32, - fast_dequant=True, - with_bias=False, - block_M=256, - block_N=128, - block_K=128, - num_stages=2, - threads=256, - split=1): + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] + + +@tilelang.autotune( + configs=get_configs(), +) +@tilelang.jit( + out_idx=[-1], +) +def matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + source_format="uint", + num_bits=4, + scale_size=32, + fast_dequant=True, + with_bias=False, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, +): """ - Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype. - - The generated kernel accepts: - - A: dense matrix with element type `in_dtype`. - - B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)). - - Scale: per-block scale/exponent information used to dequantize B. - The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths: - - fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization. - - fast_dequant (False): uses a simple elementwise dequantization helper. - - Parameters: - M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split). - in_dtype (str): element type of A (e.g., "fp4" in this file). - out_dtype (str): output tensor element type (e.g., "bfloat16"). - accum_dtype (str): accumulation type used for the inner GEMM. - source_format (str, optional): format string passed to intrinsic selector (default "uint"). - num_bits (int, optional): number of bits per quantized element in B (default 4). - scale_size (int, optional): number of elements grouped per scale entry (default 32). - fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True). - block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128). - num_stages (int, optional): pipelining stages for K loop (default 2). - threads (int, optional): threads per block used by the kernel (default 256). - split (int, optional): split factor along K used by the scheduler (default 1). - with_bias (bool, optional): whether to add Bias to the output (default False). - - Returns: - A T.prim_func implementing the tiled, pipelined GEMM that: - - loads tiled blocks of A and packed B to shared memory, - - dequantizes B via the chosen path into a shared dequantized tile, - - performs a tiled GEMM accumulating into local fragments, - - writes the final MxN block to the global output tensor. + Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype. - Notes: - - The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name. - - The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile. - - An assertion enforces that K % (block_K * split) == 0. + The generated kernel accepts: + - A: dense matrix with element type `in_dtype`. + - B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)). + - Scale: per-block scale/exponent information used to dequantize B. + The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths: + - fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization. + - fast_dequant (False): uses a simple elementwise dequantization helper. + + Parameters: + M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split). + in_dtype (str): element type of A (e.g., "fp4" in this file). + out_dtype (str): output tensor element type (e.g., "bfloat16"). + accum_dtype (str): accumulation type used for the inner GEMM. + source_format (str, optional): format string passed to intrinsic selector (default "uint"). + num_bits (int, optional): number of bits per quantized element in B (default 4). + scale_size (int, optional): number of elements grouped per scale entry (default 32). + fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True). + block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128). + num_stages (int, optional): pipelining stages for K loop (default 2). + threads (int, optional): threads per block used by the kernel (default 256). + split (int, optional): split factor along K used by the scheduler (default 1). + with_bias (bool, optional): whether to add Bias to the output (default False). + + Returns: + A T.prim_func implementing the tiled, pipelined GEMM that: + - loads tiled blocks of A and packed B to shared memory, + - dequantizes B via the chosen path into a shared dequantized tile, + - performs a tiled GEMM accumulating into local fragments, + - writes the final MxN block to the global output tensor. + + Notes: + - The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name. + - The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile. + - An assertion enforces that K % (block_K * split) == 0. """ num_elems_per_byte = 8 // num_bits storage_dtype = "uint8" @@ -150,6 +155,7 @@ def matmul(M, assert K % (block_K * split) == 0 from tilelang.quantize import get_mxfp_intrin_group + # fast_dequant_bf16_fp4_twiddling mxfp_intrin_info = get_mxfp_intrin_group( out_dtype=in_dtype, @@ -252,8 +258,7 @@ def matmul(M, for v in T.vectorized(0, local_size): index = i * threads * local_size + tx * local_size + v - B_dequantize_shared[index // block_K, - index % block_K] = B_dequantize_local_thread[v] + B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v] return fast_dequant_bf16_fp4_twiddling @@ -301,33 +306,32 @@ def matmul(M, B_local[i, j // num_elems_per_byte], j % num_elems_per_byte, Scale[ - bx * block_N + i, k * block_K // scale_size + j // - scale_size], # Scale is the exponential part, within the representation of uint8 + bx * block_N + i, k * block_K // scale_size + j // scale_size + ], # Scale is the exponential part, within the representation of uint8 dtype=out_dtype, - ) * T.shift_left( - 1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size])) + ) * T.shift_left(1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size])) T.copy(B_dequantize_local, B_dequantize_shared) return simple_dequant_bf16_fp4 @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - Scale: T.Tensor(Scale_shape, storage_dtype), - Bias: T.Tensor(Bias_shape, out_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Scale: T.Tensor(Scale_shape, storage_dtype), + Bias: T.Tensor(Bias_shape, out_dtype), + C: T.Tensor((M, N), out_dtype), ): """ - Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C. + Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C. - This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function. + This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function. - Parameters are self-descriptive in the signature; notable behaviors: - - B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM. - - The selected dequantization path is controlled by the outer-scope flag `fast_dequant`. - - The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization). - - The function writes results in-place into C. + Parameters are self-descriptive in the signature; notable behaviors: + - B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM. + - The selected dequantization path is controlled by the outer-scope flag `fast_dequant`. + - The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization). + - The function writes results in-place into C. """ with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -337,23 +341,26 @@ def matmul(M, C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_shared = T.alloc_shared((block_M, block_N), out_dtype) - T.annotate_layout({ - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - C_shared: tilelang.layout.make_swizzled_layout(C_shared), - }) + T.annotate_layout( + { + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + C_shared: tilelang.layout.make_swizzled_layout(C_shared), + } + ) if with_bias: - T.annotate_layout({ - Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared), - }) + T.annotate_layout( + { + Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared), + } + ) if threads == 512: T.disable_warp_group_reg_alloc() if with_bias: - T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N], - Bias_shared) + T.copy(Bias[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N], Bias_shared) T.copy(Bias_shared, C_local) else: T.clear(C_local) @@ -368,7 +375,7 @@ def matmul(M, T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) T.copy(C_local, C_shared) - T.copy(C_shared, C[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N]) + T.copy(C_shared, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N]) return main @@ -389,7 +396,7 @@ def ref_program_twiddling(A, qB, Scale, Bias=None): """ dtypeC = "bfloat16" B = torch_convert_bit_twiddling(qB) - B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) + B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = C.to(torch.__getattribute__(dtypeC)) return C @@ -412,7 +419,7 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias): """ dtypeC = "bfloat16" B = torch_convert_bit_twiddling(qB) - B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) + B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias C = C.to(torch.__getattribute__(dtypeC)) return C @@ -436,7 +443,7 @@ def ref_program_simple(A, qB, Scale, Bias=None): """ dtypeC = "bfloat16" B = torch_convert(qB) - B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) + B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = C.to(torch.__getattribute__(dtypeC)) return C @@ -464,7 +471,7 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias): """ dtypeC = "bfloat16" B = torch_convert(qB) - B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) + B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias C = C.to(torch.__getattribute__(dtypeC)) return C @@ -491,16 +498,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, if tune: kernel = matmul( - m, - n, - k, - "bfloat16", - "bfloat16", - "float32", - num_bits=4, - scale_size=scale_size, - fast_dequant=fast_dequant, - with_bias=with_bias) + m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias + ) else: kernel = matmul( m, @@ -518,7 +517,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, threads=256, split=1, fast_dequant=fast_dequant, - with_bias=with_bias) + with_bias=with_bias, + ) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) diff --git a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py index 7dad795..12395df 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py +++ b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py @@ -7,29 +7,28 @@ import torch from dequantize_utils import torch_convert_bit_twiddling, torch_convert -def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, - dtype: str): +def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str): """ - Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale. + Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale. - This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its - bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by - `scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation. + This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its + bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by + `scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation. - Parameters: - nbit (int): Number of bits in the packed field (must be 4). - val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields. - pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field). - scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like). - dtype (str): Destination dtype string (must be "bfloat16"). + Parameters: + nbit (int): Number of bits in the packed field (must be 4). + val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields. + pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field). + scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like). + dtype (str): Destination dtype string (must be "bfloat16"). - Returns: - tir.PrimExpr: The resulting value reinterpreted as `bfloat16`. + Returns: + tir.PrimExpr: The resulting value reinterpreted as `bfloat16`. - Notes: - - Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8". - - The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern. - """ + Notes: + - Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8". + - The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern. + """ assert nbit == 4 assert dtype == "bfloat16" assert val.dtype == "uint8" @@ -43,9 +42,10 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale # To handle the overflow, we may use the min function to limit the exponential part to 8 bits # e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) m_f4 = f4 & tir.const(1, "uint16") - val_bf16 = tir.reinterpret("bfloat16", - ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) - | (m_f4 << tir.const(6, "uint16"))).astype("uint16")) + val_bf16 = tir.reinterpret( + "bfloat16", + ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) | (m_f4 << tir.const(6, "uint16"))).astype("uint16"), + ) return val_bf16 @@ -65,6 +65,7 @@ def get_configs(): List[dict]: A list of configuration dictionaries covering all combinations. """ import itertools + iter_params = dict( block_M=[64, 128, 256], block_N=[64, 128, 256], @@ -73,67 +74,71 @@ def get_configs(): threads=[128, 256, 512], split=[1, 2], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] - - -@tilelang.autotune(configs=get_configs(),) -@tilelang.jit(out_idx=[-1],) -def matmul(M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - source_format='uint', - num_bits=4, - scale_size=32, - fast_dequant=True, - with_bias=False, - block_M=256, - block_N=128, - block_K=128, - num_stages=2, - threads=256, - split=1): + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] + + +@tilelang.autotune( + configs=get_configs(), +) +@tilelang.jit( + out_idx=[-1], +) +def matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + source_format="uint", + num_bits=4, + scale_size=32, + fast_dequant=True, + with_bias=False, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, +): """ - Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype. - - The generated kernel accepts: - - A: dense matrix with element type `in_dtype`. - - B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)). - - Scale: per-block scale/exponent information used to dequantize B. - The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths: - - fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization. - - fast_dequant (False): uses a simple elementwise dequantization helper. - - Parameters: - M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split). - in_dtype (str): element type of A (e.g., "fp4" in this file). - out_dtype (str): output tensor element type (e.g., "bfloat16"). - accum_dtype (str): accumulation type used for the inner GEMM. - source_format (str, optional): format string passed to intrinsic selector (default "uint"). - num_bits (int, optional): number of bits per quantized element in B (default 4). - scale_size (int, optional): number of elements grouped per scale entry (default 32). - fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True). - block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128). - num_stages (int, optional): pipelining stages for K loop (default 2). - threads (int, optional): threads per block used by the kernel (default 256). - split (int, optional): split factor along K used by the scheduler (default 1). - with_bias (bool, optional): whether to add Bias to the output (default False). - - Returns: - A T.prim_func implementing the tiled, pipelined GEMM that: - - loads tiled blocks of A and packed B to shared memory, - - dequantizes B via the chosen path into a shared dequantized tile, - - performs a tiled GEMM accumulating into local fragments, - - writes the final MxN block to the global output tensor. + Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype. - Notes: - - The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name. - - The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile. - - An assertion enforces that K % (block_K * split) == 0. + The generated kernel accepts: + - A: dense matrix with element type `in_dtype`. + - B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)). + - Scale: per-block scale/exponent information used to dequantize B. + The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths: + - fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization. + - fast_dequant (False): uses a simple elementwise dequantization helper. + + Parameters: + M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split). + in_dtype (str): element type of A (e.g., "fp4" in this file). + out_dtype (str): output tensor element type (e.g., "bfloat16"). + accum_dtype (str): accumulation type used for the inner GEMM. + source_format (str, optional): format string passed to intrinsic selector (default "uint"). + num_bits (int, optional): number of bits per quantized element in B (default 4). + scale_size (int, optional): number of elements grouped per scale entry (default 32). + fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True). + block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128). + num_stages (int, optional): pipelining stages for K loop (default 2). + threads (int, optional): threads per block used by the kernel (default 256). + split (int, optional): split factor along K used by the scheduler (default 1). + with_bias (bool, optional): whether to add Bias to the output (default False). + + Returns: + A T.prim_func implementing the tiled, pipelined GEMM that: + - loads tiled blocks of A and packed B to shared memory, + - dequantizes B via the chosen path into a shared dequantized tile, + - performs a tiled GEMM accumulating into local fragments, + - writes the final MxN block to the global output tensor. + + Notes: + - The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name. + - The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile. + - An assertion enforces that K % (block_K * split) == 0. """ num_elems_per_byte = 8 // num_bits storage_dtype = "uint8" @@ -150,6 +155,7 @@ def matmul(M, assert K % (block_K * split) == 0 from tilelang.quantize import get_mxfp_intrin_group + # fast_dequant_bf16_fp4_twiddling mxfp_intrin_info = get_mxfp_intrin_group( out_dtype=in_dtype, @@ -252,8 +258,7 @@ def matmul(M, for v in T.vectorized(0, local_size): index = i * threads * local_size + tx * local_size + v - B_dequantize_shared[index // block_K, - index % block_K] = B_dequantize_local_thread[v] + B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v] return fast_dequant_bf16_fp4_twiddling @@ -301,8 +306,8 @@ def matmul(M, B_local[i, j // num_elems_per_byte], j % num_elems_per_byte, Scale_shared[ - i, k * block_K // scale_size + j // - scale_size], # Scale is the exponential part, within the representation of uint8 + i, k * block_K // scale_size + j // scale_size + ], # Scale is the exponential part, within the representation of uint8 dtype=out_dtype, ) * T.shift_left(1, (Scale_shared[i, k * block_K // scale_size + j // scale_size])) T.copy(B_dequantize_local, B_dequantize_shared) @@ -311,22 +316,22 @@ def matmul(M, @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - Scale: T.Tensor(Scale_shape, storage_dtype), - Bias: T.Tensor(Bias_shape, out_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Scale: T.Tensor(Scale_shape, storage_dtype), + Bias: T.Tensor(Bias_shape, out_dtype), + C: T.Tensor((M, N), out_dtype), ): """ - Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C. + Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C. - This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function. + This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function. - Parameters are self-descriptive in the signature; notable behaviors: - - B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM. - - The selected dequantization path is controlled by the outer-scope flag `fast_dequant`. - - The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization). - - The function writes results in-place into C. + Parameters are self-descriptive in the signature; notable behaviors: + - B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM. + - The selected dequantization path is controlled by the outer-scope flag `fast_dequant`. + - The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization). + - The function writes results in-place into C. """ with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -339,16 +344,20 @@ def matmul(M, # May use much more shared memory than necessary Scale_shared = T.alloc_shared((block_N, K // scale_size), storage_dtype) - T.annotate_layout({ - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - C_shared: tilelang.layout.make_swizzled_layout(C_shared), - }) + T.annotate_layout( + { + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + C_shared: tilelang.layout.make_swizzled_layout(C_shared), + } + ) if with_bias: - T.annotate_layout({ - Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared), - }) + T.annotate_layout( + { + Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared), + } + ) if threads == 512: T.disable_warp_group_reg_alloc() @@ -357,26 +366,24 @@ def matmul(M, # T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N], # Bias_shared) # T.copy(Bias_shared, C_local) - T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N], - C_local) + T.copy(Bias[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N], C_local) else: T.clear(C_local) # Use 1D TMA to load Scale - T.copy(Scale[bx * block_N:(bx + 1) * block_N, :], Scale_shared) + T.copy(Scale[bx * block_N : (bx + 1) * block_N, :], Scale_shared) for k in T.Pipelined(K // block_K, num_stages=num_stages): T.copy(A[by * block_M, k * block_K], A_shared) T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) if fast_dequant: - get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, - k) + get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, k) else: get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale_shared, k) T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) T.copy(C_local, C_shared) - T.copy(C_shared, C[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N]) + T.copy(C_shared, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N]) return main @@ -399,7 +406,7 @@ def ref_program_twiddling(A, qB, Scale, Bias=None): B = torch_convert_bit_twiddling(qB) for i in range(B.shape[0]): for j in range(B.shape[1]): - B[i][j] = B[i][j] * (2**(Scale[i][j // 32])) + B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32])) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = C.to(torch.__getattribute__(dtypeC)) return C @@ -424,7 +431,7 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias): B = torch_convert_bit_twiddling(qB) for i in range(B.shape[0]): for j in range(B.shape[1]): - B[i][j] = B[i][j] * (2**(Scale[i][j // 32])) + B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32])) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias C = C.to(torch.__getattribute__(dtypeC)) return C @@ -450,7 +457,7 @@ def ref_program_simple(A, qB, Scale, Bias=None): B = torch_convert(qB) for i in range(B.shape[0]): for j in range(B.shape[1]): - B[i][j] = B[i][j] * (2**(Scale[i][j // 32])) + B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32])) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = C.to(torch.__getattribute__(dtypeC)) return C @@ -480,7 +487,7 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias): B = torch_convert(qB) for i in range(B.shape[0]): for j in range(B.shape[1]): - B[i][j] = B[i][j] * (2**(Scale[i][j // 32])) + B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32])) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias C = C.to(torch.__getattribute__(dtypeC)) return C @@ -507,16 +514,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, if tune: kernel = matmul( - m, - n, - k, - "bfloat16", - "bfloat16", - "float32", - num_bits=4, - scale_size=scale_size, - fast_dequant=fast_dequant, - with_bias=with_bias) + m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias + ) else: kernel = matmul( m, @@ -534,7 +533,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, threads=256, split=1, fast_dequant=fast_dequant, - with_bias=with_bias) + with_bias=with_bias, + ) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) diff --git a/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py b/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py index 727d6d3..c2b972a 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py +++ b/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py @@ -24,6 +24,7 @@ def matmul( num_bits=4, ): from tilelang.quantize import _tir_packed_to_unsigned_convert + num_elems_per_byte = 8 // num_bits storage_dtype = "int8" storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) @@ -39,9 +40,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -58,21 +59,19 @@ def matmul( T.copy(A[by * block_M, k * block_K], A_shared) T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) - for i in T.serial(block_N * block_K // num_elems_per_byte // - (threads * local_size_compressed)): + for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * local_size_compressed)): for v in T.vectorized(0, local_size_compressed): index = i * threads * local_size_compressed + tx * local_size_compressed + v vi = index // (block_K // num_elems_per_byte) vj = index % (block_K // num_elems_per_byte) B_local[v] = B_shared[vi, vj] for v in T.serial(0, local_size): - B_dequantize_local[v] = _tir_packed_to_unsigned_convert( - storage_type, storage_nbit)( - num_bits, - B_local[v // num_elems_per_byte], - v % num_elems_per_byte, - dtype=in_dtype, - ) + B_dequantize_local[v] = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + num_bits, + B_local[v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) for v in T.vectorized(0, local_size): index = i * threads * local_size + tx * local_size + v vi = index // block_K @@ -121,9 +120,7 @@ def run_gemm( def ref_program(A, qB): import torch - B = ( - torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, - dtype=torch.half).to(torch.half).to(A.device)) + B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device) for i in range(B.shape[0]): for j in range(B.shape[1]): B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half) @@ -146,9 +143,11 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( ): from tilelang.intrinsics.mma_layout import make_mma_swizzle_layout as make_swizzle_layout from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitterWithLadderTransform,) + TensorCoreIntrinEmitterWithLadderTransform, + ) from bitblas.gpu.intrin.lop3 import decode_i4_to_f16 + assert in_dtype in [ "float16", "int8", @@ -192,8 +191,7 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( pad_factor = 8 A_shape = (M, K) - B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, - micro_size_k // num_elems_per_byte) + B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k // num_elems_per_byte) A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K) B_shared_shape = ( block_N // micro_size_y, @@ -228,7 +226,8 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( chunk=chunk, reduce_k=reduce_k, transform_kind_b=transform_b, - num_elems_per_byte=num_elems_per_byte) + num_elems_per_byte=num_elems_per_byte, + ) vec_load_qb = 16 if block_N * (block_K // reduce_k) // num_elems_per_byte // threads < vec_load_qb: @@ -236,14 +235,11 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((M, N), out_dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads, - prelude=decode_i4_to_f16) as (bx, by): - + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads, prelude=decode_i4_to_f16) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -255,40 +251,36 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( thread_binding = T.get_thread_binding(0) rk = T.get_thread_binding(1) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + } + ) T.use_swizzle(panel_size=10) T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, (block_K // reduce_k)): vk = rk * (block_K // reduce_k) + k A_shared[i, vk] = A[by * block_M + i, ko * block_K + vk] # TODO(lei): Layout Inference Pass is not efficient to handle the four dims int8 load - for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte // - (threads * vec_load_qb)): + for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte // (threads * vec_load_qb)): for v in T.vectorized(0, vec_load_qb): t = thread_binding idx = i * threads * vec_load_qb * reduce_k + rk * threads * vec_load_qb + t * vec_load_qb + v vkk = idx % (micro_size_k // num_elems_per_byte) vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y - vk = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y) % ( - block_K // micro_size_k) - vj = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y // - (block_K // micro_size_k)) % ( - block_N // micro_size_y) - B_shared[vj, vk, vjj, - vkk] = B[bx * (block_N // micro_size_y) + vj, - ko * (block_K // micro_size_k) + vk, vjj, vkk] + vk = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y) % (block_K // micro_size_k) + vj = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y // (block_K // micro_size_k)) % ( + block_N // micro_size_y + ) + B_shared[vj, vk, vjj, vkk] = B[bx * (block_N // micro_size_y) + vj, ko * (block_K // micro_size_k) + vk, vjj, vkk] for ki in T.serial(0, (block_K // (micro_size_k * reduce_k))): - # Load A into fragment mma_emitter.ldmatrix_a( A_local, @@ -307,9 +299,13 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( for j in T.serial(warp_cols): local_size_b = mma_emitter.local_size_b - T.call_extern('handle', 'decode_i4u_to_f16', - T.address_of(B_local[j * local_size_b // num_elems_per_byte]), - T.address_of(B_dequantize_local[j * local_size_b]), 8) + T.call_extern( + "handle", + "decode_i4u_to_f16", + T.address_of(B_local[j * local_size_b // num_elems_per_byte]), + T.address_of(B_dequantize_local[j * local_size_b]), + 8, + ) mma_emitter.mma(A_local, B_dequantize_local, C_local) @@ -328,7 +324,8 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( reduced_accum_res[0], rk, dtype="handle", - )) + ) + ) if rk == 0: C_local[n] = reduced_accum_res[0] @@ -340,9 +337,9 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( for i, j in T.Parallel(block_M, (block_N // reduce_k)): vj = rk * (block_N // reduce_k) + j - C[by * block_M + i, - bx * block_N + vj] = C_shared[i // micro_size_x, vj // micro_size_y, - i % micro_size_x, vj % micro_size_y] + C[by * block_M + i, bx * block_N + vj] = C_shared[ + i // micro_size_x, vj // micro_size_y, i % micro_size_x, vj % micro_size_y + ] return main @@ -357,8 +354,8 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct transform_b, ): import bitblas - matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( - M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) + + matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) kernel = tilelang.compile(matmul, out_idx=[2]) src_code = kernel.get_kernel_source() @@ -371,8 +368,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct storage_dtype = "int8" A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) - qB = torch.randint( - 0, 127, (N, K // num_elems_per_byte), device="cuda", dtype=getattr(torch, storage_dtype)) + qB = torch.randint(0, 127, (N, K // num_elems_per_byte), device="cuda", dtype=getattr(torch, storage_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) ladder_permutate_config = bitblas.ops.LadderPermutateConfig( @@ -407,9 +403,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct # Ensure that the latency is not None assert latency is not None - B = ( - torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, - dtype=torch.half).to(torch.half).to(A.device)) + B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device) for i in range(B.shape[0]): for j in range(B.shape[1]): B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half) @@ -429,8 +423,7 @@ def test_run_dequantize_gemm(): @tilelang.testing.requires_package("bitblas") def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(): - assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness( - 256, 1024, 512, "float16", "float16", "float16", 3) + assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(256, 1024, 512, "float16", "float16", "float16", 3) def main(): diff --git a/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py index c5588d5..352637d 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py @@ -21,18 +21,17 @@ def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: e_f16 = e_f4 + tir.const(14, "uint16") m_f4 = f4 & tir.const(1, "uint16") m_f16 = m_f4 - val_f16 = tir.reinterpret("float16", - ((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16") - | m_f16 << tir.const(9, "uint16")).astype("uint16")) + val_f16 = tir.reinterpret( + "float16", ((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16") | m_f16 << tir.const(9, "uint16")).astype("uint16") + ) # return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16) return val_f16 def torch_convert(tensor): - def print_bit(name, val): val_cpu = val.cpu().item() - binary_repr = f'{val_cpu:032b}' + binary_repr = f"{val_cpu:032b}" print(name, binary_repr) def _convert(val, pos): @@ -68,8 +67,8 @@ def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): @T.prim_func def main( - B: T.Tensor(B_shape, storage_dtype), - C: T.Tensor((N, K), in_dtype), + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((N, K), in_dtype), ): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): B_shared = T.alloc_shared(B_shared_shape, storage_dtype) @@ -118,19 +117,11 @@ def get_configs(): splits = [1] _configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads, splits)) - configs = [{ - 'block_M': c[0], - 'block_N': c[1], - 'block_K': c[2], - 'num_stages': c[3], - 'threads': c[4], - 'split': c[5] - } for c in _configs] + configs = [{"block_M": c[0], "block_N": c[1], "block_K": c[2], "num_stages": c[3], "threads": c[4], "split": c[5]} for c in _configs] return configs def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): - @tilelang.jit(out_idx=[2]) def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1): num_elems_per_byte = 8 // num_bits @@ -145,17 +136,12 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): @T.prim_func def main_split( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - Ct: T.Tensor((N, M), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Ct: T.Tensor((N, M), out_dtype), ): - SplitC = T.alloc_buffer([ - split, (N + block_N - 1) // block_N * block_N, - (M + block_M - 1) // block_M * block_M - ], out_dtype) - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), split, - threads=threads) as (bx, by, bz): + SplitC = T.alloc_buffer([split, (N + block_N - 1) // block_N * block_N, (M + block_M - 1) // block_M * block_M], out_dtype) + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split, threads=threads) as (bx, by, bz): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_local = T.alloc_fragment(B_shared_shape, storage_dtype) @@ -164,10 +150,12 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) - T.annotate_layout({ - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), - }) + T.annotate_layout( + { + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), + } + ) T.clear(Ct_local) for k in T.Pipelined(K // (block_K * split), num_stages=num_stages): @@ -183,8 +171,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): ) T.copy(B_dequantize_local, B_dequantize_prev_local) T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) - T.copy(Ct_local, SplitC[bz, bx * block_N:(bx + 1) * block_N, - by * block_M:(by + 1) * block_M]) + T.copy(Ct_local, SplitC[bz, bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M]) with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M)) as (bx, by): acc = T.alloc_fragment((block_N, block_M), out_dtype) T.clear(acc) @@ -195,12 +182,11 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - Ct: T.Tensor((N, M), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Ct: T.Tensor((N, M), out_dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_local = T.alloc_fragment(B_shared_shape, storage_dtype) @@ -209,10 +195,12 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) - T.annotate_layout({ - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), - }) + T.annotate_layout( + { + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), + } + ) T.clear(Ct_local) for k in T.Pipelined(K // block_K, num_stages=num_stages): @@ -229,8 +217,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): T.copy(B_dequantize_local, B_dequantize_prev_local) T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) T.copy(Ct_local, Ct_shared) - T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N, - by * block_M:(by + 1) * block_M]) + T.copy(Ct_shared, Ct[bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M]) if split == 1: return main @@ -241,12 +228,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit(out_idx=[2]) - def kernel(block_M=None, - block_N=None, - block_K=None, - num_stages=None, - threads=None, - split=None): + def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None, split=None): return kernel_func(block_M, block_N, block_K, num_stages, threads, split).prim_func return kernel() @@ -269,10 +251,10 @@ def ref_program(A, qB): def main(m=256, n=256, k=256, tune=False): total_flops = 2 * m * n * k - if (not tune): - kernel = matmul( - m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)( - block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1) + if not tune: + kernel = matmul(m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)( + block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1 + ) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) print("All checks pass.") @@ -293,10 +275,10 @@ def main(m=256, n=256, k=256, tune=False): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--m', type=int, default=256, help='M') - parser.add_argument('--n', type=int, default=256, help='N') - parser.add_argument('--k', type=int, default=256, help='K') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--m", type=int, default=256, help="M") + parser.add_argument("--n", type=int, default=256, help="N") + parser.add_argument("--k", type=int, default=256, help="K") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() M, N, K = args.m, args.n, args.k main(M, N, K, args.tune) diff --git a/examples/dequantize_gemm/example_dequant_gemm_w4a8.py b/examples/dequantize_gemm/example_dequant_gemm_w4a8.py index 52ee821..3ff7267 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_w4a8.py +++ b/examples/dequantize_gemm/example_dequant_gemm_w4a8.py @@ -42,8 +42,8 @@ def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): @T.prim_func def main( - B: T.Tensor(B_shape, storage_dtype), - C: T.Tensor((N, K), in_dtype), + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((N, K), in_dtype), ): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): B_shared = T.alloc_shared(B_shared_shape, storage_dtype) @@ -66,13 +66,12 @@ def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): def torch_convert(tensor): - def _convert(val, pos): assert val.dtype == torch.uint8 val = val.view(torch.int8) mask = (1 << 4) - 1 - i4_shifted = ((val >> (pos * 4)) & mask) - i4 = ((i4_shifted << 4) >> 4) + i4_shifted = (val >> (pos * 4)) & mask + i4 = (i4_shifted << 4) >> 4 return i4.view(torch.int8) @@ -94,7 +93,6 @@ def ref_program(A, qB): def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): - @tilelang.jit(out_idx=[2]) def kernel_func(block_M, block_N, block_K, num_stages, threads): num_elems_per_byte = 8 // num_bits @@ -109,12 +107,11 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - Ct: T.Tensor((N, M), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Ct: T.Tensor((N, M), out_dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_local = T.alloc_fragment(B_shared_shape, storage_dtype) @@ -123,10 +120,12 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) - T.annotate_layout({ - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), - }) + T.annotate_layout( + { + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), + } + ) T.clear(Ct_local) for k in T.Pipelined(K // block_K, num_stages=num_stages): @@ -143,8 +142,7 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune T.copy(B_dequantize_local, B_dequantize_prev_local) T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) T.copy(Ct_local, Ct_shared) - T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N, - by * block_M:(by + 1) * block_M]) + T.copy(Ct_shared, Ct[bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M]) return main @@ -167,10 +165,10 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune def main(m=128, n=256, k=256, tune=False): total_flops = 2 * m * n * k - if (not tune): - kernel = matmul_int8xint4( - m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune)( - block_M=32, block_N=32, block_K=128, num_stages=1, threads=128) + if not tune: + kernel = matmul_int8xint4(m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune)( + block_M=32, block_N=32, block_K=128, num_stages=1, threads=128 + ) profiler = kernel.get_profiler() profiler.assert_allclose(ref_program, rtol=1e-2, atol=1e-2) print("All checks pass.") diff --git a/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py b/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py index d3e90ec..3f12146 100644 --- a/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py +++ b/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py @@ -4,7 +4,8 @@ from typing import Optional, Callable, Any import torch from tilelang import DataType from tilelang.quantize import ( - _tir_packed_int_to_int_convert,) + _tir_packed_int_to_int_convert, +) @tilelang.jit @@ -26,11 +27,10 @@ def dequantize_gemv( group_size: int = -1, with_scaling: bool = False, ) -> Callable[..., Any]: - assert n_partition is not None, "n_partition must be provided" assert reduce_thread is not None, ( - "reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMV" - "sch_outer_reduction_with_config is not implemented") + "reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMVsch_outer_reduction_with_config is not implemented" + ) assert trans_A is False, "Dequantize only implement for trans_A=False currently" assert trans_B is True, "Dequantize only implement for trans_B=TRue currently" @@ -81,12 +81,12 @@ def dequantize_gemv( C: T.Tensor[C_shape, out_dtype], ): with T.Kernel( - T.ceildiv(N, n_partition), - M, - threads=(reduce_thread, n_partition), + T.ceildiv(N, n_partition), + M, + threads=(reduce_thread, n_partition), ) as ( - bx, - by, + bx, + by, ): A_local = T.alloc_local((micro_size_k,), in_dtype) B_quant_local = T.alloc_local([micro_size_k_compressed], storage_dtype) @@ -107,8 +107,7 @@ def dequantize_gemv( for v in T.vectorized(micro_size_k_compressed): B_quant_local[v] = B[ bx * n_partition + ni, - ko * (reduce_thread * micro_size_k_compressed) + - kr * micro_size_k_compressed + v, + ko * (reduce_thread * micro_size_k_compressed) + kr * micro_size_k_compressed + v, ] if fast_decoding: @@ -120,10 +119,9 @@ def dequantize_gemv( ) else: for ki in T.serial(micro_size_k): - B_dequantize_local[ki] = _tir_packed_int_to_int_convert( - storage_type, - storage_nbit)(num_bits, B_quant_local[ki // num_elems_per_byte], - ki % num_elems_per_byte, in_dtype) + B_dequantize_local[ki] = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + num_bits, B_quant_local[ki // num_elems_per_byte], ki % num_elems_per_byte, in_dtype + ) if use_dp4a: for ki in T.serial(micro_size_k // dp4a_size): @@ -137,9 +135,9 @@ def dequantize_gemv( accum_res[0] += A_local[ki] * B_dequantize_local[ki] with T.attr( - T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), - "reduce_scope", - T.reinterpret(T.uint64(0), dtype="handle"), + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), ): T.evaluate( T.tvm_thread_allreduce( @@ -149,7 +147,8 @@ def dequantize_gemv( reduced_accum_res[0], kr, dtype="handle", - )) + ) + ) if kr == 0: C[by, bx * n_partition + ni] = reduced_accum_res[0] @@ -174,26 +173,39 @@ def main() -> None: group_size = -1 with_scaling = False - kernel = dequantize_gemv(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits, storage_dtype, - source_format, n_partition, reduce_thread, fast_decoding, trans_A, - trans_B, group_size, with_scaling) + kernel = dequantize_gemv( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + num_bits, + storage_dtype, + source_format, + n_partition, + reduce_thread, + fast_decoding, + trans_A, + trans_B, + group_size, + with_scaling, + ) storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) num_elems_per_byte = storage_nbit // num_bits A = torch.rand(M, K, dtype=getattr(torch, in_dtype)).cuda() - qB = torch.randint( - 0, 127, (N, K // num_elems_per_byte), dtype=getattr(torch, storage_dtype)).cuda() + qB = torch.randint(0, 127, (N, K // num_elems_per_byte), dtype=getattr(torch, storage_dtype)).cuda() C = torch.zeros(M, N, dtype=getattr(torch, accum_dtype)).cuda() if fast_decoding: from tilelang.quantize.utils import interleave_weight + qB = interleave_weight(qB, num_bits, in_dtype) kernel(A, qB, C) # int4 reference - B = ( - torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, - dtype=torch.half).to(torch.half).to(A.device)) + B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device) for j in range(B.shape[1]): B[:, j] = ((qB[:, j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half) diff --git a/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py b/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py index c4cf5fb..098f814 100644 --- a/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py @@ -25,6 +25,7 @@ def get_configs(): List[dict]: A list of configuration dictionaries covering all combinations. """ import itertools + iter_params = dict( block_M=[128], block_N=[64, 128, 256], @@ -33,33 +34,33 @@ def get_configs(): threads=[128, 256, 512], split=[1], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] @tilelang.autotune(configs=get_configs()) @tilelang.jit(out_idx=[-1]) -def matmul(M, - N, - K, - topk, - E, - padding_M, - in_dtype, - out_dtype, - accum_dtype, - source_format='uint', - num_bits=4, - scale_size=32, - fast_dequant=True, - with_bias=False, - block_M=128, - block_N=256, - block_K=128, - num_stages=2, - threads=256, - split=1): +def matmul( + M, + N, + K, + topk, + E, + padding_M, + in_dtype, + out_dtype, + accum_dtype, + source_format="uint", + num_bits=4, + scale_size=32, + fast_dequant=True, + with_bias=False, + block_M=128, + block_N=256, + block_K=128, + num_stages=2, + threads=256, + split=1, +): """ Construct and return a grouped (Mixture-of-Experts) matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized, expert-grouped B (shape ExNxQK) and writes an output of shape (M, topk, N) in out_dtype. @@ -115,11 +116,12 @@ def matmul(M, Block_QK = block_K // num_elems_per_byte A_shared_shape = (block_M, block_K) B_shared_shape = (block_N, Block_QK) - Bias_shared_shape = (block_N) + Bias_shared_shape = block_N B_dequantize_shared_shape = (block_N, block_K) assert K % (block_K * split) == 0 from tilelang.quantize import get_mxfp_intrin_group + # fast_dequant_bf16_fp4_twiddling mxfp_intrin_info = get_mxfp_intrin_group( out_dtype=in_dtype, @@ -221,19 +223,16 @@ def matmul(M, for v in T.vectorized(0, local_size): index = i * threads * local_size + tx * local_size + v - B_dequantize_shared[index // block_K, - index % block_K] = B_dequantize_local_thread[v] + B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v] return fast_dequant_bf16_fp4_twiddling def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"): - assert in_dtype in ["fp4"] assert out_dtype in ["bfloat16"] @T.macro def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k): - B_local = T.alloc_fragment(B_shared_shape, storage_dtype) B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype) @@ -244,8 +243,8 @@ def matmul(M, B_local[i, j // num_elems_per_byte], j % num_elems_per_byte, Scale_shared[ - i, k * block_K // scale_size + j // - scale_size], # Scale is the exponential part, within the representation of uint8 + i, k * block_K // scale_size + j // scale_size + ], # Scale is the exponential part, within the representation of uint8 dtype=out_dtype, ) * T.shift_left(1, (Scale_shared[i, k * block_K // scale_size + j // scale_size])) T.copy(B_dequantize_local, B_dequantize_shared) @@ -254,19 +253,17 @@ def matmul(M, @T.prim_func def main( - A: T.Tensor((M, K), in_dtype), - B: T.Tensor((E, N, QK), storage_dtype), - Scale: T.Tensor((E, N, K // scale_size), storage_dtype), - Bias: T.Tensor((E, N), out_dtype), - # Add fusedmoe tensors - topk_weights: T.Tensor((M * topk), out_dtype), - sorted_token_ids: T.Tensor((padding_M), "int32"), - expert_ids: T.Tensor((padding_M // block_M), "int32"), - C: T.Tensor((M, topk, N), out_dtype), + A: T.Tensor((M, K), in_dtype), + B: T.Tensor((E, N, QK), storage_dtype), + Scale: T.Tensor((E, N, K // scale_size), storage_dtype), + Bias: T.Tensor((E, N), out_dtype), + # Add fusedmoe tensors + topk_weights: T.Tensor((M * topk), out_dtype), + sorted_token_ids: T.Tensor((padding_M), "int32"), + expert_ids: T.Tensor((padding_M // block_M), "int32"), + C: T.Tensor((M, topk, N), out_dtype), ): - - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(padding_M, block_M), threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(padding_M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) @@ -280,17 +277,19 @@ def matmul(M, # May use much more shared memory than necessary Scale_shared = T.alloc_shared((block_N, K // scale_size), storage_dtype) - T.annotate_layout({ - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - C_shared: tilelang.layout.make_swizzled_layout(C_shared), - }) + T.annotate_layout( + { + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + C_shared: tilelang.layout.make_swizzled_layout(C_shared), + } + ) T.use_swizzle(10) if threads == 512: T.disable_warp_group_reg_alloc() - T.copy(sorted_token_ids[by * block_M:(by + 1) * block_M], sorted_token_ids_shared) + T.copy(sorted_token_ids[by * block_M : (by + 1) * block_M], sorted_token_ids_shared) expert_id[0] = expert_ids[by] # Get the topk weights of each token in the current block @@ -300,11 +299,11 @@ def matmul(M, # Get bias and scale based on the expert id if with_bias: - T.copy(Bias[expert_id[0], bx * block_N:(bx + 1) * block_N], Bias_shared) + T.copy(Bias[expert_id[0], bx * block_N : (bx + 1) * block_N], Bias_shared) else: T.clear(Bias_shared) - T.copy(Scale[expert_id[0], bx * block_N:(bx + 1) * block_N, :], Scale_shared) + T.copy(Scale[expert_id[0], bx * block_N : (bx + 1) * block_N, :], Scale_shared) for i, j in T.Parallel(block_M, block_N): C_local[i, j] = Bias_shared[j] @@ -317,14 +316,13 @@ def matmul(M, base = copy_i * threads * 16 + tx * 16 if sorted_token_ids_shared[base // block_K] != -1: for copy_j in T.vectorized(16): - A_shared[base // block_K, base % block_K + - copy_j] = A[sorted_token_ids_shared[base // block_K] // topk, - k * block_K + base % block_K + copy_j] + A_shared[base // block_K, base % block_K + copy_j] = A[ + sorted_token_ids_shared[base // block_K] // topk, k * block_K + base % block_K + copy_j + ] T.copy(B[expert_id[0], bx * block_N, k * block_K // num_elems_per_byte], B_shared) if fast_dequant: - get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, - k) + get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, k) else: get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale_shared, k) @@ -338,10 +336,11 @@ def matmul(M, base = copy_i * threads * 16 + tx * 16 if sorted_token_ids_shared[base // block_N] != -1: for copy_j in T.vectorized(16): - C[sorted_token_ids_shared[base // block_N] // topk, - sorted_token_ids_shared[base // block_N] % topk, bx * block_N + - base % block_N + copy_j] = C_shared[base // block_N, - base % block_N + copy_j] + C[ + sorted_token_ids_shared[base // block_N] // topk, + sorted_token_ids_shared[base // block_N] % topk, + bx * block_N + base % block_N + copy_j, + ] = C_shared[base // block_N, base % block_N + copy_j] return main @@ -355,7 +354,7 @@ def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, bloc assert scale_size == 32 # MXFP4 # Initialize output tensor - C = torch.ones((M, topk, N), dtype=getattr(torch, dtypeC), device='cuda') + C = torch.ones((M, topk, N), dtype=getattr(torch, dtypeC), device="cuda") # Iterate over sorted_token_ids for idx in range(len(sorted_token_ids)): # padding_M @@ -370,14 +369,11 @@ def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, bloc # Dequantize the expert weights B = torch_convert_bit_twiddling(qB[expert_id]) # shape: (N, K) - B *= 2**( - Scale[expert_id][:, (torch.arange(B.shape[1], device=B.device) // scale_size)].to( - torch.bfloat16)) + B *= 2 ** (Scale[expert_id][:, (torch.arange(B.shape[1], device=B.device) // scale_size)].to(torch.bfloat16)) # Compute the output for this token-expert pair # token_embedding @ B.T + bias - output = torch.matmul(token_embedding.to(torch.bfloat16), B.T.to( - torch.bfloat16)) + Bias[expert_id] + output = torch.matmul(token_embedding.to(torch.bfloat16), B.T.to(torch.bfloat16)) + Bias[expert_id] output = output.to(torch.__getattribute__(dtypeC)) # Apply the topk weight @@ -391,14 +387,12 @@ def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, bloc def get_data(m, n, k, qk, scale_size, topk, E, block_M): - A = torch.empty(m, k, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1) - qB = torch.randint( - 0, 256, (E, n, qk), dtype=torch.uint8, - device='cuda') # Quantized weight tensor for E experts. - Scale = torch.randint(0, 8, (E, n, k // scale_size), dtype=torch.uint8, device='cuda') - Bias = torch.empty(E, n, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1) - - weights = torch.empty(m, E, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1) + A = torch.empty(m, k, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1) + qB = torch.randint(0, 256, (E, n, qk), dtype=torch.uint8, device="cuda") # Quantized weight tensor for E experts. + Scale = torch.randint(0, 8, (E, n, k // scale_size), dtype=torch.uint8, device="cuda") + Bias = torch.empty(E, n, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1) + + weights = torch.empty(m, E, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1) # topk_weights: Router weights for the top-k experts for each token. # Shape: (m, topk) # tokens_experts: A flattened tensor of expert assignments for each token. @@ -420,10 +414,7 @@ def get_data(m, n, k, qk, scale_size, topk, E, block_M): pad_len = ((cnt + block_M - 1) // block_M) * block_M - cnt if pad_len > 0: # -1 for padding (`M` instead in vLLM moe_align_block_size()) - group_token_ids = torch.cat([ - group_token_ids, - torch.full((pad_len,), -1, dtype=group_token_ids.dtype, device='cuda') - ]) + group_token_ids = torch.cat([group_token_ids, torch.full((pad_len,), -1, dtype=group_token_ids.dtype, device="cuda")]) padded_token_ids.append(group_token_ids) expert_ids.extend([eid] * ((cnt + block_M - 1) // block_M)) start = end @@ -431,21 +422,13 @@ def get_data(m, n, k, qk, scale_size, topk, E, block_M): # sorted_token_ids: The final flattened and padded tensor of token indices. sorted_token_ids = torch.cat(padded_token_ids, dim=0).to(torch.int32) # (padding_M,) # expert_ids: The final tensor of expert IDs corresponding to `sorted_token_ids`. - expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device='cuda') # (padding_M,) + expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device="cuda") # (padding_M,) padding_M = sorted_token_ids.shape[0] # padding_M: token number after padding return A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M -def main(m=256, - n=256, - k=256, - scale_size=32, - topk=4, - E=32, - fast_dequant=True, - with_bias=False, - tune=False): +def main(m=256, n=256, k=256, scale_size=32, topk=4, E=32, fast_dequant=True, with_bias=False, tune=False): # Tunable parameters block_M, block_N, block_K = 128, 256, 128 # noqa: F841 num_stages = 1 # noqa: F841 @@ -456,8 +439,7 @@ def main(m=256, num_bits = 4 num_elems_per_byte = 8 // num_bits qk = k // num_elems_per_byte - A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M = get_data( - m, n, k, qk, scale_size, topk, E, block_M) + A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M = get_data(m, n, k, qk, scale_size, topk, E, block_M) if tune: with set_autotune_inputs([A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids]): @@ -510,14 +492,11 @@ def main(m=256, expert_ids, ) - print('Tilelang kernel run finished.') + print("Tilelang kernel run finished.") - ref_output = ref_moe( - A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, - block_M=block_M) # Maybe a little bit slow... + ref_output = ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, block_M=block_M) # Maybe a little bit slow... - latency = tilelang.profiler.do_bench( - lambda: kernel(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids), warmup=100) + latency = tilelang.profiler.do_bench(lambda: kernel(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids), warmup=100) print("Tilelang: {:.2f} ms".format(latency)) print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) @@ -525,32 +504,19 @@ def main(m=256, max_val = diff.max() max_idx = diff.argmax() print(f"max abs diff: {max_val} at index: {max_idx}") - assert_similar( - output, ref_output, name="output", - eps=2e-5) # We care about the similarity rather than abs. difference + assert_similar(output, ref_output, name="output", eps=2e-5) # We care about the similarity rather than abs. difference print("All checks pass. ✅") if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - "--M", type=int, default=16384, help="M") # From gpt-oss-20b MoE's first gemm + parser.add_argument("--M", type=int, default=16384, help="M") # From gpt-oss-20b MoE's first gemm parser.add_argument("--N", type=int, default=5760, help="N") parser.add_argument("--K", type=int, default=2944, help="K") parser.add_argument("--scale_size", type=int, default=32, help="scale size") - parser.add_argument( - "--topk", type=int, default=4, help="topk") # experts activated for each token + parser.add_argument("--topk", type=int, default=4, help="topk") # experts activated for each token parser.add_argument("--E", type=int, default=32, help="E") # number of experts parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() - main( - args.M, - args.N, - args.K, - args.scale_size, - topk=args.topk, - E=args.E, - fast_dequant=True, - with_bias=True, - tune=args.tune) + main(args.M, args.N, args.K, args.scale_size, topk=args.topk, E=args.E, fast_dequant=True, with_bias=True, tune=args.tune) diff --git a/examples/dsa_sparse_finetune/dsa.py b/examples/dsa_sparse_finetune/dsa.py index 1ca2824..9fae8e5 100644 --- a/examples/dsa_sparse_finetune/dsa.py +++ b/examples/dsa_sparse_finetune/dsa.py @@ -11,7 +11,6 @@ from utils import get_abs_err, get_err_ratio class RegsiterLossFunction(torch.autograd.Function): - @staticmethod def forward(ctx, x, loss): ctx.save_for_backward(loss) @@ -38,49 +37,43 @@ def ref_deepseek_sparse_attention_innner( index_sm_scale: Optional[float] = None, ): dtype = q.dtype - q, kv, index_q, index_k, weights = map(lambda x: x.to(torch.float32), - (q, kv, index_q, index_k, weights)) + q, kv, index_q, index_k, weights = map(lambda x: x.to(torch.float32), (q, kv, index_q, index_k, weights)) - index_sm_scale = index_q.shape[-1]**-0.5 + index_sm_scale = index_q.shape[-1] ** -0.5 b, s = index_q.shape[:2] # tl_topk_indices = tl_topk_indices.to(torch.int64) # tl_topk_indices[tl_topk_indices == -1] = s casual_mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device) - index_logits = einsum(index_q, index_k, 'b s1 h k, b s2 k -> b s1 h s2') + index_logits = einsum(index_q, index_k, "b s1 h k, b s2 k -> b s1 h s2") index_logits = F.relu(index_logits) - index_logits = (index_logits * weights.unsqueeze(-1)).sum( - dim=-2, dtype=torch.float32) * index_sm_scale - index_logits = torch.where(casual_mask, index_logits, float('-inf')) + index_logits = (index_logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) * index_sm_scale + index_logits = torch.where(casual_mask, index_logits, float("-inf")) topk_indices = torch.topk(index_logits, k=topk, dim=-1).indices - topk_logits = torch.gather( - F.pad(index_logits, (0, 1), value=float('-inf')), dim=-1, index=topk_indices) + topk_logits = torch.gather(F.pad(index_logits, (0, 1), value=float("-inf")), dim=-1, index=topk_indices) topk_score = F.log_softmax(topk_logits, dim=-1, dtype=torch.float32) index_topk_score = topk_score if sm_scale is None: - sm_scale = kv.shape[-1]**-0.5 + sm_scale = kv.shape[-1] ** -0.5 h = q.shape[-2] - index_mask = torch.zeros((b, s, s + 1), dtype=torch.bool, device="cuda")\ - .scatter_(dim=-1, index=topk_indices, src=torch.ones_like(topk_indices, dtype=torch.bool))[:, :, :-1] - mask = repeat(casual_mask & index_mask, 'b s1 s2 -> b s1 h s2', h=h) + index_mask = torch.zeros((b, s, s + 1), dtype=torch.bool, device="cuda").scatter_( + dim=-1, index=topk_indices, src=torch.ones_like(topk_indices, dtype=torch.bool) + )[:, :, :-1] + mask = repeat(casual_mask & index_mask, "b s1 s2 -> b s1 h s2", h=h) k, v = kv, kv[..., :dim_v] - logits = einsum(q, k, 'b s1 h d, b s2 d -> b s1 h s2') * sm_scale - logits = torch.where(mask, logits, float('-inf')) + logits = einsum(q, k, "b s1 h d, b s2 d -> b s1 h s2") * sm_scale + logits = torch.where(mask, logits, float("-inf")) attn_score = F.softmax(logits, dim=-1, dtype=torch.float32) - o = einsum(attn_score, v, 'b s1 h s2, b s2 d -> b s1 h d') + o = einsum(attn_score, v, "b s1 h s2, b s2 d -> b s1 h d") attn_score = attn_score.sum(dim=-2) # [b, s1, s2] attn_topk_score = torch.gather(F.pad(attn_score, (0, 1)), dim=-1, index=topk_indices) attn_topk_score = attn_topk_score / attn_topk_score.sum(dim=-1, keepdim=True) - loss = F.kl_div( - index_topk_score.clip(-100, 0), - attn_topk_score.detach().log().clip(-100, 0), - log_target=True, - reduction="sum") + loss = F.kl_div(index_topk_score.clip(-100, 0), attn_topk_score.detach().log().clip(-100, 0), log_target=True, reduction="sum") o = register_loss(o, loss) return o.to(dtype), topk_indices @@ -101,11 +94,11 @@ def ref_deepseek_sparse_attention( all_o, all_topk_indices = [], [] for i in range(offsets.shape[0] - 1): o, topk_indices = ref_deepseek_sparse_attention_innner( - q[None, offsets[i]:offsets[i + 1]], - kv[None, offsets[i]:offsets[i + 1]], - index_q[None, offsets[i]:offsets[i + 1]], - index_k[None, offsets[i]:offsets[i + 1]], - weights[None, offsets[i]:offsets[i + 1]], + q[None, offsets[i] : offsets[i + 1]], + kv[None, offsets[i] : offsets[i + 1]], + index_q[None, offsets[i] : offsets[i + 1]], + index_k[None, offsets[i] : offsets[i + 1]], + weights[None, offsets[i] : offsets[i + 1]], topk, dim_v, sm_scale, @@ -119,7 +112,6 @@ def ref_deepseek_sparse_attention( class DSAFunction(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -134,12 +126,9 @@ class DSAFunction(torch.autograd.Function): sm_scale: Optional[float] = None, ): # topk_indices, index_score = ref_index_score(index_q, weights, index_k, topk) - topk_indices, index_score = indexer_topk_reducesum_interface(index_q, weights, index_k, - topk, offsets) - o, lse = sparse_mla_fwd_interface( - q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), offsets, sm_scale=sm_scale, d_v=dim_v) - ctx.save_for_backward(q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, - offsets) + topk_indices, index_score = indexer_topk_reducesum_interface(index_q, weights, index_k, topk, offsets) + o, lse = sparse_mla_fwd_interface(q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), offsets, sm_scale=sm_scale, d_v=dim_v) + ctx.save_for_backward(q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, offsets) ctx.topk = topk ctx.dim_v = dim_v ctx.sm_scale = sm_scale @@ -153,19 +142,10 @@ class DSAFunction(torch.autograd.Function): ): q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, offsets = ctx.saved_tensors attn_score = sparse_mla_topk_reducesum_interface( - q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), lse, offsets, - dim_v=ctx.dim_v).squeeze(-2) - dq, dkv = sparse_mla_bwd( - q, - kv.unsqueeze(-2), - o, - do, - topk_indices.unsqueeze(-2), - lse, - offsets, - sm_scale=ctx.sm_scale) - dindex_q, dweights, dindex_k = indexer_bwd_interface(index_q, weights, index_k, attn_score, - index_score, topk_indices, offsets) + q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), lse, offsets, dim_v=ctx.dim_v + ).squeeze(-2) + dq, dkv = sparse_mla_bwd(q, kv.unsqueeze(-2), o, do, topk_indices.unsqueeze(-2), lse, offsets, sm_scale=ctx.sm_scale) + dindex_q, dweights, dindex_k = indexer_bwd_interface(index_q, weights, index_k, attn_score, index_score, topk_indices, offsets) return dq, dkv.squeeze(-2), dindex_q, dindex_k, dweights, None, None, None, None @@ -209,8 +189,7 @@ def test_kernel( index_k_grad, index_k.grad = index_k.grad, None weights_grad, weights.grad = weights.grad, None - ref_o, ref_topk_indices = ref_deepseek_sparse_attention(q, kv, index_q, index_k, weights, - offsets, topk, D) + ref_o, ref_topk_indices = ref_deepseek_sparse_attention(q, kv, index_q, index_k, weights, offsets, topk, D) ref_o.backward(do) ref_q_grad, q.grad = q.grad, None ref_kv_grad, kv.grad = kv.grad, None @@ -219,28 +198,20 @@ def test_kernel( ref_weights_grad, weights.grad = weights.grad, None print(f"o err: {get_abs_err(o, ref_o):.6f} ratio: {get_err_ratio(o, ref_o):.6f}") - print( - f"q.grad err: {get_abs_err(q_grad, ref_q_grad):.6f} ratio: {get_err_ratio(q_grad, ref_q_grad):.6f}" - ) - print( - f"kv.grad err: {get_abs_err(kv_grad, ref_kv_grad):.6f} ratio: {get_err_ratio(kv_grad, ref_kv_grad):.6f}" - ) + print(f"q.grad err: {get_abs_err(q_grad, ref_q_grad):.6f} ratio: {get_err_ratio(q_grad, ref_q_grad):.6f}") + print(f"kv.grad err: {get_abs_err(kv_grad, ref_kv_grad):.6f} ratio: {get_err_ratio(kv_grad, ref_kv_grad):.6f}") print( f"index_q.grad err: {get_abs_err(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f} ratio: {get_err_ratio(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f}" ) - print( - f"index_k.grad err: {get_abs_err(index_k_grad, ref_index_k_grad):.6f} ratio: {get_err_ratio(index_k_grad, ref_index_k_grad):.6f}" - ) - print( - f"weights.grad err: {get_abs_err(weights_grad, ref_weights_grad):.6f} ratio: {get_err_ratio(weights_grad, ref_weights_grad):.6f}" - ) + print(f"index_k.grad err: {get_abs_err(index_k_grad, ref_index_k_grad):.6f} ratio: {get_err_ratio(index_k_grad, ref_index_k_grad):.6f}") + print(f"weights.grad err: {get_abs_err(weights_grad, ref_weights_grad):.6f} ratio: {get_err_ratio(weights_grad, ref_weights_grad):.6f}") intersections = [] for j in range(S): ref_np = ref_topk_indices[j].cpu().to(torch.int32).numpy() trt_np = topk_indices[j].cpu().to(torch.int32).numpy() - mask = (trt_np != -1) + mask = trt_np != -1 set_ref = set(ref_np[mask]) set_trt = set(trt_np[mask]) diff --git a/examples/dsa_sparse_finetune/index.py b/examples/dsa_sparse_finetune/index.py index 92ce687..5e48004 100644 --- a/examples/dsa_sparse_finetune/index.py +++ b/examples/dsa_sparse_finetune/index.py @@ -5,7 +5,9 @@ import functools from typing import Callable, Any -def tensor_cache(fn: Callable[..., torch.Tensor],) -> Callable[..., torch.Tensor]: +def tensor_cache( + fn: Callable[..., torch.Tensor], +) -> Callable[..., torch.Tensor]: """ A decorator that caches the most recent result of a function with tensor inputs. @@ -29,10 +31,12 @@ def tensor_cache(fn: Callable[..., torch.Tensor],) -> Callable[..., torch.Tensor def wrapper(*args: Any, **kwargs: Any) -> Any: nonlocal last_args, last_kwargs, last_result - if (last_args is not None and last_kwargs is not None) and \ - (len(args) == len(last_args) and len(kwargs) == len(last_kwargs)) and \ - all(a is b for a, b in zip(args, last_args, strict=False)) and \ - all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()): + if ( + (last_args is not None and last_kwargs is not None) + and (len(args) == len(last_args) and len(kwargs) == len(last_kwargs)) + and all(a is b for a, b in zip(args, last_args, strict=False)) + and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()) + ): return last_result result = fn(*args, **kwargs) @@ -56,16 +60,15 @@ def prepare_cu_seqlens_from_lens( @tensor_cache -def prepare_lens_from_cu_seqlens(cu_seqlens: torch.LongTensor,) -> torch.LongTensor: +def prepare_lens_from_cu_seqlens( + cu_seqlens: torch.LongTensor, +) -> torch.LongTensor: return torch.diff(cu_seqlens) @tensor_cache def prepare_position_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor: - return torch.cat([ - torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) - for n in prepare_lens(cu_seqlens).unbind() - ]) + return torch.cat([torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) for n in prepare_lens(cu_seqlens).unbind()]) @tensor_cache diff --git a/examples/dsa_sparse_finetune/indexer_bwd.py b/examples/dsa_sparse_finetune/indexer_bwd.py index 5430c1c..5d8132d 100644 --- a/examples/dsa_sparse_finetune/indexer_bwd.py +++ b/examples/dsa_sparse_finetune/indexer_bwd.py @@ -49,17 +49,17 @@ def tl_indexer_bwd_impl( @T.prim_func def tl_indexer_bwd_kernel( - IndexQ: T.Tensor(index_q_shape, dtype), - Weights: T.Tensor(weights_shape, dtype), - IndexK: T.Tensor(index_k_shape, dtype), - dIndexQ: T.Tensor(index_q_shape, dtype), - dWeights: T.Tensor(weights_shape, dtype), - dIndexK: T.Tensor(index_k_shape, dtype), - AttnScore: T.Tensor(shape_p, FP32), - IndexScore: T.Tensor(shape_p, FP32), - TopkIndices: T.Tensor(topk_indices_shape, INT32), - Offsets: T.Tensor(offsets_shape, INT32), - TokenIndices: T.Tensor(token_indices_shape, INT32), + IndexQ: T.Tensor(index_q_shape, dtype), + Weights: T.Tensor(weights_shape, dtype), + IndexK: T.Tensor(index_k_shape, dtype), + dIndexQ: T.Tensor(index_q_shape, dtype), + dWeights: T.Tensor(weights_shape, dtype), + dIndexK: T.Tensor(index_k_shape, dtype), + AttnScore: T.Tensor(shape_p, FP32), + IndexScore: T.Tensor(shape_p, FP32), + TopkIndices: T.Tensor(topk_indices_shape, INT32), + Offsets: T.Tensor(offsets_shape, INT32), + TokenIndices: T.Tensor(token_indices_shape, INT32), ): with T.Kernel(seq_len, threads=num_threads) as (bx): i_b, i_t = TokenIndices[bx, 0], TokenIndices[bx, 1] @@ -81,7 +81,6 @@ def tl_indexer_bwd_impl( index_q_shared[i, j] = index_q_shared[i, j] * sm_scale for bi_i in T.Pipelined(num_blocks, num_stages=num_stages): - i_st = bi_i * block_I i_ed = (bi_i + 1) * block_I @@ -91,8 +90,7 @@ def tl_indexer_bwd_impl( index_k_shared = T.alloc_shared([block_I, dim], dtype=dtype) for i, j in T.Parallel(block_I, dim): pos = indices_shared[i] - index_k_shared[i, j] = T.if_then_else((pos > -1) & (pos <= i_t), - IndexK[bos + pos, j], 0) + index_k_shared[i, j] = T.if_then_else((pos > -1) & (pos <= i_t), IndexK[bos + pos, j], 0) attn_score_shared = T.alloc_shared([block_I], dtype=accum_dtype) index_score_shared = T.alloc_shared([block_I], dtype=accum_dtype) @@ -115,8 +113,7 @@ def tl_indexer_bwd_impl( # dw d_weights_i = T.alloc_fragment((block_I, heads), accum_dtype) for i, j in T.Parallel(block_I, heads): - d_weights_i[i, - j] = (index_score_shared[i] - attn_score_shared[i]) * logits[i, j] + d_weights_i[i, j] = (index_score_shared[i] - attn_score_shared[i]) * logits[i, j] T.reduce_sum(d_weights_i, d_weights_frag, dim=0, clear=False) d_logits_qk = T.alloc_shared((block_I, heads), accum_dtype) @@ -129,8 +126,7 @@ def tl_indexer_bwd_impl( d_relu = 1.0 else: d_relu = 0.0 - d_logits_qk[i, j] = (index_score_shared[i] - - attn_score_shared[i]) * d_relu * weights_shared[j] + d_logits_qk[i, j] = (index_score_shared[i] - attn_score_shared[i]) * d_relu * weights_shared[j] # dq T.copy(d_logits_qk, d_logits_qk_cast1) @@ -157,7 +153,7 @@ def tl_indexer_bwd_impl( for i, j in T.Parallel(block_I, dim): pos = indices_shared[i] - if ((pos > -1) & (pos <= i_t)): + if (pos > -1) & (pos <= i_t): T.atomic_add(dIndexK[bos + pos, j], d_index_k_frag[i, j]) for i, j in T.Parallel(heads, dim): @@ -184,40 +180,35 @@ def indexer_bwd_interface( dweights = torch.zeros_like(weights) dk = torch.zeros_like(k) kernel = tl_indexer_bwd_impl(heads, dim, topk) - kernel(q, weights, k, dq, dweights, dk, attn_score, index_score, topk_indices, offsets, - token_indices) + kernel(q, weights, k, dq, dweights, dk, attn_score, index_score, topk_indices, offsets, token_indices) return dq, dweights, dk -def ref_indexer_bwd(Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, - TopkIndices: torch.Tensor, AttnScore: torch.Tensor, - offsets: torch.Tensor) -> torch.Tensor: +def ref_indexer_bwd( + Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, TopkIndices: torch.Tensor, AttnScore: torch.Tensor, offsets: torch.Tensor +) -> torch.Tensor: Q.requires_grad_(True) Weights.requires_grad_(True) K.requires_grad_(True) - softmax_scale = Q.shape[-1]**-0.5 + softmax_scale = Q.shape[-1] ** -0.5 all_loss = [] all_log_topk_prob = [] for i in range(offsets.shape[0] - 1): assert (offsets[i + 1] - offsets[i]).item() >= TopkIndices.shape[-1] - q = Q[offsets[i]:offsets[i + 1]] - weights = Weights[offsets[i]:offsets[i + 1]] - k = K[offsets[i]:offsets[i + 1]] - topk_indices = TopkIndices[offsets[i]:offsets[i + 1]] - attn_score = AttnScore[offsets[i]:offsets[i + 1]] + q = Q[offsets[i] : offsets[i + 1]] + weights = Weights[offsets[i] : offsets[i + 1]] + k = K[offsets[i] : offsets[i + 1]] + topk_indices = TopkIndices[offsets[i] : offsets[i + 1]] + attn_score = AttnScore[offsets[i] : offsets[i + 1]] s = q.shape[0] mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device) - logits = einsum(q, k, 's1 h k, s2 k -> s1 h s2') * softmax_scale + logits = einsum(q, k, "s1 h k, s2 k -> s1 h s2") * softmax_scale logits = F.relu(logits) score = (logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) - score = torch.where(mask, score, float('-inf')) + score = torch.where(mask, score, float("-inf")) topk_value = torch.gather(score, dim=-1, index=topk_indices.to(torch.int64)) log_topk_prob = F.log_softmax(topk_value, dim=-1, dtype=torch.float32) - loss = F.kl_div( - log_topk_prob.clip(-100, 0), - attn_score.log().clip(-100, 0), - log_target=True, - reduction="sum") + loss = F.kl_div(log_topk_prob.clip(-100, 0), attn_score.log().clip(-100, 0), log_target=True, reduction="sum") all_loss.append(loss) all_log_topk_prob.append(log_topk_prob) loss = torch.stack(all_loss).sum() @@ -244,15 +235,13 @@ def test_kernel( seq_len = (offsets[i + 1] - offsets[i]).item() mask = (torch.arange(seq_len)[:, None] >= torch.arange(topk)[None, :]).to(q.device) logits = torch.ones(seq_len, topk).cuda() - logits = torch.where(mask, logits, float('-inf')) + logits = torch.where(mask, logits, float("-inf")) attn_score = F.softmax(logits, dim=-1, dtype=torch.float32) all_attn_score.append(attn_score) attn_score = torch.cat(all_attn_score, dim=0) - topk_indices = repeat( - torch.arange(topk, dtype=torch.int32).cuda(), 'k -> s k', s=S).contiguous() - index_score, ref_dq, ref_dw, ref_dk = ref_indexer_bwd(q, w, k, topk_indices, attn_score, - offsets) + topk_indices = repeat(torch.arange(topk, dtype=torch.int32).cuda(), "k -> s k", s=S).contiguous() + index_score, ref_dq, ref_dw, ref_dk = ref_indexer_bwd(q, w, k, topk_indices, attn_score, offsets) dq, dw, dk = indexer_bwd_interface(q, w, k, attn_score, index_score, topk_indices, offsets) @@ -261,5 +250,5 @@ def test_kernel( print(f"dq err: {get_abs_err(dk, ref_dk):.6f} ratio: {get_err_ratio(dk, ref_dk):.6f}") -if __name__ == '__main__': +if __name__ == "__main__": test_kernel() diff --git a/examples/dsa_sparse_finetune/indexer_topk_reducesum.py b/examples/dsa_sparse_finetune/indexer_topk_reducesum.py index b7fa662..8e2f82b 100644 --- a/examples/dsa_sparse_finetune/indexer_topk_reducesum.py +++ b/examples/dsa_sparse_finetune/indexer_topk_reducesum.py @@ -53,8 +53,8 @@ def tl_indexer_topk_reducesum_impl( @T.macro def bitonic_sort( - topk_index_shared: T.SharedBuffer([N], dtype=INT32), - topk_value_shared: T.SharedBuffer([N], dtype=FP32), + topk_index_shared: T.SharedBuffer([N], dtype=INT32), + topk_value_shared: T.SharedBuffer([N], dtype=FP32), ): T.sync_threads() for i1 in T.serial(num_iters): @@ -62,9 +62,10 @@ def tl_indexer_topk_reducesum_impl( for i in T.Parallel(N): ascending = (i & (1 << (i1 + 1))) != 0 j = i ^ (1 << (i1 - i2)) - if i < j and \ - ((ascending and topk_value_shared[i] > topk_value_shared[j]) or ( - not ascending and topk_value_shared[i] < topk_value_shared[j])): + if i < j and ( + (ascending and topk_value_shared[i] > topk_value_shared[j]) + or (not ascending and topk_value_shared[i] < topk_value_shared[j]) + ): val = topk_value_shared[i] topk_value_shared[i] = topk_value_shared[j] topk_value_shared[j] = val @@ -75,13 +76,13 @@ def tl_indexer_topk_reducesum_impl( @T.prim_func def tl_indexer_topk_reducesum_kernel( - IndexQ: T.Tensor(index_q_shape, dtype), - Weights: T.Tensor(weights_shape, dtype), - IndexK: T.Tensor(index_k_shape, dtype), - TopkIndices: T.Tensor(topk_indices_shape, INT32), - ReduceSum: T.Tensor(topk_indices_shape, FP32), - Offsets: T.Tensor(offsets_shape, INT32), - TokenIndices: T.Tensor(token_indices_shape, INT32), + IndexQ: T.Tensor(index_q_shape, dtype), + Weights: T.Tensor(weights_shape, dtype), + IndexK: T.Tensor(index_k_shape, dtype), + TopkIndices: T.Tensor(topk_indices_shape, INT32), + ReduceSum: T.Tensor(topk_indices_shape, FP32), + Offsets: T.Tensor(offsets_shape, INT32), + TokenIndices: T.Tensor(token_indices_shape, INT32), ): with T.Kernel(seq_len, threads=num_threads) as (bx): i_b, i_t = TokenIndices[bx, 0], TokenIndices[bx, 1] @@ -92,7 +93,7 @@ def tl_indexer_topk_reducesum_impl( topk_value_shared = T.alloc_shared([N], dtype=FP32) T.fill(topk_index_shared, -1) - T.fill(topk_value_shared, float('-inf')) + T.fill(topk_value_shared, float("-inf")) T.sync_threads() index_q_shared = T.alloc_shared([heads, dim], dtype=dtype) @@ -113,8 +114,7 @@ def tl_indexer_topk_reducesum_impl( index_k_shared = T.alloc_shared([block_K, dim], dtype=dtype) for i, j in T.Parallel(block_K, dim): - index_k_shared[i, j] = T.if_then_else(k_st + i < k_ed, IndexK[bos + k_st + i, - j], 0) + index_k_shared[i, j] = T.if_then_else(k_st + i < k_ed, IndexK[bos + k_st + i, j], 0) T.sync_threads() logits = T.alloc_fragment((block_K, heads), FP32) @@ -144,7 +144,7 @@ def tl_indexer_topk_reducesum_impl( T.sync_threads() for i in T.Parallel(block_K): if k_st + i > i_t: - logits_sum[i] = float('-inf') + logits_sum[i] = float("-inf") j = offset + i topk_index_shared[j] = k_st + i topk_value_shared[j] = logits_sum[i] @@ -209,22 +209,21 @@ def indexer_topk_reducesum_interface( return topk_indices, topk_score -def ref_index_score(Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, topk: int, - offsets: torch.Tensor) -> torch.Tensor: +def ref_index_score(Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, topk: int, offsets: torch.Tensor) -> torch.Tensor: all_topk_indices = [] all_topk_score = [] for i in range(offsets.shape[0] - 1): assert (offsets[i + 1] - offsets[i]).item() >= topk - q = Q[offsets[i]:offsets[i + 1]] - weights = Weights[offsets[i]:offsets[i + 1]] - k = K[offsets[i]:offsets[i + 1]] - softmax_scale = q.shape[-1]**-0.5 + q = Q[offsets[i] : offsets[i + 1]] + weights = Weights[offsets[i] : offsets[i + 1]] + k = K[offsets[i] : offsets[i + 1]] + softmax_scale = q.shape[-1] ** -0.5 s = q.shape[0] mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device) - logits = einsum(q, k, 's1 h k, s2 k -> s1 h s2') + logits = einsum(q, k, "s1 h k, s2 k -> s1 h s2") logits = F.relu(logits) logits = (logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) * softmax_scale - logits = torch.where(mask, logits, float('-inf')) + logits = torch.where(mask, logits, float("-inf")) topk_logits, topk_indices = torch.topk(logits, k=topk, dim=-1) topk_score = F.softmax(topk_logits, dim=-1, dtype=torch.float32) all_topk_indices.append(topk_indices) @@ -265,13 +264,10 @@ def test_kernel( set_trt = set(trt_np[mask]) intersection = set_ref & set_trt - print("idx:", j, "selected/all:", len(intersection), "/", len(set_ref), "=", - len(intersection) / len(set_ref)) + print("idx:", j, "selected/all:", len(intersection), "/", len(set_ref), "=", len(intersection) / len(set_ref)) - print( - f"err: {get_abs_err(ref_np_val, trt_np_val):.6f} ratio: {get_err_ratio(ref_np_val, trt_np_val):.6f}" - ) + print(f"err: {get_abs_err(ref_np_val, trt_np_val):.6f} ratio: {get_err_ratio(ref_np_val, trt_np_val):.6f}") -if __name__ == '__main__': +if __name__ == "__main__": test_kernel() diff --git a/examples/dsa_sparse_finetune/sparse_mla_bwd.py b/examples/dsa_sparse_finetune/sparse_mla_bwd.py index 33c21cb..0b08551 100644 --- a/examples/dsa_sparse_finetune/sparse_mla_bwd.py +++ b/examples/dsa_sparse_finetune/sparse_mla_bwd.py @@ -19,15 +19,15 @@ def preprocess( assert dtype == "bfloat16" assert accum_dtype == "float" - S = T.symbolic('S') + S = T.symbolic("S") shape = [S, H, D] @T.prim_func def preprocess_kernel( - O: T.Tensor(shape, dtype), - dO: T.Tensor(shape, dtype), - Delta: T.Tensor([S, H], accum_dtype), + O: T.Tensor(shape, dtype), + dO: T.Tensor(shape, dtype), + Delta: T.Tensor([S, H], accum_dtype), ): with T.Kernel(H, T.ceildiv(S, block_ND)) as (bx, by): o = T.alloc_fragment([block_ND, block_ND], accum_dtype) @@ -36,13 +36,12 @@ def preprocess( acc = T.alloc_fragment([block_ND, block_ND], accum_dtype) T.clear(acc) for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages): - T.copy(O[by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], o) - T.copy(dO[by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], - do) + T.copy(O[by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], o) + T.copy(dO[by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], do) for i, j in T.Parallel(block_ND, block_ND): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[by * block_ND:(by + 1) * block_ND, bx]) + T.copy(delta, Delta[by * block_ND : (by + 1) * block_ND, bx]) return preprocess_kernel @@ -59,19 +58,19 @@ def postprocess( ): assert dtype == "bfloat16" assert accum_dtype == "float" - S_kv = T.symbolic('S_kv') + S_kv = T.symbolic("S_kv") dkv_shape = [S_kv, kv_group, D + D_tail] @T.prim_func def postprocess_kernel( - dKV: T.Tensor(dkv_shape, accum_dtype), - dKV_out: T.Tensor(dkv_shape, dtype), + dKV: T.Tensor(dkv_shape, accum_dtype), + dKV_out: T.Tensor(dkv_shape, dtype), ): with T.Kernel(T.ceildiv(S_kv, block_N), kv_group, threads=threads) as (bx, by): T.copy( - dKV[bx * block_N:(bx + 1) * block_N, by, :], - dKV_out[bx * block_N:(bx + 1) * block_N, by, :], + dKV[bx * block_N : (bx + 1) * block_N, by, :], + dKV_out[bx * block_N : (bx + 1) * block_N, by, :], ) return postprocess_kernel @@ -82,7 +81,8 @@ def postprocess( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, +) def bwd( H, D, @@ -98,17 +98,17 @@ def bwd( dtype="bfloat16", accum_dtype="float", ): - assert is_causal == True, 'non-casual is not supported now' - assert topk % block_size == 0, 'otherwise will load some index=0 thus causing wrong kv to be loaded' + assert is_causal == True, "non-casual is not supported now" + assert topk % block_size == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" assert dtype == "bfloat16" assert accum_dtype == "float" assert indices_dtype == "int32" if sm_scale is None: - sm_scale = (D + D_tail)**(-0.5) + sm_scale = (D + D_tail) ** (-0.5) - B_plus_one = T.symbolic('B_plus_one') - S = T.symbolic('S') + B_plus_one = T.symbolic("B_plus_one") + S = T.symbolic("S") H_kv = H // kv_group q_shape = [S, H, D + D_tail] @@ -132,16 +132,16 @@ def bwd( @T.prim_func def sparse_mla_bwd_kernel( - Q: T.Tensor(q_shape, dtype), - KV: T.Tensor(k_shape, dtype), - dO: T.Tensor(o_shape, dtype), - Indices: T.Tensor(indices_shape, indices_dtype), - Lse: T.Tensor(lse_shape, accum_dtype), - Delta: T.Tensor(delta_shape, accum_dtype), - Offsets: T.Tensor(offsets_shape, indices_dtype), - TokenIndices: T.Tensor(token_indices_shape, indices_dtype), - dQ: T.Tensor(q_shape, dtype), - dKV: T.Tensor(k_shape, accum_dtype), + Q: T.Tensor(q_shape, dtype), + KV: T.Tensor(k_shape, dtype), + dO: T.Tensor(o_shape, dtype), + Indices: T.Tensor(indices_shape, indices_dtype), + Lse: T.Tensor(lse_shape, accum_dtype), + Delta: T.Tensor(delta_shape, accum_dtype), + Offsets: T.Tensor(offsets_shape, indices_dtype), + TokenIndices: T.Tensor(token_indices_shape, indices_dtype), + dQ: T.Tensor(q_shape, dtype), + dKV: T.Tensor(k_shape, accum_dtype), ): with T.Kernel(S, kv_group, threads=threads) as (b_s_i, bz): Q_shared = T.alloc_shared([padded_H, D], dtype) @@ -163,32 +163,32 @@ def bwd( acc_dkv = T.alloc_fragment([BS, D], accum_dtype) acc_dkv_tail = T.alloc_fragment([BS, D_tail], accum_dtype) acc_dkv_shared = T.view(KV_shared, shape=[BS // split_store, D], dtype=accum_dtype) - acc_dkv_tail_shared = T.view( - KV_tail_shared, shape=[BS // split_store, D_tail], dtype=accum_dtype) + acc_dkv_tail_shared = T.view(KV_tail_shared, shape=[BS // split_store, D_tail], dtype=accum_dtype) b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1] bos, eos = Offsets[b_i], Offsets[b_i + 1] max_kv_i = s_i - T.copy(Q[bos + s_i, bz * padded_H:(bz + 1) * padded_H, :D], Q_shared) - T.copy(Q[bos + s_i, bz * padded_H:(bz + 1) * padded_H, D:], Q_tail_shared) - T.copy(dO[bos + s_i, bz * padded_H:(bz + 1) * padded_H, :D], dO_shared) + T.copy(Q[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D], Q_shared) + T.copy(Q[bos + s_i, bz * padded_H : (bz + 1) * padded_H, D:], Q_tail_shared) + T.copy(dO[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D], dO_shared) T.clear(acc_dq) T.clear(acc_dq_tail) - T.annotate_layout({ - dQ_shared: tilelang.layout.make_swizzled_layout(dQ_shared), - dQ_tail_shared: tilelang.layout.make_swizzled_layout(dQ_tail_shared), - }) + T.annotate_layout( + { + dQ_shared: tilelang.layout.make_swizzled_layout(dQ_shared), + dQ_tail_shared: tilelang.layout.make_swizzled_layout(dQ_tail_shared), + } + ) # Process each block of indices for i_i in T.Pipelined(NS, num_stages=num_stages): # Check which indices are valid for bi_i in T.Parallel(BS): - mask[bi_i] = (Indices[bos + s_i, bz, i_i * BS + bi_i] <= max_kv_i) & ( - Indices[bos + s_i, bz, i_i * BS + bi_i] != -1) + mask[bi_i] = (Indices[bos + s_i, bz, i_i * BS + bi_i] <= max_kv_i) & (Indices[bos + s_i, bz, i_i * BS + bi_i] != -1) # Compute attention scores for h_i, bi_i in T.Parallel(padded_H, BS): @@ -196,65 +196,33 @@ def bwd( # Load KV, V for this block of indices for bi_i, d_i in T.Parallel(BS, D): - KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, - d_i] + KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, d_i] - T.gemm( - Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) for bi_i, d_i in T.Parallel(BS, D_tail): - KV_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], - bz, D + d_i] - T.gemm( - Q_tail_shared, - KV_tail_shared, - acc_p, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) + KV_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, D + d_i] + T.gemm(Q_tail_shared, KV_tail_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) for h_i, bi_i in T.Parallel(padded_H, BS): - acc_p[h_i, bi_i] = T.exp(acc_p[h_i, bi_i] * sm_scale - - Lse[bos + s_i, bz * padded_H + h_i]) + acc_p[h_i, bi_i] = T.exp(acc_p[h_i, bi_i] * sm_scale - Lse[bos + s_i, bz * padded_H + h_i]) T.copy(acc_p, P_shared_cast) - T.gemm( - dO_shared, - KV_shared, - acc_dp, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol, - clear_accum=True) + T.gemm(dO_shared, KV_shared, acc_dp, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) for h_i, bi_i in T.Parallel(padded_H, BS): - acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * ( - acc_dp[h_i, bi_i] - Delta[bos + s_i, bz * padded_H + h_i]) * sm_scale + acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[bos + s_i, bz * padded_H + h_i]) * sm_scale T.copy(acc_dp, dP_shared_cast) T.gemm(dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol) T.gemm(dP_shared_cast, KV_tail_shared, acc_dq_tail, policy=T.GemmWarpPolicy.FullCol) - T.gemm( - dP_shared_cast, - Q_shared, - acc_dkv, - transpose_A=True, - policy=T.GemmWarpPolicy.FullCol, - clear_accum=True) - T.gemm( - P_shared_cast, - dO_shared, - acc_dkv, - transpose_A=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(dP_shared_cast, Q_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) + T.gemm(P_shared_cast, dO_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol) T.clear(acc_dkv_tail) - T.gemm( - dP_shared_cast, - Q_tail_shared, - acc_dkv_tail, - transpose_A=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(dP_shared_cast, Q_tail_shared, acc_dkv_tail, transpose_A=True, policy=T.GemmWarpPolicy.FullCol) for s in range(split_store): for bi_i, d_i in T.Parallel(BS, D): @@ -263,44 +231,32 @@ def bwd( for bi_i, d_i in T.Parallel(BS, D_tail): if bi_i < BS // split_store: - acc_dkv_tail_shared[bi_i, - d_i] = acc_dkv_tail[bi_i + s * (BS // split_store), - d_i] + acc_dkv_tail_shared[bi_i, d_i] = acc_dkv_tail[bi_i + s * (BS // split_store), d_i] for bi_i, d_i in T.Parallel(BS // split_store, D // 4): T.atomic_addx4( - dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * - (BS // split_store)], bz, d_i * 4], - acc_dkv_shared[bi_i, d_i * 4]) + dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, d_i * 4], + acc_dkv_shared[bi_i, d_i * 4], + ) # Atomically update dKV, dKV_tail tensors for bi_i, d_i in T.Parallel(BS // split_store, D_tail // 4): T.atomic_addx4( - dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * - (BS // split_store)], bz, D + d_i * 4], - acc_dkv_tail_shared[bi_i, d_i * 4]) + dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, D + d_i * 4], + acc_dkv_tail_shared[bi_i, d_i * 4], + ) # Store the accumulated dQ T.copy(acc_dq, dQ_shared) T.copy(acc_dq_tail, dQ_tail_shared) - T.copy(dQ_shared, dQ[bos + s_i, bz * padded_H:(bz + 1) * padded_H, :D]) - T.copy(dQ_tail_shared, dQ[bos + s_i, bz * padded_H:(bz + 1) * padded_H, D:]) + T.copy(dQ_shared, dQ[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D]) + T.copy(dQ_tail_shared, dQ[bos + s_i, bz * padded_H : (bz + 1) * padded_H, D:]) return sparse_mla_bwd_kernel -def sparse_mla_bwd(q, - kv, - o, - do, - indices, - lse, - offsets, - sm_scale=None, - is_casual=True, - return_kernel=False, - delta=None): +def sparse_mla_bwd(q, kv, o, do, indices, lse, offsets, sm_scale=None, is_casual=True, return_kernel=False, delta=None): assert q.is_contiguous() assert kv.is_contiguous() assert indices.is_contiguous() @@ -333,16 +289,9 @@ def sparse_mla_bwd(q, return dq, dkv -def ref_sparse_mla_bwd_interface(q, - kv, - o, - do, - indices, - lse, - offsets, - sm_scale=None, - is_casual=True): +def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, offsets, sm_scale=None, is_casual=True): from sparse_mla_fwd import ref_sparse_mla_fwd_interface + q = q.detach().clone() kv = kv.detach().clone() q.requires_grad = True @@ -352,32 +301,25 @@ def ref_sparse_mla_bwd_interface(q, return q.grad, kv.grad -def test_sparse_mla_bwd(B=1, - S=2048, - H=64, - HKV=1, - DQKV=576, - DV=512, - topk=512, - dtype=torch.bfloat16, - check_correctness=True): +def test_sparse_mla_bwd(B=1, S=2048, H=64, HKV=1, DQKV=576, DV=512, topk=512, dtype=torch.bfloat16, check_correctness=True): # Prepare data - q = torch.randn((S, H, DQKV), dtype=dtype, device='cuda').requires_grad_(True) - kv = torch.randn((S, HKV, DQKV), dtype=dtype, device='cuda').requires_grad_(True) - do = torch.randn((S, H, DV), dtype=dtype, device='cuda') + q = torch.randn((S, H, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + kv = torch.randn((S, HKV, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((S, H, DV), dtype=dtype, device="cuda") offsets = torch.tensor([0, S], dtype=torch.int32, device="cuda") - indices = torch.full((S, HKV, topk), S, dtype=torch.int32, device='cuda') + indices = torch.full((S, HKV, topk), S, dtype=torch.int32, device="cuda") for i in range(offsets.shape[0] - 1): seq_len = (offsets[i + 1] - offsets[i]).item() assert seq_len >= topk for t in range(seq_len): for h in range(HKV): i_i = torch.randperm(max(1, t))[:topk] - indices[offsets[i] + t, h, :len(i_i)] = i_i + indices[offsets[i] + t, h, : len(i_i)] = i_i # Forward from sparse_mla_fwd import sparse_mla_fwd_interface + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, offsets) tl_dq, tl_dkv = sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse, offsets) @@ -388,13 +330,15 @@ def test_sparse_mla_bwd(B=1, assert_tensors_similar(tl_dkv, ref_dkv, eps=1e-4, name="dkv") print("assert_tensors_similar passed") - per_token_flop = 2 * sum([ - H * DV * topk, - H * DQKV * topk, - H * DQKV * topk, - H * DQKV * topk, - H * DV * topk, - ]) + per_token_flop = 2 * sum( + [ + H * DV * topk, + H * DQKV * topk, + H * DQKV * topk, + H * DQKV * topk, + H * DV * topk, + ] + ) from tilelang.profiler import do_bench def fn(): @@ -402,19 +346,9 @@ def test_sparse_mla_bwd(B=1, ms = do_bench(fn, rep=100, warmup=250) print(f"Average time: {ms:.3f} ms") - print(f'bwd io bandwidth = ', - (B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12) - print(f'bwd tflops = ', per_token_flop * S / (ms * 1e-3) / 1e12) + print(f"bwd io bandwidth = ", (B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12) + print(f"bwd tflops = ", per_token_flop * S / (ms * 1e-3) / 1e12) if __name__ == "__main__": - test_sparse_mla_bwd( - B=1, - S=2048, - H=64, - HKV=1, - DQKV=576, - DV=512, - topk=512, - dtype=torch.bfloat16, - check_correctness=True) + test_sparse_mla_bwd(B=1, S=2048, H=64, HKV=1, DQKV=576, DV=512, topk=512, dtype=torch.bfloat16, check_correctness=True) diff --git a/examples/dsa_sparse_finetune/sparse_mla_fwd.py b/examples/dsa_sparse_finetune/sparse_mla_fwd.py index 5f03dfb..6ec3caa 100644 --- a/examples/dsa_sparse_finetune/sparse_mla_fwd.py +++ b/examples/dsa_sparse_finetune/sparse_mla_fwd.py @@ -27,15 +27,12 @@ def sparse_mla_fwd( num_stages=2, threads=128, ): - assert dim == tilelang.math.next_power_of_2( - dim), f"haven't check padding correctness yet, dim={dim}" - assert tail_dim == tilelang.math.next_power_of_2( - tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" assert is_causal == True, "non-casual is not supported" - assert (topk % - block_I == 0), "otherwise will load some index=0 thus causing wrong kv to be loaded" + assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" if sm_scale is None: - sm_scale = (1.0 / (dim + tail_dim))**0.5 + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 else: sm_scale = sm_scale @@ -58,9 +55,9 @@ def sparse_mla_fwd( H = head_kv padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) if padded_H != H: - assert ( - kv_group == 1 - ), "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + assert kv_group == 1, ( + "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + ) BI = block_I NI = tilelang.cdiv(topk, block_I) D = dim @@ -76,19 +73,18 @@ def sparse_mla_fwd( @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), # type: ignore - KV: T.Tensor(kv_shape, dtype), # type: ignore - Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore - Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore - TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore - Output: T.Tensor(o_shape, dtype), # type: ignore - Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore + TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore + Output: T.Tensor(o_shape, dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore ): - with T.Kernel( - seq_len * REPLICATE_H, kv_group, threads=threads) as ( - bx, - by, - ): + with T.Kernel(seq_len * REPLICATE_H, kv_group, threads=threads) as ( + bx, + by, + ): Q_shared = T.alloc_shared([H_per_block, D], dtype) Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) KV_shared = T.alloc_shared([BI, D], dtype) @@ -122,17 +118,13 @@ def sparse_mla_fwd( T.copy(Q[bos + s_i, H0:H1, D:], Q_tail_shared) for i_i in T.Pipelined(NI, num_stages=num_stages): - for bi_i in T.Parallel(BI): - mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & ( - Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1) + mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & (Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1) for bi_i, d_i in T.Parallel(BI, D): - KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, - d_i] + KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, d_i] for bi_i, d_i in T.Parallel(BI, D_tail): - K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], - g_i, D + d_i] + K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, D + d_i] for h_i, bi_i in T.Parallel(H_per_block, BI): acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype)) @@ -177,16 +169,9 @@ def sparse_mla_fwd( return main -def sparse_mla_fwd_interface(q, - kv, - indices, - offsets, - sm_scale=None, - return_p_sum: bool = False, - d_v=512, - block_I=32, - num_stages=2, - threads=128): +def sparse_mla_fwd_interface( + q, kv, indices, offsets, sm_scale=None, return_p_sum: bool = False, d_v=512, block_I=32, num_stages=2, threads=128 +): is_casual = True assert return_p_sum == False, "This kernel file is for fwd only" assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() @@ -205,16 +190,8 @@ def sparse_mla_fwd_interface(q, token_indices = prepare_token_indices(offsets) kernel = sparse_mla_fwd( - heads, - dim, - tail_dim, - topk, - kv_group, - sm_scale, - is_casual, - block_I=block_I, - num_stages=num_stages, - threads=threads) + heads, dim, tail_dim, topk, kv_group, sm_scale, is_casual, block_I=block_I, num_stages=num_stages, threads=threads + ) out, lse = kernel(q, kv, indices, offsets, token_indices) return out, lse @@ -224,9 +201,9 @@ def ref_sparse_mla_fwd_interface(Q, KV, Indices, offsets, sm_scale=None, is_casu KV = KV.float() all_o = [] for i in range(offsets.shape[0] - 1): - q = Q[None, offsets[i]:offsets[i + 1]] - kv = KV[None, offsets[i]:offsets[i + 1]] - indices = Indices[None, offsets[i]:offsets[i + 1]].clone() + q = Q[None, offsets[i] : offsets[i + 1]] + kv = KV[None, offsets[i] : offsets[i + 1]] + indices = Indices[None, offsets[i] : offsets[i + 1]].clone() indices = indices.transpose(1, 2) b, sq, h, dim_q = q.shape @@ -240,15 +217,15 @@ def ref_sparse_mla_fwd_interface(Q, KV, Indices, offsets, sm_scale=None, is_casu b, _, _, dim_v = v.shape g_index = g h_index = h // g - compressed_casual_mask = torch.arange( - 0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange( - 1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda").view(1, -1) + compressed_casual_mask = torch.arange(0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange( + 1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda" + ).view(1, -1) indices[indices > sk] = sk mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) mask = mask[..., :-1] mask = mask & compressed_casual_mask.view(1, 1, sq, sk) - mask[:, :, :1 - 1, 0] = True + mask[:, :, : 1 - 1, 0] = True mask = mask.view(b, g_index, 1, sq, sk) q = q.view(b, sq, g, -1, dim_q) @@ -265,18 +242,20 @@ def ref_sparse_mla_fwd_interface(Q, KV, Indices, offsets, sm_scale=None, is_casu return o.to(torch.bfloat16) -def test_sparse_mla_fwd(B=1, - S=4096, - H=128, - HKV=1, - DQK=576, - DV=512, - topk=2048, - dtype=torch.bfloat16, - check_correctness=True, - block_I=64, - num_stages=2, - threads=256): +def test_sparse_mla_fwd( + B=1, + S=4096, + H=128, + HKV=1, + DQK=576, + DV=512, + topk=2048, + dtype=torch.bfloat16, + check_correctness=True, + block_I=64, + num_stages=2, + threads=256, +): torch.random.manual_seed(0) q = torch.randn((S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) kv = torch.randn((S, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) @@ -289,10 +268,9 @@ def test_sparse_mla_fwd(B=1, for t in range(seq_len): for h in range(HKV): i_i = torch.randperm(max(1, t))[:topk] - indices[offsets[i] + t, h, :len(i_i)] = i_i + indices[offsets[i] + t, h, : len(i_i)] = i_i - tl_out, tl_lse = sparse_mla_fwd_interface( - q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads) + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads) if check_correctness: # otherwise may cause out of memory @@ -301,8 +279,7 @@ def test_sparse_mla_fwd(B=1, print("assert_tensors_similar passed") def fn(): - return sparse_mla_fwd_interface( - q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads) + return sparse_mla_fwd_interface(q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads) from tilelang.profiler import do_bench @@ -329,4 +306,5 @@ if __name__ == "__main__": check_correctness=True, block_I=64, num_stages=2, - threads=256) + threads=256, + ) diff --git a/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py b/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py index 94bdb8f..6675215 100644 --- a/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py +++ b/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py @@ -30,14 +30,11 @@ def tl_sparse_mla_topk_reducesum_impl( num_stages=2, threads=128, ): - assert dim == tilelang.math.next_power_of_2( - dim), f"haven't check padding correctness yet, dim={dim}" - assert tail_dim == tilelang.math.next_power_of_2( - tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" - assert (topk % - block_I == 0), "otherwise will load some index=0 thus causing wrong kv to be loaded" + assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" if sm_scale is None: - sm_scale = (1.0 / (dim + tail_dim))**0.5 + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 batch_plus_one = T.symbolic("batch_plus_one") seq_len = T.symbolic("seq_len") @@ -52,9 +49,9 @@ def tl_sparse_mla_topk_reducesum_impl( H = head_kv padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) if padded_H != H: - assert ( - kv_group == 1 - ), "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + assert kv_group == 1, ( + "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + ) BI = block_I NI = tilelang.cdiv(topk, block_I) D = dim @@ -78,19 +75,18 @@ def tl_sparse_mla_topk_reducesum_impl( @T.prim_func def tl_sparse_mla_topk_reducesum_kernel( - Q: T.Tensor(q_shape, dtype), # type: ignore - KV: T.Tensor(kv_shape, dtype), # type: ignore - Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore - Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore - Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore - TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore - ReduceSum: T.Tensor(reducesum_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore + TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore + ReduceSum: T.Tensor(reducesum_shape, accum_dtype), # type: ignore ): - with T.Kernel( - seq_len * REPLICATE_H, kv_group, threads=threads) as ( - bx, - by, - ): + with T.Kernel(seq_len * REPLICATE_H, kv_group, threads=threads) as ( + bx, + by, + ): Q_shared = T.alloc_shared([H_per_block, D], dtype) Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) KV_shared = T.alloc_shared([BI, D], dtype) @@ -119,17 +115,13 @@ def tl_sparse_mla_topk_reducesum_impl( T.copy(Lse[bos + s_i, H0:H1], lse) for i_i in T.Pipelined(NI, num_stages=num_stages): - for bi_i in T.Parallel(BI): - mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & ( - Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1) + mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & (Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1) for bi_i, d_i in T.Parallel(BI, D): - KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, - d_i] + KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, d_i] for bi_i, d_i in T.Parallel(BI, D_tail): - K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], - g_i, D + d_i] + K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, D + d_i] for h_i, bi_i in T.Parallel(H_per_block, BI): acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype)) @@ -150,7 +142,7 @@ def tl_sparse_mla_topk_reducesum_impl( for h_i, bi_i in T.Parallel(H_per_block, BI): acc_s[h_i, bi_i] = T.exp(acc_s[h_i, bi_i] * sm_scale - lse[h_i]) T.reduce_sum(acc_s, reducesum, dim=0) - T.copy(reducesum, ReduceSum[bos + s_i, g_i, r_i, i_i * BI:i_i * BI + BI]) + T.copy(reducesum, ReduceSum[bos + s_i, g_i, r_i, i_i * BI : i_i * BI + BI]) return tl_sparse_mla_topk_reducesum_kernel @@ -178,29 +170,26 @@ def sparse_mla_topk_reducesum_interface( return attn_score -def ref_mla_topk_softmax(Q: torch.Tensor, K: torch.Tensor, TopkIndices: torch.Tensor, - offsets: torch.Tensor): +def ref_mla_topk_softmax(Q: torch.Tensor, K: torch.Tensor, TopkIndices: torch.Tensor, offsets: torch.Tensor): # q: [batch, seq_len, heads, dim] # k: [batch, seq_len, dim] - sm_scale = Q.shape[-1]**-0.5 + sm_scale = Q.shape[-1] ** -0.5 all_lse = [] all_topk_score = [] for i in range(offsets.shape[0] - 1): - q = Q[offsets[i]:offsets[i + 1]] - k = K[offsets[i]:offsets[i + 1]] - topk_indices = TopkIndices[offsets[i]:offsets[i + 1]] + q = Q[offsets[i] : offsets[i + 1]] + k = K[offsets[i] : offsets[i + 1]] + topk_indices = TopkIndices[offsets[i] : offsets[i + 1]] seq_len = q.shape[0] - mask = (torch.arange(seq_len)[:, None] - >= torch.arange(seq_len)[None, :]).unsqueeze(-2).cuda() - logits = einsum(q, k, 's1 h d, s2 d -> s1 h s2') * sm_scale - logits = torch.where(mask, logits, float('-inf')) + mask = (torch.arange(seq_len)[:, None] >= torch.arange(seq_len)[None, :]).unsqueeze(-2).cuda() + logits = einsum(q, k, "s1 h d, s2 d -> s1 h s2") * sm_scale + logits = torch.where(mask, logits, float("-inf")) score = F.softmax(logits, dim=-1, dtype=torch.float32) score_sum = score.sum(dim=-2) topk_score = torch.gather(score_sum, dim=-1, index=topk_indices.to(torch.int64)) topk_score = topk_score / topk_score.sum(dim=-1, keepdim=True) max_logits = logits.amax(dim=-1).to(torch.float32) - lse = torch.log( - (logits - max_logits.unsqueeze(-1).to(torch.float32)).exp().sum(dim=-1)) + max_logits + lse = torch.log((logits - max_logits.unsqueeze(-1).to(torch.float32)).exp().sum(dim=-1)) + max_logits all_lse.append(lse) all_topk_score.append(topk_score) lse = torch.cat(all_lse, dim=0) @@ -222,20 +211,16 @@ def test_kernel( kv = torch.randn((S, D + tail_D)).cuda().bfloat16() offsets = torch.tensor([0, 1023, S], dtype=torch.int32).cuda() - topk_indices = repeat( - torch.arange(topk, dtype=torch.int32).cuda(), 'k -> s k', s=S).contiguous() + topk_indices = repeat(torch.arange(topk, dtype=torch.int32).cuda(), "k -> s k", s=S).contiguous() lse, ref_attn_score = ref_mla_topk_softmax(q, kv, topk_indices, offsets) kv = kv.unsqueeze(-2) topk_indices = topk_indices.unsqueeze(-2) - attn_score = sparse_mla_topk_reducesum_interface( - q, kv, topk_indices, lse, offsets, dim_v=D).squeeze(-2) - print( - f"attn_score err: {get_abs_err(attn_score, ref_attn_score):.6f} ratio: {get_err_ratio(attn_score, ref_attn_score):.6f}" - ) + attn_score = sparse_mla_topk_reducesum_interface(q, kv, topk_indices, lse, offsets, dim_v=D).squeeze(-2) + print(f"attn_score err: {get_abs_err(attn_score, ref_attn_score):.6f} ratio: {get_err_ratio(attn_score, ref_attn_score):.6f}") -if __name__ == '__main__': +if __name__ == "__main__": test_kernel() diff --git a/examples/dsa_sparse_finetune/utils.py b/examples/dsa_sparse_finetune/utils.py index 691af64..96afd06 100644 --- a/examples/dsa_sparse_finetune/utils.py +++ b/examples/dsa_sparse_finetune/utils.py @@ -66,10 +66,8 @@ def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True): raise_assert: Whether to raise assertion error on failure """ sim = calculate_tensor_similarity(x, y, name) - diff = 1. - sim + diff = 1.0 - sim if not (0 <= diff <= eps): - print( - f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m" - ) + print(f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m") if raise_assert: assert False # noqa: B011 diff --git a/examples/dynamic_shape/example_dynamic.py b/examples/dynamic_shape/example_dynamic.py index be018c8..97ce7d9 100644 --- a/examples/dynamic_shape/example_dynamic.py +++ b/examples/dynamic_shape/example_dynamic.py @@ -29,9 +29,9 @@ def matmul_dynamic_mnk( @T.prim_func def dynamic_matmul( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -53,15 +53,14 @@ def matmul_dynamic_mnk( return dynamic_matmul -def matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, - accum_dtype, num_stages, threads): +def matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads): print( f"M: {M}, N: {N}, K: {K}, block_M: {block_M}, block_N: {block_N}, block_K: {block_K}, trans_A: {trans_A}, trans_B: {trans_B}, in_dtype: {in_dtype}, out_dtype: {out_dtype}, accum_dtype: {accum_dtype}, num_stages: {num_stages}, threads: {threads}" ) - kernel = matmul_dynamic_mnk(block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, - accum_dtype, num_stages, threads) + kernel = matmul_dynamic_mnk(block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads) import torch + if trans_A: A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype)) else: @@ -103,8 +102,7 @@ def main(M=16384, N=16384, K=16384): accum_dtype = "float32" num_stages = 3 threads = 128 - matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, - accum_dtype, num_stages, threads) + matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads) if __name__ == "__main__": diff --git a/examples/elementwise/example_elementwise_add.py b/examples/elementwise/example_elementwise_add.py index bc9bb4d..464312c 100644 --- a/examples/elementwise/example_elementwise_add.py +++ b/examples/elementwise/example_elementwise_add.py @@ -12,10 +12,8 @@ def ref_program(x, y): @tilelang.jit(out_idx=[-1]) def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): - @T.prim_func - def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor( - (M, N), out_dtype)): + def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N), out_dtype)): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared((block_M, block_N), in_dtype) B_shared = T.alloc_shared((block_M, block_N), in_dtype) @@ -24,7 +22,7 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): T.copy(A[by * block_M, bx * block_N], A_shared) T.copy(B[by * block_M, bx * block_N], B_shared) - for (local_y, local_x) in T.Parallel(block_M, block_N): + for local_y, local_x in T.Parallel(block_M, block_N): C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x] T.copy(C_local, C_shared) T.copy(C_shared, C[by * block_M, bx * block_N]) @@ -41,19 +39,21 @@ def get_configs(M, N): def get_best_config(M, N): - def kernel(block_M=None, block_N=None, threads=None): return elementwise_add(M, N, block_M, block_N, "float32", "float32", threads) - autotuner = AutoTuner.from_kernel( - kernel=kernel, configs=get_configs(M, N)).set_compile_args( + autotuner = ( + AutoTuner.from_kernel(kernel=kernel, configs=get_configs(M, N)) + .set_compile_args( out_idx=[-1], target="cuda", - ).set_profile_args( + ) + .set_profile_args( supply_type=tilelang.TensorSupplyType.Auto, ref_prog=ref_program, skip_check=False, ) + ) return autotuner.run(warmup=3, rep=20) diff --git a/examples/flash_attention/bert_padding.py b/examples/flash_attention/bert_padding.py index 7058fd7..15c4097 100644 --- a/examples/flash_attention/bert_padding.py +++ b/examples/flash_attention/bert_padding.py @@ -6,7 +6,6 @@ from einops import rearrange, repeat class IndexFirstAxis(torch.autograd.Function): - @staticmethod def forward(ctx, input, indices): ctx.save_for_backward(indices) @@ -15,9 +14,7 @@ class IndexFirstAxis(torch.autograd.Function): second_dim = other_shape.numel() # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. # return input[indices] - return torch.gather( - rearrange(input, "b ... -> b (...)"), 0, - repeat(indices, "z -> z d", d=second_dim)).reshape(-1, *other_shape) + return torch.gather(rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)).reshape(-1, *other_shape) @staticmethod def backward(ctx, grad_output): @@ -40,14 +37,12 @@ index_first_axis = IndexFirstAxis.apply class IndexPutFirstAxis(torch.autograd.Function): - @staticmethod def forward(ctx, values, indices, first_axis_dim): ctx.save_for_backward(indices) assert indices.ndim == 1 assert values.ndim >= 2 - output = torch.zeros( - first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) + output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. output[indices] = values # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) @@ -66,7 +61,6 @@ index_put_first_axis = IndexPutFirstAxis.apply class IndexFirstAxisResidual(torch.autograd.Function): - @staticmethod def forward(ctx, input, indices): ctx.save_for_backward(indices) @@ -128,7 +122,7 @@ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_leng """ Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model). The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286). - + For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is: ``` [ @@ -177,9 +171,7 @@ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_leng """ length = attention_mask_in_length.sum(dim=-1) seqlen = attention_mask_in_length.size(-1) - attention_mask_2d = torch.arange( - seqlen, device=length.device, dtype=length.dtype).expand(len(length), - seqlen) < length.unsqueeze(1) + attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), seqlen) < length.unsqueeze(1) real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten() seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx] indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten() diff --git a/examples/flash_attention/example_gqa_bwd.py b/examples/flash_attention/example_gqa_bwd.py index 968d1de..d1f5843 100644 --- a/examples/flash_attention/example_gqa_bwd.py +++ b/examples/flash_attention/example_gqa_bwd.py @@ -6,11 +6,13 @@ import argparse @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] @@ -20,11 +22,11 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc @T.prim_func def flash_fwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -40,25 +42,21 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc logsum = T.alloc_fragment([block_M], accum_dtype) T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = ( - T.ceildiv( - (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_range, num_stages=1): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, - -T.infinity(acc_s.dtype), 0) + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_M): @@ -76,18 +74,20 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_M, dim_v): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): dtype = "float16" accum_dtype = "float" @@ -96,9 +96,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -107,26 +107,27 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim_v, blk)): - T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep def make_dq_layout(dQ): # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment - return T.Layout(dQ.shape, - lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + return T.Layout(dQ.shape, lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) @tilelang.jit( - out_idx=[1], pass_configs={ + out_idx=[1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk): dtype = "float16" accum_dtype = "float" @@ -135,35 +136,27 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk): @T.prim_func def flash_bwd_post( - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dQ_out: T.Tensor(shape, dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): T.annotate_layout({dQ: make_dq_layout(dQ)}) T.copy( - dQ[bz, bx * blk:(bx + 1) * blk, by, :], - dQ_out[bz, bx * blk:(bx + 1) * blk, by, :], + dQ[bz, bx * blk : (bx + 1) * blk, by, :], + dQ_out[bz, bx * blk : (bx + 1) * blk, by, :], ) return flash_bwd_post -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd_atomic_add(batch, - heads, - seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=1): - sm_scale = (1.0 / dim_qk)**0.5 - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] @@ -173,15 +166,15 @@ def flashattn_bwd_atomic_add(batch, @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(k_shape, accum_dtype), # type: ignore - dV: T.Tensor(v_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -201,35 +194,36 @@ def flashattn_bwd_atomic_add(batch, dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype) dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - }) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + } + ) - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -241,29 +235,21 @@ def flashattn_bwd_atomic_add(batch, for i, j in T.Parallel(block_N, dim_qk): T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) T.copy(dv, dv_shared) - T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared) + T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared) T.copy(dk, dk_shared) - T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared) + T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared) return flash_bwd -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd_split(batch, - heads, - seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=1): - sm_scale = (1.0 / dim_qk)**0.5 - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_split(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] @@ -275,15 +261,15 @@ def flashattn_bwd_split(batch, @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(dk_shape, dtype), # type: ignore - dV: T.Tensor(dv_shape, dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(dk_shape, dtype), # type: ignore + dV: T.Tensor(dv_shape, dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -303,37 +289,38 @@ def flashattn_bwd_split(batch, dv_shared = T.alloc_shared([block_M, dim_v], dtype) dk_shared = T.alloc_shared([block_M, dim_qk], dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -346,16 +333,15 @@ def flashattn_bwd_split(batch, T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) T.copy(dv, dv_shared) - T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) + T.copy(dv_shared, dV[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) T.copy(dk, dk_shared) - T.copy(dk, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) + T.copy(dk, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) return flash_bwd @torch.compile class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): BATCH, N_CTX, H, D_HEAD_QK = q.shape @@ -373,7 +359,10 @@ class _attention(torch.autograd.Function): def backward(ctx, do): q, k, v, o, lse = ctx.saved_tensors BATCH, N_CTX, H, D_HEAD_QK = q.shape - HEAD_KV, D_HEAD_V, = v.shape[-2], v.shape[-1] + ( + HEAD_KV, + D_HEAD_V, + ) = v.shape[-2], v.shape[-1] groups = H // HEAD_KV def maybe_contiguous(x): @@ -390,17 +379,8 @@ class _attention(torch.autograd.Function): if ctx.use_atomic: kernel = flashattn_bwd_atomic_add( - BATCH, - H, - N_CTX, - D_HEAD_QK, - D_HEAD_V, - ctx.causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=groups) + BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups + ) shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK] shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V] @@ -413,17 +393,8 @@ class _attention(torch.autograd.Function): dv = dv.to(torch.float16) else: kernel = flashattn_bwd_split( - BATCH, - H, - N_CTX, - D_HEAD_QK, - D_HEAD_V, - ctx.causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=groups) + BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups + ) shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel @@ -445,53 +416,45 @@ def ref_program(Q, K, V, is_causal, groups=1): # K: [B, T, HK, D_QK] # V: [B, T, HV, D_V] # HQ = HKV * groups - assert Q.size(2) == K.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" - assert Q.size(2) == V.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" dim_qk = Q.size(-1) K = K.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output -def main(BATCH: int = 1, - H: int = 32, - N_CTX: int = 256, - D_HEAD_QK: int = 192, - D_HEAD_V: int = 128, - groups: int = 16, - causal: bool = False, - use_atomic: bool = True): +def main( + BATCH: int = 1, + H: int = 32, + N_CTX: int = 256, + D_HEAD_QK: int = 192, + D_HEAD_V: int = 128, + groups: int = 16, + causal: bool = False, + use_atomic: bool = True, +): flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V total_flops = 3 * flops_per_qk + 2 * flops_per_v if causal: total_flops *= 0.5 - Q = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() head_kv = H // groups - K = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - V = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - dO = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() O = attention(Q, K, V, causal, groups, use_atomic) O.backward(dO, retain_graph=True) dQ, Q.grad = Q.grad.clone(), None @@ -508,7 +471,7 @@ def main(BATCH: int = 1, torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) - print('All checks passed.✅') + print("All checks passed.✅") def run(): O_ref.backward(dO, retain_graph=True) @@ -528,17 +491,15 @@ def main(BATCH: int = 1, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='Batch size') - parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') - parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K') - parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') - parser.add_argument('--causal', action='store_true', help='Causal flag') - parser.add_argument('--groups', type=int, default=16, help='groups') - parser.add_argument( - '--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV') - parser.add_argument( - '--use_split', action='store_true', default=False, help='Use split for dK/dV') + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K") + parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V") + parser.add_argument("--causal", action="store_true", help="Causal flag") + parser.add_argument("--groups", type=int, default=16, help="groups") + parser.add_argument("--use_atomic", action="store_true", default=False, help="Use atomic add for dK/dV") + parser.add_argument("--use_split", action="store_true", default=False, help="Use split for dK/dV") args = parser.parse_args() # Handle backward compatibility and logic @@ -550,5 +511,4 @@ if __name__ == "__main__": # Default: use atomic use_atomic = True - main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, - use_atomic) + main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic) diff --git a/examples/flash_attention/example_gqa_bwd_tma_reduce.py b/examples/flash_attention/example_gqa_bwd_tma_reduce.py index c427908..c6cf336 100644 --- a/examples/flash_attention/example_gqa_bwd_tma_reduce.py +++ b/examples/flash_attention/example_gqa_bwd_tma_reduce.py @@ -9,11 +9,13 @@ tilelang.disable_cache() @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] @@ -23,11 +25,11 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc @T.prim_func def flash_fwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -43,27 +45,23 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc logsum = T.alloc_fragment([block_M], accum_dtype) T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) # Warning: in causal/varlen/unaligned seqlen scenarios, the -inf will cause undefined behavior in exp ops # We should set it to negative large number instead T.fill(scores_max, T.Cast(accum_dtype, -1e30)) - loop_range = ( - T.ceildiv( - (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_range, num_stages=1): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - T.Cast(accum_dtype, -1e30)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, T.Cast(accum_dtype, -1e30)) else: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, - -T.infinity(acc_s.dtype), 0) + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_M): @@ -81,18 +79,20 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_M, dim_v): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): dtype = "float16" accum_dtype = "float" @@ -101,9 +101,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -112,12 +112,12 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim_v, blk)): - T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep @@ -128,9 +128,11 @@ def make_dq_layout(dQ): @tilelang.jit( - out_idx=[3, 4, 5], pass_configs={ + out_idx=[3, 4, 5], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_postprocess(batch, heads, head_kv, seq_len, dim_qk, dim_v): dtype = "float16" accum_dtype = "float" @@ -141,46 +143,37 @@ def flashattn_bwd_postprocess(batch, heads, head_kv, seq_len, dim_qk, dim_v): @T.prim_func def flash_bwd_post( - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(k_shape, accum_dtype), # type: ignore - dV: T.Tensor(v_shape, accum_dtype), # type: ignore - dQ_out: T.Tensor(q_shape, dtype), # type: ignore - dK_out: T.Tensor(k_shape, dtype), # type: ignore - dV_out: T.Tensor(v_shape, dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(q_shape, dtype), # type: ignore + dK_out: T.Tensor(k_shape, dtype), # type: ignore + dV_out: T.Tensor(v_shape, dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): T.annotate_layout({dQ: make_dq_layout(dQ)}) - T.copy(dQ[bz, bx * blk:(bx + 1) * blk, by, :], dQ_out[bz, bx * blk:(bx + 1) * blk, - by, :]) + T.copy(dQ[bz, bx * blk : (bx + 1) * blk, by, :], dQ_out[bz, bx * blk : (bx + 1) * blk, by, :]) with T.Kernel(T.ceildiv(seq_len, blk), head_kv, batch, threads=128) as (bx, by, bz): - T.annotate_layout({ - dK: make_dq_layout(dK), - dV: make_dq_layout(dV), - }) - T.copy(dK[bz, bx * blk:(bx + 1) * blk, by, :], dK_out[bz, bx * blk:(bx + 1) * blk, - by, :]) - T.copy(dV[bz, bx * blk:(bx + 1) * blk, by, :], dV_out[bz, bx * blk:(bx + 1) * blk, - by, :]) + T.annotate_layout( + { + dK: make_dq_layout(dK), + dV: make_dq_layout(dV), + } + ) + T.copy(dK[bz, bx * blk : (bx + 1) * blk, by, :], dK_out[bz, bx * blk : (bx + 1) * blk, by, :]) + T.copy(dV[bz, bx * blk : (bx + 1) * blk, by, :], dV_out[bz, bx * blk : (bx + 1) * blk, by, :]) return flash_bwd_post -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd_atomic_add(batch, - heads, - seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=1): - sm_scale = (1.0 / dim_qk)**0.5 - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] @@ -190,15 +183,15 @@ def flashattn_bwd_atomic_add(batch, @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(k_shape, accum_dtype), # type: ignore - dV: T.Tensor(v_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -219,37 +212,38 @@ def flashattn_bwd_atomic_add(batch, dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype) dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - dK: make_dq_layout(dK), - dV: make_dq_layout(dV), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - }) - - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + dK: make_dq_layout(dK), + dV: make_dq_layout(dV), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + } + ) + + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -259,33 +253,23 @@ def flashattn_bwd_atomic_add(batch, T.clear(dq) T.gemm(dsT_shared, K_shared, dq, transpose_A=True) T.copy(dq, dq_shared) - T.atomic_add(dQ[bz, k * block_N:(k + 1) * block_N, bx, :], dq_shared, use_tma=True) + T.atomic_add(dQ[bz, k * block_N : (k + 1) * block_N, bx, :], dq_shared, use_tma=True) T.copy(dv, dv_shared) - T.atomic_add( - dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared, use_tma=True) + T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared, use_tma=True) T.copy(dk, dk_shared) - T.atomic_add( - dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared, use_tma=True) + T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared, use_tma=True) return flash_bwd -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd_split_novarlen(batch, - heads, - seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=1): - sm_scale = (1.0 / dim_qk)**0.5 - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_split_novarlen(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] @@ -297,15 +281,15 @@ def flashattn_bwd_split_novarlen(batch, @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(dk_shape, dtype), # type: ignore - dV: T.Tensor(dv_shape, dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(dk_shape, dtype), # type: ignore + dV: T.Tensor(dv_shape, dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -325,37 +309,38 @@ def flashattn_bwd_split_novarlen(batch, dv_shared = T.alloc_shared([block_M, dim_v], dtype) dk_shared = T.alloc_shared([block_M, dim_qk], dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) - - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) + + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -368,16 +353,15 @@ def flashattn_bwd_split_novarlen(batch, T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) T.copy(dv, dv_shared) - T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) + T.copy(dv_shared, dV[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) T.copy(dk, dk_shared) - T.copy(dk, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) + T.copy(dk, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) return flash_bwd @torch.compile class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): BATCH, N_CTX, H, D_HEAD_QK = q.shape @@ -395,7 +379,10 @@ class _attention(torch.autograd.Function): def backward(ctx, do): q, k, v, o, lse = ctx.saved_tensors BATCH, N_CTX, H, D_HEAD_QK = q.shape - HEAD_KV, D_HEAD_V, = v.shape[-2], v.shape[-1] + ( + HEAD_KV, + D_HEAD_V, + ) = v.shape[-2], v.shape[-1] groups = H // HEAD_KV def maybe_contiguous(x): @@ -412,17 +399,8 @@ class _attention(torch.autograd.Function): if ctx.use_atomic: kernel = flashattn_bwd_atomic_add( - BATCH, - H, - N_CTX, - D_HEAD_QK, - D_HEAD_V, - ctx.causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=groups) + BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups + ) shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK] shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V] @@ -433,17 +411,8 @@ class _attention(torch.autograd.Function): dq, dk, dv = mod_post(dq, dk, dv) else: kernel = flashattn_bwd_split_novarlen( - BATCH, - H, - N_CTX, - D_HEAD_QK, - D_HEAD_V, - ctx.causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=groups) + BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups + ) shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel @@ -451,8 +420,7 @@ class _attention(torch.autograd.Function): dk = torch.empty(shape_k, dtype=torch.float16, device=q.device) dv = torch.empty(shape_v, dtype=torch.float16, device=q.device) kernel(q, k, v, do, lse, delta, dq, dk, dv) - dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), - torch.zeros_like(v, dtype=torch.float32)) + dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), torch.zeros_like(v, dtype=torch.float32)) dk, dv = dk.sum(0), dv.sum(0) return dq, dk, dv, None, None, None @@ -466,53 +434,45 @@ def ref_program(Q, K, V, is_causal, groups=1): # K: [B, T, HK, D_QK] # V: [B, T, HV, D_V] # HQ = HKV * groups - assert Q.size(2) == K.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" - assert Q.size(2) == V.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" dim_qk = Q.size(-1) K = K.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output -def main(BATCH: int = 1, - H: int = 32, - N_CTX: int = 256, - D_HEAD_QK: int = 192, - D_HEAD_V: int = 128, - groups: int = 16, - causal: bool = False, - use_atomic: bool = True): +def main( + BATCH: int = 1, + H: int = 32, + N_CTX: int = 256, + D_HEAD_QK: int = 192, + D_HEAD_V: int = 128, + groups: int = 16, + causal: bool = False, + use_atomic: bool = True, +): flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V total_flops = 3 * flops_per_qk + 2 * flops_per_v if causal: total_flops *= 0.5 - Q = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() head_kv = H // groups - K = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - V = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - dO = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() O = attention(Q, K, V, causal, groups, use_atomic) O.backward(dO, retain_graph=True) dQ, Q.grad = Q.grad.clone(), None @@ -529,7 +489,7 @@ def main(BATCH: int = 1, torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) - print('All checks passed.✅') + print("All checks passed.✅") def run(): O_ref.backward(dO, retain_graph=True) @@ -552,17 +512,15 @@ if __name__ == "__main__": print(f"Detected GPU compute capability: {arch}") assert float(arch) >= 9.0, "This example only supports GPU with compute capability >= 9.0" parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='Batch size') - parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') - parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K') - parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') - parser.add_argument('--causal', action='store_true', help='Causal flag') - parser.add_argument('--groups', type=int, default=16, help='groups') - parser.add_argument( - '--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV') - parser.add_argument( - '--use_split', action='store_true', default=False, help='Use split for dK/dV') + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K") + parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V") + parser.add_argument("--causal", action="store_true", help="Causal flag") + parser.add_argument("--groups", type=int, default=16, help="groups") + parser.add_argument("--use_atomic", action="store_true", default=False, help="Use atomic add for dK/dV") + parser.add_argument("--use_split", action="store_true", default=False, help="Use split for dK/dV") args = parser.parse_args() # Handle backward compatibility and logic @@ -574,5 +532,4 @@ if __name__ == "__main__": # Default: use atomic use_atomic = True - main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, - use_atomic) + main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic) diff --git a/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py b/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py index a9604f4..112438f 100644 --- a/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py +++ b/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py @@ -15,32 +15,21 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): if mode == "full": lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) elif mode == "random": - lengths = torch.randint( - max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device) + lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device) elif mode == "third": lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) - padding_mask = ( - repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths) + padding_mask = repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths return padding_mask @tilelang.jit( - out_idx=[5, 6], pass_configs={ + out_idx=[5, 6], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn_fwd(batch, - total_q, - total_kv, - N_CTX, - heads, - max_seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - groups=1): - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn_fwd(batch, total_q, total_kv, N_CTX, heads, max_seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [total_q, heads, dim_qk] k_shape = [total_kv, head_kv, dim_qk] @@ -51,13 +40,13 @@ def flashattn_fwd(batch, @T.prim_func def flash_fwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore - cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore - Output: T.Tensor(o_shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore + cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore + Output: T.Tensor(o_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore ): with T.Kernel(T.ceildiv(max_seq_len, block_M), heads, batch, threads=256) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -102,15 +91,17 @@ def flashattn_fwd(batch, if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and - (bx * block_M + i < q_current_seqlen and - k * block_N + j < k_current_seqlen), 0, - T.Cast(accum_dtype, -1e30)) + acc_s[i, j] = T.if_then_else( + (bx * block_M + i >= k * block_N + j) + and (bx * block_M + i < q_current_seqlen and k * block_N + j < k_current_seqlen), + 0, + T.Cast(accum_dtype, -1e30), + ) else: for i, j in T.Parallel(block_M, block_N): acc_s[i, j] = T.if_then_else( - bx * block_M + i < q_current_seqlen and - k * block_N + j < k_current_seqlen, 0, T.Cast(accum_dtype, -1e30)) + bx * block_M + i < q_current_seqlen and k * block_N + j < k_current_seqlen, 0, T.Cast(accum_dtype, -1e30) + ) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) for i, d in T.Parallel(block_N, dim_v): if k * block_N + i < k_current_seqlen: @@ -148,9 +139,11 @@ def flashattn_fwd(batch, @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, total_q, N_CTX, max_seq_len, dim_v): dtype = "float16" accum_dtype = "float" @@ -159,10 +152,10 @@ def flashattn_bwd_preprocess(batch, heads, total_q, N_CTX, max_seq_len, dim_v): @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore - Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore + Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(max_seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -201,9 +194,11 @@ def make_dq_layout(dQ): @tilelang.jit( - out_idx=[3, 4, 5], pass_configs={ + out_idx=[3, 4, 5], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_postprocess(total_q, total_kv, heads, head_kv, dim_qk, dim_v): dtype = "float16" accum_dtype = "float" @@ -214,46 +209,39 @@ def flashattn_bwd_postprocess(total_q, total_kv, heads, head_kv, dim_qk, dim_v): @T.prim_func def flash_bwd_post( - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(k_shape, accum_dtype), # type: ignore - dV: T.Tensor(v_shape, accum_dtype), # type: ignore - dQ_out: T.Tensor(q_shape, dtype), # type: ignore - dK_out: T.Tensor(k_shape, dtype), # type: ignore - dV_out: T.Tensor(v_shape, dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(q_shape, dtype), # type: ignore + dK_out: T.Tensor(k_shape, dtype), # type: ignore + dV_out: T.Tensor(v_shape, dtype), # type: ignore ): with T.Kernel(T.ceildiv(total_q, blk), heads, threads=128) as (bx, by): T.annotate_layout({dQ: make_dq_layout(dQ)}) - T.copy(dQ[bx * blk:(bx + 1) * blk, by, :], dQ_out[bx * blk:(bx + 1) * blk, by, :]) + T.copy(dQ[bx * blk : (bx + 1) * blk, by, :], dQ_out[bx * blk : (bx + 1) * blk, by, :]) with T.Kernel(T.ceildiv(total_kv, blk), head_kv, threads=128) as (bx, by): - T.annotate_layout({ - dK: make_dq_layout(dK), - dV: make_dq_layout(dV), - }) - T.copy(dK[bx * blk:(bx + 1) * blk, by, :], dK_out[bx * blk:(bx + 1) * blk, by, :]) - T.copy(dV[bx * blk:(bx + 1) * blk, by, :], dV_out[bx * blk:(bx + 1) * blk, by, :]) + T.annotate_layout( + { + dK: make_dq_layout(dK), + dV: make_dq_layout(dV), + } + ) + T.copy(dK[bx * blk : (bx + 1) * blk, by, :], dK_out[bx * blk : (bx + 1) * blk, by, :]) + T.copy(dV[bx * blk : (bx + 1) * blk, by, :], dV_out[bx * blk : (bx + 1) * blk, by, :]) return flash_bwd_post -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd_atomic_add(batch, - total_q, - total_kv, - N_CTX, - heads, - max_seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=1): - sm_scale = (1.0 / dim_qk)**0.5 - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_atomic_add( + batch, total_q, total_kv, N_CTX, heads, max_seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1 +): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [total_q, heads, dim_qk] k_shape = [total_kv, head_kv, dim_qk] @@ -264,20 +252,19 @@ def flashattn_bwd_atomic_add(batch, @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - dO: T.Tensor(do_shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore - cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore - cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(k_shape, accum_dtype), # type: ignore - dV: T.Tensor(v_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor(do_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore + cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore ): - with T.Kernel( - heads, T.ceildiv(max_seq_len, block_M), batch, threads=threads) as (bx, by, bz): + with T.Kernel(heads, T.ceildiv(max_seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) dsT_shared = T.alloc_shared([block_M, block_N], dtype) q = T.alloc_shared([block_N, dim_qk], dtype) @@ -303,58 +290,54 @@ def flashattn_bwd_atomic_add(batch, q_current_seqlen = q_end_idx - q_start_idx k_current_seqlen = k_end_idx - k_start_idx - T.annotate_layout({ - dQ: make_dq_layout(dQ), - dK: make_dq_layout(dK), - dV: make_dq_layout(dV), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - }) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + dK: make_dq_layout(dK), + dV: make_dq_layout(dV), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + } + ) - T.copy(K[k_start_idx + by * block_M:k_start_idx + (by + 1) * block_M, bx // groups, :], - K_shared) - T.copy(V[k_start_idx + by * block_M:k_start_idx + (by + 1) * block_M, bx // groups, :], - V_shared) + T.copy(K[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) - loop_st = T.min( - T.floordiv(by * block_M, block_N), T.floordiv(q_current_seqlen, - block_N)) if is_causal else 0 + loop_st = T.min(T.floordiv(by * block_M, block_N), T.floordiv(q_current_seqlen, block_N)) if is_causal else 0 loop_ed = T.ceildiv(q_current_seqlen, block_N) for k_base in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy( - Q[q_start_idx + k_base * block_N:q_start_idx + (k_base + 1) * block_N, bx, :], - q) + T.copy(Q[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k_base * block_N:(k_base + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k_base * block_N : (k_base + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else((by * block_M + i <= k_base * block_N + j) and - (by * block_M + i < k_current_seqlen and - k_base * block_N + j < q_current_seqlen), - qkT[i, j], 0) + qkT[i, j] = T.if_then_else( + (by * block_M + i <= k_base * block_N + j) + and (by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen), + qkT[i, j], + 0, + ) else: for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.if_then_else( - by * block_M + i < k_current_seqlen and - k_base * block_N + j < q_current_seqlen, qkT[i, j], 0) + by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen, qkT[i, j], 0 + ) - T.copy( - dO[q_start_idx + k_base * block_N:q_start_idx + (k_base + 1) * block_N, bx, :], - do) + T.copy(dO[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], do) T.clear(dsT) # dsT: (block_kv, block_q) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k_base * block_N:(k_base + 1) * block_N], delta) + T.copy(Delta[bz, bx, k_base * block_N : (k_base + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) @@ -364,49 +347,40 @@ def flashattn_bwd_atomic_add(batch, T.gemm(dsT_shared, K_shared, dq, transpose_A=True) T.copy(dq, dq_shared) T.atomic_add( - dQ[q_start_idx + k_base * block_N:q_start_idx + k_base * block_N + block_N, - bx, :], + dQ[q_start_idx + k_base * block_N : q_start_idx + k_base * block_N + block_N, bx, :], dq_shared, memory_order="relaxed", - use_tma=True) + use_tma=True, + ) T.copy(dv, dv_shared) T.atomic_add( - dV[k_start_idx + by * block_M:k_start_idx + by * block_M + block_M, - bx // groups, :], + dV[k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :], dv_shared, memory_order="relaxed", - use_tma=True) + use_tma=True, + ) T.copy(dk, dk_shared) T.atomic_add( - dK[k_start_idx + by * block_M:k_start_idx + by * block_M + block_M, - bx // groups, :], + dK[k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :], dk_shared, memory_order="relaxed", - use_tma=True) + use_tma=True, + ) return flash_bwd -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd_split(batch, - total_q, - total_kv, - N_CTX, - heads, - max_seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=1): - sm_scale = (1.0 / dim_qk)**0.5 - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_split( + batch, total_q, total_kv, N_CTX, heads, max_seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1 +): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [total_q, heads, dim_qk] k_shape = [total_kv, head_kv, dim_qk] @@ -419,20 +393,19 @@ def flashattn_bwd_split(batch, @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - dO: T.Tensor(do_shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore - cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore - cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(dk_shape, dtype), # type: ignore - dV: T.Tensor(dv_shape, dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor(do_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore + cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(dk_shape, dtype), # type: ignore + dV: T.Tensor(dv_shape, dtype), # type: ignore ): - with T.Kernel( - heads, T.ceildiv(max_seq_len, block_M), batch, threads=threads) as (bx, by, bz): + with T.Kernel(heads, T.ceildiv(max_seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) dsT_shared = T.alloc_shared([block_M, block_N], dtype) q = T.alloc_shared([block_N, dim_qk], dtype) @@ -457,59 +430,55 @@ def flashattn_bwd_split(batch, q_current_seqlen = q_end_idx - q_start_idx k_current_seqlen = k_end_idx - k_start_idx - T.annotate_layout({ - dQ: make_dq_layout(dQ), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) - T.copy(K[k_start_idx + by * block_M:k_start_idx + (by + 1) * block_M, bx // groups, :], - K_shared) - T.copy(V[k_start_idx + by * block_M:k_start_idx + (by + 1) * block_M, bx // groups, :], - V_shared) + T.copy(K[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) - loop_st = T.min( - T.floordiv(by * block_M, block_N), T.floordiv(q_current_seqlen, - block_N)) if is_causal else 0 + loop_st = T.min(T.floordiv(by * block_M, block_N), T.floordiv(q_current_seqlen, block_N)) if is_causal else 0 loop_ed = T.ceildiv(q_current_seqlen, block_N) for k_base in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): # Note: The padding zero of varlen should be considered in T.copy - T.copy( - Q[q_start_idx + k_base * block_N:q_start_idx + (k_base + 1) * block_N, bx, :], - q) + T.copy(Q[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy( - dO[q_start_idx + k_base * block_N:q_start_idx + (k_base + 1) * block_N, bx, :], - do) + T.copy(dO[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k_base * block_N:(k_base + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k_base * block_N : (k_base + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else((by * block_M + i <= k_base * block_N + j) and - (by * block_M + i < k_current_seqlen and - k_base * block_N + j < q_current_seqlen), - qkT[i, j], 0) + qkT[i, j] = T.if_then_else( + (by * block_M + i <= k_base * block_N + j) + and (by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen), + qkT[i, j], + 0, + ) else: for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.if_then_else( - by * block_M + i < k_current_seqlen and - k_base * block_N + j < q_current_seqlen, qkT[i, j], 0) + by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen, qkT[i, j], 0 + ) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k_base * block_N:(k_base + 1) * block_N], delta) + T.copy(Delta[bz, bx, k_base * block_N : (k_base + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -520,62 +489,37 @@ def flashattn_bwd_split(batch, T.gemm(dsT_shared, K_shared, dq, transpose_A=True) for i, j in T.Parallel(block_N, dim_qk): if k_base * block_N + i < q_current_seqlen: - T.atomic_add( - dQ[q_start_idx + k_base * block_N + i, bx, j], - dq[i, j], - memory_order="relaxed") + T.atomic_add(dQ[q_start_idx + k_base * block_N + i, bx, j], dq[i, j], memory_order="relaxed") T.copy(dv, dv_shared) - T.copy( - dv_shared, - dV[bx % groups, k_start_idx + by * block_M:k_start_idx + by * block_M + block_M, - bx // groups, :]) + T.copy(dv_shared, dV[bx % groups, k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :]) T.copy(dk, dk_shared) - T.copy( - dk_shared, - dK[bx % groups, k_start_idx + by * block_M:k_start_idx + by * block_M + block_M, - bx // groups, :]) + T.copy(dk_shared, dK[bx % groups, k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :]) return flash_bwd @torch.compile class _attention(torch.autograd.Function): - @staticmethod - def forward(ctx, - q, - k, - v, - seqlens_q, - seqlens_k, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - causal, - groups=1, - use_atomic=True): + def forward( + ctx, q, k, v, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal, groups=1, use_atomic=True + ): BATCH, N_CTX, H, D_HEAD_QK = q.shape D_HEAD_V = v.shape[-1] block_M = 128 block_N = 64 - q_unpad, indices_q, _, _ = unpad_input( - q, (torch.arange(N_CTX, device=q.device).unsqueeze(0) < seqlens_q.unsqueeze(1))) - k_unpad, indices_k, _, _ = unpad_input( - k, (torch.arange(N_CTX, device=k.device).unsqueeze(0) < seqlens_k.unsqueeze(1))) - v_unpad, _, _, _ = unpad_input( - v, (torch.arange(N_CTX, device=v.device).unsqueeze(0) < seqlens_k.unsqueeze(1))) + q_unpad, indices_q, _, _ = unpad_input(q, (torch.arange(N_CTX, device=q.device).unsqueeze(0) < seqlens_q.unsqueeze(1))) + k_unpad, indices_k, _, _ = unpad_input(k, (torch.arange(N_CTX, device=k.device).unsqueeze(0) < seqlens_k.unsqueeze(1))) + v_unpad, _, _, _ = unpad_input(v, (torch.arange(N_CTX, device=v.device).unsqueeze(0) < seqlens_k.unsqueeze(1))) total_q = q_unpad.shape[0] total_kv = k_unpad.shape[0] - mod = flashattn_fwd(BATCH, total_q, total_kv, N_CTX, H, max_seqlen_q, D_HEAD_QK, D_HEAD_V, - causal, block_M, block_N, groups) + mod = flashattn_fwd(BATCH, total_q, total_kv, N_CTX, H, max_seqlen_q, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups) o_unpad, lse = mod(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k) o = pad_input(o_unpad, indices_q, BATCH, N_CTX) - ctx.save_for_backward(q_unpad, k_unpad, v_unpad, o_unpad, lse, seqlens_q, seqlens_k, - cu_seqlens_q, cu_seqlens_k) + ctx.save_for_backward(q_unpad, k_unpad, v_unpad, o_unpad, lse, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k) ctx.batch = BATCH ctx.causal = causal ctx.use_atomic = use_atomic @@ -590,8 +534,7 @@ class _attention(torch.autograd.Function): N_CTX = do.shape[1] q, k, v, o, lse_clone, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors # lse_clone = lse.clone() - do_unpad, _, _, _ = unpad_input( - do, (torch.arange(N_CTX, device=do.device).unsqueeze(0) < seqlens_q.unsqueeze(1))) + do_unpad, _, _, _ = unpad_input(do, (torch.arange(N_CTX, device=do.device).unsqueeze(0) < seqlens_q.unsqueeze(1))) total_q, H, D_HEAD_QK = q.shape total_kv, HEAD_KV, D_HEAD_V = v.shape groups = H // HEAD_KV @@ -624,7 +567,8 @@ class _attention(torch.autograd.Function): block_N, threads=256, num_stages=2, - groups=groups) + groups=groups, + ) dq = torch.zeros_like(q, dtype=torch.float32) dk = torch.zeros_like(k, dtype=torch.float32) dv = torch.zeros_like(v, dtype=torch.float32) @@ -645,13 +589,13 @@ class _attention(torch.autograd.Function): block_N, threads=256, num_stages=2, - groups=groups) + groups=groups, + ) dq = torch.zeros_like(q, dtype=torch.float32) dk = torch.empty(groups, *k.shape, dtype=torch.float16, device=q.device) dv = torch.empty(groups, *v.shape, dtype=torch.float16, device=q.device) kernel(q, k, v, do, lse_clone, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv) - dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), - torch.zeros_like(v, dtype=torch.float32)) + dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), torch.zeros_like(v, dtype=torch.float32)) dk, dv = dk.sum(0), dv.sum(0) dq = pad_input(dq, ctx.indices_q, BATCH, N_CTX) @@ -670,15 +614,13 @@ def ref_program(Q, K, V, padding_mask, is_causal, groups=1): # HQ = HKV * groups # To handle precision issue Q, K, V = Q.float(), K.float(), V.float() - assert Q.size(2) == K.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" - assert Q.size(2) == V.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" dim_qk = Q.size(-1) K = K.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype)) if padding_mask is not None: scores.masked_fill_(rearrange(~padding_mask, "b s -> b 1 1 s"), float("-inf")) @@ -686,41 +628,35 @@ def ref_program(Q, K, V, padding_mask, is_causal, groups=1): seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) if padding_mask is not None: output.masked_fill_(rearrange(~padding_mask, "b s -> b s 1 1"), 0.0) return output -def main(BATCH: int = 1, - H: int = 32, - N_CTX: int = 256, - D_HEAD_QK: int = 192, - D_HEAD_V: int = 128, - groups: int = 16, - causal: bool = False, - use_atomic: bool = True): +def main( + BATCH: int = 1, + H: int = 32, + N_CTX: int = 256, + D_HEAD_QK: int = 192, + D_HEAD_V: int = 128, + groups: int = 16, + causal: bool = False, + use_atomic: bool = True, +): flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V total_flops = 3 * flops_per_qk + 2 * flops_per_v if causal: total_flops *= 0.5 - Q = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() head_kv = H // groups - K = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - V = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - dO = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() padding_mask = generate_random_padding_mask(N_CTX, BATCH, "cuda", mode="random") seqlens_q = padding_mask.sum(dim=-1, dtype=torch.int32) cu_seqlens_q = F.pad(torch.cumsum(seqlens_q, dim=0, dtype=torch.int32), (1, 0)) @@ -729,8 +665,7 @@ def main(BATCH: int = 1, # In training backward pass, seqlens_k should be the same as seqlens_q seqlens_k, cu_seqlens_k, max_seqlen_k = seqlens_q, cu_seqlens_q, max_seqlen_q - O = attention(Q, K, V, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, - max_seqlen_k, causal, groups, use_atomic) + O = attention(Q, K, V, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal, groups, use_atomic) O.backward(dO, retain_graph=True) dQ, Q.grad = Q.grad.clone(), None dK, K.grad = K.grad.clone(), None @@ -772,17 +707,15 @@ if __name__ == "__main__": print(f"Detected GPU compute capability: {arch}") assert float(arch) >= 9.0, "This example only supports GPU with compute capability >= 9.0" parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='Batch size') - parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') - parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K') - parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') - parser.add_argument('--causal', action='store_true', help='Causal flag') - parser.add_argument('--groups', type=int, default=16, help='groups') - parser.add_argument( - '--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV') - parser.add_argument( - '--use_split', action='store_true', default=False, help='Use split for dK/dV') + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K") + parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V") + parser.add_argument("--causal", action="store_true", help="Causal flag") + parser.add_argument("--groups", type=int, default=16, help="groups") + parser.add_argument("--use_atomic", action="store_true", default=False, help="Use atomic add for dK/dV") + parser.add_argument("--use_split", action="store_true", default=False, help="Use split for dK/dV") args = parser.parse_args() # Can be set to True/False for testing args.causal = True @@ -796,5 +729,4 @@ if __name__ == "__main__": # Default: use atomic use_atomic = True - main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, - use_atomic) + main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic) diff --git a/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py b/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py index e916812..adb7e06 100644 --- a/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py +++ b/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py @@ -6,11 +6,13 @@ import argparse @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] @@ -20,11 +22,11 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc @T.prim_func def flash_fwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -40,25 +42,21 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc logsum = T.alloc_fragment([block_M], accum_dtype) T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = ( - T.ceildiv( - (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_range, num_stages=1): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, - -T.infinity(acc_s.dtype), 0) + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_M): @@ -76,18 +74,20 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_M, dim_v): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): dtype = "float16" accum_dtype = "float" @@ -96,9 +96,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -107,32 +107,24 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim_v, blk)): - T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd(batch, - heads, - seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=1): - sm_scale = (1.0 / dim_qk)**0.5 - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] @@ -142,15 +134,15 @@ def flashattn_bwd(batch, @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(k_shape, accum_dtype), # type: ignore - dV: T.Tensor(v_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -171,45 +163,39 @@ def flashattn_bwd(batch, dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype) dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype) - T.annotate_layout({ - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dq_shared: tilelang.layout.make_swizzled_layout(dq_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - }) - - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) + T.annotate_layout( + { + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dq_shared: tilelang.layout.make_swizzled_layout(dq_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + } + ) + + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) - T.gemm( - K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) - T.gemm( - V_shared, - do, - dsT, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow, - wg_wait=-1) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) T.wait_wgmma(1) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -221,18 +207,17 @@ def flashattn_bwd(batch, T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1) T.wait_wgmma(0) T.copy(dq, dq_shared) - T.atomic_add(dQ[bz, k * block_N:(k + 1) * block_N, bx, :], dq_shared) + T.atomic_add(dQ[bz, k * block_N : (k + 1) * block_N, bx, :], dq_shared) T.copy(dv, dv_shared) - T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared) + T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared) T.copy(dk, dk_shared) - T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared) + T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared) return flash_bwd @torch.compile class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): BATCH, N_CTX, H, D_HEAD_QK = q.shape @@ -250,7 +235,10 @@ class _attention(torch.autograd.Function): def backward(ctx, do): q, k, v, o, lse = ctx.saved_tensors BATCH, N_CTX, H, D_HEAD_QK = q.shape - HEAD_KV, D_HEAD_V, = v.shape[-2], v.shape[-1] + ( + HEAD_KV, + D_HEAD_V, + ) = v.shape[-2], v.shape[-1] groups = H // HEAD_KV def maybe_contiguous(x): @@ -264,18 +252,7 @@ class _attention(torch.autograd.Function): mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V) delta = mod_prep(o, do) - kernel = flashattn_bwd( - BATCH, - H, - N_CTX, - D_HEAD_QK, - D_HEAD_V, - ctx.causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=groups) + kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups) shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK] shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V] @@ -298,52 +275,36 @@ def ref_program(Q, K, V, is_causal, groups=1): # K: [B, T, HK, D_QK] # V: [B, T, HV, D_V] # HQ = HKV * groups - assert Q.size(2) == K.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" - assert Q.size(2) == V.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" dim_qk = Q.size(-1) K = K.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output -def main(BATCH: int = 1, - H: int = 32, - N_CTX: int = 256, - D_HEAD_QK: int = 192, - D_HEAD_V: int = 128, - groups: int = 16, - causal: bool = False): +def main(BATCH: int = 1, H: int = 32, N_CTX: int = 256, D_HEAD_QK: int = 192, D_HEAD_V: int = 128, groups: int = 16, causal: bool = False): flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V total_flops = 3 * flops_per_qk + 2 * flops_per_v if causal: total_flops *= 0.5 - Q = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() head_kv = H // groups - K = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - V = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - dO = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() O = attention(Q, K, V, causal, groups) O.backward(dO, retain_graph=True) dQ, Q.grad = Q.grad.clone(), None @@ -360,7 +321,7 @@ def main(BATCH: int = 1, torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) - print('All checks passed.✅') + print("All checks passed.✅") def run(): O_ref.backward(dO, retain_graph=True) @@ -380,13 +341,13 @@ def main(BATCH: int = 1, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='Batch size') - parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') - parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K') - parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') - parser.add_argument('--causal', action='store_true', help='Causal flag') - parser.add_argument('--groups', type=int, default=16, help='groups') + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K") + parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V") + parser.add_argument("--causal", action="store_true", help="Causal flag") + parser.add_argument("--groups", type=int, default=16, help="groups") args = parser.parse_args() main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal) diff --git a/examples/flash_attention/example_gqa_fwd_bshd.py b/examples/flash_attention/example_gqa_fwd_bshd.py index a6d3b5f..408d6e5 100644 --- a/examples/flash_attention/example_gqa_fwd_bshd.py +++ b/examples/flash_attention/example_gqa_fwd_bshd.py @@ -9,7 +9,6 @@ from functools import partial class FlashAttentionTuneSpace: - def __init__( self, block_sizes=(64, 128, 256), @@ -40,7 +39,7 @@ def get_configs(user_config=None): warp_M = block_M // warp_count warp_N = block_N // warp_count - if (warp_M % config.warp_alignment != 0 or warp_N % config.warp_alignment != 0): + if warp_M % config.warp_alignment != 0 or warp_N % config.warp_alignment != 0: continue shared_mem = 2 * config.dtype_bytes * config.dim * (block_M + block_N) @@ -48,31 +47,26 @@ def get_configs(user_config=None): continue for num_stages in config.num_stages_range: - valid_configs.append({ - "block_M": block_M, - "block_N": block_N, - "num_stages": num_stages, - "threads": threads, - }) + valid_configs.append( + { + "block_M": block_M, + "block_N": block_N, + "num_stages": num_stages, + "threads": threads, + } + ) return valid_configs @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch, - heads, - seq_len, - dim, - is_causal, - groups=1, - block_M=64, - block_N=64, - num_stages=0, - threads=128): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch, heads, seq_len, dim, is_causal, groups=1, block_M=64, block_N=64, num_stages=0, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] @@ -90,15 +84,13 @@ def flashattn(batch, by: T.int32, bz: T.int32, ): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), - 0) + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @T.macro @@ -111,18 +103,18 @@ def flashattn(batch, by: T.int32, bz: T.int32, ): - T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -148,18 +140,18 @@ def flashattn(batch, @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -175,25 +167,24 @@ def flashattn(batch, scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) return main @@ -203,50 +194,34 @@ def ref_program(Q, K, V, is_causal, groups=1): # K: [B, T, HK, D] # V: [B, T, HV, D] # HQ = HKV * groups - assert Q.size(2) == K.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" - assert Q.size(2) == V.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" dim = Q.size(-1) K = K.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output -def main(batch: int = 1, - heads: int = 64, - seq_len: int = 4096, - dim: int = 128, - is_causal: bool = False, - groups: int = 16, - tune: bool = False): +def main( + batch: int = 1, heads: int = 64, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 16, tune: bool = False +): flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim total_flops = 2 * flops_per_matmul if is_causal: total_flops *= 0.5 - if (not tune): - kernel = flashattn( - batch, - heads, - seq_len, - dim, - is_causal, - groups=groups, - block_M=64, - block_N=64, - num_stages=2, - threads=128) + if not tune: + kernel = flashattn(batch, heads, seq_len, dim, is_causal, groups=groups, block_M=64, block_N=64, num_stages=2, threads=128) ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) @@ -270,12 +245,12 @@ def main(batch: int = 1, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=64, help='heads') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--tune', action='store_true', help='tune configs') - parser.add_argument('--groups', type=int, default=16, help='groups') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") + parser.add_argument("--groups", type=int, default=16, help="groups") args = parser.parse_args() main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups, args.tune) diff --git a/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py index 03ad15e..3492be7 100644 --- a/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py @@ -24,9 +24,11 @@ def get_configs(): rep=10, ) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn( batch, heads, @@ -39,7 +41,7 @@ def flashattn( num_stages=0, threads=128, ): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] @@ -57,15 +59,13 @@ def flashattn( by: T.int32, bz: T.int32, ): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), - 0) + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @T.macro @@ -78,18 +78,18 @@ def flashattn( by: T.int32, bz: T.int32, ): - T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -115,18 +115,18 @@ def flashattn( @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -142,30 +142,30 @@ def flashattn( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined( - loop_range, - num_stages=num_stages, - order=[-1, 0, 3, 1, -1, 2], - stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]): + loop_range, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) return main @@ -175,23 +175,21 @@ def ref_program(Q, K, V, is_causal, groups=1): # K: [B, T, HK, D] # V: [B, T, HV, D] # HQ = HKV * groups - assert Q.size(2) == K.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" - assert Q.size(2) == V.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" dim = Q.size(-1) K = K.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output @@ -209,18 +207,8 @@ def main( if is_causal: total_flops *= 0.5 - if (not tune): - kernel = flashattn( - batch, - heads, - seq_len, - dim, - is_causal, - groups=groups, - block_M=128, - block_N=128, - num_stages=2, - threads=256) + if not tune: + kernel = flashattn(batch, heads, seq_len, dim, is_causal, groups=groups, block_M=128, block_N=128, num_stages=2, threads=256) ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) @@ -244,12 +232,12 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=64, help='heads') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--tune', action='store_true', help='tune configs') - parser.add_argument('--groups', type=int, default=16, help='groups') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") + parser.add_argument("--groups", type=int, default=16, help="groups") args = parser.parse_args() main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups, args.tune) diff --git a/examples/flash_attention/example_gqa_fwd_varlen.py b/examples/flash_attention/example_gqa_fwd_varlen.py index ccc50e4..87b11f7 100644 --- a/examples/flash_attention/example_gqa_fwd_varlen.py +++ b/examples/flash_attention/example_gqa_fwd_varlen.py @@ -10,14 +10,14 @@ from varlen_utils import generate_random_padding_mask, generate_qkv def attention_ref( - q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - causal=False, - window_size=(-1, -1), - upcast=True, + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + causal=False, + window_size=(-1, -1), + upcast=True, ): if causal: window_size = (window_size[0], 0) @@ -26,7 +26,7 @@ def attention_ref( q, k, v = q.float(), k.float(), v.float() b, T, Hq, D = q.shape S = k.shape[1] - scale = (1.0 / D)**0.5 + scale = (1.0 / D) ** 0.5 k = repeat(k, "b s h d -> b s (h g) d", g=Hq // k.shape[2]) v = repeat(v, "b s h d -> b s (h g) d", g=Hq // v.shape[2]) scores = torch.einsum("bthd,bshd->bhts", q, k) @@ -54,21 +54,13 @@ def attention_ref( @tilelang.jit( - out_idx=[6], pass_configs={ + out_idx=[6], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch_size, - groups, - UQ, - UKV, - heads, - dim, - is_causal, - block_M=64, - block_N=64, - num_stages=1, - threads=128): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch_size, groups, UQ, UKV, heads, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [UQ, heads, dim] kv_shape = [UKV, head_kv, dim] @@ -78,17 +70,15 @@ def flashattn(batch_size, @T.prim_func def main( - Q_unpad: T.Tensor(q_shape, dtype), - K_unpad: T.Tensor(kv_shape, dtype), - V_unpad: T.Tensor(kv_shape, dtype), - cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), - cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), - max_seqlen_q: T.int32, - Output_unpad: T.Tensor(o_shape, dtype), + Q_unpad: T.Tensor(q_shape, dtype), + K_unpad: T.Tensor(kv_shape, dtype), + V_unpad: T.Tensor(kv_shape, dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), + cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), + max_seqlen_q: T.int32, + Output_unpad: T.Tensor(o_shape, dtype), ): - with T.Kernel( - T.ceildiv(max_seqlen_q, block_M), heads, batch_size, - threads=threads) as (bx, by, bz): + with T.Kernel(T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) @@ -102,10 +92,12 @@ def flashattn(batch_size, scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.annotate_layout({ - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), - }) + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), + } + ) batch_idx = bz head_idx = by @@ -119,36 +111,34 @@ def flashattn(batch_size, q_current_seqlen = q_end_idx - q_start_idx kv_current_seqlen = k_end_idx - kv_start_idx - T.copy( - Q_unpad[q_start_idx + bx * block_M:q_start_idx + (bx + 1) * block_M, head_idx, :], - Q_shared) + T.copy(Q_unpad[q_start_idx + bx * block_M : q_start_idx + (bx + 1) * block_M, head_idx, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = ( - T.min( - T.ceildiv(q_current_seqlen + - (bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N)) - if is_causal else T.ceildiv(kv_current_seqlen, block_N)) + T.min(T.ceildiv(q_current_seqlen + (bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N)) + if is_causal + else T.ceildiv(kv_current_seqlen, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): - T.copy( - K_unpad[kv_start_idx + k * block_N:kv_start_idx + (k + 1) * block_N, - kv_head_idx, :], K_shared) + T.copy(K_unpad[kv_start_idx + k * block_N : kv_start_idx + (k + 1) * block_N, kv_head_idx, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, - j] = T.if_then_else((bx * block_M + i < k * block_N + j) or - (bx * block_M + i >= q_current_seqlen or - k * block_N + j >= kv_current_seqlen), -1e9, 0) + acc_s[i, j] = T.if_then_else( + (bx * block_M + i < k * block_N + j) + or (bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen), + -1e9, + 0, + ) else: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or - k * block_N + j >= kv_current_seqlen), -1e9, - 0) + acc_s[i, j] = T.if_then_else( + (bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen), -1e9, 0 + ) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -170,9 +160,7 @@ def flashattn(batch_size, for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] - T.copy( - V_unpad[kv_start_idx + k * block_N:kv_start_idx + (k + 1) * block_N, - kv_head_idx, :], V_shared) + T.copy(V_unpad[kv_start_idx + k * block_N : kv_start_idx + (k + 1) * block_N, kv_head_idx, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @@ -187,13 +175,9 @@ def flashattn(batch_size, return main -def main(batch: int = 1, - heads: int = 64, - q_seqlen: int = 2048, - k_seqlen: int = 2048, - dim: int = 128, - groups: int = 16, - is_causal: bool = False): +def main( + batch: int = 1, heads: int = 64, q_seqlen: int = 2048, k_seqlen: int = 2048, dim: int = 128, groups: int = 16, is_causal: bool = False +): assert heads % groups == 0, "heads must be divisible by groups" flops_per_matmul = 2.0 * batch * heads * q_seqlen * k_seqlen * dim @@ -231,24 +215,12 @@ def main(batch: int = 1, output_pad_fn, _, _, - ) = generate_qkv( - q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) UQ = q_unpad.shape[0] UKV = k_unpad.shape[0] - kernel = flashattn( - batch, - groups, - UQ, - UKV, - heads, - dim, - is_causal, - block_M=128, - block_N=128, - num_stages=2, - threads=256) + kernel = flashattn(batch, groups, UQ, UKV, heads, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256) out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) out = output_pad_fn(out_unpad) @@ -263,23 +235,19 @@ def main(batch: int = 1, ) torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2) print("All checks passed.✅") - latency = do_bench( - lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q), - _n_warmup=5, - _n_repeat=5) + latency = do_bench(lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q), _n_warmup=5, _n_repeat=5) print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=64, help='query heads') - parser.add_argument('--groups', type=int, default=16, help='groups') - parser.add_argument('--q_seqlen', type=int, default=2048, help='query sequence length') - parser.add_argument('--k_seqlen', type=int, default=2048, help='key/value sequence length') - parser.add_argument('--dim', type=int, default=128, help='head dim') - parser.add_argument('--is_causal', action='store_true', help='causal attention') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="query heads") + parser.add_argument("--groups", type=int, default=16, help="groups") + parser.add_argument("--q_seqlen", type=int, default=2048, help="query sequence length") + parser.add_argument("--k_seqlen", type=int, default=2048, help="key/value sequence length") + parser.add_argument("--dim", type=int, default=128, help="head dim") + parser.add_argument("--is_causal", action="store_true", help="causal attention") args = parser.parse_args() - main(args.batch, args.heads, args.q_seqlen, args.k_seqlen, args.dim, args.groups, - args.is_causal) + main(args.batch, args.heads, args.q_seqlen, args.k_seqlen, args.dim, args.groups, args.is_causal) diff --git a/examples/flash_attention/example_mha_bwd_bhsd.py b/examples/flash_attention/example_mha_bwd_bhsd.py index d91d177..81eb6d1 100644 --- a/examples/flash_attention/example_mha_bwd_bhsd.py +++ b/examples/flash_attention/example_mha_bwd_bhsd.py @@ -7,22 +7,24 @@ import argparse @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, heads, seq_len, dim] dtype = "float16" accum_dtype = "float" @T.prim_func def flash_fwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - Output: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + Output: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -39,28 +41,24 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): logsum = T.alloc_fragment([block_M], accum_dtype) T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) # T.copy(Q_shared, Q_local) # for i, j in T.Parallel(block_M, dim): # Q_local[i, j] *= scale - loop_range = ( - T.ceildiv( - (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_range, num_stages=1): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, - -T.infinity(acc_s.dtype), 0) + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_M): @@ -78,18 +76,20 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(acc_o, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, seq_len, dim): dtype = "float16" accum_dtype = "float" @@ -98,9 +98,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim): @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -109,26 +109,27 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim): delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim, blk)): - T.copy(O[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], do) + T.copy(O[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep def make_dq_layout(dQ): # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment - return T.Layout(dQ.shape, - lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + return T.Layout(dQ.shape, lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) @tilelang.jit( - out_idx=[1], pass_configs={ + out_idx=[1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_postprocess(batch, heads, seq_len, dim): dtype = "float16" accum_dtype = "float" @@ -137,40 +138,42 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim): @T.prim_func def flash_bwd_post( - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dQ_out: T.Tensor(shape, dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): T.annotate_layout({dQ: make_dq_layout(dQ)}) T.copy( - dQ[bz, by, bx * blk:(bx + 1) * blk, :], - dQ_out[bz, by, bx * blk:(bx + 1) * blk, :], + dQ[bz, by, bx * blk : (bx + 1) * blk, :], + dQ_out[bz, by, bx * blk : (bx + 1) * blk, :], ) return flash_bwd_post -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): - sm_scale = (1.0 / dim)**0.5 - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + sm_scale = (1.0 / dim) ** 0.5 + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, heads, seq_len, dim] dtype = "float16" accum_dtype = "float" @T.prim_func def flash_bwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dK: T.Tensor(shape, dtype), # type: ignore - dV: T.Tensor(shape, dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dK: T.Tensor(shape, dtype), # type: ignore + dV: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim], dtype) @@ -194,38 +197,39 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): dv_shared = T.alloc_shared([block_M, dim], dtype) dk_shared = T.alloc_shared([block_M, dim], dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) - T.copy(K[bz, bx, by * block_M:(by + 1) * block_M, :], K_shared) - T.copy(V[bz, bx, by * block_M:(by + 1) * block_M, :], V_shared) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) + T.copy(K[bz, bx, by * block_M : (by + 1) * block_M, :], K_shared) + T.copy(V[bz, bx, by * block_M : (by + 1) * block_M, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=2): - T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q) + T.copy(Q[bz, bx, k * block_N : (k + 1) * block_N, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) # We don't need to handle OOB positions for non-causal cases, # since OOB values won't affect other positions here. - T.copy(dO[bz, bx, k * block_N:(k + 1) * block_N, :], do) + T.copy(dO[bz, bx, k * block_N : (k + 1) * block_N, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -238,14 +242,13 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): T.atomic_add(dQ[bz, bx, k * block_N + i, j], dq[i, j]) T.copy(dv, dv_shared) T.copy(dk, dk_shared) - T.copy(dv_shared, dV[bz, bx, by * block_M:(by + 1) * block_M, :]) - T.copy(dk_shared, dK[bz, bx, by * block_M:(by + 1) * block_M, :]) + T.copy(dv_shared, dV[bz, bx, by * block_M : (by + 1) * block_M, :]) + T.copy(dk_shared, dK[bz, bx, by * block_M : (by + 1) * block_M, :]) return flash_bwd class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, causal): BATCH, H, N_CTX, D_HEAD = q.shape @@ -287,15 +290,15 @@ attention = _attention.apply def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bhqd,bhkd->bhqk', Q, K) + scores = torch.einsum("bhqd,bhkd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(2) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V) + output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V) return output @@ -310,9 +313,7 @@ def main( total_flops = 5 * flops_per_matmul if causal: total_flops *= 0.5 - Q = ( - torch.empty(BATCH, H, N_CTX, D_HEAD, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + Q = torch.empty(BATCH, H, N_CTX, D_HEAD, dtype=torch.half, device="cuda").normal_().requires_grad_() K = torch.empty_like(Q).normal_().requires_grad_() V = torch.empty_like(Q).normal_().requires_grad_() dO = torch.randn_like(Q) @@ -353,10 +354,10 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='Batch size') - parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') - parser.add_argument('--d_head', type=int, default=64, help='Head dimension') - parser.add_argument('--causal', type=bool, default=False, help='Causal flag') + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head", type=int, default=64, help="Head dimension") + parser.add_argument("--causal", type=bool, default=False, help="Causal flag") args = parser.parse_args() main(args.batch, args.h, args.n_ctx, args.d_head, args.causal) diff --git a/examples/flash_attention/example_mha_bwd_bshd.py b/examples/flash_attention/example_mha_bwd_bshd.py index 7c85f98..427a0f6 100644 --- a/examples/flash_attention/example_mha_bwd_bshd.py +++ b/examples/flash_attention/example_mha_bwd_bshd.py @@ -7,22 +7,24 @@ import argparse @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] dtype = "float16" accum_dtype = "float" @T.prim_func def flash_fwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - Output: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + Output: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -38,25 +40,21 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = ( - T.ceildiv( - (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_range, num_stages=1): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, - -T.infinity(acc_s.dtype), 0) + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_M): @@ -74,18 +72,20 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, seq_len, dim): dtype = "float16" accum_dtype = "float" @@ -94,9 +94,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim): @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -105,26 +105,27 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim): delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim, blk)): - T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep def make_dq_layout(dQ): # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment - return T.Layout(dQ.shape, - lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + return T.Layout(dQ.shape, lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) @tilelang.jit( - out_idx=[1], pass_configs={ + out_idx=[1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_postprocess(batch, heads, seq_len, dim): dtype = "float16" accum_dtype = "float" @@ -133,40 +134,42 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim): @T.prim_func def flash_bwd_post( - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dQ_out: T.Tensor(shape, dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): T.annotate_layout({dQ: make_dq_layout(dQ)}) T.copy( - dQ[bz, bx * blk:(bx + 1) * blk, by, :], - dQ_out[bz, bx * blk:(bx + 1) * blk, by, :], + dQ[bz, bx * blk : (bx + 1) * blk, by, :], + dQ_out[bz, bx * blk : (bx + 1) * blk, by, :], ) return flash_bwd_post -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): - sm_scale = (1.0 / dim)**0.5 - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + sm_scale = (1.0 / dim) ** 0.5 + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] dtype = "float16" accum_dtype = "float" @T.prim_func def flash_bwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dK: T.Tensor(shape, dtype), # type: ignore - dV: T.Tensor(shape, dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dK: T.Tensor(shape, dtype), # type: ignore + dV: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim], dtype) @@ -190,35 +193,36 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): dv_shared = T.alloc_shared([block_M, dim], dtype) dk_shared = T.alloc_shared([block_M, dim], dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - }) - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx, :], V_shared) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + } + ) + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=2): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) # We don't need to handle OOB positions for non-causal cases, # since OOB values won't affect other positions here. - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -231,14 +235,13 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) T.copy(dv, dv_shared) T.copy(dk, dk_shared) - T.copy(dv_shared, dV[bz, by * block_M:(by + 1) * block_M, bx, :]) - T.copy(dk_shared, dK[bz, by * block_M:(by + 1) * block_M, bx, :]) + T.copy(dv_shared, dV[bz, by * block_M : (by + 1) * block_M, bx, :]) + T.copy(dk_shared, dK[bz, by * block_M : (by + 1) * block_M, bx, :]) return flash_bwd class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, causal): BATCH, N_CTX, H, D_HEAD = q.shape @@ -280,15 +283,15 @@ attention = _attention.apply def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output @@ -303,9 +306,7 @@ def main( total_flops = 5 * flops_per_matmul if causal: total_flops *= 0.5 - Q = ( - torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + Q = torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half, device="cuda").normal_().requires_grad_() K = torch.empty_like(Q).normal_().requires_grad_() V = torch.empty_like(Q).normal_().requires_grad_() dO = torch.randn_like(Q) @@ -344,10 +345,10 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='Batch size') - parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') - parser.add_argument('--d_head', type=int, default=64, help='Head dimension') - parser.add_argument('--causal', type=bool, default=False, help='Causal flag') + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head", type=int, default=64, help="Head dimension") + parser.add_argument("--causal", type=bool, default=False, help="Causal flag") args = parser.parse_args() main(args.batch, args.h, args.n_ctx, args.d_head, args.causal) diff --git a/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py index e8ee5d9..813f379 100644 --- a/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py @@ -7,22 +7,24 @@ import argparse @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] dtype = "float16" accum_dtype = "float" @T.prim_func def flash_fwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - Output: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + Output: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -38,26 +40,22 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): logsum = T.alloc_fragment([block_M], accum_dtype) T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = ( - T.ceildiv( - (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_range, num_stages=1): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, - -T.infinity(acc_s.dtype), 0) + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_M): @@ -75,18 +73,20 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, seq_len, dim): dtype = "float16" accum_dtype = "float" @@ -95,9 +95,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim): @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -106,37 +106,39 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim): delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim, blk)): - T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): - sm_scale = (1.0 / dim)**0.5 - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + sm_scale = (1.0 / dim) ** 0.5 + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] dtype = "float16" accum_dtype = "float" @T.prim_func def flash_bwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dK: T.Tensor(shape, dtype), # type: ignore - dV: T.Tensor(shape, dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dK: T.Tensor(shape, dtype), # type: ignore + dV: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=256) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim], dtype) @@ -161,49 +163,43 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): dk_shared = T.alloc_shared([block_M, dim], dtype) dq_shared = T.alloc_shared([block_N, dim], accum_dtype) - T.annotate_layout({ - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - dq_shared: tilelang.layout.make_swizzled_layout(dq_shared), - }) - - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx, :], V_shared) + T.annotate_layout( + { + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + dq_shared: tilelang.layout.make_swizzled_layout(dq_shared), + } + ) + + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=2): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) - T.gemm( - K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) - T.gemm( - V_shared, - do, - dsT, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow, - wg_wait=-1) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) T.wait_wgmma(1) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) # We don't need to handle OOB positions for non-causal cases, # since OOB values won't affect other positions here. T.wait_wgmma(0) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -214,17 +210,16 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1) T.wait_wgmma(0) T.copy(dq, dq_shared) - T.atomic_add(dQ[bz, k * block_N:(k + 1) * block_N, bx, :], dq_shared) + T.atomic_add(dQ[bz, k * block_N : (k + 1) * block_N, bx, :], dq_shared) T.copy(dv, dv_shared) T.copy(dk, dk_shared) - T.copy(dv_shared, dV[bz, by * block_M:(by + 1) * block_M, bx, :]) - T.copy(dk_shared, dK[bz, by * block_M:(by + 1) * block_M, bx, :]) + T.copy(dv_shared, dV[bz, by * block_M : (by + 1) * block_M, bx, :]) + T.copy(dk_shared, dK[bz, by * block_M : (by + 1) * block_M, bx, :]) return flash_bwd class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, causal): BATCH, N_CTX, H, D_HEAD = q.shape @@ -266,15 +261,15 @@ attention = _attention.apply def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output @@ -289,9 +284,7 @@ def main( total_flops = 5 * flops_per_matmul if causal: total_flops *= 0.5 - Q = ( - torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + Q = torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half, device="cuda").normal_().requires_grad_() K = torch.empty_like(Q).normal_().requires_grad_() V = torch.empty_like(Q).normal_().requires_grad_() dO = torch.randn_like(Q) @@ -311,7 +304,7 @@ def main( assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2) - print('All checks passed.✅') + print("All checks passed.✅") def run(): O_ref.backward(dO, retain_graph=True) @@ -329,10 +322,10 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='Batch size') - parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') - parser.add_argument('--d_head', type=int, default=64, help='Head dimension') - parser.add_argument('--causal', type=bool, default=False, help='Causal flag') + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head", type=int, default=64, help="Head dimension") + parser.add_argument("--causal", type=bool, default=False, help="Causal flag") args = parser.parse_args() main(args.batch, args.h, args.n_ctx, args.d_head, args.causal) diff --git a/examples/flash_attention/example_mha_fwd_bhsd.py b/examples/flash_attention/example_mha_fwd_bhsd.py index e0e0bca..7fa5549 100644 --- a/examples/flash_attention/example_mha_fwd_bhsd.py +++ b/examples/flash_attention/example_mha_fwd_bhsd.py @@ -15,20 +15,13 @@ def get_configs(): @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch, - heads, - seq_q, - seq_kv, - dim, - is_causal, - block_M=64, - block_N=64, - num_stages=1, - threads=128): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, heads, seq_kv, dim] dtype = "float16" @@ -48,7 +41,7 @@ def flashattn(batch, by: T.int32, bz: T.int32, ): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): q_idx = bx * block_M + i + past_len @@ -70,18 +63,18 @@ def flashattn(batch, by: T.int32, bz: T.int32, ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -110,18 +103,18 @@ def flashattn(batch, @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -137,43 +130,42 @@ def flashattn(batch, scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = ( - T.min( - T.ceildiv(seq_kv, block_N), T.ceildiv( - (bx + 1) * block_M + - past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N)) + T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + if is_causal + else T.ceildiv(seq_kv, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bhqd,bhkd->bhqk', Q, K) + scores = torch.einsum("bhqd,bhkd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_q = Q.size(2) seq_kv = K.size(2) mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V) + output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V) return output @@ -191,18 +183,8 @@ def main( if is_causal: total_flops *= 0.5 - if (not tune): - kernel = flashattn( - batch, - heads, - seq_q, - seq_kv, - dim, - is_causal, - block_M=64, - block_N=64, - num_stages=1, - threads=128) + if not tune: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128) ref_program_processed = partial(ref_program, is_causal=is_causal) profiler = kernel.get_profiler() @@ -227,12 +209,12 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=1, help='heads') - parser.add_argument('--seq_q', type=int, default=256, help='query sequence length') - parser.add_argument('--seq_kv', type=int, default=256, help='key/value sequence length') - parser.add_argument('--dim', type=int, default=64, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal', default=False) - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=1, help="heads") + parser.add_argument("--seq_q", type=int, default=256, help="query sequence length") + parser.add_argument("--seq_kv", type=int, default=256, help="key/value sequence length") + parser.add_argument("--dim", type=int, default=64, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal", default=False) + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune) diff --git a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py index b797bbc..440a2cd 100644 --- a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py @@ -15,20 +15,13 @@ def get_configs(): @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch, - heads, - seq_q, - seq_kv, - dim, - is_causal, - block_M=128, - block_N=128, - num_stages=2, - threads=256): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, heads, seq_kv, dim] dtype = "float16" @@ -48,7 +41,7 @@ def flashattn(batch, by: T.int32, bz: T.int32, ): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): q_idx = bx * block_M + i + past_len @@ -70,18 +63,18 @@ def flashattn(batch, by: T.int32, bz: T.int32, ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -108,18 +101,18 @@ def flashattn(batch, @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -135,48 +128,48 @@ def flashattn(batch, scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = ( - T.min( - T.ceildiv(seq_kv, block_N), T.ceildiv( - (bx + 1) * block_M + - past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N)) + T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + if is_causal + else T.ceildiv(seq_kv, block_N) + ) for k in T.Pipelined( - loop_range, - num_stages=num_stages, - order=[-1, 0, 3, 1, -1, 2], - stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]): + loop_range, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bhqd,bhkd->bhqk', Q, K) + scores = torch.einsum("bhqd,bhkd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_q = Q.size(2) seq_kv = K.size(2) mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V) + output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V) return output @@ -194,18 +187,8 @@ def main( if is_causal: total_flops *= 0.5 - if (not tune): - kernel = flashattn( - batch, - heads, - seq_q, - seq_kv, - dim, - is_causal, - block_M=128, - block_N=128, - num_stages=2, - threads=256) + if not tune: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256) ref_program_processed = partial(ref_program, is_causal=is_causal) profiler = kernel.get_profiler() @@ -230,12 +213,12 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--seq_q', type=int, default=4096, help='query sequence length') - parser.add_argument('--seq_kv', type=int, default=4096, help='key/value sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_q", type=int, default=4096, help="query sequence length") + parser.add_argument("--seq_kv", type=int, default=4096, help="key/value sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune) diff --git a/examples/flash_attention/example_mha_fwd_bshd.py b/examples/flash_attention/example_mha_fwd_bshd.py index b5b7282..888914c 100644 --- a/examples/flash_attention/example_mha_fwd_bshd.py +++ b/examples/flash_attention/example_mha_fwd_bshd.py @@ -15,19 +15,13 @@ def get_configs(): @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch, - heads, - seq_len, - dim, - is_causal, - block_M=64, - block_N=64, - num_stages=1, - threads=128): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch, heads, seq_len, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] dtype = "float16" accum_dtype = "float" @@ -43,16 +37,14 @@ def flashattn(batch, by: T.int32, bz: T.int32, ): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: # We shall fill -inf for OOB positions for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), - 0) + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @T.macro @@ -65,18 +57,18 @@ def flashattn(batch, by: T.int32, bz: T.int32, ): - T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -102,18 +94,18 @@ def flashattn(batch, @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def main( - Q: T.Tensor(shape, dtype), - K: T.Tensor(shape, dtype), - V: T.Tensor(shape, dtype), - Output: T.Tensor(shape, dtype), + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + Output: T.Tensor(shape, dtype), ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -129,40 +121,39 @@ def flashattn(batch, scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) return main def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output @@ -179,17 +170,8 @@ def main( if is_causal: total_flops *= 0.5 - if (not tune): - kernel = flashattn( - batch, - heads, - seq_len, - dim, - is_causal, - block_M=128, - block_N=128, - num_stages=1, - threads=128) + if not tune: + kernel = flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=1, threads=128) ref_program_processed = partial(ref_program, is_causal=is_causal) profiler = kernel.get_profiler() profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) @@ -213,11 +195,11 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.tune) diff --git a/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py index 02d8bae..b54d3e6 100644 --- a/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py @@ -15,19 +15,13 @@ def get_configs(): @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch, - heads, - seq_len, - dim, - is_causal, - block_M=128, - block_N=128, - num_stages=2, - threads=256): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] dtype = "float16" accum_dtype = "float" @@ -43,16 +37,14 @@ def flashattn(batch, by: T.int32, bz: T.int32, ): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: # We shall fill -inf for OOB positions for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), - 0) + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @T.macro @@ -65,18 +57,18 @@ def flashattn(batch, by: T.int32, bz: T.int32, ): - T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -102,18 +94,18 @@ def flashattn(batch, @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def main( - Q: T.Tensor(shape, dtype), - K: T.Tensor(shape, dtype), - V: T.Tensor(shape, dtype), - Output: T.Tensor(shape, dtype), + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + Output: T.Tensor(shape, dtype), ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -129,45 +121,45 @@ def flashattn(batch, scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined( - loop_range, - num_stages=num_stages, - order=[-1, 0, 3, 1, -1, 2], - stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]): + loop_range, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) return main def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output @@ -184,17 +176,8 @@ def main( if is_causal: total_flops *= 0.5 - if (not tune): - kernel = flashattn( - batch, - heads, - seq_len, - dim, - is_causal, - block_M=128, - block_N=128, - num_stages=2, - threads=256) + if not tune: + kernel = flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256) ref_program_processed = partial(ref_program, is_causal=is_causal) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) @@ -218,11 +201,11 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.tune) diff --git a/examples/flash_attention/example_mha_fwd_varlen.py b/examples/flash_attention/example_mha_fwd_varlen.py index bbb4546..f7bb36f 100644 --- a/examples/flash_attention/example_mha_fwd_varlen.py +++ b/examples/flash_attention/example_mha_fwd_varlen.py @@ -11,14 +11,14 @@ from varlen_utils import generate_random_padding_mask, generate_qkv def attention_ref( - q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - causal=False, - window_size=(-1, -1), # -1 means infinite window size - upcast=True, + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + causal=False, + window_size=(-1, -1), # -1 means infinite window size + upcast=True, ): """ Arguments: @@ -47,7 +47,7 @@ def attention_ref( if upcast: q, k, v = q.float(), k.float(), v.float() dim = q.shape[-1] - scale = (1.0 / dim)**0.5 # log2(e) + scale = (1.0 / dim) ** 0.5 # log2(e) k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) scores = torch.einsum("bthd,bshd->bhts", q, k) @@ -68,20 +68,13 @@ def attention_ref( @tilelang.jit( - out_idx=[6], pass_configs={ + out_idx=[6], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch_size, - UQ, - UKV, - heads, - dim, - is_causal, - block_M=64, - block_N=64, - num_stages=0, - threads=32): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch_size, UQ, UKV, heads, dim, is_causal, block_M=64, block_N=64, num_stages=0, threads=32): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) q_shape = [UQ, heads, dim] k_shape = [UKV, heads, dim] v_shape = [UKV, heads, dim] @@ -92,17 +85,15 @@ def flashattn(batch_size, @T.prim_func def main( - Q_unpad: T.Tensor(q_shape, dtype), - K_unpad: T.Tensor(k_shape, dtype), - V_unpad: T.Tensor(v_shape, dtype), - cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), - cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), - max_seqlen_q: T.int32, - Output_unpad: T.Tensor(o_shape, dtype), + Q_unpad: T.Tensor(q_shape, dtype), + K_unpad: T.Tensor(k_shape, dtype), + V_unpad: T.Tensor(v_shape, dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), + cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), + max_seqlen_q: T.int32, + Output_unpad: T.Tensor(o_shape, dtype), ): - with T.Kernel( - T.ceildiv(max_seqlen_q, block_M), heads, batch_size, - threads=threads) as (bx, by, bz): + with T.Kernel(T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype, "shared") K_shared = T.alloc_shared([block_N, dim], dtype, "shared") V_shared = T.alloc_shared([block_N, dim], dtype, "shared") @@ -151,15 +142,17 @@ def flashattn(batch_size, K_shared[i, d] = 0 if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and - (bx * block_M + i >= q_current_seqlen or - k * block_N + j >= k_current_seqlen), - -T.infinity(acc_s.dtype), 0) + acc_s[i, j] = T.if_then_else( + (bx * block_M + i >= k * block_N + j) + and (bx * block_M + i >= q_current_seqlen or k * block_N + j >= k_current_seqlen), + -T.infinity(acc_s.dtype), + 0, + ) else: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or - k * block_N + j >= k_current_seqlen), - -T.infinity(acc_s.dtype), 0) + acc_s[i, j] = T.if_then_else( + (bx * block_M + i >= q_current_seqlen or k * block_N + j >= k_current_seqlen), -T.infinity(acc_s.dtype), 0 + ) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -244,8 +237,7 @@ def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128): output_pad_fn, dq_pad_fn, dk_pad_fn, - ) = generate_qkv( - q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) UQ = q_unpad.shape[0] # unpadded query length UK = k_unpad.shape[0] # unpadded key length @@ -287,10 +279,10 @@ def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=64, help='heads') - parser.add_argument('--seq_len', type=int, default=2048, help='sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_len", type=int, default=2048, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") args = parser.parse_args() main(args.batch, args.heads, args.seq_len, args.dim) diff --git a/examples/flash_attention/test_example_flash_attention.py b/examples/flash_attention/test_example_flash_attention.py index b184fc6..da172bb 100644 --- a/examples/flash_attention/test_example_flash_attention.py +++ b/examples/flash_attention/test_example_flash_attention.py @@ -62,14 +62,12 @@ def test_example_mha_bwd_wgmma_pipelined(): @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_gqa_fwd_bshd_wgmma_pipelined(): - example_gqa_fwd_bshd_wgmma_pipelined.main( - batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False) + example_gqa_fwd_bshd_wgmma_pipelined.main(batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False) @tilelang.testing.requires_cuda def test_example_gqa_fwd_bshd(): - example_gqa_fwd_bshd.main( - batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False) + example_gqa_fwd_bshd.main(batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False) @tilelang.testing.requires_cuda diff --git a/examples/flash_attention/varlen_utils.py b/examples/flash_attention/varlen_utils.py index 4301215..43e21cc 100644 --- a/examples/flash_attention/varlen_utils.py +++ b/examples/flash_attention/varlen_utils.py @@ -9,22 +9,14 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): if mode == "full": lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) elif mode == "random": - lengths = torch.randint( - max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device) + lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device) elif mode == "third": lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) - padding_mask = ( - repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths) + padding_mask = repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths return padding_mask -def generate_qkv(q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - kvpacked=False, - qkvpacked=False): +def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False): """ Arguments: q: (batch_size, seqlen_q, nheads, d) @@ -39,15 +31,12 @@ def generate_qkv(q, if query_padding_mask is not None: q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) - output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q - ) + output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q) else: q_unpad = rearrange(q, "b s h d -> (b s) h d") - cu_seqlens_q = torch.arange( - 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device) + cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device) max_seqlen_q = seqlen_q - output_pad_fn = lambda output_unpad: rearrange( - output_unpad, "(b s) h d -> b s h d", b=batch_size) + output_pad_fn = lambda output_unpad: rearrange(output_unpad, "(b s) h d -> b s h d", b=batch_size) if key_padding_mask is not None: k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) @@ -55,8 +44,7 @@ def generate_qkv(q, else: k_unpad = rearrange(k, "b s h d -> (b s) h d") v_unpad = rearrange(v, "b s h d -> (b s) h d") - cu_seqlens_k = torch.arange( - 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device) + cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device) max_seqlen_k = seqlen_k if qkvpacked: @@ -67,8 +55,7 @@ def generate_qkv(q, if query_padding_mask is not None: dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) else: - dqkv_pad_fn = lambda dqkv_unpad: rearrange( - dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) + dqkv_pad_fn = lambda dqkv_unpad: rearrange(dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) return ( qkv_unpad.detach().requires_grad_(), cu_seqlens_q, @@ -84,8 +71,7 @@ def generate_qkv(q, if key_padding_mask is not None: dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) else: - dkv_pad_fn = lambda dkv_unpad: rearrange( - dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) + dkv_pad_fn = lambda dkv_unpad: rearrange(dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) return ( q_unpad.detach().requires_grad_(), kv_unpad.detach().requires_grad_(), diff --git a/examples/flash_decoding/example_gqa_decode.py b/examples/flash_decoding/example_gqa_decode.py index 7ccd983..136a512 100644 --- a/examples/flash_decoding/example_gqa_decode.py +++ b/examples/flash_decoding/example_gqa_decode.py @@ -20,13 +20,7 @@ def get_configs(): threads = [128] _configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads)) - configs = [{ - 'block_N': c[0], - 'block_H': c[1], - 'num_split': c[2], - 'num_stages': c[3], - 'threads': c[4] - } for c in _configs] + configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs] return configs @@ -48,17 +42,13 @@ def get_heuristic_config() -> Tuple[Dict, int]: # TODO(lei): fix warp specialized and tma lower pass def get_pass_configs(): - return { - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True - } + return {tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True} @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit(out_idx=[6], pass_configs=get_pass_configs()) -def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, num_stages, - threads): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) +def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, num_stages, threads): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape_q = [batch, heads, dim] shape_k = [batch, seqlen_kv, groups, dim] shape_v = [batch, seqlen_kv, groups, dim] @@ -73,11 +63,11 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, @T.macro def flash_attn( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), + Output: T.Tensor([batch, heads, dim], dtype), ): with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) @@ -98,20 +88,19 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, hid = by cur_kv_head = hid // (kv_group_num // valid_block_H) - T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = T.ceildiv((seqlen_kv // num_split), block_N) for k in T.Pipelined(loop_range, num_stages=num_stages): - T.copy(K[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_shared) - T.copy(mask[bid, k * block_N:(k + 1) * block_N, cur_kv_head], mask_local) + T.copy(K[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_shared) + T.copy(mask[bid, k * block_N : (k + 1) * block_N, cur_kv_head], mask_local) T.clear(acc_s) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else(mask_local[j] != 0, acc_s[i, j], - -T.infinity(accum_dtype)) + acc_s[i, j] = T.if_then_else(mask_local[j] != 0, acc_s[i, j], -T.infinity(accum_dtype)) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) @@ -127,23 +116,23 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, T.copy(acc_s, acc_s_cast) for i, j in T.Parallel(block_H, dim): acc_o[i, j] *= scores_scale[i] - T.copy(V[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], V_shared) + T.copy(V[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, dim): acc_o[i, j] /= logsum[i] for i in T.Parallel(block_H): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale T.copy(acc_o[:valid_block_H, :], O_shared) - T.copy(O_shared, Output[bid, hid * valid_block_H:(hid + 1) * valid_block_H, :]) + T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) @T.macro def flash_attn_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor(part_shape, dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor(part_shape, dtype), ): with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) @@ -165,7 +154,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, sid = bz cur_kv_head = hid // (kv_group_num // valid_block_H) - T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -174,19 +163,26 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, for k in T.Pipelined(loop_range, num_stages=num_stages): T.copy( - K[bid, (seqlen_kv // num_split) * sid + - k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, - cur_kv_head, :], K_shared) + K[ + bid, + (seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, + cur_kv_head, + :, + ], + K_shared, + ) T.copy( - mask[bid, (seqlen_kv // num_split) * sid + - k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, - cur_kv_head], mask_local) + mask[ + bid, + (seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, + cur_kv_head, + ], + mask_local, + ) T.clear(acc_s) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, block_N): - acc_s[i, - j] = T.if_then_else((mask_local[j] != 0) & (j < seqlen_kv // num_split), - acc_s[i, j], -T.infinity(accum_dtype)) + acc_s[i, j] = T.if_then_else((mask_local[j] != 0) & (j < seqlen_kv // num_split), acc_s[i, j], -T.infinity(accum_dtype)) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) @@ -203,9 +199,14 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, for i, j in T.Parallel(block_H, dim): acc_o[i, j] *= scores_scale[i] T.copy( - V[bid, (seqlen_kv // num_split) * sid + - k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, - cur_kv_head, :], V_shared) + V[ + bid, + (seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, + cur_kv_head, + :, + ], + V_shared, + ) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, dim): acc_o[i, j] /= logsum[i] @@ -216,14 +217,13 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, if i < valid_block_H: glse[bid, hid * valid_block_H + i, sid] = logsum[i] T.copy(acc_o[:valid_block_H, :], O_shared) - T.copy(O_shared, Output_partial[bid, hid * valid_block_H:(hid + 1) * valid_block_H, - sid, :]) + T.copy(O_shared, Output_partial[bid, hid * valid_block_H : (hid + 1) * valid_block_H, sid, :]) @T.macro def combine( - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor(part_shape, dtype), - Output: T.Tensor(shape_o, dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor(part_shape, dtype), + Output: T.Tensor(shape_o, dtype), ): with T.Kernel(heads, batch, threads=128) as (by, bz): po_local = T.alloc_fragment([dim], dtype) @@ -233,12 +233,14 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, lse_max_local = T.alloc_fragment([128], accum_dtype) scale_local = T.alloc_fragment([128], accum_dtype) - T.annotate_layout({ - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - lse_max_local: T.Fragment(lse_max_local.shape, forward_thread_fn=lambda i: i), - # lse_local: (local_id, thread_id) - lse_local: T.Fragment(lse_local.shape, forward_fn=lambda i, j: (j, i)), - }) + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + lse_max_local: T.Fragment(lse_max_local.shape, forward_thread_fn=lambda i: i), + # lse_local: (local_id, thread_id) + lse_local: T.Fragment(lse_local.shape, forward_fn=lambda i, j: (j, i)), + } + ) T.clear(lse_logsum_local) T.clear(o_accum_local) @@ -263,26 +265,26 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, @T.prim_func def flashattn_gqa_decode_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor(part_shape, dtype), - Output: T.Tensor(shape_o, dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor(part_shape, dtype), + Output: T.Tensor(shape_o, dtype), ): flash_attn_split(Q, K, V, mask, glse, Output_partial) combine(glse, Output_partial, Output) @T.prim_func def flashattn_gqa_decode_no_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor(part_shape, dtype), - Output: T.Tensor(shape_o, dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor(part_shape, dtype), + Output: T.Tensor(shape_o, dtype), ): flash_attn(Q, K, V, mask, Output) @@ -305,27 +307,21 @@ def ref_program(query, key, value, mask, glse, Output_partial): dim = query.shape[-1] num_head_groups = query.shape[1] // key.shape[2] scale = dim**0.5 - key = rearrange(key, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] - value = rearrange(value, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] + key = rearrange(key, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] - query = rearrange( - query, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] if mask is not None: - mask = rearrange(mask, 'b s h -> b h s') + mask = rearrange(mask, "b s h -> b h s") mask = mask.unsqueeze(1) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] - out = einsum(attention, value, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out @@ -339,16 +335,12 @@ def flash_split_ref(Q, K, V, mask): seqlen_kv = K.size(1) num_head_groups = nheads // groups - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) acc_s = torch.empty((batch, num_head_groups, groups, block_N), device="cuda", dtype=torch.float) - acc_s_cast = torch.empty((batch, num_head_groups, groups, block_N), - device="cuda", - dtype=torch.float16) + acc_s_cast = torch.empty((batch, num_head_groups, groups, block_N), device="cuda", dtype=torch.float16) acc_o = torch.empty((batch, num_head_groups, groups, dim), device="cuda", dtype=torch.float) scores_max = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) - scores_max_prev = torch.empty((batch, num_head_groups, groups), - device="cuda", - dtype=torch.float) + scores_max_prev = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) scores_scale = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) scores_sum = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) logsum = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) @@ -356,25 +348,25 @@ def flash_split_ref(Q, K, V, mask): glogsum = torch.empty((num_split, batch, nheads), device="cuda", dtype=torch.float) Q_ = Q * scale - Q_ = rearrange(Q_, 'b (h g) d -> b g h d', g=num_head_groups) + Q_ = rearrange(Q_, "b (h g) d -> b g h d", g=num_head_groups) for ks in range(num_split): acc_o.fill_(0) logsum.fill_(0) - scores_max.fill_(float('-inf')) - scores_max_prev.fill_(float('-inf')) + scores_max.fill_(float("-inf")) + scores_max_prev.fill_(float("-inf")) for i in range(int((seqlen_kv // num_split) / block_N)): acc_s.fill_(0) - acc_s = torch.einsum('bghd,bkhd->bghk', Q_, - K[:, (seqlen_kv // num_split) * ks + - i * block_N:(seqlen_kv // num_split) * ks + - (i + 1) * block_N, :, :]) # [batch, nheads, block_N] + acc_s = torch.einsum( + "bghd,bkhd->bghk", + Q_, + K[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) # [batch, nheads, block_N] if mask is not None: - mask_local = mask[:, (seqlen_kv // num_split) * ks + - i * block_N:(seqlen_kv // num_split) * ks + (i + 1) * block_N, :] - mask_local = rearrange(mask_local, 'b s h -> b h s') + mask_local = mask[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :] + mask_local = rearrange(mask_local, "b s h -> b h s") mask_local = mask_local.unsqueeze(1) - acc_s = acc_s.masked_fill(mask_local == 0, float('-inf')) + acc_s = acc_s.masked_fill(mask_local == 0, float("-inf")) scores_max_prev = scores_max scores_max = acc_s.max(dim=-1, keepdim=False).values # [batch, nheads] scores_scale = torch.exp2(scores_max_prev - scores_max) # [batch, nheads] @@ -382,15 +374,16 @@ def flash_split_ref(Q, K, V, mask): acc_s = torch.exp2(acc_s - scores_max[:, :, :, None]) acc_s_cast = acc_s.to(torch.float16) # [batch, nheads, block_N] acc_o += torch.einsum( - 'bghk,bkhd->bghd', acc_s_cast, - V[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks + - (i + 1) * block_N, :, :]) + "bghk,bkhd->bghd", + acc_s_cast, + V[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) scores_sum = acc_s.sum(dim=-1, keepdim=False) logsum = logsum * scores_scale + scores_sum - acc_o_out = rearrange(acc_o, 'b g h d->b (h g) d') - logsum_out = rearrange(logsum, 'b g h->b (h g)') + acc_o_out = rearrange(acc_o, "b g h d->b (h g) d") + logsum_out = rearrange(logsum, "b g h->b (h g)") acc_o_out /= logsum_out[:, :, None] - logsum_out = torch.log2(logsum_out) + rearrange(scores_max, 'b g h->b (h g)') + logsum_out = torch.log2(logsum_out) + rearrange(scores_max, "b g h->b (h g)") gacc_o[ks, :, :, :] = acc_o_out glogsum[ks, :, :] = logsum_out @@ -426,7 +419,7 @@ def calc_sim(x, y, name="tensor"): x, y = x.data.double(), y.data.double() denominator = (x * x + y * y).sum() if denominator == 0: - print_red_warning(f'{name} all zero') + print_red_warning(f"{name} all zero") return 1 sim = 2 * (x * y).sum() / denominator return sim @@ -434,28 +427,23 @@ def calc_sim(x, y, name="tensor"): def assert_similar(x, y, eps=1e-2, name="tensor", assert_=False, print_=True): sim = calc_sim(x, y, name) - diff = 1. - sim + diff = 1.0 - sim if not (0 <= diff <= eps): - print_red_warning(f'{name} Error: {diff}') + print_red_warning(f"{name} Error: {diff}") if assert_: - raise AssertionError(f'{name} Error: {diff}') + raise AssertionError(f"{name} Error: {diff}") else: if print_: - print(f'passed: {name} diff={diff}') + print(f"passed: {name} diff={diff}") -def main(batch: int = 1, - heads: int = 32, - groups: int = 8, - kv_seqlen: int = 8192, - dim: int = 128, - tune: bool = False): +def main(batch: int = 1, heads: int = 32, groups: int = 8, kv_seqlen: int = 8192, dim: int = 128, tune: bool = False): batch, heads, groups, kv_seqlen, dim = batch, heads, groups, kv_seqlen, dim qk_flops = 2 * batch * heads * kv_seqlen * dim pv_flops = 2 * batch * heads * kv_seqlen * dim total_flops = qk_flops + pv_flops - if (not tune): + if not tune: config, sm_version = get_heuristic_config() kernel = flashattn(batch, heads, groups, kv_seqlen, dim, **config) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) @@ -497,11 +485,11 @@ def main(batch: int = 1, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--groups', type=int, default=8, help='groups') - parser.add_argument('--kv_seqlen', type=int, default=8192, help='kv sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--groups", type=int, default=8, help="groups") + parser.add_argument("--kv_seqlen", type=int, default=8192, help="kv sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() main(args.batch, args.heads, args.groups, args.kv_seqlen, args.dim, args.tune) diff --git a/examples/flash_decoding/example_gqa_decode_varlen_logits.py b/examples/flash_decoding/example_gqa_decode_varlen_logits.py index 16924eb..0fdd529 100644 --- a/examples/flash_decoding/example_gqa_decode_varlen_logits.py +++ b/examples/flash_decoding/example_gqa_decode_varlen_logits.py @@ -19,8 +19,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, - head_dim) + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) @@ -74,14 +73,9 @@ def _fwd_inner( return m_i, l_i, acc - @triton.autotune( - configs=[ - triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [4, 8]\ - for num_stages in [2, 4]\ - ], - key=['gqa_group_size', 'BLOCK_N', 'BLOCK_D', 'BLOCK_H'], + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [4, 8] for num_stages in [2, 4]], + key=["gqa_group_size", "BLOCK_N", "BLOCK_D", "BLOCK_H"], ) @triton.jit def _fwd_kernel_varlen( @@ -107,13 +101,12 @@ def _fwd_kernel_varlen( stride_od, stride_sb, stride_sh, - stride_sn, #bmask shape [b, q_h, seq/BLOCK_N] + stride_sn, # bmask shape [b, q_h, seq/BLOCK_N] gqa_group_size: tl.constexpr, BLOCK_H: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_D: tl.constexpr, ): - off_z = tl.program_id(0) off_h_for_kv = tl.program_id(1) off_h_q = off_h_for_kv * gqa_group_size @@ -134,8 +127,7 @@ def _fwd_kernel_varlen( S_ptrs = S + off_z * stride_sb + off_h_q * stride_sh mask_h = offs_h < gqa_group_size - q = tl.load( - Q_ptrs + offs_d[None, :] * stride_qd + offs_h[:, None] * stride_qh, mask=mask_h[:, None]) + q = tl.load(Q_ptrs + offs_d[None, :] * stride_qd + offs_h[:, None] * stride_qh, mask=mask_h[:, None]) if s_aux is not None: sink = tl.load(s_aux + off_h_q + offs_h, mask=mask_h).to(tl.float32) @@ -189,14 +181,12 @@ def _fwd_kernel_varlen( acc = acc.to(O.dtype.element_ty) - tl.store( - O_ptrs + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od, - acc, - mask=mask_h[:, None]) + tl.store(O_ptrs + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od, acc, mask=mask_h[:, None]) def get_configs(): import itertools + block_N = [64, 128] block_H = [64] num_split = [1] @@ -204,31 +194,16 @@ def get_configs(): threads = [128] _configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads)) - configs = [{ - 'block_N': c[0], - 'block_H': c[1], - 'num_split': c[2], - 'num_stages': c[3], - 'threads': c[4] - } for c in _configs] + configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs] return configs @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit(out_idx=[-2, -1], debug_root_path="./examples/flash_decoding") -def flashattn(batch, - heads, - k_heads, - max_seqlen_kv, - total_seqlen_k, - dim, - has_sink, - block_N=128, - block_H=64, - num_split=1, - num_stages=1, - threads=128): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) +def flashattn( + batch, heads, k_heads, max_seqlen_kv, total_seqlen_k, dim, has_sink, block_N=128, block_H=64, num_split=1, num_stages=1, threads=128 +): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape_q = [batch, heads, dim] shape_k = [total_seqlen_k, k_heads, dim] shape_v = [total_seqlen_k, k_heads, dim] @@ -243,13 +218,13 @@ def flashattn(batch, @T.macro def flash_attn( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - cu_seqlens_k: T.Tensor([batch + 1], "int32"), - s_aux: T.Tensor([heads], "float32"), - Output: T.Tensor([batch, heads, dim], dtype), - S: T.Tensor(shape_s, dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + cu_seqlens_k: T.Tensor([batch + 1], "int32"), + s_aux: T.Tensor([heads], "float32"), + Output: T.Tensor([batch, heads, dim], dtype), + S: T.Tensor(shape_s, dtype), ): with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) @@ -268,13 +243,15 @@ def flashattn(batch, # S_fragment = T.alloc_fragment([block_H, math.ceil(max_seqlen_kv / block_N)], accum_dtype) s_aux_shared = T.alloc_shared([block_H], "float32") - T.annotate_layout({ - # Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), - # K_shared: tilelang.layout.make_swizzled_layout(K_shared), - # V_shared: tilelang.layout.make_swizzled_layout(V_shared), - # O_shared: tilelang.layout.make_swizzled_layout(O_shared), - # S_shared: tilelang.layout.make_swizzled_layout(S_shared), - }) + T.annotate_layout( + { + # Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), + # K_shared: tilelang.layout.make_swizzled_layout(K_shared), + # V_shared: tilelang.layout.make_swizzled_layout(V_shared), + # O_shared: tilelang.layout.make_swizzled_layout(O_shared), + # S_shared: tilelang.layout.make_swizzled_layout(S_shared), + } + ) bid = bx hid = by @@ -284,7 +261,7 @@ def flashattn(batch, cur_end_k = cu_seqlens_k[bid + 1] cur_seqlen_k = cur_end_k - cur_start_k - T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -292,15 +269,13 @@ def flashattn(batch, # loop_range = T.ceildiv((seqlen_kv // num_split), block_N) loop_range = T.ceildiv((cur_seqlen_k // num_split), block_N) for k in T.Pipelined(loop_range, num_stages=num_stages): - T.copy(K[cur_start_k + k * block_N:cur_start_k + (k + 1) * block_N, cur_kv_head, :], - K_shared) + T.copy(K[cur_start_k + k * block_N : cur_start_k + (k + 1) * block_N, cur_kv_head, :], K_shared) T.clear(acc_s) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, block_N): # acc_s[i, j] = T.if_then_else(mask_local[j] != 0 and k * block_N + j < cur_seqlen_k, acc_s[i, j], # -T.infinity(accum_dtype)) - acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], - -T.infinity(accum_dtype)) + acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], -T.infinity(accum_dtype)) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) @@ -320,12 +295,11 @@ def flashattn(batch, T.copy(acc_s, acc_s_cast) for i, j in T.Parallel(block_H, dim): acc_o[i, j] *= scores_scale[i] - T.copy(V[cur_start_k + k * block_N:cur_start_k + (k + 1) * block_N, cur_kv_head, :], - V_shared) + T.copy(V[cur_start_k + k * block_N : cur_start_k + (k + 1) * block_N, cur_kv_head, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) if has_sink: - T.copy(s_aux[hid * valid_block_H:hid * valid_block_H + block_H], s_aux_shared) + T.copy(s_aux[hid * valid_block_H : hid * valid_block_H + block_H], s_aux_shared) for i in T.Parallel(block_H): logsum[i] += s_aux_shared[i] for i, j in T.Parallel(block_H, dim): @@ -338,20 +312,19 @@ def flashattn(batch, for i in T.Parallel(block_H): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale T.copy(acc_o[:valid_block_H, :], O_shared) - T.copy(O_shared, Output[bid, hid * valid_block_H:(hid + 1) * valid_block_H, :]) + T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) # T.copy(S_fragment, S_shared) - T.copy(S_shared[:valid_block_H, :], S[bid, - hid * valid_block_H:(hid + 1) * valid_block_H, :]) + T.copy(S_shared[:valid_block_H, :], S[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) @T.prim_func def flashattn_gqa_decode_no_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - cu_seqlens_k: T.Tensor([batch + 1], "int32"), - s_aux: T.Tensor([heads], "float32"), - Output: T.Tensor(shape_o, dtype), - S: T.Tensor(shape_s, dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + cu_seqlens_k: T.Tensor([batch + 1], "int32"), + s_aux: T.Tensor([heads], "float32"), + Output: T.Tensor(shape_o, dtype), + S: T.Tensor(shape_s, dtype), ): flash_attn(Q, K, V, cu_seqlens_k, s_aux, Output, S) @@ -388,9 +361,7 @@ def flash_attn_with_attn_pool_decode_tilelang( gqa_group_size = q_h // k_h O_tl = torch.zeros_like(Q) - S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), - dtype=Q.dtype, - device=Q.device) + S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), dtype=Q.dtype, device=Q.device) O_tl, S_tl = tl_kernel(Q, K, V, cu_seqlens_k, s_aux) if use_per_kv_head_sparse_index: @@ -433,9 +404,7 @@ def flash_attn_with_attn_pool_decode( BLOCK_H = 64 O = torch.zeros_like(Q) - S = torch.zeros((batch, q_h, math.ceil(max_seqlen_k / block_size)), - dtype=Q.dtype, - device=Q.device) + S = torch.zeros((batch, q_h, math.ceil(max_seqlen_k / block_size)), dtype=Q.dtype, device=Q.device) def grid(META): return (batch, k_h) @@ -483,15 +452,15 @@ def test_equal_seqlen_decode_main(args): dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 # For decode, query is just 1 token per batch - q = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype) - k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device='cuda', dtype=dtype) - v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device='cuda', dtype=dtype) + q = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype) + v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype) softmax_scale = 1.0 / math.sqrt(head_size) # Generate sink values if needed sink = None if args.test_sink: - sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values print(f"Using sink attention with sink values: {sink}") # Convert to varlen format for K, V @@ -499,8 +468,7 @@ def test_equal_seqlen_decode_main(args): v_varlen = v.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size) # Generate cumulative sequence lengths - cu_seqlens_k = torch.arange( - 0, (batch_size + 1) * k_seqlen, k_seqlen, device='cuda', dtype=torch.int32) + cu_seqlens_k = torch.arange(0, (batch_size + 1) * k_seqlen, k_seqlen, device="cuda", dtype=torch.int32) max_seqlen_k = k_seqlen print(f"q shape: {q.shape}") @@ -510,8 +478,7 @@ def test_equal_seqlen_decode_main(args): num_tokens, q_h, head_size = q.shape batch = cu_seqlens_k.size(0) - 1 k_h = k_varlen.size(1) - tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, - args.test_sink) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink) # Test our decode kernel O_triton, S_triton = flash_attn_with_attn_pool_decode( @@ -524,7 +491,8 @@ def test_equal_seqlen_decode_main(args): args.num_split, softmax_scale, s_aux=sink, - block_size=block_size) + block_size=block_size, + ) O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( q, k_varlen, @@ -539,9 +507,7 @@ def test_equal_seqlen_decode_main(args): tl_kernel=tl_kernel, ) for i in range(batch_size): - S_tilelang[i, :, - math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / - block_size):] = 0 + S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0 # Compute torch reference q_expanded = q.unsqueeze(2) # [b, q_heads, 1, head_size] @@ -550,14 +516,12 @@ def test_equal_seqlen_decode_main(args): if sink is None: # Standard scaled dot-product attention - logits = torch.matmul(q_expanded, k_repeat.transpose( - -2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] + logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] attn_weights = torch.softmax(logits, dim=-1) O_torch = torch.matmul(attn_weights, v_repeat).squeeze(2) # [batch, q_heads, head_size] else: # s_aux attention - logits = torch.matmul(q_expanded, k_repeat.transpose( - -2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] + logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] logits_max = torch.max(logits, dim=-1, keepdim=True).values @@ -566,15 +530,15 @@ def test_equal_seqlen_decode_main(args): unnormalized_scores = torch.exp(logits - logits_or_sinks_max) normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks attn_weights = unnormalized_scores / normalizer - O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), - v_repeat).squeeze(2) # [batch, q_heads, head_size] + O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat).squeeze(2) # [batch, q_heads, head_size] # Compute attention score pooling attn_score_pooled = torch.max_pool2d( attn_weights.squeeze(2), # [b, q_heads, k_seqlen] kernel_size=(q_heads, block_size), stride=(q_heads, block_size), - ceil_mode=True).to(torch.float16) + ceil_mode=True, + ).to(torch.float16) print("S_tilelang", S_tilelang) print("attn_score_pooled", attn_score_pooled) @@ -588,15 +552,10 @@ def test_equal_seqlen_decode_main(args): print(f"Max difference in S: {max_diff_s.item()}") print(f"Max difference in O_tilelang: {max_diff_o_tilelang.item()}") print(f"Max difference in S_tilelang: {max_diff_s_tilelang.item()}") - assert torch.allclose( - O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" - assert torch.allclose( - S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" - assert torch.allclose( - O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}" - assert torch.allclose( - S_tilelang, attn_score_pooled, atol=1e-2, - rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}" + assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" + assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" + assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}" + assert torch.allclose(S_tilelang, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}" print("✅ All tests passed!") @@ -616,7 +575,7 @@ def test_varlen_decode_main(args): # Generate sink values if needed sink = None if args.test_sink: - sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values print(f"Using sink attention with sink values: {sink}") # Generate variable length k sequences @@ -624,7 +583,7 @@ def test_varlen_decode_main(args): print(f"k_seqlens: {k_seqlens}") # Generate cumulative sequence lengths for k - cu_seqlens_k = torch.zeros(batch_size + 1, device='cuda', dtype=torch.int32) + cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) total_k_tokens = 0 for i in range(batch_size): cu_seqlens_k[i] = total_k_tokens @@ -634,9 +593,9 @@ def test_varlen_decode_main(args): print(f"cu_seqlens_k: {cu_seqlens_k}") # Generate tensors - Q is [batch_size, q_heads, head_size] for decode - q_decode = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype) - k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) - v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) + q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) softmax_scale = 1.0 / math.sqrt(head_size) max_seqlen_k = int(k_seqlens.max()) @@ -649,8 +608,7 @@ def test_varlen_decode_main(args): num_tokens, q_h, head_size = q_decode.shape batch = cu_seqlens_k.size(0) - 1 k_h = k_varlen.size(1) - tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, - args.test_sink) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink) # Test our decode kernel O_triton, S_triton = flash_attn_with_attn_pool_decode( @@ -663,7 +621,8 @@ def test_varlen_decode_main(args): args.num_split, softmax_scale, s_aux=sink, - block_size=block_size) + block_size=block_size, + ) O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( q_decode, k_varlen, @@ -678,9 +637,7 @@ def test_varlen_decode_main(args): tl_kernel=tl_kernel, ) for i in range(batch_size): - S_tilelang[i, :, - math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / - block_size):] = 0 + S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0 # Create torch reference - pad tensors for comparison k_padded_list = [] @@ -694,8 +651,8 @@ def test_varlen_decode_main(args): k_end = cu_seqlens_k[i + 1] # Pad to max_seqlen_k - k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device='cuda', dtype=dtype) - v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device='cuda', dtype=dtype) + k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype) + v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype) k_padded[:actual_k_len] = k_varlen[k_start:k_end] v_padded[:actual_k_len] = v_varlen[k_start:k_end] @@ -704,10 +661,8 @@ def test_varlen_decode_main(args): v_padded_list.append(v_padded) # Stack to create batched tensors [b, max_seqlen, kv_heads, head_size] - k_padded_batched = torch.stack( - k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] - v_padded_batched = torch.stack( - v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] + k_padded_batched = torch.stack(k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] + v_padded_batched = torch.stack(v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] # Expand q to match kv heads: [b, q_heads, 1, head_size] q_expanded = q_decode.unsqueeze(2) # [b, q_heads, 1, head_size] @@ -717,20 +672,17 @@ def test_varlen_decode_main(args): print(f"v_padded_batched shape: {v_padded_batched.shape}") # Compute torch reference - k_repeat = repeat_kv(k_padded_batched, - q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] - v_repeat = repeat_kv(v_padded_batched, - q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] + k_repeat = repeat_kv(k_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] + v_repeat = repeat_kv(v_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] if sink is None: # Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen] - attn_score = torch.matmul(q_expanded, k_repeat.transpose( - -2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] + attn_score = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] # Apply sequence length masking for i in range(batch_size): actual_k_len = k_seqlens[i] - attn_score[i, :, :, actual_k_len:] = float('-inf') + attn_score[i, :, :, actual_k_len:] = float("-inf") attn_weights = attn_score.softmax(dim=-1) # [b, q_heads, 1, max_seqlen] @@ -743,13 +695,12 @@ def test_varlen_decode_main(args): O_torch = torch.matmul(attn_weights, v_repeat) # [b, q_heads, 1, head_size] else: # s_aux attention - logits = torch.matmul(q_expanded, k_repeat.transpose( - -2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] + logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] # Apply sequence length masking for i in range(batch_size): actual_k_len = k_seqlens[i] - logits[i, :, :, actual_k_len:] = float('-inf') + logits[i, :, :, actual_k_len:] = float("-inf") sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] logits_max = torch.max(logits, dim=-1, keepdim=True).values @@ -765,8 +716,7 @@ def test_varlen_decode_main(args): attn_weights[i, :, :, actual_k_len:] = 0.0 # Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size] - O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), - v_repeat) # [b, q_heads, 1, head_size] + O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat) # [b, q_heads, 1, head_size] O_torch = O_torch.squeeze(2) # [b, q_heads, head_size] @@ -775,7 +725,8 @@ def test_varlen_decode_main(args): attn_weights.squeeze(2), # [b, q_heads, max_seqlen] kernel_size=(q_heads, block_size), stride=(q_heads, block_size), - ceil_mode=True).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)] + ceil_mode=True, + ).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)] print(f"O_triton shape: {O_triton.shape}") print(f"O_tilelang shape: {O_tilelang.shape}") @@ -791,22 +742,16 @@ def test_varlen_decode_main(args): print(f"Max difference in O_tilelang: {max_diff_o_tl.item()}") max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled)) - max_diff_s_tl = torch.max( - torch.abs(S_tilelang[:, :, :math.ceil(max_seqlen_k / block_size)] - attn_score_pooled)) + max_diff_s_tl = torch.max(torch.abs(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)] - attn_score_pooled)) print(f"Max difference in S: {max_diff_s.item()}") print(f"Max difference in S_tilelang: {max_diff_s_tl.item()}") - assert torch.allclose( - O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" - assert torch.allclose( - S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" - assert torch.allclose( - O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}" - assert torch.allclose( - S_tilelang[:, :, :math.ceil(max_seqlen_k / block_size)], - attn_score_pooled, - atol=1e-2, - rtol=1e-2), f"Score mismatch: {max_diff_s_tl.item()}" + assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" + assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" + assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}" + assert torch.allclose(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)], attn_score_pooled, atol=1e-2, rtol=1e-2), ( + f"Score mismatch: {max_diff_s_tl.item()}" + ) print("✅ All tests passed!") @@ -865,7 +810,7 @@ def speed_benchmark_decode_comparison(args): k_seqlens = torch.full((batch_size,), max_k_seqlen, dtype=int) # Generate cumulative sequence lengths for k - cu_seqlens_k = torch.zeros(batch_size + 1, device='cuda', dtype=torch.int32) + cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) total_k_tokens = 0 for i in range(batch_size): cu_seqlens_k[i] = total_k_tokens @@ -873,9 +818,9 @@ def speed_benchmark_decode_comparison(args): cu_seqlens_k[batch_size] = total_k_tokens # Generate tensors - q_decode = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype) - k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) - v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) + q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) softmax_scale = 1.0 / math.sqrt(head_size) max_seqlen_k = int(k_seqlens.max()) @@ -883,7 +828,7 @@ def speed_benchmark_decode_comparison(args): # Generate sink values if needed sink = None if args.test_sink: - sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values print(" Using sink attention with sink values") print("Setup complete:") @@ -896,8 +841,7 @@ def speed_benchmark_decode_comparison(args): num_tokens, q_h, head_size = q_decode.shape batch = cu_seqlens_k.size(0) - 1 k_h = k_varlen.size(1) - tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, - args.test_sink) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink) # Benchmark print("⚡ Benchmarking Tilelang kernel (100 iterations)...") @@ -920,36 +864,41 @@ def speed_benchmark_decode_comparison(args): # Benchmark print("⚡ Benchmarking Triton kernel (100 iterations)...") - triton_time = do_bench(flash_attn_with_attn_pool_decode, q_decode, k_varlen, v_varlen, - cu_seqlens_k, max_seqlen_k, args.k_seqlen, 1, softmax_scale, sink, - block_size) + triton_time = do_bench( + flash_attn_with_attn_pool_decode, + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + args.k_seqlen, + 1, + softmax_scale, + sink, + block_size, + ) print(f"Average decode kernel time Triton: {triton_time:.3f} ms") print(f"Speedup: {(triton_time / tilelang_time):.3f}") if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Flash Attention Decode with Attention Pooling') - parser.add_argument('--batch_size', type=int, default=1, help='Batch size') - parser.add_argument('--q_heads', type=int, default=32, help='Number of query heads') - parser.add_argument('--kv_heads', type=int, default=8, help='Number of key-value heads') - parser.add_argument('--k_seqlen', type=int, default=8192, help='Key sequence length') - parser.add_argument( - '--head_size', type=int, default=128, choices=[64, 128, 256], help='Head dimension') - parser.add_argument('--block_size', type=int, default=64, help='Block size for computation') - parser.add_argument( - '--dtype', type=str, default='bfloat16', choices=['float16', 'bfloat16'], help='Data type') - parser.add_argument( - '--test_varlen', action='store_true', help='Test with truly variable sequence lengths') - parser.add_argument( - '--test_sink', action='store_true', help='Test with sink attention mechanism') - parser.add_argument('--benchmark', action='store_true', help='Run speed benchmark') - parser.add_argument( - '--num_split', type=int, default=1, choices=[1, 16], help='Number of splits') + parser = argparse.ArgumentParser(description="Flash Attention Decode with Attention Pooling") + parser.add_argument("--batch_size", type=int, default=1, help="Batch size") + parser.add_argument("--q_heads", type=int, default=32, help="Number of query heads") + parser.add_argument("--kv_heads", type=int, default=8, help="Number of key-value heads") + parser.add_argument("--k_seqlen", type=int, default=8192, help="Key sequence length") + parser.add_argument("--head_size", type=int, default=128, choices=[64, 128, 256], help="Head dimension") + parser.add_argument("--block_size", type=int, default=64, help="Block size for computation") + parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float16", "bfloat16"], help="Data type") + parser.add_argument("--test_varlen", action="store_true", help="Test with truly variable sequence lengths") + parser.add_argument("--test_sink", action="store_true", help="Test with sink attention mechanism") + parser.add_argument("--benchmark", action="store_true", help="Run speed benchmark") + parser.add_argument("--num_split", type=int, default=1, choices=[1, 16], help="Number of splits") args = parser.parse_args() args.test_sink = True args.test_varlen = False - args.dtype = 'float16' + args.dtype = "float16" args.num_split = 1 if args.benchmark: diff --git a/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py b/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py index e565cbe..3537e5a 100644 --- a/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py +++ b/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py @@ -10,6 +10,7 @@ torch.manual_seed(0) def get_configs(): import itertools + block_N = [64, 128] block_H = [64] num_split = [1] @@ -17,32 +18,28 @@ def get_configs(): threads = [128] _configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads)) - configs = [{ - 'block_N': c[0], - 'block_H': c[1], - 'num_split': c[2], - 'num_stages': c[3], - 'threads': c[4] - } for c in _configs] + configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs] return configs # @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit(out_idx=[-2, -1], debug_root_path="./examples/flash_decoding") -def flashattn(batch, - heads, - k_heads, - max_seqlen_kv, - total_seqlen_k, - dim, - has_sink, - page_block_size, - block_N=128, - block_H=64, - num_split=1, - num_stages=1, - threads=128): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) +def flashattn( + batch, + heads, + k_heads, + max_seqlen_kv, + total_seqlen_k, + dim, + has_sink, + page_block_size, + block_N=128, + block_H=64, + num_split=1, + num_stages=1, + threads=128, +): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape_q = [batch, heads, dim] shape_k = [total_seqlen_k, k_heads, dim] shape_v = [total_seqlen_k, k_heads, dim] @@ -51,21 +48,23 @@ def flashattn(batch, dtype = "float16" accum_dtype = "float" kv_group_num = heads // k_heads - assert page_block_size >= block_N and page_block_size % block_N == 0, "page_block_size must be larger than block_N and a multiple of block_N" + assert page_block_size >= block_N and page_block_size % block_N == 0, ( + "page_block_size must be larger than block_N and a multiple of block_N" + ) valid_block_H = min(block_H, kv_group_num) # TODO: check if max_seqlen_kv is correct for varlen case @T.macro def flash_attn( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - cu_seqlens_k: T.Tensor([batch + 1], "int32"), - s_aux: T.Tensor([heads], "float32"), - BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / block_N)], "int32"), - Output: T.Tensor([batch, heads, dim], dtype), - S: T.Tensor(shape_s, dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + cu_seqlens_k: T.Tensor([batch + 1], "int32"), + s_aux: T.Tensor([heads], "float32"), + BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / block_N)], "int32"), + Output: T.Tensor([batch, heads, dim], dtype), + S: T.Tensor(shape_s, dtype), ): with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) @@ -91,7 +90,7 @@ def flashattn(batch, cur_end_k = cu_seqlens_k[bid + 1] cur_seqlen_k = cur_end_k - cur_start_k - T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -99,15 +98,12 @@ def flashattn(batch, # loop_range = T.ceildiv((seqlen_kv // num_split), block_N) loop_range = T.ceildiv((cur_seqlen_k // num_split), block_N) for k in T.Pipelined(loop_range, num_stages=num_stages): - k_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + ( - k * block_N) % page_block_size - T.copy(K[cur_start_k + k_start:cur_start_k + k_start + block_N, cur_kv_head, :], - K_shared) + k_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + (k * block_N) % page_block_size + T.copy(K[cur_start_k + k_start : cur_start_k + k_start + block_N, cur_kv_head, :], K_shared) T.clear(acc_s) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], - -T.infinity(accum_dtype)) + acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], -T.infinity(accum_dtype)) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) @@ -127,14 +123,12 @@ def flashattn(batch, T.copy(acc_s, acc_s_cast) for i, j in T.Parallel(block_H, dim): acc_o[i, j] *= scores_scale[i] - v_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + ( - k * block_N) % page_block_size - T.copy(V[cur_start_k + v_start:cur_start_k + v_start + block_N, cur_kv_head, :], - V_shared) + v_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + (k * block_N) % page_block_size + T.copy(V[cur_start_k + v_start : cur_start_k + v_start + block_N, cur_kv_head, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) if has_sink: - T.copy(s_aux[hid * valid_block_H:hid * valid_block_H + block_H], s_aux_shared) + T.copy(s_aux[hid * valid_block_H : hid * valid_block_H + block_H], s_aux_shared) for i in T.Parallel(block_H): logsum[i] += s_aux_shared[i] for i, j in T.Parallel(block_H, dim): @@ -144,20 +138,19 @@ def flashattn(batch, for i in T.Parallel(block_H): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale T.copy(acc_o[:valid_block_H, :], O_shared) - T.copy(O_shared, Output[bid, hid * valid_block_H:(hid + 1) * valid_block_H, :]) - T.copy(S_shared[:valid_block_H, :], S[bid, - hid * valid_block_H:(hid + 1) * valid_block_H, :]) + T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) + T.copy(S_shared[:valid_block_H, :], S[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) @T.prim_func def flashattn_gqa_decode_no_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - cu_seqlens_k: T.Tensor([batch + 1], "int32"), - s_aux: T.Tensor([heads], "float32"), - BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / page_block_size)], "int32"), - Output: T.Tensor(shape_o, dtype), - S: T.Tensor(shape_s, dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + cu_seqlens_k: T.Tensor([batch + 1], "int32"), + s_aux: T.Tensor([heads], "float32"), + BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / page_block_size)], "int32"), + Output: T.Tensor(shape_o, dtype), + S: T.Tensor(shape_s, dtype), ): flash_attn(Q, K, V, cu_seqlens_k, s_aux, BLOCK_TABLE, Output, S) @@ -195,9 +188,7 @@ def flash_attn_with_attn_pool_decode_tilelang( gqa_group_size = q_h // k_h O_tl = torch.zeros_like(Q) - S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), - dtype=Q.dtype, - device=Q.device) + S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), dtype=Q.dtype, device=Q.device) O_tl, S_tl = tl_kernel(Q, K, V, cu_seqlens_k, s_aux, block_table) if use_per_kv_head_sparse_index: @@ -223,15 +214,15 @@ def test_equal_seqlen_decode_main(args): dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 # For decode, query is just 1 token per batch - q = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype) - k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device='cuda', dtype=dtype) - v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device='cuda', dtype=dtype) + q = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype) + v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype) softmax_scale = 1.0 / math.sqrt(head_size) # Generate sink values if needed sink = None if args.test_sink: - sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values print(f"Using sink attention with sink values: {sink}") # Convert to varlen format for K, V @@ -239,8 +230,7 @@ def test_equal_seqlen_decode_main(args): v_varlen = v.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size).contiguous() # Generate cumulative sequence lengths - cu_seqlens_k = torch.arange( - 0, (batch_size + 1) * k_seqlen, k_seqlen, device='cuda', dtype=torch.int32) + cu_seqlens_k = torch.arange(0, (batch_size + 1) * k_seqlen, k_seqlen, device="cuda", dtype=torch.int32) max_seqlen_k = k_seqlen print(f"q shape: {q.shape}") @@ -250,11 +240,9 @@ def test_equal_seqlen_decode_main(args): num_tokens, q_h, head_size = q.shape batch = cu_seqlens_k.size(0) - 1 k_h = k_varlen.size(1) - tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, - args.test_sink, page_block_size) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size) - block_table = torch.zeros( - batch, math.ceil(real_max_k_seqlen / page_block_size), device='cuda', dtype=torch.int32) + block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32) block_cnt = 0 for i in range(batch): cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item() @@ -274,7 +262,8 @@ def test_equal_seqlen_decode_main(args): args.num_split, softmax_scale, s_aux=sink, - block_size=block_size) + block_size=block_size, + ) O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( q, k_varlen, @@ -290,9 +279,7 @@ def test_equal_seqlen_decode_main(args): block_table=block_table, ) for i in range(batch_size): - S_tilelang[i, :, - math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / - block_size):] = 0 + S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0 # Compute torch reference q_expanded = q.unsqueeze(2) # [b, q_heads, 1, head_size] @@ -301,14 +288,12 @@ def test_equal_seqlen_decode_main(args): if sink is None: # Standard scaled dot-product attention - logits = torch.matmul(q_expanded, k_repeat.transpose( - -2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] + logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] attn_weights = torch.softmax(logits, dim=-1) O_torch = torch.matmul(attn_weights, v_repeat).squeeze(2) # [batch, q_heads, head_size] else: # s_aux attention - logits = torch.matmul(q_expanded, k_repeat.transpose( - -2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] + logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] logits_max = torch.max(logits, dim=-1, keepdim=True).values @@ -317,15 +302,15 @@ def test_equal_seqlen_decode_main(args): unnormalized_scores = torch.exp(logits - logits_or_sinks_max) normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks attn_weights = unnormalized_scores / normalizer - O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), - v_repeat).squeeze(2) # [batch, q_heads, head_size] + O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat).squeeze(2) # [batch, q_heads, head_size] # Compute attention score pooling attn_score_pooled = torch.max_pool2d( attn_weights.squeeze(2), # [b, q_heads, k_seqlen] kernel_size=(q_heads, block_size), stride=(q_heads, block_size), - ceil_mode=True).to(torch.float16) + ceil_mode=True, + ).to(torch.float16) print("S_tilelang", S_tilelang) print("attn_score_pooled", attn_score_pooled) @@ -339,15 +324,10 @@ def test_equal_seqlen_decode_main(args): print(f"Max difference in S: {max_diff_s.item()}") print(f"Max difference in O_tilelang: {max_diff_o_tilelang.item()}") print(f"Max difference in S_tilelang: {max_diff_s_tilelang.item()}") - assert torch.allclose( - O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" - assert torch.allclose( - S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" - assert torch.allclose( - O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}" - assert torch.allclose( - S_tilelang, attn_score_pooled, atol=1e-2, - rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}" + assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" + assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" + assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}" + assert torch.allclose(S_tilelang, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}" print("✅ All tests passed!") @@ -368,7 +348,7 @@ def test_varlen_decode_main(args): # Generate sink values if needed sink = None if args.test_sink: - sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values print(f"Using sink attention with sink values: {sink}") # Generate variable length k sequences @@ -376,7 +356,7 @@ def test_varlen_decode_main(args): print(f"k_seqlens: {k_seqlens}") # Generate cumulative sequence lengths for k - cu_seqlens_k = torch.zeros(batch_size + 1, device='cuda', dtype=torch.int32) + cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) total_k_tokens = 0 for i in range(batch_size): cu_seqlens_k[i] = total_k_tokens @@ -386,9 +366,9 @@ def test_varlen_decode_main(args): print(f"cu_seqlens_k: {cu_seqlens_k}") # Generate tensors - Q is [batch_size, q_heads, head_size] for decode - q_decode = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype) - k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) - v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) + q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) softmax_scale = 1.0 / math.sqrt(head_size) max_seqlen_k = int(k_seqlens.max()) @@ -401,11 +381,9 @@ def test_varlen_decode_main(args): num_tokens, q_h, head_size = q_decode.shape batch = cu_seqlens_k.size(0) - 1 k_h = k_varlen.size(1) - tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, - args.test_sink, page_block_size) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size) - block_table = torch.zeros( - batch, math.ceil(real_max_k_seqlen / page_block_size), device='cuda', dtype=torch.int32) + block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32) block_cnt = 0 for i in range(batch): cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item() @@ -425,7 +403,8 @@ def test_varlen_decode_main(args): args.num_split, softmax_scale, s_aux=sink, - block_size=block_size) + block_size=block_size, + ) O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( q_decode, k_varlen, @@ -441,9 +420,7 @@ def test_varlen_decode_main(args): block_table=block_table, ) for i in range(batch_size): - S_tilelang[i, :, - math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / - block_size):] = 0 + S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0 # Create torch reference - pad tensors for comparison k_padded_list = [] @@ -457,8 +434,8 @@ def test_varlen_decode_main(args): k_end = cu_seqlens_k[i + 1] # Pad to max_seqlen_k - k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device='cuda', dtype=dtype) - v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device='cuda', dtype=dtype) + k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype) + v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype) k_padded[:actual_k_len] = k_varlen[k_start:k_end] v_padded[:actual_k_len] = v_varlen[k_start:k_end] @@ -467,10 +444,8 @@ def test_varlen_decode_main(args): v_padded_list.append(v_padded) # Stack to create batched tensors [b, max_seqlen, kv_heads, head_size] - k_padded_batched = torch.stack( - k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] - v_padded_batched = torch.stack( - v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] + k_padded_batched = torch.stack(k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] + v_padded_batched = torch.stack(v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] # Expand q to match kv heads: [b, q_heads, 1, head_size] q_expanded = q_decode.unsqueeze(2) # [b, q_heads, 1, head_size] @@ -480,20 +455,17 @@ def test_varlen_decode_main(args): print(f"v_padded_batched shape: {v_padded_batched.shape}") # Compute torch reference - k_repeat = repeat_kv(k_padded_batched, - q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] - v_repeat = repeat_kv(v_padded_batched, - q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] + k_repeat = repeat_kv(k_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] + v_repeat = repeat_kv(v_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] if sink is None: # Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen] - attn_score = torch.matmul(q_expanded, k_repeat.transpose( - -2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] + attn_score = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] # Apply sequence length masking for i in range(batch_size): actual_k_len = k_seqlens[i] - attn_score[i, :, :, actual_k_len:] = float('-inf') + attn_score[i, :, :, actual_k_len:] = float("-inf") attn_weights = attn_score.softmax(dim=-1) # [b, q_heads, 1, max_seqlen] @@ -506,13 +478,12 @@ def test_varlen_decode_main(args): O_torch = torch.matmul(attn_weights, v_repeat) # [b, q_heads, 1, head_size] else: # s_aux attention - logits = torch.matmul(q_expanded, k_repeat.transpose( - -2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] + logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] # Apply sequence length masking for i in range(batch_size): actual_k_len = k_seqlens[i] - logits[i, :, :, actual_k_len:] = float('-inf') + logits[i, :, :, actual_k_len:] = float("-inf") sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] logits_max = torch.max(logits, dim=-1, keepdim=True).values @@ -528,8 +499,7 @@ def test_varlen_decode_main(args): attn_weights[i, :, :, actual_k_len:] = 0.0 # Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size] - O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), - v_repeat) # [b, q_heads, 1, head_size] + O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat) # [b, q_heads, 1, head_size] O_torch = O_torch.squeeze(2) # [b, q_heads, head_size] @@ -538,7 +508,8 @@ def test_varlen_decode_main(args): attn_weights.squeeze(2), # [b, q_heads, max_seqlen] kernel_size=(q_heads, block_size), stride=(q_heads, block_size), - ceil_mode=True).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)] + ceil_mode=True, + ).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)] print(f"O_triton shape: {O_triton.shape}") print(f"O_tilelang shape: {O_tilelang.shape}") @@ -554,22 +525,16 @@ def test_varlen_decode_main(args): print(f"Max difference in O_tilelang: {max_diff_o_tl.item()}") max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled)) - max_diff_s_tl = torch.max( - torch.abs(S_tilelang[:, :, :math.ceil(max_seqlen_k / block_size)] - attn_score_pooled)) + max_diff_s_tl = torch.max(torch.abs(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)] - attn_score_pooled)) print(f"Max difference in S: {max_diff_s.item()}") print(f"Max difference in S_tilelang: {max_diff_s_tl.item()}") - assert torch.allclose( - O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" - assert torch.allclose( - S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" - assert torch.allclose( - O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}" - assert torch.allclose( - S_tilelang[:, :, :math.ceil(max_seqlen_k / block_size)], - attn_score_pooled, - atol=1e-2, - rtol=1e-2), f"Score mismatch: {max_diff_s_tl.item()}" + assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" + assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" + assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}" + assert torch.allclose(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)], attn_score_pooled, atol=1e-2, rtol=1e-2), ( + f"Score mismatch: {max_diff_s_tl.item()}" + ) print("✅ All tests passed!") @@ -605,7 +570,7 @@ def speed_benchmark_decode_comparison(args): k_seqlens = torch.full((batch_size,), max_k_seqlen, dtype=int) # Generate cumulative sequence lengths for k - cu_seqlens_k = torch.zeros(batch_size + 1, device='cuda', dtype=torch.int32) + cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) total_k_tokens = 0 for i in range(batch_size): cu_seqlens_k[i] = total_k_tokens @@ -613,9 +578,9 @@ def speed_benchmark_decode_comparison(args): cu_seqlens_k[batch_size] = total_k_tokens # Generate tensors - q_decode = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype) - k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) - v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) + q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) softmax_scale = 1.0 / math.sqrt(head_size) max_seqlen_k = int(k_seqlens.max()) @@ -623,7 +588,7 @@ def speed_benchmark_decode_comparison(args): # Generate sink values if needed sink = None if args.test_sink: - sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values print(" Using sink attention with sink values") print("Setup complete:") @@ -636,11 +601,9 @@ def speed_benchmark_decode_comparison(args): num_tokens, q_h, head_size = q_decode.shape batch = cu_seqlens_k.size(0) - 1 k_h = k_varlen.size(1) - tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, - args.test_sink, page_block_size) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size) - block_table = torch.zeros( - batch, math.ceil(real_max_k_seqlen / page_block_size), device='cuda', dtype=torch.int32) + block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32) block_cnt = 0 for i in range(batch): cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item() @@ -671,36 +634,41 @@ def speed_benchmark_decode_comparison(args): # Benchmark print("⚡ Benchmarking Triton kernel (100 iterations)...") - triton_time = do_bench(flash_attn_with_attn_pool_decode, q_decode, k_varlen, v_varlen, - cu_seqlens_k, max_seqlen_k, args.k_seqlen, 1, softmax_scale, sink, - block_size) + triton_time = do_bench( + flash_attn_with_attn_pool_decode, + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + args.k_seqlen, + 1, + softmax_scale, + sink, + block_size, + ) print(f"Average decode kernel time Triton: {triton_time:.3f} ms") print(f"Speedup: {(triton_time / tilelang_time):.3f}") if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Flash Attention Decode with Attention Pooling') - parser.add_argument('--batch_size', type=int, default=1, help='Batch size') - parser.add_argument('--q_heads', type=int, default=32, help='Number of query heads') - parser.add_argument('--kv_heads', type=int, default=8, help='Number of key-value heads') - parser.add_argument('--k_seqlen', type=int, default=8192, help='Key sequence length') - parser.add_argument( - '--head_size', type=int, default=128, choices=[64, 128, 256], help='Head dimension') - parser.add_argument('--block_size', type=int, default=128, help='Block size for computation') - parser.add_argument( - '--dtype', type=str, default='bfloat16', choices=['float16', 'bfloat16'], help='Data type') - parser.add_argument( - '--test_varlen', action='store_true', help='Test with truly variable sequence lengths') - parser.add_argument( - '--test_sink', action='store_true', help='Test with sink attention mechanism') - parser.add_argument('--benchmark', action='store_true', help='Run speed benchmark') - parser.add_argument( - '--num_split', type=int, default=1, choices=[1, 16], help='Number of splits') - parser.add_argument('--page_block_size', type=int, default=128, help='Page block size') + parser = argparse.ArgumentParser(description="Flash Attention Decode with Attention Pooling") + parser.add_argument("--batch_size", type=int, default=1, help="Batch size") + parser.add_argument("--q_heads", type=int, default=32, help="Number of query heads") + parser.add_argument("--kv_heads", type=int, default=8, help="Number of key-value heads") + parser.add_argument("--k_seqlen", type=int, default=8192, help="Key sequence length") + parser.add_argument("--head_size", type=int, default=128, choices=[64, 128, 256], help="Head dimension") + parser.add_argument("--block_size", type=int, default=128, help="Block size for computation") + parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float16", "bfloat16"], help="Data type") + parser.add_argument("--test_varlen", action="store_true", help="Test with truly variable sequence lengths") + parser.add_argument("--test_sink", action="store_true", help="Test with sink attention mechanism") + parser.add_argument("--benchmark", action="store_true", help="Run speed benchmark") + parser.add_argument("--num_split", type=int, default=1, choices=[1, 16], help="Number of splits") + parser.add_argument("--page_block_size", type=int, default=128, help="Page block size") args = parser.parse_args() args.test_sink = True args.test_varlen = True - args.dtype = 'float16' + args.dtype = "float16" args.num_split = 1 if args.benchmark: diff --git a/examples/flash_decoding/example_mha_inference.py b/examples/flash_decoding/example_mha_inference.py index 0360b3e..d0381bc 100644 --- a/examples/flash_decoding/example_mha_inference.py +++ b/examples/flash_decoding/example_mha_inference.py @@ -10,7 +10,7 @@ num_split = 4 @tilelang.jit(out_idx=[5]) def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_N): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape_q = [batch, seqlen_q, heads, dim] shape_kv = [batch, seqlen_kv, heads, dim] part_shape = [batch, seqlen_q, heads, num_split, dim] @@ -29,14 +29,11 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ bid: T.int32, sid: T.int32, ): - T.copy( - K[bid, (seqlen_kv // num_split) * sid + k * block_N:(seqlen_kv // num_split) * sid + - (k + 1) * block_N, hid, :], K_shared) + T.copy(K[bid, (seqlen_kv // num_split) * sid + k * block_N : (seqlen_kv // num_split) * sid + (k + 1) * block_N, hid, :], K_shared) # TODO: Handle causal split case if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(mid * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(mid * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -52,20 +49,18 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ bid: T.int32, sid: T.int32, ): - T.copy( - V[bid, (seqlen_kv // num_split) * sid + k * block_N:(seqlen_kv // num_split) * sid + - (k + 1) * block_N, hid, :], V_shared) + T.copy(V[bid, (seqlen_kv // num_split) * sid + k * block_N : (seqlen_kv // num_split) * sid + (k + 1) * block_N, hid, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -91,23 +86,21 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.macro def flash_attn_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_kv, dtype), - V: T.Tensor(shape_kv, dtype), - glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), - Output_partial: T.Tensor(part_shape, dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_kv, dtype), + V: T.Tensor(shape_kv, dtype), + glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), + Output_partial: T.Tensor(part_shape, dtype), ): - with T.Kernel( - T.ceildiv(seqlen_q, block_M), heads * batch, num_split, - threads=128) as (bx, by, bz): + with T.Kernel(T.ceildiv(seqlen_q, block_M), heads * batch, num_split, threads=128) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) @@ -128,39 +121,36 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ # NOTE(wt): tma barrier has some problems with padded dimensions (seq_q here) currently # disable relevant tma copy and use SIMT as fallback for now - T.copy(Q[bid, mid * block_M:(mid + 1) * block_M, hid, :], Q_shared, disable_tma=True) + T.copy(Q[bid, mid * block_M : (mid + 1) * block_M, hid, :], Q_shared, disable_tma=True) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) # TODO: Handle causal split case loop_range = ( - T.min(T.ceildiv(seqlen_kv, block_N), T.ceildiv( - (mid + 1) * block_M, block_N)) if is_causal else T.ceildiv( - (seqlen_kv // num_split), block_N)) + T.min(T.ceildiv(seqlen_kv, block_N), T.ceildiv((mid + 1) * block_M, block_N)) + if is_causal + else T.ceildiv((seqlen_kv // num_split), block_N) + ) for k in T.Pipelined(loop_range, num_stages=2): MMA0(K, Q_shared, K_shared, acc_s, k, mid, hid, bid, sid) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, hid, bid, sid) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, glse[bid, hid, sid, mid * block_M:(mid + 1) * block_M]) + T.copy(logsum, glse[bid, hid, sid, mid * block_M : (mid + 1) * block_M]) T.copy(acc_o, O_shared) - T.copy( - O_shared, - Output_partial[bid, mid * block_M:(mid + 1) * block_M, hid, sid, :], - disable_tma=True) + T.copy(O_shared, Output_partial[bid, mid * block_M : (mid + 1) * block_M, hid, sid, :], disable_tma=True) @T.macro def combine( - glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), - Output_partial: T.Tensor(part_shape, dtype), - Output: T.Tensor(shape_q, dtype), + glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), + Output_partial: T.Tensor(part_shape, dtype), + Output: T.Tensor(shape_q, dtype), ): with T.Kernel(T.ceildiv(seqlen_q, block_M), heads, batch, threads=128) as (bx, by, bz): po_local = T.alloc_fragment([block_M, dim], dtype) @@ -173,20 +163,25 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ lse_max_local = T.alloc_fragment([block_M], accum_dtype) scale_local = T.alloc_fragment([block_M], accum_dtype) - T.annotate_layout({ - o_accum_local: T.Fragment(o_accum_local.shape, forward_thread_fn=lambda i, j: i), - o_shared: tilelang.layout.make_swizzled_layout(o_shared), - po_shared: tilelang.layout.make_swizzled_layout(po_shared), - }) + T.annotate_layout( + { + o_accum_local: T.Fragment(o_accum_local.shape, forward_thread_fn=lambda i, j: i), + o_shared: tilelang.layout.make_swizzled_layout(o_shared), + po_shared: tilelang.layout.make_swizzled_layout(po_shared), + } + ) T.clear(lse_logsum_local) T.clear(o_accum_local) - T.copy(glse[ - bz, - by, - :, - bx * block_M:(bx + 1) * block_M, - ], lse_local) + T.copy( + glse[ + bz, + by, + :, + bx * block_M : (bx + 1) * block_M, + ], + lse_local, + ) T.reduce_max(lse_local, lse_max_local, dim=0, clear=False) for k in T.Pipelined(num_split): T.copy(lse_local[k, :], lse_local_split) @@ -195,10 +190,7 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ for i in T.Parallel(block_M): lse_logsum_local[i] = T.log2(lse_logsum_local[i]) + lse_max_local[i] for k in T.Pipelined(num_split, num_stages=2): - T.copy( - Output_partial[bz, bx * block_M:(bx + 1) * block_M, by, k, :], - po_shared, - disable_tma=True) + T.copy(Output_partial[bz, bx * block_M : (bx + 1) * block_M, by, k, :], po_shared, disable_tma=True) T.copy(po_shared, po_local) for i in T.Parallel(block_M): lse_local_split[i] = lse_local[k, i] @@ -207,16 +199,16 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ for i, j in T.Parallel(block_M, dim): o_accum_local[i, j] += po_local[i, j] * scale_local[i] T.copy(o_accum_local, o_shared) - T.copy(o_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :], disable_tma=True) + T.copy(o_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :], disable_tma=True) @T.prim_func def flashattn_mha_inference( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_kv, dtype), - V: T.Tensor(shape_kv, dtype), - glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), - Output_partial: T.Tensor(part_shape, dtype), # [batch, seqlen_q, heads, num_split, dim] - Output: T.Tensor(shape_q, dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_kv, dtype), + V: T.Tensor(shape_kv, dtype), + glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), + Output_partial: T.Tensor(part_shape, dtype), # [batch, seqlen_q, heads, num_split, dim] + Output: T.Tensor(shape_q, dtype), ): flash_attn_split(Q, K, V, glse, Output_partial) combine(glse, Output_partial, Output) @@ -227,10 +219,10 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ def ref_program(Q, K, V, glse, Output_partial, causal): assert causal is False dim = Q.size(-1) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output @@ -258,7 +250,7 @@ def flash_split_ref(Q, K, V, causal): block_N = 128 seqlen_kv = K.size(1) - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) acc_s = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float) acc_s_cast = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float16) acc_o = torch.empty((batch, block_M, nheads, dim), device="cuda", dtype=torch.float) @@ -275,14 +267,15 @@ def flash_split_ref(Q, K, V, causal): for ks in range(num_split): acc_o.fill_(0) logsum.fill_(0) - scores_max.fill_(float('-inf')) - scores_max_prev.fill_(float('-inf')) + scores_max.fill_(float("-inf")) + scores_max_prev.fill_(float("-inf")) for i in range(int((seqlen_kv // num_split) / block_N)): acc_s.fill_(0) - acc_s = torch.einsum('bqhd,bkhd->bhqk', Q_, - K[:, (seqlen_kv // num_split) * ks + - i * block_N:(seqlen_kv // num_split) * ks + - (i + 1) * block_N, :, :]) # [batch, seqlen, nheads, block_N] + acc_s = torch.einsum( + "bqhd,bkhd->bhqk", + Q_, + K[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) # [batch, seqlen, nheads, block_N] scores_max_prev = scores_max scores_max = acc_s.max(dim=-1, keepdim=False).values # [blockM] scores_scale = torch.exp2(scores_max_prev - scores_max) @@ -290,9 +283,10 @@ def flash_split_ref(Q, K, V, causal): acc_s = torch.exp2(acc_s - scores_max[:, :, :, None]) acc_s_cast = acc_s.to(torch.float16) acc_o += torch.einsum( - 'bhqk,bkhd->bqhd', acc_s_cast, - V[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks + - (i + 1) * block_N, :, :]) + "bhqk,bkhd->bqhd", + acc_s_cast, + V[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) scores_sum = acc_s.sum(dim=-1, keepdim=False) logsum = logsum * scores_scale + scores_sum acc_o /= logsum[:, :, :, None].transpose(1, 2) @@ -300,8 +294,7 @@ def flash_split_ref(Q, K, V, causal): gacc_o[ks, :, :, :, :] = acc_o glogsum[ks, :, :, :] = logsum - return glogsum.to(torch.float16).permute(1, 2, 0, - 3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4) + return glogsum.to(torch.float16).permute(1, 2, 0, 3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4) def main(BATCH=1, H=32, Q_CTX=128, KV_CTX=8192, D_HEAD=128, causal=False): diff --git a/examples/fusedmoe/example_fusedmoe_tilelang.py b/examples/fusedmoe/example_fusedmoe_tilelang.py index a8d6849..b737f30 100644 --- a/examples/fusedmoe/example_fusedmoe_tilelang.py +++ b/examples/fusedmoe/example_fusedmoe_tilelang.py @@ -9,17 +9,18 @@ from example_fusedmoe_torch import * @tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) -def moe_forward_tilelang_shared(d_hidden, - d_expert, - n_shared_experts, - dtype, - num_tokens, - block_token=128, - block_dhidden=128, - block_dexpert=128, - threads=256, - num_stages=1): - +def moe_forward_tilelang_shared( + d_hidden, + d_expert, + n_shared_experts, + dtype, + num_tokens, + block_token=128, + block_dhidden=128, + block_dexpert=128, + threads=256, + num_stages=1, +): scale = 1.44269504 # log2(e) # Parameters @@ -36,17 +37,15 @@ def moe_forward_tilelang_shared(d_hidden, @T.prim_func def kernel_shared( - input: T.Tensor(input_shape, dtype), # type: ignore - shared_W_gate: T.Tensor(shared_W_gate_shape, dtype), # type: ignore - shared_W_up: T.Tensor(shared_W_up_shape, dtype), # type: ignore - shared_W_down: T.Tensor(shared_W_down_shape, dtype), # type: ignore - up_logits: T.Tensor((num_tokens, dexpert), dtype), # type: ignore - output: T.Tensor(input_shape, dtype), # type: ignore + input: T.Tensor(input_shape, dtype), # type: ignore + shared_W_gate: T.Tensor(shared_W_gate_shape, dtype), # type: ignore + shared_W_up: T.Tensor(shared_W_up_shape, dtype), # type: ignore + shared_W_down: T.Tensor(shared_W_down_shape, dtype), # type: ignore + up_logits: T.Tensor((num_tokens, dexpert), dtype), # type: ignore + output: T.Tensor(input_shape, dtype), # type: ignore ): # Step 1: Compute gate and up logits - with T.Kernel( - T.ceildiv(num_tokens, block_token), T.ceildiv(dexpert, block_dexpert), - threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(num_tokens, block_token), T.ceildiv(dexpert, block_dexpert), threads=threads) as (bx, by): # Split the block to shared experts and routed experts input_shared = T.alloc_fragment((block_token, block_dhidden), dtype=dtype) W_gate_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype) @@ -70,16 +69,13 @@ def moe_forward_tilelang_shared(d_hidden, # Fuse with SiLU and element-wise product for i, j in T.Parallel(block_token, block_dexpert): - gate_logits_local[i, j] = gate_logits_local[i, j] * ( - 1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale))) + gate_logits_local[i, j] = gate_logits_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale))) up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j] T.copy(up_logits_local, up_logits[bx * block_token, by * block_dexpert]) # Step 2: Compute down logits - with T.Kernel( - T.ceildiv(num_tokens, block_token), T.ceildiv(dhidden, block_dhidden), - threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(num_tokens, block_token), T.ceildiv(dhidden, block_dhidden), threads=threads) as (bx, by): up_logits_shared = T.alloc_fragment((block_token, block_dexpert), dtype=dtype) W_down_shared = T.alloc_shared((block_dhidden, block_dexpert), dtype=dtype) output_local = T.alloc_fragment((block_token, block_dhidden), dtype=accum_type) @@ -98,20 +94,21 @@ def moe_forward_tilelang_shared(d_hidden, @tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) -def moe_forward_tilelang_routed(d_hidden, - d_expert, - n_routed_experts, - dtype, - group_sum, - group_count, - block_token=128, - block_dhidden=128, - block_dexpert=128, - threads=256, - num_stages=1, - k_pack=1, - coalesced_width=None): - +def moe_forward_tilelang_routed( + d_hidden, + d_expert, + n_routed_experts, + dtype, + group_sum, + group_count, + block_token=128, + block_dhidden=128, + block_dexpert=128, + threads=256, + num_stages=1, + k_pack=1, + coalesced_width=None, +): scale = 1.44269504 # log2(e) # Parameters @@ -132,22 +129,22 @@ def moe_forward_tilelang_routed(d_hidden, routed_expert_gate_shape = (n_routed_experts, dexpert, dhidden) routed_expert_up_shape = (n_routed_experts, dexpert, dhidden) routed_expert_down_shape = (n_routed_experts, dhidden, dexpert) - routed_expert_weights_shape = (group_sum) - group_sizes_shape = (n_routed_experts) + routed_expert_weights_shape = group_sum + group_sizes_shape = n_routed_experts @T.prim_func def kernel( - input: T.Tensor(input_shape, dtype), # type: ignore - routed_expert_gate: T.Tensor(routed_expert_gate_shape, dtype), # type: ignore - routed_expert_up: T.Tensor(routed_expert_up_shape, dtype), # type: ignore - routed_expert_down: T.Tensor(routed_expert_down_shape, dtype), # type: ignore - routed_expert_weights: T.Tensor(routed_expert_weights_shape, dtype), # type: ignore - group_sizes: T.Tensor(group_sizes_shape, "int32"), # type: ignore - group_offsets: T.Tensor(group_sizes_shape, "int32"), # type: ignore - group_padded_offsets: T.Tensor(group_sizes_shape, "int32"), # type: ignore - group_idx_for_bx: T.Tensor((M,), "int32"), # type: ignore - up_logits: T.Tensor(intermediate_shape, dtype), # type: ignore - output: T.Tensor(input_shape, dtype), # type: ignore + input: T.Tensor(input_shape, dtype), # type: ignore + routed_expert_gate: T.Tensor(routed_expert_gate_shape, dtype), # type: ignore + routed_expert_up: T.Tensor(routed_expert_up_shape, dtype), # type: ignore + routed_expert_down: T.Tensor(routed_expert_down_shape, dtype), # type: ignore + routed_expert_weights: T.Tensor(routed_expert_weights_shape, dtype), # type: ignore + group_sizes: T.Tensor(group_sizes_shape, "int32"), # type: ignore + group_offsets: T.Tensor(group_sizes_shape, "int32"), # type: ignore + group_padded_offsets: T.Tensor(group_sizes_shape, "int32"), # type: ignore + group_idx_for_bx: T.Tensor((M,), "int32"), # type: ignore + up_logits: T.Tensor(intermediate_shape, dtype), # type: ignore + output: T.Tensor(input_shape, dtype), # type: ignore ): # Step 1: Compute gate and up logits with T.Kernel(M, T.ceildiv(dexpert, block_dexpert), threads=threads) as (bx, by): @@ -168,48 +165,37 @@ def moe_forward_tilelang_routed(d_hidden, cur_group_idx[0] = group_idx_for_bx[bx] cur_group_size[0] = group_sizes[cur_group_idx[0]] - m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[ - cur_group_idx[0]] - actual_rows = T.max( - 0, - T.min(block_token, cur_group_size[0] - - (m_start_padded - group_padded_offsets[cur_group_idx[0]]))) + m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[cur_group_idx[0]] + actual_rows = T.max(0, T.min(block_token, cur_group_size[0] - (m_start_padded - group_padded_offsets[cur_group_idx[0]]))) T.clear(gate_logits_local) T.clear(up_logits_local) for k in T.Pipelined(T.ceildiv(dhidden, block_dhidden), num_stages=num_stages): T.copy( - input[m_start:m_start + block_token, k * block_dhidden:(k + 1) * block_dhidden], + input[m_start : m_start + block_token, k * block_dhidden : (k + 1) * block_dhidden], input_shared, - coalesced_width=coalesced_width) + coalesced_width=coalesced_width, + ) T.copy( - routed_expert_gate[cur_group_idx[0], - by * block_dexpert:(by + 1) * block_dexpert, - k * block_dhidden:(k + 1) * block_dhidden], - routed_expert_gate_shared, - coalesced_width=coalesced_width) - T.gemm( - input_shared, + routed_expert_gate[ + cur_group_idx[0], by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden + ], routed_expert_gate_shared, - gate_logits_local, - k_pack=k_pack, - transpose_B=True) + coalesced_width=coalesced_width, + ) + T.gemm(input_shared, routed_expert_gate_shared, gate_logits_local, k_pack=k_pack, transpose_B=True) T.copy( - routed_expert_up[cur_group_idx[0], by * block_dexpert:(by + 1) * block_dexpert, - k * block_dhidden:(k + 1) * block_dhidden], + routed_expert_up[ + cur_group_idx[0], by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden + ], routed_expert_up_shared, - coalesced_width=coalesced_width) - T.gemm( - input_shared, - routed_expert_up_shared, - up_logits_local, - k_pack=k_pack, - transpose_B=True) + coalesced_width=coalesced_width, + ) + T.gemm(input_shared, routed_expert_up_shared, up_logits_local, k_pack=k_pack, transpose_B=True) for i, j in T.Parallel(block_token, block_dexpert): - gate_logits_local[i, j] = gate_logits_local[i, j] * ( - 1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale))) + gate_logits_local[i, j] = gate_logits_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale))) up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j] for i, j in T.Parallel(block_token, block_dexpert): @@ -232,50 +218,35 @@ def moe_forward_tilelang_routed(d_hidden, cur_group_idx[0] = group_idx_for_bx[bx] cur_group_size[0] = group_sizes[cur_group_idx[0]] - m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[ - cur_group_idx[0]] - actual_rows = T.max( - 0, - T.min(block_token, cur_group_size[0] - - (m_start_padded - group_padded_offsets[cur_group_idx[0]]))) + m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[cur_group_idx[0]] + actual_rows = T.max(0, T.min(block_token, cur_group_size[0] - (m_start_padded - group_padded_offsets[cur_group_idx[0]]))) T.clear(output_local) for k in T.Pipelined(T.ceildiv(dexpert, block_dexpert), num_stages=num_stages): T.copy( - up_logits[m_start:m_start + block_token, - k * block_dexpert:(k + 1) * block_dexpert], + up_logits[m_start : m_start + block_token, k * block_dexpert : (k + 1) * block_dexpert], up_logits_shared, - coalesced_width=coalesced_width) + coalesced_width=coalesced_width, + ) T.copy( - routed_expert_down[cur_group_idx[0], - by * block_dhidden:(by + 1) * block_dhidden, - k * block_dexpert:(k + 1) * block_dexpert], - routed_expert_down_shared, - coalesced_width=coalesced_width) - T.gemm( - up_logits_shared, + routed_expert_down[ + cur_group_idx[0], by * block_dhidden : (by + 1) * block_dhidden, k * block_dexpert : (k + 1) * block_dexpert + ], routed_expert_down_shared, - output_local, - k_pack=k_pack, - transpose_B=True) + coalesced_width=coalesced_width, + ) + T.gemm(up_logits_shared, routed_expert_down_shared, output_local, k_pack=k_pack, transpose_B=True) for i, j in T.Parallel(block_token, block_dhidden): if i < actual_rows: - output[m_start + i, by * block_dhidden + - j] = output_local[i, j] * routed_expert_weights[m_start + i] + output[m_start + i, by * block_dhidden + j] = output_local[i, j] * routed_expert_weights[m_start + i] return kernel class Expert(nn.Module): - - def __init__(self, - config: Dict, - gate: torch.Tensor, - up: torch.Tensor, - down: torch.Tensor, - d_expert: Optional[int] = None): + def __init__(self, config: Dict, gate: torch.Tensor, up: torch.Tensor, down: torch.Tensor, d_expert: Optional[int] = None): super().__init__() self.config = config self.act_fn = nn.SiLU() @@ -294,14 +265,13 @@ class Expert(nn.Module): class MoEGate(nn.Module): - def __init__(self, config: Dict, weights: Dict): super().__init__() self.top_k: int = config["n_experts_per_token"] self.num_experts: int = config["n_routed_experts"] self.d_hidden: int = config["d_hidden"] - self.W_g_weight = weights['router.weight'].t() + self.W_g_weight = weights["router.weight"].t() def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: logits = x @ self.W_g_weight @@ -312,76 +282,69 @@ class MoEGate(nn.Module): class MoE(nn.Module): - - def __init__(self, - config: Dict, - shared_kernel: tilelang.JITKernel, - routed_kernel: tilelang.JITKernel, - weights: Dict, - padding_M: int = 128): + def __init__( + self, config: Dict, shared_kernel: tilelang.JITKernel, routed_kernel: tilelang.JITKernel, weights: Dict, padding_M: int = 128 + ): super().__init__() self.config = config self.shared_kernel = shared_kernel self.routed_kernel = routed_kernel self.padding_M = padding_M - self.experts = nn.ModuleList([ - Expert( - config, - gate=weights[f'experts.{i}.0.weight'], - up=weights[f'experts.{i}.1.weight'], - down=weights[f'experts.{i}.2.weight']) for i in range(config["n_routed_experts"]) - ]) + self.experts = nn.ModuleList( + [ + Expert( + config, + gate=weights[f"experts.{i}.0.weight"], + up=weights[f"experts.{i}.1.weight"], + down=weights[f"experts.{i}.2.weight"], + ) + for i in range(config["n_routed_experts"]) + ] + ) self.device = torch.device("cuda") self.gating_network = MoEGate(config, weights).to(self.device) shared_expert_dim = config["d_expert"] * config["n_shared_experts"] self.shared_expert = Expert( config=config, - gate=weights['shared_experts.0.weight'], - up=weights['shared_experts.1.weight'], - down=weights['shared_experts.2.weight'], - d_expert=shared_expert_dim).to(self.device) + gate=weights["shared_experts.0.weight"], + up=weights["shared_experts.1.weight"], + down=weights["shared_experts.2.weight"], + d_expert=shared_expert_dim, + ).to(self.device) self.expert_cache = torch.zeros( - (config["batch_size"] * config["seq_len"], config["d_hidden"]), - dtype=torch.float16, - device=self.device) - self.stacked_expert_w_gate = torch.stack([expert.W_gate_weight for expert in self.experts], - dim=0) - self.stacked_expert_w_up = torch.stack([expert.W_up_weight for expert in self.experts], - dim=0) - self.stacked_expert_w_down = torch.stack([expert.W_down_weight for expert in self.experts], - dim=0) + (config["batch_size"] * config["seq_len"], config["d_hidden"]), dtype=torch.float16, device=self.device + ) + self.stacked_expert_w_gate = torch.stack([expert.W_gate_weight for expert in self.experts], dim=0) + self.stacked_expert_w_up = torch.stack([expert.W_up_weight for expert in self.experts], dim=0) + self.stacked_expert_w_down = torch.stack([expert.W_down_weight for expert in self.experts], dim=0) self.stacked_expert_tokens = torch.empty( - (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], - self.config["d_hidden"]), + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_hidden"]), dtype=torch.float16, - device=self.device) + device=self.device, + ) self.stacked_expert_weights = torch.empty( - (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), - dtype=torch.float16, - device=self.device) + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.float16, device=self.device + ) self.stacked_expert_tokens_idxs = torch.empty( - (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), - dtype=torch.int64, - device=self.device) + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.int64, device=self.device + ) self.up_logits_shared = torch.empty( - (config["batch_size"] * config["seq_len"], self.config["d_expert"]), - dtype=torch.float16, - device=self.device) + (config["batch_size"] * config["seq_len"], self.config["d_expert"]), dtype=torch.float16, device=self.device + ) self.expert_output_shared = torch.empty( - (config["batch_size"] * config["seq_len"], self.config["d_hidden"]), - dtype=torch.float16, - device=self.device) + (config["batch_size"] * config["seq_len"], self.config["d_hidden"]), dtype=torch.float16, device=self.device + ) self.up_logits_routed = torch.empty( - (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], - self.config["d_expert"]), + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_expert"]), dtype=torch.float16, - device=self.device) + device=self.device, + ) self.expert_output_routed = torch.empty( - (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], - self.config["d_hidden"]), + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_hidden"]), dtype=torch.float16, - device=self.device) + device=self.device, + ) @torch.no_grad() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -413,22 +376,20 @@ class MoE(nn.Module): self.stacked_expert_tokens[start_idx:end_idx] = expert_tokens self.stacked_expert_tokens_idxs[start_idx:end_idx] = exp_token_idxs - self.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[ - idxs[start_idx:end_idx]] + self.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[idxs[start_idx:end_idx]] group_sizes = torch.tensor(counts, dtype=torch.int32, device=self.device) - group_offset = torch.tensor( - tokens_per_expert - counts, dtype=torch.int32, device=self.device) + group_offset = torch.tensor(tokens_per_expert - counts, dtype=torch.int32, device=self.device) group_padded_offsets = [0 for _ in range(len(group_sizes))] for i in range(1, len(group_sizes)): - group_padded_offsets[i] = group_padded_offsets[i - 1] + math.ceil( - (counts[i - 1] + 1) / self.padding_M) * self.padding_M + group_padded_offsets[i] = group_padded_offsets[i - 1] + math.ceil((counts[i - 1] + 1) / self.padding_M) * self.padding_M block_token = 128 - M = math.ceil( - self.config["batch_size"] * self.config["seq_len"] * - self.config["n_experts_per_token"] / block_token) + self.config["n_routed_experts"] + M = ( + math.ceil(self.config["batch_size"] * self.config["seq_len"] * self.config["n_experts_per_token"] / block_token) + + self.config["n_routed_experts"] + ) group_idx_for_bx = [0 for _ in range(M)] for bx in range(M): @@ -437,8 +398,7 @@ class MoE(nn.Module): if m_start_padded >= group_padded_offsets[i]: group_idx_for_bx[bx] = i - group_padded_offsets = torch.tensor( - group_padded_offsets, dtype=torch.int32, device=self.device) + group_padded_offsets = torch.tensor(group_padded_offsets, dtype=torch.int32, device=self.device) group_idx_for_bx = torch.tensor(group_idx_for_bx, dtype=torch.int32, device=self.device) # Multi-stream execution @@ -448,11 +408,19 @@ class MoE(nn.Module): with torch.cuda.stream(routed_stream): # Tilelang version: Grouped GEMM - self.routed_kernel(self.stacked_expert_tokens, self.stacked_expert_w_gate, - self.stacked_expert_w_up, self.stacked_expert_w_down, - self.stacked_expert_weights, group_sizes, group_offset, - group_padded_offsets, group_idx_for_bx, self.up_logits_routed, - self.expert_output_routed) + self.routed_kernel( + self.stacked_expert_tokens, + self.stacked_expert_w_gate, + self.stacked_expert_w_up, + self.stacked_expert_w_down, + self.stacked_expert_weights, + group_sizes, + group_offset, + group_padded_offsets, + group_idx_for_bx, + self.up_logits_routed, + self.expert_output_routed, + ) # Scatter reduce self.expert_cache = torch.scatter_reduce( @@ -460,14 +428,19 @@ class MoE(nn.Module): 0, self.stacked_expert_tokens_idxs.view(-1, 1).repeat(1, x_flat.shape[-1]), self.expert_output_routed, - reduce='sum') + reduce="sum", + ) routed_output = self.expert_cache.view(*orig_shape) with torch.cuda.stream(shared_stream): - - self.shared_kernel(x_flat, self.shared_expert.W_gate_weight, - self.shared_expert.W_up_weight, self.shared_expert.W_down_weight, - self.up_logits_shared, self.expert_output_shared) + self.shared_kernel( + x_flat, + self.shared_expert.W_gate_weight, + self.shared_expert.W_up_weight, + self.shared_expert.W_down_weight, + self.up_logits_shared, + self.expert_output_shared, + ) shared_output = self.expert_output_shared.view(*orig_shape) torch.cuda.synchronize() @@ -498,7 +471,8 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: config["d_expert"], config["n_shared_experts"], dtype=dtype_str, - num_tokens=config["batch_size"] * config["seq_len"]) + num_tokens=config["batch_size"] * config["seq_len"], + ) routed_kernel = moe_forward_tilelang_routed( config["d_hidden"], config["d_expert"], @@ -512,7 +486,8 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: threads=256, num_stages=1, k_pack=1, - coalesced_width=2) + coalesced_width=2, + ) moe = MoE(config, shared_kernel, routed_kernel, weights, padding_M=128) @@ -521,13 +496,7 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: return output -def main(d_hidden=7168, - d_expert=2048, - n_routed_experts=8, - n_shared_experts=1, - n_experts_per_token=4, - batch_size=1, - seq_len=8192): +def main(d_hidden=7168, d_expert=2048, n_routed_experts=8, n_shared_experts=1, n_experts_per_token=4, batch_size=1, seq_len=8192): config = { "dhidden": d_hidden, "dexpert": d_expert, @@ -536,7 +505,7 @@ def main(d_hidden=7168, "nexpertspertoken": n_experts_per_token, "bs": batch_size, "seqlen": seq_len, - "seed": 81394 + "seed": 81394, } data = generate_input(**config) diff --git a/examples/fusedmoe/example_fusedmoe_torch.py b/examples/fusedmoe/example_fusedmoe_torch.py index 00219c6..6b6322a 100644 --- a/examples/fusedmoe/example_fusedmoe_torch.py +++ b/examples/fusedmoe/example_fusedmoe_torch.py @@ -6,7 +6,6 @@ from typing import Dict, Tuple, Optional # Reference code in PyTorch class ExpertTorch(nn.Module): - def __init__(self, config: Dict, d_expert: Optional[int] = None): super().__init__() self.config = config @@ -25,7 +24,6 @@ class ExpertTorch(nn.Module): class MoEGateTorch(nn.Module): - def __init__(self, config: Dict): super().__init__() self.top_k: int = config["n_experts_per_token"] @@ -43,12 +41,10 @@ class MoEGateTorch(nn.Module): class MoETorch(nn.Module): - def __init__(self, config: Dict): super().__init__() self.config = config - self.experts = nn.ModuleList( - [ExpertTorch(config) for _ in range(config["n_routed_experts"])]) + self.experts = nn.ModuleList([ExpertTorch(config) for _ in range(config["n_routed_experts"])]) self.gating_network = MoEGateTorch(config) shared_expert_dim = config["d_expert"] * config["n_shared_experts"] self.shared_expert = ExpertTorch(config=config, d_expert=shared_expert_dim) @@ -67,8 +63,7 @@ class MoETorch(nn.Module): return routed_output + shared_output @torch.no_grad() - def moe_infer(self, x: torch.Tensor, flat_expert_indices: torch.Tensor, - flat_expert_weights: torch.Tensor) -> torch.Tensor: + def moe_infer(self, x: torch.Tensor, flat_expert_indices: torch.Tensor, flat_expert_weights: torch.Tensor) -> torch.Tensor: expert_cache = torch.zeros_like(x) # test_expert_cache = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"])) # test_expert_tokens = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"])) @@ -91,8 +86,7 @@ class MoETorch(nn.Module): expert_out = expert(expert_tokens) expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) - expert_cache.scatter_reduce_( - 0, exp_token_idxs.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum') + expert_cache.scatter_reduce_(0, exp_token_idxs.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce="sum") return expert_cache @@ -116,21 +110,21 @@ def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: moe = MoETorch(config) # Fill in the given weights of the model - moe.gating_network.W_g.weight = nn.Parameter(weights['router.weight']) + moe.gating_network.W_g.weight = nn.Parameter(weights["router.weight"]) for i in range(num_experts): - gate_proj_weight = weights[f'experts.{i}.0.weight'] - up_proj_weight = weights[f'experts.{i}.1.weight'] - down_proj_weight = weights[f'experts.{i}.2.weight'] + gate_proj_weight = weights[f"experts.{i}.0.weight"] + up_proj_weight = weights[f"experts.{i}.1.weight"] + down_proj_weight = weights[f"experts.{i}.2.weight"] # Transpose weights to match expected shape for nn.Linear moe.experts[i].W_gate.weight = nn.Parameter(gate_proj_weight.t()) moe.experts[i].W_up.weight = nn.Parameter(up_proj_weight.t()) moe.experts[i].W_down.weight = nn.Parameter(down_proj_weight.t()) - moe.shared_expert.W_gate.weight = nn.Parameter(weights['shared_experts.0.weight'].t()) - moe.shared_expert.W_up.weight = nn.Parameter(weights['shared_experts.1.weight'].t()) - moe.shared_expert.W_down.weight = nn.Parameter(weights['shared_experts.2.weight'].t()) + moe.shared_expert.W_gate.weight = nn.Parameter(weights["shared_experts.0.weight"].t()) + moe.shared_expert.W_up.weight = nn.Parameter(weights["shared_experts.1.weight"].t()) + moe.shared_expert.W_down.weight = nn.Parameter(weights["shared_experts.2.weight"].t()) output = moe(input_tensor) @@ -140,10 +134,9 @@ def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: # Input generation for the reference code -def generate_input(dhidden: int, dexpert: int, nroutedexperts: int, nsharedexperts: int, - nexpertspertoken: int, bs: int, seqlen: int, - seed: int) -> Tuple[torch.Tensor, Dict, Dict]: - +def generate_input( + dhidden: int, dexpert: int, nroutedexperts: int, nsharedexperts: int, nexpertspertoken: int, bs: int, seqlen: int, seed: int +) -> Tuple[torch.Tensor, Dict, Dict]: # Really dumb but for now _ isn't parsing correctly. d_hidden = dhidden d_expert = dexpert @@ -163,50 +156,40 @@ def generate_input(dhidden: int, dexpert: int, nroutedexperts: int, nsharedexper "seq_len": seq_len, } - gen = torch.Generator(device='cuda') + gen = torch.Generator(device="cuda") gen.manual_seed(seed) num_experts = n_routed_experts expert_dim = d_expert weights = {} - input_tensor = torch.randn((batch_size, seq_len, d_hidden), - device='cuda', - dtype=torch.float16, - generator=gen).contiguous() + input_tensor = torch.randn((batch_size, seq_len, d_hidden), device="cuda", dtype=torch.float16, generator=gen).contiguous() # Initialize router weights - weights['router.weight'] = torch.randn( - (num_experts, d_hidden), device="cuda", dtype=torch.float16, - generator=gen) / math.sqrt(d_hidden) + weights["router.weight"] = torch.randn((num_experts, d_hidden), device="cuda", dtype=torch.float16, generator=gen) / math.sqrt(d_hidden) for i in range(num_experts): - weights[f'experts.{i}.0.weight'] = torch.randn( - (d_hidden, expert_dim), device='cuda', dtype=torch.float16, - generator=gen) / math.sqrt(expert_dim) - - weights[f'experts.{i}.1.weight'] = torch.randn( - (d_hidden, expert_dim), device='cuda', dtype=torch.float16, - generator=gen) / math.sqrt(expert_dim) - - weights[f'experts.{i}.2.weight'] = torch.randn( - (expert_dim, d_hidden), device='cuda', dtype=torch.float16, - generator=gen) / math.sqrt(d_hidden) - - weights['shared_experts.0.weight'] = torch.randn( - (d_hidden, expert_dim * n_shared_experts), - device='cuda', - dtype=torch.float16, - generator=gen) / math.sqrt(expert_dim * n_shared_experts) - weights['shared_experts.1.weight'] = torch.randn( - (d_hidden, expert_dim * n_shared_experts), - device='cuda', - dtype=torch.float16, - generator=gen) / math.sqrt(expert_dim * n_shared_experts) - weights['shared_experts.2.weight'] = torch.randn((expert_dim * n_shared_experts, d_hidden), - device='cuda', - dtype=torch.float16, - generator=gen) / math.sqrt(d_hidden) + weights[f"experts.{i}.0.weight"] = torch.randn( + (d_hidden, expert_dim), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(expert_dim) + + weights[f"experts.{i}.1.weight"] = torch.randn( + (d_hidden, expert_dim), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(expert_dim) + + weights[f"experts.{i}.2.weight"] = torch.randn( + (expert_dim, d_hidden), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(d_hidden) + + weights["shared_experts.0.weight"] = torch.randn( + (d_hidden, expert_dim * n_shared_experts), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(expert_dim * n_shared_experts) + weights["shared_experts.1.weight"] = torch.randn( + (d_hidden, expert_dim * n_shared_experts), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(expert_dim * n_shared_experts) + weights["shared_experts.2.weight"] = torch.randn( + (expert_dim * n_shared_experts, d_hidden), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(d_hidden) return (input_tensor, weights, config) diff --git a/examples/fusedmoe/test_example_fusedmoe.py b/examples/fusedmoe/test_example_fusedmoe.py index 806aff4..ba84158 100644 --- a/examples/fusedmoe/test_example_fusedmoe.py +++ b/examples/fusedmoe/test_example_fusedmoe.py @@ -4,13 +4,8 @@ import example_fusedmoe_tilelang def test_example_fusedmoe_tilelang(): example_fusedmoe_tilelang.main( - d_hidden=1024, - d_expert=256, - n_routed_experts=8, - n_shared_experts=1, - n_experts_per_token=4, - batch_size=1, - seq_len=1024) + d_hidden=1024, d_expert=256, n_routed_experts=8, n_shared_experts=1, n_experts_per_token=4, batch_size=1, seq_len=1024 + ) if __name__ == "__main__": diff --git a/examples/gdn/example_chunk_delta_bwd.py b/examples/gdn/example_chunk_delta_bwd.py index d9ccc25..ecda7e4 100644 --- a/examples/gdn/example_chunk_delta_bwd.py +++ b/examples/gdn/example_chunk_delta_bwd.py @@ -12,6 +12,7 @@ print(tilelang.__file__, flush=True) # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__, flush=True) from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu except ImportError: @@ -49,6 +50,7 @@ def prepare_input( G = F.logsigmoid(G) try: from fla.ops.utils.cumsum import chunk_local_cumsum + G = chunk_local_cumsum(G, chunk_size) except ImportError: print("fla not found, skip cumsum") @@ -125,8 +127,11 @@ def torch_chunk_gated_delta_rule_bwd_dhu( DV = dv.shape[-1] block_S = 64 BS = S // block_S - dh, dh0, dv2 = torch.empty((B, BS, H, DK, DV), dtype=output_dtype), torch.empty( - (B, H, DK, DV), dtype=state_dtype), torch.empty((B, S, H, DV), dtype=output_dtype) + dh, dh0, dv2 = ( + torch.empty((B, BS, H, DK, DV), dtype=output_dtype), + torch.empty((B, H, DK, DV), dtype=state_dtype), + torch.empty((B, S, H, DV), dtype=output_dtype), + ) dh_tmp = torch.empty((B, H, DK, DV), dtype=accum_dtype) dv_tmp = torch.empty((B, S, H, DV), dtype=accum_dtype) Q_tmp = torch.empty((B, S, H, DK), dtype=accum_dtype) @@ -138,34 +143,30 @@ def torch_chunk_gated_delta_rule_bwd_dhu( for i_s in range(BS - 1, -1, -1): dh[:, i_s, :, :, :] = dh_tmp - dv_tmp = torch.matmul(K[:, i_s * block_S:(i_s + 1) * block_S, :, :].permute(0, 2, 1, 3), - dh_tmp.to(K.dtype)).permute(0, 2, 1, 3) + dv_tmp = torch.matmul(K[:, i_s * block_S : (i_s + 1) * block_S, :, :].permute(0, 2, 1, 3), dh_tmp.to(K.dtype)).permute(0, 2, 1, 3) if use_g: for i_bh in range(B * H): i_b, i_h = i_bh // H, i_bh % H for i_s2 in range(block_S): - if G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, - i_h] <= 0: - dv_tmp[i_b, i_s2, - i_h, :] *= torch.exp(G[i_b, i_s * block_S + block_S - 1, i_h] - - G[i_b, i_s * block_S + i_s2, i_h]) + if G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, i_h] <= 0: + dv_tmp[i_b, i_s2, i_h, :] *= torch.exp(G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, i_h]) else: dv_tmp[i_b, i_s2, i_h, :] = 0 - dv_tmp += dv[:, i_s * block_S:(i_s + 1) * block_S, :, :] - dv2[:, i_s * block_S:(i_s + 1) * block_S, :, :] = dv_tmp + dv_tmp += dv[:, i_s * block_S : (i_s + 1) * block_S, :, :] + dv2[:, i_s * block_S : (i_s + 1) * block_S, :, :] = dv_tmp if use_g: G_last = G[:, i_s * block_S + block_S - 1, :] for i_bh in range(B * H): i_b, i_h = i_bh // H, i_bh % H dh_tmp[i_b, i_h, :, :] *= torch.exp(G_last[i_b, i_h]) - Q_tmp = Q[:, i_s * block_S:(i_s + 1) * block_S, :, :] + Q_tmp = Q[:, i_s * block_S : (i_s + 1) * block_S, :, :] for i_s2 in range(block_S): for i_k in range(DK): Q_tmp[:, i_s2, :, i_k] *= torch.exp(G[:, i_s * block_S + i_s2, :]) Q_tmp *= scale - W_tmp = W[:, i_s * block_S:(i_s + 1) * block_S, :, :] - dO_tmp = dO[:, i_s * block_S:(i_s + 1) * block_S, :, :] + W_tmp = W[:, i_s * block_S : (i_s + 1) * block_S, :, :] + dO_tmp = dO[:, i_s * block_S : (i_s + 1) * block_S, :, :] torch.backends.cuda.matmul.allow_tf32 = True dh_tmp += torch.matmul(Q_tmp.permute(0, 2, 3, 1), dO_tmp.permute(0, 2, 1, 3)) @@ -223,19 +224,19 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu( @T.prim_func def kernel( - # Input - Q: T.Tensor(Q_shape, dtype=input_dtype), - K: T.Tensor(K_shape, dtype=input_dtype), - W: T.Tensor(W_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=gate_dtype), - h0: T.Tensor(h0_shape, dtype=input_dtype), - dht: T.Tensor(dht_shape, dtype=input_dtype), - dO: T.Tensor(dO_shape, dtype=input_dtype), - dv: T.Tensor(dv_shape, dtype=input_dtype), - # Output - dh: T.Tensor(dh_shape, dtype=output_dtype), - dh0: T.Tensor(dh0_shape, dtype=state_dtype), - dv2: T.Tensor(dv2_shape, dtype=output_dtype), + # Input + Q: T.Tensor(Q_shape, dtype=input_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + W: T.Tensor(W_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + h0: T.Tensor(h0_shape, dtype=input_dtype), + dht: T.Tensor(dht_shape, dtype=input_dtype), + dO: T.Tensor(dO_shape, dtype=input_dtype), + dv: T.Tensor(dv_shape, dtype=input_dtype), + # Output + dh: T.Tensor(dh_shape, dtype=output_dtype), + dh0: T.Tensor(dh0_shape, dtype=state_dtype), + dv2: T.Tensor(dv2_shape, dtype=output_dtype), ): with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh): bb, bh = bbh // H, bbh % H @@ -269,20 +270,22 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu( T.use_swizzle(10) - T.annotate_layout({ - b_dh_shared: tilelang.layout.make_swizzled_layout(b_dh_shared), - b_dh_shared_fp32: tilelang.layout.make_swizzled_layout(b_dh_shared_fp32), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dO_shared: tilelang.layout.make_swizzled_layout(dO_shared), - dO_shared_t: tilelang.layout.make_swizzled_layout(dO_shared_t), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), - Q_shared_fp32: tilelang.layout.make_swizzled_layout(Q_shared_fp32), - W_shared: tilelang.layout.make_swizzled_layout(W_shared), - }) + T.annotate_layout( + { + b_dh_shared: tilelang.layout.make_swizzled_layout(b_dh_shared), + b_dh_shared_fp32: tilelang.layout.make_swizzled_layout(b_dh_shared_fp32), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dO_shared: tilelang.layout.make_swizzled_layout(dO_shared), + dO_shared_t: tilelang.layout.make_swizzled_layout(dO_shared_t), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), + Q_shared_fp32: tilelang.layout.make_swizzled_layout(Q_shared_fp32), + W_shared: tilelang.layout.make_swizzled_layout(W_shared), + } + ) if use_final_state_gradient: - T.copy(dht[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV], b_dh_shared) + T.copy(dht[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_dh_shared) T.copy(b_dh_shared, b_dh_fragment) else: T.clear(b_dh_fragment) @@ -293,17 +296,14 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu( # Store the updated dh T.copy(b_dh_fragment, b_dh_shared) - T.copy(b_dh_shared, dh[bb, i_s_inv, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV]) + T.copy(b_dh_shared, dh[bb, i_s_inv, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) # Update dv - T.copy(K[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, 0:DK], K_shared) + T.copy(K[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], K_shared) T.gemm(K_shared, b_dh_shared, dv_fragment, clear_accum=True) if use_g: - T.copy( - G[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh], - G_shared, - disable_tma=True) + T.copy(G[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh], G_shared, disable_tma=True) T.copy(G_shared, G_fragment) G_last_local[0] = G_shared[block_S - 1] G_last_local_exp[0] = T.exp(G_last_local[0]) @@ -313,27 +313,22 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu( # with T.If(G_last_local[0] - G_shared[i_s2] <= 0): with T.If(G_last_local[0] - G_fragment[i_s2] <= 0): with T.Then(): - dv_fragment[i_s2, - i_v] = dv_fragment[i_s2, i_v] * G_fragment_post[i_s2] + dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] * G_fragment_post[i_s2] with T.Else(): dv_fragment[i_s2, i_v] = 0 - T.copy( - dv[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, - bv * block_DV:(bv + 1) * block_DV], dv_shared) + T.copy(dv[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dv_shared) T.copy(dv_shared, dv_fragment_2) for i_s2, i_v in T.Parallel(block_S, block_DV): dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] + dv_fragment_2[i_s2, i_v] # Store the updated dv T.copy(dv_fragment, dv_shared) - T.copy( - dv_shared, dv2[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, - bv * block_DV:(bv + 1) * block_DV]) + T.copy(dv_shared, dv2[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV]) # Update dh - T.copy(Q[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, 0:DK], Q_shared) - T.copy(W[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, 0:DK], W_shared) + T.copy(Q[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], Q_shared) + T.copy(W[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], W_shared) T.clear(Q_fragment) if use_g: @@ -353,9 +348,7 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu( for i_s2, i_k in T.Parallel(block_S, DK): Q_fragment_t[i_k, i_s2] = Q_fragment[i_s2, i_k] - T.copy( - dO[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, - bv * block_DV:(bv + 1) * block_DV], dO_shared) + T.copy(dO[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dO_shared) T.copy(dO_shared, dO_fragment) for i_s2, i_v in T.Parallel(block_S, block_DV): dO_fragment_t[i_v, i_s2] = dO_fragment[i_s2, i_v] @@ -369,7 +362,7 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu( b_dh_fragment[i_k, i_v] += b_dh_fragment_1[i_k, i_v] - b_dh_fragment_2[i_k, i_v] if use_initial_state: - T.copy(b_dh_fragment, dh0[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV]) + T.copy(b_dh_fragment, dh0[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) return kernel @@ -444,44 +437,61 @@ def run_test( num_stages=0, use_torch=False, ): - Q, K, W, G, h0, dht, dO, dv = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, accum_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype)) - dh_ref, dh0_ref, dv2_ref = prepare_output(B, S, H, DK, DV, chunk_size, - getattr(torch, output_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype)) - dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, - getattr(torch, output_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype)) + Q, K, W, G, h0, dht, dO, dv = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + dh_ref, dh0_ref, dv2_ref = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) + dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) # fla ref print("fla running...", flush=True) if use_g: - dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, - scale) + dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, scale) else: G = G.fill_(0) - dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, - scale) + dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, scale) # tilelang print("tilelang running...", flush=True) - kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(B, S, H, DK, DV, input_dtype, output_dtype, - accum_dtype, gate_dtype, state_dtype, - chunk_size, scale, use_g, use_initial_state, - use_final_state_gradient, block_DV, threads, - num_stages) + kernel = tilelang_chunk_gated_delta_rule_bwd_dhu( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + scale, + use_g, + use_initial_state, + use_final_state_gradient, + block_DV, + threads, + num_stages, + ) # kernel = tilelang.compile(program) print(kernel.get_kernel_source()) dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv) - fla_time = do_bench( - chunk_gated_delta_rule_bwd_dhu, Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size) + fla_time = do_bench(chunk_gated_delta_rule_bwd_dhu, Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size) tilelang_time = do_bench(kernel, Q, K, W, G, h0, dht, dO, dv) print(f"fla time: {fla_time} ms") @@ -496,19 +506,47 @@ def run_test( print("torch running...", flush=True) if use_g: dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu( - Q, K, W, G, h0, dht, dO, dv, scale, use_g, use_initial_state, - use_final_state_gradient, getattr(torch, input_dtype), getattr(torch, output_dtype), - getattr(torch, accum_dtype), getattr(torch, - gate_dtype), getattr(torch, state_dtype)) + Q, + K, + W, + G, + h0, + dht, + dO, + dv, + scale, + use_g, + use_initial_state, + use_final_state_gradient, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) dh_ref_torch = dh_ref_torch.cuda() dh0_ref_torch = dh0_ref_torch.cuda() dv2_ref_torch = dv2_ref_torch.cuda() else: dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu( - Q, K, W, None, h0, dht, dO, dv, scale, use_g, use_initial_state, - use_final_state_gradient, getattr(torch, input_dtype), getattr(torch, output_dtype), - getattr(torch, accum_dtype), getattr(torch, - gate_dtype), getattr(torch, state_dtype)) + Q, + K, + W, + None, + h0, + dht, + dO, + dv, + scale, + use_g, + use_initial_state, + use_final_state_gradient, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) dh_ref_torch = dh_ref_torch.cuda() dh0_ref_torch = dh0_ref_torch.cuda() dv2_ref_torch = dv2_ref_torch.cuda() diff --git a/examples/gdn/example_chunk_delta_h.py b/examples/gdn/example_chunk_delta_h.py index cc384ad..43f1e97 100644 --- a/examples/gdn/example_chunk_delta_h.py +++ b/examples/gdn/example_chunk_delta_h.py @@ -10,6 +10,7 @@ from tilelang.autotuner import autotune # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__) from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_fwd_h except ImportError: @@ -56,6 +57,7 @@ def prepare_input( G = F.logsigmoid(G) try: from fla.ops.utils.cumsum import chunk_local_cumsum + G = chunk_local_cumsum(G, chunk_size) except ImportError: print("fla not found, skip cumsum") @@ -83,18 +85,14 @@ def prepare_output( def get_configs(): import itertools + block_DK = [32, 64, 128] block_DV = [32, 64, 128] threads = [128, 256] num_stages = [1, 2, 3] _configs = list(itertools.product(block_DK, block_DV, threads, num_stages)) - configs = [{ - 'block_DK': c[0], - 'block_DV': c[1], - 'threads': c[2], - 'num_stages': c[3] - } for c in _configs] + configs = [{"block_DK": c[0], "block_DV": c[1], "threads": c[2], "num_stages": c[3]} for c in _configs] return configs @@ -137,14 +135,14 @@ def tilelang_chunk_gated_delta_rule_fwd_h( @T.prim_func def kernel( - K: T.Tensor(K_shape, dtype=input_dtype), - W: T.Tensor(W_shape, dtype=input_dtype), - U: T.Tensor(U_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=gate_dtype), - initial_state: T.Tensor(initial_state_shape, dtype=input_dtype), - h: T.Tensor(h_shape, dtype=output_dtype), - final_state: T.Tensor(final_state_shape, dtype=state_dtype), - V_new: T.Tensor(V_shape, dtype=output_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + W: T.Tensor(W_shape, dtype=input_dtype), + U: T.Tensor(U_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + initial_state: T.Tensor(initial_state_shape, dtype=input_dtype), + h: T.Tensor(h_shape, dtype=output_dtype), + final_state: T.Tensor(final_state_shape, dtype=state_dtype), + V_new: T.Tensor(V_shape, dtype=output_dtype), ): with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh): bb, bh = bbh // H, bbh % H @@ -162,35 +160,35 @@ def tilelang_chunk_gated_delta_rule_fwd_h( G_shared = T.alloc_shared((block_S, block_DV), dtype=gate_dtype) G_fragment = T.alloc_fragment((block_S, block_DV), dtype=gate_dtype) - T.annotate_layout({ - b_h_shared: tilelang.layout.make_swizzled_layout(b_h_shared), - U_shared: tilelang.layout.make_swizzled_layout(U_shared), - W_shared: tilelang.layout.make_swizzled_layout(W_shared), - V_new_shared: tilelang.layout.make_swizzled_layout(V_new_shared), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - G_shared: tilelang.layout.make_swizzled_layout(G_shared), - }) + T.annotate_layout( + { + b_h_shared: tilelang.layout.make_swizzled_layout(b_h_shared), + U_shared: tilelang.layout.make_swizzled_layout(U_shared), + W_shared: tilelang.layout.make_swizzled_layout(W_shared), + V_new_shared: tilelang.layout.make_swizzled_layout(V_new_shared), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + G_shared: tilelang.layout.make_swizzled_layout(G_shared), + } + ) T.use_swizzle(10) if use_initial_state: - T.copy(initial_state[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV], b_h_shared) + T.copy(initial_state[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_h_shared) T.copy(b_h_shared, b_h_fragment) else: T.clear(b_h_fragment) for i_s in T.Pipelined(T.ceildiv(S, block_S), num_stages=num_stages): # Store previous result to the hidden tensor, like the epilogue - T.copy(b_h_shared, h[bb, i_s, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV]) + T.copy(b_h_shared, h[bb, i_s, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) # Recurrence - T.copy(W[bb, i_s * block_S:(i_s + 1) * block_S, bh, 0:DK], W_shared) + T.copy(W[bb, i_s * block_S : (i_s + 1) * block_S, bh, 0:DK], W_shared) T.gemm(W_shared, b_h_shared, V_new_fragment, clear_accum=True) # U - W * S - T.copy( - U[bb, i_s * block_S:(i_s + 1) * block_S, bh, bv * block_DV:(bv + 1) * block_DV], - U_shared) + T.copy(U[bb, i_s * block_S : (i_s + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], U_shared) T.copy(U_shared, U_fragment) for i_s2, i_v in T.Parallel(block_S, block_DV): V_new_fragment[i_s2, i_v] = -V_new_fragment[i_s2, i_v] + U_fragment[i_s2, i_v] @@ -198,11 +196,9 @@ def tilelang_chunk_gated_delta_rule_fwd_h( # Save V_new if save_new_value: T.copy(V_new_fragment, dst=V_new_shared) - T.copy( - V_new_shared, V_new[bb, i_s * block_S:(i_s + 1) * block_S, bh, - bv * block_DV:(bv + 1) * block_DV]) + T.copy(V_new_shared, V_new[bb, i_s * block_S : (i_s + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV]) - T.copy(K[bb, i_s * block_S:(i_s + 1) * block_S, bh, 0:DK], K_shared) + T.copy(K[bb, i_s * block_S : (i_s + 1) * block_S, bh, 0:DK], K_shared) # use_g if use_g: G_last_local[0] = G[bb, (i_s + 1) * block_S - 1, bh] @@ -213,7 +209,8 @@ def tilelang_chunk_gated_delta_rule_fwd_h( with T.If(G_last_local[0] - G_fragment[i_s2, i_v] <= 0): with T.Then(): V_new_fragment[i_s2, i_v] = V_new_fragment[i_s2, i_v] * T.exp2( - (G_last_local[0] - G_fragment[i_s2, i_v]) * 1.442695) + (G_last_local[0] - G_fragment[i_s2, i_v]) * 1.442695 + ) with T.Else(): V_new_fragment[i_s2, i_v] = 0 G_last_local[0] = T.exp2(G_last_local[0] * 1.442695) @@ -228,7 +225,7 @@ def tilelang_chunk_gated_delta_rule_fwd_h( # Save final state if store_final_state: - T.copy(b_h_fragment, final_state[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV]) + T.copy(b_h_fragment, final_state[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) return kernel @@ -279,17 +276,24 @@ def run_test( threads=128, num_stages=0, ): - K, W, U, G, initial_state = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, accum_dtype), - getattr(torch, gate_dtype)) - h_ref, final_state_ref, V_new_ref = prepare_output(B, S, H, DK, DV, chunk_size, - getattr(torch, output_dtype), - getattr(torch, state_dtype)) - h_tilelang, final_state_tilelang, V_new_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, - getattr(torch, output_dtype), - getattr(torch, state_dtype)) + K, W, U, G, initial_state = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + ) + h_ref, final_state_ref, V_new_ref = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, state_dtype) + ) + h_tilelang, final_state_tilelang, V_new_tilelang = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, state_dtype) + ) # fla ref h_ref, V_new_ref, final_state_ref = chunk_gated_delta_rule_fwd_h( @@ -300,13 +304,27 @@ def run_test( initial_state=initial_state, output_final_state=store_final_state, chunk_size=chunk_size, - save_new_value=save_new_value) + save_new_value=save_new_value, + ) # tilelang - kernel = tilelang_chunk_gated_delta_rule_fwd_h(B, S, H, DK, DV, input_dtype, output_dtype, - accum_dtype, gate_dtype, state_dtype, chunk_size, - use_g, use_initial_state, store_final_state, - save_new_value) + kernel = tilelang_chunk_gated_delta_rule_fwd_h( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + use_g, + use_initial_state, + store_final_state, + save_new_value, + ) h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state) # (zhengju) If you want to print the generated cuda code, you can uncomment the following line # print("CUDA Code:\n", kernel.get_kernel_source()) @@ -320,19 +338,15 @@ def run_test( initial_state=initial_state, output_final_state=store_final_state, chunk_size=chunk_size, - save_new_value=save_new_value) + save_new_value=save_new_value, + ) tilelang_time = do_bench(kernel, K, W, U, G, initial_state) # check correctness try: h_ref_fp32 = h_ref.to(torch.float32) h_tilelang_fp32 = h_tilelang.to(torch.float32) - assert_similar( - h_ref_fp32, - h_tilelang_fp32, - eps=1e-5, - name="tilelang chunk gated delta rule fwd h", - raise_assert=False) + assert_similar(h_ref_fp32, h_tilelang_fp32, eps=1e-5, name="tilelang chunk gated delta rule fwd h", raise_assert=False) print("tilelang chunk gated delta rule fwd h passed √") except Exception as e: print("tilelang chunk gated delta rule fwd h failed ✗") @@ -346,7 +360,8 @@ def run_test( final_state_tilelang_fp32, eps=1e-5, name="tilelang chunk gated delta rule fwd final_state", - raise_assert=False) + raise_assert=False, + ) print("tilelang chunk gated delta rule fwd final_state passed √") except Exception as e: print("tilelang chunk gated delta rule fwd final_state failed ✗") @@ -355,12 +370,7 @@ def run_test( try: V_new_ref_fp32 = V_new_ref.to(torch.float32) V_new_tilelang_fp32 = V_new_tilelang.to(torch.float32) - assert_similar( - V_new_ref_fp32, - V_new_tilelang_fp32, - eps=1e-5, - name="tilelang chunk gated delta rule fwd V_new", - raise_assert=False) + assert_similar(V_new_ref_fp32, V_new_tilelang_fp32, eps=1e-5, name="tilelang chunk gated delta rule fwd V_new", raise_assert=False) print("tilelang chunk gated delta rule fwd V_new passed √") except Exception as e: print("tilelang chunk gated delta rule fwd V_new failed ✗") diff --git a/examples/gdn/example_chunk_o.py b/examples/gdn/example_chunk_o.py index 1c084be..bd1e9aa 100644 --- a/examples/gdn/example_chunk_o.py +++ b/examples/gdn/example_chunk_o.py @@ -9,6 +9,7 @@ import sys # noqa: F401 # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__) from fla.ops.common.chunk_o import chunk_fwd_o except ImportError: @@ -87,16 +88,14 @@ def tilelang_chunk_fwd_o( @T.prim_func def kernel( - Q: T.Tensor(Q_shape, dtype=input_dtype), - K: T.Tensor(K_shape, dtype=input_dtype), - V: T.Tensor(V_shape, dtype=input_dtype), - HIDDEN: T.Tensor(H_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=gate_dtype), - O: T.Tensor(O_shape, dtype=output_dtype), + Q: T.Tensor(Q_shape, dtype=input_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + HIDDEN: T.Tensor(H_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + O: T.Tensor(O_shape, dtype=output_dtype), ): - with T.Kernel( - T.ceildiv(DV, block_DV), T.ceildiv(S, block_S), B * H, - threads=threads) as (bv, bs, bbh): + with T.Kernel(T.ceildiv(DV, block_DV), T.ceildiv(S, block_S), B * H, threads=threads) as (bv, bs, bbh): bb, bh = bbh // H, bbh % H Q_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) @@ -109,28 +108,24 @@ def tilelang_chunk_fwd_o( G_shared = T.alloc_shared((block_S,), dtype=gate_dtype, scope="shared") G_diff_local = T.alloc_fragment((block_S, block_S), dtype=gate_dtype) - T.annotate_layout({ - Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - V_shared: tilelang.layout.make_swizzled_layout(V_shared), - H_shared: tilelang.layout.make_swizzled_layout(H_shared), - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - }) + T.annotate_layout( + { + Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + V_shared: tilelang.layout.make_swizzled_layout(V_shared), + H_shared: tilelang.layout.make_swizzled_layout(H_shared), + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + } + ) T.clear(A_fragment) T.clear(O_fragment) T.disable_warp_group_reg_alloc() for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): - T.copy( - Q[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK], - Q_shared) - T.copy( - K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK], - K_shared) - T.copy( - HIDDEN[bb, bs, bh, i_k * block_DK:(i_k + 1) * block_DK, - bv * block_DV:(bv + 1) * block_DV], H_shared) + T.copy(Q[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], Q_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) + T.copy(HIDDEN[bb, bs, bh, i_k * block_DK : (i_k + 1) * block_DK, bv * block_DV : (bv + 1) * block_DV], H_shared) T.gemm(Q_shared, H_shared, O_fragment) T.gemm(Q_shared, K_shared, A_fragment, transpose_B=True) @@ -145,8 +140,7 @@ def tilelang_chunk_fwd_o( for i_s1, i_s2 in T.Parallel(block_S, block_S): with T.If(G_diff_local[i_s1, i_s2] <= 0): with T.Then(): - A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp( - G_diff_local[i_s1, i_s2]) + A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2]) with T.Else(): A_fragment[i_s1, i_s2] = 0 @@ -155,8 +149,7 @@ def tilelang_chunk_fwd_o( with T.Then(): A_fragment[i_s1, i_s2] = 0 - T.copy(V[bb, bs * block_S:(bs + 1) * block_S, bh, bv * block_DV:(bv + 1) * block_DV], - V_shared) + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], V_shared) T.copy(A_fragment, A_shared) T.gemm(A_shared, V_shared, O_fragment) @@ -164,8 +157,7 @@ def tilelang_chunk_fwd_o( O_fragment[i_s, i_v] = O_fragment[i_s, i_v] * scale T.copy(O_fragment, O_shared) - T.copy(O_shared, O[bb, bs * block_S:(bs + 1) * block_S, bh, - bv * block_DV:(bv + 1) * block_DV]) + T.copy(O_shared, O[bb, bs * block_S : (bs + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV]) return kernel @@ -191,8 +183,9 @@ def run_test( output_dtype_torch = getattr(torch, output_dtype) accum_dtype_torch = getattr(torch, accum_dtype) gate_dtype_torch = getattr(torch, gate_dtype) - Q, K, V, HIDDEN, G = prepare_input(B, S, H, DK, DV, chunk_size, input_dtype_torch, - output_dtype_torch, accum_dtype_torch, gate_dtype_torch) + Q, K, V, HIDDEN, G = prepare_input( + B, S, H, DK, DV, chunk_size, input_dtype_torch, output_dtype_torch, accum_dtype_torch, gate_dtype_torch + ) scale = 1.0 / DK**0.5 O_ref = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch) @@ -200,9 +193,25 @@ def run_test( block_S = chunk_size O_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch) - kernel = tilelang_chunk_fwd_o(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, - gate_dtype, chunk_size, scale, use_g, block_S, block_DK, block_DV, - threads, num_stages) + kernel = tilelang_chunk_fwd_o( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + chunk_size, + scale, + use_g, + block_S, + block_DK, + block_DV, + threads, + num_stages, + ) O_tilelang = kernel(Q, K, V, HIDDEN, G) try: diff --git a/examples/gdn/example_chunk_o_bwd.py b/examples/gdn/example_chunk_o_bwd.py index 20aa841..66cb694 100644 --- a/examples/gdn/example_chunk_o_bwd.py +++ b/examples/gdn/example_chunk_o_bwd.py @@ -12,6 +12,7 @@ from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F4 # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__) from fla.ops.common.chunk_o import chunk_bwd_dqkwg except ImportError: @@ -108,10 +109,8 @@ def prepare_output( @tilelang.jit( out_idx=[-4, -3, -2, -1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True - }) + pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, +) def tilelang_chunk_o_bwd_dqkwg( # task config B, @@ -155,25 +154,23 @@ def tilelang_chunk_o_bwd_dqkwg( @T.prim_func def kernel( - # input - Q: T.Tensor(Q_shape, dtype=input_dtype), - K: T.Tensor(K_shape, dtype=input_dtype), - V: T.Tensor(V_shape, dtype=input_dtype), - h: T.Tensor(h_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=gate_dtype), - dO: T.Tensor(dO_shape, dtype=input_dtype), - dh: T.Tensor(dh_shape, dtype=input_dtype), - dv: T.Tensor(dv_shape, dtype=input_dtype), - W: T.Tensor(W_shape, dtype=input_dtype), - # output - dq: T.Tensor(dq_shape, dtype=output_dtype), - dk: T.Tensor(dk_shape, dtype=output_dtype), - dw: T.Tensor(dw_shape, dtype=output_dtype), - dg: T.Tensor(dg_shape, dtype=gate_dtype), + # input + Q: T.Tensor(Q_shape, dtype=input_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + h: T.Tensor(h_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + dO: T.Tensor(dO_shape, dtype=input_dtype), + dh: T.Tensor(dh_shape, dtype=input_dtype), + dv: T.Tensor(dv_shape, dtype=input_dtype), + W: T.Tensor(W_shape, dtype=input_dtype), + # output + dq: T.Tensor(dq_shape, dtype=output_dtype), + dk: T.Tensor(dk_shape, dtype=output_dtype), + dw: T.Tensor(dw_shape, dtype=output_dtype), + dg: T.Tensor(dg_shape, dtype=gate_dtype), ): - with T.Kernel( - T.ceildiv(DK, block_DK), T.ceildiv(S, block_S), B * H, - threads=threads) as (bk, bs, bbh): + with T.Kernel(T.ceildiv(DK, block_DK), T.ceildiv(S, block_S), B * H, threads=threads) as (bk, bs, bbh): bb, bh = bbh // H, bbh % H V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) @@ -212,15 +209,17 @@ def tilelang_chunk_o_bwd_dqkwg( T.use_swizzle(10) - T.annotate_layout({ - V_shared: tilelang.layout.make_swizzled_layout(V_shared), - dO_shared: tilelang.layout.make_swizzled_layout(dO_shared), - h_shared: tilelang.layout.make_swizzled_layout(h_shared), - dh_shared: tilelang.layout.make_swizzled_layout(dh_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - q_shared: tilelang.layout.make_swizzled_layout(q_shared), - k_shared: tilelang.layout.make_swizzled_layout(k_shared), - }) + T.annotate_layout( + { + V_shared: tilelang.layout.make_swizzled_layout(V_shared), + dO_shared: tilelang.layout.make_swizzled_layout(dO_shared), + h_shared: tilelang.layout.make_swizzled_layout(h_shared), + dh_shared: tilelang.layout.make_swizzled_layout(dh_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + q_shared: tilelang.layout.make_swizzled_layout(q_shared), + k_shared: tilelang.layout.make_swizzled_layout(k_shared), + } + ) T.clear(dg_last_local) T.clear(G_last_local) @@ -235,18 +234,10 @@ def tilelang_chunk_o_bwd_dqkwg( T.clear(dw_fragment) for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): - T.copy( - V[bb, bs * block_S:(bs + 1) * block_S, bh, i_v * block_DV:(i_v + 1) * block_DV], - V_shared) - T.copy( - dO[bb, bs * block_S:(bs + 1) * block_S, bh, - i_v * block_DV:(i_v + 1) * block_DV], dO_shared) - T.copy( - h[bb, bs, bh, bk * block_DK:(bk + 1) * block_DK, - i_v * block_DV:(i_v + 1) * block_DV], h_shared) - T.copy( - dh[bb, bs, bh, bk * block_DK:(bk + 1) * block_DK, - i_v * block_DV:(i_v + 1) * block_DV], dh_shared) + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared) + T.copy(dO[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], dO_shared) + T.copy(h[bb, bs, bh, bk * block_DK : (bk + 1) * block_DK, i_v * block_DV : (i_v + 1) * block_DV], h_shared) + T.copy(dh[bb, bs, bh, bk * block_DK : (bk + 1) * block_DK, i_v * block_DV : (i_v + 1) * block_DV], dh_shared) if use_g: T.clear(dg_last_fragment_scalar) @@ -254,9 +245,7 @@ def tilelang_chunk_o_bwd_dqkwg( # for i_kv in T.Parallel(block_DK * block_DV): # dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV] for i_kv in T.Parallel(block_DK * block_DV): - dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % - block_DV] * dh_shared[i_kv // block_DV, - i_kv % block_DV] + dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV] T.reduce_sum(dg_last_fragment, dg_last_fragment_scalar, dim=-1, clear=False) dg_last_local[0] += dg_last_fragment_scalar[0] @@ -265,22 +254,16 @@ def tilelang_chunk_o_bwd_dqkwg( T.gemm(V_shared, dh_shared, dk_fragment, transpose_B=True) if use_dw: - T.copy( - dv[bb, bs * block_S:(bs + 1) * block_S, bh, - i_v * block_DV:(i_v + 1) * block_DV], dv_shared) + T.copy(dv[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], dv_shared) T.gemm(dv_shared, h_shared, dw_fragment, transpose_B=True) if use_dw: for i_s, i_k in T.Parallel(block_S, block_DK): dw_fragment[i_s, i_k] = -dw_fragment[i_s, i_k] - T.copy( - dw_fragment, dw[bb, bs * block_S:(bs + 1) * block_S, bh, - bk * block_DK:(bk + 1) * block_DK]) - - T.copy(Q[bb, bs * block_S:(bs + 1) * block_S, bh, bk * block_DK:(bk + 1) * block_DK], - q_shared) - T.copy(K[bb, bs * block_S:(bs + 1) * block_S, bh, bk * block_DK:(bk + 1) * block_DK], - k_shared) + T.copy(dw_fragment, dw[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + + T.copy(Q[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK], q_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK], k_shared) T.copy(q_shared, q_fragment) T.copy(k_shared, k_fragment) @@ -294,8 +277,7 @@ def tilelang_chunk_o_bwd_dqkwg( dg_last_local[0] = dg_last_local[0] * T.exp(G[bb, bs * block_S + block_S - 1, bh]) for i_s, i_k in T.Parallel(block_S, block_DK): - dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * T.exp(G[bb, bs * block_S + i_s, - bh]) * scale + dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * T.exp(G[bb, bs * block_S + i_s, bh]) * scale T.clear(dg_fragment_reduce_tmp) for i_s, i_k in T.Parallel(block_S, block_DK): dg_fragment_reduce_tmp[i_s, i_k] = dq_fragment[i_s, i_k] * q_shared[i_s, i_k] @@ -305,8 +287,7 @@ def tilelang_chunk_o_bwd_dqkwg( for i_s, i_k in T.Parallel(block_S, block_DK): with T.If(G_last_local[0] - G[bb, bs * block_S + i_s, bh] <= 0): with T.Then(): - dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] * T.exp( - G_last_local[0] - G[bb, bs * block_S + i_s, bh]) + dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] * T.exp(G_last_local[0] - G[bb, bs * block_S + i_s, bh]) with T.Else(): dk_fragment[i_s, i_k] = 0 T.clear(dg_fragment_reduce_tmp) @@ -325,12 +306,11 @@ def tilelang_chunk_o_bwd_dqkwg( dg_last_local[1] = dg_last_fragment_scalar_2[0] for i_s1, i_s2 in T.Parallel(block_S, block_S): - with T.If(i_s1 >= i_s2 and - G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0): + with T.If(i_s1 >= i_s2 and G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0): with T.Then(): - ds_fragment[i_s1, i_s2] = ds_fragment[ - i_s1, i_s2] * T.exp(G[bb, bs * block_S + i_s1, bh] - - G[bb, bs * block_S + i_s2, bh]) * scale + ds_fragment[i_s1, i_s2] = ( + ds_fragment[i_s1, i_s2] * T.exp(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh]) * scale + ) with T.Else(): ds_fragment[i_s1, i_s2] = 0 @@ -338,8 +318,7 @@ def tilelang_chunk_o_bwd_dqkwg( T.clear(ds_fragment_positive_transpose) T.gemm(q_shared, k_shared, ds_fragment_positive, transpose_B=True) for i_s1, i_s2 in T.Parallel(block_S, block_S): - ds_fragment_positive[ - i_s1, i_s2] = ds_fragment[i_s1, i_s2] * ds_fragment_positive[i_s1, i_s2] + ds_fragment_positive[i_s1, i_s2] = ds_fragment[i_s1, i_s2] * ds_fragment_positive[i_s1, i_s2] # FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass T.reduce_sum(ds_fragment_positive, dg_fragment, dim=1, clear=False) @@ -363,15 +342,10 @@ def tilelang_chunk_o_bwd_dqkwg( for i_s in T.Parallel(block_S): with T.If(i_s >= block_S - 1): # noqa: SIM117 with T.Then(): - dg_fragment_final[ - i_s] = dg_fragment_final[i_s] + dg_last_local[0] + dg_last_local[1] - - T.copy( - dq_fragment, dq[bb, bs * block_S:(bs + 1) * block_S, bh, - bk * block_DK:(bk + 1) * block_DK]) - T.copy( - dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh, - bk * block_DK:(bk + 1) * block_DK]) + dg_fragment_final[i_s] = dg_fragment_final[i_s] + dg_last_local[0] + dg_last_local[1] + + T.copy(dq_fragment, dq[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) for i_s in T.Parallel(block_S): dg[bk, bb, bs * block_S + i_s, bh] = dg_fragment_final[i_s] @@ -387,12 +361,8 @@ def tilelang_chunk_o_bwd_dqkwg( for i_s, i_k in T.Parallel(block_S, block_DK): dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * scale dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] + dk_fragment_2[i_s, i_k] * scale - T.copy( - dq_fragment, dq[bb, bs * block_S:(bs + 1) * block_S, bh, - bk * block_DK:(bk + 1) * block_DK]) - T.copy( - dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh, - bk * block_DK:(bk + 1) * block_DK]) + T.copy(dq_fragment, dq[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) return kernel @@ -442,32 +412,53 @@ def run_test( threads=256, num_stages=0, ): - Q, K, V, h, G, dO, dh, dv, W = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, accum_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype)) - dq_ref, dk_ref, dw_ref, dg_ref = prepare_output(B, S, H, DK, DV, chunk_size, - getattr(torch, output_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype), block_DK) + Q, K, V, h, G, dO, dh, dv, W = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + dq_ref, dk_ref, dw_ref, dg_ref = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype), block_DK + ) dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = prepare_output( - B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), - getattr(torch, state_dtype), block_DK) + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype), block_DK + ) # ref if use_g: - dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg( - Q, K, V, G, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale) + dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg(Q, K, V, G, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale) else: - dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg( - Q, K, V, None, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale) + dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg(Q, K, V, None, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale) # tilelang - kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, - gate_dtype, state_dtype, chunk_size, scale, use_g, use_dw, - block_DK, block_DV, threads, num_stages) + kernel = tilelang_chunk_o_bwd_dqkwg( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + scale, + use_g, + use_dw, + block_DK, + block_DV, + threads, + num_stages, + ) dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W) if use_g: diff --git a/examples/gdn/example_chunk_scaled_dot_kkt.py b/examples/gdn/example_chunk_scaled_dot_kkt.py index d07a477..af2b08e 100644 --- a/examples/gdn/example_chunk_scaled_dot_kkt.py +++ b/examples/gdn/example_chunk_scaled_dot_kkt.py @@ -9,6 +9,7 @@ import sys # noqa: F401 # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__) from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd except ImportError: @@ -75,10 +76,10 @@ def tilelang_chunk_scaled_dot_kkt_fwd( @T.prim_func def kernel( - K: T.Tensor(K_shape, dtype=input_dtype), - Beta: T.Tensor(Beta_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=accum_dtype), - A: T.Tensor(output_shape, dtype=output_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + Beta: T.Tensor(Beta_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=accum_dtype), + A: T.Tensor(output_shape, dtype=output_dtype), ): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): bb, bh = bbh // H, bbh % H @@ -93,10 +94,12 @@ def tilelang_chunk_scaled_dot_kkt_fwd( G_shared = T.alloc_shared((block_S,), dtype=accum_dtype, scope="shared") G_diff_local = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) - T.annotate_layout({ - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - }) + T.annotate_layout( + { + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + } + ) T.fill(A_fragment, 0) T.disable_warp_group_reg_alloc() @@ -104,9 +107,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd( Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): - T.copy( - K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK], - K_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) for i_s, i_k2 in T.Parallel(block_S, block_DK): Beta_K_fragment[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] T.gemm(Beta_K_fragment, K_shared, A_fragment, transpose_B=True) @@ -119,8 +120,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd( for i_s1, i_s2 in T.Parallel(block_S, block_S): with T.If(G_diff_local[i_s1, i_s2] <= 0 and i_s1 > i_s2): with T.Then(): - A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp( - G_diff_local[i_s1, i_s2]) + A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2]) with T.Else(): A_fragment[i_s1, i_s2] = 0 else: @@ -130,7 +130,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd( A_fragment[i_s1, i_s2] = 0 T.copy(A_fragment, A_shared) - T.copy(A_shared, A[bb, bs * block_S:(bs + 1) * block_S, bh, :]) + T.copy(A_shared, A[bb, bs * block_S : (bs + 1) * block_S, bh, :]) return kernel @@ -149,24 +149,21 @@ def run_test( threads, num_stages, ): - K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), - getattr(torch, output_dtype), getattr(torch, accum_dtype)) + K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), getattr(torch, output_dtype), getattr(torch, accum_dtype)) A_ref = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype)) A_tilelang = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype)) # reference if use_g: - A_ref = chunk_scaled_dot_kkt_fwd( - K, Beta, G, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype)) + A_ref = chunk_scaled_dot_kkt_fwd(K, Beta, G, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype)) else: - A_ref = chunk_scaled_dot_kkt_fwd( - K, Beta, None, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype)) + A_ref = chunk_scaled_dot_kkt_fwd(K, Beta, None, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype)) # tilelang block_S = chunk_size - kernel = tilelang_chunk_scaled_dot_kkt_fwd(B, S, H, DK, chunk_size, input_dtype, output_dtype, - accum_dtype, use_g, block_S, block_DK, threads, - num_stages) + kernel = tilelang_chunk_scaled_dot_kkt_fwd( + B, S, H, DK, chunk_size, input_dtype, output_dtype, accum_dtype, use_g, block_S, block_DK, threads, num_stages + ) A_tilelang = kernel(K, Beta, G) try: @@ -192,7 +189,8 @@ def main(): use_g=True, block_DK=64, threads=128, - num_stages=2) + num_stages=2, + ) if __name__ == "__main__": diff --git a/examples/gdn/example_cumsum.py b/examples/gdn/example_cumsum.py index 9896c7e..13547cd 100644 --- a/examples/gdn/example_cumsum.py +++ b/examples/gdn/example_cumsum.py @@ -10,6 +10,7 @@ import sys # noqa: F401 # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__) from fla.ops.utils.cumsum import chunk_local_cumsum_scalar except ImportError: @@ -20,11 +21,8 @@ import torch @tilelang.jit( - out_idx=[-1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True - }) + out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True} +) def tilelang_chunk_local_cumsum_scalar( # task config B, @@ -42,35 +40,35 @@ def tilelang_chunk_local_cumsum_scalar( use_fragment=False, ): G_shape = (B, H, S) if head_first else (B, S, H) - assert chunk_size == 2**(chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" assert chunk_size == block_S, "chunk_size must be equal to block_S" @T.prim_func def kernel( - G: T.Tensor(G_shape, dtype=input_dtype), - G_new: T.Tensor(G_shape, dtype=output_dtype), + G: T.Tensor(G_shape, dtype=input_dtype), + G_new: T.Tensor(G_shape, dtype=output_dtype), ): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): bb, bh = bbh // H, bbh % H G_shared = T.alloc_shared((1, block_S), dtype=output_dtype, scope="shared") if head_first: - T.copy(G[bb, bh, bs * block_S:(bs + 1) * block_S], G_shared) + T.copy(G[bb, bh, bs * block_S : (bs + 1) * block_S], G_shared) else: - T.copy(G[bb, bs * block_S:(bs + 1) * block_S, bh], G_shared) + T.copy(G[bb, bs * block_S : (bs + 1) * block_S, bh], G_shared) if use_fragment: G_fragment = T.alloc_fragment((1, block_S), dtype=output_dtype, scope="shared") T.copy(G_shared, G_fragment) T.cumsum(G_fragment, dim=1, reverse=reverse) if head_first: - T.copy(G_fragment, G_new[bb, bh, bs * block_S:(bs + 1) * block_S]) + T.copy(G_fragment, G_new[bb, bh, bs * block_S : (bs + 1) * block_S]) else: - T.copy(G_fragment, G_new[bb, bs * block_S:(bs + 1) * block_S, bh]) + T.copy(G_fragment, G_new[bb, bs * block_S : (bs + 1) * block_S, bh]) else: T.cumsum(G_shared, dim=1, reverse=reverse) if head_first: - T.copy(G_shared, G_new[bb, bh, bs * block_S:(bs + 1) * block_S]) + T.copy(G_shared, G_new[bb, bh, bs * block_S : (bs + 1) * block_S]) else: - T.copy(G_shared, G_new[bb, bs * block_S:(bs + 1) * block_S, bh]) + T.copy(G_shared, G_new[bb, bs * block_S : (bs + 1) * block_S, bh]) return kernel @@ -113,11 +111,8 @@ def run_test( # reference cumsum G_new_ref = chunk_local_cumsum_scalar( - g=G, - chunk_size=chunk_size, - reverse=reverse, - head_first=head_first, - output_dtype=getattr(torch, output_dtype)) + g=G, chunk_size=chunk_size, reverse=reverse, head_first=head_first, output_dtype=getattr(torch, output_dtype) + ) # tilelang cumsum block_S = chunk_size @@ -162,7 +157,8 @@ def main(): input_dtype="float32", output_dtype="float32", threads=256, - use_fragment=False) + use_fragment=False, + ) if __name__ == "__main__": diff --git a/examples/gdn/example_wy_fast.py b/examples/gdn/example_wy_fast.py index 0a0983a..874e25c 100644 --- a/examples/gdn/example_wy_fast.py +++ b/examples/gdn/example_wy_fast.py @@ -9,6 +9,7 @@ import sys # noqa: F401 # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__) from fla.ops.gated_delta_rule.wy_fast import recompute_w_u_fwd except ImportError: @@ -73,13 +74,13 @@ def tilelang_recompute_w_u_fwd( @T.prim_func def kernel( - K: T.Tensor(K_shape, dtype=input_dtype), - V: T.Tensor(V_shape, dtype=input_dtype), - Beta: T.Tensor(Beta_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=gate_dtype), - A: T.Tensor(A_shape, dtype=output_dtype), - W: T.Tensor(K_shape, dtype=output_dtype), - U: T.Tensor(V_shape, dtype=output_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + Beta: T.Tensor(Beta_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + A: T.Tensor(A_shape, dtype=output_dtype), + W: T.Tensor(K_shape, dtype=output_dtype), + U: T.Tensor(V_shape, dtype=output_dtype), ): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): bb, bh = bbh // H, bbh % H @@ -95,49 +96,42 @@ def tilelang_recompute_w_u_fwd( W_Beta_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) U_Beta_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) - T.annotate_layout({ - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - V_shared: tilelang.layout.make_swizzled_layout(V_shared), - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - W_shared: tilelang.layout.make_swizzled_layout(W_shared), - U_shared: tilelang.layout.make_swizzled_layout(U_shared), - W_Beta_shared: tilelang.layout.make_swizzled_layout(W_Beta_shared), - U_Beta_shared: tilelang.layout.make_swizzled_layout(U_Beta_shared), - }) + T.annotate_layout( + { + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + V_shared: tilelang.layout.make_swizzled_layout(V_shared), + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + W_shared: tilelang.layout.make_swizzled_layout(W_shared), + U_shared: tilelang.layout.make_swizzled_layout(U_shared), + W_Beta_shared: tilelang.layout.make_swizzled_layout(W_Beta_shared), + U_Beta_shared: tilelang.layout.make_swizzled_layout(U_Beta_shared), + } + ) T.disable_warp_group_reg_alloc() for i_s in T.Parallel(block_S): Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] G_shared[i_s] = T.exp(G[bb, bs * block_S + i_s, bh]) - T.copy(A[bb, bs * block_S:(bs + 1) * block_S, bh, :], A_shared) + T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared) for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): - T.copy( - V[bb, bs * block_S:(bs + 1) * block_S, bh, i_v * block_DV:(i_v + 1) * block_DV], - V_shared) + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared) for i_s, i_v2 in T.Parallel(block_S, block_DV): U_Beta_shared[i_s, i_v2] = V_shared[i_s, i_v2] * Beta_shared[i_s] T.gemm(A_shared, U_Beta_shared, U_fragment, clear_accum=True) # First copy to smem, then copy to gmem to reduce U2RU instructions T.copy(U_fragment, U_shared) - T.copy( - U_shared, U[bb, bs * block_S:(bs + 1) * block_S, bh, - i_v * block_DV:(i_v + 1) * block_DV]) + T.copy(U_shared, U[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV]) for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): - T.copy( - K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK], - K_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) for i_s, i_k2 in T.Parallel(block_S, block_DK): - W_Beta_shared[i_s, - i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared[i_s] + W_Beta_shared[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared[i_s] T.gemm(A_shared, W_Beta_shared, W_fragment, clear_accum=True) # First copy to smem, then copy to gmem to reduce U2RU instructions T.copy(W_fragment, W_shared) - T.copy( - W_shared, W[bb, bs * block_S:(bs + 1) * block_S, bh, - i_k * block_DK:(i_k + 1) * block_DK]) + T.copy(W_shared, W[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK]) return kernel @@ -159,15 +153,8 @@ def run_test( num_stages, ): K, V, Beta, G, A = prepare_input( - B, - S, - H, - DK, - DV, - chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - gate_dtype=getattr(torch, gate_dtype)) + B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, output_dtype), gate_dtype=getattr(torch, gate_dtype) + ) W_ref, U_ref = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype)) W_tilelang, U_tilelang = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype)) @@ -191,7 +178,8 @@ def run_test( block_DK=block_DK, block_DV=block_DV, threads=threads, - num_stages=num_stages) + num_stages=num_stages, + ) print(kernel.get_kernel_source()) W_tilelang, U_tilelang = kernel(K, V, Beta, G, A) @@ -224,7 +212,8 @@ def main(): block_DK=64, block_DV=32, threads=128, - num_stages=3) + num_stages=3, + ) if __name__ == "__main__": diff --git a/examples/gdn/example_wy_fast_bwd_split.py b/examples/gdn/example_wy_fast_bwd_split.py index 42a0040..5b0230e 100644 --- a/examples/gdn/example_wy_fast_bwd_split.py +++ b/examples/gdn/example_wy_fast_bwd_split.py @@ -10,6 +10,7 @@ import tilelang.language as T # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__) from fla.ops.gated_delta_rule.wy_fast import bwd_prepare_wy_repr except ImportError: @@ -93,10 +94,8 @@ def prepare_output( @tilelang.jit( out_idx=[-5, -4, -3, -2, -1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True - }) + pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, +) def tilelang_wy_fast_bwd( # task config B, @@ -135,20 +134,20 @@ def tilelang_wy_fast_bwd( @T.prim_func def kernel( - # input - K: T.Tensor(K_shape, dtype=input_dtype), - V: T.Tensor(V_shape, dtype=input_dtype), - Beta: T.Tensor(Beta_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=gate_dtype), - A: T.Tensor(A_shape, dtype=input_dtype), - dw: T.Tensor(dw_shape, dtype=input_dtype), - du: T.Tensor(du_shape, dtype=input_dtype), - # output - dA: T.Tensor(dA_shape, dtype=input_dtype), - dk: T.Tensor(dk_shape, dtype=output_dtype), - dv: T.Tensor(dv_shape, dtype=output_dtype), - dbeta: T.Tensor(dbeta_shape, dtype=output_dtype), - dg: T.Tensor(dg_shape, dtype=gate_dtype), + # input + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + Beta: T.Tensor(Beta_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + A: T.Tensor(A_shape, dtype=input_dtype), + dw: T.Tensor(dw_shape, dtype=input_dtype), + du: T.Tensor(du_shape, dtype=input_dtype), + # output + dA: T.Tensor(dA_shape, dtype=input_dtype), + dk: T.Tensor(dk_shape, dtype=output_dtype), + dv: T.Tensor(dv_shape, dtype=output_dtype), + dbeta: T.Tensor(dbeta_shape, dtype=output_dtype), + dg: T.Tensor(dg_shape, dtype=gate_dtype), ): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): bb, bh = bbh // H, bbh % H @@ -187,7 +186,7 @@ def tilelang_wy_fast_bwd( T.clear(dbeta_fragment_v) T.clear(dg_fragment) - T.copy(A[bb, bs * block_S:(bs + 1) * block_S, bh, :], A_shared) + T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared) for i_s in T.Parallel(block_S): Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] G_shared[i_s] = G[bb, bs * block_S + i_s, bh] @@ -195,51 +194,37 @@ def tilelang_wy_fast_bwd( # Update dk for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): - T.copy( - K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK], - K_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) for i_s, i_k2 in T.Parallel(block_S, block_DK): - K_shared_beta_g[i_s, - i_k2] = K_shared[i_s, - i_k2] * Beta_shared[i_s] * G_shared_exp[i_s] - T.copy( - dw[bb, bs * block_S:(bs + 1) * block_S, bh, - i_k * block_DK:(i_k + 1) * block_DK], dw_shared) + K_shared_beta_g[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared_exp[i_s] + T.copy(dw[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], dw_shared) T.gemm(dw_shared, K_shared_beta_g, dA_fragment, transpose_B=True) T.gemm(A_shared, dw_shared, dk_fragment_beta_g, clear_accum=True, transpose_A=True) for i_s, i_k2 in T.Parallel(block_S, block_DK): - dk_fragment[ - i_s, - i_k2] = dk_fragment_beta_g[i_s, i_k2] * Beta_shared[i_s] * G_shared_exp[i_s] + dk_fragment[i_s, i_k2] = dk_fragment_beta_g[i_s, i_k2] * Beta_shared[i_s] * G_shared_exp[i_s] # for i_s, i_k2 in T.Parallel(block_S, block_DK): # dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] for i_s, i_k2 in T.Parallel(block_S, block_DK): - dbeta_fragment_reduce_tmpk[i_s, i_k2] = dk_fragment_beta_g[ - i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] + dbeta_fragment_reduce_tmpk[i_s, i_k2] = dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] T.reduce_sum(dbeta_fragment_reduce_tmpk, dbeta_fragment_k, dim=1, clear=False) # for i_s, i_k2 in T.Parallel(block_S, block_DK): # dg_fragment[i_s] = dg_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s] for i_s, i_k2 in T.Parallel(block_S, block_DK): - dg_fragment_reduce_tmp[i_s, i_k2] = dk_fragment_beta_g[i_s, i_k2] * K_shared[ - i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s] + dg_fragment_reduce_tmp[i_s, i_k2] = ( + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s] + ) T.reduce_sum(dg_fragment_reduce_tmp, dg_fragment, dim=1, clear=False) # correct dk - T.copy( - dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh, - i_k * block_DK:(i_k + 1) * block_DK]) + T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK]) # Update dv for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): - T.copy( - V[bb, bs * block_S:(bs + 1) * block_S, bh, i_v * block_DV:(i_v + 1) * block_DV], - V_shared) + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared) for i_s, i_v2 in T.Parallel(block_S, block_DV): V_shared_beta[i_s, i_v2] = V_shared[i_s, i_v2] * Beta_shared[i_s] - T.copy( - du[bb, bs * block_S:(bs + 1) * block_S, bh, - i_v * block_DV:(i_v + 1) * block_DV], du_shared) + T.copy(du[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], du_shared) T.gemm(du_shared, V_shared_beta, dA_fragment, transpose_B=True) T.gemm(A_shared, du_shared, dv_fragment_beta, clear_accum=True, transpose_A=True) for i_s, i_v2 in T.Parallel(block_S, block_DV): @@ -247,30 +232,22 @@ def tilelang_wy_fast_bwd( # for i_s, i_v2 in T.Parallel(block_S, block_DV): # dbeta_fragment[i_s] = dbeta_fragment[i_s] + dv_fragment_beta[i_s, i_v2] * V_shared[i_s, i_v2] for i_s, i_v2 in T.Parallel(block_S, block_DV): - dbeta_fragment_reduce_tmpv[i_s, - i_v2] = dv_fragment_beta[i_s, i_v2] * V_shared[i_s, - i_v2] + dbeta_fragment_reduce_tmpv[i_s, i_v2] = dv_fragment_beta[i_s, i_v2] * V_shared[i_s, i_v2] T.reduce_sum(dbeta_fragment_reduce_tmpv, dbeta_fragment_v, dim=1, clear=False) - T.copy( - dv_fragment, dv[bb, bs * block_S:(bs + 1) * block_S, bh, - i_v * block_DV:(i_v + 1) * block_DV]) + T.copy(dv_fragment, dv[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV]) # Temporary store dbeta, dg and dA for i_s in T.Parallel(block_S): dbeta[bb, bs * block_S + i_s, bh] = dbeta_fragment_k[i_s] + dbeta_fragment_v[i_s] dg[bb, bs * block_S + i_s, bh] = dg_fragment[i_s] # correct dA - T.copy(dA_fragment, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :]) + T.copy(dA_fragment, dA[bb, bs * block_S : (bs + 1) * block_S, bh, :]) return kernel -@tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True - }) +@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}) def tilelang_wy_fast_bwd_split( # task config B, @@ -308,20 +285,20 @@ def tilelang_wy_fast_bwd_split( @T.prim_func def kernel( - # input - K: T.Tensor(K_shape, dtype=input_dtype), - V: T.Tensor(V_shape, dtype=input_dtype), - Beta: T.Tensor(Beta_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=gate_dtype), - A: T.Tensor(A_shape, dtype=input_dtype), - dw: T.Tensor(dw_shape, dtype=input_dtype), - du: T.Tensor(du_shape, dtype=input_dtype), - dA: T.Tensor(dA_shape, dtype=input_dtype), - dk: T.Tensor(dk_shape, dtype=output_dtype), - dv: T.Tensor(dv_shape, dtype=output_dtype), - dbeta_k: T.Tensor(dbeta_shape, dtype=output_dtype), - dg_A_positive: T.Tensor(dA_shape, dtype=gate_dtype), - dg_A_negative: T.Tensor(dA_shape, dtype=gate_dtype), + # input + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + Beta: T.Tensor(Beta_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + A: T.Tensor(A_shape, dtype=input_dtype), + dw: T.Tensor(dw_shape, dtype=input_dtype), + du: T.Tensor(du_shape, dtype=input_dtype), + dA: T.Tensor(dA_shape, dtype=input_dtype), + dk: T.Tensor(dk_shape, dtype=output_dtype), + dv: T.Tensor(dv_shape, dtype=output_dtype), + dbeta_k: T.Tensor(dbeta_shape, dtype=output_dtype), + dg_A_positive: T.Tensor(dA_shape, dtype=gate_dtype), + dg_A_negative: T.Tensor(dA_shape, dtype=gate_dtype), ): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): bb, bh = bbh // H, bbh % H @@ -350,7 +327,7 @@ def tilelang_wy_fast_bwd_split( T.clear(dA_A_fragment_1) T.clear(dA_A_fragment_2) - T.copy(A[bb, bs * block_S:(bs + 1) * block_S, bh, :], A_shared) + T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared) for i_s in T.Parallel(block_S): Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] G_shared[i_s] = G[bb, bs * block_S + i_s, bh] @@ -361,7 +338,7 @@ def tilelang_wy_fast_bwd_split( # for i_s in T.Parallel(block_S): # dbeta_fragment[i_s] = dbeta[bb, bs * block_S + i_s, bh] # dg_fragment[i_s] = dg[bb, bs * block_S + i_s, bh] - T.copy(dA[bb, bs * block_S:(bs + 1) * block_S, bh, :], dA_shared) + T.copy(dA[bb, bs * block_S : (bs + 1) * block_S, bh, :], dA_shared) # T.copy(dA_shared, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :]) # Update dA @@ -385,8 +362,7 @@ def tilelang_wy_fast_bwd_split( for i_s1, i_s2 in T.Parallel(block_S, block_S): with T.If(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0): with T.Then(): - dA_fragment[i_s1, i_s2] *= T.exp(G[bb, bs * block_S + i_s1, bh] - - G[bb, bs * block_S + i_s2, bh]) + dA_fragment[i_s1, i_s2] *= T.exp(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh]) with T.Else(): dA_fragment[i_s1, i_s2] = 0 T.copy(dA_fragment, dA_shared) @@ -397,12 +373,8 @@ def tilelang_wy_fast_bwd_split( # Update dk using previous dk T.clear(A_fragment) for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): - T.copy( - K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK], - K_shared) - T.copy( - dk[bb, bs * block_S:(bs + 1) * block_S, bh, - i_k * block_DK:(i_k + 1) * block_DK], dk_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) + T.copy(dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], dk_shared) T.copy(dk_shared, dk_fragment) for i_s, i_k2 in T.Parallel(block_S, block_DK): K_shared_beta[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] @@ -411,18 +383,14 @@ def tilelang_wy_fast_bwd_split( # for i_s, i_k2 in T.Parallel(block_S, block_DK): # dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta[i_s, i_k2] * K_shared[i_s, i_k2] for i_s, i_k2 in T.Parallel(block_S, block_DK): - dbeta_fragment_reduce_tmpk[i_s, - i_k2] = dk_fragment_beta[i_s, i_k2] * K_shared[i_s, - i_k2] + dbeta_fragment_reduce_tmpk[i_s, i_k2] = dk_fragment_beta[i_s, i_k2] * K_shared[i_s, i_k2] T.reduce_sum(dbeta_fragment_reduce_tmpk, dbeta_fragment_k, dim=1, clear=False) T.gemm(dA_shared, K_shared_beta, dk_fragment, transpose_A=True) for i_s, i_k2 in T.Parallel(block_S, block_DK): dk_shared_beta[i_s, i_k2] = dk_fragment_beta[i_s, i_k2] * Beta_shared[i_s] for i_s, i_k2 in T.Parallel(block_S, block_DK): dk_fragment[i_s, i_k2] = dk_fragment[i_s, i_k2] + dk_shared_beta[i_s, i_k2] - T.copy( - dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh, - i_k * block_DK:(i_k + 1) * block_DK]) + T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK]) # Update dg and dbeta T.copy(A_fragment, A_shared) @@ -460,19 +428,25 @@ def run_test( threads=128, num_stages=0, ): - K, V, Beta, G, A, dw, du = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, - accum_dtype), getattr(torch, gate_dtype), - getattr(torch, state_dtype)) - dk_ref, dv_ref, dbeta_ref, dg_ref = prepare_output(B, S, H, DK, DV, chunk_size, - getattr(torch, output_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype)) + K, V, Beta, G, A, dw, du = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + dk_ref, dv_ref, dbeta_ref, dg_ref = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = prepare_output( - B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), - getattr(torch, state_dtype)) + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) BS = chunk_size dA_tilelang = torch.empty(B, S, H, BS, dtype=getattr(torch, input_dtype)).cuda() dbeta_tilelang_k = torch.empty(B, S, H, dtype=getattr(torch, output_dtype)).cuda() @@ -480,28 +454,55 @@ def run_test( dg_tilelang_A_negative = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda() # ref - dk_ref, dv_ref, dbeta_ref, dg_ref = bwd_prepare_wy_repr( - K, V, G, Beta, A, dw, du, cu_seqlens=None) + dk_ref, dv_ref, dbeta_ref, dg_ref = bwd_prepare_wy_repr(K, V, G, Beta, A, dw, du, cu_seqlens=None) # tilelang - kernel = tilelang_wy_fast_bwd(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, - gate_dtype, state_dtype, chunk_size, block_DK, block_DV, threads, - num_stages) - dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel( - K, V, Beta, G, A, dw, du) + kernel = tilelang_wy_fast_bwd( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + block_DK, + block_DV, + threads, + num_stages, + ) + dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel(K, V, Beta, G, A, dw, du) torch.cuda.synchronize() - kernel_split = tilelang_wy_fast_bwd_split(B, S, H, DK, DV, input_dtype, output_dtype, - accum_dtype, gate_dtype, state_dtype, chunk_size, - block_DK, block_DV, threads, num_stages) - kernel_split(K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, - dg_tilelang_A_positive, dg_tilelang_A_negative) + kernel_split = tilelang_wy_fast_bwd_split( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + block_DK, + block_DV, + threads, + num_stages, + ) + kernel_split( + K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, dg_tilelang_A_positive, dg_tilelang_A_negative + ) torch.cuda.synchronize() dbeta_tilelang = dbeta_tilelang_k + dbeta_tilelang - dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum( - dim=-1) + dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(dim=-1) from test_utils import assert_similar + assert_similar(dk_ref, dk_tilelang, eps=1e-5, name="dk", raise_assert=False) assert_similar(dv_ref, dv_tilelang, eps=1e-5, name="dv", raise_assert=False) assert_similar(dbeta_ref, dbeta_tilelang, eps=1e-5, name="dbeta", raise_assert=False) diff --git a/examples/gdn/test_example_gdn_compilation.py b/examples/gdn/test_example_gdn_compilation.py index 75a6217..a51936e 100644 --- a/examples/gdn/test_example_gdn_compilation.py +++ b/examples/gdn/test_example_gdn_compilation.py @@ -25,16 +25,10 @@ num_stages = 1 def test_example_wy_fast_compilation(): from example_wy_fast import tilelang_recompute_w_u_fwd, prepare_input + K, V, Beta, G, A = prepare_input( - B, - S, - H, - DK, - DV, - chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - gate_dtype=getattr(torch, gate_dtype)) + B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, output_dtype), gate_dtype=getattr(torch, gate_dtype) + ) # tilelang block_S = chunk_size kernel = tilelang_recompute_w_u_fwd( @@ -52,22 +46,31 @@ def test_example_wy_fast_compilation(): block_DK=block_DK, block_DV=block_DV, threads=threads, - num_stages=num_stages) + num_stages=num_stages, + ) print(kernel.get_kernel_source()) W_tilelang, U_tilelang = kernel(K, V, Beta, G, A) def test_example_wy_fast_bwd_split_compilation(): from example_wy_fast_bwd_split import tilelang_wy_fast_bwd, tilelang_wy_fast_bwd_split, prepare_input, prepare_output - K, V, Beta, G, A, dw, du = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, - accum_dtype), getattr(torch, gate_dtype), - getattr(torch, state_dtype)) + + K, V, Beta, G, A, dw, du = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = prepare_output( - B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), - getattr(torch, state_dtype)) + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) BS = chunk_size dA_tilelang = torch.empty(B, S, H, BS, dtype=getattr(torch, input_dtype)).cuda() dbeta_tilelang_k = torch.empty(B, S, H, dtype=getattr(torch, output_dtype)).cuda() @@ -75,68 +78,146 @@ def test_example_wy_fast_bwd_split_compilation(): dg_tilelang_A_negative = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda() # tilelang - kernel = tilelang_wy_fast_bwd(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, - gate_dtype, state_dtype, chunk_size, block_DK, block_DV, threads, - num_stages) - dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel( - K, V, Beta, G, A, dw, du) + kernel = tilelang_wy_fast_bwd( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + block_DK, + block_DV, + threads, + num_stages, + ) + dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel(K, V, Beta, G, A, dw, du) torch.cuda.synchronize() - kernel_split = tilelang_wy_fast_bwd_split(B, S, H, DK, DV, input_dtype, output_dtype, - accum_dtype, gate_dtype, state_dtype, chunk_size, - block_DK, block_DV, threads, num_stages) - kernel_split(K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, - dg_tilelang_A_positive, dg_tilelang_A_negative) + kernel_split = tilelang_wy_fast_bwd_split( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + block_DK, + block_DV, + threads, + num_stages, + ) + kernel_split( + K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, dg_tilelang_A_positive, dg_tilelang_A_negative + ) torch.cuda.synchronize() dbeta_tilelang = dbeta_tilelang_k + dbeta_tilelang - dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum( - dim=-1) + dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(dim=-1) def test_example_chunk_o_compilation(): from example_chunk_o import tilelang_chunk_fwd_o, prepare_input - Q, K, V, HIDDEN, G = prepare_input(B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), - getattr(torch, output_dtype), getattr(torch, accum_dtype), - getattr(torch, gate_dtype)) + + Q, K, V, HIDDEN, G = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + ) scale = 1.0 / DK**0.5 block_S = chunk_size - kernel = tilelang_chunk_fwd_o(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, - gate_dtype, chunk_size, scale, use_g, block_S, block_DK, block_DV, - threads, num_stages) + kernel = tilelang_chunk_fwd_o( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + chunk_size, + scale, + use_g, + block_S, + block_DK, + block_DV, + threads, + num_stages, + ) O_tilelang = kernel(Q, K, V, HIDDEN, G) # noqa: F841 def test_example_chunk_o_bwd_compilation(): from example_chunk_o_bwd import tilelang_chunk_o_bwd_dqkwg, prepare_input - Q, K, V, h, G, dO, dh, dv, W = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, accum_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype)) - kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, - gate_dtype, state_dtype, chunk_size, 1.0, use_g, True, - block_DK, block_DV, threads, num_stages) - - dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, - W) # noqa: F841 + + Q, K, V, h, G, dO, dh, dv, W = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + kernel = tilelang_chunk_o_bwd_dqkwg( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + 1.0, + use_g, + True, + block_DK, + block_DV, + threads, + num_stages, + ) + + dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W) # noqa: F841 if use_g: dg_tilelang = dg_tilelang.sum(dim=0) def test_example_chunk_scaled_dot_kkt_compilation(): from example_chunk_scaled_dot_kkt import tilelang_chunk_scaled_dot_kkt_fwd, prepare_input - K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), - getattr(torch, output_dtype), getattr(torch, accum_dtype)) + + K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), getattr(torch, output_dtype), getattr(torch, accum_dtype)) block_S = chunk_size - kernel = tilelang_chunk_scaled_dot_kkt_fwd(B, S, H, DK, chunk_size, input_dtype, output_dtype, - accum_dtype, use_g, block_S, block_DK, threads, - num_stages) + kernel = tilelang_chunk_scaled_dot_kkt_fwd( + B, S, H, DK, chunk_size, input_dtype, output_dtype, accum_dtype, use_g, block_S, block_DK, threads, num_stages + ) A_tilelang = kernel(K, Beta, G) # noqa: F841 def test_example_cumsum_compilation(): from example_cumsum import tilelang_chunk_local_cumsum_scalar, prepare_cumsum_input, prepare_cumsum_output + G = prepare_cumsum_input(B, S, H, getattr(torch, gate_dtype)) G_new_tilelang = prepare_cumsum_output(B, S, H, getattr(torch, gate_dtype)) block_S = chunk_size @@ -158,33 +239,79 @@ def test_example_cumsum_compilation(): def test_example_chunk_delta_h_compilation(): from example_chunk_delta_h import tilelang_chunk_gated_delta_rule_fwd_h, prepare_input - K, W, U, G, initial_state = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, accum_dtype), - getattr(torch, gate_dtype)) - kernel = tilelang_chunk_gated_delta_rule_fwd_h(B, S, H, DK, DV, input_dtype, output_dtype, - accum_dtype, gate_dtype, state_dtype, chunk_size, - use_g, use_initial_state, store_final_state, - save_new_value, block_DK, block_DV, threads, - num_stages) - h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, - initial_state) # noqa: F841 + + K, W, U, G, initial_state = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + ) + kernel = tilelang_chunk_gated_delta_rule_fwd_h( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + use_g, + use_initial_state, + store_final_state, + save_new_value, + block_DK, + block_DV, + threads, + num_stages, + ) + h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state) # noqa: F841 def test_example_chunk_delta_bwd_compilation(): from example_chunk_delta_bwd import tilelang_chunk_gated_delta_rule_bwd_dhu, prepare_input - Q, K, W, G, h0, dht, dO, dv = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, accum_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype)) - kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(B, S, H, DK, DV, input_dtype, output_dtype, - accum_dtype, gate_dtype, state_dtype, - chunk_size, 1.0, use_g, use_initial_state, - use_final_state_gradient, block_DV, threads, - num_stages) + + Q, K, W, G, h0, dht, dO, dv = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + kernel = tilelang_chunk_gated_delta_rule_bwd_dhu( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + 1.0, + use_g, + use_initial_state, + use_final_state_gradient, + block_DV, + threads, + num_stages, + ) dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv) # noqa: F841 diff --git a/examples/gdn/test_utils.py b/examples/gdn/test_utils.py index 37f8d8e..3588551 100644 --- a/examples/gdn/test_utils.py +++ b/examples/gdn/test_utils.py @@ -9,7 +9,7 @@ def calc_sim(x, y, name="tensor"): x, y = x.data.double(), y.data.double() denominator = (x * x + y * y).sum() if denominator == 0: - print_red_warning(f'{name} all zero') + print_red_warning(f"{name} all zero") return 1 sim = 2 * (x * y).sum() / denominator return sim @@ -19,21 +19,19 @@ def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True): x_mask = torch.isfinite(x) y_mask = torch.isfinite(y) if not torch.all(x_mask == y_mask): - print_red_warning(f'{name} Error: isfinite mask mismatch') + print_red_warning(f"{name} Error: isfinite mask mismatch") if raise_assert: raise AssertionError - if not torch.isclose( - x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, - equal_nan=True).all(): - print_red_warning(f'{name} Error: nonfinite value mismatch') + if not torch.isclose(x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, equal_nan=True).all(): + print_red_warning(f"{name} Error: nonfinite value mismatch") if raise_assert: raise AssertionError x = x.masked_fill(~x_mask, 0) y = y.masked_fill(~y_mask, 0) sim = calc_sim(x, y, name) - diff = 1. - sim + diff = 1.0 - sim if not (0 <= diff <= eps): - print_red_warning(f'{name} Error: {diff}') + print_red_warning(f"{name} Error: {diff}") if raise_assert: raise AssertionError else: diff --git a/examples/gemm/example_gemm.py b/examples/gemm/example_gemm.py index f18cd38..2c234d1 100644 --- a/examples/gemm/example_gemm.py +++ b/examples/gemm/example_gemm.py @@ -4,12 +4,11 @@ import tilelang.language as T @tilelang.jit(out_idx=[-1]) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def gemm( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) diff --git a/examples/gemm/example_gemm_autotune.py b/examples/gemm/example_gemm_autotune.py index 661ef12..badc334 100644 --- a/examples/gemm/example_gemm_autotune.py +++ b/examples/gemm/example_gemm_autotune.py @@ -90,7 +90,8 @@ def get_configs(M, N, K, with_roller=False, topk=20): num_stages, thread_num, enable_rasterization, - )) + ) + ) configs = [ { @@ -100,13 +101,13 @@ def get_configs(M, N, K, with_roller=False, topk=20): "num_stages": c[3], "thread_num": c[4], "enable_rasteration": c[5], # keep param name for backward-compat - } for c in _configs + } + for c in _configs ] return configs def get_best_config(M, N, K, with_roller=False): - def kernel( block_M=None, block_N=None, @@ -120,12 +121,11 @@ def get_best_config(M, N, K, with_roller=False): @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_N, block_K), dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -146,15 +146,18 @@ def get_best_config(M, N, K, with_roller=False): return main - autotuner = AutoTuner.from_kernel( - kernel=kernel, configs=get_configs(M, N, K, with_roller)).set_compile_args( + autotuner = ( + AutoTuner.from_kernel(kernel=kernel, configs=get_configs(M, N, K, with_roller)) + .set_compile_args( out_idx=[-1], target="auto", - ).set_profile_args( + ) + .set_profile_args( supply_type=tl.TensorSupplyType.Integer, ref_prog=ref_program, skip_check=False, ) + ) return autotuner.run(warmup=3, rep=20) @@ -167,52 +170,20 @@ def get_heuristic_config() -> dict: sm_version = sm_major * 10 + sm_minor print(f"CUDA device capability: {sm_version}") if sm_version in {80}: - return { - "block_M": 128, - "block_N": 256, - "block_K": 32, - "num_stages": 2, - "thread_num": 128, - "enable_rasteration": True - } + return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 2, "thread_num": 128, "enable_rasteration": True} elif sm_version in {90}: - return { - "block_M": 128, - "block_N": 256, - "block_K": 64, - "num_stages": 3, - "thread_num": 256, - "enable_rasteration": True - } + return {"block_M": 128, "block_N": 256, "block_K": 64, "num_stages": 3, "thread_num": 256, "enable_rasteration": True} else: - return { - "block_M": 128, - "block_N": 256, - "block_K": 32, - "num_stages": 0, - "thread_num": 128, - "enable_rasteration": True - } + return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 0, "thread_num": 128, "enable_rasteration": True} @tl.jit(out_idx=[-1]) -def matmul(M, - N, - K, - block_M, - block_N, - block_K, - num_stages, - thread_num, - enable_rasteration, - dtype="float16", - accum_dtype="float"): - +def matmul(M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype="float16", accum_dtype="float"): @T.prim_func def gemm_autotune( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -236,11 +207,7 @@ def matmul(M, return gemm_autotune -def main(M: int = 4096, - N: int = 4096, - K: int = 4096, - use_autotune: bool = False, - with_roller: bool = False): +def main(M: int = 4096, N: int = 4096, K: int = 4096, use_autotune: bool = False, with_roller: bool = False): use_autotune = True if use_autotune: result = get_best_config(M, N, K, with_roller) @@ -266,15 +233,7 @@ if __name__ == "__main__": parser.add_argument("--m", type=int, default=4096, help="Matrix dimension M") parser.add_argument("--n", type=int, default=4096, help="Matrix dimension N") parser.add_argument("--k", type=int, default=4096, help="Matrix dimension K") - parser.add_argument( - "--use_autotune", - action="store_true", - default=False, - help="Whether to use autotune for matmul configs") - parser.add_argument( - "--with_roller", - action="store_true", - default=False, - help="Whether to enable BitBLAS roller for search space") + parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune for matmul configs") + parser.add_argument("--with_roller", action="store_true", default=False, help="Whether to enable BitBLAS roller for search space") args = parser.parse_args() main(args.m, args.n, args.k, args.use_autotune, args.with_roller) diff --git a/examples/gemm/example_gemm_intrinsics.py b/examples/gemm/example_gemm_intrinsics.py index 5c014ce..488e5bf 100644 --- a/examples/gemm/example_gemm_intrinsics.py +++ b/examples/gemm/example_gemm_intrinsics.py @@ -4,7 +4,8 @@ import tilelang import tilelang.language as T from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitter,) + TensorCoreIntrinEmitter, +) from tilelang.transform import simplify_prim_func @@ -99,12 +100,11 @@ def tl_matmul( @T.prim_func def gemm_intrinsics( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -112,10 +112,12 @@ def tl_matmul( B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -123,7 +125,6 @@ def tl_matmul( T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -133,7 +134,6 @@ def tl_matmul( B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a(A_local, A_shared, ki) diff --git a/examples/gemm/example_gemm_persistent.py b/examples/gemm/example_gemm_persistent.py index a2a7122..6fc0e5a 100644 --- a/examples/gemm/example_gemm_persistent.py +++ b/examples/gemm/example_gemm_persistent.py @@ -5,22 +5,12 @@ import argparse @tilelang.jit(out_idx=[-1]) -def matmul_non_persistent(M, - N, - K, - block_M, - block_N, - block_K, - threads, - num_stages, - dtype="float16", - accum_dtype="float"): - +def matmul_non_persistent(M, N, K, block_M, block_N, block_K, threads, num_stages, dtype="float16", accum_dtype="float"): @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=threads) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -43,18 +33,9 @@ def matmul_non_persistent(M, @tilelang.jit(out_idx=[-1]) -def matmul_persistent(M, - N, - K, - block_M, - block_N, - block_K, - threads, - num_stages, - dtype="float16", - accum_dtype="float", - use_persistent_primitive=True): - +def matmul_persistent( + M, N, K, block_M, block_N, block_K, threads, num_stages, dtype="float16", accum_dtype="float", use_persistent_primitive=True +): sm_num = driver.get_num_sms() m_blocks = T.ceildiv(M, block_M) n_blocks = T.ceildiv(N, block_N) @@ -63,9 +44,9 @@ def matmul_persistent(M, @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(sm_num, threads=threads) as (block_id): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -90,9 +71,9 @@ def matmul_persistent(M, @T.prim_func def main_persistent_primitive( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(sm_num, threads=threads) as (block_id): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -100,8 +81,7 @@ def matmul_persistent(M, C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_shared = T.alloc_shared((block_M, block_N), dtype) - for bx, by in T.Persistent( - [T.ceildiv(M, block_M), T.ceildiv(N, block_N)], sm_num, block_id): + for bx, by in T.Persistent([T.ceildiv(M, block_M), T.ceildiv(N, block_N)], sm_num, block_id): T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A[bx * block_M, k * block_K], A_shared) @@ -128,18 +108,15 @@ def main(M=4096, N=4096, K=4096): num_stages = 3 persistent_kernel = matmul_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages) - persistent_profiler = persistent_kernel.get_profiler( - tensor_supply_type=tilelang.TensorSupplyType.Randn) + persistent_profiler = persistent_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) print("Persistent GEMM: All check passed.") persistent_latency = persistent_profiler.do_bench(warmup=500) print(f"Persistent GEMM Latency: {persistent_latency} ms") print(f"Persistent GEMM TFlops: {total_flops / persistent_latency * 1e-9} TFlops") - non_persistent_kernel = matmul_non_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, - num_stages) - non_persistent_profiler = non_persistent_kernel.get_profiler( - tensor_supply_type=tilelang.TensorSupplyType.Randn) + non_persistent_kernel = matmul_non_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages) + non_persistent_profiler = non_persistent_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) non_persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) print("Non-Persistent GEMM: All check passed.") non_persistent_latency = non_persistent_profiler.do_bench(warmup=500) @@ -151,9 +128,9 @@ def main(M=4096, N=4096, K=4096): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--M', type=int, default=8192, help='M dimension') - parser.add_argument('--N', type=int, default=8192, help='N dimension') - parser.add_argument('--K', type=int, default=8192, help='K dimension') + parser.add_argument("--M", type=int, default=8192, help="M dimension") + parser.add_argument("--N", type=int, default=8192, help="N dimension") + parser.add_argument("--K", type=int, default=8192, help="K dimension") args = parser.parse_args() M, N, K = args.M, args.N, args.K main(M, N, K) diff --git a/examples/gemm/example_gemm_schedule.py b/examples/gemm/example_gemm_schedule.py index f472741..d1eb11d 100644 --- a/examples/gemm/example_gemm_schedule.py +++ b/examples/gemm/example_gemm_schedule.py @@ -4,12 +4,11 @@ import tilelang.language as T @tilelang.jit(out_idx=[-1]) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def gemm_schedule( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) diff --git a/examples/gemm_fp8/example_tilelang_gemm_amd.py b/examples/gemm_fp8/example_tilelang_gemm_amd.py index 0e6ace7..4c58144 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_amd.py +++ b/examples/gemm_fp8/example_tilelang_gemm_amd.py @@ -17,10 +17,8 @@ def supply_prog(args): a_param, b_param = args M, K = a_param.shape N, _ = b_param.shape - a = (torch.randn(M, K, dtype=torch.float16, device='cuda') * - 0.01).to(dtype=torch.float8_e4m3fnuz) - b = (torch.randn(N, K, dtype=torch.float16, device='cuda') * - 0.01).to(dtype=torch.float8_e4m3fnuz) + a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz) + b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz) return [a, b] @@ -35,27 +33,24 @@ def get_configs(): valid_configs = [] - for m, n, k, stages, t, kp, gemm_type in itertools.product(block_Ms, block_Ns, block_Ks, - num_stages, num_threads, k_packs, - gemm_types): - valid_configs.append({ - "block_M": m, - "block_N": n, - "block_K": k, - "num_stages": stages, - "num_threads": t, - "k_pack": kp, - "gemm_type": gemm_type, - }) + for m, n, k, stages, t, kp, gemm_type in itertools.product(block_Ms, block_Ns, block_Ks, num_stages, num_threads, k_packs, gemm_types): + valid_configs.append( + { + "block_M": m, + "block_N": n, + "block_K": k, + "num_stages": stages, + "num_threads": t, + "k_pack": kp, + "gemm_type": gemm_type, + } + ) return valid_configs @tilelang.autotune( - configs=get_configs(), - cache_input_tensors=True, - ref_prog=ref_program, - manual_check_prog=manual_check_prog, - supply_prog=supply_prog) + configs=get_configs(), cache_input_tensors=True, ref_prog=ref_program, manual_check_prog=manual_check_prog, supply_prog=supply_prog +) @tilelang.jit(out_idx=[-1]) def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pack, gemm_type): dtype = "float8_e4m3fnuz" @@ -63,12 +58,11 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa @T.prim_func def gemm_fp8_rs( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), accum_dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), accum_dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): A_local = T.alloc_fragment((block_M, block_K), dtype) B_shared = T.alloc_shared((block_N, block_K), dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -77,24 +71,17 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A[by * block_M, k * block_K], A_local) T.copy(B[bx * block_N, k * block_K], B_shared) - T.gemm( - A_local, - B_shared, - C_local, - transpose_B=True, - k_pack=k_pack, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(A_local, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow) T.copy(C_local, C[by * block_M, bx * block_N]) @T.prim_func def gemm_fp8_ss( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), accum_dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), accum_dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_N, block_K), dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -103,13 +90,7 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A[by * block_M, k * block_K], A_shared) T.copy(B[bx * block_N, k * block_K], B_shared) - T.gemm( - A_shared, - B_shared, - C_local, - transpose_B=True, - k_pack=k_pack, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(A_shared, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow) T.copy(C_local, C[by * block_M, bx * block_N]) @@ -123,10 +104,8 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa def test_gemm_fp8(M, N, K): kernel = fp8_matmul(M, N, K) - a = (torch.randn(M, K, dtype=torch.float16, device='cuda') * - 0.01).to(dtype=torch.float8_e4m3fnuz) - b = (torch.randn(N, K, dtype=torch.float16, device='cuda') * - 0.01).to(dtype=torch.float8_e4m3fnuz) + a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz) + b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz) c = kernel(a, b) ref_c = ref_program(a, b) torch_assert_close(c, ref_c, rtol=1e-2, atol=1e-2) diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8.py b/examples/gemm_fp8/example_tilelang_gemm_fp8.py index a403ed0..1ecd344 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8.py @@ -13,12 +13,11 @@ def calc_diff(x, y): @tilelang.jit(out_idx=[-1]) def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"): - @T.prim_func def gemm_fp8( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -41,8 +40,8 @@ def test_gemm_fp8(M, N, K, dtype): kernel = matmul(M, N, K, 128, 128, 64, dtype) - a = torch.randn(M, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype) - b = torch.randn(N, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype) + a = torch.randn(M, K, dtype=torch.float16, device="cuda").to(dtype=torch_dtype) + b = torch.randn(N, K, dtype=torch.float16, device="cuda").to(dtype=torch_dtype) c = kernel(a, b) @@ -57,8 +56,8 @@ def test_gemm_fp8(M, N, K, dtype): def main(): - test_gemm_fp8(1024, 1024, 1024, 'float8_e4m3') - test_gemm_fp8(1024, 1024, 1024, 'float8_e5m2') + test_gemm_fp8(1024, 1024, 1024, "float8_e4m3") + test_gemm_fp8(1024, 1024, 1024, "float8_e5m2") if __name__ == "__main__": diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py index 1d9207a..3af4c3d 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py @@ -13,9 +13,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"): @T.prim_func def gemm_fp8_2xAcc( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), accum_dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), accum_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -59,14 +59,14 @@ def test_gemm_fp8(M, N, K, dtype): kernel = matmul(M, N, K, 128, 128, 64, dtype) - a = torch.rand(M, K, dtype=torch.float16, device='cuda') + a = torch.rand(M, K, dtype=torch.float16, device="cuda") a = (100 * (2 * a - 1)).to(dtype=torch_dtype) - b = torch.rand(N, K, dtype=torch.float16, device='cuda') + b = torch.rand(N, K, dtype=torch.float16, device="cuda") b = (100 * (2 * b - 1)).to(dtype=torch_dtype) c = kernel(a, b) - ref_c = (a.float() @ b.float().T) + ref_c = a.float() @ b.float().T diff = calc_diff(c, ref_c) print(f"diff: {diff}") @@ -74,8 +74,8 @@ def test_gemm_fp8(M, N, K, dtype): def main(): - test_gemm_fp8(1024, 1024, 8192, 'float8_e4m3') - test_gemm_fp8(1024, 1024, 8192, 'float8_e5m2') + test_gemm_fp8(1024, 1024, 8192, "float8_e4m3") + test_gemm_fp8(1024, 1024, 8192, "float8_e5m2") if __name__ == "__main__": diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py index 0e2c437..6e2d41b 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py @@ -5,7 +5,8 @@ from tvm import DataType import tilelang.language as T from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitter,) + TensorCoreIntrinEmitter, +) from tilelang.transform import simplify_prim_func from tilelang.utils.tensor import map_torch_type @@ -110,12 +111,11 @@ def tl_matmul( @T.prim_func def gemm_fp8_intrinsic( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -123,10 +123,12 @@ def tl_matmul( B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -134,7 +136,6 @@ def tl_matmul( T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -144,7 +145,6 @@ def tl_matmul( B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a( A_local, diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py index 4628a99..5cb42e3 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py @@ -26,9 +26,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -121,6 +121,4 @@ for tvm_fp8_dtype in ["float8_e4m3", "float8_e5m2"]: profiler = jit_kernel.get_profiler() latency = profiler.do_bench() print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Latency: {latency} ms") - print( - f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS" - ) + print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS") diff --git a/examples/gemm_sm100/gemm_mma.py b/examples/gemm_sm100/gemm_mma.py index a58e5a7..be43f4e 100644 --- a/examples/gemm_sm100/gemm_mma.py +++ b/examples/gemm_sm100/gemm_mma.py @@ -5,12 +5,11 @@ import tilelang.language as T # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): @@ -62,7 +61,8 @@ jit_kernel = tilelang.compile( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, +) print(jit_kernel.get_kernel_source()) # 3. Test the kernel in Python with PyTorch data import torch diff --git a/examples/gemm_sm100/gemm_tcgen5mma.py b/examples/gemm_sm100/gemm_tcgen5mma.py index 9008c7e..88614f5 100644 --- a/examples/gemm_sm100/gemm_tcgen5mma.py +++ b/examples/gemm_sm100/gemm_tcgen5mma.py @@ -25,9 +25,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -40,15 +40,7 @@ def matmul( for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A[by * block_M, k * block_K], A_shared) T.copy(B[bx * block_N, k * block_K], B_shared) - T.gemm( - A_shared, - B_shared, - C_tmem, - trans_A, - trans_B, - mbar=mbar, - wg_wait=-1, - clear_accum=k == 0) + T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, wg_wait=-1, clear_accum=k == 0) T.mbarrier_wait_parity(mbar, k % 2) T.copy(C_tmem, C_local) @@ -66,8 +58,7 @@ in_dtype, out_dtype, accum_dtype = "bfloat16", "bfloat16", "float" num_stages = 2 threads = 256 -func = matmul(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, - accum_dtype, num_stages, threads) +func = matmul(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads) jit_kernel = tilelang.compile( func, out_idx=[2], @@ -75,7 +66,8 @@ jit_kernel = tilelang.compile( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, +) print(jit_kernel.get_kernel_source()) @@ -88,4 +80,4 @@ torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) profiler = jit_kernel.get_profiler() latency = profiler.do_bench() print(f"Latency: {latency} ms") -print(f"Flops: {2 * M * N * K / (latency/1e3) / 1e12} TFLOPS") +print(f"Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS") diff --git a/examples/gemm_sp/example_custom_compress.py b/examples/gemm_sp/example_custom_compress.py index 5125aed..fe3b152 100644 --- a/examples/gemm_sp/example_custom_compress.py +++ b/examples/gemm_sp/example_custom_compress.py @@ -17,77 +17,76 @@ torch.manual_seed(42) DEFAULT_CONFIG = { # take best config from autotune script "4090": { - 'float': { - 'block_M': 128, - 'block_N': 64, - 'block_K': 64, - 'num_stages': 1, - 'thread_num': 128, - 'policy': T.GemmWarpPolicy.Square, - 'enable_rasterization': True + "float": { + "block_M": 128, + "block_N": 64, + "block_K": 64, + "num_stages": 1, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + "float16": { + "block_M": 256, + "block_N": 128, + "block_K": 64, + "num_stages": 2, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, }, - 'float16': { - 'block_M': 256, - 'block_N': 128, - 'block_K': 64, - 'num_stages': 2, - 'thread_num': 128, - 'policy': T.GemmWarpPolicy.Square, - 'enable_rasterization': True - } }, "h20": { - 'float': { - 'block_M': 128, - 'block_N': 64, - 'block_K': 128, - 'num_stages': 3, - 'thread_num': 128, - 'policy': T.GemmWarpPolicy.Square, - 'enable_rasterization': True + "float": { + "block_M": 128, + "block_N": 64, + "block_K": 128, + "num_stages": 3, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + "float16": { + "block_M": 128, + "block_N": 64, + "block_K": 128, + "num_stages": 3, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, }, - 'float16': { - 'block_M': 128, - 'block_N': 64, - 'block_K': 128, - 'num_stages': 3, - 'thread_num': 128, - 'policy': T.GemmWarpPolicy.Square, - 'enable_rasterization': True - } - } + }, } ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")} @tilelang.jit(out_idx=[-1]) -def matmul_sp_fp16_custom_compress(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, - thread_num, policy, enable_rasterization, use_cutlass_layout): +def matmul_sp_fp16_custom_compress( + M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, enable_rasterization, use_cutlass_layout +): e_factor, e_dtype = (16, "int16") @T.prim_func def gemm_sp_fp16_custom_compress( - A_sparse: T.Tensor((M, K // 2), 'float16'), - E: T.Tensor((M, K // e_factor), e_dtype), - B: T.Tensor((K, N), 'float16'), - C: T.Tensor((M, N), accum_dtype), + A_sparse: T.Tensor((M, K // 2), "float16"), + E: T.Tensor((M, K // e_factor), e_dtype), + B: T.Tensor((K, N), "float16"), + C: T.Tensor((M, N), accum_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K // 2), 'float16') + A_shared = T.alloc_shared((block_M, block_K // 2), "float16") E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) - B_shared = T.alloc_shared((block_K, block_N), 'float16') + B_shared = T.alloc_shared((block_K, block_N), "float16") C_shared = T.alloc_shared((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) if use_cutlass_layout: - T.annotate_layout({ - E: - make_cutlass_metadata_layout( - E, mma_dtype="float16", arch="8.0", block_k=block_K), - E_shared: - make_cutlass_metadata_layout( - E_shared, mma_dtype="float16", arch="8.0", block_k=block_K), - }) + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype="float16", arch="8.0", block_k=block_K), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype="float16", arch="8.0", block_k=block_K), + } + ) T.clear(C_local) T.disable_warp_group_reg_alloc() T.use_swizzle(panel_size=10, enable=enable_rasterization) @@ -108,8 +107,7 @@ def torch_compress(dense): A naive compression function, where each 4-bit meta matches 4 elements in original matrix in row major layout. """ if dense.dim() != 2: - raise RuntimeError( - f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor") + raise RuntimeError(f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor") m, k = dense.shape @@ -131,9 +129,7 @@ def torch_compress(dense): if m % 32 != 0: raise RuntimeError(f"Number of rows of dense matrix {m} must be divisible by 32") if k % (4 * quadbits_per_meta_elem) != 0: - raise RuntimeError( - f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" - ) + raise RuntimeError(f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}") if dense.dtype != torch.float: ksparse = 4 @@ -194,19 +190,13 @@ def torch_compress(dense): sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) else: - sparse = dense_2.gather(-1, - idxs0.unsqueeze(-1) // 2).view( - m, k // 2) # type: ignore[possibly-undefined] + sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2) # type: ignore[possibly-undefined] meta_4 = idxs0 | (idxs1 << 2) meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) if quadbits_per_meta_elem == 4: - meta = ( - meta_n[:, :, 0] - | (meta_n[:, :, 1] << 4) - | (meta_n[:, :, 2] << 8) - | (meta_n[:, :, 3] << 12)) + meta = meta_n[:, :, 0] | (meta_n[:, :, 1] << 4) | (meta_n[:, :, 2] << 8) | (meta_n[:, :, 3] << 12) elif quadbits_per_meta_elem == 8: meta = ( meta_n[:, :, 0] @@ -216,7 +206,8 @@ def torch_compress(dense): | (meta_n[:, :, 4] << 16) | (meta_n[:, :, 5] << 20) | (meta_n[:, :, 6] << 24) - | (meta_n[:, :, 7] << 28)) + | (meta_n[:, :, 7] << 28) + ) return (sparse, meta) @@ -234,9 +225,11 @@ def decode_metadata(meta: torch.Tensor) -> torch.Tensor: @tilelang.jit( - out_idx=[1, 2], pass_configs={ + out_idx=[1, 2], + pass_configs={ tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True, - }) + }, +) def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout): e_factor, e_dtype = ARCH_INFO["8.0"] e_K = K // e_factor @@ -249,23 +242,21 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout): @T.prim_func def kernel( - A: T.Tensor((M, K), dtype), - A_sp: T.Tensor((M, K // 2), dtype), - E: T.Tensor((M, e_K), e_dtype), + A: T.Tensor((M, K), dtype), + A_sp: T.Tensor((M, K // 2), dtype), + E: T.Tensor((M, e_K), e_dtype), ): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=block_M) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype) E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) if use_cutlass_layout: - T.annotate_layout({ - E: - make_cutlass_metadata_layout( - E, mma_dtype="float16", arch="8.0", block_k=block_K), - E_shared: - make_cutlass_metadata_layout( - E_shared, mma_dtype="float16", arch="8.0", block_k=block_K), - }) + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype="float16", arch="8.0", block_k=block_K), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype="float16", arch="8.0", block_k=block_K), + } + ) T.clear(A_sp_shared) T.clear(E_shared) # TODO: alloc_var seems buggy here @@ -295,8 +286,7 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout): non_zero_elt_log_idx[1] = 3 for i in T.serial(elem): val = non_zero_elt_log_idx[i] - E_shared[tm, a_k // e_factor] |= T.shift_left( - val, 4 * (g_i % (e_factor // group)) + 2 * i) + E_shared[tm, a_k // e_factor] |= T.shift_left(val, 4 * (g_i % (e_factor // group)) + 2 * i) T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2]) T.copy(E_shared, E[bx * block_M, by * block_K // e_factor]) @@ -304,41 +294,27 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout): def main(): - parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") - parser.add_argument( - "--use_cutlass_layout", action='store_true', help="Use cutlass layout for E tensor") - parser.add_argument( - "--use_torch_compressor", action='store_true', help="Use torch sparse for reference") - parser.add_argument( - "--accum_dtype", - type=str, - default="float", - choices=["float", "float16"], - help="Accumulation datatype") + parser.add_argument("--use_cutlass_layout", action="store_true", help="Use cutlass layout for E tensor") + parser.add_argument("--use_torch_compressor", action="store_true", help="Use torch sparse for reference") + parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype") parser.add_argument("--cfg", type=str, choices=["4090"], default="4090") args = parser.parse_args() kernel = matmul_sp_fp16_custom_compress( - args.m, - args.n, - args.k, - args.accum_dtype, - **DEFAULT_CONFIG[args.cfg][args.accum_dtype], - use_cutlass_layout=args.use_cutlass_layout) + args.m, args.n, args.k, args.accum_dtype, **DEFAULT_CONFIG[args.cfg][args.accum_dtype], use_cutlass_layout=args.use_cutlass_layout + ) - a = randn_semi_sparse(args.m, args.k, device='cuda', dtype=torch.half) - b = torch.randn(args.k, args.n, device='cuda', dtype=torch.half) + a = randn_semi_sparse(args.m, args.k, device="cuda", dtype=torch.half) + b = torch.randn(args.k, args.n, device="cuda", dtype=torch.half) if args.use_torch_compressor: assert not args.use_cutlass_layout, "torch sparse must be used with naive layout" a_sparse, e = torch_compress(a) else: - a_sparse, e = compress_kernel( - args.m, args.k, 32, 32, "float16", use_cutlass_layout=args.use_cutlass_layout)( - a) + a_sparse, e = compress_kernel(args.m, args.k, 32, 32, "float16", use_cutlass_layout=args.use_cutlass_layout)(a) c = kernel(a_sparse, e, b) @@ -346,9 +322,7 @@ def main(): assert not c.isnan().any(), "Reference result contains NaNs, please report an issue" torch_assert_close(c, ref_c.to(c.dtype), rtol=1e-3, atol=1e-3) - print( - f"Precision check passed. Max diff: {(c - ref_c).abs().max()}, Mean diff: {(c - ref_c).abs().mean()}" - ) + print(f"Precision check passed. Max diff: {(c - ref_c).abs().max()}, Mean diff: {(c - ref_c).abs().mean()}") latency = do_bench(lambda: kernel(a_sparse, e, b)) ref_latency = do_bench(lambda: a @ b) @@ -356,8 +330,8 @@ def main(): total_flops = 2 * args.m * args.n * args.k tflops = total_flops / latency / 1e9 ref_tflops = total_flops / ref_latency / 1e9 - print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency/1e3} s") - print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency/1e3:} s") + print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency / 1e3} s") + print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency / 1e3:} s") if __name__ == "__main__": diff --git a/examples/gemm_sp/example_gemm_sp.py b/examples/gemm_sp/example_gemm_sp.py index 91682a9..828ca43 100644 --- a/examples/gemm_sp/example_gemm_sp.py +++ b/examples/gemm_sp/example_gemm_sp.py @@ -16,80 +16,77 @@ arch = nvcc.get_target_compute_version() DEFAULT_CONFIG = { # take best config from autotune script "4090": { - 'float': { - 'block_M': 128, - 'block_N': 64, - 'block_K': 64, - 'num_stages': 1, - 'thread_num': 128, - 'policy': T.GemmWarpPolicy.Square, - 'enable_rasterization': True + "float": { + "block_M": 128, + "block_N": 64, + "block_K": 64, + "num_stages": 1, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + "float16": { + "block_M": 256, + "block_N": 128, + "block_K": 64, + "num_stages": 2, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, }, - 'float16': { - 'block_M': 256, - 'block_N': 128, - 'block_K': 64, - 'num_stages': 2, - 'thread_num': 128, - 'policy': T.GemmWarpPolicy.Square, - 'enable_rasterization': True - } }, "h20": { - 'float': { - 'block_M': 128, - 'block_N': 64, - 'block_K': 128, - 'num_stages': 3, - 'thread_num': 128, - 'policy': T.GemmWarpPolicy.Square, - 'enable_rasterization': True + "float": { + "block_M": 128, + "block_N": 64, + "block_K": 128, + "num_stages": 3, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + "float16": { + "block_M": 128, + "block_N": 64, + "block_K": 128, + "num_stages": 3, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, }, - 'float16': { - 'block_M': 128, - 'block_N': 64, - 'block_K': 128, - 'num_stages': 3, - 'thread_num': 128, - 'policy': T.GemmWarpPolicy.Square, - 'enable_rasterization': True - } - } + }, } ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")} @tilelang.jit(out_idx=[-1]) -def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, - enable_rasterization): +def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, enable_rasterization): e_factor, e_dtype = ARCH_INFO[arch] @T.prim_func def gemm_sp_fp16( - A_sparse: T.Tensor((M, K // 2), 'float16'), - E: T.Tensor((M, K // e_factor), e_dtype), - B: T.Tensor((K, N), 'float16'), - C: T.Tensor((M, N), accum_dtype), + A_sparse: T.Tensor((M, K // 2), "float16"), + E: T.Tensor((M, K // e_factor), e_dtype), + B: T.Tensor((K, N), "float16"), + C: T.Tensor((M, N), accum_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K // 2), 'float16') + A_shared = T.alloc_shared((block_M, block_K // 2), "float16") E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) - B_shared = T.alloc_shared((block_K, block_N), 'float16') + B_shared = T.alloc_shared((block_K, block_N), "float16") C_shared = T.alloc_shared((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) T.disable_warp_group_reg_alloc() T.use_swizzle(panel_size=10, enable=enable_rasterization) - T.annotate_layout({ - E: - make_cutlass_metadata_layout( - E, mma_dtype="float16", block_k=block_K, arch=arch), - E_shared: - make_cutlass_metadata_layout( - E_shared, mma_dtype="float16", block_k=block_K, arch=arch), - }) + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype="float16", block_k=block_K, arch=arch), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype="float16", block_k=block_K, arch=arch), + } + ) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) T.copy(E[by * block_M, k * block_K // e_factor], E_shared) @@ -107,25 +104,15 @@ def main(): parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") - parser.add_argument( - "--accum_dtype", - type=str, - default="float", - choices=["float", "float16"], - help="Accumulation datatype") + parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype") parser.add_argument("--cfg", type=str, choices=["4090", "h20"], default="4090") args = parser.parse_args() - kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype, - **DEFAULT_CONFIG[args.cfg][args.accum_dtype]) + kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype, **DEFAULT_CONFIG[args.cfg][args.accum_dtype]) - a = randn_semi_sparse(args.m, args.k, device='cuda', dtype=torch.half) - b = torch.randn(args.k, args.n, device='cuda', dtype=torch.half) + a = randn_semi_sparse(args.m, args.k, device="cuda", dtype=torch.half) + b = torch.randn(args.k, args.n, device="cuda", dtype=torch.half) - a_sparse, e = compress( - a, - transposed=False, - block_k=DEFAULT_CONFIG[args.cfg][args.accum_dtype]['block_K'], - arch=arch) + a_sparse, e = compress(a, transposed=False, block_k=DEFAULT_CONFIG[args.cfg][args.accum_dtype]["block_K"], arch=arch) c = kernel(a_sparse, e, b) ref_c = a @ b @@ -140,8 +127,8 @@ def main(): total_flops = 2 * args.m * args.n * args.k tflops = total_flops / latency / 1e9 ref_tflops = total_flops / ref_latency / 1e9 - print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency/1e3} s") - print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency/1e3:} s") + print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency / 1e3} s") + print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency / 1e3:} s") if __name__ == "__main__": diff --git a/examples/gemm_splitk/example_tilelang_gemm_splitk.py b/examples/gemm_splitk/example_tilelang_gemm_splitk.py index c966697..320a699 100644 --- a/examples/gemm_splitk/example_tilelang_gemm_splitk.py +++ b/examples/gemm_splitk/example_tilelang_gemm_splitk.py @@ -3,27 +3,16 @@ import tilelang.language as T @tilelang.jit -def matmul(M, - N, - K, - block_M, - block_N, - block_K, - split_k, - dtype="float16", - accum_dtype="float", - out_dtype="float32"): - +def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_dtype="float", out_dtype="float32"): splitK = K // split_k @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), out_dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz): A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype) C_shared = T.alloc_shared((block_M, block_N), out_dtype) diff --git a/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py b/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py index 145d622..dfd8471 100644 --- a/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py +++ b/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py @@ -3,27 +3,16 @@ import tilelang.language as T @tilelang.jit -def matmul(M, - N, - K, - block_M, - block_N, - block_K, - split_k, - dtype="float16", - accum_dtype="float", - out_dtype="float32"): - +def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_dtype="float", out_dtype="float32"): splitK = K // split_k @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), out_dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz): A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype) C_shared = T.alloc_shared((block_M, block_N), out_dtype) diff --git a/examples/gemm_streamk/example_tilelang_gemm_streamk.py b/examples/gemm_streamk/example_tilelang_gemm_streamk.py index 31cf406..2d83586 100644 --- a/examples/gemm_streamk/example_tilelang_gemm_streamk.py +++ b/examples/gemm_streamk/example_tilelang_gemm_streamk.py @@ -39,7 +39,7 @@ total_tiles = num_block_m * num_block_n # Two-tile SK + DP streamk_tiles = total_tiles % streamk_programs -if (total_tiles - streamk_tiles > streamk_programs): # (total_tiles // total_programs > 1) +if total_tiles - streamk_tiles > streamk_programs: # (total_tiles // total_programs > 1) streamk_tiles += streamk_programs blocking_tiles = total_tiles - streamk_tiles @@ -135,7 +135,6 @@ def tl_matmul_streamk( C: T.Tensor, C_local: T.LocalBuffer, ): - for p in T.serial(sm_patition_factor): tile_id = pid + streamk_tiles + p * total_sm pid_m = tile_id // T.ceildiv(N, block_N) @@ -150,12 +149,11 @@ def tl_matmul_streamk( @T.prim_func def main( - A: T.Tensor(A_shape, dtypeAB), - B: T.Tensor(B_shape, dtypeAB), - C: T.Tensor((M, N), dtypeC), + A: T.Tensor(A_shape, dtypeAB), + B: T.Tensor(B_shape, dtypeAB), + C: T.Tensor((M, N), dtypeC), ): with T.Kernel(streamk_programs, threads=threads) as pid: - A_shared = T.alloc_shared(A_shared_shape, dtypeAB) B_shared = T.alloc_shared(B_shared_shape, dtypeAB) A_shared_full_tiles = T.alloc_shared(A_shared_shape, dtypeAB) diff --git a/examples/gemv/example_gemv.py b/examples/gemv/example_gemv.py index 3772dc6..00cbac0 100644 --- a/examples/gemv/example_gemv.py +++ b/examples/gemv/example_gemv.py @@ -20,12 +20,11 @@ def naive_gemv( dtype: str = "float16", accum_dtype: str = "float", ): - @T.prim_func def main( - A: T.Tensor((K,), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((N,), dtype), + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, BLOCK_N)) as bn: tn = T.get_thread_binding(0) # tn = threadIdx.x @@ -38,8 +37,7 @@ def naive_gemv( A_shared[tk] = A[bk * BLOCK_K + tk] B_shared[tn, tk] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk] for tk in T.serial(BLOCK_K): - C_reg[0] += A_shared[tk].astype(accum_dtype) * B_shared[tn, - tk].astype(accum_dtype) + C_reg[0] += A_shared[tk].astype(accum_dtype) * B_shared[tn, tk].astype(accum_dtype) C[bn * BLOCK_N + tn] = C_reg[0] return main @@ -54,12 +52,11 @@ def naive_splitk_gemv( dtype: str = "float16", accum_dtype: str = "float", ): - @T.prim_func def main( - A: T.Tensor((K,), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((N,), dtype), + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, BLOCK_K)) as bn: tn = T.get_thread_binding(0) @@ -95,9 +92,9 @@ def splitk_gemv( @T.prim_func def main( - A: T.Tensor((K,), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((N,), dtype), + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: tn = T.get_thread_binding(0) @@ -136,9 +133,9 @@ def splitk_gemv_vectorized( @T.prim_func def main( - A: T.Tensor((K,), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((N,), dtype), + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: tn = T.get_thread_binding(0) @@ -177,9 +174,9 @@ def splitk_gemv_vectorized_tvm( @T.prim_func def main( - A: T.Tensor((K,), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((N,), dtype), + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: tn = T.get_thread_binding(0) @@ -197,9 +194,9 @@ def splitk_gemv_vectorized_tvm( C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) C_reduced = T.alloc_local((1,), accum_dtype) with T.attr( - T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), - "reduce_scope", - T.reinterpret(T.uint64(0), dtype="handle"), + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), ): T.evaluate( T.tvm_thread_allreduce( @@ -209,7 +206,8 @@ def splitk_gemv_vectorized_tvm( C_reduced[0], tk, dtype="handle", - )) + ) + ) C[bn * BLOCK_N + tn] = C_reduced[0] @@ -218,10 +216,8 @@ def splitk_gemv_vectorized_tvm( def get_block_template_configs(): iter_params = dict( - block_M=[2, 4, 8, 32, 64, 128], - block_N=[2, 4, 8, 32, 64, 128], - num_stages=[0, 1, 2, 3, 4], - threads=[32, 64, 128, 256]) + block_M=[2, 4, 8, 32, 64, 128], block_N=[2, 4, 8, 32, 64, 128], num_stages=[0, 1, 2, 3, 4], threads=[32, 64, 128, 256] + ) return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] @@ -237,18 +233,9 @@ def get_block_template_configs(): }, out_idx=[2], ) -def gemv_alloc_reducer(M, - N, - block_M=128, - block_N=128, - num_stages=2, - threads=256, - dtype: str = "float16", - accum_dtype: str = "float"): - +def gemv_alloc_reducer(M, N, block_M=128, block_N=128, num_stages=2, threads=256, dtype: str = "float16", accum_dtype: str = "float"): @T.prim_func - def main(a: T.Tensor((M, N), dtype), x: T.Tensor(N, dtype), o: T.Tensor(M, - dtype)): # type: ignore + def main(a: T.Tensor((M, N), dtype), x: T.Tensor(N, dtype), o: T.Tensor(M, dtype)): # type: ignore with T.Kernel(T.ceildiv(M, block_M), threads=threads) as i0_m: o_reducer = T.alloc_reducer(block_M, accum_dtype, replication="all") T.clear(o_reducer) @@ -295,9 +282,9 @@ def get_autotuned_kernel( @T.prim_func def main( - A: T.Tensor((K,), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((N,), dtype), + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: tn = T.get_thread_binding(0) @@ -315,9 +302,9 @@ def get_autotuned_kernel( C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) C_reduced = T.alloc_local((1,), accum_dtype) with T.attr( - T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), - "reduce_scope", - T.reinterpret(T.uint64(0), dtype="handle"), + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), ): T.evaluate( T.tvm_thread_allreduce( @@ -327,7 +314,8 @@ def get_autotuned_kernel( C_reduced[0], tk, dtype="handle", - )) + ) + ) C[bn * BLOCK_N + tn] = C_reduced[0] @@ -355,8 +343,7 @@ def main(do_bench: bool = True): check_correctness_and_bench(splitk_gemv(N, K, 32, 32, 32), N, K, do_bench=do_bench) check_correctness_and_bench(splitk_gemv_vectorized(N, K, 2, 32), N, K, do_bench=do_bench) check_correctness_and_bench(splitk_gemv_vectorized_tvm(N, K, 2, 32), N, K, do_bench=do_bench) - check_correctness_and_bench( - gemv_alloc_reducer(N, K, block_M=128, block_N=128), N, K, do_bench=do_bench) + check_correctness_and_bench(gemv_alloc_reducer(N, K, block_M=128, block_N=128), N, K, do_bench=do_bench) print("Test passed!") diff --git a/examples/grouped_gemm/example_grouped_gemm_bwd.py b/examples/grouped_gemm/example_grouped_gemm_bwd.py index ac8da7e..b1af536 100644 --- a/examples/grouped_gemm/example_grouped_gemm_bwd.py +++ b/examples/grouped_gemm/example_grouped_gemm_bwd.py @@ -5,21 +5,8 @@ import tilelang import tilelang.language as T -@tilelang.jit( - out_idx=[2], pass_configs={ - "tl.disable_tma_lower": True, - "tl.disable_warp_specialized": True - }) -def grouped_gemm_fwd(batch_sum, - batch_count, - K, - N, - block_M, - block_N, - block_K, - num_stages=2, - threads=128, - dtype="float16"): +@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) +def grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype="float16"): """ args: a (torch.Tensor): Input tensor of shape (M, K). @@ -29,17 +16,14 @@ def grouped_gemm_fwd(batch_sum, @T.prim_func def kernel( - A: T.Tensor([batch_sum, K], dtype), # type: ignore - B: T.Tensor([batch_count, K, N], dtype), # type: ignore - C: T.Tensor([batch_sum, N], dtype), # type: ignore - batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore - batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore - batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore + A: T.Tensor([batch_sum, K], dtype), # type: ignore + B: T.Tensor([batch_count, K, N], dtype), # type: ignore + C: T.Tensor([batch_sum, N], dtype), # type: ignore + batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore + batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore + batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore ): - - with T.Kernel( - T.ceildiv(batch_sum, block_M) + batch_count, T.ceildiv(N, block_N), - threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(batch_sum, block_M) + batch_count, T.ceildiv(N, block_N), threads=threads) as (bx, by): A_shared = T.alloc_shared([block_M, block_K], dtype) B_shared = T.alloc_shared([block_K, block_N], dtype) C_local = T.alloc_fragment([block_M, block_N], accum_dtype) @@ -49,23 +33,17 @@ def grouped_gemm_fwd(batch_sum, m_start_padded = bx * block_M for i in range(batch_count): - in_cur_batch_idx = (m_start_padded >= batch_padded_offsets[i]) + in_cur_batch_idx = m_start_padded >= batch_padded_offsets[i] cur_batch_idx[0] = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx[0]) cur_batch_size[0] = batch_sizes[cur_batch_idx[0]] - m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[ - cur_batch_idx[0]] - actual_rows = T.max( - 0, - T.min(block_M, - cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded)) + m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[cur_batch_idx[0]] + actual_rows = T.max(0, T.min(block_M, cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded)) T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A[m_start:m_start + block_M, k * block_K:(k + 1) * block_K], A_shared) - T.copy( - B[cur_batch_idx[0], k * block_K:(k + 1) * block_K, - by * block_N:(by + 1) * block_N], B_shared) + T.copy(A[m_start : m_start + block_M, k * block_K : (k + 1) * block_K], A_shared) + T.copy(B[cur_batch_idx[0], k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], B_shared) T.gemm(A_shared, B_shared, C_local) for i, j in T.Parallel(block_M, block_N): @@ -76,7 +54,6 @@ def grouped_gemm_fwd(batch_sum, class _GroupedGEMM(torch.autograd.Function): - @staticmethod def forward(ctx, a, b, batch_sizes): block_M = 64 @@ -99,15 +76,11 @@ class _GroupedGEMM(torch.autograd.Function): for i in range(batch_count - 1): batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes[i]) for i in range(batch_count - 1): - batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + - math.ceil((batch_sizes[i] + 1) / padding_M) * - padding_M) + batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + math.ceil((batch_sizes[i] + 1) / padding_M) * padding_M) batch_offsets = torch.tensor(batch_offsets_list, device=a.device, dtype=torch.int32) - batch_padded_offsets = torch.tensor( - batch_padded_offsets_list, device=a.device, dtype=torch.int32) + batch_padded_offsets = torch.tensor(batch_padded_offsets_list, device=a.device, dtype=torch.int32) - kernel = grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, - num_stages, threads) + kernel = grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, num_stages, threads) o = kernel(a, b, batch_sizes, batch_offsets, batch_padded_offsets) ctx.save_for_backward(a, b, batch_sizes, batch_offsets) @@ -135,8 +108,7 @@ class _GroupedGEMM(torch.autograd.Function): return x A, B, batch_sizes = [maybe_contiguous(x) for x in (A, B, batch_sizes)] - kernel = grouped_gemm_bwd(ctx.batch_sum, ctx.batch_count, M, N, block_M, block_N, block_K, - num_stages, threads) + kernel = grouped_gemm_bwd(ctx.batch_sum, ctx.batch_count, M, N, block_M, block_N, block_K, num_stages, threads) dB = kernel(A, grad_output, batch_sizes, batch_offsets) return None, dB, None @@ -172,9 +144,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): for i in range(batch_count - 1): batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes_list[i]) for i in range(batch_count - 1): - batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + - math.ceil((batch_sizes_list[i] + 1) / padding_M) * - padding_M) + batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + math.ceil((batch_sizes_list[i] + 1) / padding_M) * padding_M) A = torch.randn(batch_sum, K, device=device, dtype=dtype) B = torch.randn(batch_count, K, M, device=device, dtype=dtype) C = torch.empty(batch_sum, M, device=device, dtype=dtype) @@ -187,21 +157,8 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): return A, B, C, batch_sizes, batch_offsets, batch_padded_offsets -@tilelang.jit( - out_idx=[2], pass_configs={ - "tl.disable_tma_lower": True, - "tl.disable_warp_specialized": True - }) -def grouped_gemm_bwd(batch_sum, - batch_count, - M, - N, - block_M, - block_N, - block_K, - num_stages=2, - threads=128, - dtype="float16"): +@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) +def grouped_gemm_bwd(batch_sum, batch_count, M, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype="float16"): """ args: a (torch.Tensor): Input tensor of shape (M, K). @@ -211,16 +168,13 @@ def grouped_gemm_bwd(batch_sum, @T.prim_func def kernel( - A: T.Tensor([batch_sum, M], dtype), # type: ignore - B: T.Tensor([batch_sum, N], dtype), # type: ignore - C: T.Tensor([batch_count, M, N], dtype), # type: ignore - batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore - batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore + A: T.Tensor([batch_sum, M], dtype), # type: ignore + B: T.Tensor([batch_sum, N], dtype), # type: ignore + C: T.Tensor([batch_count, M, N], dtype), # type: ignore + batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore + batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore ): - - with T.Kernel( - T.ceildiv(M, block_M), T.ceildiv(N, block_N), batch_count, - threads=threads) as (bx, by, bz): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), batch_count, threads=threads) as (bx, by, bz): A_shared = T.alloc_shared([block_K, block_M], dtype) B_shared = T.alloc_shared([block_K, block_N], dtype) C_local = T.alloc_fragment([block_M, block_N], accum_dtype) @@ -228,13 +182,9 @@ def grouped_gemm_bwd(batch_sum, T.clear(C_local) for k in T.Pipelined(T.ceildiv(batch_sizes[bz], block_K), num_stages=num_stages): for i, j in T.Parallel(block_K, block_M): - A_shared[i, j] = T.if_then_else( - i < batch_sizes[bz], A[batch_offsets[bz] + k * block_K + i, - bx * block_M + j], 0) + A_shared[i, j] = T.if_then_else(i < batch_sizes[bz], A[batch_offsets[bz] + k * block_K + i, bx * block_M + j], 0) for i, j in T.Parallel(block_K, block_N): - B_shared[i, j] = T.if_then_else( - i < batch_sizes[bz], B[batch_offsets[bz] + k * block_K + i, - by * block_N + j], 0) + B_shared[i, j] = T.if_then_else(i < batch_sizes[bz], B[batch_offsets[bz] + k * block_K + i, by * block_N + j], 0) T.gemm(A_shared, B_shared, C_local, transpose_A=True) T.copy(C_local, C[bz, bx * block_M, by * block_N]) @@ -242,23 +192,12 @@ def grouped_gemm_bwd(batch_sum, return kernel -def run_tilelang_grouped_gemm(batch_sizes_list, - K, - M, - block_M, - block_N, - block_K, - trans_b, - num_stages=2, - threads=128, - profile=False): - +def run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages=2, threads=128, profile=False): padding_M = block_M device = torch.device("cuda") dtype = torch.float16 - A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs( - batch_sizes_list, K, M, False, padding_M, device, dtype) + A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs(batch_sizes_list, K, M, False, padding_M, device, dtype) A.requires_grad_(False) B.requires_grad_(True) @@ -273,10 +212,7 @@ def run_tilelang_grouped_gemm(batch_sizes_list, O.backward(dO, retain_graph=True) dB, B.grad = B.grad.clone(), None - if ( - torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) and \ - torch.allclose(dB, dB_ref, rtol=1e-2, atol=1e-2) - ): + if torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) and torch.allclose(dB, dB_ref, rtol=1e-2, atol=1e-2): print("✅ Tilelang and Torch match") else: print("❌ Tilelang and Torch mismatch") @@ -284,12 +220,11 @@ def run_tilelang_grouped_gemm(batch_sizes_list, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - '--batch_sizes', type=str, default="64, 128", help='comma-separated batch sizes') - parser.add_argument('--K', type=int, default=8192, help='reduce dim') - parser.add_argument('--M', type=int, default=8192, help='output dim') - parser.add_argument('--trans_b', action="store_true", help="transpose B") - parser.add_argument('--profile', action="store_true", help="profile") + parser.add_argument("--batch_sizes", type=str, default="64, 128", help="comma-separated batch sizes") + parser.add_argument("--K", type=int, default=8192, help="reduce dim") + parser.add_argument("--M", type=int, default=8192, help="output dim") + parser.add_argument("--trans_b", action="store_true", help="transpose B") + parser.add_argument("--profile", action="store_true", help="profile") args = parser.parse_args() batch_sizes_list = [int(x) for x in args.batch_sizes.split(",")] @@ -301,14 +236,4 @@ if __name__ == "__main__": num_stages = 2 threads = 256 - run_tilelang_grouped_gemm( - batch_sizes_list, - K, - M, - block_M, - block_N, - block_K, - trans_b, - num_stages, - threads, - profile=args.profile) + run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages, threads, profile=args.profile) diff --git a/examples/grouped_gemm/example_grouped_gemm_fwd.py b/examples/grouped_gemm/example_grouped_gemm_fwd.py index 9b58e3a..8f77105 100644 --- a/examples/grouped_gemm/example_grouped_gemm_fwd.py +++ b/examples/grouped_gemm/example_grouped_gemm_fwd.py @@ -18,8 +18,7 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False): torch.Tensor: Resulting tensor after grouped matrix multiplication. """ assert a.shape[0] == sum(batch_sizes), "Sum of batch_sizes must equal the first dimension of a" - assert b.shape[0] == len( - batch_sizes), "The first dimension of b must match the length of batch_sizes" + assert b.shape[0] == len(batch_sizes), "The first dimension of b must match the length of batch_sizes" # Initialize output tensor output = torch.empty((sum(batch_sizes), b.shape[2]), device=a.device, dtype=a.dtype) @@ -38,15 +37,7 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False): @tilelang.jit(out_idx=[2]) -def grouped_gemm(batch_sizes_list, - K, - N, - block_M, - block_N, - block_K, - num_stages=2, - threads=128, - dtype="float16"): +def grouped_gemm(batch_sizes_list, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype="float16"): """ args: a (torch.Tensor): Input tensor of shape (M, K). @@ -59,14 +50,13 @@ def grouped_gemm(batch_sizes_list, @T.prim_func def kernel( - A: T.Tensor([batch_sum, K], dtype), # type: ignore - B: T.Tensor([batch_count, K, N], dtype), # type: ignore - C: T.Tensor([batch_sum, N], dtype), # type: ignore - batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore - batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore - batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore + A: T.Tensor([batch_sum, K], dtype), # type: ignore + B: T.Tensor([batch_count, K, N], dtype), # type: ignore + C: T.Tensor([batch_sum, N], dtype), # type: ignore + batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore + batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore + batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore ): - with T.Kernel(total_m_blocks, T.ceildiv(N, block_N), threads=threads) as (bx, by): A_shared = T.alloc_shared([block_M, block_K], dtype) B_shared = T.alloc_shared([block_K, block_N], dtype) @@ -77,23 +67,17 @@ def grouped_gemm(batch_sizes_list, m_start_padded = bx * block_M for i in range(batch_count): - in_cur_batch_idx = (m_start_padded >= batch_padded_offsets[i]) + in_cur_batch_idx = m_start_padded >= batch_padded_offsets[i] cur_batch_idx[0] = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx[0]) cur_batch_size[0] = batch_sizes[cur_batch_idx[0]] - m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[ - cur_batch_idx[0]] - actual_rows = T.max( - 0, - T.min(block_M, - cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded)) + m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[cur_batch_idx[0]] + actual_rows = T.max(0, T.min(block_M, cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded)) T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A[m_start:m_start + block_M, k * block_K:(k + 1) * block_K], A_shared) - T.copy( - B[cur_batch_idx[0], k * block_K:(k + 1) * block_K, - by * block_N:(by + 1) * block_N], B_shared) + T.copy(A[m_start : m_start + block_M, k * block_K : (k + 1) * block_K], A_shared) + T.copy(B[cur_batch_idx[0], k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], B_shared) T.gemm(A_shared, B_shared, C_local) for i, j in T.Parallel(block_M, block_N): @@ -111,8 +95,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): for i in range(batch_count - 1): batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes_list[i]) for i in range(batch_count - 1): - batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + - math.ceil((batch_sizes_list[i]) / padding_M) * padding_M) + batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + math.ceil((batch_sizes_list[i]) / padding_M) * padding_M) A = torch.randn(batch_sum, K, device=device, dtype=dtype) B = torch.randn(batch_count, K, M, device=device, dtype=dtype) C = torch.empty(batch_sum, M, device=device, dtype=dtype) @@ -125,27 +108,16 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): return A, B, C, batch_sizes, batch_offsets, batch_padded_offsets -def run_tilelang_grouped_gemm(batch_sizes_list, - K, - M, - block_M, - block_N, - block_K, - trans_b, - num_stages=2, - threads=128, - profile=False): +def run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages=2, threads=128, profile=False): padding_M = block_M batch_sum = sum(batch_sizes_list) - kernel = grouped_gemm( - tuple(batch_sizes_list), K, M, block_M, block_N, block_K, num_stages, threads) + kernel = grouped_gemm(tuple(batch_sizes_list), K, M, block_M, block_N, block_K, num_stages, threads) # print(kernel.get_kernel_source()) device = torch.device("cuda") dtype = torch.float16 - A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs( - batch_sizes_list, K, M, trans_b, padding_M, device, dtype) + A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype) out = kernel(A, B, batch_sizes, batch_offsets, batch_padded_offsets) ref_output = torch_gmm(A, B, batch_sizes, batch_offsets, trans_b) # print(out) @@ -157,8 +129,7 @@ def run_tilelang_grouped_gemm(batch_sizes_list, if profile: profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) - latency = profiler.do_bench( - warmup=500, input_tensors=[A, B, batch_sizes, batch_offsets, batch_padded_offsets]) + latency = profiler.do_bench(warmup=500, input_tensors=[A, B, batch_sizes, batch_offsets, batch_padded_offsets]) print(f"Latency: {latency} ms") print(f"TFlops: {batch_sum * K * M * 2 / latency * 1e-9} TFlops") @@ -173,12 +144,11 @@ def test_grouped_gemm(): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - '--batch_sizes', type=str, default="64, 128", help='comma-separated batch sizes') - parser.add_argument('--K', type=int, default=8192, help='reduce dim') - parser.add_argument('--M', type=int, default=8192, help='output dim') - parser.add_argument('--trans_b', action="store_true", help="transpose B") - parser.add_argument('--profile', action="store_true", help="profile") + parser.add_argument("--batch_sizes", type=str, default="64, 128", help="comma-separated batch sizes") + parser.add_argument("--K", type=int, default=8192, help="reduce dim") + parser.add_argument("--M", type=int, default=8192, help="output dim") + parser.add_argument("--trans_b", action="store_true", help="transpose B") + parser.add_argument("--profile", action="store_true", help="profile") args = parser.parse_args() batch_sizes_list = [int(x) for x in args.batch_sizes.split(",")] @@ -190,14 +160,4 @@ if __name__ == "__main__": num_stages = 2 threads = 256 - run_tilelang_grouped_gemm( - batch_sizes_list, - K, - M, - block_M, - block_N, - block_K, - trans_b, - num_stages, - threads, - profile=args.profile) + run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages, threads, profile=args.profile) diff --git a/examples/hadamard_transform/example_hadamard.py b/examples/hadamard_transform/example_hadamard.py index 531d468..64eb9bb 100644 --- a/examples/hadamard_transform/example_hadamard.py +++ b/examples/hadamard_transform/example_hadamard.py @@ -17,7 +17,7 @@ def is_pow_of_2(n): def hadamard(b, n, dtype): assert is_pow_of_2(n), "n must be a power of 2" assert 2 <= n <= 32768, "n must be in [2, 32768]" - elem_size = {'float32': 4, 'float16': 2, 'bfloat16': 2}[dtype] + elem_size = {"float32": 4, "float16": 2, "bfloat16": 2}[dtype] logN = int(math.log2(n)) threads = [0, 1, 1, 1, 2, 4, 8, 16, 32, 32, 128, 256, 256, 256, 256, 256][logN] @@ -40,23 +40,21 @@ def hadamard(b, n, dtype): # print(f'{exchange_round=}') @T.macro - def warp_shfl(local: T.Tensor((thread_elem,), dtype), buf: T.Tensor((thread_elem,), dtype), - round: int): + def warp_shfl(local: T.Tensor((thread_elem,), dtype), buf: T.Tensor((thread_elem,), dtype), round: int): tx = T.get_thread_binding(0) for i in T.serial(round): tx_stride = 1 << i another_tx = tx ^ tx_stride - sign = ( - tx >> i - ) & 1 # get i-th lowest bit of tx, which determines the operation type for shared[tx, :] + sign = (tx >> i) & 1 # get i-th lowest bit of tx, which determines the operation type for shared[tx, :] for j in T.Pipelined(thread_elem, num_stages=1): buf[j] = T.tvm_warp_shuffle( - 0xffffffff, # mask of all threads + 0xFFFFFFFF, # mask of all threads local[j], another_tx % warp_size, warp_size, - warp_size) + warp_size, + ) local[j] = T.if_then_else(sign == 0, local[j] + buf[j], buf[j] - local[j]) @T.prim_func @@ -78,10 +76,8 @@ def hadamard(b, n, dtype): for j in T.serial(chunknum): chunkbase = j * chunksize for k in T.serial(chunksize // 2): - local[chunkbase + - k] = local[chunkbase + k] + local[chunkbase + k + chunksize // 2] - local[chunkbase + k + chunksize // - 2] = local[chunkbase + k] - 2 * local[chunkbase + k + chunksize // 2] + local[chunkbase + k] = local[chunkbase + k] + local[chunkbase + k + chunksize // 2] + local[chunkbase + k + chunksize // 2] = local[chunkbase + k] - 2 * local[chunkbase + k + chunksize // 2] # 3. Hadamard inside warp, n<=512 # In warp level, we rely on warp shuffle to exchange data inside each warp, without using shared memory @@ -131,28 +127,27 @@ def ref_program(x: torch.Tensor): assert x.ndim == 2 dim = x.shape[-1] assert is_pow_of_2(dim) - return F.linear( - x, torch.tensor(scipy.linalg.hadamard(dim, dtype=float), dtype=x.dtype, device=x.device)) + return F.linear(x, torch.tensor(scipy.linalg.hadamard(dim, dtype=float), dtype=x.dtype, device=x.device)) def main(): parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=64, help='Batch size') - parser.add_argument('--dim', type=int, default=32768, help='Dimension') + parser.add_argument("--batch", type=int, default=64, help="Batch size") + parser.add_argument("--dim", type=int, default=32768, help="Dimension") args = parser.parse_args() B, D = args.batch, args.dim - x = torch.randn((B, D), device='cuda') - kernel = hadamard(B, D, 'float32') + x = torch.randn((B, D), device="cuda") + kernel = hadamard(B, D, "float32") y = kernel(x) y_ref = ref_program(x) torch.testing.assert_close(y, y_ref, atol=1e-2, rtol=1e-2) - print('All tests passed.') + print("All tests passed.") profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) latency = profiler.do_bench(warmup=100) print("Tile-lang: {:.2f} ms".format(latency)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/lazy_jit/lazyjit.en.ipynb b/examples/lazy_jit/lazyjit.en.ipynb index acb318c..196ddfc 100644 --- a/examples/lazy_jit/lazyjit.en.ipynb +++ b/examples/lazy_jit/lazyjit.en.ipynb @@ -9,6 +9,7 @@ "source": [ "import sys\n", "from pathlib import Path\n", + "\n", "sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n", "import tilelang\n", "import torch\n", @@ -61,7 +62,7 @@ " out_dtype: T.dtype = T.float32,\n", " block_M: int = 128,\n", " block_N: int = 128,\n", - " block_K: int = 32\n", + " block_K: int = 32,\n", "):\n", " M, K = A.shape\n", " K, N = B.shape\n", @@ -94,8 +95,8 @@ "metadata": {}, "outputs": [], "source": [ - "A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", - "B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n", + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", "C = gemm(A, B)\n", "\n", "# check output is correct\n", @@ -118,8 +119,8 @@ "metadata": {}, "outputs": [], "source": [ - "A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", - "B = torch.randn(512, 1024, dtype=torch.float16, device='cuda')\n", + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 1024, dtype=torch.float16, device=\"cuda\")\n", "C = gemm(A, B, block_M=64, block_N=64)" ] }, @@ -218,8 +219,8 @@ "source": [ "@tilelang.lazy_jit\n", "def gemm_dyn_K(\n", - " A: T.Tensor[[int, T.dyn['K']], T.float16], # noqa: F821\n", - " B: T.Tensor[[T.dyn['K'], int], T.float16], # noqa: F821\n", + " A: T.Tensor[[int, T.dyn[\"K\"]], T.float16], # noqa: F821\n", + " B: T.Tensor[[T.dyn[\"K\"], int], T.float16], # noqa: F821\n", "):\n", " M, K = A.shape\n", " K, N = B.shape\n", @@ -265,8 +266,8 @@ "metadata": {}, "outputs": [], "source": [ - "A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", - "B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n", + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", "C = gemm_dyn_K(A, B)\n", "C_ref = (A @ B).float()\n", "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" @@ -295,18 +296,17 @@ "source": [ "from typing import Any\n", "\n", + "\n", "@tilelang.lazy_jit\n", - "def as_contingious(\n", - " A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]\n", - "):\n", + "def as_contingious(A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]):\n", " M, N = A.shape\n", " B = T.empty((M, N), A.dtype)\n", " block_M = 128\n", " block_N = 128\n", " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", " T.copy(\n", - " A[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N],\n", - " B[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N]\n", + " A[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", + " B[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", " )\n", " return B" ] @@ -318,7 +318,7 @@ "metadata": {}, "outputs": [], "source": [ - "A = torch.randn(1024, 1024, device='cuda')\n", + "A = torch.randn(1024, 1024, device=\"cuda\")\n", "B = as_contingious(A[::2, ::2])\n", "B_ref = A[::2, ::2].contiguous()\n", "torch.testing.assert_close(B, B_ref)" @@ -370,8 +370,8 @@ "metadata": {}, "outputs": [], "source": [ - "A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", - "B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n", + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", "C = gemm_ptr(A, B, 1024, 256, 512)\n", "C_ref = (A @ B).float()\n", "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" @@ -416,8 +416,8 @@ "metadata": {}, "outputs": [], "source": [ - "A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", - "B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n", + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", "C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n", "C_ref = (A @ B).float()\n", "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" @@ -496,18 +496,20 @@ "source": [ "from itertools import product\n", "\n", + "\n", "def get_configs():\n", " return [\n", " {\n", - " 'A': T.Tensor((1024, 1024), T.float32),\n", - " 'B': T.Tensor((1024, 1024), T.float32),\n", - " 'block_M': block_M,\n", - " 'block_N': block_N,\n", - " 'block_K': block_K,\n", + " \"A\": T.Tensor((1024, 1024), T.float32),\n", + " \"B\": T.Tensor((1024, 1024), T.float32),\n", + " \"block_M\": block_M,\n", + " \"block_N\": block_N,\n", + " \"block_K\": block_K,\n", " }\n", " for block_M, block_N, block_K in product([32, 64], repeat=3)\n", " ]\n", "\n", + "\n", "gemm.par_compile(get_configs())" ] }, @@ -579,7 +581,8 @@ "source": [ "@T.macro\n", "def macro_with_ref(x: T.Ref):\n", - " x = 1 # noqa: F841\n", + " x = 1 # noqa: F841\n", + "\n", "\n", "@T.prim_func\n", "def foo(x: T.Tensor((2,))):\n", @@ -591,6 +594,7 @@ " idx = T.alloc_var(T.int32, 0)\n", " macro_with_ref(x[idx])\n", "\n", + "\n", "foo" ] }, @@ -616,7 +620,7 @@ " A: T.Tensor[[T.dyn], Any],\n", " fn,\n", "):\n", - " N, = A.shape\n", + " (N,) = A.shape\n", " B = T.empty((N,), dtype=A.dtype)\n", " block_N = 128\n", " with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n", @@ -624,6 +628,8 @@ " idx = bx * block_N + i\n", " B[idx] = fn(A[idx])\n", " return B\n", + "\n", + "\n", "@T.macro\n", "def add_one(x):\n", " return x + 1" @@ -636,7 +642,7 @@ "metadata": {}, "outputs": [], "source": [ - "A = torch.randn(1024, device='cuda')\n", + "A = torch.randn(1024, device=\"cuda\")\n", "B = element_wise(A, add_one)\n", "B_ref = A + 1\n", "torch.testing.assert_close(B, B_ref)" @@ -670,10 +676,11 @@ " var = var * 3 + 1\n", " n31(x * 3 + 1, var)\n", "\n", + "\n", "@tilelang.lazy_jit\n", "def foo(A: T.Tensor[[1], T.int32], n: int):\n", " with T.Kernel(1) as _:\n", - " n31(n, A[0])\n" + " n31(n, A[0])" ] }, { @@ -694,7 +701,7 @@ } ], "source": [ - "A = torch.tensor([100], dtype=torch.int32, device='cuda')\n", + "A = torch.tensor([100], dtype=torch.int32, device=\"cuda\")\n", "foo(A, 5)\n", "A" ] @@ -745,12 +752,15 @@ "def sincos(x):\n", " return T.sin(x), T.cos(x)\n", "\n", + "\n", "@T.prim_func\n", "def foo():\n", " with T.Kernel(32) as x:\n", " s, c = sincos(x)\n", - " a = s + c # noqa: F841\n", - " b = s - c # noqa: F841\n", + " a = s + c # noqa: F841\n", + " b = s - c # noqa: F841\n", + "\n", + "\n", "foo" ] } diff --git a/examples/lazy_jit/lazyjit.zh.ipynb b/examples/lazy_jit/lazyjit.zh.ipynb index fb9b71b..d6db4c7 100644 --- a/examples/lazy_jit/lazyjit.zh.ipynb +++ b/examples/lazy_jit/lazyjit.zh.ipynb @@ -9,6 +9,7 @@ "source": [ "import sys\n", "from pathlib import Path\n", + "\n", "sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n", "import tilelang\n", "import torch\n", @@ -61,7 +62,7 @@ " out_dtype: T.dtype = T.float32,\n", " block_M: int = 128,\n", " block_N: int = 128,\n", - " block_K: int = 32\n", + " block_K: int = 32,\n", "):\n", " M, K = A.shape\n", " K, N = B.shape\n", @@ -94,8 +95,8 @@ "metadata": {}, "outputs": [], "source": [ - "A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", - "B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n", + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", "C = gemm(A, B)\n", "\n", "# check output is correct\n", @@ -118,8 +119,8 @@ "metadata": {}, "outputs": [], "source": [ - "A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", - "B = torch.randn(512, 1024, dtype=torch.float16, device='cuda')\n", + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 1024, dtype=torch.float16, device=\"cuda\")\n", "C = gemm(A, B, block_M=64, block_N=64)" ] }, @@ -218,8 +219,8 @@ "source": [ "@tilelang.lazy_jit\n", "def gemm_dyn_K(\n", - " A: T.Tensor[[int, T.dyn['K']], T.float16], # noqa: F821\n", - " B: T.Tensor[[T.dyn['K'], int], T.float16], # noqa: F821\n", + " A: T.Tensor[[int, T.dyn[\"K\"]], T.float16], # noqa: F821\n", + " B: T.Tensor[[T.dyn[\"K\"], int], T.float16], # noqa: F821\n", "):\n", " M, K = A.shape\n", " K, N = B.shape\n", @@ -265,8 +266,8 @@ "metadata": {}, "outputs": [], "source": [ - "A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", - "B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n", + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", "C = gemm_dyn_K(A, B)\n", "C_ref = (A @ B).float()\n", "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" @@ -295,18 +296,17 @@ "source": [ "from typing import Any\n", "\n", + "\n", "@tilelang.lazy_jit\n", - "def as_contingious(\n", - " A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]\n", - "):\n", + "def as_contingious(A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]):\n", " M, N = A.shape\n", " B = T.empty((M, N), A.dtype)\n", " block_M = 128\n", " block_N = 128\n", " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", " T.copy(\n", - " A[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N],\n", - " B[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N]\n", + " A[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", + " B[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", " )\n", " return B" ] @@ -318,7 +318,7 @@ "metadata": {}, "outputs": [], "source": [ - "A = torch.randn(1024, 1024, device='cuda')\n", + "A = torch.randn(1024, 1024, device=\"cuda\")\n", "B = as_contingious(A[::2, ::2])\n", "B_ref = A[::2, ::2].contiguous()\n", "torch.testing.assert_close(B, B_ref)" @@ -370,8 +370,8 @@ "metadata": {}, "outputs": [], "source": [ - "A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", - "B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n", + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", "C = gemm_ptr(A, B, 1024, 256, 512)\n", "C_ref = (A @ B).float()\n", "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" @@ -416,8 +416,8 @@ "metadata": {}, "outputs": [], "source": [ - "A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", - "B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n", + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", "C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n", "C_ref = (A @ B).float()\n", "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" @@ -496,18 +496,20 @@ "source": [ "from itertools import product\n", "\n", + "\n", "def get_configs():\n", " return [\n", " {\n", - " 'A': T.Tensor((1024, 1024), T.float32),\n", - " 'B': T.Tensor((1024, 1024), T.float32),\n", - " 'block_M': block_M,\n", - " 'block_N': block_N,\n", - " 'block_K': block_K,\n", + " \"A\": T.Tensor((1024, 1024), T.float32),\n", + " \"B\": T.Tensor((1024, 1024), T.float32),\n", + " \"block_M\": block_M,\n", + " \"block_N\": block_N,\n", + " \"block_K\": block_K,\n", " }\n", " for block_M, block_N, block_K in product([32, 64], repeat=3)\n", " ]\n", "\n", + "\n", "gemm.par_compile(get_configs())" ] }, @@ -579,7 +581,8 @@ "source": [ "@T.macro\n", "def macro_with_ref(x: T.Ref):\n", - " x = 1 # noqa: F841\n", + " x = 1 # noqa: F841\n", + "\n", "\n", "@T.prim_func\n", "def foo(x: T.Tensor((2,))):\n", @@ -591,6 +594,7 @@ " idx = T.alloc_var(T.int32, 0)\n", " macro_with_ref(x[idx])\n", "\n", + "\n", "foo" ] }, @@ -616,7 +620,7 @@ " A: T.Tensor[[T.dyn], Any],\n", " fn,\n", "):\n", - " N, = A.shape\n", + " (N,) = A.shape\n", " B = T.empty((N,), dtype=A.dtype)\n", " block_N = 128\n", " with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n", @@ -624,6 +628,8 @@ " idx = bx * block_N + i\n", " B[idx] = fn(A[idx])\n", " return B\n", + "\n", + "\n", "@T.macro\n", "def add_one(x):\n", " return x + 1" @@ -636,7 +642,7 @@ "metadata": {}, "outputs": [], "source": [ - "A = torch.randn(1024, device='cuda')\n", + "A = torch.randn(1024, device=\"cuda\")\n", "B = element_wise(A, add_one)\n", "B_ref = A + 1\n", "torch.testing.assert_close(B, B_ref)" @@ -670,10 +676,11 @@ " var = var * 3 + 1\n", " n31(x * 3 + 1, var)\n", "\n", + "\n", "@tilelang.lazy_jit\n", "def foo(A: T.Tensor[[1], T.int32], n: int):\n", " with T.Kernel(1) as _:\n", - " n31(n, A[0])\n" + " n31(n, A[0])" ] }, { @@ -694,7 +701,7 @@ } ], "source": [ - "A = torch.tensor([100], dtype=torch.int32, device='cuda')\n", + "A = torch.tensor([100], dtype=torch.int32, device=\"cuda\")\n", "foo(A, 5)\n", "A" ] @@ -745,12 +752,15 @@ "def sincos(x):\n", " return T.sin(x), T.cos(x)\n", "\n", + "\n", "@T.prim_func\n", "def foo():\n", " with T.Kernel(32) as x:\n", " s, c = sincos(x)\n", - " a = s + c # noqa: F841\n", - " b = s - c # noqa: F841\n", + " a = s + c # noqa: F841\n", + " b = s - c # noqa: F841\n", + "\n", + "\n", "foo" ] } diff --git a/examples/linear_attention/example_linear_attn_bwd.py b/examples/linear_attention/example_linear_attn_bwd.py index 568bcc5..7cbfc46 100644 --- a/examples/linear_attention/example_linear_attn_bwd.py +++ b/examples/linear_attention/example_linear_attn_bwd.py @@ -13,20 +13,20 @@ from typing import Optional, Tuple pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + } +) def tl_fused_chunk_bwd_kernel( B, S, H, DK, DV, - dtype: str = 'float16', + dtype: str = "float16", scale: float = None, ) -> torch.Tensor: - if scale is None: scale = DK**-0.5 - accum_dtype = 'float' + accum_dtype = "float" chunk_size = 64 BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA @@ -37,13 +37,13 @@ def tl_fused_chunk_bwd_kernel( @T.prim_func def fused_chunk_linear_attn_bwd( - Q: T.Tensor([B, S, H, DK], dtype), # type: ignore - K: T.Tensor([B, S, H, DK], dtype), # type: ignore - V: T.Tensor([B, S, H, DV], dtype), # type: ignore - dO: T.Tensor([B, S, H, DV], dtype), # type: ignore - dQ: T.Tensor([B, S, H, DK], accum_dtype), # type: ignore - dK: T.Tensor([B, S, H, DK], accum_dtype), # type: ignore - dV: T.Tensor([B, S, H, DV], accum_dtype), # type: ignore + Q: T.Tensor([B, S, H, DK], dtype), # type: ignore + K: T.Tensor([B, S, H, DK], dtype), # type: ignore + V: T.Tensor([B, S, H, DV], dtype), # type: ignore + dO: T.Tensor([B, S, H, DV], dtype), # type: ignore + dQ: T.Tensor([B, S, H, DK], accum_dtype), # type: ignore + dK: T.Tensor([B, S, H, DK], accum_dtype), # type: ignore + dV: T.Tensor([B, S, H, DV], accum_dtype), # type: ignore ): with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): i_b = i_bh // H @@ -66,11 +66,13 @@ def tl_fused_chunk_bwd_kernel( dh = T.alloc_fragment([BK, BV], accum_dtype) dh_shared = T.alloc_shared([BK, BV], dtype) - T.annotate_layout({ - dq_shared: tilelang.layout.make_swizzled_layout(dq_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared) - }) + T.annotate_layout( + { + dq_shared: tilelang.layout.make_swizzled_layout(dq_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + } + ) T.use_swizzle(10) T.clear(h) @@ -78,10 +80,9 @@ def tl_fused_chunk_bwd_kernel( # Calculate dQ for i in T.Pipelined(0, NT): - T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k) - T.copy(V[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v) - T.copy(dO[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], - do) + T.copy(K[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], k) + T.copy(V[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], v) + T.copy(dO[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], do) T.gemm(do, v, ds, transpose_B=True, clear_accum=True) for row, col in T.Parallel(chunk_size, chunk_size): @@ -94,29 +95,19 @@ def tl_fused_chunk_bwd_kernel( for row, col in T.Parallel(chunk_size, BK): dq[row, col] *= scale T.copy(dq, dq_shared) - T.atomic_add( - dQ[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], - dq_shared) + T.atomic_add(dQ[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], dq_shared) # Calculate dK, dV (reversely) for i in T.Pipelined(1, NT + 1): start = NT - i for row, col in T.Parallel(chunk_size, BK): q[row, col] = Q[i_b, start * chunk_size + row, i_h, i_k * BK + col] * scale - T.copy( - K[i_b, start * chunk_size:(start + 1) * chunk_size, i_h, - i_k * BK:(i_k + 1) * BK], k) - T.copy( - V[i_b, start * chunk_size:(start + 1) * chunk_size, i_h, - i_v * BV:(i_v + 1) * BV], v) - T.copy( - dO[i_b, start * chunk_size:(start + 1) * chunk_size, i_h, - i_v * BV:(i_v + 1) * BV], do) + T.copy(K[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], k) + T.copy(V[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], v) + T.copy(dO[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], do) # Calculate dk - T.gemm( - v, do, ds, transpose_B=True, clear_accum=True - ) # ds here actually means `s`, but we simply reuse the buffer `ds` + T.gemm(v, do, ds, transpose_B=True, clear_accum=True) # ds here actually means `s`, but we simply reuse the buffer `ds` for row, col in T.Parallel(chunk_size, chunk_size): ds_shared[row, col] = T.if_then_else(row <= col, ds[row, col], 0) T.gemm(ds_shared, q, dk, clear_accum=True) @@ -134,13 +125,9 @@ def tl_fused_chunk_bwd_kernel( T.gemm(q, do, dh, transpose_A=True) T.copy(dk, dk_shared) - T.atomic_add( - dK[i_b, start * chunk_size:(start + 1) * chunk_size, i_h, - i_k * BK:(i_k + 1) * BK], dk_shared) + T.atomic_add(dK[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], dk_shared) T.copy(dv, dv_shared) - T.atomic_add( - dV[i_b, start * chunk_size:(start + 1) * chunk_size, i_h, - i_v * BV:(i_v + 1) * BV], dv_shared) + T.atomic_add(dV[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], dv_shared) return fused_chunk_linear_attn_bwd @@ -155,34 +142,31 @@ def tl_fused_chunk_bwd(Q, K, V, dO): return dQ.to(torch.float16), dK.to(torch.float16), dV.to(torch.float16) -def ref_program(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - scale: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]: +def ref_program(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]: q, k, v = q.float(), k.float(), v.float() if scale is None: - scale = q.shape[-1]**-0.5 + scale = q.shape[-1] ** -0.5 chunk_size = 64 - q = rearrange(q, 'b (n c) h d -> b h n c d', c=chunk_size) * scale - k = rearrange(k, 'b (n c) h d -> b h n c d', c=chunk_size) - v = rearrange(v, 'b (n c) h d -> b h n c d', c=chunk_size) + q = rearrange(q, "b (n c) h d -> b h n c d", c=chunk_size) * scale + k = rearrange(k, "b (n c) h d -> b h n c d", c=chunk_size) + v = rearrange(v, "b (n c) h d -> b h n c d", c=chunk_size) kv = k.transpose(-1, -2) @ v kv = kv.cumsum(2) h = kv[:, :, -1, :, :] kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) inter = q @ kv - intra = ((q @ k.transpose(-1, -2)).masked_fill_( - torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), - 0)) @ v + intra = ( + (q @ k.transpose(-1, -2)).masked_fill_(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0) + ) @ v o = inter + intra - return rearrange(o, 'b h n c d -> b (n c) h d'), h + return rearrange(o, "b h n c d -> b (n c) h d"), h def main(B=1, S=1024, H=16, D=128): - q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True) - k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True) - v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True) - do = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) + q = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16, requires_grad=True) + k = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16, requires_grad=True) + v = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16, requires_grad=True) + do = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) # qk norm is necessary for linear attn q = l2norm_fwd(q)[0].requires_grad_(True) @@ -193,30 +177,27 @@ def main(B=1, S=1024, H=16, D=128): o_ref, _ = ref_program(q, k, v) o_ref.backward(do, retain_graph=True) - assert torch.allclose( - dq, q.grad, atol=1e-2, rtol=1e-2), f'dq max err: {(dq - q.grad).abs().max()}' - assert torch.allclose( - dk, k.grad, atol=1e-2, rtol=1e-2), f'dk max err: {(dk - k.grad).abs().max()}' - assert torch.allclose( - dv, v.grad, atol=1e-2, rtol=1e-2), f'dv max err: {(dv - v.grad).abs().max()}' - print('Passed all tests!✅') + assert torch.allclose(dq, q.grad, atol=1e-2, rtol=1e-2), f"dq max err: {(dq - q.grad).abs().max()}" + assert torch.allclose(dk, k.grad, atol=1e-2, rtol=1e-2), f"dk max err: {(dk - k.grad).abs().max()}" + assert torch.allclose(dv, v.grad, atol=1e-2, rtol=1e-2), f"dv max err: {(dv - v.grad).abs().max()}" + print("Passed all tests!✅") # Benchmark q.grad = k.grad = v.grad = None o_ref, _ = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False) - t1 = do_bench(lambda: o_ref.backward(do, retain_graph=True), backend='cupti') - t2 = do_bench(lambda: tl_fused_chunk_bwd(q, k, v, do), backend='cupti') - print(f'Triton latency: {t1:.3f} ms') - print(f'TileLang latency: {t2:.3f} ms') - print(f'Speedup: {t1/t2:.3f}x') + t1 = do_bench(lambda: o_ref.backward(do, retain_graph=True), backend="cupti") + t2 = do_bench(lambda: tl_fused_chunk_bwd(q, k, v, do), backend="cupti") + print(f"Triton latency: {t1:.3f} ms") + print(f"TileLang latency: {t2:.3f} ms") + print(f"Speedup: {t1 / t2:.3f}x") -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--B', type=int, default=8, help='Batch size') - parser.add_argument('--S', type=int, default=1024, help='Seq len') - parser.add_argument('--H', type=int, default=32, help='Num heads') - parser.add_argument('--D', type=int, default=128, help='Head dim') + parser.add_argument("--B", type=int, default=8, help="Batch size") + parser.add_argument("--S", type=int, default=1024, help="Seq len") + parser.add_argument("--H", type=int, default=32, help="Num heads") + parser.add_argument("--D", type=int, default=128, help="Head dim") args = parser.parse_args() main(args.B, args.S, args.H, args.D) diff --git a/examples/linear_attention/example_linear_attn_fwd.py b/examples/linear_attention/example_linear_attn_fwd.py index 03900a7..3d28f92 100644 --- a/examples/linear_attention/example_linear_attn_fwd.py +++ b/examples/linear_attention/example_linear_attn_fwd.py @@ -14,20 +14,20 @@ from typing import Optional, Tuple pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, +) def tl_fused_chunk_fwd_kernel( B, S, H, DK, DV, - dtype: str = 'float16', + dtype: str = "float16", scale: float = None, ) -> torch.Tensor: - if scale is None: scale = DK**-0.5 - accum_dtype = 'float' + accum_dtype = "float" chunk_size = 64 BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA @@ -38,11 +38,12 @@ def tl_fused_chunk_fwd_kernel( @T.prim_func def fused_chunk_linear_attn_fwd( - Q: T.Tensor([B, S, H, DK], dtype), # type: ignore - K: T.Tensor([B, S, H, DK], dtype), # type: ignore - V: T.Tensor([B, S, H, DV], dtype), # type: ignore - O: T.Tensor([B, S, H, DV], accum_dtype), # type: ignore - final_state: T.Tensor([B, H, DK, DV], accum_dtype)): # type: ignore + Q: T.Tensor([B, S, H, DK], dtype), # type: ignore + K: T.Tensor([B, S, H, DK], dtype), # type: ignore + V: T.Tensor([B, S, H, DV], dtype), # type: ignore + O: T.Tensor([B, S, H, DV], accum_dtype), # type: ignore + final_state: T.Tensor([B, H, DK, DV], accum_dtype), + ): # type: ignore with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): i_b = i_bh // H i_h = i_bh % H @@ -65,8 +66,8 @@ def tl_fused_chunk_fwd_kernel( for i in T.Pipelined(0, NT): for row, col in T.Parallel(chunk_size, BK): q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale - T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k) - T.copy(V[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v) + T.copy(K[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], k) + T.copy(V[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], v) T.gemm(q, k, s, clear_accum=True, transpose_B=True) for row, col in T.Parallel(chunk_size, chunk_size): @@ -77,12 +78,10 @@ def tl_fused_chunk_fwd_kernel( T.gemm(k, v, h, transpose_A=True) T.gemm(q, h_shared, o) T.copy(o, o_shared) - T.atomic_add( - O[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], - o_shared) + T.atomic_add(O[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], o_shared) # Output final state - T.copy(h, final_state[i_b, i_h, i_k * BK:(i_k + 1) * BK, i_v * BV:(i_v + 1) * BV]) + T.copy(h, final_state[i_b, i_h, i_k * BK : (i_k + 1) * BK, i_v * BV : (i_v + 1) * BV]) return fused_chunk_linear_attn_fwd @@ -91,38 +90,35 @@ def tl_fused_chunk_fwd(q, k, v): B, S, H, D = q.shape kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D) print(kernel.get_kernel_source()) - o = torch.zeros((B, S, H, D), device='cuda', dtype=torch.float32) + o = torch.zeros((B, S, H, D), device="cuda", dtype=torch.float32) h = kernel(q, k, v, o) return o, h -def ref_program(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - scale: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]: +def ref_program(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]: q, k, v = q.float(), k.float(), v.float() if scale is None: - scale = q.shape[-1]**-0.5 + scale = q.shape[-1] ** -0.5 chunk_size = 64 - q = rearrange(q, 'b (n c) h d -> b h n c d', c=chunk_size) * scale - k = rearrange(k, 'b (n c) h d -> b h n c d', c=chunk_size) - v = rearrange(v, 'b (n c) h d -> b h n c d', c=chunk_size) + q = rearrange(q, "b (n c) h d -> b h n c d", c=chunk_size) * scale + k = rearrange(k, "b (n c) h d -> b h n c d", c=chunk_size) + v = rearrange(v, "b (n c) h d -> b h n c d", c=chunk_size) kv = k.transpose(-1, -2) @ v kv = kv.cumsum(2) h = kv[:, :, -1, :, :] kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) inter = q @ kv - intra = ((q @ k.transpose(-1, -2)).masked_fill_( - torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), - 0)) @ v + intra = ( + (q @ k.transpose(-1, -2)).masked_fill_(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0) + ) @ v o = inter + intra - return rearrange(o, 'b h n c d -> b (n c) h d'), h + return rearrange(o, "b h n c d -> b (n c) h d"), h def main(B=1, S=512, H=16, D=128): - q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) - k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) - v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) + q = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + k = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + v = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) # qk norm is necessary for linear attn q, _ = l2norm_fwd(q) @@ -131,25 +127,23 @@ def main(B=1, S=512, H=16, D=128): o, h = tl_fused_chunk_fwd(q, k, v) o_ref, h_ref = ref_program(q, k, v) - assert torch.allclose(o, o_ref, atol=1e-2, rtol=1e-2), f'o max err: {(o - o_ref).abs().max()}' - assert torch.allclose(h, h_ref, atol=1e-2, rtol=1e-2), f'h max err: {(h - h_ref).abs().max()}' - print('Passed all tests!✅') + assert torch.allclose(o, o_ref, atol=1e-2, rtol=1e-2), f"o max err: {(o - o_ref).abs().max()}" + assert torch.allclose(h, h_ref, atol=1e-2, rtol=1e-2), f"h max err: {(h - h_ref).abs().max()}" + print("Passed all tests!✅") - t1 = do_bench( - lambda: fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False), - backend='cupti') - t2 = do_bench(lambda: tl_fused_chunk_fwd(q, k, v), backend='cupti') - print(f'Triton latency: {t1:.3f} ms') - print(f'TileLang latency: {t2:.3f} ms') - print(f'Speedup: {t1/t2:.3f}x') + t1 = do_bench(lambda: fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False), backend="cupti") + t2 = do_bench(lambda: tl_fused_chunk_fwd(q, k, v), backend="cupti") + print(f"Triton latency: {t1:.3f} ms") + print(f"TileLang latency: {t2:.3f} ms") + print(f"Speedup: {t1 / t2:.3f}x") -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--B', type=int, default=8, help='Batch size') - parser.add_argument('--S', type=int, default=1024, help='Seq len') - parser.add_argument('--H', type=int, default=32, help='Num heads') - parser.add_argument('--D', type=int, default=128, help='Head dim') + parser.add_argument("--B", type=int, default=8, help="Batch size") + parser.add_argument("--S", type=int, default=1024, help="Seq len") + parser.add_argument("--H", type=int, default=32, help="Num heads") + parser.add_argument("--D", type=int, default=128, help="Head dim") args = parser.parse_args() main(args.B, args.S, args.H, args.D) diff --git a/examples/linear_attention/example_mamba_chunk_scan.py b/examples/linear_attention/example_mamba_chunk_scan.py index add4905..53b6cf9 100644 --- a/examples/linear_attention/example_mamba_chunk_scan.py +++ b/examples/linear_attention/example_mamba_chunk_scan.py @@ -9,6 +9,7 @@ import itertools def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D): from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd + out, _ = _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D) return out @@ -43,14 +44,15 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D): dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :] decay = torch.exp(dt_segment_sum) scores_decay = cb * rearrange(decay, "b h c l s -> b c h l s") - causal_mask = torch.tril( - torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0) + causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0) scores_decay = scores_decay.masked_fill(~causal_mask, 0) - out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype), - rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)) + out = torch.einsum( + "bchls,bhcs,bcshp->bclhp", scores_decay.to(x.dtype), dt.to(x.dtype), rearrange(x, "b (c s) h p -> b c s h p", c=nchunks) + ) state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1")) - out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange( - C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out + out_prev = ( + torch.einsum("bclhn,bchpn->bclhp", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out + ) out = out + out_prev out = rearrange(out, "b c l h p -> b (c l) h p") if D is not None: @@ -61,12 +63,7 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D): def get_configs(): - iter_params = dict( - block_M=[64, 128, 256], - block_N=[32, 64], - block_K=[64, 128, 256], - block_Dstate=[128], - num_stages=[1, 2, 3, 4, 5]) + iter_params = dict(block_M=[64, 128, 256], block_N=[32, 64], block_K=[64, 128, 256], block_Dstate=[128], num_stages=[1, 2, 3, 4, 5]) return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] @@ -77,19 +74,21 @@ def get_configs(): tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }, ) -def chunk_scan_fwd(batch, - seqlen, - chunk_size, - ngroups, - nheads, - headdim, - dstate, - block_M=64, - block_N=64, - block_K=64, - block_Dstate=128, - num_stages=2, - threads=128): +def chunk_scan_fwd( + batch, + seqlen, + chunk_size, + ngroups, + nheads, + headdim, + dstate, + block_M=64, + block_N=64, + block_K=64, + block_Dstate=128, + num_stages=2, + threads=128, +): dtype = "float16" accum_dtype = "float" nchunks = T.ceildiv(seqlen, chunk_size) @@ -97,20 +96,20 @@ def chunk_scan_fwd(batch, @T.prim_func def main( - cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore - x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore - dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore - dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore - C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore - prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore - D: T.Tensor((nheads), dtype), # type: ignore - Output: T.Tensor((batch, seqlen, nheads, headdim), dtype) # type: ignore + cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore + x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore + dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore + dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore + C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore + prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore + D: T.Tensor((nheads), dtype), # type: ignore + Output: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore ): - with T.Kernel( - nheads, - T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), - batch * nchunks, - threads=threads) as (bz, bx, by): + with T.Kernel(nheads, T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), batch * nchunks, threads=threads) as ( + bz, + bx, + by, + ): acc_o = T.alloc_fragment((block_M, block_N), accum_dtype) acc_o_shared = T.alloc_shared((block_M, block_N), dtype) cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn") @@ -136,27 +135,32 @@ def chunk_scan_fwd(batch, m_idx = bx // T.ceildiv(headdim, block_N) n_idx = bx % T.ceildiv(headdim, block_N) - T.annotate_layout({ - acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared), - cb_shared: tilelang.layout.make_swizzled_layout(cb_shared), - x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared) - }) + T.annotate_layout( + { + acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared), + cb_shared: tilelang.layout.make_swizzled_layout(cb_shared), + x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared), + } + ) T.no_set_max_nreg() - T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M:(m_idx + 1) * block_M], - dA_cs_m_shared) + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M : (m_idx + 1) * block_M], dA_cs_m_shared) T.copy(dA_cs_m_shared, dA_cs_m_local) T.clear(acc_o) for i in T.Parallel(block_M): scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p) T.copy( - C[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + - (m_idx + 1) * block_M, bz // (nheads // ngroups), 0:block_Dstate], C_shared) - T.copy( - prev_states[batch_idx, chunk_idx, bz, n_idx * block_N:(n_idx + 1) * block_N, - 0:block_Dstate], prev_state_shared) + C[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz // (nheads // ngroups), + 0:block_Dstate, + ], + C_shared, + ) + T.copy(prev_states[batch_idx, chunk_idx, bz, n_idx * block_N : (n_idx + 1) * block_N, 0:block_Dstate], prev_state_shared) T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True) for i, j in T.Parallel(block_M, block_N): acc_o[i, j] *= scale_m_local[i] @@ -165,34 +169,47 @@ def chunk_scan_fwd(batch, for k in T.Pipelined(loop_range, num_stages=num_stages): T.copy( - cb[batch_idx, chunk_idx, bz // (nheads // ngroups), - m_idx * block_M:(m_idx + 1) * block_M, k * block_K:(k + 1) * block_K], - cb_shared) + cb[ + batch_idx, + chunk_idx, + bz // (nheads // ngroups), + m_idx * block_M : (m_idx + 1) * block_M, + k * block_K : (k + 1) * block_K, + ], + cb_shared, + ) T.copy(cb_shared, cb_local) - T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], - dA_cs_k_shared) + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dA_cs_k_shared) T.copy(dA_cs_k_shared, dA_cs_k_local) for i, j in T.Parallel(block_M, block_K): - cb_local[i, - j] = cb_local[i, - j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p) - T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared) + cb_local[i, j] = cb_local[i, j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p) + T.copy(dt[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dt_shared) T.copy(dt_shared, dt_local) for i, j in T.Parallel(block_M, block_K): cb_local[i, j] *= dt_local[j] for i, j in T.Parallel(block_M, block_K): - cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, - cb_local[i, j], 0) + cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, cb_local[i, j], 0) T.copy( - x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size + - (k + 1) * block_K, bz, n_idx * block_N:(n_idx + 1) * block_N], x_shared) + x[ + batch_idx, + chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + x_shared, + ) T.gemm(cb_local, x_shared, acc_o) D_local[0] = D[bz] T.copy( - x[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + - (m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N], - x_residual_shared) + x[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + x_residual_shared, + ) T.copy(x_residual_shared, x_residual_local) for i, j in T.Parallel(block_M, block_N): acc_o[i, j] += x_residual_local[i, j] * D_local[0] @@ -200,27 +217,40 @@ def chunk_scan_fwd(batch, T.copy(acc_o, acc_o_shared) T.copy( acc_o_shared, - Output[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + - (m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N]) + Output[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + ) return main if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=80, help='heads') - parser.add_argument('--groups', type=int, default=1, help='groups') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--chunk_size', type=int, default=256, help='chunk size') - parser.add_argument('--dim', type=int, default=64, help='dim') - parser.add_argument('--dstate', type=int, default=128, help='dstate') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=80, help="heads") + parser.add_argument("--groups", type=int, default=1, help="groups") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--chunk_size", type=int, default=256, help="chunk size") + parser.add_argument("--dim", type=int, default=64, help="dim") + parser.add_argument("--dstate", type=int, default=128, help="dstate") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() - batch, heads, groups, seq_len, chunk_size, dim, dstate = args.batch, args.heads, args.groups, args.seq_len, args.chunk_size, args.dim, args.dstate + batch, heads, groups, seq_len, chunk_size, dim, dstate = ( + args.batch, + args.heads, + args.groups, + args.seq_len, + args.chunk_size, + args.dim, + args.dstate, + ) total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate - if (not args.tune): + if not args.tune: kernel = chunk_scan_fwd( batch, seq_len, @@ -234,7 +264,8 @@ if __name__ == "__main__": block_K=64, block_Dstate=128, num_stages=2, - threads=128) + threads=128, + ) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) print("All checks pass.") diff --git a/examples/linear_attention/example_mamba_chunk_state.py b/examples/linear_attention/example_mamba_chunk_state.py index ad3df0d..6aefde7 100644 --- a/examples/linear_attention/example_mamba_chunk_state.py +++ b/examples/linear_attention/example_mamba_chunk_state.py @@ -10,6 +10,7 @@ import itertools def chunk_state_triton(B, x, dt, dA_cumsum): from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_fwd + return _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=False) @@ -41,46 +42,33 @@ def ref_program(B, x, dt, dA_cumsum): x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size) B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size) decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum)) - return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype), - dt.to(x.dtype), x) + return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype), dt.to(x.dtype), x) def get_configs(): - iter_params = dict( - block_M=[64, 128], block_N=[32, 64, 128], block_K=[32, 64], num_stages=[1, 2, 3, 4, 5]) + iter_params = dict(block_M=[64, 128], block_N=[32, 64, 128], block_K=[32, 64], num_stages=[1, 2, 3, 4, 5]) return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit(out_idx=[4]) -def chunk_state_fwd(batch, - seqlen, - chunk_size, - ngroups, - nheads, - headdim, - dstate, - block_M=64, - block_N=64, - block_K=64, - num_stages=2, - threads=128): +def chunk_state_fwd( + batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, block_M=64, block_N=64, block_K=64, num_stages=2, threads=128 +): dtype = "float16" accum_dtype = "float" nchunks = T.ceildiv(seqlen, chunk_size) p = 1.44269504 @T.prim_func - def main(B: T.Tensor((batch, seqlen, ngroups, dstate), dtype), x: T.Tensor( - (batch, seqlen, nheads, headdim), dtype), dt: T.Tensor( - (batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Tensor( - (batch, nheads, nchunks, chunk_size), dtype), Output: T.Tensor( - (batch, nchunks, nheads, headdim, dstate), dtype)): - with T.Kernel( - nheads, - T.ceildiv(headdim, block_M) * T.ceildiv(dstate, block_N), - batch * nchunks, - threads=threads) as (bz, bx, by): + def main( + B: T.Tensor((batch, seqlen, ngroups, dstate), dtype), + x: T.Tensor((batch, seqlen, nheads, headdim), dtype), + dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), + dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), + Output: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), + ): + with T.Kernel(nheads, T.ceildiv(headdim, block_M) * T.ceildiv(dstate, block_N), batch * nchunks, threads=threads) as (bz, bx, by): x_shared = T.alloc_shared((block_K, block_M), dtype) x_local = T.alloc_fragment((block_K, block_M), dtype) xt_local = T.alloc_fragment((block_M, block_K), dtype) @@ -101,20 +89,24 @@ def chunk_state_fwd(batch, m_idx = bx // T.ceildiv(dstate, block_N) n_idx = bx % T.ceildiv(dstate, block_N) - T.annotate_layout({ - x_shared: tilelang.layout.make_swizzled_layout(x_shared), - acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared) - }) + T.annotate_layout( + {x_shared: tilelang.layout.make_swizzled_layout(x_shared), acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared)} + ) dA_cs_last[0] = dA_cumsum[batch_idx, bz, chunk_idx, chunk_size - 1] T.clear(acc_o) for k in T.Pipelined(loop_range, num_stages=num_stages): T.copy( - x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size + - (k + 1) * block_K, bz, m_idx * block_M:(m_idx + 1) * block_M], x_shared) - T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], - dA_cumsum_shared) - T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared) + x[ + batch_idx, + chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K, + bz, + m_idx * block_M : (m_idx + 1) * block_M, + ], + x_shared, + ) + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dA_cumsum_shared) + T.copy(dt[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dt_shared) T.copy(dA_cumsum_shared, dA_cumsum_local) T.copy(dt_shared, dt_local) for i in T.Parallel(block_K): @@ -123,47 +115,50 @@ def chunk_state_fwd(batch, for i, j in T.Parallel(block_M, block_K): xt_local[i, j] = x_local[j, i] * scale[j] T.copy( - B[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size + - (k + 1) * block_K, bz // (nheads // ngroups), - n_idx * block_N:(n_idx + 1) * block_N], B_shared) + B[ + batch_idx, + chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K, + bz // (nheads // ngroups), + n_idx * block_N : (n_idx + 1) * block_N, + ], + B_shared, + ) T.gemm(xt_local, B_shared, acc_o) T.copy(acc_o, acc_o_shared) T.copy( acc_o_shared, - Output[batch_idx, chunk_idx, bz, m_idx * block_M:(m_idx + 1) * block_M, - n_idx * block_N:(n_idx + 1) * block_N]) + Output[batch_idx, chunk_idx, bz, m_idx * block_M : (m_idx + 1) * block_M, n_idx * block_N : (n_idx + 1) * block_N], + ) return main if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=80, help='heads') - parser.add_argument('--groups', type=int, default=1, help='groups') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--chunk_size', type=int, default=256, help='chunk size') - parser.add_argument('--dim', type=int, default=64, help='dim') - parser.add_argument('--dstate', type=int, default=128, help='dstate') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=80, help="heads") + parser.add_argument("--groups", type=int, default=1, help="groups") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--chunk_size", type=int, default=256, help="chunk size") + parser.add_argument("--dim", type=int, default=64, help="dim") + parser.add_argument("--dstate", type=int, default=128, help="dstate") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() - batch, heads, groups, seq_len, chunk_size, dim, dstate = args.batch, args.heads, args.groups, args.seq_len, args.chunk_size, args.dim, args.dstate + batch, heads, groups, seq_len, chunk_size, dim, dstate = ( + args.batch, + args.heads, + args.groups, + args.seq_len, + args.chunk_size, + args.dim, + args.dstate, + ) total_flops = 2 * batch * seq_len * heads * dim * dstate - if (not args.tune): + if not args.tune: kernel = chunk_state_fwd( - batch, - seq_len, - chunk_size, - groups, - heads, - dim, - dstate, - block_M=64, - block_N=128, - block_K=64, - num_stages=4, - threads=128) + batch, seq_len, chunk_size, groups, heads, dim, dstate, block_M=64, block_N=128, block_K=64, num_stages=4, threads=128 + ) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) print("All checks pass.") diff --git a/examples/linear_attention/example_retention_fwd.py b/examples/linear_attention/example_retention_fwd.py index 5944541..ccb11fe 100644 --- a/examples/linear_attention/example_retention_fwd.py +++ b/examples/linear_attention/example_retention_fwd.py @@ -13,13 +13,12 @@ def chunk_retention_fwd_kernel( H, DK, DV, - dtype: str = 'float16', + dtype: str = "float16", scale: float = None, ) -> torch.Tensor: - if scale is None: scale = DK**-0.5 - accum_dtype = 'float' + accum_dtype = "float" chunk_size = 64 BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA @@ -30,16 +29,16 @@ def chunk_retention_fwd_kernel( @T.prim_func def chunk_retention_fwd( - Q: T.Tensor([B, S, H, DK], dtype), # type: ignore - K: T.Tensor([B, S, H, DK], dtype), # type: ignore - V: T.Tensor([B, S, H, DV], dtype), # type: ignore - O: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore + Q: T.Tensor([B, S, H, DK], dtype), # type: ignore + K: T.Tensor([B, S, H, DK], dtype), # type: ignore + V: T.Tensor([B, S, H, DV], dtype), # type: ignore + O: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore ): with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): i_b = i_bh // H i_h = i_bh % H - log_decay = T.alloc_var('float32') - log_decay = T.log2(1 - T.exp2(-5. - 1. * i_h)) # Head-specific log decay + log_decay = T.alloc_var("float32") + log_decay = T.log2(1 - T.exp2(-5.0 - 1.0 * i_h)) # Head-specific log decay q = T.alloc_shared([chunk_size, BK], dtype) k = T.alloc_shared([chunk_size, BK], dtype) @@ -56,14 +55,12 @@ def chunk_retention_fwd_kernel( for i in T.Pipelined(0, NT): for row, col in T.Parallel(chunk_size, BK): q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale - T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k) - T.copy(V[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v) + T.copy(K[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], k) + T.copy(V[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], v) T.gemm(q, k, s, clear_accum=True, transpose_B=True) for row, col in T.Parallel(chunk_size, chunk_size): - s_shared[row, - col] = T.if_then_else(row >= col, s[row, col] * T.exp2( - (row - col) * log_decay), 0) + s_shared[row, col] = T.if_then_else(row >= col, s[row, col] * T.exp2((row - col) * log_decay), 0) T.copy(h, h_shared) T.gemm(q, h_shared, o, clear_accum=True) @@ -75,9 +72,7 @@ def chunk_retention_fwd_kernel( v[row, col] = v[row, col] * T.exp2((chunk_size - row - 1) * log_decay) for row, col in T.Parallel(BK, BV): h[row, col] = T.exp2(chunk_size * log_decay) * h[row, col] - T.copy( - o, O[i_k, i_b, i * chunk_size:(i + 1) * chunk_size, i_h, - i_v * BV:(i_v + 1) * BV]) + T.copy(o, O[i_k, i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV]) T.gemm(k, v, h, transpose_A=True) return chunk_retention_fwd @@ -89,24 +84,24 @@ def postprocess(o): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--B', type=int, default=8, help='Batch size') - parser.add_argument('--S', type=int, default=4096, help='Seq len') - parser.add_argument('--H', type=int, default=32, help='Num heads') - parser.add_argument('--D', type=int, default=128, help='Head dim') + parser.add_argument("--B", type=int, default=8, help="Batch size") + parser.add_argument("--S", type=int, default=4096, help="Seq len") + parser.add_argument("--H", type=int, default=32, help="Num heads") + parser.add_argument("--D", type=int, default=128, help="Head dim") args = parser.parse_args() B, S, H, D = args.B, args.S, args.H, args.D total_flops = 2.0 * B * S * S * H * D # causal - q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) - k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) - v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) + q = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + k = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + v = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) kernel = chunk_retention_fwd_kernel(B, S, H, D, D) t = do_bench(lambda: postprocess(kernel(q, k, v)), warmup=25, rep=100) - print(f'Tilelang latency: {t:.3f} ms') - print(f'Tilelang TFLOPs: {total_flops/t * 1e-9}') + print(f"Tilelang latency: {t:.3f} ms") + print(f"Tilelang TFLOPs: {total_flops / t * 1e-9}") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/minference/example_vertical_slash_sparse_attn.py b/examples/minference/example_vertical_slash_sparse_attn.py index 48df3e0..6600bb5 100644 --- a/examples/minference/example_vertical_slash_sparse_attn.py +++ b/examples/minference/example_vertical_slash_sparse_attn.py @@ -15,12 +15,11 @@ from tilelang.profiler import do_bench @tilelang.jit(out_idx=[3]) def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_size): - block_M = 64 block_N = 64 num_stages = 2 threads = 128 - scale = (1.0 / dim)**0.5 * 1.44269504 + scale = (1.0 / dim) ** 0.5 * 1.44269504 shape = [batch, heads, seq_len, dim] seq_blocks = (seq_len + block_M - 1) // block_M @@ -30,15 +29,13 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz offset_shape = count_shape + [slash_size] index_shape = count_shape + [vertical_size] - vertical_size_round, slash_size_round = tilelang.next_power_of_2( - vertical_size), tilelang.next_power_of_2(slash_size) + vertical_size_round, slash_size_round = tilelang.next_power_of_2(vertical_size), tilelang.next_power_of_2(slash_size) dtype = "float16" accum_dtype = "float" int_dtype = "int32" def kernel_func(block_M, block_N, num_stages, threads): - @T.macro def Prefetch( K: T.Tensor(shape, dtype), @@ -53,32 +50,30 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz ): with T.attr("default", "async_scope", 1): for i, j in T.Parallel(block_N, dim): - K_shared[i, j] = T.if_then_else(k + i < column_count, - K[bz, by, column_index[k + i], j], 0) + K_shared[i, j] = T.if_then_else(k + i < column_count, K[bz, by, column_index[k + i], j], 0) with T.attr("default", "async_scope", 1): for i, j in T.Parallel(block_N, dim): - V_shared[i, j] = T.if_then_else(k + i < column_count, - V[bz, by, column_index[k + i], j], 0) + V_shared[i, j] = T.if_then_else(k + i < column_count, V[bz, by, column_index[k + i], j], 0) T.ptx_commit_group() @T.macro def Compute( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - k: T.int32, - column_count: T.int32, - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), - V_shared: T.SharedBuffer([block_N, dim], dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - count: T.int32, + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + k: T.int32, + column_count: T.int32, + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + V_shared: T.SharedBuffer([block_N, dim], dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + count: T.int32, ): T.ptx_wait_group(count) for i, j in T.Parallel(block_M, block_N): @@ -108,17 +103,16 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz @T.prim_func def vs_sparse_flashattn_ws( - Q: T.Tensor(shape, dtype), - K: T.Tensor(shape, dtype), - V: T.Tensor(shape, dtype), - Output: T.Tensor(shape, dtype), - BlockCount: T.Tensor(count_shape, int_dtype), - BlockOffset: T.Tensor(offset_shape, int_dtype), - ColumnCount: T.Tensor(count_shape, int_dtype), - ColumnIndex: T.Tensor(index_shape, int_dtype), + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + Output: T.Tensor(shape, dtype), + BlockCount: T.Tensor(count_shape, int_dtype), + BlockOffset: T.Tensor(offset_shape, int_dtype), + ColumnCount: T.Tensor(count_shape, int_dtype), + ColumnIndex: T.Tensor(index_shape, int_dtype), ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bc, by, bz): - bx = T.ceildiv(seq_len, block_M) - 1 - bc Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([2, block_N, dim], dtype) @@ -143,9 +137,11 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz T.create_list_of_mbarrier([128] * 9) - T.annotate_layout({ - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - }) + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + } + ) block_count[0] = BlockCount[bz, by, bx] column_count[0] = ColumnCount[bz, by, bx] @@ -162,15 +158,15 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz if tid >= 128: T.annotate_producer_reg_dealloc() - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.mbarrier_arrive(mbarrier=8) for bi in T.serial(block_count[0]): k = block_offset[bi] T.mbarrier_wait_parity(mbarrier=bi % 2 + 4, parity=(((bi & 3) >> 1) ^ 1)) - T.copy(K[bz, by, k:k + block_N, :], K_shared[bi % 2, :, :]) + T.copy(K[bz, by, k : k + block_N, :], K_shared[bi % 2, :, :]) T.mbarrier_arrive(mbarrier=bi % 2) T.mbarrier_wait_parity(mbarrier=bi % 2 + 6, parity=(((bi & 3) >> 1) ^ 1)) - T.copy(V[bz, by, k:k + block_N, :], V_shared[bi % 2, :, :]) + T.copy(V[bz, by, k : k + block_N, :], V_shared[bi % 2, :, :]) T.mbarrier_arrive(mbarrier=bi % 2 + 2) else: T.annotate_consumer_reg_alloc() @@ -181,16 +177,10 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz for bi in T.serial(block_count[0]): k = block_offset[bi] for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k + j, 0, -T.infinity(acc_s.dtype)) T.mbarrier_wait_parity(mbarrier=bi % 2, parity=((bi & 3) >> 1)) - T.gemm( - Q_shared, - K_shared[bi % 2, :, :], - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared[bi % 2, :, :], acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.mbarrier_arrive(mbarrier=bi % 2 + 4) T.copy(scores_max, scores_max_prev) @@ -200,20 +190,15 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - - scores_max[i] * scale) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim): acc_o[i, j] = acc_o[i, j] * scores_scale[i] T.copy(acc_s, acc_s_cast) - T.mbarrier_wait_parity(mbarrier=bi % 2 + 2, parity=(((bi & 3) >> 1))) - T.gemm( - acc_s_cast, - V_shared[bi % 2, :, :], - acc_o, - policy=T.GemmWarpPolicy.FullRow) + T.mbarrier_wait_parity(mbarrier=bi % 2 + 2, parity=((bi & 3) >> 1)) + T.gemm(acc_s_cast, V_shared[bi % 2, :, :], acc_o, policy=T.GemmWarpPolicy.FullRow) T.mbarrier_arrive(mbarrier=bi % 2 + 6) @@ -223,38 +208,85 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] if column_count[0] != 0: - Prefetch(K, V, K_shared_1, V_shared_1, column_index, column_count[0], 0, bz, - by) + Prefetch(K, V, K_shared_1, V_shared_1, column_index, column_count[0], 0, bz, by) for bi in T.serial(T.ceildiv(column_count[0], block_N) - 1): k = bi * block_N if bi % 2 == 0: - Prefetch(K, V, K_shared_2, V_shared_2, column_index, - column_count[0], k + block_N, bz, by) - - Compute(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, k, - column_count[0], Q_shared, K_shared_1, V_shared_1, - scores_scale, scores_sum, logsum, 1) + Prefetch(K, V, K_shared_2, V_shared_2, column_index, column_count[0], k + block_N, bz, by) + + Compute( + acc_s, + acc_s_cast, + acc_o, + scores_max, + scores_max_prev, + k, + column_count[0], + Q_shared, + K_shared_1, + V_shared_1, + scores_scale, + scores_sum, + logsum, + 1, + ) else: - Prefetch(K, V, K_shared_1, V_shared_1, column_index, - column_count[0], k + block_N, bz, by) - - Compute(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, k, - column_count[0], Q_shared, K_shared_2, V_shared_2, - scores_scale, scores_sum, logsum, 1) + Prefetch(K, V, K_shared_1, V_shared_1, column_index, column_count[0], k + block_N, bz, by) + + Compute( + acc_s, + acc_s_cast, + acc_o, + scores_max, + scores_max_prev, + k, + column_count[0], + Q_shared, + K_shared_2, + V_shared_2, + scores_scale, + scores_sum, + logsum, + 1, + ) if T.ceildiv(column_count[0], block_N) % 2 == 0: - Compute(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, - T.ceildiv(column_count[0], block_N) * block_N - block_N, - column_count[0], Q_shared, K_shared_2, V_shared_2, scores_scale, - scores_sum, logsum, 0) + Compute( + acc_s, + acc_s_cast, + acc_o, + scores_max, + scores_max_prev, + T.ceildiv(column_count[0], block_N) * block_N - block_N, + column_count[0], + Q_shared, + K_shared_2, + V_shared_2, + scores_scale, + scores_sum, + logsum, + 0, + ) else: - Compute(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, - T.ceildiv(column_count[0], block_N) * block_N - block_N, - column_count[0], Q_shared, K_shared_1, V_shared_1, scores_scale, - scores_sum, logsum, 0) + Compute( + acc_s, + acc_s_cast, + acc_o, + scores_max, + scores_max_prev, + T.ceildiv(column_count[0], block_N) * block_N - block_N, + column_count[0], + Q_shared, + K_shared_1, + V_shared_1, + scores_scale, + scores_sum, + logsum, + 0, + ) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return vs_sparse_flashattn_ws @@ -470,11 +502,8 @@ def vertical_slash_sparse_attention( import os current_dir = os.path.dirname(os.path.abspath(__file__)) - sources = [ - os.path.join(current_dir, 'ops', 'kernels.cpp'), - os.path.join(current_dir, 'ops', 'vertical_slash_index.cu') - ] - ops = load(name='convert', sources=sources, verbose=False) + sources = [os.path.join(current_dir, "ops", "kernels.cpp"), os.path.join(current_dir, "ops", "vertical_slash_index.cu")] + ops = load(name="convert", sources=sources, verbose=False) convert_vertical_slash_indexes = ops.convert_vertical_slash_indexes batch_size, num_heads, context_size, head_dim = query.shape pad = (block_size_M - context_size) & (block_size_M - 1) @@ -485,15 +514,13 @@ def vertical_slash_sparse_attention( value = torch.nn.functional.pad(value, [0, 0, 0, pad, 0, 0, 0, 0]) if head_dim not in [16, 32, 64, 128, 256, 512]: - target_dim = 2**math.ceil(math.log2(head_dim)) - head_dim + target_dim = 2 ** math.ceil(math.log2(head_dim)) - head_dim query = torch.nn.functional.pad(query, [0, target_dim, 0, 0, 0, 0, 0, 0]) key = torch.nn.functional.pad(key, [0, target_dim, 0, 0, 0, 0, 0, 0]) value = torch.nn.functional.pad(value, [0, target_dim, 0, 0, 0, 0, 0, 0]) - v_idx = v_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort( - dim=-1, descending=False)[0] - s_idx = s_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort( - dim=-1, descending=True)[0] + v_idx = v_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=False)[0] + s_idx = s_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=True)[0] seqlens = torch.tensor([context_size] * query.shape[0], dtype=torch.int32, device=query.device) sm_scale = head_dim**-0.5 @@ -506,8 +533,7 @@ def vertical_slash_sparse_attention( block_size_N, ) - tl_kernel = _tl_vs_sparse_flashattn(batch_size, num_heads, context_size, head_dim, - v_idx.shape[2], s_idx.shape[2]) + tl_kernel = _tl_vs_sparse_flashattn(batch_size, num_heads, context_size, head_dim, v_idx.shape[2], s_idx.shape[2]) def run(is_triton: bool = True): if is_triton: @@ -525,8 +551,7 @@ def vertical_slash_sparse_attention( block_size_N, ) else: - out = tl_kernel(query, key, value, block_count, block_offset, column_count, - column_index) + out = tl_kernel(query, key, value, block_count, block_offset, column_count, column_index) return out[..., :context_size, :head_dim] return run @@ -536,8 +561,7 @@ def sum_all_diagonal_matrix(mat: torch.tensor): b, h, n, m = mat.shape zero_mat = torch.zeros((b, h, n, n)).to(mat.device) # Zero matrix used for padding mat_padded = torch.cat((zero_mat, mat, zero_mat), -1) # pads the matrix on left and right - mat_strided = mat_padded.as_strided( - (1, 1, n, n + m), (1, n * (2 * n + m), 2 * n + m + 1, 1)) # Change the strides + mat_strided = mat_padded.as_strided((1, 1, n, n + m), (1, n * (2 * n + m), 2 * n + m + 1, 1)) # Change the strides sum_diags = torch.sum(mat_strided, 2) # Sums the resulting matrix's columns return sum_diags[:, :, 1:] @@ -559,24 +583,23 @@ def main(argv=None): vertical_size, slash_size = args.vertical_size, args.slash_size torch.manual_seed(0) - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) q_len = SEQ_LEN vertical_size, slash_size = min(q_len, vertical_size), min(q_len, slash_size) last_q = 64 - qk = torch.einsum('bhmk, bhnk -> bhmn', q[:, :, -last_q:, :], k) + qk = torch.einsum("bhmk, bhnk -> bhmn", q[:, :, -last_q:, :], k) arange = torch.arange(last_q, device="cuda") - qk[:, :, :, -last_q:] = torch.where(arange[None, None, :, None] >= arange[None, None, None, :], - qk[:, :, :, -last_q:], -torch.inf) + qk[:, :, :, -last_q:] = torch.where(arange[None, None, :, None] >= arange[None, None, None, :], qk[:, :, :, -last_q:], -torch.inf) qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32) vertical = qk.sum(-2, keepdim=True) vertical[..., :30] = torch.inf vertical_topk = torch.topk(vertical, vertical_size, -1).indices - slash = sum_all_diagonal_matrix(qk)[..., :-last_q + 1] + slash = sum_all_diagonal_matrix(qk)[..., : -last_q + 1] slash[..., -30:] = torch.inf slash = (q_len - 1) - torch.topk(slash, slash_size, -1).indices diff --git a/examples/norm/rms_norm.py b/examples/norm/rms_norm.py index 40d367c..a7a06b9 100644 --- a/examples/norm/rms_norm.py +++ b/examples/norm/rms_norm.py @@ -45,7 +45,7 @@ def rms_norm(M, N, blk_m): A_local = T.alloc_fragment((blk_m, N), dtype) A_powsum = T.alloc_fragment((blk_m,), dtype) - T.copy(A[bx * blk_m:(bx + 1) * blk_m, :], A_shared) + T.copy(A[bx * blk_m : (bx + 1) * blk_m, :], A_shared) T.copy(A_shared, A_local) for i, j in T.Parallel(blk_m, N): A_pow_local[i, j] = A_local[i, j] * A_local[i, j] @@ -54,7 +54,7 @@ def rms_norm(M, N, blk_m): A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12) for i, j in T.Parallel(blk_m, N): A_local[i, j] *= A_powsum[i] - T.copy(A_local, B[bx * blk_m:(bx + 1) * blk_m, :]) + T.copy(A_local, B[bx * blk_m : (bx + 1) * blk_m, :]) return main diff --git a/examples/norm/test_rms_norm.py b/examples/norm/test_rms_norm.py index a05f9b0..124a212 100644 --- a/examples/norm/test_rms_norm.py +++ b/examples/norm/test_rms_norm.py @@ -45,7 +45,7 @@ def rms_norm(M, N, blk_m): A_local = T.alloc_fragment((blk_m, N), dtype) A_powsum = T.alloc_fragment((blk_m,), dtype) - T.copy(A[bx * blk_m:(bx + 1) * blk_m, :], A_shared) + T.copy(A[bx * blk_m : (bx + 1) * blk_m, :], A_shared) T.copy(A_shared, A_local) for i, j in T.Parallel(blk_m, N): A_pow_local[i, j] = A_local[i, j] * A_local[i, j] @@ -54,7 +54,7 @@ def rms_norm(M, N, blk_m): A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12) for i, j in T.Parallel(blk_m, N): A_local[i, j] *= A_powsum[i] - T.copy(A_local, B[bx * blk_m:(bx + 1) * blk_m, :]) + T.copy(A_local, B[bx * blk_m : (bx + 1) * blk_m, :]) return main diff --git a/examples/online_softmax/online_softmax.py b/examples/online_softmax/online_softmax.py index 432482d..32f1c00 100644 --- a/examples/online_softmax/online_softmax.py +++ b/examples/online_softmax/online_softmax.py @@ -20,8 +20,8 @@ def softmax_kernel( @T.prim_func def main( - X: T.Tensor([M, N], dtype), - Y: T.Tensor([M, N], dtype), + X: T.Tensor([M, N], dtype), + Y: T.Tensor([M, N], dtype), ): with T.Kernel(M, threads=128) as (i_m): x = T.alloc_fragment([BN], dtype) @@ -33,7 +33,7 @@ def softmax_kernel( T.fill(lse, -T.infinity(accum_dtype)) for i_n in T.Pipelined(0, NN): - T.copy(X[i_m, i_n * BN:(i_n + 1) * BN], x) + T.copy(X[i_m, i_n * BN : (i_n + 1) * BN], x) T.reduce_max(x, max_x, dim=0, clear=True) @@ -45,12 +45,12 @@ def softmax_kernel( lse[0] = max_x[0] * scale + T.log2(T.exp2(lse[0] - max_x[0] * scale) + sum_exp_x[0]) for i_n in T.Pipelined(0, NN): - T.copy(X[i_m, i_n * BN:(i_n + 1) * BN], x) + T.copy(X[i_m, i_n * BN : (i_n + 1) * BN], x) for j in T.Parallel(BN): y[j] = T.exp2(x[j] * scale - lse[0]) - T.copy(y, Y[i_m, i_n * BN:(i_n + 1) * BN]) + T.copy(y, Y[i_m, i_n * BN : (i_n + 1) * BN]) return main @@ -69,4 +69,4 @@ t1 = do_bench(lambda: X.softmax(dim=1), warmup=25, rep=100) t2 = do_bench(lambda: kernel(X), warmup=25, rep=100) print(f"torch latency: {t1:.3f} ms") print(f"TileLang latency: {t2:.3f} ms") -print(f"Speedup: {t1/t2:.3f}x") +print(f"Speedup: {t1 / t2:.3f}x") diff --git a/examples/plot_layout/fragment_mfma_load_a.py b/examples/plot_layout/fragment_mfma_load_a.py index 2c3b282..a7e8f89 100644 --- a/examples/plot_layout/fragment_mfma_load_a.py +++ b/examples/plot_layout/fragment_mfma_load_a.py @@ -11,10 +11,9 @@ from tilelang.intrinsics.mfma_layout import ( ) -def make_mfma_load_base_layout(dtype: str = "float16", - matrix: Literal["A", "B"] = "A", - k_dim: int = 16, - transposed: bool = False) -> T.Fragment: +def make_mfma_load_base_layout( + dtype: str = "float16", matrix: Literal["A", "B"] = "A", k_dim: int = 16, transposed: bool = False +) -> T.Fragment: """ Create a layout function for storing MFMA results into a fragment buffer. This layout is used in conjunction with `inverse_mfma_store_layout` to @@ -72,12 +71,10 @@ def make_mfma_load_base_layout(dtype: str = "float16", # so the b matrix expected a transposed basic layout transform_func: Callable = None if matrix == "A": - transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( - j, i) + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i) micro_size_s, micro_size_r = micro_size_x, micro_size_k elif matrix == "B": - transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b( - j, i) + transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i) micro_size_s, micro_size_r = micro_size_k, micro_size_y else: raise ValueError(f"Unsupported matrix {matrix}") @@ -120,14 +117,11 @@ print(base_layout) plot_layout(base_layout, name="base_layout") # warp layout 32x32 -warp_layout = base_layout.repeat([warp_rows, warp_cols], - repeat_on_thread=False, - lower_dim_first=False) +warp_layout = base_layout.repeat([warp_rows, warp_cols], repeat_on_thread=False, lower_dim_first=False) print(warp_layout) plot_layout(warp_layout, name="warp_layout") # block layout 64x32 -block_layout = warp_layout.repeat([block_rows, 1], repeat_on_thread=True, - lower_dim_first=True).replicate(block_cols) +block_layout = warp_layout.repeat([block_rows, 1], repeat_on_thread=True, lower_dim_first=True).replicate(block_cols) print(block_layout) plot_layout(block_layout, name="block_layout") diff --git a/examples/plot_layout/fragment_mma_load_a.py b/examples/plot_layout/fragment_mma_load_a.py index 9888994..17d1c6d 100644 --- a/examples/plot_layout/fragment_mma_load_a.py +++ b/examples/plot_layout/fragment_mma_load_a.py @@ -5,9 +5,7 @@ from tvm.tir import IndexMap from tilelang.intrinsics.utils import get_mma_micro_size -def make_mma_load_base_layout(dtype: str = "float16", - matrix: Literal["A", "B"] = "A", - transposed: bool = False) -> T.Fragment: +def make_mma_load_base_layout(dtype: str = "float16", matrix: Literal["A", "B"] = "A", transposed: bool = False) -> T.Fragment: """ Create a layout function for storing MMA results into a fragment buffer. This layout is used in conjunction with `inverse_mma_store_layout` to @@ -36,6 +34,7 @@ def make_mma_load_base_layout(dtype: str = "float16", shared_16x16_to_mma_32x8_layout_sr_b, shared_16x32_to_mma_32x16_layout_sr_b, ) + assert matrix in ["A", "B"], "matrix should be either A or B" dtype_bits = DataType(dtype).bits # s represents spatial axis @@ -67,12 +66,10 @@ def make_mma_load_base_layout(dtype: str = "float16", # so the b matrix expected a transposed basic layout transform_func: Callable = None if matrix == "A": - transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( - j, i) + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i) micro_size_s, micro_size_r = micro_size_x, micro_size_k elif matrix == "B": - transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b( - j, i) + transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i) micro_size_s, micro_size_r = micro_size_k, micro_size_y else: raise ValueError(f"Unsupported matrix {matrix}") diff --git a/examples/quickstart.py b/examples/quickstart.py index 39ad348..4b765ce 100644 --- a/examples/quickstart.py +++ b/examples/quickstart.py @@ -7,12 +7,11 @@ import tilelang.language as T # if not specified, it will be inferred from the input tensors during compile time @tilelang.jit def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def matmul_relu_kernel( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): diff --git a/examples/seer_attention/block_sparse_attn_tilelang.py b/examples/seer_attention/block_sparse_attn_tilelang.py index 219d3ee..f5f7fe7 100644 --- a/examples/seer_attention/block_sparse_attn_tilelang.py +++ b/examples/seer_attention/block_sparse_attn_tilelang.py @@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -30,15 +27,17 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F @tilelang.jit( - out_idx=[4], pass_configs={ + out_idx=[4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_causal): block_M = 64 block_N = 64 num_stages = 0 threads = 128 - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, heads, seq_kv, dim] block_mask_shape = [batch, heads, downsample_len, downsample_len] @@ -48,16 +47,15 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c block_mask_dtype = "int8" def kernel_func(block_M, block_N, num_stages, threads): - @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -83,19 +81,19 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -112,7 +110,7 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c logsum = T.alloc_fragment([block_M], accum_dtype) block_mask = T.alloc_local([downsample_len], block_mask_dtype) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -124,33 +122,25 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c for k in T.Pipelined(loop_range, num_stages=num_stages): if block_mask[k] != 0: - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) if is_causal: past_len = seq_kv - seq_q for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else( - bx * block_M + i + past_len >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i + past_len >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) - - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, - scores_sum, logsum) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main @@ -165,44 +155,40 @@ def test_topk_sparse_attention(): torch.manual_seed(0) # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) sm_scale = 1.0 / (D_HEAD**0.5) # Create sparse mask (downsampled to block level) downsample_factor = BLOCK downsample_len = math.ceil(SEQ_LEN / downsample_factor) - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.float16) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.float16) x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) # Run tilelang kernel - kernel = blocksparse_flashattn( - BATCH, N_HEADS, SEQ_LEN, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) + kernel = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) tilelang_output = kernel(q, k, v, block_mask.to(torch.int8)) # Compute reference # Expand block mask to full attention matrix - full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")) full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal # PyTorch reference implementation - attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale - attn = attn.masked_fill(~full_mask, float('-inf')) + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float("-inf")) attn = F.softmax(attn, dim=-1) - ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) print("ref_output", ref_output) print("tilelang_output", tilelang_output) # Verify accuracy - assert torch.allclose(tilelang_output, ref_output, atol=1e-2, rtol=1e-2), \ - "TileLang output doesn't match reference" + assert torch.allclose(tilelang_output, ref_output, atol=1e-2, rtol=1e-2), "TileLang output doesn't match reference" print("Pass topk sparse attention test with qlen == klen") @@ -215,42 +201,40 @@ def test_topk_sparse_attention_qlen_lt_klen(): torch.manual_seed(0) # Create inputs. - q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device='cuda', dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.float16) + q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.float16) sm_scale = 1.0 / (D_HEAD**0.5) downsample_factor = BLOCK downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension - x_ds = torch.randn( - BATCH, N_HEADS, downsample_len, downsample_len, device='cuda', dtype=torch.float16) + x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="cuda", dtype=torch.float16) # Force the first column to be high so that the first block is always selected. x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) - kernel = blocksparse_flashattn( - BATCH, N_HEADS, Q_LEN, K_LEN, D_HEAD, downsample_len, is_causal=True) + kernel = blocksparse_flashattn(BATCH, N_HEADS, Q_LEN, K_LEN, D_HEAD, downsample_len, is_causal=True) print(kernel.get_kernel_source()) tilelang_output = kernel(q, k, v, block_mask.to(torch.int8)) past_len = K_LEN - Q_LEN - attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale - full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')).bool() + full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")).bool() full_mask_full = full_mask_full[..., :K_LEN, :K_LEN] effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN) i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1) j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN) - causal_mask = (j_global <= i_global) # shape: (Q_LEN, K_LEN) + causal_mask = j_global <= i_global # shape: (Q_LEN, K_LEN) final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN) - attn = attn.masked_fill(~final_mask, float('-inf')) + attn = attn.masked_fill(~final_mask, float("-inf")) attn = F.softmax(attn, dim=-1) - ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) print("ref_output", ref_output) print("tilelang_output", tilelang_output) diff --git a/examples/seer_attention/block_sparse_attn_triton.py b/examples/seer_attention/block_sparse_attn_triton.py index ed33cc1..b4cc3cd 100644 --- a/examples/seer_attention/block_sparse_attn_triton.py +++ b/examples/seer_attention/block_sparse_attn_triton.py @@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -54,7 +51,6 @@ def _fwd_kernel_inner( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): - mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n) if mask_val == True: @@ -69,7 +65,7 @@ def _fwd_kernel_inner( qk *= sm_scale # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N - qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float('-inf')) + qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf")) m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk -= m_ij[:, None] @@ -149,7 +145,7 @@ def _fwd_kernel( v_ptrs = V + off_v mask_ptrs = block_mask_ptr + start_m * stride_bmm - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) @@ -185,24 +181,12 @@ def _fwd_kernel( acc = acc * l_recip acc = acc.to(Out.dtype.element_ty) - off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[ - None, :] * stride_od + off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od out_ptrs = Out + off_o tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX) -def _forward(ctx, - q, - k, - v, - block_sparse_mask, - sm_scale, - BLOCK_M=64, - BLOCK_N=64, - num_warps=None, - num_stages=1, - out=None): - +def _forward(ctx, q, k, v, block_sparse_mask, sm_scale, BLOCK_M=64, BLOCK_N=64, num_warps=None, num_stages=1, out=None): assert q.shape[-1] == k.shape[-1] == v.shape[-1] assert k.shape[2] == v.shape[2] o = out if out is not None else torch.empty_like(q).contiguous() @@ -247,7 +231,6 @@ def _forward(ctx, class _sparse_attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, block_sparse_dense, sm_scale): # shape constraints @@ -271,9 +254,9 @@ def test_topk_sparse_attention(): torch.manual_seed(0) # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) sm_scale = 1.0 / (D_HEAD**0.5) # Create sparse mask (downsampled to block level) @@ -281,9 +264,7 @@ def test_topk_sparse_attention(): downsample_len = math.ceil(SEQ_LEN / downsample_factor) print("downsample_len", downsample_len) - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.bfloat16) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) x_ds[:, :, :, 0] = 100 print("x_ds.shape", x_ds.shape) block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) @@ -295,22 +276,21 @@ def test_topk_sparse_attention(): # Compute reference # Expand block mask to full attention matrix - full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")) full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal # PyTorch reference implementation - attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale - attn = attn.masked_fill(~full_mask, float('-inf')) + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float("-inf")) attn = F.softmax(attn, dim=-1) - ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) # print("ref_output", ref_output) # print("triton_output", triton_output) # Verify accuracy - assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ - "Triton output doesn't match reference" + assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference" print("Pass topk sparse attention test with qlen == klen") @@ -322,16 +302,15 @@ def test_topk_sparse_attention_qlt_kl(): torch.manual_seed(0) # Create inputs. - q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) + q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) # softmax scale sm_scale = 1.0 / (D_HEAD**0.5) downsample_factor = BLOCK downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension - x_ds = torch.randn( - BATCH, N_HEADS, downsample_len, downsample_len, device='cuda', dtype=torch.bfloat16) + x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="cuda", dtype=torch.bfloat16) # Force the first column to be high so that the first block is always selected. x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) @@ -340,26 +319,25 @@ def test_topk_sparse_attention_qlt_kl(): past_len = K_LEN - Q_LEN - attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale - full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')).bool() + full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")).bool() full_mask_full = full_mask_full[..., :K_LEN, :K_LEN] effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN) i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1) j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN) - causal_mask = (j_global <= i_global) # shape: (Q_LEN, K_LEN) + causal_mask = j_global <= i_global # shape: (Q_LEN, K_LEN) final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN) - attn = attn.masked_fill(~final_mask, float('-inf')) + attn = attn.masked_fill(~final_mask, float("-inf")) attn = F.softmax(attn, dim=-1) - ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) # Verify accuracy. - assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ - "Triton output doesn't match reference when qlen < klen" + assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference when qlen < klen" print("Pass topk sparse attention test with qlen < klen") diff --git a/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py b/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py index 8707c94..6c37dc0 100644 --- a/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py +++ b/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py @@ -28,24 +28,22 @@ def matmul_sp( @T.prim_func def main( - A_sparse: T.Tensor(A_sparse_shape, in_dtype), - E: T.Tensor((M, K // 8), 'uint8'), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // 8), "uint8"), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) - E_shared = T.alloc_shared((block_M, block_K // 8), 'uint8') + E_shared = T.alloc_shared((block_M, block_K // 8), "uint8") C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - T.annotate_layout({ - E: - make_cutlass_metadata_layout( - E, mma_dtype="float16", arch="9.0", block_k=block_K), - E_shared: - make_cutlass_metadata_layout( - E_shared, mma_dtype="float16", arch="9.0", block_k=block_K), - }) + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype="float16", arch="9.0", block_k=block_K), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype="float16", arch="9.0", block_k=block_K), + } + ) T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(E[by * block_M, k * block_K // 8], E_shared) @@ -57,7 +55,7 @@ def matmul_sp( return main -def generate_2_to_4_sparse_tensor(shape, dtype=torch.float32, device='cpu'): +def generate_2_to_4_sparse_tensor(shape, dtype=torch.float32, device="cpu"): if shape[-1] % 4 != 0: raise ValueError("Last dimension must be divisible by 4 for 2:4 sparsity.") @@ -102,9 +100,9 @@ def run_gemm_sp( num_threads, ) - A = generate_2_to_4_sparse_tensor((M, K), dtype=torch.float16, device='cuda') + A = generate_2_to_4_sparse_tensor((M, K), dtype=torch.float16, device="cuda") A_sparse, E = compress_sm90(A, block_k=block_K, transposed=False) - B = torch.randn((K, N), device='cuda', dtype=torch.float16) + B = torch.randn((K, N), device="cuda", dtype=torch.float16) C_sp = kernel(A_sparse, E, B).half() C = torch.matmul(A, B) diff --git a/examples/topk/example_topk.py b/examples/topk/example_topk.py index 0ca19fb..c0cf09b 100644 --- a/examples/topk/example_topk.py +++ b/examples/topk/example_topk.py @@ -26,9 +26,9 @@ def tl_topk( @T.prim_func def topk_kernel( - logits: T.Tensor([M, N], dtype), - topk_gates: T.Tensor([M, topk], dtype), - topk_indices: T.Tensor([M, topk], "int32"), + logits: T.Tensor([M, N], dtype), + topk_gates: T.Tensor([M, topk], dtype), + topk_indices: T.Tensor([M, topk], "int32"), ): with T.Kernel(T.ceildiv(M, blk_m), threads=threads) as bx: logits_frag = T.alloc_fragment([blk_m, N], dtype=dtype) @@ -43,15 +43,12 @@ def tl_topk( T.reduce_max(logits_frag, max_val, dim=1, clear=True) for i, j in T.Parallel(blk_m, N): - expand_max_idx[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], j, - expand_max_idx[i, j]) + expand_max_idx[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], j, expand_max_idx[i, j]) T.reduce_max(expand_max_idx, max_idx, dim=1, clear=True) for i, j in T.Parallel(blk_m, N): - - logits_frag[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], -10000.0, - logits_frag[i, j]) + logits_frag[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], -10000.0, logits_frag[i, j]) for i in T.Parallel(blk_m): topk_gates[bx * blk_m + i, k] = max_val[i] @@ -61,7 +58,6 @@ def tl_topk( def ref_program(logits, top_k): - top_k_gates, top_k_indices = logits.topk(top_k, dim=1) return top_k_gates, top_k_indices.to(torch.int32) diff --git a/examples/visual_layout_inference/visual_layout_inference.py b/examples/visual_layout_inference/visual_layout_inference.py index 3677d47..dbb39f7 100644 --- a/examples/visual_layout_inference/visual_layout_inference.py +++ b/examples/visual_layout_inference/visual_layout_inference.py @@ -7,15 +7,15 @@ import tilelang.language as T out_idx=[-1], pass_configs={ tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLE: True, - tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS: "svg" - }) + tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS: "svg", + }, +) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def gemm( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -49,12 +49,12 @@ def main(): print("All check passed.") # print the layout visualization result and save figures to ./tmp. - ''' + """ C_local inferenced layout: Shape: [32, 32] -> [8] Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2 Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2] - ''' + """ if __name__ == "__main__": diff --git a/examples/warp_specialize/example_warp_specialize_flashmla.py b/examples/warp_specialize/example_warp_specialize_flashmla.py index 4a8f41e..4f4417e 100644 --- a/examples/warp_specialize/example_warp_specialize_flashmla.py +++ b/examples/warp_specialize/example_warp_specialize_flashmla.py @@ -9,7 +9,7 @@ import argparse @tilelang.jit(out_idx=[6]) def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split): - scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) + scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) dtype = "float16" accum_dtype = "float" kv_group_num = heads // kv_head_num @@ -19,11 +19,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ @T.macro def flash_attn( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid): # smem_sQ @@ -81,10 +81,12 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ cur_kv_head = hid // (kv_group_num // block_H) - T.annotate_layout({ - O_shared_l: tilelang.layout.make_swizzled_layout(O_shared_l), - O_shared_r: tilelang.layout.make_swizzled_layout(O_shared_r), - }) + T.annotate_layout( + { + O_shared_l: tilelang.layout.make_swizzled_layout(O_shared_l), + O_shared_r: tilelang.layout.make_swizzled_layout(O_shared_r), + } + ) # barriers_Q q_shared_ready_barrier = T.alloc_barrier(arrive_count=256) @@ -108,9 +110,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ tx = T.get_thread_binding() - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :h_dim], Q_shared_l) - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, h_dim:], Q_shared_r) - T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :h_dim], Q_shared_l) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, h_dim:], Q_shared_r) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.barrier_arrive(q_shared_ready_barrier) T.barrier_wait(q_shared_ready_barrier, 0) @@ -123,25 +125,18 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ T.fill(acc_o_l, 0) T.fill(logsum_0, 0) - T.copy(KV[bid, block_N:2 * block_N, cur_kv_head, :h_dim], KV_shared_1_l) + T.copy(KV[bid, block_N : 2 * block_N, cur_kv_head, :h_dim], KV_shared_1_l) T.barrier_arrive(kv_shared_1_l_is_ready) - T.copy(KV[bid, block_N:2 * block_N, cur_kv_head, h_dim:], KV_shared_1_r) + T.copy(KV[bid, block_N : 2 * block_N, cur_kv_head, h_dim:], KV_shared_1_r) T.barrier_arrive(kv_shared_1_r_is_ready) - T.copy(K_pe[bid, block_N:2 * block_N, cur_kv_head, :], K_pe_shared_1) + T.copy(K_pe[bid, block_N : 2 * block_N, cur_kv_head, :], K_pe_shared_1) T.barrier_arrive(kv_shared_1_pe_is_ready) for k in T.serial(loop_range): - T.barrier_wait(kv_shared_0_l_is_ready, k % 2) - T.gemm( - Q_shared_l, - KV_shared_0_l, - acc_s_0, - transpose_B=True, - clear_accum=True, - wg_wait=-1) + T.gemm(Q_shared_l, KV_shared_0_l, acc_s_0, transpose_B=True, clear_accum=True, wg_wait=-1) T.barrier_wait(kv_shared_0_r_is_ready, k % 2) T.gemm(Q_shared_r, KV_shared_0_r, acc_s_0, transpose_B=True, wg_wait=-1) @@ -161,8 +156,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ for i, j in T.Parallel(block_H, block_N): acc_s_0[i, j] = T.exp2(acc_s_0[i, j] * scale - scores_max[i] * scale) for i in T.Parallel(block_H): - scores_scale_0[i] = T.exp2(scores_max_prev_0[i] * scale - - scores_max[i] * scale) + scores_scale_0[i] = T.exp2(scores_max_prev_0[i] * scale - scores_max[i] * scale) T.reduce_sum(acc_s_0, scores_sum_0, dim=1) @@ -182,9 +176,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ T.barrier_wait(scale_1_ready_barrier, k % 2) if k < loop_range - 1: - T.copy( - KV[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N, - cur_kv_head, :h_dim], KV_shared_0_l) + T.copy(KV[bid, (2 * k + 2) * block_N : (2 * k + 3) * block_N, cur_kv_head, :h_dim], KV_shared_0_l) T.barrier_arrive(kv_shared_0_l_is_ready) # Step 11. @@ -204,15 +196,10 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ T.gemm(SP1_shared, KV_shared_1_l, acc_o_l) if k < loop_range - 1: - - T.copy( - KV[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N, - cur_kv_head, :h_dim], KV_shared_1_l) + T.copy(KV[bid, (2 * k + 3) * block_N : (2 * k + 4) * block_N, cur_kv_head, :h_dim], KV_shared_1_l) T.barrier_arrive(kv_shared_1_l_is_ready) - T.copy( - K_pe[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N, cur_kv_head, :], - K_pe_shared_1) + T.copy(K_pe[bid, (2 * k + 3) * block_N : (2 * k + 4) * block_N, cur_kv_head, :], K_pe_shared_1) T.barrier_arrive(kv_shared_1_pe_is_ready) T.copy(logsum_0, logsum) @@ -221,8 +208,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ for i, j in T.Parallel(block_H, h_dim): acc_o_l[i, j] /= logsum[i] T.copy(acc_o_l, O_shared_l) - T.copy(O_shared_l, Output[bid, - hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :h_dim]) + T.copy(O_shared_l, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :h_dim]) else: T.copy(Q_pe_shared, Q_pe_local_1) @@ -237,16 +223,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ T.barrier_arrive(kv_shared_0_pe_is_ready) for k in T.serial(loop_range): - # Step 2. T.barrier_wait(kv_shared_1_l_is_ready, k % 2) - T.gemm( - Q_shared_l, - KV_shared_1_l, - acc_s_1, - transpose_B=True, - clear_accum=True, - wg_wait=-1) + T.gemm(Q_shared_l, KV_shared_1_l, acc_s_1, transpose_B=True, clear_accum=True, wg_wait=-1) T.barrier_wait(kv_shared_1_r_is_ready, k % 2) T.gemm(Q_shared_r, KV_shared_1_r, acc_s_1, transpose_B=True, wg_wait=-1) @@ -265,8 +244,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ T.copy(scores_max_1, scores_max) for i in T.Parallel(block_H): - scores_scale_1[i] = T.exp2(scores_max_prev_1[i] * scale - - scores_max[i] * scale) + scores_scale_1[i] = T.exp2(scores_max_prev_1[i] * scale - scores_max[i] * scale) # Step 8. for i, j in T.Parallel(block_H, block_N): @@ -279,8 +257,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ acc_o_r[i, j] = acc_o_r[i, j] * (scores_scale_0[i] * scores_scale_1[i]) for i in T.Parallel(block_H): - logsum_1[i] = logsum_1[i] * scores_scale_1[i] * scores_scale_0[ - i] + scores_sum_1[i] + logsum_1[i] = logsum_1[i] * scores_scale_1[i] * scores_scale_0[i] + scores_sum_1[i] T.barrier_arrive(scale_1_ready_barrier) @@ -291,9 +268,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ T.barrier_arrive(s_shared_ready_barrier) if k < loop_range - 1: - T.copy( - KV[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N, cur_kv_head, - h_dim:], KV_shared_1_r) + T.copy(KV[bid, (2 * k + 3) * block_N : (2 * k + 4) * block_N, cur_kv_head, h_dim:], KV_shared_1_r) T.barrier_arrive(kv_shared_1_r_is_ready) T.barrier_wait(p0_1_1_ready_barrier, k % 2) @@ -301,15 +276,10 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ T.gemm(SP0_shared, KV_shared_0_r, acc_o_r) if k < loop_range - 1: - - T.copy( - KV[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N, cur_kv_head, - h_dim:], KV_shared_0_r) + T.copy(KV[bid, (2 * k + 2) * block_N : (2 * k + 3) * block_N, cur_kv_head, h_dim:], KV_shared_0_r) T.barrier_arrive(kv_shared_0_r_is_ready) - T.copy( - K_pe[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N, cur_kv_head, :], - K_pe_shared_0) + T.copy(K_pe[bid, (2 * k + 2) * block_N : (2 * k + 3) * block_N, cur_kv_head, :], K_pe_shared_0) T.barrier_arrive(kv_shared_0_pe_is_ready) T.barrier_wait(lse_0_ready_barrier, 0) @@ -319,18 +289,17 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ for i, j in T.Parallel(block_H, h_dim): acc_o_r[i, j] /= logsum[i] T.copy(acc_o_r, O_shared_r) - T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, - h_dim:]) + T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, h_dim:]) @T.prim_func def main_no_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): flash_attn(Q, Q_pe, KV, K_pe, Output) @@ -352,31 +321,24 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): dim = q.shape[-1] pe_dim = q_pe.shape[-1] num_head_groups = q.shape[1] // kv.shape[2] - scale = (dim + pe_dim)**0.5 - q = rearrange( - q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] - q_pe = rearrange( - q_pe, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] - kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] - k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim] + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] query = torch.concat([q, q_pe], dim=-1) key = torch.concat([kv, k_pe], dim=-1) - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] - out = einsum(attention, kv, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out @@ -399,12 +361,12 @@ def main(batch=1, heads=64, kv_heads=1, kv_ctx=1024, dim=512, pe_dim=64): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=128, help='q heads number') - parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') - parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') - parser.add_argument('--dim', type=int, default=512, help='head dim') - parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") args = parser.parse_args() batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim main(batch, heads, kv_heads, kv_ctx, dim, pe_dim) diff --git a/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py b/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py index b738a4b..5d438b5 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py @@ -8,7 +8,6 @@ tilelang.disable_cache() # @tilelang.jit @tilelang.jit(out_idx=[2]) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - num_stages = 2 mbarrier_list = [128, 128] * num_stages @@ -32,19 +31,13 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo for ko in range(T.ceildiv(K, block_K)): with T.ws(1): - T.mbarrier_wait_parity( - mbarrier=ko % num_stages + num_stages, - parity=((ko // num_stages) % num_stages) ^ 1) - T.copy(A[by * block_M:(by + 1) * block_M, ko * block_K:(ko + 1) * block_K], - A_shared[ko % num_stages, :, :]) - T.copy(B[ko * block_K:(ko + 1) * block_K, bx * block_N:(bx + 1) * block_N], - B_shared[ko % num_stages, :, :]) + T.mbarrier_wait_parity(mbarrier=ko % num_stages + num_stages, parity=((ko // num_stages) % num_stages) ^ 1) + T.copy(A[by * block_M : (by + 1) * block_M, ko * block_K : (ko + 1) * block_K], A_shared[ko % num_stages, :, :]) + T.copy(B[ko * block_K : (ko + 1) * block_K, bx * block_N : (bx + 1) * block_N], B_shared[ko % num_stages, :, :]) T.mbarrier_arrive(mbarrier=ko % num_stages) with T.ws(0): - T.mbarrier_wait_parity( - mbarrier=ko % num_stages, parity=(ko // num_stages) % num_stages) - T.gemm(A_shared[ko % num_stages, :, :], B_shared[ko % num_stages, :, :], - C_local) + T.mbarrier_wait_parity(mbarrier=ko % num_stages, parity=(ko // num_stages) % num_stages) + T.gemm(A_shared[ko % num_stages, :, :], B_shared[ko % num_stages, :, :], C_local) T.mbarrier_arrive(mbarrier=ko % num_stages + num_stages) with T.ws(0): diff --git a/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py b/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py index 9ba9f68..03ddf81 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py @@ -5,20 +5,12 @@ import tilelang.language as T # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit @tilelang.jit(out_idx=[2]) -def matmul_warp_specialize_copy_0_gemm_1(M, - N, - K, - block_M, - block_N, - block_K, - dtype="float16", - accum_dtype="float"): - +def matmul_warp_specialize_copy_0_gemm_1(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): diff --git a/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py b/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py index faaf48c..63aed2b 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py @@ -5,20 +5,12 @@ import tilelang.language as T # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit @tilelang.jit(out_idx=[2]) -def matmul_warp_specialize_copy_1_gemm_0(M, - N, - K, - block_M, - block_N, - block_K, - dtype="float16", - accum_dtype="float"): - +def matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): diff --git a/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py b/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py index c912745..f24d76a 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py @@ -5,26 +5,20 @@ import tilelang.language as T # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - }) -def matmul_warp_specialize_copy_1_gemm_0(M, - N, - K, - block_M, - block_N, - block_K, - dtype="float16", - accum_dtype="float"): - + }, +) +def matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): warp_group_num = 2 threads = 128 * warp_group_num @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): diff --git a/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py b/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py index 3b1d867..f3f8a66 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py @@ -6,7 +6,6 @@ import tilelang.language as T # @tilelang.jit @tilelang.jit(out_idx=[2]) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def main( A: T.Tensor[(M, K), dtype], diff --git a/format.sh b/format.sh index e820b58..3cc4390 100755 --- a/format.sh +++ b/format.sh @@ -9,7 +9,7 @@ # bash format.sh --all # # -# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase. +# Ruff (format) + Clang formatter (if installed). This script formats all changed files from the last mergebase. # You are encouraged to run this locally before pushing changes for review. # Cause the script to exit if a single command fails diff --git a/maint/gemm_v2/correctness_evaluation.py b/maint/gemm_v2/correctness_evaluation.py index 33a5812..e7a8225 100644 --- a/maint/gemm_v2/correctness_evaluation.py +++ b/maint/gemm_v2/correctness_evaluation.py @@ -28,9 +28,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") @@ -66,7 +66,8 @@ def _compile_and_check( tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, # tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False, - }) + }, + ) print(kernel.get_kernel_source()) @@ -151,9 +152,9 @@ def matmul_rs( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") @@ -238,9 +239,9 @@ def matmul_sr( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") @@ -326,9 +327,9 @@ def matmul_rr( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") @@ -394,37 +395,48 @@ M_VALUES = [64, 128, 256] N_VALUES = [16, 32, 64, 128, 256, 512] K_VALUES = [16, 32, 64, 128] K_VALUES_8Bit = [32, 64, 128] -FALSE_TRUE_CASES = ([ - pytest.param( - k, - "float16", - "float16", - "float16", - id=f"K{k}-float16-float16-float16", - ) for k in K_VALUES -] + [pytest.param( - k, - "int8", - "int32", - "int32", - id="K32-int8-int32-int32", -) for k in K_VALUES_8Bit] + [ - pytest.param( - k, - "float8_e5m2", - "float32", - "float32", - id="K32-float8_e5m2-float32-float32", - ) for k in K_VALUES_8Bit -] + [ - pytest.param( - k, - "float8_e4m3", - "float32", - "float32", - id="K32-float8_e4m3-float32-float32", - ) for k in K_VALUES_8Bit -]) +FALSE_TRUE_CASES = ( + [ + pytest.param( + k, + "float16", + "float16", + "float16", + id=f"K{k}-float16-float16-float16", + ) + for k in K_VALUES + ] + + [ + pytest.param( + k, + "int8", + "int32", + "int32", + id="K32-int8-int32-int32", + ) + for k in K_VALUES_8Bit + ] + + [ + pytest.param( + k, + "float8_e5m2", + "float32", + "float32", + id="K32-float8_e5m2-float32-float32", + ) + for k in K_VALUES_8Bit + ] + + [ + pytest.param( + k, + "float8_e4m3", + "float32", + "float32", + id="K32-float8_e4m3-float32-float32", + ) + for k in K_VALUES_8Bit + ] +) def _ensure_torch_dtypes(*dtype_names): diff --git a/maint/gemm_v2/correctness_evaluation_sm70.py b/maint/gemm_v2/correctness_evaluation_sm70.py index 128f4ab..3b4503d 100644 --- a/maint/gemm_v2/correctness_evaluation_sm70.py +++ b/maint/gemm_v2/correctness_evaluation_sm70.py @@ -28,9 +28,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") @@ -67,7 +67,8 @@ def _compile_and_check( tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, # tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False, - }) + }, + ) print(kernel.get_kernel_source()) @@ -150,9 +151,9 @@ def matmul_rs( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") @@ -213,14 +214,15 @@ def run_gemm_rs( M_VALUES = [64, 128] N_VALUES = [32, 64, 128] K_VALUES = [16, 32, 64] -FALSE_TRUE_CASES = ([ +FALSE_TRUE_CASES = [ pytest.param( k, "float16", "float16", "float16", id=f"K{k}-float16-float16-float16", - ) for k in K_VALUES + ) + for k in K_VALUES ] + [ pytest.param( k, @@ -228,8 +230,9 @@ FALSE_TRUE_CASES = ([ "float16", "float32", id=f"K{k}-float16-float16-float32", - ) for k in K_VALUES -]) + ) + for k in K_VALUES +] def _ensure_torch_dtypes(*dtype_names): diff --git a/maint/gemm_v2/correctness_evaluation_tcgen05.py b/maint/gemm_v2/correctness_evaluation_tcgen05.py index 1831ac8..4ce8691 100644 --- a/maint/gemm_v2/correctness_evaluation_tcgen05.py +++ b/maint/gemm_v2/correctness_evaluation_tcgen05.py @@ -27,9 +27,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -42,15 +42,7 @@ def matmul( for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A[by * block_M, k * block_K], A_shared) T.copy(B[bx * block_N, k * block_K], B_shared) - T.gemm( - A_shared, - B_shared, - C_tmem, - trans_A, - trans_B, - mbar=mbar, - wg_wait=-1, - clear_accum=k == 0) + T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, wg_wait=-1, clear_accum=k == 0) T.mbarrier_wait_parity(mbar, k % 2) T.copy(C_tmem, C_local) @@ -74,7 +66,8 @@ def _compile_and_check( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) print(kernel.get_kernel_source()) @@ -138,14 +131,15 @@ M_VALUES = [32, 64, 128, 256] N_VALUES = [64, 128, 256, 512] K_VALUES = [16, 32, 64, 128] K_VALUES_8Bit = [32, 64, 128] -FALSE_TRUE_CASES = ([ +FALSE_TRUE_CASES = [ pytest.param( k, "float16", "float32", "float32", id=f"K{k}-float16-float-float", - ) for k in K_VALUES + ) + for k in K_VALUES ] + [ pytest.param( k, @@ -153,8 +147,9 @@ FALSE_TRUE_CASES = ([ "float32", "float32", id="K32-float8_e5m2-float32-float32", - ) for k in K_VALUES_8Bit -]) + ) + for k in K_VALUES_8Bit +] TRANS_CASES = [ pytest.param(False, True, id="nt"), diff --git a/maint/gemm_v2/latency.py b/maint/gemm_v2/latency.py index 07a5020..4dcb7cf 100644 --- a/maint/gemm_v2/latency.py +++ b/maint/gemm_v2/latency.py @@ -14,12 +14,11 @@ use_v2 = args.use_v2 # if not specified, it will be inferred from the input tensors during compile time @tilelang.jit def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def matmul_relu_kernel( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): diff --git a/maint/gemm_v2/latency_gemm.py b/maint/gemm_v2/latency_gemm.py index 13392de..a66167d 100644 --- a/maint/gemm_v2/latency_gemm.py +++ b/maint/gemm_v2/latency_gemm.py @@ -14,12 +14,11 @@ use_v2 = args.use_v2 # if not specified, it will be inferred from the input tensors during compile time @tilelang.jit def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def matmul_relu_kernel( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): diff --git a/maint/gemm_v2/latency_mha_fwd_bhsd.py b/maint/gemm_v2/latency_mha_fwd_bhsd.py index 4126bb9..3fd5600 100644 --- a/maint/gemm_v2/latency_mha_fwd_bhsd.py +++ b/maint/gemm_v2/latency_mha_fwd_bhsd.py @@ -8,13 +8,13 @@ import argparse from functools import partial parser = argparse.ArgumentParser() -parser.add_argument('--batch', type=int, default=128, help='batch size') -parser.add_argument('--heads', type=int, default=16, help='heads') -parser.add_argument('--seq_q', type=int, default=1024, help='query sequence length') -parser.add_argument('--seq_kv', type=int, default=1024, help='key/value sequence length') -parser.add_argument('--dim', type=int, default=256, help='dim') -parser.add_argument('--is_causal', action='store_true', help='causal') -parser.add_argument('--tune', action='store_true', help='tune configs') +parser.add_argument("--batch", type=int, default=128, help="batch size") +parser.add_argument("--heads", type=int, default=16, help="heads") +parser.add_argument("--seq_q", type=int, default=1024, help="query sequence length") +parser.add_argument("--seq_kv", type=int, default=1024, help="key/value sequence length") +parser.add_argument("--dim", type=int, default=256, help="dim") +parser.add_argument("--is_causal", action="store_true", help="causal") +parser.add_argument("--tune", action="store_true", help="tune configs") parser.add_argument("--use_v2", action="store_true") args = parser.parse_args() @@ -29,20 +29,13 @@ def get_configs(): @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch, - heads, - seq_q, - seq_kv, - dim, - is_causal, - block_M=64, - block_N=64, - num_stages=0, - threads=128): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=0, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, heads, seq_kv, dim] dtype = "float16" @@ -62,7 +55,7 @@ def flashattn(batch, by: T.int32, bz: T.int32, ): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): q_idx = bx * block_M + i + past_len @@ -85,7 +78,7 @@ def flashattn(batch, by: T.int32, bz: T.int32, ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) # T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) if use_v2: T.gemm_v2(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @@ -94,13 +87,13 @@ def flashattn(batch, @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -125,18 +118,18 @@ def flashattn(batch, @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -152,43 +145,42 @@ def flashattn(batch, scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = ( - T.min( - T.ceildiv(seq_kv, block_N), T.ceildiv( - (bx + 1) * block_M + - past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N)) + T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + if is_causal + else T.ceildiv(seq_kv, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bhqd,bhkd->bhqk', Q, K) + scores = torch.einsum("bhqd,bhkd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_q = Q.size(2) seq_kv = K.size(2) mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V) + output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V) return output @@ -206,18 +198,8 @@ def main( if is_causal: total_flops *= 0.5 - if (not tune): - kernel = flashattn( - batch, - heads, - seq_q, - seq_kv, - dim, - is_causal, - block_M=64, - block_N=64, - num_stages=0, - threads=128) + if not tune: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=0, threads=128) print(kernel.get_kernel_source()) ref_program_processed = partial(ref_program, is_causal=is_causal) diff --git a/maint/host_checks/01_num_args_mismatch.py b/maint/host_checks/01_num_args_mismatch.py index 8ba3664..9528652 100644 --- a/maint/host_checks/01_num_args_mismatch.py +++ b/maint/host_checks/01_num_args_mismatch.py @@ -3,6 +3,7 @@ Note: The adapter-level wrapper expects only inputs (A, B) because C is marked as output. Calling with the wrong number of inputs raises a ValueError before host entry. """ + import torch from common import build_matmul_kernel diff --git a/maint/host_checks/02_pointer_type_error.py b/maint/host_checks/02_pointer_type_error.py index fd35854..188a4f8 100644 --- a/maint/host_checks/02_pointer_type_error.py +++ b/maint/host_checks/02_pointer_type_error.py @@ -3,6 +3,7 @@ We pass an integer for A; wrapper forwards it to the host where a pointer is expected. Expected: error like "Expect buffer A_handle to be pointer or tensor" (exact name depends on kernel param). """ + import torch from common import build_matmul_kernel diff --git a/maint/host_checks/03_ndim_mismatch.py b/maint/host_checks/03_ndim_mismatch.py index 994ce23..76637e8 100644 --- a/maint/host_checks/03_ndim_mismatch.py +++ b/maint/host_checks/03_ndim_mismatch.py @@ -1,5 +1,5 @@ -"""Reproduce: ndim (rank) mismatch for A. -""" +"""Reproduce: ndim (rank) mismatch for A.""" + import torch from common import build_matmul_kernel diff --git a/maint/host_checks/04_dtype_mismatch.py b/maint/host_checks/04_dtype_mismatch.py index 6e6a050..f3554c1 100644 --- a/maint/host_checks/04_dtype_mismatch.py +++ b/maint/host_checks/04_dtype_mismatch.py @@ -1,5 +1,5 @@ -"""Reproduce: dtype mismatch for A (float32 vs expected float16). -""" +"""Reproduce: dtype mismatch for A (float32 vs expected float16).""" + import torch from common import build_matmul_kernel diff --git a/maint/host_checks/05_shape_mismatch.py b/maint/host_checks/05_shape_mismatch.py index 8b41ae3..a482481 100644 --- a/maint/host_checks/05_shape_mismatch.py +++ b/maint/host_checks/05_shape_mismatch.py @@ -1,5 +1,5 @@ -"""Reproduce: shape constant/symbol mismatch on A. -""" +"""Reproduce: shape constant/symbol mismatch on A.""" + import torch from common import build_matmul_kernel diff --git a/maint/host_checks/06_strides_mismatch.py b/maint/host_checks/06_strides_mismatch.py index 477d200..7e523cd 100644 --- a/maint/host_checks/06_strides_mismatch.py +++ b/maint/host_checks/06_strides_mismatch.py @@ -1,5 +1,5 @@ -"""Reproduce: strides check failure (non-contiguous A via transpose). -""" +"""Reproduce: strides check failure (non-contiguous A via transpose).""" + import torch from common import build_matmul_kernel diff --git a/maint/host_checks/07_device_type_mismatch.py b/maint/host_checks/07_device_type_mismatch.py index 67cb771..af8e5ef 100644 --- a/maint/host_checks/07_device_type_mismatch.py +++ b/maint/host_checks/07_device_type_mismatch.py @@ -1,5 +1,5 @@ -"""Reproduce: device_type mismatch by passing CPU tensors to a CUDA kernel. -""" +"""Reproduce: device_type mismatch by passing CPU tensors to a CUDA kernel.""" + import torch from common import build_matmul_kernel diff --git a/maint/host_checks/08_device_id_mismatch.py b/maint/host_checks/08_device_id_mismatch.py index 6491096..280aca1 100644 --- a/maint/host_checks/08_device_id_mismatch.py +++ b/maint/host_checks/08_device_id_mismatch.py @@ -1,5 +1,5 @@ -"""Reproduce: device_id mismatch (requires >=2 CUDA devices). -""" +"""Reproduce: device_id mismatch (requires >=2 CUDA devices).""" + import torch from common import build_matmul_kernel diff --git a/maint/host_checks/09_null_data_pointer.py b/maint/host_checks/09_null_data_pointer.py index 00bac67..09f5de1 100644 --- a/maint/host_checks/09_null_data_pointer.py +++ b/maint/host_checks/09_null_data_pointer.py @@ -7,6 +7,7 @@ or a host-side non-NULL pointer check. Note: Constructing a true DLTensor with NULL data in PyTorch is not typical; this script demonstrates passing None, which still reproduces the intended class of failure. """ + import torch from common import build_matmul_kernel diff --git a/maint/host_checks/10_scalar_type_mismatch.py b/maint/host_checks/10_scalar_type_mismatch.py index f1fcba2..4f2c90b 100644 --- a/maint/host_checks/10_scalar_type_mismatch.py +++ b/maint/host_checks/10_scalar_type_mismatch.py @@ -1,5 +1,5 @@ -"""Reproduce: scalar parameter type mismatch (int/bool). -""" +"""Reproduce: scalar parameter type mismatch (int/bool).""" + from common import build_scalar_check_kernel diff --git a/maint/host_checks/common.py b/maint/host_checks/common.py index cdafc8b..649527d 100644 --- a/maint/host_checks/common.py +++ b/maint/host_checks/common.py @@ -3,20 +3,12 @@ import tilelang.language as T import torch -def make_matmul_prim(M, - N, - K, - block_M=128, - block_N=128, - block_K=32, - dtype="float16", - accum_dtype="float"): - +def make_matmul_prim(M, N, K, block_M=128, block_N=128, block_K=32, dtype="float16", accum_dtype="float"): @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -42,7 +34,6 @@ def build_matmul_kernel(M=1024, N=1024, K=1024, target="cuda"): def build_scalar_check_kernel(target="cuda"): - @T.prim_func def scalar_check(x: T.int32, flag: T.bool()): T.evaluate(0) diff --git a/maint/precision/compare_ops.py b/maint/precision/compare_ops.py index 7d0d67d..985c3bd 100644 --- a/maint/precision/compare_ops.py +++ b/maint/precision/compare_ops.py @@ -37,7 +37,7 @@ OP_NAMES: Dict[int, str] = { 6: "sqrt", 7: "tanh", 8: "rsqrt", - 9: "inv_sqrt" + 9: "inv_sqrt", } # Block sizes for kernels @@ -49,8 +49,7 @@ TILELANG_THREADS = 128 def parse_arguments() -> argparse.Namespace: """Parse command line arguments.""" - parser = argparse.ArgumentParser( - description="Precision comparison tool for various CUDA implementations") + parser = argparse.ArgumentParser(description="Precision comparison tool for various CUDA implementations") parser.add_argument("--n", type=int, default=1000000, help="Number of elements to test") parser.add_argument("--low", type=float, default=-4.0, help="Lower bound for random values") parser.add_argument("--high", type=float, default=4.0, help="Upper bound for random values") @@ -67,7 +66,7 @@ def initialize_cuda() -> torch.nn.Module: return load( name="cuda_ops", sources=["cuda_ops.cu"], - extra_cuda_cflags=[] # No fast_math flags + extra_cuda_cflags=[], # No fast_math flags ) @@ -149,8 +148,7 @@ def triton_unary_kernel(x_ptr, out_ptr, n_elements, op_id: tl.constexpr, BLOCK_S @triton.jit -def triton_libdevice_unary_kernel(x_ptr, out_ptr, n_elements, op_id: tl.constexpr, - BLOCK_SIZE: tl.constexpr): +def triton_libdevice_unary_kernel(x_ptr, out_ptr, n_elements, op_id: tl.constexpr, BLOCK_SIZE: tl.constexpr): """LibDevice Triton kernel for unary operations.""" pid = tl.program_id(0) block_start = pid * BLOCK_SIZE @@ -188,13 +186,10 @@ def make_tilelang_unary_kernel(M: int, N: int, op_id: int, use_fastmath: bool = @T.prim_func def tilelang_unary_kernel( - A: T.Tensor((M, N), "float32"), - B: T.Tensor((M, N), "float32"), + A: T.Tensor((M, N), "float32"), + B: T.Tensor((M, N), "float32"), ): - with T.Kernel( - T.ceildiv(N, TILELANG_BLOCK_N), - T.ceildiv(M, TILELANG_BLOCK_M), - threads=TILELANG_THREADS) as (bx, by): + with T.Kernel(T.ceildiv(N, TILELANG_BLOCK_N), T.ceildiv(M, TILELANG_BLOCK_M), threads=TILELANG_THREADS) as (bx, by): for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N): row = by * TILELANG_BLOCK_M + i col = bx * TILELANG_BLOCK_N + j @@ -229,14 +224,11 @@ def make_tilelang_binary_kernel(M: int, N: int): @T.prim_func def tilelang_binary_kernel( - A: T.Tensor((M, N), "float32"), - B: T.Tensor((M, N), "float32"), - C: T.Tensor((M, N), "float32"), + A: T.Tensor((M, N), "float32"), + B: T.Tensor((M, N), "float32"), + C: T.Tensor((M, N), "float32"), ): - with T.Kernel( - T.ceildiv(N, TILELANG_BLOCK_N), - T.ceildiv(M, TILELANG_BLOCK_M), - threads=TILELANG_THREADS) as (bx, by): + with T.Kernel(T.ceildiv(N, TILELANG_BLOCK_N), T.ceildiv(M, TILELANG_BLOCK_M), threads=TILELANG_THREADS) as (bx, by): for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N): row = by * TILELANG_BLOCK_M + i col = bx * TILELANG_BLOCK_N + j @@ -247,10 +239,7 @@ def make_tilelang_binary_kernel(M: int, N: int): return tilelang_binary_kernel -def tilelang_op(x: torch.Tensor, - op_id: int, - y: Optional[torch.Tensor] = None, - use_fastmath: bool = False) -> torch.Tensor: +def tilelang_op(x: torch.Tensor, op_id: int, y: Optional[torch.Tensor] = None, use_fastmath: bool = False) -> torch.Tensor: """TileLang operation interface.""" assert x.is_cuda @@ -272,7 +261,8 @@ def tilelang_op(x: torch.Tensor, target="cuda", pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: use_fastmath, - }) + }, + ) out = kernel(x, y) else: # Unary operation kernel_func = make_tilelang_unary_kernel(M, N, op_id, use_fastmath) @@ -282,7 +272,8 @@ def tilelang_op(x: torch.Tensor, target="cuda", pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: use_fastmath, - }) + }, + ) out = kernel(x) # Restore original shape @@ -293,7 +284,7 @@ def triton_op(x: torch.Tensor, op_id: int, y: Optional[torch.Tensor] = None) -> """Standard Triton operation interface.""" assert x.is_cuda out = torch.empty_like(x) - grid = lambda meta: ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],) + grid = lambda meta: ((x.numel() + meta["BLOCK_SIZE"] - 1) // meta["BLOCK_SIZE"],) if op_id == 0: # Division - binary operation assert y is not None, "Division operation requires second operand" @@ -304,13 +295,11 @@ def triton_op(x: torch.Tensor, op_id: int, y: Optional[torch.Tensor] = None) -> return out -def triton_libdevice_op(x: torch.Tensor, - op_id: int, - y: Optional[torch.Tensor] = None) -> torch.Tensor: +def triton_libdevice_op(x: torch.Tensor, op_id: int, y: Optional[torch.Tensor] = None) -> torch.Tensor: """LibDevice Triton operation interface.""" assert x.is_cuda out = torch.empty_like(x) - grid = lambda meta: ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],) + grid = lambda meta: ((x.numel() + meta["BLOCK_SIZE"] - 1) // meta["BLOCK_SIZE"],) if op_id == 0: # Division - binary operation assert y is not None, "Division operation requires second operand" @@ -321,9 +310,7 @@ def triton_libdevice_op(x: torch.Tensor, return out -def get_pytorch_reference(x: torch.Tensor, - op_id: int, - y: Optional[torch.Tensor] = None) -> torch.Tensor: +def get_pytorch_reference(x: torch.Tensor, op_id: int, y: Optional[torch.Tensor] = None) -> torch.Tensor: """Get PyTorch reference implementation for the given operation.""" if op_id == 0: assert y is not None, "Division requires second operand" @@ -362,8 +349,10 @@ def summarize_error(tag: str, output: Optional[torch.Tensor], reference: torch.T abs_err = (output_double - reference_double).abs() rel_err = abs_err / (reference_double.abs().clamp_min(1e-30)) - print(f"{tag:<32} max abs: {abs_err.max():.3e}, mean abs: {abs_err.mean():.3e}, " - f"max rel: {rel_err.max():.3e}, mean rel: {rel_err.mean():.3e}") + print( + f"{tag:<32} max abs: {abs_err.max():.3e}, mean abs: {abs_err.mean():.3e}, " + f"max rel: {rel_err.max():.3e}, mean rel: {rel_err.mean():.3e}" + ) # Precision comparison function @@ -407,9 +396,7 @@ def compare(op_id: int, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> No results[name] = None # Print comparison header - print( - f"{'Implementation':<32} {'Max Abs Error':<19} {'Mean Abs Error':<20} {'Max Rel Error':<19} {'Mean Rel Error'}" - ) + print(f"{'Implementation':<32} {'Max Abs Error':<19} {'Mean Abs Error':<20} {'Max Rel Error':<19} {'Mean Rel Error'}") print("-" * 90) # Compare all implementations against double precision reference @@ -427,8 +414,7 @@ def compare(op_id: int, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> No summarize_error(tag, output, ref_double) -def generate_test_data(op_id: int, n: int, device: torch.device, low: float, - high: float) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: +def generate_test_data(op_id: int, n: int, device: torch.device, low: float, high: float) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Generate appropriate test data for each operation.""" if op_id == 0: # Division x = torch.empty(n, device=device).uniform_(low, high) @@ -450,9 +436,7 @@ def generate_test_data(op_id: int, n: int, device: torch.device, low: float, def main() -> None: """Main execution function.""" - print( - "Precision comparison between CUDA Precise/Fast, Triton, Triton LibDevice, PyTorch, and TileLang" - ) + print("Precision comparison between CUDA Precise/Fast, Triton, Triton LibDevice, PyTorch, and TileLang") print("=" * 90) for op_id in range(len(OP_NAMES)): diff --git a/maint/scripts/ci_performance.py b/maint/scripts/ci_performance.py index 998e7b6..8a353c0 100644 --- a/maint/scripts/ci_performance.py +++ b/maint/scripts/ci_performance.py @@ -10,39 +10,32 @@ env["TILELANG_CLEAR_CACHE"] = "1" def parse_output(output): data = {} - for line in output.split('\n'): + for line in output.split("\n"): line = line.strip() - if line.startswith('Latency:'): - match = re.search(r'Latency: ([\d.]+)', line) - data['latency'] = match.group(1) if match else 'N/A' - elif line.startswith('TFlops:'): - match = re.search(r'TFlops: ([\d.]+)', line) - data['best_tflops'] = match.group(1) if match else 'N/A' - elif line.startswith('Config:'): - data['config'] = line.split('Config: ')[-1] - elif line.startswith('Reference TFlops:'): - match = re.search(r'Reference TFlops: ([\d.]+)', line) - data['ref_tflops'] = match.group(1) if match else 'N/A' + if line.startswith("Latency:"): + match = re.search(r"Latency: ([\d.]+)", line) + data["latency"] = match.group(1) if match else "N/A" + elif line.startswith("TFlops:"): + match = re.search(r"TFlops: ([\d.]+)", line) + data["best_tflops"] = match.group(1) if match else "N/A" + elif line.startswith("Config:"): + data["config"] = line.split("Config: ")[-1] + elif line.startswith("Reference TFlops:"): + match = re.search(r"Reference TFlops: ([\d.]+)", line) + data["ref_tflops"] = match.group(1) if match else "N/A" return data -output_v1 = subprocess.run(['./tl/bin/python', './maint/scripts/performance.py'], - capture_output=True, - text=True, - env=env).stdout +output_v1 = subprocess.run(["./tl/bin/python", "./maint/scripts/performance.py"], capture_output=True, text=True, env=env).stdout data_v1 = parse_output(output_v1) -output_v2 = subprocess.run(['./tll/bin/python', './maint/scripts/performance.py'], - capture_output=True, - text=True, - env=env).stdout +output_v2 = subprocess.run(["./tll/bin/python", "./maint/scripts/performance.py"], capture_output=True, text=True, env=env).stdout data_v2 = parse_output(output_v2) -table = [[ - "original", data_v1['latency'], data_v1['best_tflops'], data_v1['ref_tflops'], data_v1['config'] -], [ - "current", data_v2['latency'], data_v2['best_tflops'], data_v2['ref_tflops'], data_v2['config'] -]] +table = [ + ["original", data_v1["latency"], data_v1["best_tflops"], data_v1["ref_tflops"], data_v1["config"]], + ["current", data_v2["latency"], data_v2["best_tflops"], data_v2["ref_tflops"], data_v2["config"]], +] headers = ["version", "Best Latency (s)", "Best TFlops", "Reference TFlops", "Best Config"] diff --git a/maint/scripts/performance.py b/maint/scripts/performance.py index 24c4a21..849bcf3 100644 --- a/maint/scripts/performance.py +++ b/maint/scripts/performance.py @@ -8,19 +8,20 @@ def ref_program(A, B): def get_configs(): - configs = [{ - "block_M": 128, - "block_N": 128, - "block_K": 64, - "num_stages": 2, - "thread_num": 256, - "enable_rasteration": True, # keep param name for backward-compat - }] + configs = [ + { + "block_M": 128, + "block_N": 128, + "block_K": 64, + "num_stages": 2, + "thread_num": 256, + "enable_rasteration": True, # keep param name for backward-compat + } + ] return configs def run(M, N, K): - def kernel( block_M=None, block_N=None, @@ -34,12 +35,11 @@ def run(M, N, K): @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_N, block_K), dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -60,12 +60,16 @@ def run(M, N, K): return main - autotuner = AutoTuner.from_kernel( - kernel=kernel, configs=get_configs()).set_compile_args( + autotuner = ( + AutoTuner.from_kernel(kernel=kernel, configs=get_configs()) + .set_compile_args( out_idx=[-1], target="auto", - ).set_profile_args( - ref_prog=ref_program,) + ) + .set_profile_args( + ref_prog=ref_program, + ) + ) return autotuner.run(warmup=3, rep=20) diff --git a/pyproject.toml b/pyproject.toml index 2246713..992eba5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,10 +122,7 @@ tilelang = "tilelang" "tilelang/3rdparty/composable_kernel/include" = "3rdparty/composable_kernel/include" "tilelang/3rdparty/composable_kernel/library" = "3rdparty/composable_kernel/library" -[tool.yapf] -based_on_style = "yapf" -column_limit = 100 -indent_width = 4 + [tool.codespell] ignore-words = "docs/spelling_wordlist.txt" @@ -138,7 +135,7 @@ skip = [ [tool.ruff] target-version = "py39" -line-length = 100 +line-length = 140 output-format = "full" exclude = [ @@ -146,6 +143,14 @@ exclude = [ "examples/deepseek_v32/inference", ] +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" +docstring-code-format = false +docstring-code-line-length = "dynamic" + [tool.ruff.lint.per-file-ignores] # Do not upgrade type hint in testing and examples. # See https://github.com/tile-ai/tilelang/issues/1079 for more information. diff --git a/requirements-lint.txt b/requirements-lint.txt index e64eee1..54f0363 100644 --- a/requirements-lint.txt +++ b/requirements-lint.txt @@ -4,4 +4,3 @@ clang-format==21.1.2 clang-tidy==21.1.1 codespell[toml]==2.4.1 ruff==0.14.3 -yapf==0.43.0 diff --git a/testing/conftest.py b/testing/conftest.py index 9f49d40..4010e0d 100644 --- a/testing/conftest.py +++ b/testing/conftest.py @@ -33,12 +33,9 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config): "warnings", "error", } - if (sum( - len(terminalreporter.stats.get(k, [])) - for k in known_types.difference({"skipped", "deselected"})) == 0): + if sum(len(terminalreporter.stats.get(k, [])) for k in known_types.difference({"skipped", "deselected"})) == 0: terminalreporter.write_sep( "!", - (f"Error: No tests were collected. " - f"{dict(sorted((k, len(v)) for k, v in terminalreporter.stats.items()))}"), + (f"Error: No tests were collected. {dict(sorted((k, len(v)) for k, v in terminalreporter.stats.items()))}"), ) pytest.exit("No tests were collected.", returncode=5) diff --git a/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py b/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py index a01bd45..4007beb 100644 --- a/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py +++ b/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py @@ -4,7 +4,8 @@ from tilelang import tvm as tvm import tilelang.language as T from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout from tilelang.intrinsics.mfma_macro_generator import ( - MatrixCoreIntrinEmitter,) + MatrixCoreIntrinEmitter, +) from tilelang.transform import simplify_prim_func tilelang.testing.set_random_seed(0) @@ -22,7 +23,6 @@ def tl_matmul( b_transposed=True, k_pack=1, ): - micro_size_x = micro_size_y = micro_size_k = 16 if in_dtype in {"float8_e4m3fnuz", "int8"}: @@ -78,12 +78,11 @@ def tl_matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -91,10 +90,12 @@ def tl_matmul( B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -102,7 +103,6 @@ def tl_matmul( T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=0): - # Load A into shared memory if a_transposed: T.copy(A[ko * block_K, by * block_M], A_shared) @@ -116,7 +116,6 @@ def tl_matmul( T.copy(B[ko * block_K, bx * block_N], B_shared) for ki in T.serial(0, (block_K // (k_pack * micro_size_k))): - # Load A into fragment mfma_emitter.ldmatrix_a( A_local, @@ -160,17 +159,8 @@ def tl_matmul( return main -def assert_tl_matmul_correctness(M, - N, - K, - in_dtype, - out_dtype, - accum_dtype="float32", - a_transposed=False, - b_transposed=True, - k_pack=1): - matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, - k_pack) +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="float32", a_transposed=False, b_transposed=True, k_pack=1): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack) print(matmul) kernel = tilelang.compile(matmul) src_code = kernel.get_kernel_source() @@ -201,16 +191,13 @@ def assert_tl_matmul_correctness(M, if a_transposed and b_transposed: # Get Reference Result - ref_c = torch.matmul(A.T.to(torch.float32), - B.T.to(torch.float32)).to(getattr(torch, out_dtype)) + ref_c = torch.matmul(A.T.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype)) elif a_transposed and not b_transposed: # Get Reference Result - ref_c = torch.matmul(A.Tto(torch.float32), - B.to(torch.float32)).to(getattr(torch, out_dtype)) + ref_c = torch.matmul(A.Tto(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype)) elif not a_transposed and b_transposed: # Get Reference Result - ref_c = torch.matmul(A.to(torch.float32), - B.T.to(torch.float32)).to(getattr(torch, out_dtype)) + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype)) else: # Get Reference Result ref_c = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype)) @@ -228,16 +215,13 @@ def test_assert_tl_matmul(): assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32") assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32") assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2) - assert_tl_matmul_correctness( - 128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32") - assert_tl_matmul_correctness( - 128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", k_pack=2) + assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32") + assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", k_pack=2) assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3fnuz", "float16") assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32") assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", k_pack=2) assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False) - assert_tl_matmul_correctness( - 128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False, k_pack=2) + assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False, k_pack=2) if __name__ == "__main__": diff --git a/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py b/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py index b215f0d..393a77b 100644 --- a/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py +++ b/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py @@ -23,7 +23,6 @@ def tl_matmul( b_preshuffle=False, b_g2l_load=False, ): - micro_size_x = micro_size_y = micro_size_k = 16 if in_dtype in {"float8_e4m3fnuz", "int8"}: @@ -53,18 +52,21 @@ def tl_matmul( A_shape = (K, M) if a_transposed else (M, K) if b_preshuffle: - B_shape = (N // micro_size_y, K // pack_size_k, micro_size_y, - pack_size_k) if b_transposed else (K // pack_size_k, N // micro_size_y, - pack_size_k, micro_size_y) + B_shape = ( + (N // micro_size_y, K // pack_size_k, micro_size_y, pack_size_k) + if b_transposed + else (K // pack_size_k, N // micro_size_y, pack_size_k, micro_size_y) + ) else: B_shape = (N, K) if b_transposed else (K, N) A_shared_shape = (block_K, block_M) if a_transposed else (block_M, block_K) if b_preshuffle: - B_shared_shape = (block_N // micro_size_y, block_K // pack_size_k, micro_size_y, - pack_size_k) if b_transposed else (block_K // pack_size_k, - block_N // micro_size_y, pack_size_k, - micro_size_y) + B_shared_shape = ( + (block_N // micro_size_y, block_K // pack_size_k, micro_size_y, pack_size_k) + if b_transposed + else (block_K // pack_size_k, block_N // micro_size_y, pack_size_k, micro_size_y) + ) else: B_shared_shape = (block_N, block_K) if b_transposed else (block_K, block_N) @@ -94,21 +96,22 @@ def tl_matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + } + ) num_ko = K // block_K num_ki = block_K // (k_pack * micro_size_k) @@ -119,7 +122,6 @@ def tl_matmul( T.clear(C_local) for ko in T.Pipelined(num_ko, num_stages=0): - # Load A into shared memory if a_transposed: T.copy(A[ko * block_K, by * block_M], A_shared) @@ -129,20 +131,13 @@ def tl_matmul( # Load B into shared memory if b_g2l_load is False: if b_transposed: - for j, k, jj, kk in T.Parallel(block_N // micro_size_y, - block_K // pack_size_k, micro_size_y, - pack_size_k): - B_shared[j, k, jj, kk] = B[bx * block_N // micro_size_y + j, - ko * block_K // pack_size_k + k, jj, kk] + for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // pack_size_k, micro_size_y, pack_size_k): + B_shared[j, k, jj, kk] = B[bx * block_N // micro_size_y + j, ko * block_K // pack_size_k + k, jj, kk] else: - for k, j, kk, jj in T.Parallel(block_K // pack_size_k, - block_N // micro_size_y, pack_size_k, - micro_size_y): - B_shared[k, j, kk, jj] = B[ko * block_K // pack_size_k + k, - bx * block_N // micro_size_y + j, kk, jj] + for k, j, kk, jj in T.Parallel(block_K // pack_size_k, block_N // micro_size_y, pack_size_k, micro_size_y): + B_shared[k, j, kk, jj] = B[ko * block_K // pack_size_k + k, bx * block_N // micro_size_y + j, kk, jj] for ki in T.serial(0, num_ki): - # Load A S2L mfma_emitter.ldmatrix_a( A_local, @@ -176,10 +171,10 @@ def tl_matmul( def shuffle_weight( - x: torch.Tensor, - layout=(16, 32), - k_pack=1, - is_transpose=False, + x: torch.Tensor, + layout=(16, 32), + k_pack=1, + is_transpose=False, ) -> torch.Tensor: IN, IK = layout BK = IK * k_pack @@ -194,19 +189,20 @@ def shuffle_weight( return x.contiguous() -def assert_tl_matmul_correctness(M, - N, - K, - in_dtype, - out_dtype, - accum_dtype="float32", - a_transposed=False, - b_transposed=True, - k_pack=1, - b_preshuffle=False, - b_g2l_load=False): - matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, - k_pack, b_preshuffle, b_g2l_load) +def assert_tl_matmul_correctness( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype="float32", + a_transposed=False, + b_transposed=True, + k_pack=1, + b_preshuffle=False, + b_g2l_load=False, +): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack, b_preshuffle, b_g2l_load) print(matmul) kernel = tilelang.compile(matmul) src_code = kernel.get_kernel_source() @@ -244,16 +240,13 @@ def assert_tl_matmul_correctness(M, if a_transposed and b_transposed: # Get Reference Result - ref_c = torch.matmul(A.T.to(torch.float32), - B.T.to(torch.float32)).to(getattr(torch, out_dtype)) + ref_c = torch.matmul(A.T.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype)) elif a_transposed and not b_transposed: # Get Reference Result - ref_c = torch.matmul(A.Tto(torch.float32), - B.to(torch.float32)).to(getattr(torch, out_dtype)) + ref_c = torch.matmul(A.Tto(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype)) elif not a_transposed and b_transposed: # Get Reference Result - ref_c = torch.matmul(A.to(torch.float32), - B.T.to(torch.float32)).to(getattr(torch, out_dtype)) + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype)) else: # Get Reference Result ref_c = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype)) @@ -266,40 +259,17 @@ def assert_tl_matmul_correctness(M, @tilelang.testing.requires_rocm def test_assert_tl_matmul(): - assert_tl_matmul_correctness( - 256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True) - assert_tl_matmul_correctness( - 256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True) - assert_tl_matmul_correctness( - 256, 256, 512, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True) - - assert_tl_matmul_correctness( - 256, 256, 512, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True) - assert_tl_matmul_correctness( - 256, - 256, - 512, - "int8", - "int32", - b_transposed=False, - accum_dtype="int32", - k_pack=2, - b_preshuffle=True) + assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True) + assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True) + assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True) + + assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True) + assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", b_transposed=False, accum_dtype="int32", k_pack=2, b_preshuffle=True) assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", b_preshuffle=True) - assert_tl_matmul_correctness( - 256, 256, 512, "float8_e4m3fnuz", "float32", b_transposed=False, b_preshuffle=True) - assert_tl_matmul_correctness( - 256, 256, 512, "float8_e4m3fnuz", "float32", k_pack=2, b_preshuffle=True) - assert_tl_matmul_correctness( - 256, - 256, - 512, - "float8_e4m3fnuz", - "float32", - k_pack=2, - b_transposed=False, - b_preshuffle=True) + assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", b_transposed=False, b_preshuffle=True) + assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", k_pack=2, b_preshuffle=True) + assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", k_pack=2, b_transposed=False, b_preshuffle=True) if __name__ == "__main__": diff --git a/testing/python/amd/test_tilelang_test_amd.py b/testing/python/amd/test_tilelang_test_amd.py index 456a3ae..0666fd4 100644 --- a/testing/python/amd/test_tilelang_test_amd.py +++ b/testing/python/amd/test_tilelang_test_amd.py @@ -27,8 +27,7 @@ def matmul( vec_size = 4 * k_pack @T.prim_func - def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor( - (M, N), out_dtype)): + def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) @@ -111,8 +110,7 @@ def test_gemm_bf16f32f32_nt(): run_gemm(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32) run_gemm(1024, 1024, 1024, True, True, "bfloat16", "float32", "float32", 128, 128, 32) run_gemm(1024, 1024, 1024, True, False, "bfloat16", "float32", "float32", 128, 128, 32) - run_gemm( - 1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32, k_pack=2) + run_gemm(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32, k_pack=2) @tilelang.testing.requires_rocm @@ -121,8 +119,7 @@ def test_gemm_bf16bf16f32(): run_gemm(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32) run_gemm(1024, 1024, 1024, True, True, "bfloat16", "bfloat16", "float32", 128, 128, 32) run_gemm(1024, 1024, 1024, True, False, "bfloat16", "bfloat16", "float32", 128, 128, 32) - run_gemm( - 1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32, k_pack=2) + run_gemm(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32, k_pack=2) def matmul_rs( @@ -149,9 +146,9 @@ def matmul_rs( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) diff --git a/testing/python/analysis/test_tilelang_fragment_loop_checker.py b/testing/python/analysis/test_tilelang_fragment_loop_checker.py index df88573..85aa518 100644 --- a/testing/python/analysis/test_tilelang_fragment_loop_checker.py +++ b/testing/python/analysis/test_tilelang_fragment_loop_checker.py @@ -5,14 +5,12 @@ import pytest @tilelang.jit -def simple_invalid_loop(dtype: str = "bfloat16", - accum_dtype: str = "float32", - num_threads: int = 128): +def simple_invalid_loop(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128): A = T.dynamic("A") @T.prim_func def main( - data: T.Tensor((128, A), dtype), # type: ignore + data: T.Tensor((128, A), dtype), # type: ignore ): with T.Kernel(128, threads=num_threads) as (tid,): data_frag = T.alloc_fragment([128], accum_dtype) @@ -28,14 +26,12 @@ def simple_invalid_loop(dtype: str = "bfloat16", @tilelang.jit -def nested_invalid_loop(dtype: str = "bfloat16", - accum_dtype: str = "float32", - num_threads: int = 128): +def nested_invalid_loop(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128): A = T.dynamic("A") @T.prim_func def main( - data: T.Tensor((128, A), dtype), # type: ignore + data: T.Tensor((128, A), dtype), # type: ignore ): with T.Kernel(128, threads=num_threads) as (tid,): data_frag = T.alloc_fragment([128], accum_dtype) @@ -52,14 +48,12 @@ def nested_invalid_loop(dtype: str = "bfloat16", @tilelang.jit -def invalid_loop_with_complex_dataflow(dtype: str = "bfloat16", - accum_dtype: str = "float32", - num_threads: int = 128): +def invalid_loop_with_complex_dataflow(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128): A = T.dynamic("A") @T.prim_func def main( - data: T.Tensor((128, A), dtype), # type: ignore + data: T.Tensor((128, A), dtype), # type: ignore ): with T.Kernel(128, threads=num_threads) as (tid,): data_frag = T.alloc_fragment([128], accum_dtype) @@ -75,14 +69,12 @@ def invalid_loop_with_complex_dataflow(dtype: str = "bfloat16", @tilelang.jit -def valid_loop_not_use_loop_var(dtype: str = "bfloat16", - accum_dtype: str = "float32", - num_threads: int = 128): +def valid_loop_not_use_loop_var(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128): A = T.dynamic("A") @T.prim_func def main( - data: T.Tensor((128, A), dtype), # type: ignore + data: T.Tensor((128, A), dtype), # type: ignore ): with T.Kernel(128, threads=num_threads) as (tid,): data_frag = T.alloc_fragment([128], accum_dtype) @@ -99,14 +91,12 @@ def valid_loop_not_use_loop_var(dtype: str = "bfloat16", @tilelang.jit -def valid_loop_not_frag(dtype: str = "bfloat16", - accum_dtype: str = "float32", - num_threads: int = 128): +def valid_loop_not_frag(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128): A = T.dynamic("A") @T.prim_func def main( - data: T.Tensor((128, A), dtype), # type: ignore + data: T.Tensor((128, A), dtype), # type: ignore ): with T.Kernel(128, threads=num_threads) as (tid,): data_shared = T.alloc_shared([128], accum_dtype) @@ -122,14 +112,12 @@ def valid_loop_not_frag(dtype: str = "bfloat16", @tilelang.jit -def valid_loop_serial(dtype: str = "bfloat16", - accum_dtype: str = "float32", - num_threads: int = 128): +def valid_loop_serial(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128): A = T.dynamic("A") @T.prim_func def main( - data: T.Tensor((128, A), dtype), # type: ignore + data: T.Tensor((128, A), dtype), # type: ignore ): with T.Kernel(128, threads=num_threads) as (tid,): data_shared = T.alloc_shared([128], accum_dtype) diff --git a/testing/python/analysis/test_tilelang_nested_loop_checker.py b/testing/python/analysis/test_tilelang_nested_loop_checker.py index d3c2ec2..e282c8e 100644 --- a/testing/python/analysis/test_tilelang_nested_loop_checker.py +++ b/testing/python/analysis/test_tilelang_nested_loop_checker.py @@ -30,11 +30,10 @@ Rule: @tilelang.jit(out_idx=[1]) def nested_continuous_parallels(length=256, block=16, dtype="float32"): - @T.prim_func def main( - A: T.Tensor((length,), dtype), - B: T.Tensor((length,), dtype), + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), ): with T.Kernel(1, threads=length) as _: for i in T.Parallel(length // block): @@ -46,29 +45,26 @@ def nested_continuous_parallels(length=256, block=16, dtype="float32"): @tilelang.jit(out_idx=[1]) def nested_triple_continuous_parallels(length=256, block1=8, block2=2, dtype="float32"): - @T.prim_func def main( - A: T.Tensor((length,), dtype), - B: T.Tensor((length,), dtype), + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), ): with T.Kernel(1, threads=length) as _: for i in T.Parallel(length // block1 // block2): for j in T.Parallel(block1): for k in T.Parallel(block2): - B[i * block1 * block2 + j * block2 + - k] = A[i * block1 * block2 + j * block2 + k] + 1.0 + B[i * block1 * block2 + j * block2 + k] = A[i * block1 * block2 + j * block2 + k] + 1.0 return main @tilelang.jit(out_idx=[1]) def nested_noncontinuous_parallels(length=256, block=16, dtype="float32"): - @T.prim_func def main( - A: T.Tensor((length,), dtype), - B: T.Tensor((length,), dtype), + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), ): with T.Kernel(1, threads=length) as _: for i in T.Parallel(length // block): @@ -103,8 +99,9 @@ is OK. """ -def matmul_nested_pipelines(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, - out_dtype, accum_dtype, threads, order, stage, extra_pipeline_repeats): +def matmul_nested_pipelines( + M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, threads, order, stage, extra_pipeline_repeats +): A_shape = (K, M) if trans_A else (M, K) B_shape = (N, K) if trans_B else (K, N) A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) @@ -114,9 +111,9 @@ def matmul_nested_pipelines(M, N, K, block_M, block_N, block_K, trans_A, trans_B @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -180,7 +177,8 @@ def run_gemm_nested_pipelines( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = kernel.get_profiler() def ref_program(A, B): @@ -193,8 +191,8 @@ def run_gemm_nested_pipelines( if in_dtype == "float32": # Convert float32 to tfloat32 because tfloat32 mma cannot truncate # float32 automatically, -0x1000 meas - A = ((A.view(torch.int32) - 0x1000)).view(torch.float32) - B = ((B.view(torch.int32) - 0x1000)).view(torch.float32) + A = (A.view(torch.int32) - 0x1000).view(torch.float32) + B = (B.view(torch.int32) - 0x1000).view(torch.float32) C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = C.to(torch.__getattribute__(out_dtype)) return C @@ -218,11 +216,10 @@ is OK. @tilelang.jit(out_idx=[1]) def nested_continuous_serials(length=256, block=16, dtype="float32"): - @T.prim_func def main( - A: T.Tensor((length,), dtype), - B: T.Tensor((length,), dtype), + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), ): with T.Kernel(1, threads=length) as _: for i in T.serial(length // block): @@ -234,11 +231,10 @@ def nested_continuous_serials(length=256, block=16, dtype="float32"): @tilelang.jit(out_idx=[1]) def nested_noncontinuous_serials(length=256, block=16, dtype="float32"): - @T.prim_func def main( - A: T.Tensor((length,), dtype), - B: T.Tensor((length,), dtype), + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), ): with T.Kernel(1, threads=length) as _: for i in T.serial(length // block): @@ -277,11 +273,10 @@ Rule: @tilelang.jit(out_idx=[1]) def nested_continuous_sp(length=256, block=16, dtype="float32"): - @T.prim_func def main( - A: T.Tensor((length,), dtype), - B: T.Tensor((length,), dtype), + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), ): with T.Kernel(1, threads=length) as _: for i in T.serial(length // block): @@ -293,11 +288,10 @@ def nested_continuous_sp(length=256, block=16, dtype="float32"): @tilelang.jit(out_idx=[1]) def nested_continuous_ps(length=256, block=16, dtype="float32"): - @T.prim_func def main( - A: T.Tensor((length,), dtype), - B: T.Tensor((length,), dtype), + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), ): with T.Kernel(1, threads=length) as _: for i in T.Parallel(length // block): @@ -309,36 +303,32 @@ def nested_continuous_ps(length=256, block=16, dtype="float32"): @tilelang.jit(out_idx=[1]) def nested_continuous_psp(length=256, block1=8, block2=2, dtype="float32"): - @T.prim_func def main( - A: T.Tensor((length,), dtype), - B: T.Tensor((length,), dtype), + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), ): with T.Kernel(1, threads=length) as _: for i in T.Parallel(length // block1 // block2): for j in T.serial(block1): for k in T.Parallel(block2): - B[i * block1 * block2 + j * block2 + - k] = A[i * block1 * block2 + j * block2 + k] + 1.0 + B[i * block1 * block2 + j * block2 + k] = A[i * block1 * block2 + j * block2 + k] + 1.0 return main @tilelang.jit(out_idx=[1]) def nested_continuous_sps(length=256, block1=8, block2=2, dtype="float32"): - @T.prim_func def main( - A: T.Tensor((length,), dtype), - B: T.Tensor((length,), dtype), + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), ): with T.Kernel(1, threads=length) as _: for i in T.serial(length // block1 // block2): for j in T.Parallel(block1): for k in T.serial(block2): - B[i * block1 * block2 + j * block2 + - k] = A[i * block1 * block2 + j * block2 + k] + 1.0 + B[i * block1 * block2 + j * block2 + k] = A[i * block1 * block2 + j * block2 + k] + 1.0 return main @@ -399,9 +389,9 @@ def matmul_nested_pipa( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -444,9 +434,9 @@ def matmul_nested_papipa( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -505,7 +495,8 @@ def run_gemm_mixed_pp( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = kernel.get_profiler() def ref_program(A, B): @@ -514,8 +505,8 @@ def run_gemm_mixed_pp( if in_dtype == "float32": # Convert float32 to tfloat32 because tfloat32 mma cannot truncate # float32 automatically, -0x1000 meas - A = ((A.view(torch.int32) - 0x1000)).view(torch.float32) - B = ((B.view(torch.int32) - 0x1000)).view(torch.float32) + A = (A.view(torch.int32) - 0x1000).view(torch.float32) + B = (B.view(torch.int32) - 0x1000).view(torch.float32) C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = C.to(torch.__getattribute__(out_dtype)) return C @@ -543,7 +534,8 @@ def run_gemm_mixed_pp( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) def test_mixed_pp(): @@ -576,9 +568,9 @@ def matmul_with_parallel( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -637,7 +629,8 @@ def run_gemm_tiled_op_with_parallel( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = kernel.get_profiler() def ref_program(A, B): @@ -646,8 +639,8 @@ def run_gemm_tiled_op_with_parallel( if in_dtype == "float32": # Convert float32 to tfloat32 because tfloat32 mma cannot truncate # float32 automatically, -0x1000 meas - A = ((A.view(torch.int32) - 0x1000)).view(torch.float32) - B = ((B.view(torch.int32) - 0x1000)).view(torch.float32) + A = (A.view(torch.int32) - 0x1000).view(torch.float32) + B = (B.view(torch.int32) - 0x1000).view(torch.float32) C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = C.to(torch.__getattribute__(out_dtype)) return C @@ -675,16 +668,16 @@ def run_gemm_tiled_op_with_parallel( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) @tilelang.jit(out_idx=[1]) def tir_op_with_parallel(length=256, block=16, dtype="float32"): - @T.prim_func def main( - A: T.Tensor((length,), dtype), - B: T.Tensor((length,), dtype), + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), ): with T.Kernel(1, threads=length) as _: for i in T.Parallel(length // block): @@ -696,11 +689,10 @@ def tir_op_with_parallel(length=256, block=16, dtype="float32"): @tilelang.jit(out_idx=[1]) def customize_op_with_parallel(length=256, block=16, dtype="float32"): - @T.prim_func def main( - A: T.Tensor((length,), dtype), - B: T.Tensor((length,), dtype), + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), ): with T.Kernel(1, threads=length) as _: for i in T.Parallel(length // block): diff --git a/testing/python/autotune/test_tilelang_autotune.py b/testing/python/autotune/test_tilelang_autotune.py index 85e2e48..3e6a05a 100644 --- a/testing/python/autotune/test_tilelang_autotune.py +++ b/testing/python/autotune/test_tilelang_autotune.py @@ -48,6 +48,7 @@ def get_configs(M, N, K, with_roller=False): from tilelang.carver.template import MatmulTemplate from tilelang.carver.arch import CUDA from tilelang.carver.roller.rasterization import NoRasterization + arch = CUDA("cuda") topk = 20 @@ -84,7 +85,6 @@ def get_configs(M, N, K, with_roller=False): for config in configs: print(config) else: - block_M = [64] block_N = [64] block_K = [32] @@ -100,7 +100,8 @@ def get_configs(M, N, K, with_roller=False): num_stages, thread_num, enable_rasterization, - )) + ) + ) configs = [ { @@ -110,7 +111,8 @@ def get_configs(M, N, K, with_roller=False): "num_stages": c[3], "thread_num": c[4], "enable_rasteration": c[5], # keep param name for backward-compat - } for c in _configs + } + for c in _configs ] return configs @@ -190,9 +192,9 @@ def matmul(M, N, K, with_roller): @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): """ The compiled TVM function for block-level matrix multiplication. @@ -206,9 +208,7 @@ def matmul(M, N, K, with_roller): """ # Bind x-dimension to block index in N, # y-dimension to block index in M. - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): # Allocate shared memory for A sub-block of shape (block_M, block_K) A_shared = T.alloc_shared((block_M, block_K), dtype) # Allocate shared memory for B sub-block of shape (block_N, block_K) @@ -247,12 +247,16 @@ def matmul(M, N, K, with_roller): return main - autotuner = AutoTuner.from_kernel( - kernel=kernel, configs=get_configs(M, N, K, with_roller)).set_compile_args( + autotuner = ( + AutoTuner.from_kernel(kernel=kernel, configs=get_configs(M, N, K, with_roller)) + .set_compile_args( out_idx=[-1], target="auto", - ).set_profile_args( - ref_prog=ref_program,) + ) + .set_profile_args( + ref_prog=ref_program, + ) + ) return autotuner.run(warmup=3, rep=20) diff --git a/testing/python/autotune/test_tilelang_autotune_with_inputs.py b/testing/python/autotune/test_tilelang_autotune_with_inputs.py index 39efce6..8f9a609 100644 --- a/testing/python/autotune/test_tilelang_autotune_with_inputs.py +++ b/testing/python/autotune/test_tilelang_autotune_with_inputs.py @@ -30,38 +30,23 @@ def ref_program(A, B): def get_configs(): - iter_params = dict( - block_M=[64], - block_N=[64], - block_K=[32], - num_stages=[0, 1], - thread_num=[128], - enable_rasterization=[False]) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] - - -@tilelang.autotune(configs=get_configs(),) -@tilelang.jit(out_idx=[-1]) -def matmul(M, - N, - K, - block_M=128, - block_N=128, - block_K=32, - num_stages=0, - thread_num=128, - enable_rasterization=False): + iter_params = dict(block_M=[64], block_N=[64], block_K=[32], num_stages=[0, 1], thread_num=[128], enable_rasterization=[False]) + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] + +@tilelang.autotune( + configs=get_configs(), +) +@tilelang.jit(out_idx=[-1]) +def matmul(M, N, K, block_M=128, block_N=128, block_K=32, num_stages=0, thread_num=128, enable_rasterization=False): dtype = "float16" accum_dtype = "float" @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): """ The compiled TVM function for block-level matrix multiplication. @@ -76,7 +61,6 @@ def matmul(M, # Bind x-dimension to block index in N, # y-dimension to block index in M. with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - # Allocate shared memory for A sub-block of shape (block_M, block_K) A_shared = T.alloc_shared((block_M, block_K), dtype) # Allocate shared memory for B sub-block of shape (block_N, block_K) diff --git a/testing/python/cache/test_tilelang_cache_matmul.py b/testing/python/cache/test_tilelang_cache_matmul.py index 6e966a8..f38ed48 100644 --- a/testing/python/cache/test_tilelang_cache_matmul.py +++ b/testing/python/cache/test_tilelang_cache_matmul.py @@ -28,9 +28,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -63,6 +63,7 @@ def run_cache_matmul(): Reference PyTorch matrix multiplication for comparison. """ import torch + C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = C.to(torch.half) # Assuming dtype="float16" in matmul return C diff --git a/testing/python/carver/test_tilelang_carver_cuda_driver_properties.py b/testing/python/carver/test_tilelang_carver_cuda_driver_properties.py index 46b17bf..67d20b8 100644 --- a/testing/python/carver/test_tilelang_carver_cuda_driver_properties.py +++ b/testing/python/carver/test_tilelang_carver_cuda_driver_properties.py @@ -29,9 +29,7 @@ class _cudaDeviceAttrNames: def test_driver_get_device_properties(): prop = get_cuda_device_properties() assert prop is not None, "Failed to get CUDA device properties" - assert isinstance( - prop, - torch.cuda._CudaDeviceProperties), ("Returned object is not of type _CudaDeviceProperties") + assert isinstance(prop, torch.cuda._CudaDeviceProperties), "Returned object is not of type _CudaDeviceProperties" def test_device_get_device_name(): @@ -48,8 +46,7 @@ def test_device_get_shared_memory_per_block(): def test_device_get_persisting_l2_cache_size(): tl_cache_size = get_persisting_l2_cache_max_size() - driver_cache_size = get_device_attribute( - _cudaDeviceAttrNames.cudaDevAttrMaxPersistingL2CacheSize) + driver_cache_size = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrMaxPersistingL2CacheSize) assert tl_cache_size == driver_cache_size, "Persisting L2 cache size values do not match" @@ -61,17 +58,14 @@ def test_device_get_num_sms(): def test_device_get_registers_per_block(): tl_regs_per_block = get_registers_per_block() - driver_regs_per_block = get_device_attribute( - _cudaDeviceAttrNames.cudaDevAttrMaxRegistersPerBlock) + driver_regs_per_block = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrMaxRegistersPerBlock) assert tl_regs_per_block == driver_regs_per_block, "Registers per block values do not match" def test_device_get_max_dynamic_shared_size_bytes(): tl_dynamic_smem = get_max_dynamic_shared_size_bytes() - driver_dynamic_smem = get_device_attribute( - _cudaDeviceAttrNames.cudaDevAttrMaxSharedMemoryPerMultiprocessor) - assert tl_dynamic_smem == driver_dynamic_smem, ( - "Max dynamic shared size bytes values do not match") + driver_dynamic_smem = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrMaxSharedMemoryPerMultiprocessor) + assert tl_dynamic_smem == driver_dynamic_smem, "Max dynamic shared size bytes values do not match" if __name__ == "__main__": diff --git a/testing/python/carver/test_tilelang_carver_generate_hints.py b/testing/python/carver/test_tilelang_carver_generate_hints.py index 43cdb27..313dc85 100644 --- a/testing/python/carver/test_tilelang_carver_generate_hints.py +++ b/testing/python/carver/test_tilelang_carver_generate_hints.py @@ -9,16 +9,13 @@ def run_general_matmul_emit_configs(M, N, K, topk: int = 20): arch = auto_infer_current_arch() def gemm(M, N, K): - A = te.placeholder((M, K), name='A', dtype='float16') - B = te.placeholder((N, K), name='B', dtype='float16') + A = te.placeholder((M, K), name="A", dtype="float16") + B = te.placeholder((N, K), name="B", dtype="float16") # Describe the matrix multiplication in TE - k = te.reduce_axis((0, K), name='k') + k = te.reduce_axis((0, K), name="k") - C = te.compute( - (M, N), - lambda i, j: te.sum(A[i, k].astype('float16') * B[j, k].astype('float16'), axis=[k]), - name='C') + C = te.compute((M, N), lambda i, j: te.sum(A[i, k].astype("float16") * B[j, k].astype("float16"), axis=[k]), name="C") return A, B, C @@ -29,8 +26,7 @@ def run_general_matmul_emit_configs(M, N, K, topk: int = 20): tensorized_func, tags = carver.utils.get_tensorized_func_and_tags(func, arch.target) print(tags) - policy = carver.TensorCorePolicy.from_prim_func( - func=tensorized_func, arch=arch, tags=tags, name="matmul_0") + policy = carver.TensorCorePolicy.from_prim_func(func=tensorized_func, arch=arch, tags=tags, name="matmul_0") hints = policy.emit_config(topk=topk) @@ -59,16 +55,13 @@ def run_general_matmul_matmul_emit_configs(M, N, K, topk: int = 20): arch = auto_infer_current_arch() def gemm(M, N, K): - A = te.placeholder((M, K), name='A', dtype='float16') - B = te.placeholder((N, K), name='B', dtype='float16') + A = te.placeholder((M, K), name="A", dtype="float16") + B = te.placeholder((N, K), name="B", dtype="float16") # Describe the matrix multiplication in TE - k = te.reduce_axis((0, K), name='k') + k = te.reduce_axis((0, K), name="k") - C = te.compute( - (M, N), - lambda i, j: te.sum(A[i, k].astype('float16') * B[j, k].astype('float16'), axis=[k]), - name='C') + C = te.compute((M, N), lambda i, j: te.sum(A[i, k].astype("float16") * B[j, k].astype("float16"), axis=[k]), name="C") return A, B, C diff --git a/testing/python/carver/test_tilelang_carver_recommend_hints.py b/testing/python/carver/test_tilelang_carver_recommend_hints.py index fee4676..4973c24 100644 --- a/testing/python/carver/test_tilelang_carver_recommend_hints.py +++ b/testing/python/carver/test_tilelang_carver_recommend_hints.py @@ -4,10 +4,7 @@ from tilelang.carver.arch import auto_infer_current_arch from typing import List -def run_general_reduction_recommend_hints(structure: str = "SSR", - shape: List[int] = None, - dtype: str = "float16", - topk: int = 20): +def run_general_reduction_recommend_hints(structure: str = "SSR", shape: List[int] = None, dtype: str = "float16", topk: int = 20): arch = auto_infer_current_arch() carve_template = carver.GeneralReductionTemplate( structure=structure, @@ -28,9 +25,7 @@ def test_general_reduction_recommend_hints(): run_general_reduction_recommend_hints("SRS", [1024, 1024, 1024], "float16") -def run_elementwise_recommend_hints(shape: List[int] = None, - dtype: str = "float16", - topk: int = 20): +def run_elementwise_recommend_hints(shape: List[int] = None, dtype: str = "float16", topk: int = 20): arch = auto_infer_current_arch() carve_template = carver.ElementwiseTemplate( shape=shape, @@ -81,11 +76,9 @@ def test_matmul_recommend_hints(): run_matmul_recommend_hints(1024, 1024, 1024, "float16", "float32", "float16") -def run_gemv_recommend_hints(N: int = 1024, - K: int = 1024, - in_dtype: str = "float16", - out_dtype: str = "float16", - accum_dtype: str = "float16"): +def run_gemv_recommend_hints( + N: int = 1024, K: int = 1024, in_dtype: str = "float16", out_dtype: str = "float16", accum_dtype: str = "float16" +): arch = auto_infer_current_arch() carve_template = carver.GEMVTemplate( N=N, diff --git a/testing/python/components/test_storage_rewrite_detect_inplace.py b/testing/python/components/test_storage_rewrite_detect_inplace.py index 1d60708..bd0a64d 100644 --- a/testing/python/components/test_storage_rewrite_detect_inplace.py +++ b/testing/python/components/test_storage_rewrite_detect_inplace.py @@ -23,7 +23,8 @@ def _compile_kernel_without_inplace(): @tilelang.jit( pass_configs={ tilelang.PassConfigKey.TL_STORAGE_REWRITE_DETECT_INPLACE: True, - },) + }, +) def _compile_kernel_with_inplace(): num_tokens = T.symbolic("num_tokens") diff --git a/testing/python/components/test_tilelang_pass_config_disable_warp_specialized.py b/testing/python/components/test_tilelang_pass_config_disable_warp_specialized.py index 499f334..323f764 100644 --- a/testing/python/components/test_tilelang_pass_config_disable_warp_specialized.py +++ b/testing/python/components/test_tilelang_pass_config_disable_warp_specialized.py @@ -26,9 +26,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -88,7 +88,8 @@ def run_gemm( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: disable_warp_specialized, - }) + }, + ) profiler = kernel.get_profiler() def ref_program(A, B): diff --git a/testing/python/cpu/test_tilelang_cpu_gemm.py b/testing/python/cpu/test_tilelang_cpu_gemm.py index 0129b37..4a878f3 100644 --- a/testing/python/cpu/test_tilelang_cpu_gemm.py +++ b/testing/python/cpu/test_tilelang_cpu_gemm.py @@ -10,9 +10,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo @T.prim_func def matmul( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), is_cpu=True) as (bx, by): A_local = T.alloc_local((block_M, block_K), dtype) @@ -31,7 +31,6 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo # ) for ko in T.Pipelined(K // block_K, num_stages=num_stages): - T.copy(A[by * block_M, ko * block_K], A_local) # Or Copy with Parallel @@ -62,14 +61,13 @@ def test_matmul_codegen(): def test_matmul_compile(): - def matmul_jit_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): # a simple kernel just for jit test @T.prim_func def matmul( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), is_cpu=True) as (bx, by): A_local = T.alloc_local((block_M, block_K), dtype) diff --git a/testing/python/debug/test_device_assert.py b/testing/python/debug/test_device_assert.py index 1602c30..210b896 100644 --- a/testing/python/debug/test_device_assert.py +++ b/testing/python/debug/test_device_assert.py @@ -7,7 +7,6 @@ import tilelang.language as T # TODO(dyq) It intentionally triggers a device-side assert so we can't include this in CI # Please run manually when you want to verify that device_assert actually traps on GPU. def _manual_device_assert_triggered(): - @T.prim_func def program(): with T.Kernel(threads=128): @@ -20,7 +19,6 @@ def _manual_device_assert_triggered(): def test_device_assert_no_trigger(): - @T.prim_func def program(): with T.Kernel(threads=128): diff --git a/testing/python/debug/test_tilelang_debug_print.py b/testing/python/debug/test_tilelang_debug_print.py index a1aa42e..e262966 100644 --- a/testing/python/debug/test_tilelang_debug_print.py +++ b/testing/python/debug/test_tilelang_debug_print.py @@ -6,7 +6,6 @@ import tilelang.language as T def debug_print_buffer(M=16, N=16, dtype="float16"): - @T.prim_func def program(Q: T.Tensor((M, N), dtype)): with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz): @@ -19,24 +18,24 @@ def debug_print_buffer(M=16, N=16, dtype="float16"): def test_debug_print_buffer(): - debug_print_buffer(dtype='bool') - debug_print_buffer(dtype='int8') - debug_print_buffer(dtype='int16') - debug_print_buffer(dtype='int32') - debug_print_buffer(dtype='int64') - debug_print_buffer(dtype='uint8') - debug_print_buffer(dtype='uint16') - debug_print_buffer(dtype='uint32') - debug_print_buffer(dtype='uint64') - debug_print_buffer(dtype='float16') - debug_print_buffer(dtype='float32') - debug_print_buffer(dtype='float64') - debug_print_buffer(dtype='bfloat16') - debug_print_buffer(dtype='float8_e4m3') - debug_print_buffer(dtype='float8_e4m3fn') - debug_print_buffer(dtype='float8_e4m3fnuz') - debug_print_buffer(dtype='float8_e5m2') - debug_print_buffer(dtype='float8_e5m2fnuz') + debug_print_buffer(dtype="bool") + debug_print_buffer(dtype="int8") + debug_print_buffer(dtype="int16") + debug_print_buffer(dtype="int32") + debug_print_buffer(dtype="int64") + debug_print_buffer(dtype="uint8") + debug_print_buffer(dtype="uint16") + debug_print_buffer(dtype="uint32") + debug_print_buffer(dtype="uint64") + debug_print_buffer(dtype="float16") + debug_print_buffer(dtype="float32") + debug_print_buffer(dtype="float64") + debug_print_buffer(dtype="bfloat16") + debug_print_buffer(dtype="float8_e4m3") + debug_print_buffer(dtype="float8_e4m3fn") + debug_print_buffer(dtype="float8_e4m3fnuz") + debug_print_buffer(dtype="float8_e5m2") + debug_print_buffer(dtype="float8_e5m2fnuz") def debug_print_buffer_conditional(M=16, N=16): diff --git a/testing/python/dynamic/test_tilelang_dynamic_symbolic.py b/testing/python/dynamic/test_tilelang_dynamic_symbolic.py index 4b9dff7..8e50a27 100644 --- a/testing/python/dynamic/test_tilelang_dynamic_symbolic.py +++ b/testing/python/dynamic/test_tilelang_dynamic_symbolic.py @@ -5,7 +5,7 @@ import tilelang.testing from tvm import DataType import tilelang.language as T from tilelang.intrinsics.utils import get_swizzle_layout -from tilelang.intrinsics.mma_macro_generator import (TensorCoreIntrinEmitter) +from tilelang.intrinsics.mma_macro_generator import TensorCoreIntrinEmitter tilelang.testing.set_random_seed(0) @@ -96,12 +96,11 @@ def tl_matmul_macro( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -109,10 +108,12 @@ def tl_matmul_macro( B_local = T.alloc_local((warp_cols * local_size), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -120,7 +121,6 @@ def tl_matmul_macro( T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -130,7 +130,6 @@ def tl_matmul_macro( B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a( A_local, @@ -207,8 +206,7 @@ def tl_matmul_block( B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) @T.prim_func - def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor( - (M, N), out_dtype)): + def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) @@ -306,8 +304,7 @@ def tl_matmul_block_all_dynamic( B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) @T.prim_func - def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor( - (M, N), out_dtype)): + def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) @@ -417,7 +414,7 @@ def assert_tl_matmul_block_all_dynamic_correctness_with_pass_config( ) pass_configs = { tilelang.PassConfigKey.TL_DISABLE_DYNAMIC_TAIL_SPLIT: dynamic_alignment != 0, - tilelang.PassConfigKey.TL_DYNAMIC_ALIGNMENT: dynamic_alignment + tilelang.PassConfigKey.TL_DYNAMIC_ALIGNMENT: dynamic_alignment, } if M % 64 == 0 or N % 64 == 0 or K % 64 != 0: # workaround for hopper tma lower pass @@ -462,55 +459,31 @@ def test_assert_tl_matmul_macro(): def test_assert_tl_matmul_block(): - assert_tl_matmul_block_correctness(128, 128, 128, False, False, "float16", "float16", "float16", - 64, 64, 32) - assert_tl_matmul_block_correctness(67, 128, 128, False, False, "float16", "float16", "float16", - 64, 64, 32) - assert_tl_matmul_block_correctness(36, 128, 128, False, False, "float16", "float16", "float16", - 64, 64, 32) + assert_tl_matmul_block_correctness(128, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) + assert_tl_matmul_block_correctness(67, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) + assert_tl_matmul_block_correctness(36, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) def test_assert_tl_matmul_block_all_dynamic(): - assert_tl_matmul_block_all_dynamic_correctness(128, 128, 128, False, False, "float16", - "float16", "float16", 64, 64, 32) - assert_tl_matmul_block_all_dynamic_correctness(67, 128, 128, False, False, "float16", "float16", - "float16", 64, 64, 32) - assert_tl_matmul_block_all_dynamic_correctness(36, 128, 128, False, False, "float16", "float16", - "float16", 64, 64, 32) + assert_tl_matmul_block_all_dynamic_correctness(128, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) + assert_tl_matmul_block_all_dynamic_correctness(67, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) + assert_tl_matmul_block_all_dynamic_correctness(36, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) def test_assert_tl_matmul_block_all_dynamic_with_pass_config(): assert_tl_matmul_block_all_dynamic_correctness_with_pass_config( - 128, - 128, - 128, - False, - False, - "float16", - "float16", - "float16", - 64, - 64, - 32, - dynamic_alignment=8) + 128, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=8 + ) assert_tl_matmul_block_all_dynamic_correctness_with_pass_config( - 64, - 128, - 128, - False, - False, - "float16", - "float16", - "float16", - 64, - 64, - 32, - dynamic_alignment=8) + 64, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=8 + ) assert_tl_matmul_block_all_dynamic_correctness_with_pass_config( - 64, 128, 60, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=4) + 64, 128, 60, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=4 + ) # Tail split is enabled with dynamic alignment 0 assert_tl_matmul_block_all_dynamic_correctness_with_pass_config( - 64, 128, 64, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=0) + 64, 128, 64, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=0 + ) if __name__ == "__main__": diff --git a/testing/python/dynamic/test_tilelang_dynamic_symbolic_bench.py b/testing/python/dynamic/test_tilelang_dynamic_symbolic_bench.py index b5ccbda..1bee135 100644 --- a/testing/python/dynamic/test_tilelang_dynamic_symbolic_bench.py +++ b/testing/python/dynamic/test_tilelang_dynamic_symbolic_bench.py @@ -25,10 +25,8 @@ def tl_matmul_block_static( B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) @T.prim_func - def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor( - (M, N), out_dtype)): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): + def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -137,10 +135,8 @@ def tl_matmul_block_dynamic_m( B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) @T.prim_func - def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor( - (M, N), out_dtype)): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): + def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -247,10 +243,8 @@ def tl_matmul_block_dynamic_mn( B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) @T.prim_func - def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor( - (M, N), out_dtype)): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): + def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -357,10 +351,8 @@ def tl_matmul_block_dynamic_mnk( B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) @T.prim_func - def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor( - (M, N), out_dtype)): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): + def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -445,8 +437,7 @@ def assert_tl_matmul_block_dynamic_mnk( def run_assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K): - assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K, False, False, "float16", - "float16", "float32") + assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K, False, False, "float16", "float16", "float32") def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K): @@ -462,10 +453,8 @@ def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K): "float16", "float16", "float32", - pass_configs={ - "tl.disable_dynamic_tail_split": True, - "tl.dynamic_alignment": 8 - }) + pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8}, + ) assert_tl_matmul_block_dynamic_m( M, N, @@ -478,7 +467,8 @@ def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K): "float16", "float16", "float32", - pass_configs={"tl.disable_dynamic_tail_split": False}) + pass_configs={"tl.disable_dynamic_tail_split": False}, + ) def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K): @@ -494,10 +484,8 @@ def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K): "float16", "float16", "float32", - pass_configs={ - "tl.disable_dynamic_tail_split": True, - "tl.dynamic_alignment": 8 - }) + pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8}, + ) assert_tl_matmul_block_dynamic_mn( M, N, @@ -510,7 +498,8 @@ def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K): "float16", "float16", "float32", - pass_configs={"tl.disable_dynamic_tail_split": False}) + pass_configs={"tl.disable_dynamic_tail_split": False}, + ) def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K): @@ -526,10 +515,8 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K): "float16", "float16", "float32", - pass_configs={ - "tl.disable_dynamic_tail_split": True, - "tl.dynamic_alignment": 4 - }) + pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 4}, + ) assert_tl_matmul_block_dynamic_mnk( M, N, @@ -542,7 +529,8 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K): "float16", "float16", "float32", - pass_configs={"tl.disable_dynamic_tail_split": False}) + pass_configs={"tl.disable_dynamic_tail_split": False}, + ) def test_all(): diff --git a/testing/python/fastmath/test_mathops_fastmath.py b/testing/python/fastmath/test_mathops_fastmath.py index c3b5d1b..7809983 100644 --- a/testing/python/fastmath/test_mathops_fastmath.py +++ b/testing/python/fastmath/test_mathops_fastmath.py @@ -7,16 +7,16 @@ import re def get_mathop_lines(source, mathop_name): """Extract lines containing the mathop from CUDA source for debugging""" - lines = source.split('\n') + lines = source.split("\n") relevant_lines = [] for i, line in enumerate(lines): - if mathop_name in line and ('(' in line): + if mathop_name in line and ("(" in line): # Include some context start = max(0, i - 1) end = min(len(lines), i + 2) relevant_lines.extend([f"{j}: {lines[j]}" for j in range(start, end)]) relevant_lines.append("---") - return '\n'.join(relevant_lines[-10:]) # Show last 10 lines to avoid too much output + return "\n".join(relevant_lines[-10:]) # Show last 10 lines to avoid too much output def check_fastmath_usage(source, mathop_name, expect_fastmath=False): @@ -27,9 +27,7 @@ def check_fastmath_usage(source, mathop_name, expect_fastmath=False): fastmath_matches = re.findall(fastmath_pattern, source) non_fastmath_matches = re.findall(non_fastmath_pattern, source) - print( - f"Found {len(fastmath_matches)} fastmath calls, {len(non_fastmath_matches)} non-fastmath calls" - ) + print(f"Found {len(fastmath_matches)} fastmath calls, {len(non_fastmath_matches)} non-fastmath calls") if len(fastmath_matches) > 0: print(f"Fastmath calls found: {fastmath_matches}") if len(non_fastmath_matches) > 0: @@ -51,13 +49,7 @@ def check_non_fastmath_usage(source, mathop_name): check_fastmath_usage(source, mathop_name, expect_fastmath=False) -def run_single_arg_mathop_test(mathop_name, - mathop_func, - M=128, - N=128, - block_M=32, - block_N=32, - dtype="float32"): +def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"): """ Test single-argument mathops. T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath) @@ -65,13 +57,12 @@ def run_single_arg_mathop_test(mathop_name, @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): for i, j in T.Parallel(block_M, block_N): - B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, - bx * block_N + j]) + B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j]) # Test with FAST_MATH disabled kernel_no_fastmath = tilelang.compile( @@ -80,7 +71,8 @@ def run_single_arg_mathop_test(mathop_name, target="cuda", pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, - }) + }, + ) source_no_fastmath = kernel_no_fastmath.get_kernel_source() @@ -93,28 +85,22 @@ def run_single_arg_mathop_test(mathop_name, print(f"✓ {mathop_name} compilation and execution test passed") -def run_two_arg_mathop_test(mathop_name, - mathop_func, - M=128, - N=128, - block_M=32, - block_N=32, - dtype="float32"): +def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"): """ Test two-argument mathops to ensure they generate non-fastmath CUDA code. """ @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, - bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j], - B[by * block_M + i, bx * block_N + j]) + C[by * block_M + i, bx * block_N + j] = mathop_func( + A[by * block_M + i, bx * block_N + j], B[by * block_M + i, bx * block_N + j] + ) # Test with FAST_MATH disabled kernel_no_fastmath = tilelang.compile( @@ -123,7 +109,8 @@ def run_two_arg_mathop_test(mathop_name, target="cuda", pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, - }) + }, + ) # Test with FAST_MATH enabled kernel_fastmath = tilelang.compile( @@ -132,7 +119,8 @@ def run_two_arg_mathop_test(mathop_name, target="cuda", pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, + ) source_no_fastmath = kernel_no_fastmath.get_kernel_source() source_fastmath = kernel_fastmath.get_kernel_source() @@ -171,8 +159,8 @@ def run_abs_test(): @T.prim_func def main( - A: T.Tensor((M, N), "float32"), - B: T.Tensor((M, N), "float32"), + A: T.Tensor((M, N), "float32"), + B: T.Tensor((M, N), "float32"), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): for i, j in T.Parallel(block_M, block_N): @@ -184,7 +172,8 @@ def run_abs_test(): target="cuda", pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, - }) + }, + ) source = kernel.get_kernel_source() print("\n=== Testing abs (maps to fabs) ===") @@ -199,26 +188,19 @@ def run_abs_test(): print("✓ abs numerical test passed") -def run_fastmath_mathop_test(mathop_name, - mathop_func, - M=128, - N=128, - block_M=32, - block_N=32, - dtype="float32"): +def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"): """ Test fastmath mathops to ensure they generate fastmath CUDA code (with __ prefix). """ @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): for i, j in T.Parallel(block_M, block_N): - B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, - bx * block_N + j]) + B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j]) # Test with FAST_MATH enabled kernel_fastmath = tilelang.compile( @@ -227,14 +209,15 @@ def run_fastmath_mathop_test(mathop_name, target="cuda", pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, + ) source_fastmath = kernel_fastmath.get_kernel_source() print(f"\n=== Testing {mathop_name} (fastmath version) ===") print("FAST_MATH=True:") # Strip the __ prefix for checking in the CUDA source - cuda_mathop_name = mathop_name.lstrip('_') + cuda_mathop_name = mathop_name.lstrip("_") check_fastmath_usage(source_fastmath, cuda_mathop_name, expect_fastmath=True) # Test numerical correctness diff --git a/testing/python/issue/test_tilelang_issue_1001.py b/testing/python/issue/test_tilelang_issue_1001.py index 77d8cc1..a4283da 100644 --- a/testing/python/issue/test_tilelang_issue_1001.py +++ b/testing/python/issue/test_tilelang_issue_1001.py @@ -8,14 +8,15 @@ from tilelang import language as T pass_configs={ tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - },) + }, +) def _cumsum_view_infer_layout(hidden): - num_tokens = T.dynamic('num_tokens') + num_tokens = T.dynamic("num_tokens") @T.prim_func - def buggy_kernel(x: T.Tensor[(num_tokens, hidden), 'float']): + def buggy_kernel(x: T.Tensor[(num_tokens, hidden), "float"]): with T.Kernel(num_tokens, threads=128) as pid: - smem = T.alloc_shared((hidden,), dtype='float') + smem = T.alloc_shared((hidden,), dtype="float") T.copy(x[pid, :], smem) T.cumsum(T.view(smem, (1, hidden)), dim=1) @@ -24,10 +25,10 @@ def _cumsum_view_infer_layout(hidden): def test_cumsum_view_infer_layout(): hidden = 128 - x = torch.randn(1, hidden, device='cuda', dtype=torch.float) + x = torch.randn(1, hidden, device="cuda", dtype=torch.float) kernel = _cumsum_view_infer_layout(hidden) kernel(x) -if __name__ == '__main__': +if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/python/issue/test_tilelang_issue_1008.py b/testing/python/issue/test_tilelang_issue_1008.py index 395593d..2d86d16 100644 --- a/testing/python/issue/test_tilelang_issue_1008.py +++ b/testing/python/issue/test_tilelang_issue_1008.py @@ -8,12 +8,13 @@ from tilelang import language as T pass_configs={ tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - },) + }, +) def _fill_with_static_region_kernel(): - num_tokens = T.symbolic('num_tokens') + num_tokens = T.symbolic("num_tokens") @T.prim_func - def buggy_kernel(x: T.Tensor[(num_tokens,), 'int64']): # noqa: F821 + def buggy_kernel(x: T.Tensor[(num_tokens,), "int64"]): # noqa: F821 with T.Kernel(num_tokens, threads=128) as _: T.fill(x[0:128], 0) @@ -24,14 +25,15 @@ def _fill_with_static_region_kernel(): pass_configs={ tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - },) + }, +) def _fill_with_dynamic_region_kernel(): - num_tokens = T.symbolic('num_tokens') + num_tokens = T.symbolic("num_tokens") @T.prim_func - def buggy_kernel(x: T.Tensor[(num_tokens,), 'int64']): # noqa: F821 + def buggy_kernel(x: T.Tensor[(num_tokens,), "int64"]): # noqa: F821 with T.Kernel(num_tokens, threads=128) as _: - a, b = T.alloc_var('int'), T.alloc_var('int') + a, b = T.alloc_var("int"), T.alloc_var("int") T.fill(x[a:b], 0) return buggy_kernel @@ -39,15 +41,15 @@ def _fill_with_dynamic_region_kernel(): def test_fill_with_static_region_kernel(): kernel = _fill_with_static_region_kernel() - x = torch.zeros((256,), dtype=torch.int64, device='cuda') + x = torch.zeros((256,), dtype=torch.int64, device="cuda") kernel(x) def test_fill_with_dynamic_region_kernel(): kernel = _fill_with_dynamic_region_kernel() - x = torch.zeros((256,), dtype=torch.int64, device='cuda') + x = torch.zeros((256,), dtype=torch.int64, device="cuda") kernel(x) -if __name__ == '__main__': +if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/python/issue/test_tilelang_issue_1115.py b/testing/python/issue/test_tilelang_issue_1115.py index 1769862..ce21a3b 100644 --- a/testing/python/issue/test_tilelang_issue_1115.py +++ b/testing/python/issue/test_tilelang_issue_1115.py @@ -4,25 +4,23 @@ import tilelang.language as T def test_int64_address(): - @tilelang.jit def set_cache_kernel( S, D, - pos_ty='int64', + pos_ty="int64", dtype="float32", ): - @T.prim_func def main( - pos: T - .Tensor( + pos: T.Tensor( [ S, - ], pos_ty + ], + pos_ty, ), # type: ignore `TypeError: Check failed: (a.dtype() == b.dtype()) is false: mismatched types. int64 vs. int32` - value: T.Tensor([S, D], dtype), # type: ignore - cache: T.Tensor([S, D], dtype), # type: ignore + value: T.Tensor([S, D], dtype), # type: ignore + cache: T.Tensor([S, D], dtype), # type: ignore ): with T.Kernel(S, threads=128) as bx: slot = pos[bx] @@ -34,11 +32,11 @@ def test_int64_address(): D = 2 S = 10 cache = torch.rand((S, D), device="cuda", dtype=torch.float32) - value = torch.rand((S, D), device='cuda', dtype=torch.float32) - pos_int64 = torch.arange(S, device='cuda', dtype=torch.int64) - pos_int32 = torch.arange(S, device='cuda', dtype=torch.int32) - kernel_int64 = set_cache_kernel(S, D, 'int64') - kernel_int32 = set_cache_kernel(S, D, 'int32') + value = torch.rand((S, D), device="cuda", dtype=torch.float32) + pos_int64 = torch.arange(S, device="cuda", dtype=torch.int64) + pos_int32 = torch.arange(S, device="cuda", dtype=torch.int32) + kernel_int64 = set_cache_kernel(S, D, "int64") + kernel_int32 = set_cache_kernel(S, D, "int32") kernel_int64(pos_int64, value, cache) torch.testing.assert_close(cache, value) kernel_int32(pos_int32, value, cache) diff --git a/testing/python/issue/test_tilelang_issue_1198.py b/testing/python/issue/test_tilelang_issue_1198.py index eb9ed45..08f3682 100644 --- a/testing/python/issue/test_tilelang_issue_1198.py +++ b/testing/python/issue/test_tilelang_issue_1198.py @@ -3,13 +3,17 @@ import tilelang.language as T def test_issue_1198(): - @T.prim_func - def foo(x: T.Buffer([ - 32, - ], "int32")): + def foo( + x: T.Buffer( + [ + 32, + ], + "int32", + ), + ): pass -if __name__ == '__main__': +if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/python/issue/test_tilelang_issue_814.py b/testing/python/issue/test_tilelang_issue_814.py index 1a9e63d..a202bd9 100644 --- a/testing/python/issue/test_tilelang_issue_814.py +++ b/testing/python/issue/test_tilelang_issue_814.py @@ -6,11 +6,10 @@ import torch @tilelang.jit def _tmp_var_kernel(N, block_N, dtype="float"): - @T.prim_func def kernel( - A: T.Tensor((N,), dtype), - B: T.Tensor((N,), dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx: for i in T.Parallel(block_N): diff --git a/testing/python/issue/test_tilelang_issue_830.py b/testing/python/issue/test_tilelang_issue_830.py index 950b858..74ceed3 100644 --- a/testing/python/issue/test_tilelang_issue_830.py +++ b/testing/python/issue/test_tilelang_issue_830.py @@ -8,7 +8,6 @@ import tilelang.language as T @tilelang.jit def _empty_kernel(): - @T.prim_func def empty_kernel(): with T.Kernel(1, threads=32) as thread_idx: @@ -51,7 +50,6 @@ def test_empty_with_dead_code_kernel(): @tilelang.jit def _empty_kernel_with_binding_variants(use_tuple_binding: bool = False): - @T.prim_func def kernel_with_tuple_kernel_binding(): with T.Kernel(1, threads=32) as (pid,): diff --git a/testing/python/issue/test_tilelang_issue_96.py b/testing/python/issue/test_tilelang_issue_96.py index e42ebb5..6ab7fe4 100644 --- a/testing/python/issue/test_tilelang_issue_96.py +++ b/testing/python/issue/test_tilelang_issue_96.py @@ -5,18 +5,16 @@ import torch def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as ( - bx, - by, - ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as ( + bx, + by, + ): A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_N, block_K), dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) diff --git a/testing/python/issue/test_tilelang_issue_merge_if.py b/testing/python/issue/test_tilelang_issue_merge_if.py index 1db7f33..fa9432f 100644 --- a/testing/python/issue/test_tilelang_issue_merge_if.py +++ b/testing/python/issue/test_tilelang_issue_merge_if.py @@ -6,7 +6,6 @@ import tilelang.language as T def merge_if_test(): - @T.prim_func def main(): A = T.alloc_fragment((1,), "float16") diff --git a/testing/python/jit/test_tilelang_jit_callback.py b/testing/python/jit/test_tilelang_jit_callback.py index e987368..7d76a64 100644 --- a/testing/python/jit/test_tilelang_jit_callback.py +++ b/testing/python/jit/test_tilelang_jit_callback.py @@ -29,9 +29,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -141,9 +141,9 @@ def matmu_jit_kernel( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -208,6 +208,7 @@ def run_gemm_jit_kernel( def ref_program(A, B): import torch + C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = C.to(torch.__getattribute__(out_dtype)) return C diff --git a/testing/python/jit/test_tilelang_jit_gemm.py b/testing/python/jit/test_tilelang_jit_gemm.py index 25c19a0..153f06c 100644 --- a/testing/python/jit/test_tilelang_jit_gemm.py +++ b/testing/python/jit/test_tilelang_jit_gemm.py @@ -31,9 +31,9 @@ def matmul_kernel_jit( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -96,6 +96,7 @@ def run_gemm_kernel_jit( def ref_program(A, B): import torch + C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = C.to(torch.__getattribute__(out_dtype)) return C diff --git a/testing/python/jit/test_tilelang_jit_gemm_cython.py b/testing/python/jit/test_tilelang_jit_gemm_cython.py index 12524f1..4ea4ba8 100644 --- a/testing/python/jit/test_tilelang_jit_gemm_cython.py +++ b/testing/python/jit/test_tilelang_jit_gemm_cython.py @@ -28,9 +28,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -138,9 +138,9 @@ def matmu_jit_kernel( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -208,6 +208,7 @@ def run_gemm_jit_kernel( def ref_program(A, B): import torch + C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = C.to(out_dtype) return C @@ -235,19 +236,9 @@ def test_gemm_jit_kernel(): ) -def run_cython_kernel_do_bench(M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128): +def run_cython_kernel_do_bench( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): program = matmul( M, N, @@ -287,23 +278,12 @@ def run_cython_kernel_do_bench(M, def test_cython_kernel_do_bench(): - run_cython_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, - 256, 32, 2) - - -def run_cython_kernel_multi_stream(M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128): + run_cython_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + +def run_cython_kernel_multi_stream( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): program = matmul( M, N, @@ -342,23 +322,12 @@ def run_cython_kernel_multi_stream(M, def test_cython_kernel_multi_stream(): - run_cython_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", - 128, 256, 32, 2) - - -def run_cython_dynamic_shape(M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128): + run_cython_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + +def run_cython_dynamic_shape( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): program = matmul( M, N, @@ -398,36 +367,20 @@ def run_cython_dynamic_shape(M, matmul_kernel(tensor_a, tensor_b, tensor_c) tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype) - tilelang.testing.torch_assert_close( - tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) def test_cython_dynamic_shape(): - run_cython_dynamic_shape( - T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) - - run_cython_dynamic_shape( - T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, - 256, 32, 2) - - run_cython_dynamic_shape( - T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", - "float16", 128, 256, 32, 2) - - -def run_cython_dynamic_shape_with_out_idx(M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128): + run_cython_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + run_cython_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + run_cython_dynamic_shape(T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + +def run_cython_dynamic_shape_with_out_idx( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): program = matmul( M, N, @@ -467,13 +420,11 @@ def run_cython_dynamic_shape_with_out_idx(M, tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype) - tilelang.testing.torch_assert_close( - tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) def test_cython_dynamic_shape_with_out_idx(): - run_cython_dynamic_shape_with_out_idx( - T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + run_cython_dynamic_shape_with_out_idx(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) def matmul_int_variable( @@ -498,10 +449,10 @@ def matmul_int_variable( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - offset: T.int32, + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + offset: T.int32, ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -525,10 +476,10 @@ def matmul_int_variable( return main -def run_matmul_int_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, - out_dtype, dtypeAccum, num_stages, threads): - program = matmul_int_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, - out_dtype, dtypeAccum, num_stages, threads) +def run_matmul_int_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, num_stages, threads): + program = matmul_int_variable( + M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, num_stages, threads + ) matmul_kernel = tilelang.compile(program, execution_backend="cython", out_idx=2) in_dtype = map_torch_type(in_dtype) @@ -544,8 +495,7 @@ def run_matmul_int_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B def test_matmul_int_variable(): - run_matmul_int_variable(1024, 1024, 1024, 128, 128, 32, False, False, "float16", "float16", - "float32", 0, 128) + run_matmul_int_variable(1024, 1024, 1024, 128, 128, 32, False, False, "float16", "float16", "float32", 0, 128) def matmul_float_variable( @@ -570,10 +520,10 @@ def matmul_float_variable( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - offset: T.float32, + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + offset: T.float32, ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -597,10 +547,10 @@ def matmul_float_variable( return main -def run_matmul_float_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, - out_dtype, dtypeAccum, num_stages, threads): - program = matmul_float_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, - out_dtype, dtypeAccum, num_stages, threads) +def run_matmul_float_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, num_stages, threads): + program = matmul_float_variable( + M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, num_stages, threads + ) matmul_kernel = tilelang.compile(program, execution_backend="cython", out_idx=2) in_dtype = map_torch_type(in_dtype) @@ -616,8 +566,7 @@ def run_matmul_float_variable(M, N, K, block_M, block_N, block_K, trans_A, trans def test_matmul_float_variable(): - run_matmul_float_variable(1024, 1024, 1024, 128, 128, 32, False, False, "float16", "float16", - "float32", 0, 128) + run_matmul_float_variable(1024, 1024, 1024, 128, 128, 32, False, False, "float16", "float16", "float32", 0, 128) if __name__ == "__main__": diff --git a/testing/python/jit/test_tilelang_jit_nullptr.py b/testing/python/jit/test_tilelang_jit_nullptr.py index cce1fce..8965e2a 100644 --- a/testing/python/jit/test_tilelang_jit_nullptr.py +++ b/testing/python/jit/test_tilelang_jit_nullptr.py @@ -7,22 +7,13 @@ from tilelang.utils import map_torch_type @tl.jit -def tensor_null_test(M, - N, - K, - block_M, - block_N, - block_K, - dtype="float16", - accum_dtype="float", - with_bias=False): - +def tensor_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float", with_bias=False): @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), accum_dtype), - Bias: T.Tensor((N), accum_dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), accum_dtype), + Bias: T.Tensor((N), accum_dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): @@ -48,12 +39,10 @@ def tensor_null_test(M, def run_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - a = torch.randn(M, K, device="cuda", dtype=map_torch_type(dtype)) b = torch.randn(N, K, device="cuda", dtype=map_torch_type(dtype)) c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype)) - kernel = tensor_null_test( - M, N, K, block_M, block_N, block_K, dtype, accum_dtype, with_bias=False) + kernel = tensor_null_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype, with_bias=False) kernel(a, b, c, None) diff --git a/testing/python/jit/test_tilelang_jit_nvrtc.py b/testing/python/jit/test_tilelang_jit_nvrtc.py index c707686..2b15027 100644 --- a/testing/python/jit/test_tilelang_jit_nvrtc.py +++ b/testing/python/jit/test_tilelang_jit_nvrtc.py @@ -28,9 +28,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -136,9 +136,9 @@ def matmu_jit_kernel( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -206,6 +206,7 @@ def run_gemm_jit_kernel( def ref_program(A, B): import torch + C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = C.to(out_dtype) return C @@ -233,19 +234,9 @@ def test_gemm_jit_kernel(): ) -def run_nvrtc_kernel_do_bench(M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128): +def run_nvrtc_kernel_do_bench( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): program = matmul( M, N, @@ -278,23 +269,12 @@ def run_nvrtc_kernel_do_bench(M, def test_nvrtc_kernel_do_bench(): - run_nvrtc_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, - 256, 32, 2) - - -def run_nvrtc_kernel_multi_stream(M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128): + run_nvrtc_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + +def run_nvrtc_kernel_multi_stream( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): program = matmul( M, N, @@ -331,23 +311,12 @@ def run_nvrtc_kernel_multi_stream(M, def test_nvrtc_kernel_multi_stream(): - run_nvrtc_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", - 128, 256, 32, 2) - - -def run_nvrtc_dynamic_shape(M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128): + run_nvrtc_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + +def run_nvrtc_dynamic_shape( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): program = matmul( M, N, @@ -387,21 +356,15 @@ def run_nvrtc_dynamic_shape(M, matmul_kernel(tensor_a, tensor_b, tensor_c) tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype) - tilelang.testing.torch_assert_close( - tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) def test_nvrtc_dynamic_shape(): - run_nvrtc_dynamic_shape( - T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + run_nvrtc_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) - run_nvrtc_dynamic_shape( - T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, - 256, 32, 2) + run_nvrtc_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) - run_nvrtc_dynamic_shape( - T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", - "float16", 128, 256, 32, 2) + run_nvrtc_dynamic_shape(T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", "float16", 128, 256, 32, 2) def check_hopper(): @@ -412,35 +375,18 @@ def check_hopper(): return compute_capability == (9, 0) -def convolution_im2col(N, - C, - H, - W, - F, - K, - S, - D, - P, - block_M, - block_N, - block_K, - num_stages, - threads, - dtype="float16", - accum_dtype="float"): +def convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"): KH, KW = K, K OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 @T.prim_func def main( - data: T.Tensor((N, H, W, C), dtype), - kernel: T.Tensor((KH, KW, C, F), dtype), - out: T.Tensor((N, OH, OW, F), dtype), + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + out: T.Tensor((N, OH, OW, F), dtype), ): - with T.Kernel( - T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), - threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by): data_shared = T.alloc_shared((block_M, block_K), dtype) kernel_shared = T.alloc_shared((block_K, block_N), dtype) out_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -449,11 +395,13 @@ def convolution_im2col(N, kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) - T.annotate_layout({ - out_shared: tilelang.layout.make_swizzled_layout(out_shared), - data_shared: tilelang.layout.make_swizzled_layout(data_shared), - kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), - }) + T.annotate_layout( + { + out_shared: tilelang.layout.make_swizzled_layout(out_shared), + data_shared: tilelang.layout.make_swizzled_layout(data_shared), + kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), + } + ) T.clear(out_local) for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): @@ -467,23 +415,9 @@ def convolution_im2col(N, return main -def run_nvrtc_im2col_tma_desc(N, - C, - H, - W, - F, - K, - S, - D, - P, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=256): +def run_nvrtc_im2col_tma_desc(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages=3, num_threads=256): """Test im2col TMA descriptor functionality in NVRTC backend.""" - program = convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, - num_threads) + program = convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, num_threads) conv_kernel = tilelang.compile(program, out_idx=-1, execution_backend="nvrtc") @@ -501,32 +435,20 @@ def run_nvrtc_im2col_tma_desc(N, return C ref_c = ref_program(a, b) - tilelang.testing.torch_assert_close( - out_c, ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + tilelang.testing.torch_assert_close(out_c, ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) def test_nvrtc_im2col_tma_desc(): """Test im2col TMA descriptor with NVRTC backend.""" if not check_hopper(): import pytest + pytest.skip("Test requires Hopper GPU (compute capability 9.0)") # Small test case for im2col TMA descriptor run_nvrtc_im2col_tma_desc( - N=4, - C=64, - H=32, - W=32, - F=64, - K=3, - S=1, - D=1, - P=1, - block_M=64, - block_N=128, - block_K=32, - num_stages=3, - num_threads=256) + N=4, C=64, H=32, W=32, F=64, K=3, S=1, D=1, P=1, block_M=64, block_N=128, block_K=32, num_stages=3, num_threads=256 + ) def test_nvrtc_l2_persistent_map(): @@ -543,12 +465,11 @@ def test_nvrtc_l2_persistent_map(): block_size=256, dtype="float32", ): - @T.prim_func def kernel( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(M * N // block_size, threads=block_size) as bx: # Annotate L2 persistent cache for buffer B diff --git a/testing/python/jit/test_tilelang_jit_parcompile.py b/testing/python/jit/test_tilelang_jit_parcompile.py index e7bcec4..0a6e906 100644 --- a/testing/python/jit/test_tilelang_jit_parcompile.py +++ b/testing/python/jit/test_tilelang_jit_parcompile.py @@ -16,9 +16,9 @@ def matmul_kernel_jit( block_K, trans_A=False, trans_B=True, - in_dtype='float16', - out_dtype='float32', - accum_dtype='float32', + in_dtype="float16", + out_dtype="float32", + accum_dtype="float32", num_stages=2, threads=128, ): @@ -31,9 +31,9 @@ def matmul_kernel_jit( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) diff --git a/testing/python/jit/test_tilelang_jit_tvm_ffi.py b/testing/python/jit/test_tilelang_jit_tvm_ffi.py index f7bde6a..5daaf30 100644 --- a/testing/python/jit/test_tilelang_jit_tvm_ffi.py +++ b/testing/python/jit/test_tilelang_jit_tvm_ffi.py @@ -28,9 +28,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -74,9 +74,9 @@ def matmu_jit_kernel( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -144,6 +144,7 @@ def run_gemm_jit_kernel( def ref_program(A, B): import torch + C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = C.to(out_dtype) return C @@ -171,19 +172,9 @@ def test_gemm_jit_kernel(): ) -def run_tvm_ffi_kernel_do_bench(M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128): +def run_tvm_ffi_kernel_do_bench( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): program = matmul( M, N, @@ -216,23 +207,12 @@ def run_tvm_ffi_kernel_do_bench(M, def test_tvm_ffi_kernel_do_bench(): - run_tvm_ffi_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, - 256, 32, 2) - - -def run_tvm_ffi_kernel_multi_stream(M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128): + run_tvm_ffi_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + +def run_tvm_ffi_kernel_multi_stream( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): program = matmul( M, N, @@ -269,23 +249,12 @@ def run_tvm_ffi_kernel_multi_stream(M, def test_tvm_ffi_kernel_multi_stream(): - run_tvm_ffi_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", - 128, 256, 32, 2) - - -def run_tvm_ffi_dynamic_shape(M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128): + run_tvm_ffi_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + +def run_tvm_ffi_dynamic_shape( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): program = matmul( M, N, @@ -325,21 +294,17 @@ def run_tvm_ffi_dynamic_shape(M, matmul_kernel(tensor_a, tensor_b, tensor_c) tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype) - tilelang.testing.torch_assert_close( - tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) def test_tvm_ffi_dynamic_shape(): - run_tvm_ffi_dynamic_shape( - T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + run_tvm_ffi_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) - run_tvm_ffi_dynamic_shape( - T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, - 256, 32, 2) + run_tvm_ffi_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) run_tvm_ffi_dynamic_shape( - T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", - "float16", 128, 256, 32, 2) + T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", "float16", 128, 256, 32, 2 + ) def check_hopper(): @@ -350,35 +315,18 @@ def check_hopper(): return compute_capability == (9, 0) -def convolution_im2col(N, - C, - H, - W, - F, - K, - S, - D, - P, - block_M, - block_N, - block_K, - num_stages, - threads, - dtype="float16", - accum_dtype="float"): +def convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"): KH, KW = K, K OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 @T.prim_func def main( - data: T.Tensor((N, H, W, C), dtype), - kernel: T.Tensor((KH, KW, C, F), dtype), - out: T.Tensor((N, OH, OW, F), dtype), + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + out: T.Tensor((N, OH, OW, F), dtype), ): - with T.Kernel( - T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), - threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by): data_shared = T.alloc_shared((block_M, block_K), dtype) kernel_shared = T.alloc_shared((block_K, block_N), dtype) out_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -387,11 +335,13 @@ def convolution_im2col(N, kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) - T.annotate_layout({ - out_shared: tilelang.layout.make_swizzled_layout(out_shared), - data_shared: tilelang.layout.make_swizzled_layout(data_shared), - kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), - }) + T.annotate_layout( + { + out_shared: tilelang.layout.make_swizzled_layout(out_shared), + data_shared: tilelang.layout.make_swizzled_layout(data_shared), + kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), + } + ) T.clear(out_local) for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): @@ -405,23 +355,9 @@ def convolution_im2col(N, return main -def run_tvm_ffi_im2col_tma_desc(N, - C, - H, - W, - F, - K, - S, - D, - P, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=256): +def run_tvm_ffi_im2col_tma_desc(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages=3, num_threads=256): """Test im2col TMA descriptor functionality in tvm_ffi backend.""" - program = convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, - num_threads) + program = convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, num_threads) conv_kernel = tilelang.compile(program, out_idx=-1, execution_backend="tvm_ffi") @@ -439,32 +375,20 @@ def run_tvm_ffi_im2col_tma_desc(N, return C ref_c = ref_program(a, b) - tilelang.testing.torch_assert_close( - out_c, ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + tilelang.testing.torch_assert_close(out_c, ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) def test_tvm_ffi_im2col_tma_desc(): """Test im2col TMA descriptor with tvm_ffi backend.""" if not check_hopper(): import pytest + pytest.skip("Test requires Hopper GPU (compute capability 9.0)") # Small test case for im2col TMA descriptor run_tvm_ffi_im2col_tma_desc( - N=4, - C=64, - H=32, - W=32, - F=64, - K=3, - S=1, - D=1, - P=1, - block_M=64, - block_N=128, - block_K=32, - num_stages=3, - num_threads=256) + N=4, C=64, H=32, W=32, F=64, K=3, S=1, D=1, P=1, block_M=64, block_N=128, block_K=32, num_stages=3, num_threads=256 + ) def test_tvm_ffi_l2_persistent_map(): @@ -481,12 +405,11 @@ def test_tvm_ffi_l2_persistent_map(): block_size=256, dtype="float32", ): - @T.prim_func def kernel( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(M * N // block_size, threads=block_size) as bx: # Annotate L2 persistent cache for buffer B @@ -506,8 +429,12 @@ def test_tvm_ffi_l2_persistent_map(): kernel = elementwise_add_with_l2_cache(M, N) source = kernel.get_host_source() - assert "__tvm_cuda_stream_set_access_policy_window_packed" in source, "Expected __tvm_cuda_stream_set_access_policy_window_packed in the kernel source" - assert "__tvm_cuda_stream_reset_access_policy_window_packed" in source, "Expected __tvm_cuda_stream_reset_access_policy_window_packed in the kernel source" + assert "__tvm_cuda_stream_set_access_policy_window_packed" in source, ( + "Expected __tvm_cuda_stream_set_access_policy_window_packed in the kernel source" + ) + assert "__tvm_cuda_stream_reset_access_policy_window_packed" in source, ( + "Expected __tvm_cuda_stream_reset_access_policy_window_packed in the kernel source" + ) # Create test tensors a = torch.randn(M, N, dtype=torch.float32).cuda() diff --git a/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py b/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py index 13135d4..e7d7021 100644 --- a/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py +++ b/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py @@ -6,7 +6,8 @@ from tvm import DataType import tilelang.language as T from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitter,) + TensorCoreIntrinEmitter, +) from tilelang.transform import simplify_prim_func from tilelang.utils.tensor import map_torch_type @@ -111,12 +112,11 @@ def tl_matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -124,10 +124,12 @@ def tl_matmul( B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -135,7 +137,6 @@ def tl_matmul( T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -145,7 +146,6 @@ def tl_matmul( B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a( A_local, diff --git a/testing/python/kernel/test_tilelang_kernel_element_wise_add.py b/testing/python/kernel/test_tilelang_kernel_element_wise_add.py index 3ec6ae0..52763c8 100644 --- a/testing/python/kernel/test_tilelang_kernel_element_wise_add.py +++ b/testing/python/kernel/test_tilelang_kernel_element_wise_add.py @@ -16,15 +16,15 @@ def elementwise_add( @T.prim_func def main( - A: T.Tensor((M, N), in_dtype), - B: T.Tensor((M, N), in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor((M, N), in_dtype), + B: T.Tensor((M, N), in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): start_x = bx * block_N start_y = by * block_M - for (local_y, local_x) in T.Parallel(block_M, block_N): + for local_y, local_x in T.Parallel(block_M, block_N): y = start_y + local_y x = start_x + local_x diff --git a/testing/python/kernel/test_tilelang_kernel_fp8_gemm.py b/testing/python/kernel/test_tilelang_kernel_fp8_gemm.py index 19f327d..63c8212 100644 --- a/testing/python/kernel/test_tilelang_kernel_fp8_gemm.py +++ b/testing/python/kernel/test_tilelang_kernel_fp8_gemm.py @@ -12,12 +12,11 @@ def calc_diff(x, y): def matmul_nt(M, N, K, bM, bN, bK, in_dtype, out_dtype, accum_dtype): - @T.prim_func def main( - A: T.Tensor((M, K), in_dtype), - B: T.Tensor((N, K), in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor((M, K), in_dtype), + B: T.Tensor((N, K), in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, bN), T.ceildiv(M, bM), threads=128) as (bx, by): A_shared = T.alloc_shared((bM, bK), in_dtype) @@ -44,8 +43,7 @@ def assert_matmul_correctness(M, N, K, block_M, block_N, block_K, in_dtype, out_ C = kernel(A, B) - ref_c = torch.matmul(A.to(map_torch_type(accum_dtype)), - B.T.to(map_torch_type(accum_dtype))).to(map_torch_type(out_dtype)) + ref_c = torch.matmul(A.to(map_torch_type(accum_dtype)), B.T.to(map_torch_type(accum_dtype))).to(map_torch_type(out_dtype)) print(C) print(ref_c) diff = calc_diff(C, ref_c) diff --git a/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py b/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py index 46f4e12..eec3a9c 100644 --- a/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py +++ b/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py @@ -6,7 +6,8 @@ from tvm import DataType import tilelang.language as T from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitter,) + TensorCoreIntrinEmitter, +) from tilelang.transform import simplify_prim_func from tilelang.utils.tensor import map_torch_type @@ -110,12 +111,11 @@ def tl_matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -123,10 +123,12 @@ def tl_matmul( B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -134,7 +136,6 @@ def tl_matmul( T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -144,7 +145,6 @@ def tl_matmul( B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a( A_local, diff --git a/testing/python/kernel/test_tilelang_kernel_fp8_gemv_simt.py b/testing/python/kernel/test_tilelang_kernel_fp8_gemv_simt.py index afd01f3..4a48b65 100644 --- a/testing/python/kernel/test_tilelang_kernel_fp8_gemv_simt.py +++ b/testing/python/kernel/test_tilelang_kernel_fp8_gemv_simt.py @@ -27,8 +27,8 @@ def gemv_simt( ): assert n_partition is not None, "n_partition must be provided" assert reduce_thread is not None, ( - "reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMV" - "sch_outer_reduction_with_config is not implemented") + "reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMVsch_outer_reduction_with_config is not implemented" + ) assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently" @@ -50,16 +50,15 @@ def gemv_simt( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - Bias: T.Tensor(Bias_shape, out_dtype), - C: T.Tensor(C_shape, out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + Bias: T.Tensor(Bias_shape, out_dtype), + C: T.Tensor(C_shape, out_dtype), ): - with T.Kernel( - T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as ( - bx, - by, - ): + with T.Kernel(T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as ( + bx, + by, + ): A_local = T.alloc_local((micro_size_k,), in_dtype) B_local = T.alloc_local((micro_size_k,), in_dtype) accum_res = T.alloc_local((1,), accum_dtype) @@ -88,13 +87,12 @@ def gemv_simt( ) else: for ki in T.serial(micro_size_k): - accum_res[0] += A_local[ki].astype(accum_dtype) * B_local[ki].astype( - accum_dtype) + accum_res[0] += A_local[ki].astype(accum_dtype) * B_local[ki].astype(accum_dtype) with T.attr( - T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), - "reduce_scope", - T.reinterpret(T.uint64(0), dtype="handle"), + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), ): T.evaluate( T.tvm_thread_allreduce( @@ -104,11 +102,11 @@ def gemv_simt( reduced_accum_res[0], kr, dtype="handle", - )) + ) + ) if kr == 0: if with_bias: - C[by, - bx * n_partition + ni] = reduced_accum_res[0] + Bias[bx * n_partition + ni] + C[by, bx * n_partition + ni] = reduced_accum_res[0] + Bias[bx * n_partition + ni] else: C[by, bx * n_partition + ni] = reduced_accum_res[0] diff --git a/testing/python/kernel/test_tilelang_kernel_gemm.py b/testing/python/kernel/test_tilelang_kernel_gemm.py index 5dcde1d..6c01297 100644 --- a/testing/python/kernel/test_tilelang_kernel_gemm.py +++ b/testing/python/kernel/test_tilelang_kernel_gemm.py @@ -26,9 +26,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -95,8 +95,8 @@ def run_gemm( if in_dtype == "float32": # Convert float32 to tfloat32 because tfloat32 mma cannot truncate # float32 automatically, -0x1000 meas - A = ((A.view(torch.int32) - 0x1000)).view(torch.float32) - B = ((B.view(torch.int32) - 0x1000)).view(torch.float32) + A = (A.view(torch.int32) - 0x1000).view(torch.float32) + B = (B.view(torch.int32) - 0x1000).view(torch.float32) C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = C.to(torch.__getattribute__(out_dtype)) return C @@ -321,9 +321,9 @@ def matmul_sr( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -441,9 +441,9 @@ def matmul_rs( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared") diff --git a/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py b/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py index 6e20754..3633d3e 100644 --- a/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py +++ b/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py @@ -6,7 +6,8 @@ from tvm import DataType import tilelang.language as T from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitter,) + TensorCoreIntrinEmitter, +) from tilelang.transform import simplify_prim_func from tilelang.utils.tensor import map_torch_type @@ -111,12 +112,11 @@ def tl_matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -124,10 +124,12 @@ def tl_matmul( B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -135,7 +137,6 @@ def tl_matmul( T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -145,7 +146,6 @@ def tl_matmul( B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a( A_local, diff --git a/testing/python/kernel/test_tilelang_kernel_gemm_simt.py b/testing/python/kernel/test_tilelang_kernel_gemm_simt.py index 548497c..e4da44b 100644 --- a/testing/python/kernel/test_tilelang_kernel_gemm_simt.py +++ b/testing/python/kernel/test_tilelang_kernel_gemm_simt.py @@ -76,12 +76,11 @@ def tl_matmul_simt( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor(C_shape, out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor(C_shape, out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) @@ -97,7 +96,6 @@ def tl_matmul_simt( T.clear(C_local) for ko in T.serial(K // block_K): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -109,29 +107,24 @@ def tl_matmul_simt( for ki in T.serial((block_K // micro_size_k)): for i in T.serial(local_size_a): for mk in T.vectorized(micro_size_k): - A_local[i, mk] = A_shared[warp_m * local_size_a + i, - ki * micro_size_k + mk] + A_local[i, mk] = A_shared[warp_m * local_size_a + i, ki * micro_size_k + mk] for i in T.serial(local_size_b): for mk in T.vectorized(micro_size_k): - B_local[i, mk] = B_shared[warp_n * local_size_b + i, - ki * micro_size_k + mk] + B_local[i, mk] = B_shared[warp_n * local_size_b + i, ki * micro_size_k + mk] for i, j in T.grid(local_size_a, local_size_b): for mk in T.serial(micro_size_k // dp4a_size): if use_dp4a: - T.dp4a(A_local[i, mk * dp4a_size], B_local[j, mk * dp4a_size], - C_local[i * local_size_b + j]) + T.dp4a(A_local[i, mk * dp4a_size], B_local[j, mk * dp4a_size], C_local[i * local_size_b + j]) else: for dp4a_idx in T.serial(dp4a_size): - C_local[i * local_size_b + - j] += A_local[i, mk * dp4a_size + - dp4a_idx] * B_local[j, mk * dp4a_size + - dp4a_idx] + C_local[i * local_size_b + j] += ( + A_local[i, mk * dp4a_size + dp4a_idx] * B_local[j, mk * dp4a_size + dp4a_idx] + ) for i, j in T.grid(local_size_a, local_size_b): - C[by * block_M + warp_m * local_size_a + i, - bx * block_N + warp_n * local_size_b + j] = C_local[i * local_size_b + j] + C[by * block_M + warp_m * local_size_a + i, bx * block_N + warp_n * local_size_b + j] = C_local[i * local_size_b + j] return main diff --git a/testing/python/kernel/test_tilelang_kernel_gemm_with_stride.py b/testing/python/kernel/test_tilelang_kernel_gemm_with_stride.py index bbc2e79..2def480 100644 --- a/testing/python/kernel/test_tilelang_kernel_gemm_with_stride.py +++ b/testing/python/kernel/test_tilelang_kernel_gemm_with_stride.py @@ -5,12 +5,11 @@ import torch def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): @@ -59,7 +58,8 @@ def run_gemm_with_stride_ss(M: int, N: int, K: int, block_M: int, block_N: int, pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) # Create random input tensors on the GPU a = torch.randn(M, K, device="cuda", dtype=torch.float16) b = torch.randn(K, N, device="cuda", dtype=torch.float16) diff --git a/testing/python/kernel/test_tilelang_kernel_gemv_simt.py b/testing/python/kernel/test_tilelang_kernel_gemv_simt.py index 86d6acb..5825f69 100644 --- a/testing/python/kernel/test_tilelang_kernel_gemv_simt.py +++ b/testing/python/kernel/test_tilelang_kernel_gemv_simt.py @@ -27,8 +27,8 @@ def gemv_simt( ): assert n_partition is not None, "n_partition must be provided" assert reduce_thread is not None, ( - "reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMV" - "sch_outer_reduction_with_config is not implemented") + "reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMVsch_outer_reduction_with_config is not implemented" + ) assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently" @@ -50,16 +50,15 @@ def gemv_simt( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - Bias: T.Tensor(Bias_shape, out_dtype), - C: T.Tensor(C_shape, out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + Bias: T.Tensor(Bias_shape, out_dtype), + C: T.Tensor(C_shape, out_dtype), ): - with T.Kernel( - T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as ( - bx, - by, - ): + with T.Kernel(T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as ( + bx, + by, + ): A_local = T.alloc_local((micro_size_k,), in_dtype) B_local = T.alloc_local((micro_size_k,), in_dtype) accum_res = T.alloc_local((1,), accum_dtype) @@ -88,13 +87,12 @@ def gemv_simt( ) else: for ki in T.serial(micro_size_k): - accum_res[0] += A_local[ki].astype(accum_dtype) * B_local[ki].astype( - accum_dtype) + accum_res[0] += A_local[ki].astype(accum_dtype) * B_local[ki].astype(accum_dtype) with T.attr( - T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), - "reduce_scope", - T.reinterpret(T.uint64(0), dtype="handle"), + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), ): T.evaluate( T.tvm_thread_allreduce( @@ -104,11 +102,11 @@ def gemv_simt( reduced_accum_res[0], kr, dtype="handle", - )) + ) + ) if kr == 0: if with_bias: - C[by, - bx * n_partition + ni] = reduced_accum_res[0] + Bias[bx * n_partition + ni] + C[by, bx * n_partition + ni] = reduced_accum_res[0] + Bias[bx * n_partition + ni] else: C[by, bx * n_partition + ni] = reduced_accum_res[0] diff --git a/testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py b/testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py index 5cdd671..affeb3d 100644 --- a/testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py +++ b/testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py @@ -4,7 +4,8 @@ from tilelang import tvm as tvm import tilelang.testing import tilelang.language as T from tilelang.intrinsics import ( - make_mma_swizzle_layout as make_swizzle_layout,) + make_mma_swizzle_layout as make_swizzle_layout, +) from tilelang.intrinsics.mma_macro_generator import ( INT4TensorCoreIntrinEmitter, @@ -91,12 +92,11 @@ def tl_matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -104,10 +104,12 @@ def tl_matmul( B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -115,7 +117,6 @@ def tl_matmul( T.clear(C_local) for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -125,7 +126,6 @@ def tl_matmul( B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a( A_local, @@ -168,7 +168,8 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): out_idx=[2], pass_configs={ tilelang.PassConfigKey.TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS: True, - }) + }, + ) print(kernel.get_kernel_source()) profiler = kernel.get_profiler() @@ -285,12 +286,11 @@ def tl_matmul_weight_only_transform( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -298,10 +298,12 @@ def tl_matmul_weight_only_transform( B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -309,19 +311,15 @@ def tl_matmul_weight_only_transform( T.clear(C_local) for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] # Load B into shared memory - for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // micro_size_k, - micro_size_y, micro_size_k): - B_shared[j, k, jj, kk] = B[bx * (block_N // micro_size_y) + j, - ko * (block_K // micro_size_k) + k, jj, kk] + for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // micro_size_k, micro_size_y, micro_size_k): + B_shared[j, k, jj, kk] = B[bx * (block_N // micro_size_y) + j, ko * (block_K // micro_size_k) + k, jj, kk] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a( A_local, @@ -359,6 +357,7 @@ def tl_matmul_weight_only_transform( def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): import bitblas + matmul = tl_matmul_weight_only_transform(M, N, K, in_dtype, out_dtype, accum_dtype) kernel = tilelang.compile(matmul, out_idx=[2]) profiler = kernel.get_profiler() diff --git a/testing/python/language/test_tilelang_capture.py b/testing/python/language/test_tilelang_capture.py index 875fa68..47fec99 100644 --- a/testing/python/language/test_tilelang_capture.py +++ b/testing/python/language/test_tilelang_capture.py @@ -6,16 +6,17 @@ import gc def test_tilelang_capture(): - @tilelang.jit( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - },) + }, + ) def get_dummy_kernel(): - @T.prim_func - def dummy_kernel(a: T.Tensor[(1,), T.float32],): + def dummy_kernel( + a: T.Tensor[(1,), T.float32], + ): with T.Kernel(1) as _: a[0] = 1 @@ -36,5 +37,5 @@ def test_tilelang_capture(): # objgraph.show_backrefs([a_upgrade], max_depth=5) -if __name__ == '__main__': +if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/python/language/test_tilelang_intimm.py b/testing/python/language/test_tilelang_intimm.py index 58fea31..46c2c79 100644 --- a/testing/python/language/test_tilelang_intimm.py +++ b/testing/python/language/test_tilelang_intimm.py @@ -4,25 +4,25 @@ import tilelang.language as T def test_tilelang_intimm(): - T.int32(0x7fffffff) - T.int32(-0x7fffffff - 1) - T.uint32(0xffffffff) - T.int64(0x7fffffffffffffff) - T.int64(-0x7fffffffffffffff - 1) - T.uint64(0xffffffffffffffff) + T.int32(0x7FFFFFFF) + T.int32(-0x7FFFFFFF - 1) + T.uint32(0xFFFFFFFF) + T.int64(0x7FFFFFFFFFFFFFFF) + T.int64(-0x7FFFFFFFFFFFFFFF - 1) + T.uint64(0xFFFFFFFFFFFFFFFF) a = T.int32() - a & 0x7fffffff + a & 0x7FFFFFFF a = T.uint32() - a & 0xffffffff + a & 0xFFFFFFFF a = T.int64() - a & 0x7fffffffffffffff + a & 0x7FFFFFFFFFFFFFFF a = T.uint64() - a & T.uint64(0xffffffffffffffff) + a & T.uint64(0xFFFFFFFFFFFFFFFF) -if __name__ == '__main__': +if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/python/language/test_tilelang_language_alias.py b/testing/python/language/test_tilelang_language_alias.py index c99d361..f55d9e8 100644 --- a/testing/python/language/test_tilelang_language_alias.py +++ b/testing/python/language/test_tilelang_language_alias.py @@ -5,12 +5,11 @@ import tilelang.language as T # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): diff --git a/testing/python/language/test_tilelang_language_all_of.py b/testing/python/language/test_tilelang_language_all_of.py index 73233ec..4841212 100644 --- a/testing/python/language/test_tilelang_language_all_of.py +++ b/testing/python/language/test_tilelang_language_all_of.py @@ -13,11 +13,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K): accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device) for k in range(K // block_K): if torch.all(BlockMask[i, j, k]): - accu += A[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to( - torch.float32) @ B[k * block_K:(k + 1) * block_K, - j * block_N:(j + 1) * block_N].to(torch.float32) - ref_c[i * block_M:(i + 1) * block_M, j * block_N:(j + 1) * block_N] = ( - accu.to(torch.float16)) + accu += A[i * block_M : (i + 1) * block_M, k * block_K : (k + 1) * block_K].to(torch.float32) @ B[ + k * block_K : (k + 1) * block_K, j * block_N : (j + 1) * block_N + ].to(torch.float32) + ref_c[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = accu.to(torch.float16) return ref_c @@ -35,15 +34,14 @@ def blocksparse_matmul_global( dtype="float16", accum_dtype="float", ): - block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - BlockMask: T.Tensor(block_mask_shape, "bool"), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + BlockMask: T.Tensor(block_mask_shape, "bool"), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -80,15 +78,14 @@ def blocksparse_matmul_shared( dtype="float16", accum_dtype="float", ): - block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - BlockMask: T.Tensor(block_mask_shape, "bool"), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + BlockMask: T.Tensor(block_mask_shape, "bool"), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -130,15 +127,14 @@ def blocksparse_matmul_local( dtype="float16", accum_dtype="float", ): - block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - BlockMask: T.Tensor(block_mask_shape, "bool"), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + BlockMask: T.Tensor(block_mask_shape, "bool"), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -237,7 +233,8 @@ def run_block_sparse_matmul_shared(M=1024, N=1024, K=1024, sparsity=0.5, conditi pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) # Create block mask with desired sparsity mask_shape = (M // block_M, N // block_N, K // block_K) block_mask = torch.rand(mask_shape).cuda() > sparsity @@ -284,7 +281,8 @@ def run_block_sparse_matmul_local(M=1024, N=1024, K=1024, sparsity=0.5, conditio pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) # Create block mask with desired sparsity mask_shape = (M // block_M, N // block_N, K // block_K) block_mask = torch.rand(mask_shape).cuda() > sparsity diff --git a/testing/python/language/test_tilelang_language_alloc.py b/testing/python/language/test_tilelang_language_alloc.py index 149a1c2..6695e93 100644 --- a/testing/python/language/test_tilelang_language_alloc.py +++ b/testing/python/language/test_tilelang_language_alloc.py @@ -10,8 +10,8 @@ def alloc_var( @T.prim_func def main( - A: T.Tensor((N,), dtype), - B: T.Tensor((N,), dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: A_shared = T.alloc_shared([block_N], dtype) @@ -50,8 +50,8 @@ def alloc_var_add( @T.prim_func def main( - A: T.Tensor((N,), dtype), - B: T.Tensor((N,), dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: A_shared = T.alloc_shared([block_N], dtype) @@ -91,8 +91,8 @@ def alloc_var_with_initializer( @T.prim_func def main( - A: T.Tensor((N,), dtype), - B: T.Tensor((N,), dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: tmp = T.alloc_var(dtype, init_value) @@ -129,8 +129,8 @@ def alloc_multi_vars_with_initializer( @T.prim_func def main( - A: T.Tensor((N,), dtype), - B: T.Tensor((N,), dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: tmp0 = T.alloc_var(dtype, 1) diff --git a/testing/python/language/test_tilelang_language_annot.py b/testing/python/language/test_tilelang_language_annot.py index 7425bf5..5c9aeea 100644 --- a/testing/python/language/test_tilelang_language_annot.py +++ b/testing/python/language/test_tilelang_language_annot.py @@ -5,13 +5,14 @@ import torch def test_tensor_annot_mul(): - @tilelang.jit def example_tensor_annot(): - n = T.symbolic('n') + n = T.symbolic("n") @T.prim_func - def kernel(A: T.Tensor((n * 4,), T.int32),): + def kernel( + A: T.Tensor((n * 4,), T.int32), + ): with T.Kernel(1) as _: for i in range(n * 4): A[i] = 0 @@ -19,20 +20,21 @@ def test_tensor_annot_mul(): return kernel ker = example_tensor_annot() - A = torch.arange(16, dtype=torch.int32, device='cuda') + A = torch.arange(16, dtype=torch.int32, device="cuda") ker(A) - expected = torch.zeros(16, dtype=torch.int32, device='cuda') + expected = torch.zeros(16, dtype=torch.int32, device="cuda") assert torch.equal(A, expected) def test_tensor_annot_add(): - @tilelang.jit def example_tensor_annot(): - n = T.symbolic('n') + n = T.symbolic("n") @T.prim_func - def kernel(A: T.Tensor((n + 1,), T.int32),): + def kernel( + A: T.Tensor((n + 1,), T.int32), + ): with T.Kernel(1) as _: for i in range(n + 1): A[i] = 0 @@ -40,20 +42,21 @@ def test_tensor_annot_add(): return kernel ker = example_tensor_annot() - A = torch.arange(16, dtype=torch.int32, device='cuda') + A = torch.arange(16, dtype=torch.int32, device="cuda") ker(A) - expected = torch.zeros(16, dtype=torch.int32, device='cuda') + expected = torch.zeros(16, dtype=torch.int32, device="cuda") assert torch.equal(A, expected) def test_tensor_annot_mul_add(): - @tilelang.jit def example_tensor_annot(): - n = T.symbolic('n') + n = T.symbolic("n") @T.prim_func - def kernel(A: T.Tensor((n * 3 + 1,), T.int32),): + def kernel( + A: T.Tensor((n * 3 + 1,), T.int32), + ): with T.Kernel(1) as _: for i in range(n * 3 + 1): A[i] = 0 @@ -61,11 +64,11 @@ def test_tensor_annot_mul_add(): return kernel ker = example_tensor_annot() - A = torch.arange(16, dtype=torch.int32, device='cuda') + A = torch.arange(16, dtype=torch.int32, device="cuda") ker(A) - expected = torch.zeros(16, dtype=torch.int32, device='cuda') + expected = torch.zeros(16, dtype=torch.int32, device="cuda") assert torch.equal(A, expected) -if __name__ == '__main__': +if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/python/language/test_tilelang_language_annotate_safe_value.py b/testing/python/language/test_tilelang_language_annotate_safe_value.py index 3d616ac..442172b 100644 --- a/testing/python/language/test_tilelang_language_annotate_safe_value.py +++ b/testing/python/language/test_tilelang_language_annotate_safe_value.py @@ -7,11 +7,10 @@ import torch # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit def tilelang_copy(M, N, block_M, block_N, dtype="float16", pad_value=0): - @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): @@ -30,13 +29,8 @@ def tilelang_copy(M, N, block_M, block_N, dtype="float16", pad_value=0): def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16", pad_value=0): program = tilelang_copy(M, N, block_M, block_N, dtype, pad_value=pad_value) kernel = tilelang.compile( - program, - out_idx=[1], - target="cuda", - pass_configs={ - "tl.disable_warp_specialized": True, - "tl.disable_tma_lower": True - }) + program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True} + ) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) b = kernel(a) ref_b = torch.zeros_like(a) diff --git a/testing/python/language/test_tilelang_language_any_of.py b/testing/python/language/test_tilelang_language_any_of.py index 354d32c..37605e5 100644 --- a/testing/python/language/test_tilelang_language_any_of.py +++ b/testing/python/language/test_tilelang_language_any_of.py @@ -13,11 +13,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K): accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device) for k in range(K // block_K): if torch.any(BlockMask[i, j, k]): - accu += A[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to( - torch.float32) @ B[k * block_K:(k + 1) * block_K, - j * block_N:(j + 1) * block_N].to(torch.float32) - ref_c[i * block_M:(i + 1) * block_M, j * block_N:(j + 1) * block_N] = ( - accu.to(torch.float16)) + accu += A[i * block_M : (i + 1) * block_M, k * block_K : (k + 1) * block_K].to(torch.float32) @ B[ + k * block_K : (k + 1) * block_K, j * block_N : (j + 1) * block_N + ].to(torch.float32) + ref_c[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = accu.to(torch.float16) return ref_c @@ -35,15 +34,14 @@ def blocksparse_matmul_global( dtype="float16", accum_dtype="float", ): - block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - BlockMask: T.Tensor(block_mask_shape, "bool"), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + BlockMask: T.Tensor(block_mask_shape, "bool"), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -80,15 +78,14 @@ def blocksparse_matmul_shared( dtype="float16", accum_dtype="float", ): - block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - BlockMask: T.Tensor(block_mask_shape, "bool"), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + BlockMask: T.Tensor(block_mask_shape, "bool"), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -130,15 +127,14 @@ def blocksparse_matmul_local( dtype="float16", accum_dtype="float", ): - block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - BlockMask: T.Tensor(block_mask_shape, "bool"), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + BlockMask: T.Tensor(block_mask_shape, "bool"), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -237,7 +233,8 @@ def run_block_sparse_matmul_shared(M=1024, N=1024, K=1024, sparsity=0.5, conditi pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) # Create block mask with desired sparsity mask_shape = (M // block_M, N // block_N, K // block_K) block_mask = torch.rand(mask_shape).cuda() > sparsity @@ -284,7 +281,8 @@ def run_block_sparse_matmul_local(M=1024, N=1024, K=1024, sparsity=0.5, conditio pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) # Create block mask with desired sparsity mask_shape = (M // block_M, N // block_N, K // block_K) block_mask = torch.rand(mask_shape).cuda() > sparsity diff --git a/testing/python/language/test_tilelang_language_assume.py b/testing/python/language/test_tilelang_language_assume.py index 9c75a5a..32e6b1c 100644 --- a/testing/python/language/test_tilelang_language_assume.py +++ b/testing/python/language/test_tilelang_language_assume.py @@ -4,10 +4,9 @@ import tilelang.testing def test_assume_remove_boundary_check(): - @tilelang.jit def kernel_with_assume(): - N = T.dynamic('N') + N = T.dynamic("N") @T.prim_func def main(A: T.Tensor((N,), "float32"), l: T.int32, r: T.int32): @@ -21,20 +20,19 @@ def test_assume_remove_boundary_check(): jit_kernel = kernel_with_assume() source = jit_kernel.get_kernel_source() - assert ("if (" not in source) + assert "if (" not in source def test_assume_enable_vectorization(): - @tilelang.jit def kernel_vectorize(M): - N = T.dynamic('N') + N = T.dynamic("N") vectorize_size = 4 @T.prim_func def main( - A: T.Tensor((M, N), "float32"), - B: T.Tensor((M, N), "float32"), + A: T.Tensor((M, N), "float32"), + B: T.Tensor((M, N), "float32"), ): with T.Kernel(1, threads=32) as _: tid = T.get_thread_binding() @@ -55,16 +53,15 @@ def test_assume_enable_vectorization(): def test_assume_complex_indexing(): - @tilelang.jit def kernel_complex(): - M = T.dynamic('M') - N = T.dynamic('N') + M = T.dynamic("M") + N = T.dynamic("N") @T.prim_func def main( - A: T.Tensor((M, N), "float32"), - B: T.Tensor((M, N), "float32"), + A: T.Tensor((M, N), "float32"), + B: T.Tensor((M, N), "float32"), ): with T.Kernel(1, threads=32) as _: tid = T.get_thread_binding() @@ -82,8 +79,8 @@ def test_assume_complex_indexing(): jit_kernel = kernel_complex() source = jit_kernel.get_kernel_source() - assert ("if (" not in source) + assert "if (" not in source -if __name__ == '__main__': +if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/python/language/test_tilelang_language_atomic_add.py b/testing/python/language/test_tilelang_language_atomic_add.py index b157966..eaf5ae1 100644 --- a/testing/python/language/test_tilelang_language_atomic_add.py +++ b/testing/python/language/test_tilelang_language_atomic_add.py @@ -4,14 +4,12 @@ import tilelang.language as T @tilelang.jit def atomic_add_program(K, M, N, block_M, block_N, dtype="float"): - @T.prim_func def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): A_shared = T.alloc_shared((block_M, block_N), dtype) - T.copy(A[bz, bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], - A_shared) + T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared) for i, j in T.Parallel(block_M, block_N): T.atomic_add(B[bx * block_M + i, by * block_N + j], A_shared[i, j]) @@ -39,14 +37,12 @@ def run_atomic_add(K, M, N, block_M, block_N, dtype="float32"): @tilelang.jit def tile_atomic_add_program(K, M, N, block_M, block_N, dtype="float"): - @T.prim_func def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): A_shared = T.alloc_shared((block_M, block_N), dtype) - T.copy(A[bz, bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], - A_shared) + T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared) T.atomic_add(B[bx * block_M, by * block_N], A_shared) @@ -76,14 +72,12 @@ def run_tile_atomic_add(K, M, N, block_M, block_N, dtype="float32"): @tilelang.jit def atomic_max_program(K, M, N, block_M, block_N, dtype="float"): - @T.prim_func def atomic_max(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): A_shared = T.alloc_shared((block_M, block_N), dtype) - T.copy(A[bz, bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], - A_shared) + T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared) for i, j in T.Parallel(block_M, block_N): T.atomic_max(B[bx * block_M + i, by * block_N + j], A_shared[i, j]) @@ -111,14 +105,12 @@ def run_atomic_max(K, M, N, block_M, block_N, dtype="float32"): @tilelang.jit def atomic_min_program(K, M, N, block_M, block_N, dtype="float"): - @T.prim_func def atomic_min(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): A_shared = T.alloc_shared((block_M, block_N), dtype) - T.copy(A[bz, bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], - A_shared) + T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared) for i, j in T.Parallel(block_M, block_N): T.atomic_min(B[bx * block_M + i, by * block_N + j], A_shared[i, j]) @@ -137,7 +129,7 @@ def run_atomic_min(K, M, N, block_M, block_N, dtype="float32"): B[i, j] = min(B[i, j], A[k, i, j]) A = torch.randn(K, M, N, dtype=getattr(torch, dtype)).cuda() - B = torch.full((M, N), float('inf'), dtype=getattr(torch, dtype)).cuda() + B = torch.full((M, N), float("inf"), dtype=getattr(torch, dtype)).cuda() ref_B = B.clone() ref_program(A, ref_B) kernel(A, B) @@ -146,7 +138,6 @@ def run_atomic_min(K, M, N, block_M, block_N, dtype="float32"): @tilelang.jit def atomic_load_store_program(M, N, block_M, block_N, dtype="float"): - @T.prim_func def atomic_load_store(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): @@ -172,18 +163,15 @@ def run_atomic_load_store(M, N, block_M, block_N, dtype="float32"): @tilelang.jit def atomic_memory_order_program(K, M, N, block_M, block_N, dtype="float"): - @T.prim_func def atomic_with_memory_order(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): A_shared = T.alloc_shared((block_M, block_N), dtype) - T.copy(A[bz, bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], - A_shared) + T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared) for i, j in T.Parallel(block_M, block_N): - T.atomic_add( - B[bx * block_M + i, by * block_N + j], A_shared[i, j], memory_order="relaxed") + T.atomic_add(B[bx * block_M + i, by * block_N + j], A_shared[i, j], memory_order="relaxed") return atomic_with_memory_order @@ -208,7 +196,6 @@ def run_atomic_memory_order(K, M, N, block_M, block_N, dtype="float32"): @tilelang.jit def atomic_addx2_program(M, N, block_M, block_N): - @T.prim_func def atomic_addx2(A: T.Tensor((M, N), "float16"), B: T.Tensor((M, N), "float16")): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): @@ -262,10 +249,10 @@ def test_atomic_addx2(): @tilelang.jit def atomic_different_memory_orders_program(M, N, block_M, block_N, dtype="float"): - @T.prim_func - def atomic_different_orders(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype), C: T.Tensor( - (M, N), dtype), D: T.Tensor((M, N), dtype)): + def atomic_different_orders( + A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype), D: T.Tensor((M, N), dtype) + ): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): for i, j in T.Parallel(block_M, block_N): idx_i = bx * block_M + i @@ -286,18 +273,17 @@ def run_atomic_different_memory_orders(M, N, block_M, block_N, dtype="float32"): A = torch.randn(M, N, dtype=getattr(torch, dtype)).cuda() B = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda() C = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda() - D = torch.full((M, N), float('inf'), dtype=getattr(torch, dtype)).cuda() + D = torch.full((M, N), float("inf"), dtype=getattr(torch, dtype)).cuda() kernel(A, B, C, D) torch.testing.assert_close(B, A, atol=1e-3, rtol=1e-3) torch.testing.assert_close(C, torch.maximum(torch.zeros_like(A), A)) - torch.testing.assert_close(D, torch.minimum(torch.full_like(A, float('inf')), A)) + torch.testing.assert_close(D, torch.minimum(torch.full_like(A, float("inf")), A)) @tilelang.jit def atomic_addx4_program(M, N, block_M, block_N): - @T.prim_func def atomic_addx4(A: T.Tensor((M, N), "float32"), B: T.Tensor((M, N), "float32")): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): @@ -330,17 +316,14 @@ def run_atomic_addx4(M, N, block_M, block_N): @tilelang.jit def atomic_return_prev_program(M, N, block_M, block_N, dtype="float"): - @T.prim_func - def atomic_with_return_prev(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype), - old_vals: T.Tensor((M, N), dtype)): + def atomic_with_return_prev(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype), old_vals: T.Tensor((M, N), dtype)): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): for i, j in T.Parallel(block_M, block_N): idx_i = bx * block_M + i idx_j = by * block_N + j if idx_i < M and idx_j < N: - old_vals[idx_i, idx_j] = T.atomic_add( - B[idx_i, idx_j], A[idx_i, idx_j], return_prev=True) + old_vals[idx_i, idx_j] = T.atomic_add(B[idx_i, idx_j], A[idx_i, idx_j], return_prev=True) return atomic_with_return_prev diff --git a/testing/python/language/test_tilelang_language_ceildiv.py b/testing/python/language/test_tilelang_language_ceildiv.py index 35201a0..66215ab 100644 --- a/testing/python/language/test_tilelang_language_ceildiv.py +++ b/testing/python/language/test_tilelang_language_ceildiv.py @@ -5,7 +5,6 @@ import torch @tilelang.jit(out_idx=[-1]) def _ceildiv_kernel(a: int, b: int): - @T.prim_func def ceildiv_kernel(A: T.Tensor((1,), "int32")): with T.Kernel(1, threads=1) as _: @@ -30,7 +29,6 @@ def test_ceildiv(): @tilelang.jit def _ceildiv_kernel_dyn(b: int): - @T.prim_func def ceildiv_kernel(A: T.Tensor((1,), "int32"), a: T.int32): with T.Kernel(1, threads=1) as _: diff --git a/testing/python/language/test_tilelang_language_chain_equal.py b/testing/python/language/test_tilelang_language_chain_equal.py index 696a9c7..0a9623f 100644 --- a/testing/python/language/test_tilelang_language_chain_equal.py +++ b/testing/python/language/test_tilelang_language_chain_equal.py @@ -8,14 +8,14 @@ import torch pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - },) + }, +) def chain_equal(N, block_size, dtype="float32"): - @T.prim_func def main( - A: T.Tensor((N,), dtype), - B: T.Tensor((N,), dtype), - C: T.Tensor((N,), dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), + C: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, block_size), threads=block_size) as bx: for lane in T.Parallel(block_size): diff --git a/testing/python/language/test_tilelang_language_clamp.py b/testing/python/language/test_tilelang_language_clamp.py index 4a2f177..06e558f 100644 --- a/testing/python/language/test_tilelang_language_clamp.py +++ b/testing/python/language/test_tilelang_language_clamp.py @@ -13,8 +13,8 @@ def clamp_within_bounds( @T.prim_func def main( - A: T.Tensor((N,), dtype), - B: T.Tensor((N,), dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: A_shared = T.alloc_shared([block_N], dtype) @@ -56,8 +56,8 @@ def clamp_value_range( @T.prim_func def main( - A: T.Tensor((1, N), dtype), - B: T.Tensor((1, N), dtype), + A: T.Tensor((1, N), dtype), + B: T.Tensor((1, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: # A_shared = T.alloc_shared([1, block_N], dtype=dtype) diff --git a/testing/python/language/test_tilelang_language_clear.py b/testing/python/language/test_tilelang_language_clear.py index be3d808..19ae0bb 100644 --- a/testing/python/language/test_tilelang_language_clear.py +++ b/testing/python/language/test_tilelang_language_clear.py @@ -5,12 +5,11 @@ import tilelang.language as T # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): @@ -42,10 +41,10 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo def run_matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): program = matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype) - kernel = tilelang.compile( - program, out_idx=[2], target="cuda", pass_configs={"tl.disable_tma_lower": True}) + kernel = tilelang.compile(program, out_idx=[2], target="cuda", pass_configs={"tl.disable_tma_lower": True}) import torch from tilelang.utils import map_torch_type + a = torch.randn((M, K), dtype=map_torch_type(dtype)).cuda() b = torch.randn((N, K), dtype=map_torch_type(dtype)).cuda() c = kernel(a, b) diff --git a/testing/python/language/test_tilelang_language_composable_index.py b/testing/python/language/test_tilelang_language_composable_index.py index ac2254f..8a58695 100644 --- a/testing/python/language/test_tilelang_language_composable_index.py +++ b/testing/python/language/test_tilelang_language_composable_index.py @@ -7,11 +7,10 @@ import torch # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit def tilelang_composable_copy(M, N, block_M, block_N, dtype="float16"): - @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M * N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M * N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): @@ -35,7 +34,8 @@ def run_tilelang_composable_copy(M=1024, N=1024, block_M=128, block_N=128, dtype pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) b = kernel(a) torch.testing.assert_close(b.flatten(), a.flatten(), rtol=1e-2, atol=1e-2) diff --git a/testing/python/language/test_tilelang_language_copy.py b/testing/python/language/test_tilelang_language_copy.py index 4a2ddee..367f8ed 100644 --- a/testing/python/language/test_tilelang_language_copy.py +++ b/testing/python/language/test_tilelang_language_copy.py @@ -7,11 +7,10 @@ import tilelang.testing # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit def tilelang_copy(M, N, block_M, block_N, dtype="float16"): - @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): @@ -27,10 +26,8 @@ def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16") program, out_idx=[1], target="cuda", - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True - }) + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True}, + ) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) b = kernel(a) torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) @@ -43,11 +40,10 @@ def test_tilelang_copy(): def tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype="float16"): - @T.prim_func def main( - A: T.StridedTensor((M, N), (NN, 1), dtype), - B: T.Tensor((M, N), dtype), + A: T.StridedTensor((M, N), (NN, 1), dtype), + B: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): @@ -57,12 +53,7 @@ def tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype="float16"): return main -def run_tilelang_copy_with_stride(M=1024, - N=1024, - NN=2048, - block_M=128, - block_N=128, - dtype="float16"): +def run_tilelang_copy_with_stride(M=1024, N=1024, NN=2048, block_M=128, block_N=128, dtype="float16"): if isinstance(NN, int): assert NN > N, "NN must be greater than N" program = tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype) @@ -73,7 +64,8 @@ def run_tilelang_copy_with_stride(M=1024, pass_configs={ tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - }) + }, + ) if isinstance(NN, T.Var): NN = N * 2 a = torch.randn(M, NN, device="cuda", dtype=getattr(torch, dtype)) @@ -87,11 +79,10 @@ def test_tilelang_copy_with_stride(): def tilelang_copy_bufferload(num_tokens, dtype="float16"): - @T.prim_func def main( - indices: T.Tensor((num_tokens,), "int32"), - x: T.Tensor((num_tokens,), dtype), + indices: T.Tensor((num_tokens,), "int32"), + x: T.Tensor((num_tokens,), dtype), ): with T.Kernel(num_tokens, threads=32) as pid: idx = T.alloc_local([1], "int32") @@ -107,10 +98,8 @@ def run_tilelang_copy_bufferload(num_tokens=128, dtype="float16"): tilelang.compile( program, out_idx=[1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True - }) + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True}, + ) def test_tilelang_copy_bufferload(): @@ -118,11 +107,10 @@ def test_tilelang_copy_bufferload(): def tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype="float16"): - @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): @@ -132,20 +120,14 @@ def tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype="float return main -def run_tilelang_copy_buffer_load_with_parallel(M=1024, - N=1024, - block_M=128, - block_N=128, - dtype="float16"): +def run_tilelang_copy_buffer_load_with_parallel(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"): program = tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype) kernel = tilelang.compile( program, out_idx=[1], target="cuda", - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True - }) + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True}, + ) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) b = kernel(a) torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) diff --git a/testing/python/language/test_tilelang_language_cumsum.py b/testing/python/language/test_tilelang_language_cumsum.py index 0046405..76982a4 100644 --- a/testing/python/language/test_tilelang_language_cumsum.py +++ b/testing/python/language/test_tilelang_language_cumsum.py @@ -9,8 +9,8 @@ def cumsum_smem_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="float3 @T.prim_func def cumsum( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): @@ -28,8 +28,8 @@ def cumsum_fragment_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="fl @T.prim_func def cumsum( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): @@ -57,13 +57,16 @@ def run_cumsum(M, N, block_M, block_N, dim=0, reverse=False, dtype="float32", sc ref_b = torch.empty_like(A) for i in range(M // block_M): for j in range(N // block_N): - ref_b[i * block_M:(i + 1) * block_M, - j * block_N:(j + 1) * block_N] = A[i * block_M:(i + 1) * block_M, j * - block_N:(j + 1) * block_N].cumsum(dim=dim) + ref_b[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = A[ + i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N + ].cumsum(dim=dim) if reverse: - ref_b[i * block_M:(i + 1) * block_M, j * block_N:(j + 1) * - block_N] = A[i * block_M:(i + 1) * block_M, j * block_N:(j + 1) * - block_N].flip(dims=[dim]).cumsum(dim=dim).flip(dims=[dim]) + ref_b[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = ( + A[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] + .flip(dims=[dim]) + .cumsum(dim=dim) + .flip(dims=[dim]) + ) return ref_b tilelang_res = jit_kernel(A) @@ -76,8 +79,8 @@ def cumsum_smem_test_1d(N, block_N, reverse=False, dtype="float32"): @T.prim_func def cumsum( - A: T.Tensor((N,), dtype), - B: T.Tensor((N,), dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: A_shared = T.alloc_shared((block_N,), dtype) @@ -94,8 +97,8 @@ def cumsum_fragment_test_1d(N, block_N, reverse=False, dtype="float32"): @T.prim_func def cumsum( - A: T.Tensor((N,), dtype), - B: T.Tensor((N,), dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: A_shared = T.alloc_shared((block_N,), dtype) diff --git a/testing/python/language/test_tilelang_language_frontend_v2.py b/testing/python/language/test_tilelang_language_frontend_v2.py index 78d38f3..b0191b4 100644 --- a/testing/python/language/test_tilelang_language_frontend_v2.py +++ b/testing/python/language/test_tilelang_language_frontend_v2.py @@ -8,7 +8,6 @@ from tvm.tir.expr import IntImm, Var def test_argument(): - @T.prim_func def test_argument( t_1: T.bool, @@ -41,6 +40,7 @@ def test_argument(): def test_expr(): from tilelang.language.v2.dtypes import _all_dtypes + errors = [] for name in _all_dtypes: dtype = getattr(T, name) @@ -116,33 +116,32 @@ def test_expr(): def test_dtype_str_repr(): - @T.prim_func def test_str_repr(): - buf_1 = T.alloc_buffer((1,), dtype=T.bool, scope='shared') # noqa F841 - buf_2 = T.alloc_buffer((1,), dtype=T.short, scope='shared') # noqa F841 - buf_3 = T.alloc_buffer((1,), dtype=T.int, scope='shared') # noqa F841 - buf_4 = T.alloc_buffer((1,), dtype=T.long, scope='shared') # noqa F841 - buf_5 = T.alloc_buffer((1,), dtype=T.half, scope='shared') # noqa F841 - buf_6 = T.alloc_buffer((1,), dtype=T.float, scope='shared') # noqa F841 - buf_7 = T.alloc_buffer((1,), dtype=T.long, scope='shared') # noqa F841 - buf_8 = T.alloc_buffer((1,), dtype=T.int8, scope='shared') # noqa F841 - buf_9 = T.alloc_buffer((1,), dtype=T.int16, scope='shared') # noqa F841 - buf_10 = T.alloc_buffer((1,), dtype=T.int32, scope='shared') # noqa F841 - buf_11 = T.alloc_buffer((1,), dtype=T.int64, scope='shared') # noqa F841 - buf_12 = T.alloc_buffer((1,), dtype=T.uint8, scope='shared') # noqa F841 - buf_13 = T.alloc_buffer((1,), dtype=T.uint16, scope='shared') # noqa F841 - buf_14 = T.alloc_buffer((1,), dtype=T.uint32, scope='shared') # noqa F841 - buf_15 = T.alloc_buffer((1,), dtype=T.uint64, scope='shared') # noqa F841 - buf_16 = T.alloc_buffer((1,), dtype=T.float8_e4m3fn, scope='shared') # noqa F841 - buf_17 = T.alloc_buffer((1,), dtype=T.float8_e4m3fnuz, scope='shared') # noqa F841 - buf_18 = T.alloc_buffer((1,), dtype=T.float8_e5m2, scope='shared') # noqa F841 - buf_19 = T.alloc_buffer((1,), dtype=T.float8_e5m2fnuz, scope='shared') # noqa F841 - buf_20 = T.alloc_buffer((1,), dtype=T.float8_e8m0fnu, scope='shared') # noqa F841 - buf_21 = T.alloc_buffer((1,), dtype=T.float16, scope='shared') # noqa F841 - buf_22 = T.alloc_buffer((1,), dtype=T.bfloat16, scope='shared') # noqa F841 - buf_23 = T.alloc_buffer((1,), dtype=T.float32, scope='shared') # noqa F841 - buf_24 = T.alloc_buffer((1,), dtype=T.float64, scope='shared') # noqa F841 + buf_1 = T.alloc_buffer((1,), dtype=T.bool, scope="shared") # noqa F841 + buf_2 = T.alloc_buffer((1,), dtype=T.short, scope="shared") # noqa F841 + buf_3 = T.alloc_buffer((1,), dtype=T.int, scope="shared") # noqa F841 + buf_4 = T.alloc_buffer((1,), dtype=T.long, scope="shared") # noqa F841 + buf_5 = T.alloc_buffer((1,), dtype=T.half, scope="shared") # noqa F841 + buf_6 = T.alloc_buffer((1,), dtype=T.float, scope="shared") # noqa F841 + buf_7 = T.alloc_buffer((1,), dtype=T.long, scope="shared") # noqa F841 + buf_8 = T.alloc_buffer((1,), dtype=T.int8, scope="shared") # noqa F841 + buf_9 = T.alloc_buffer((1,), dtype=T.int16, scope="shared") # noqa F841 + buf_10 = T.alloc_buffer((1,), dtype=T.int32, scope="shared") # noqa F841 + buf_11 = T.alloc_buffer((1,), dtype=T.int64, scope="shared") # noqa F841 + buf_12 = T.alloc_buffer((1,), dtype=T.uint8, scope="shared") # noqa F841 + buf_13 = T.alloc_buffer((1,), dtype=T.uint16, scope="shared") # noqa F841 + buf_14 = T.alloc_buffer((1,), dtype=T.uint32, scope="shared") # noqa F841 + buf_15 = T.alloc_buffer((1,), dtype=T.uint64, scope="shared") # noqa F841 + buf_16 = T.alloc_buffer((1,), dtype=T.float8_e4m3fn, scope="shared") # noqa F841 + buf_17 = T.alloc_buffer((1,), dtype=T.float8_e4m3fnuz, scope="shared") # noqa F841 + buf_18 = T.alloc_buffer((1,), dtype=T.float8_e5m2, scope="shared") # noqa F841 + buf_19 = T.alloc_buffer((1,), dtype=T.float8_e5m2fnuz, scope="shared") # noqa F841 + buf_20 = T.alloc_buffer((1,), dtype=T.float8_e8m0fnu, scope="shared") # noqa F841 + buf_21 = T.alloc_buffer((1,), dtype=T.float16, scope="shared") # noqa F841 + buf_22 = T.alloc_buffer((1,), dtype=T.bfloat16, scope="shared") # noqa F841 + buf_23 = T.alloc_buffer((1,), dtype=T.float32, scope="shared") # noqa F841 + buf_24 = T.alloc_buffer((1,), dtype=T.float64, scope="shared") # noqa F841 # not supported now @@ -205,7 +204,6 @@ def test_dtype_str_repr(): def test_var_assign(): - @tilelang.jit(out_idx=-1) @T.prim_func def test_var_assign(A: T.Tensor((2,), T.int32)): @@ -223,7 +221,6 @@ def test_var_assign(): def test_marco_return(): - @T.macro def macro_return_constant(): return 0 @@ -258,11 +255,10 @@ def test_marco_return(): def test_prim_func_generator(): - @T.prim_func(generator=True) def prim_func_gen( - A=T.Tensor((128,), T.float32), # noqa: B008 - B=T.Tensor((128,), T.float32), # noqa: B008 + A=T.Tensor((128,), T.float32), # noqa: B008 + B=T.Tensor((128,), T.float32), # noqa: B008 ): with T.Kernel(128) as (tx,): T.copy(A[tx], B[tx]) @@ -277,7 +273,6 @@ def test_prim_func_generator(): def test_serial_for_with_step(): - @tilelang.jit(out_idx=-1) @T.prim_func def test_stepped_serial(A: T.Tensor((10,), T.int32)): @@ -291,7 +286,7 @@ def test_serial_for_with_step(): ker = test_stepped_serial() res = ker() - ref = torch.tensor([1, 2, 1, 2, 1, 2, 1, 2, 1, 2], dtype=torch.int32, device='cuda') + ref = torch.tensor([1, 2, 1, 2, 1, 2, 1, 2, 1, 2], dtype=torch.int32, device="cuda") assert torch.all(res == ref), f"Expected {ref}, but got {res}" @tilelang.jit(out_idx=-1) @@ -304,17 +299,16 @@ def test_serial_for_with_step(): ker = test_serial_step_neg() res = ker() - ref = torch.tensor([10, 9, 8, 7, 6, 5, 4, 3, 2, 1], dtype=torch.int32, device='cuda') + ref = torch.tensor([10, 9, 8, 7, 6, 5, 4, 3, 2, 1], dtype=torch.int32, device="cuda") assert torch.all(res == ref), f"Expected {ref}, but got {res}" assert isinstance(T.serial(1, 10, 1), IRBuilderFrame) - assert isinstance(T.serial(1, 10, IntImm('int32', 1)), IRBuilderFrame) - assert not isinstance(T.serial(1, 10, Var('tmp', 'int32')), IRBuilderFrame) + assert isinstance(T.serial(1, 10, IntImm("int32", 1)), IRBuilderFrame) + assert not isinstance(T.serial(1, 10, Var("tmp", "int32")), IRBuilderFrame) assert not isinstance(T.serial(10, -1, -1), IRBuilderFrame) def test_swap_logic(): - @tilelang.jit @T.prim_func def swap_var(A: T.Tensor[(2,), T.float32]): @@ -344,7 +338,6 @@ def test_swap_logic(): def test_while_loop(): - @tilelang.jit(out_idx=-1) @T.prim_func def test_while_loop(A: T.Tensor((1,), T.int32)): @@ -374,7 +367,7 @@ def test_var_macro(): x = T.alloc_var(T.int32) macro_with_var(x) - assert 'x[0] = 1' in prim_call_macro.script() + assert "x[0] = 1" in prim_call_macro.script() finally: pass @@ -406,7 +399,7 @@ def test_var_macro(): x = T.alloc_var(T.int32) macro_with_var(x) - assert 'x[0] = 1' in prim_call_macro.script() + assert "x[0] = 1" in prim_call_macro.script() finally: pass @@ -428,10 +421,8 @@ def test_var_macro(): def test_frame_inside_macro(): - @tilelang.jit def get_sample_kernel(): - @T.macro def transform(x): return x + 1 @@ -442,7 +433,7 @@ def test_frame_inside_macro(): idx_out: T.Tensor[(32,), T.int32], ): with T.Kernel(num_blocks, threads=32) as block_idx: # noqa: F841 - fragment = T.alloc_fragment(32, 'int32') + fragment = T.alloc_fragment(32, "int32") T.copy(idx_out, fragment) for i in T.Parallel(32): @@ -467,10 +458,10 @@ def test_buffer_slice_step(): def test_boolop(): - a = Var('a', 'int32') - b = Var('b', 'int32') - c = Var('c', 'int32') - d = Var('d', 'int32') + a = Var("a", "int32") + b = Var("b", "int32") + c = Var("c", "int32") + d = Var("d", "int32") @T.macro def cond(): @@ -479,5 +470,5 @@ def test_boolop(): cond() -if __name__ == '__main__': +if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/python/language/test_tilelang_language_get_warp_info.py b/testing/python/language/test_tilelang_language_get_warp_info.py index 68b65fc..edbc511 100644 --- a/testing/python/language/test_tilelang_language_get_warp_info.py +++ b/testing/python/language/test_tilelang_language_get_warp_info.py @@ -23,7 +23,6 @@ def _resolve_warps_per_group(warps_per_group: Optional[int]) -> int: @tilelang.jit(out_idx=[-1]) def _get_laneid_kernel(num_threads: int = 128, warp_size: Optional[int] = None): - @T.prim_func def laneid_kernel(A: T.Tensor((num_threads,), "int32")): with T.Kernel(1, threads=num_threads) as _: @@ -35,7 +34,6 @@ def _get_laneid_kernel(num_threads: int = 128, warp_size: Optional[int] = None): @tilelang.jit(out_idx=[-1]) def _get_warp_idx_sync_kernel(num_threads: int = 128, warp_size: Optional[int] = None): - @T.prim_func def warp_idx_sync_kernel(A: T.Tensor((num_threads,), "int32")): with T.Kernel(1, threads=num_threads) as _: @@ -47,7 +45,6 @@ def _get_warp_idx_sync_kernel(num_threads: int = 128, warp_size: Optional[int] = @tilelang.jit(out_idx=[-1]) def _get_warp_idx_kernel(num_threads: int = 128, warp_size: Optional[int] = None): - @T.prim_func def warp_idx_kernel(A: T.Tensor((num_threads,), "int32")): with T.Kernel(1, threads=num_threads) as _: @@ -63,7 +60,6 @@ def _get_warp_group_idx_kernel( warp_size: Optional[int] = None, warps_per_group: Optional[int] = None, ): - @T.prim_func def warp_group_idx_kernel(A: T.Tensor((num_threads,), "int32")): with T.Kernel(1, threads=num_threads) as _: @@ -75,7 +71,6 @@ def _get_warp_group_idx_kernel( @tilelang.jit(out_idx=[-1]) def _shuffle_elect_kernel(num_threads: int = 128, thread_extent: int = 64): - @T.prim_func def shuffle_elect_kernel(A: T.Tensor((num_threads,), "int32")): with T.Kernel(1, threads=num_threads) as _: diff --git a/testing/python/language/test_tilelang_language_if_range.py b/testing/python/language/test_tilelang_language_if_range.py index b3550f5..9c98456 100644 --- a/testing/python/language/test_tilelang_language_if_range.py +++ b/testing/python/language/test_tilelang_language_if_range.py @@ -4,13 +4,14 @@ import torch import tilelang.testing -@tilelang.jit(out_idx=[1],) +@tilelang.jit( + out_idx=[1], +) def tilelang_if_range(M, N, block_M, block_N, dtype="float16"): - @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): diff --git a/testing/python/language/test_tilelang_language_infinity.py b/testing/python/language/test_tilelang_language_infinity.py index 0779bff..5d25186 100644 --- a/testing/python/language/test_tilelang_language_infinity.py +++ b/testing/python/language/test_tilelang_language_infinity.py @@ -5,7 +5,6 @@ import tilelang.language as T @tilelang.jit(out_idx=-1) def get_inf_kernel(dtype: str): - @T.prim_func def main(A: T.Tensor((32,), dtype)): with T.Kernel(1, threads=32): @@ -18,7 +17,7 @@ def _test_infinity(dtype: str): kernel = get_inf_kernel(dtype) output = kernel() - assert torch.all(output == torch.inf), f'check failed for {dtype=}' + assert torch.all(output == torch.inf), f"check failed for {dtype=}" @tilelang.testing.requires_cuda diff --git a/testing/python/language/test_tilelang_language_intrinsics_codegen.py b/testing/python/language/test_tilelang_language_intrinsics_codegen.py index f817be2..8031824 100644 --- a/testing/python/language/test_tilelang_language_intrinsics_codegen.py +++ b/testing/python/language/test_tilelang_language_intrinsics_codegen.py @@ -9,8 +9,8 @@ def test_language_ldg_codegen(): @T.prim_func def main( - x: T.Tensor((N,), "float32"), - y: T.Tensor((N,), "float32"), + x: T.Tensor((N,), "float32"), + y: T.Tensor((N,), "float32"), ): with T.Kernel(N, threads=32) as pid: # Explicitly request read-only cache load for x[pid] diff --git a/testing/python/language/test_tilelang_language_lazy_jit.py b/testing/python/language/test_tilelang_language_lazy_jit.py index d3b20c6..31da09c 100644 --- a/testing/python/language/test_tilelang_language_lazy_jit.py +++ b/testing/python/language/test_tilelang_language_lazy_jit.py @@ -8,7 +8,6 @@ import torch def _gemm_impl(): - @T.macro def gemm_impl( A: T.Tensor[[int, int], Any], @@ -37,7 +36,6 @@ def _gemm_impl(): def test_jit2_gemm_annot(): - @tilelang.lazy_jit def gemm( A: T.Tensor[[int, int], Any], @@ -54,24 +52,24 @@ def test_jit2_gemm_annot(): return C prod = product([T.float16, T.float32], [T.float32]) - gemm.par_compile([{ - 'A': T.Tensor((1024, 1024), dtype=in_dtype), - 'B': T.Tensor((1024, 1024), dtype=in_dtype), - 'out_dtype': out_dtype - } for in_dtype, out_dtype in prod]) + gemm.par_compile( + [ + {"A": T.Tensor((1024, 1024), dtype=in_dtype), "B": T.Tensor((1024, 1024), dtype=in_dtype), "out_dtype": out_dtype} + for in_dtype, out_dtype in prod + ] + ) for in_dtype, out_dtype in prod: in_dtype = in_dtype.torch() out_dtype = out_dtype.torch() - A = torch.randn(1024, 1024, dtype=in_dtype, device='cuda') - B = torch.randn(1024, 1024, dtype=in_dtype, device='cuda') + A = torch.randn(1024, 1024, dtype=in_dtype, device="cuda") + B = torch.randn(1024, 1024, dtype=in_dtype, device="cuda") C_ref = out_dtype(A @ B) C = gemm(A, B) torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2) def test_jit2_gemm_ptr(): - @tilelang.lazy_jit def gemm_ptr( A: T.ptr, @@ -92,23 +90,19 @@ def test_jit2_gemm_ptr(): _gemm_impl()(A, B, C, out_dtype, block_M, block_N, block_K) prod = product([T.float16, T.float32], [T.float32]) - gemm_ptr.par_compile([{ - 'A': T.ptr(), - 'B': T.ptr(), - 'C': T.ptr(), - 'M': 1024, - 'N': 1024, - 'K': 1024, - 'dtype': in_dtype, - 'out_dtype': out_dtype - } for in_dtype, out_dtype in prod]) + gemm_ptr.par_compile( + [ + {"A": T.ptr(), "B": T.ptr(), "C": T.ptr(), "M": 1024, "N": 1024, "K": 1024, "dtype": in_dtype, "out_dtype": out_dtype} + for in_dtype, out_dtype in prod + ] + ) for in_dtype, out_dtype in prod: in_dtype = in_dtype.torch() out_dtype = out_dtype.torch() - A = torch.randn(1024, 1024, dtype=in_dtype, device='cuda') - B = torch.randn(1024, 1024, dtype=in_dtype, device='cuda') + A = torch.randn(1024, 1024, dtype=in_dtype, device="cuda") + B = torch.randn(1024, 1024, dtype=in_dtype, device="cuda") C_ref = out_dtype(A @ B) - C = torch.empty(1024, 1024, dtype=out_dtype, device='cuda') + C = torch.empty(1024, 1024, dtype=out_dtype, device="cuda") gemm_ptr(A, B, C, 1024, 1024, 1024, in_dtype, out_dtype) torch.testing.assert_close(C, C_ref, atol=1e-2, rtol=1e-2) @@ -129,8 +123,7 @@ def test_jit2_annot(): AnnotTest( annot=T.Tensor[[int, int], T.float32], promote=False, - match_ok=[torch.randn(1, 1, dtype=torch.float32), - T.Tensor((1, 1), dtype=T.float32)], + match_ok=[torch.randn(1, 1, dtype=torch.float32), T.Tensor((1, 1), dtype=T.float32)], match_ng=[ torch.randn(1, 1, dtype=torch.float16), T.Tensor(1, dtype=T.float32), @@ -146,8 +139,8 @@ def test_jit2_annot(): T.Tensor((1,), dtype=T.float32), T.Tensor((1,), dtype=T.float16), ], - match_ng=[torch.randn((1, 1), dtype=torch.float32), - T.Tensor((1, 1), dtype=T.float16)]), + match_ng=[torch.randn((1, 1), dtype=torch.float32), T.Tensor((1, 1), dtype=T.float16)], + ), AnnotTest( annot=T.Tensor[[int, 1], Any], promote=False, @@ -157,8 +150,8 @@ def test_jit2_annot(): T.Tensor((12, 1), T.float32), T.Tensor((12, 1), T.float16), ], - match_ng=[torch.randn(12, 12, dtype=torch.float32), - T.Tensor((12, 12), T.float32)]), + match_ng=[torch.randn(12, 12, dtype=torch.float32), T.Tensor((12, 12), T.float32)], + ), AnnotTest( annot=T.Tensor[[T.dyn, 1], Any], promote=False, @@ -168,43 +161,39 @@ def test_jit2_annot(): T.Tensor((12, 1), T.float32), T.Tensor((12, 1), T.float16), ], - match_ng=[torch.randn(12, 12, dtype=torch.float32), - T.Tensor((12, 12), T.float32)]), + match_ng=[torch.randn(12, 12, dtype=torch.float32), T.Tensor((12, 12), T.float32)], + ), AnnotTest( annot=T.Tensor[[1024, 1024], T.float32], promote=True, ), - AnnotTest(annot=T.dyn[int, 'X'], promote=False, match_ok=[1, 2, 3, 4]), - AnnotTest(annot=T.dyn, promote=False, match_ok=[1, 2, 3, 4]) + AnnotTest(annot=T.dyn[int, "X"], promote=False, match_ok=[1, 2, 3, 4]), + AnnotTest(annot=T.dyn, promote=False, match_ok=[1, 2, 3, 4]), ] for test in tests: promote = test.annot.promote() promoted = promote is not None if promoted != test.promote: - raise AssertionError( - f'Promote mismatch for {test.annot}: expected {test.promote}, got {promoted}') - with Builder().prim_func('_test'): + raise AssertionError(f"Promote mismatch for {test.annot}: expected {test.promote}, got {promoted}") + with Builder().prim_func("_test"): for match_ok in test.match_ok: try: vt = ArgVarTable() - test.annot.create_prim_func_arg('arg', match_ok, vt) + test.annot.create_prim_func_arg("arg", match_ok, vt) except Exception as e: traceback.print_exc() - raise AssertionError( - f'Match failed for {test.annot} with value {match_ok}: {e}') from e + raise AssertionError(f"Match failed for {test.annot} with value {match_ok}: {e}") from e for match_ng in test.match_ng: try: vt = ArgVarTable() - test.annot.create_prim_func_arg('arg', match_ng, vt) - raise AssertionError( - f'Match unexpectedly succeeded for {test.annot} with value {match_ng}') + test.annot.create_prim_func_arg("arg", match_ng, vt) + raise AssertionError(f"Match unexpectedly succeeded for {test.annot} with value {match_ng}") except Exception: pass def test_jit2_many_annot(): - @T.macro def copy_impl(A, B): M, N = A.shape @@ -213,8 +202,7 @@ def test_jit2_many_annot(): assert N == N_, f"N mismatch {N} {N_}" # assert tuple(A.shape) == tuple(B.shape), f"Invalid tensor shape: {A.shape}, {B.shape}" with T.Kernel(T.ceildiv(M, 128), T.ceildiv(N, 128), threads=128) as (bx, by): - T.copy(A[bx * 128:bx * 128 + 128, by * 128:by * 128 + 128], B[bx * 128:bx * 128 + 128, - by * 128:by * 128 + 128]) + T.copy(A[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128], B[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) @tilelang.lazy_jit def copy1( @@ -259,20 +247,19 @@ def test_jit2_many_annot(): copy_impl(A, B) for copy in [copy1, copy2, copy3, copy4]: - A = torch.randn(128, 128, device='cuda') - B = torch.empty(128, 128, device='cuda') + A = torch.randn(128, 128, device="cuda") + B = torch.empty(128, 128, device="cuda") copy(A, B) assert torch.equal(B, A) for copy in [copy5, copy6]: - A = torch.randn(128, 2, 128, 2, device='cuda') - B = torch.randn(128, 2, 128, 2, device='cuda') + A = torch.randn(128, 2, 128, 2, device="cuda") + B = torch.randn(128, 2, 128, 2, device="cuda") copy(A[:, 0, :, 0], B[:, 0, :, 0]) assert torch.equal(A[:, 0, :, 0], B[:, 0, :, 0]) def test_jit2_return(): - @T.macro def copy_impl(A): M, N = A.shape @@ -283,8 +270,7 @@ def test_jit2_return(): assert N == N_, f"N mismatch {N} {N_}" # assert tuple(A.shape) == tuple(B.shape), f"Invalid tensor shape: {A.shape}, {B.shape}" with T.Kernel(T.ceildiv(M, 128), T.ceildiv(N, 128), threads=128) as (bx, by): - T.copy(A[bx * 128:bx * 128 + 128, by * 128:by * 128 + 128], B[bx * 128:bx * 128 + 128, - by * 128:by * 128 + 128]) + T.copy(A[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128], B[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) return B @tilelang.lazy_jit @@ -292,41 +278,52 @@ def test_jit2_return(): return copy_impl(A) @tilelang.lazy_jit - def copy1(A: T.Tensor[[int, int], T.float32],): + def copy1( + A: T.Tensor[[int, int], T.float32], + ): return copy_impl(A) @tilelang.lazy_jit - def copy2(A: T.Tensor[[128, 128], T.float32],): + def copy2( + A: T.Tensor[[128, 128], T.float32], + ): return copy_impl(A) @tilelang.lazy_jit - def copy3(A: T.Tensor[[int, 128], T.float32],): + def copy3( + A: T.Tensor[[int, 128], T.float32], + ): return copy_impl(A) @tilelang.lazy_jit - def copy4(A: T.Tensor[[T.dyn, int], T.float32],): + def copy4( + A: T.Tensor[[T.dyn, int], T.float32], + ): return copy_impl(A) @tilelang.lazy_jit - def copy5(A: T.StridedTensor[[int, int], [int, int], T.float32],): + def copy5( + A: T.StridedTensor[[int, int], [int, int], T.float32], + ): return copy_impl(A) @tilelang.lazy_jit - def copy6(A: T.StridedTensor[[T.dyn, int], [int, int], T.float32],): + def copy6( + A: T.StridedTensor[[T.dyn, int], [int, int], T.float32], + ): return copy_impl(A) for copy in [copy0, copy1, copy2, copy3, copy4]: - A = torch.randn(128, 128, device='cuda') + A = torch.randn(128, 128, device="cuda") B = copy(A) assert torch.equal(B, A) for copy in [copy5, copy6]: - A = torch.randn(128, 2, 128, 2, device='cuda') + A = torch.randn(128, 2, 128, 2, device="cuda") B = copy(A[:, 0, :, 0]) assert torch.equal(A[:, 0, :, 0], B) def test_jit2_deepseek_deepgemm(): - @tilelang.lazy_jit def deep_gemm( A: T.Tensor[[int, int], T.float8_e4m3], @@ -351,13 +348,9 @@ def test_jit2_deepseek_deepgemm(): N, K = B.shape C = T.empty(M, N, dtype=out_dtype) - assert out_dtype in [ - T.bfloat16, T.float32 - ], f"Expect out_dtype to be one of [T.float16, T.float32], got {out_dtype}" - assert scales_a.shape == [M, T.ceildiv(K, group_size) - ], f"Expect scales_a shape to be f{[M, T.ceildiv(K, group_size)]}" - assert scales_b.shape == [N, T.ceildiv(K, group_size) - ], f"Expect scales_b shape to be f{[N, T.ceildiv(K, group_size)]}" + assert out_dtype in [T.bfloat16, T.float32], f"Expect out_dtype to be one of [T.float16, T.float32], got {out_dtype}" + assert scales_a.shape == [M, T.ceildiv(K, group_size)], f"Expect scales_a shape to be f{[M, T.ceildiv(K, group_size)]}" + assert scales_b.shape == [N, T.ceildiv(K, group_size)], f"Expect scales_b shape to be f{[N, T.ceildiv(K, group_size)]}" with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), in_dtype) @@ -421,5 +414,5 @@ def test_jit2_deepseek_deepgemm(): # M, N, K = 1024, 1024, 8192 # A = torch.randn((M, K), dtype=torch.float8_e4m3fn, ) -if __name__ == '__main__': +if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/python/language/test_tilelang_language_let.py b/testing/python/language/test_tilelang_language_let.py index a2af09c..a290595 100644 --- a/testing/python/language/test_tilelang_language_let.py +++ b/testing/python/language/test_tilelang_language_let.py @@ -4,7 +4,6 @@ from tilelang import language as T def test_let_vectorize_load(): - @T.prim_func def main(A_ptr: T.handle): A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16) diff --git a/testing/python/language/test_tilelang_language_mask_op.py b/testing/python/language/test_tilelang_language_mask_op.py index ad90785..37b5204 100644 --- a/testing/python/language/test_tilelang_language_mask_op.py +++ b/testing/python/language/test_tilelang_language_mask_op.py @@ -6,11 +6,10 @@ import torch # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit def tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype="float16"): - @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): @@ -30,13 +29,8 @@ def tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype="float16"): def run_tilelang_copy_mask_parallel(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"): program = tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype) kernel = tilelang.compile( - program, - out_idx=[1], - target="cuda", - pass_configs={ - "tl.disable_warp_specialized": True, - "tl.disable_tma_lower": True - }) + program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True} + ) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) b = kernel(a) torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) @@ -49,11 +43,10 @@ def test_tilelang_copy_mask_parallel(): # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit def tilelang_copy_mask_copy(M, N, block_M, block_N, dtype="float16"): - @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): @@ -72,13 +65,8 @@ def tilelang_copy_mask_copy(M, N, block_M, block_N, dtype="float16"): def run_tilelang_copy_mask_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"): program = tilelang_copy_mask_copy(M, N, block_M, block_N, dtype) kernel = tilelang.compile( - program, - out_idx=[1], - target="cuda", - pass_configs={ - "tl.disable_warp_specialized": True, - "tl.disable_tma_lower": True - }) + program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True} + ) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) b = kernel(a) torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) @@ -91,11 +79,10 @@ def test_tilelang_copy_mask_copy(): # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit def tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype="float16"): - @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): @@ -112,20 +99,11 @@ def tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype="float16"): return main -def run_tilelang_copy_mask_parallel_range(M=1024, - N=1024, - block_M=128, - block_N=128, - dtype="float16"): +def run_tilelang_copy_mask_parallel_range(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"): program = tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype) kernel = tilelang.compile( - program, - out_idx=[1], - target="cuda", - pass_configs={ - "tl.disable_warp_specialized": True, - "tl.disable_tma_lower": True - }) + program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True} + ) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) b = kernel(a) torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) @@ -138,11 +116,10 @@ def test_tilelang_copy_mask_parallel_range(): # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit def tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype="float16"): - @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): @@ -161,13 +138,8 @@ def tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype="float16"): def run_tilelang_copy_mask_copy_range(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"): program = tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype) kernel = tilelang.compile( - program, - out_idx=[1], - target="cuda", - pass_configs={ - "tl.disable_warp_specialized": True, - "tl.disable_tma_lower": True - }) + program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True} + ) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) b = kernel(a) torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) diff --git a/testing/python/language/test_tilelang_language_negative_index.py b/testing/python/language/test_tilelang_language_negative_index.py index 4a0df87..c052ccb 100644 --- a/testing/python/language/test_tilelang_language_negative_index.py +++ b/testing/python/language/test_tilelang_language_negative_index.py @@ -31,8 +31,7 @@ def negative_index_loop_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((4,) @T.prim_func -def negative_index_symbolic_before(shift: T.int32, A: T.Buffer((16,), "float32"), - B: T.Buffer((16,), "float32")): +def negative_index_symbolic_before(shift: T.int32, A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): T.func_attr({"tir.noalias": True}) for i in T.serial(16): B[i] = A[shift + i] diff --git a/testing/python/language/test_tilelang_language_parallel.py b/testing/python/language/test_tilelang_language_parallel.py index b51ca8b..b0e85ff 100644 --- a/testing/python/language/test_tilelang_language_parallel.py +++ b/testing/python/language/test_tilelang_language_parallel.py @@ -9,11 +9,10 @@ tilelang.testing.set_random_seed() @tilelang.jit(out_idx=[1]) def parallel_elementwise_static(length=256, dtype="float32"): - @T.prim_func def main( - A: T.Tensor((length,), dtype), - B: T.Tensor((length,), dtype), + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), ): with T.Kernel(1, threads=length) as _: for i in T.Parallel(length): @@ -24,12 +23,11 @@ def parallel_elementwise_static(length=256, dtype="float32"): @tilelang.jit(out_idx=[1]) def parallel_elementwise_dynamic(max_len=512, threads=256, dtype="float32"): - @T.prim_func def main( - A: T.Tensor((max_len,), dtype), - B: T.Tensor((max_len,), dtype), - valid_len: T.int32, + A: T.Tensor((max_len,), dtype), + B: T.Tensor((max_len,), dtype), + valid_len: T.int32, ): with T.Kernel(1, threads=threads) as _: for i in T.Parallel(max_len): diff --git a/testing/python/language/test_tilelang_language_pipeline.py b/testing/python/language/test_tilelang_language_pipeline.py index 212f281..54e1055 100644 --- a/testing/python/language/test_tilelang_language_pipeline.py +++ b/testing/python/language/test_tilelang_language_pipeline.py @@ -27,9 +27,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -90,7 +90,8 @@ def run_gemm( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = kernel.get_profiler() def ref_program(A, B): @@ -103,8 +104,8 @@ def run_gemm( if in_dtype == "float32": # Convert float32 to tfloat32 because tfloat32 mma cannot truncate # float32 automatically, -0x1000 meas - A = ((A.view(torch.int32) - 0x1000)).view(torch.float32) - B = ((B.view(torch.int32) - 0x1000)).view(torch.float32) + A = (A.view(torch.int32) - 0x1000).view(torch.float32) + B = (B.view(torch.int32) - 0x1000).view(torch.float32) C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = C.to(torch.__getattribute__(out_dtype)) return C @@ -124,27 +125,19 @@ def test_pipeline_order_stage(): pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) -def blocksparse_matmul(M, - N, - K, - block_M, - block_N, - block_K, - num_stages, - dtype="float16", - accum_dtype="float"): - + }, +) +def blocksparse_matmul(M, N, K, block_M, block_N, block_K, num_stages, dtype="float16", accum_dtype="float"): block_mask_shape = (M // block_M, N // block_N, K // block_K) import tilelang.language as T @T.prim_func def block_sparse_matmul( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - BlockMask: T.Tensor(block_mask_shape, "bool"), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + BlockMask: T.Tensor(block_mask_shape, "bool"), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -183,8 +176,7 @@ def run_blocksparse_matmul(num_stages): a = torch.randn(M, K).cuda().half() b = torch.randn(K, N).cuda().half() - kernel = blocksparse_matmul( - M, N, K, block_M=block_M, block_N=block_N, block_K=block_K, num_stages=num_stages) + kernel = blocksparse_matmul(M, N, K, block_M=block_M, block_N=block_N, block_K=block_K, num_stages=num_stages) print(kernel.get_kernel_source()) # Create block mask with desired sparsity mask_shape = (M // block_M, N // block_N, K // block_K) @@ -200,12 +192,10 @@ def run_blocksparse_matmul(num_stages): accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device) for k in range(K // block_K): if BlockMask[i, j, k]: - accu += ( - A[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to( - torch.float32) @ B[k * block_K:(k + 1) * block_K, - j * block_N:(j + 1) * block_N].to(torch.float32)) - ref_c[i * block_M:(i + 1) * block_M, - j * block_N:(j + 1) * block_N] = accu.to(torch.float16) + accu += A[i * block_M : (i + 1) * block_M, k * block_K : (k + 1) * block_K].to(torch.float32) @ B[ + k * block_K : (k + 1) * block_K, j * block_N : (j + 1) * block_N + ].to(torch.float32) + ref_c[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = accu.to(torch.float16) return ref_c # Compute the reference result using the naive PyTorch implementation diff --git a/testing/python/language/test_tilelang_language_ptr.py b/testing/python/language/test_tilelang_language_ptr.py index e4659ec..0e60ddd 100644 --- a/testing/python/language/test_tilelang_language_ptr.py +++ b/testing/python/language/test_tilelang_language_ptr.py @@ -7,7 +7,6 @@ from tilelang.utils import map_torch_type def matmul_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def main( a_ptr: T.ptr, diff --git a/testing/python/language/test_tilelang_language_reduce.py b/testing/python/language/test_tilelang_language_reduce.py index cecfaa0..7ec5003 100644 --- a/testing/python/language/test_tilelang_language_reduce.py +++ b/testing/python/language/test_tilelang_language_reduce.py @@ -10,8 +10,8 @@ def _make_shared_reduce(M, N, dtype, reduce_cb): @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M,), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M,), dtype), ): with T.Kernel(1) as _: A_shared = T.alloc_shared((M, N), dtype) @@ -35,8 +35,8 @@ def reduce_max_test(M, N, dtype="float16"): @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M,), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M,), dtype), ): with T.Kernel(1) as _: A_local = T.alloc_fragment((M, N), dtype) @@ -54,8 +54,8 @@ def reduce_sum_test(M, N, dtype="float32"): @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M,), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M,), dtype), ): with T.Kernel(1) as _: A_local = T.alloc_fragment((M, N), dtype) @@ -145,8 +145,8 @@ def reduce_sum_test_clear(M, N, dtype="float32"): @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M,), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M,), dtype), ): with T.Kernel(1, threads=32) as _: A_local = T.alloc_fragment((M, N), dtype) @@ -186,8 +186,8 @@ def reduce_max_test_clear(M, N, dtype="float16"): @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M,), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M,), dtype), ): with T.Kernel(1, threads=32) as _: A_local = T.alloc_fragment((M, N), dtype) diff --git a/testing/python/language/test_tilelang_language_reshape.py b/testing/python/language/test_tilelang_language_reshape.py index 60588b4..3c34330 100644 --- a/testing/python/language/test_tilelang_language_reshape.py +++ b/testing/python/language/test_tilelang_language_reshape.py @@ -10,8 +10,8 @@ def reshape_test(N, M, dtype): @T.prim_func def main( - A: T.Tensor((N,), dtype), - B: T.Tensor((N // M, M), dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor((N // M, M), dtype), ): with T.Kernel(1) as _: A_reshaped = T.reshape(A, [N // M, M]) @@ -30,7 +30,8 @@ def run_reshape(N, M, dtype): pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = jit_kernel.get_profiler() def ref_program(A): @@ -50,8 +51,8 @@ def reshape_test_smem_1d_2_2d(N, M, dtype): @T.prim_func def main( - A: T.Tensor((N,), dtype), - B: T.Tensor((N // M, M), dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor((N // M, M), dtype), ): with T.Kernel(1) as _: A_shared = T.alloc_shared((N,), dtype) @@ -74,7 +75,8 @@ def run_reshape_smem_1d_2_2d(N, M, dtype): pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = jit_kernel.get_profiler() def ref_program(A): @@ -93,8 +95,8 @@ def reshape_test_smem_2d_2_1d(N, M, dtype): @T.prim_func def main( - A: T.Tensor((N // M, M), dtype), - B: T.Tensor((N,), dtype), + A: T.Tensor((N // M, M), dtype), + B: T.Tensor((N,), dtype), ): with T.Kernel(1) as _: A_shared = T.alloc_shared((N // M, M), dtype) @@ -117,7 +119,8 @@ def run_reshape_smem_2d_2_1d(N, M, dtype): pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = jit_kernel.get_profiler() def ref_program(A): @@ -136,8 +139,8 @@ def reshape_fragment_test(N, M, dtype): @T.prim_func def main( - A: T.Tensor((N // M, M), dtype), - B: T.Tensor((N,), dtype), + A: T.Tensor((N // M, M), dtype), + B: T.Tensor((N,), dtype), ): with T.Kernel(1, threads=32) as _: A_shared = T.alloc_shared((N // M, M), dtype, scope="shared") @@ -161,7 +164,8 @@ def run_reshape_fragment(N, M, dtype): pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = jit_kernel.get_profiler() def ref_program(A): @@ -181,15 +185,17 @@ def reshape_layout_transform_shared(N, M, dtype): @T.prim_func def main( - A: T.Tensor((N // M, M), dtype), - B: T.Tensor((N,), dtype), + A: T.Tensor((N // M, M), dtype), + B: T.Tensor((N,), dtype), ): with T.Kernel(1, threads=32) as _: A_shared = T.alloc_shared((N // M, M), dtype, scope="shared") - T.annotate_layout({ - A_shared: make_mma_swizzle_layout(A_shared), - }) + T.annotate_layout( + { + A_shared: make_mma_swizzle_layout(A_shared), + } + ) T.copy(A, A_shared) A_shared_reshape = T.reshape(A_shared, [N]) T.copy(A_shared_reshape, B) @@ -205,7 +211,8 @@ def run_reshape_layout_transform_shared(N, M, dtype): pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = jit_kernel.get_profiler() def ref_program(A): @@ -224,8 +231,8 @@ def reduce_after_reshape_test(N, M, dtype): @T.prim_func def main( - A: T.Tensor((N,), dtype), - B: T.Tensor((N // M,), dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor((N // M,), dtype), ): with T.Kernel(1, threads=32) as _: A_shared = T.alloc_shared((N,), dtype, scope="shared") @@ -249,7 +256,8 @@ def run_reduce_after_reshape(N, M, dtype): pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = jit_kernel.get_profiler() def ref_program(A): @@ -268,8 +276,8 @@ def reshape_shape_mismatch_test(N, M, dtype): @T.prim_func def main( - A: T.Tensor((N,), dtype), - B: T.Tensor((N // M, M), dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor((N // M, M), dtype), ): with T.Kernel(1) as _: A_reshaped = T.reshape(A, [N // M, M + 1]) diff --git a/testing/python/language/test_tilelang_language_ternary.py b/testing/python/language/test_tilelang_language_ternary.py index 821231a..632dcf7 100644 --- a/testing/python/language/test_tilelang_language_ternary.py +++ b/testing/python/language/test_tilelang_language_ternary.py @@ -4,19 +4,19 @@ import torch import tilelang.testing -@tilelang.jit(out_idx=[1],) +@tilelang.jit( + out_idx=[1], +) def tilelang_ternary(M, N, block_M, block_N, dtype="float16"): - @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): for i, j in T.Parallel(block_M, block_N): - B[by * block_M + i, bx * block_N + j] = ( - A[by * block_M + i, bx * block_N + j] if (by * block_M + i) < (M // 2) else 0) + B[by * block_M + i, bx * block_N + j] = A[by * block_M + i, bx * block_N + j] if (by * block_M + i) < (M // 2) else 0 return main diff --git a/testing/python/language/test_tilelang_language_tma_1d.py b/testing/python/language/test_tilelang_language_tma_1d.py index efb665b..90022b5 100644 --- a/testing/python/language/test_tilelang_language_tma_1d.py +++ b/testing/python/language/test_tilelang_language_tma_1d.py @@ -9,10 +9,8 @@ def ref_program(x, y): @tilelang.jit(out_idx=[-1]) def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): - @T.prim_func - def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor( - (M, N), out_dtype)): + def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N), out_dtype)): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared((block_M, block_N), in_dtype) B_shared = T.alloc_shared((block_M, block_N), in_dtype) @@ -21,7 +19,7 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): T.copy(A[by * block_M, bx * block_N], A_shared) T.copy(B[by * block_M, bx * block_N], B_shared) - for (local_y, local_x) in T.Parallel(block_M, block_N): + for local_y, local_x in T.Parallel(block_M, block_N): C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x] T.copy(C_local, C_shared) T.copy(C_shared, C[by * block_M, bx * block_N]) diff --git a/testing/python/language/test_tilelang_language_unroll.py b/testing/python/language/test_tilelang_language_unroll.py index 1796302..416840a 100644 --- a/testing/python/language/test_tilelang_language_unroll.py +++ b/testing/python/language/test_tilelang_language_unroll.py @@ -4,7 +4,6 @@ from tilelang import language as T def test_unroll_with_step(): - @T.prim_func def main(A_ptr: T.handle): A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16) @@ -19,7 +18,6 @@ def test_unroll_with_step(): def test_unroll_with_unroll_factor(): - @T.prim_func def main(A_ptr: T.handle): A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16) diff --git a/testing/python/language/test_tilelang_language_var_init.py b/testing/python/language/test_tilelang_language_var_init.py index a5a7dde..d4f9062 100644 --- a/testing/python/language/test_tilelang_language_var_init.py +++ b/testing/python/language/test_tilelang_language_var_init.py @@ -4,17 +4,15 @@ import tilelang.testing def test_var_assign() -> None: - @tilelang.jit(out_idx=-1) def jit_kernel(): - @T.prim_func - def test_var_assign(A: T.Tensor((2,), 'int32')): + def test_var_assign(A: T.Tensor((2,), "int32")): with T.Kernel(1) as _: - a = T.alloc_var('int32', init=1) - b = T.alloc_var('int32', init=a) # b gets value of a + a = T.alloc_var("int32", init=1) + b = T.alloc_var("int32", init=a) # b gets value of a a = 2 - d = T.alloc_var('int32', init=a) # c gets new value of a + d = T.alloc_var("int32", init=a) # c gets new value of a A[0] = b A[1] = d @@ -28,5 +26,5 @@ def test_var_assign() -> None: assert res[1] == 2 -if __name__ == '__main__': +if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/python/language/test_tilelang_language_vectorize.py b/testing/python/language/test_tilelang_language_vectorize.py index bc2d314..6867079 100644 --- a/testing/python/language/test_tilelang_language_vectorize.py +++ b/testing/python/language/test_tilelang_language_vectorize.py @@ -5,11 +5,10 @@ import tilelang.language as T @tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True}) def vectorize_test(N, M, stride_A, stride_B): - @T.prim_func def main( - A: T.StridedTensor[(N, M), (1, stride_A), "float32"], # noqa: F821 - B: T.StridedTensor[(N, M), (1, stride_B), "float32"], # noqa: F821 + A: T.StridedTensor[(N, M), (1, stride_A), "float32"], # noqa: F821 + B: T.StridedTensor[(N, M), (1, stride_B), "float32"], # noqa: F821 ): with T.Kernel(M // 128, threads=128) as (bx): tx = T.get_thread_binding(0) @@ -39,9 +38,7 @@ def run_vectorize(N, M, stride_A, stride_B): code = jit_kernel.get_kernel_source() vectorize_size = 1 - while vectorize_size <= 2 and \ - stride_A % (vectorize_size * 2) == 0 and \ - stride_B % (vectorize_size * 2) == 0: + while vectorize_size <= 2 and stride_A % (vectorize_size * 2) == 0 and stride_B % (vectorize_size * 2) == 0: vectorize_size *= 2 if vectorize_size == 4: @@ -61,12 +58,11 @@ def test_vectorize(): @tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True}) def vectorize_test_invariant_index(N, M, K): - @T.prim_func def main( - A: T.Tensor[(N, M), "float32"], # noqa: F821 - B: T.Tensor[(N, M), "float32"], # noqa: F821 - C: T.Tensor[(N, M // K), "float32"], # noqa: F821 + A: T.Tensor[(N, M), "float32"], # noqa: F821 + B: T.Tensor[(N, M), "float32"], # noqa: F821 + C: T.Tensor[(N, M // K), "float32"], # noqa: F821 ): with T.Kernel(N // 128, threads=128) as (bx): tx = T.get_thread_binding(0) diff --git a/testing/python/language/test_tilelang_language_vectorized_cast.py b/testing/python/language/test_tilelang_language_vectorized_cast.py index afb8a05..adb59a6 100644 --- a/testing/python/language/test_tilelang_language_vectorized_cast.py +++ b/testing/python/language/test_tilelang_language_vectorized_cast.py @@ -17,8 +17,8 @@ def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str): @T.prim_func def main( - A: T.Tensor[(M,), dtype_A], # noqa: F821 - B: T.Tensor[(M,), dtype_B], # noqa: F821 + A: T.Tensor[(M,), dtype_A], # noqa: F821 + B: T.Tensor[(M,), dtype_B], # noqa: F821 ): with T.Kernel(1, threads=128): T.copy(A, B) @@ -32,8 +32,8 @@ def parallel_vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str): @T.prim_func def main( - A: T.Tensor[(M,), dtype_A], # noqa: F821 - B: T.Tensor[(M,), dtype_B], # noqa: F821 + A: T.Tensor[(M,), dtype_A], # noqa: F821 + B: T.Tensor[(M,), dtype_B], # noqa: F821 ): with T.Kernel(1, threads=128): A_local = T.alloc_fragment((M,), dtype_A) @@ -73,8 +73,7 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str, code = kernel.get_kernel_source() code_parallel = kernel_parallel.get_kernel_source() - assert check_str in code and check_str in code_parallel, \ - f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!" + assert check_str in code and check_str in code_parallel, f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!" def test_vectorized_cast(): diff --git a/testing/python/language/test_tilelang_language_view.py b/testing/python/language/test_tilelang_language_view.py index a79d428..ff050e3 100644 --- a/testing/python/language/test_tilelang_language_view.py +++ b/testing/python/language/test_tilelang_language_view.py @@ -10,6 +10,7 @@ def view_test(N, M, dtype, new_dtype=None): new_shape = [N // M, M] if new_dtype: from tvm import DataType + dtype_src = DataType(dtype) dtype_dst = DataType(new_dtype) src_bits = dtype_src.bits @@ -19,8 +20,8 @@ def view_test(N, M, dtype, new_dtype=None): @T.prim_func def main( - A: T.Tensor((N,), dtype), - B: T.Tensor(new_shape, new_dtype if new_dtype else dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor(new_shape, new_dtype if new_dtype else dtype), ): with T.Kernel(1) as _: A_viewed = T.view(A, new_shape, dtype=new_dtype) @@ -37,6 +38,7 @@ def run_view(N, M, dtype, new_dtype=None): def ref_program(A): if new_dtype: from tilelang.utils.tensor import map_torch_type + torch_dtype = map_torch_type(new_dtype) return A.view(N // M, M).view(dtype=torch_dtype) return A.view(N // M, M) @@ -45,7 +47,6 @@ def run_view(N, M, dtype, new_dtype=None): def test_reshape_view(): - # Test view with same dtype run_view(1024, 32, "float32") run_view(2048, 64, "float16") @@ -61,6 +62,7 @@ def view_shape_mismatch_test(N, M, dtype, new_dtype=None): new_shape = [N // M, M + 1] if new_dtype: from tvm import DataType + dtype_src = DataType(dtype) dtype_dst = DataType(new_dtype) src_bits = dtype_src.bits @@ -70,8 +72,8 @@ def view_shape_mismatch_test(N, M, dtype, new_dtype=None): @T.prim_func def main( - A: T.Tensor((N,), dtype), - B: T.Tensor(new_shape, new_dtype if new_dtype else dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor(new_shape, new_dtype if new_dtype else dtype), ): with T.Kernel(1) as _: A_viewed = T.view(A, new_shape, dtype=new_dtype) diff --git a/testing/python/language/test_tilelang_language_warp_reduce.py b/testing/python/language/test_tilelang_language_warp_reduce.py index 681b234..0a0fb70 100644 --- a/testing/python/language/test_tilelang_language_warp_reduce.py +++ b/testing/python/language/test_tilelang_language_warp_reduce.py @@ -7,7 +7,6 @@ import tilelang.language as T @tilelang.jit def get_kernel(reduce_op: str, dtype: str): - assert reduce_op in ["sum", "max", "min", "bitand", "bitor"] @T.prim_func @@ -33,16 +32,16 @@ def get_kernel(reduce_op: str, dtype: str): def test_warp_reduce_sum(): - a = torch.randn((32,), dtype=torch.float32, device='cuda') - kernel = get_kernel('sum', 'float32') + a = torch.randn((32,), dtype=torch.float32, device="cuda") + kernel = get_kernel("sum", "float32") ref = torch.full_like(a, a.sum()) kernel(a) torch.testing.assert_close(a, ref) def test_warp_reduce_max(): - a = torch.randn((32,), dtype=torch.float32, device='cuda') - kernel = get_kernel("max", 'float32') + a = torch.randn((32,), dtype=torch.float32, device="cuda") + kernel = get_kernel("max", "float32") print(kernel.get_kernel_source()) ref = torch.full_like(a, a.max()) kernel(a) @@ -50,16 +49,16 @@ def test_warp_reduce_max(): def test_warp_reduce_min(): - a = torch.randn((32,), dtype=torch.float32, device='cuda') - kernel = get_kernel("min", 'float32') + a = torch.randn((32,), dtype=torch.float32, device="cuda") + kernel = get_kernel("min", "float32") ref = torch.full_like(a, a.min()) kernel(a) torch.testing.assert_close(a, ref) def test_warp_reduce_bitand(): - a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device='cuda') - kernel = get_kernel("bitand", 'int32') + a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device="cuda") + kernel = get_kernel("bitand", "int32") ref_val = a[0] for i in range(1, a.shape[0]): ref_val = ref_val & a[i] @@ -69,8 +68,8 @@ def test_warp_reduce_bitand(): def test_warp_reduce_bitor(): - a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device='cuda') - kernel = get_kernel("bitor", 'int32') + a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device="cuda") + kernel = get_kernel("bitor", "int32") ref_val = a[0] for i in range(1, a.shape[0]): ref_val = ref_val | a[i] diff --git a/testing/python/layout/test_tilelang_layout_fused_replicate.py b/testing/python/layout/test_tilelang_layout_fused_replicate.py index d67a87b..6d3c268 100644 --- a/testing/python/layout/test_tilelang_layout_fused_replicate.py +++ b/testing/python/layout/test_tilelang_layout_fused_replicate.py @@ -12,17 +12,16 @@ VEC_SIZE = 32 @tilelang.jit def fused_index_kernel(B: int, M: int, N: int, BLOCK_MN: int, BLOCK_K: int): - @T.prim_func def main( - a: T.Buffer((B, M, N), "bfloat16"), - a_out: T.Buffer((B, M, N), "float32"), + a: T.Buffer((B, M, N), "bfloat16"), + a_out: T.Buffer((B, M, N), "float32"), ): with T.Kernel( - T.ceildiv(M, BLOCK_MN), - T.ceildiv(N, BLOCK_K), - B, - threads=128, + T.ceildiv(M, BLOCK_MN), + T.ceildiv(N, BLOCK_K), + B, + threads=128, ) as (pid_m, pid_n, pid_b): a_fp32_local = T.alloc_fragment((BLOCK_MN * BLOCK_K // VEC_SIZE, VEC_SIZE), "float32") offs_m = pid_m * BLOCK_MN diff --git a/testing/python/math/test_math_bitwise_reduce.py b/testing/python/math/test_math_bitwise_reduce.py index 9c22946..8d7f5a1 100644 --- a/testing/python/math/test_math_bitwise_reduce.py +++ b/testing/python/math/test_math_bitwise_reduce.py @@ -19,12 +19,11 @@ def bitwise_reduce( func, clear=True, ): - @T.prim_func def reduce_func( - A: T.Tensor((M, N), "int32"), - B: T.Tensor((M), "int32"), - Output: T.Tensor((M), "int32"), + A: T.Tensor((M, N), "int32"), + B: T.Tensor((M), "int32"), + Output: T.Tensor((M), "int32"), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_N), "int32") @@ -64,7 +63,7 @@ def run_single_bitwise_reduce( row_pattern = (i & 0xF) << (i % 4) # 4-bit patterns shifted by row # Column-based pattern: different bit positions set based on column - col_pattern = (1 << (j % 31)) # Single bit set at different positions + col_pattern = 1 << (j % 31) # Single bit set at different positions # Combine patterns with XOR to create diverse bit distributions # Add some deterministic "noise" based on position @@ -76,7 +75,7 @@ def run_single_bitwise_reduce( if i % 4 == 0: a[i, j] &= ~(0x1 << (i // 4)) elif i % 2 == 0: - a[i, j] |= (0x1 << (i // 2)) + a[i, j] |= 0x1 << (i // 2) if name == "reduce_bitand": expected = torch.full((M,), -1, device="cuda", dtype=torch.int32) diff --git a/testing/python/math/test_math_fast_math.py b/testing/python/math/test_math_fast_math.py index c3b5d1b..7809983 100644 --- a/testing/python/math/test_math_fast_math.py +++ b/testing/python/math/test_math_fast_math.py @@ -7,16 +7,16 @@ import re def get_mathop_lines(source, mathop_name): """Extract lines containing the mathop from CUDA source for debugging""" - lines = source.split('\n') + lines = source.split("\n") relevant_lines = [] for i, line in enumerate(lines): - if mathop_name in line and ('(' in line): + if mathop_name in line and ("(" in line): # Include some context start = max(0, i - 1) end = min(len(lines), i + 2) relevant_lines.extend([f"{j}: {lines[j]}" for j in range(start, end)]) relevant_lines.append("---") - return '\n'.join(relevant_lines[-10:]) # Show last 10 lines to avoid too much output + return "\n".join(relevant_lines[-10:]) # Show last 10 lines to avoid too much output def check_fastmath_usage(source, mathop_name, expect_fastmath=False): @@ -27,9 +27,7 @@ def check_fastmath_usage(source, mathop_name, expect_fastmath=False): fastmath_matches = re.findall(fastmath_pattern, source) non_fastmath_matches = re.findall(non_fastmath_pattern, source) - print( - f"Found {len(fastmath_matches)} fastmath calls, {len(non_fastmath_matches)} non-fastmath calls" - ) + print(f"Found {len(fastmath_matches)} fastmath calls, {len(non_fastmath_matches)} non-fastmath calls") if len(fastmath_matches) > 0: print(f"Fastmath calls found: {fastmath_matches}") if len(non_fastmath_matches) > 0: @@ -51,13 +49,7 @@ def check_non_fastmath_usage(source, mathop_name): check_fastmath_usage(source, mathop_name, expect_fastmath=False) -def run_single_arg_mathop_test(mathop_name, - mathop_func, - M=128, - N=128, - block_M=32, - block_N=32, - dtype="float32"): +def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"): """ Test single-argument mathops. T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath) @@ -65,13 +57,12 @@ def run_single_arg_mathop_test(mathop_name, @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): for i, j in T.Parallel(block_M, block_N): - B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, - bx * block_N + j]) + B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j]) # Test with FAST_MATH disabled kernel_no_fastmath = tilelang.compile( @@ -80,7 +71,8 @@ def run_single_arg_mathop_test(mathop_name, target="cuda", pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, - }) + }, + ) source_no_fastmath = kernel_no_fastmath.get_kernel_source() @@ -93,28 +85,22 @@ def run_single_arg_mathop_test(mathop_name, print(f"✓ {mathop_name} compilation and execution test passed") -def run_two_arg_mathop_test(mathop_name, - mathop_func, - M=128, - N=128, - block_M=32, - block_N=32, - dtype="float32"): +def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"): """ Test two-argument mathops to ensure they generate non-fastmath CUDA code. """ @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, - bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j], - B[by * block_M + i, bx * block_N + j]) + C[by * block_M + i, bx * block_N + j] = mathop_func( + A[by * block_M + i, bx * block_N + j], B[by * block_M + i, bx * block_N + j] + ) # Test with FAST_MATH disabled kernel_no_fastmath = tilelang.compile( @@ -123,7 +109,8 @@ def run_two_arg_mathop_test(mathop_name, target="cuda", pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, - }) + }, + ) # Test with FAST_MATH enabled kernel_fastmath = tilelang.compile( @@ -132,7 +119,8 @@ def run_two_arg_mathop_test(mathop_name, target="cuda", pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, + ) source_no_fastmath = kernel_no_fastmath.get_kernel_source() source_fastmath = kernel_fastmath.get_kernel_source() @@ -171,8 +159,8 @@ def run_abs_test(): @T.prim_func def main( - A: T.Tensor((M, N), "float32"), - B: T.Tensor((M, N), "float32"), + A: T.Tensor((M, N), "float32"), + B: T.Tensor((M, N), "float32"), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): for i, j in T.Parallel(block_M, block_N): @@ -184,7 +172,8 @@ def run_abs_test(): target="cuda", pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, - }) + }, + ) source = kernel.get_kernel_source() print("\n=== Testing abs (maps to fabs) ===") @@ -199,26 +188,19 @@ def run_abs_test(): print("✓ abs numerical test passed") -def run_fastmath_mathop_test(mathop_name, - mathop_func, - M=128, - N=128, - block_M=32, - block_N=32, - dtype="float32"): +def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"): """ Test fastmath mathops to ensure they generate fastmath CUDA code (with __ prefix). """ @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): for i, j in T.Parallel(block_M, block_N): - B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, - bx * block_N + j]) + B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j]) # Test with FAST_MATH enabled kernel_fastmath = tilelang.compile( @@ -227,14 +209,15 @@ def run_fastmath_mathop_test(mathop_name, target="cuda", pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, + ) source_fastmath = kernel_fastmath.get_kernel_source() print(f"\n=== Testing {mathop_name} (fastmath version) ===") print("FAST_MATH=True:") # Strip the __ prefix for checking in the CUDA source - cuda_mathop_name = mathop_name.lstrip('_') + cuda_mathop_name = mathop_name.lstrip("_") check_fastmath_usage(source_fastmath, cuda_mathop_name, expect_fastmath=True) # Test numerical correctness diff --git a/testing/python/math/test_math_ieee_math.py b/testing/python/math/test_math_ieee_math.py index 0b04e3b..193092e 100644 --- a/testing/python/math/test_math_ieee_math.py +++ b/testing/python/math/test_math_ieee_math.py @@ -5,14 +5,7 @@ import tilelang.testing import pytest -def run_ieee_math_test(mathop_name, - mathop_func, - rounding_mode="rn", - M=128, - N=128, - block_M=32, - block_N=32, - dtype="float32"): +def run_ieee_math_test(mathop_name, mathop_func, rounding_mode="rn", M=128, N=128, block_M=32, block_N=32, dtype="float32"): """ Test IEEE-compliant math operations with specified rounding modes. """ @@ -22,18 +15,19 @@ def run_ieee_math_test(mathop_name, @T.prim_func def main_func( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), - C: T.Tensor((M, N), dtype), - D: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + C: T.Tensor((M, N), dtype), + D: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): for i, j in T.Parallel(block_M, block_N): - D[by * block_M + i, - bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j], - B[by * block_M + i, bx * block_N + j], - C[by * block_M + i, - bx * block_N + j], rounding_mode) + D[by * block_M + i, bx * block_N + j] = mathop_func( + A[by * block_M + i, bx * block_N + j], + B[by * block_M + i, bx * block_N + j], + C[by * block_M + i, bx * block_N + j], + rounding_mode, + ) out_idx = [3] num_inputs = 3 @@ -41,16 +35,15 @@ def run_ieee_math_test(mathop_name, @T.prim_func def main_func( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, - bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j], - B[by * block_M + i, - bx * block_N + j], rounding_mode) + C[by * block_M + i, bx * block_N + j] = mathop_func( + A[by * block_M + i, bx * block_N + j], B[by * block_M + i, bx * block_N + j], rounding_mode + ) out_idx = [2] num_inputs = 2 @@ -58,14 +51,12 @@ def run_ieee_math_test(mathop_name, @T.prim_func def main_func( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): for i, j in T.Parallel(block_M, block_N): - B[by * block_M + i, - bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j], - rounding_mode) + B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j], rounding_mode) out_idx = [1] num_inputs = 1 @@ -77,7 +68,8 @@ def run_ieee_math_test(mathop_name, target="cuda", pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, - }) + }, + ) print(f"\n=== Testing {mathop_name} with rounding mode {rounding_mode} ===") print(f"✓ {mathop_name} compilation test passed") @@ -194,8 +186,8 @@ def test_ieee_frsqrt_rn_only(): @T.prim_func def main( - A: T.Tensor((128, 128), "float32"), - B: T.Tensor((128, 128), "float32"), + A: T.Tensor((128, 128), "float32"), + B: T.Tensor((128, 128), "float32"), ): with T.Kernel(T.ceildiv(128, 32), T.ceildiv(128, 32), threads=128) as (bx, by): for i, j in T.Parallel(32, 32): @@ -207,7 +199,8 @@ def test_ieee_frsqrt_rn_only(): target="cuda", pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, - }) + }, + ) print("\n=== Testing ieee_frsqrt (rn only) ===") print("✓ ieee_frsqrt compilation test passed") diff --git a/testing/python/metal/test_metal_codegen.py b/testing/python/metal/test_metal_codegen.py index 22f4beb..ea088ae 100644 --- a/testing/python/metal/test_metal_codegen.py +++ b/testing/python/metal/test_metal_codegen.py @@ -5,18 +5,17 @@ import tilelang.language as T import torch -@tilelang.jit(execution_backend='torch') +@tilelang.jit(execution_backend="torch") def matmul(M, N, K, block_M, block_N, block_K, dtype="float32", accum_dtype="float"): - @T.prim_func def gemm( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), dtype, scope='shared') - B_shared = T.alloc_shared((block_K, block_N), dtype, scope='shared') + A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared") + B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared") C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) @@ -48,13 +47,13 @@ def assert_gemm( torch_dtype = getattr(torch, dtype) a, b = None, None - if 'int' in dtype: - a = torch.randint(100, (M, K), dtype=torch_dtype, device='mps') - b = torch.randint(100, (K, N), dtype=torch_dtype, device='mps') + if "int" in dtype: + a = torch.randint(100, (M, K), dtype=torch_dtype, device="mps") + b = torch.randint(100, (K, N), dtype=torch_dtype, device="mps") else: - a = torch.randn(M, K, dtype=torch_dtype, device='mps') - b = torch.randn(K, N, dtype=torch_dtype, device='mps') - c = torch.zeros(M, N, dtype=torch_dtype, device='mps') + a = torch.randn(M, K, dtype=torch_dtype, device="mps") + b = torch.randn(K, N, dtype=torch_dtype, device="mps") + c = torch.zeros(M, N, dtype=torch_dtype, device="mps") jit_kernel(a, b, c) @@ -70,12 +69,12 @@ def test_gemm_float32(): @tilelang.testing.requires_metal def test_gemm_float16(): - assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype='float16', atol=1) + assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype="float16", atol=1) @tilelang.testing.requires_metal def test_gemm_int32(): - assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype='int32', atol=1) + assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype="int32", atol=1) if __name__ == "__main__": diff --git a/testing/python/primitives/test_tilelang_primitives_mma.py b/testing/python/primitives/test_tilelang_primitives_mma.py index fcda987..97ce323 100644 --- a/testing/python/primitives/test_tilelang_primitives_mma.py +++ b/testing/python/primitives/test_tilelang_primitives_mma.py @@ -27,9 +27,9 @@ def matmul_ssr( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) @@ -88,7 +88,8 @@ def run_matmul_ssr( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = kernel.get_profiler() def ref_program(A, B): @@ -106,24 +107,9 @@ def run_matmul_ssr( def test_gemm_f16f16f16_nt_ssr(): - run_matmul_ssr( - 16, 16, 16, False, True, "float16", "float16", "float16", 16, 16, 16, 0, num_threads=32) - run_matmul_ssr( - 128, 128, 128, False, True, "float16", "float16", "float16", 32, 32, 32, 0, num_threads=64) - run_matmul_ssr( - 1024, - 1024, - 1024, - False, - True, - "float16", - "float16", - "float16", - 128, - 128, - 32, - 2, - num_threads=128) + run_matmul_ssr(16, 16, 16, False, True, "float16", "float16", "float16", 16, 16, 16, 0, num_threads=32) + run_matmul_ssr(128, 128, 128, False, True, "float16", "float16", "float16", 32, 32, 32, 0, num_threads=64) + run_matmul_ssr(1024, 1024, 1024, False, True, "float16", "float16", "float16", 128, 128, 32, 2, num_threads=128) def matmul_rsr( @@ -151,9 +137,9 @@ def matmul_rsr( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) @@ -214,7 +200,8 @@ def run_matmul_rsr( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = kernel.get_profiler() def ref_program(A, B): @@ -276,9 +263,9 @@ def matmul_rrr( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -342,7 +329,8 @@ def run_matmul_rrr( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = kernel.get_profiler() def ref_program(A, B): diff --git a/testing/python/profiler/test_tilelang_profiler.py b/testing/python/profiler/test_tilelang_profiler.py index ee46725..8aa5470 100644 --- a/testing/python/profiler/test_tilelang_profiler.py +++ b/testing/python/profiler/test_tilelang_profiler.py @@ -4,12 +4,11 @@ import tilelang.language as T @tilelang.jit(out_idx=[-1]) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def gemm( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py index d984ad4..a13e453 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py @@ -27,9 +27,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -89,7 +89,8 @@ def run_gemm_ss( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) @@ -159,9 +160,9 @@ def matmul_rs( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -169,9 +170,11 @@ def matmul_rs( A_frag = T.alloc_fragment(A_frag_shape, in_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) - T.annotate_layout({ - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - }) + T.annotate_layout( + { + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + } + ) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): if trans_A: T.copy(A[k * block_K, by * block_M], A_shared) @@ -225,7 +228,8 @@ def run_gemm_rs( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) def ref_program(A, B): @@ -294,9 +298,9 @@ def matmul_sr( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -304,9 +308,11 @@ def matmul_sr( B_frag = T.alloc_fragment(B_frag_shape, in_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) - T.annotate_layout({ - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - }) + T.annotate_layout( + { + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + } + ) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): if trans_A: T.copy(A[k * block_K, by * block_M], A_shared) @@ -360,7 +366,8 @@ def run_gemm_sr( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) def ref_program(A, B): @@ -430,9 +437,9 @@ def matmul_rr( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -441,10 +448,12 @@ def matmul_rr( B_frag = T.alloc_fragment(B_frag_shape, in_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) - T.annotate_layout({ - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + } + ) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): if trans_A: T.copy(A[k * block_K, by * block_M], A_shared) @@ -499,7 +508,8 @@ def run_gemm_rr( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) def ref_program(A, B): diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py index cefe986..4ced4f8 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py @@ -20,27 +20,11 @@ def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype): low, high = (0, 4) if is_unsigned else (-2, 2) else: low, high = (0, 128) if is_unsigned else (-64, 64) - A = randint_semi_sparse( - M, - K, - low=low, - high=high, - dtype=map_torch_type(in_dtype), - device='cuda', - transposed=trans_A) - B = torch.randint( - size=(N, K) if trans_B else (K, N), - low=low, - high=high, - dtype=map_torch_type(in_dtype), - device='cuda') + A = randint_semi_sparse(M, K, low=low, high=high, dtype=map_torch_type(in_dtype), device="cuda", transposed=trans_A) + B = torch.randint(size=(N, K) if trans_B else (K, N), low=low, high=high, dtype=map_torch_type(in_dtype), device="cuda") else: - A = randn_semi_sparse( - M, K, dtype=torch.float32, device='cuda', - transposed=trans_A).to(map_torch_type(in_dtype)) - B = torch.randn( - (N, K) if trans_B else (K, N), device='cuda', - dtype=torch.float32).to(map_torch_type(in_dtype)) + A = randn_semi_sparse(M, K, dtype=torch.float32, device="cuda", transposed=trans_A).to(map_torch_type(in_dtype)) + B = torch.randn((N, K) if trans_B else (K, N), device="cuda", dtype=torch.float32).to(map_torch_type(in_dtype)) return A, B @@ -69,24 +53,22 @@ def matmul_sp_sm90( @T.prim_func def main( - A_sparse: T.Tensor(A_sparse_shape, in_dtype), - E: T.Tensor((M, K // E_factor), 'uint8'), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // E_factor), "uint8"), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) - E_shared = T.alloc_shared((block_M, block_K // E_factor), 'uint8') + E_shared = T.alloc_shared((block_M, block_K // E_factor), "uint8") C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) - T.annotate_layout({ - E: - make_cutlass_metadata_layout( - E, mma_dtype=in_dtype, arch="9.0", block_k=block_K), - E_shared: - make_cutlass_metadata_layout( - E_shared, mma_dtype=in_dtype, arch="9.0", block_k=block_K), - }) + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="9.0", block_k=block_K), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="9.0", block_k=block_K), + } + ) T.disable_warp_group_reg_alloc() T.clear(C_frag) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): @@ -121,7 +103,7 @@ def matmul_sp_sm80( trans_B, ): is_8_bit = "8" in in_dtype - metadata_dtype = 'int32' if is_8_bit else 'int16' + metadata_dtype = "int32" if is_8_bit else "int16" E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype] A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) B_shape = (K, N) if not trans_B else (N, K) @@ -132,20 +114,22 @@ def matmul_sp_sm80( @T.prim_func def main( - A_sparse: T.Tensor(A_sparse_shape, in_dtype), - E: T.Tensor((M, K // E_factor), metadata_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // E_factor), metadata_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) - T.annotate_layout({ - E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), - E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), - }) + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), + } + ) T.clear(C_frag) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(E[by * block_M, k * block_K // E_factor], E_shared) @@ -216,7 +200,7 @@ def run_gemm_sp( C = _matmul(A, B) - if 'float8' in in_dtype: + if "float8" in in_dtype: diff = calc_diff(C_sp, C) assert diff < 1e-3, f"{diff=}" else: @@ -332,15 +316,11 @@ def test_gemm_sp_sm90(): run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 0, 128) run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 2, 128) - run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, - True) - run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, - False) - run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, - True) + run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, True) + run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, False) + run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, True) - run_gemm_sp_sm90(512, 1024, 768, "float8_e4m3", "float16", "float16", 64, 64, 64, 2, 128, False, - True) + run_gemm_sp_sm90(512, 1024, 768, "float8_e4m3", "float16", "float16", 64, 64, 64, 2, 128, False, True) run_gemm_sp_sm90(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True) @@ -352,12 +332,9 @@ def test_gemm_sp_sm80(): run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32) run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128) - run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 32, 32, 64, 0, 32, False, - True) - run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32, False, - True) - run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, - True) + run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 32, 32, 64, 0, 32, False, True) + run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32, False, True) + run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, True) run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 1, 128) run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 2, 128) diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py index a82c29f..276bce4 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py @@ -34,20 +34,22 @@ def matmul( @T.prim_func def main( - A_sparse: T.Tensor(A_sparse_shape, in_dtype), - E: T.Tensor((M, K // E_factor), metadata_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // E_factor), metadata_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) - T.annotate_layout({ - E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), - E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), - }) + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), + } + ) T.clear(C_frag) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(E[by * block_M, k * block_K // E_factor], E_shared) @@ -80,7 +82,7 @@ def run_gemm_ss( num_stages=3, num_threads=128, ): - metadata_dtype = 'int32' if ('8' in in_dtype) else 'int16' + metadata_dtype = "int32" if ("8" in in_dtype) else "int16" program = matmul( M, N, @@ -105,7 +107,8 @@ def run_gemm_ss( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0") @@ -142,26 +145,11 @@ def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype): low, high = (0, 4) if is_unsigned else (-2, 2) else: low, high = (0, 128) if is_unsigned else (-64, 64) - A = randint_semi_sparse( - M, - K, - low=low, - high=high, - dtype=map_torch_type(in_dtype), - device='cuda', - transposed=trans_A) - B = torch.randint( - size=(N, K) if trans_B else (K, N), - low=low, - high=high, - dtype=map_torch_type(in_dtype), - device='cuda') + A = randint_semi_sparse(M, K, low=low, high=high, dtype=map_torch_type(in_dtype), device="cuda", transposed=trans_A) + B = torch.randint(size=(N, K) if trans_B else (K, N), low=low, high=high, dtype=map_torch_type(in_dtype), device="cuda") else: - A = randn_semi_sparse( - M, K, dtype=map_torch_type(in_dtype), device='cuda', transposed=trans_A) - B = torch.randn( - (N, K) if trans_B else (K, N), device='cuda', - dtype=torch.float32).to(map_torch_type(in_dtype)) + A = randn_semi_sparse(M, K, dtype=map_torch_type(in_dtype), device="cuda", transposed=trans_A) + B = torch.randn((N, K) if trans_B else (K, N), device="cuda", dtype=torch.float32).to(map_torch_type(in_dtype)) return A, B @@ -184,8 +172,7 @@ def test_gemm_ss(): run_gemm_ss(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2) # float8 tests - run_gemm_ss(128, 128, 128, False, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, - 2) + run_gemm_ss(128, 128, 128, False, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2) run_gemm_ss(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2) # tfloat32 test @@ -222,10 +209,10 @@ def matmul_rs( @T.prim_func def main( - A_sparse: T.Tensor(A_sparse_shape, in_dtype), - E: T.Tensor((M, K // E_factor), metadata_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // E_factor), metadata_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -233,11 +220,13 @@ def matmul_rs( E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) A_frag = T.alloc_fragment(A_frag_shape, in_dtype) C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) - T.annotate_layout({ - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), - E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), - }) + T.annotate_layout( + { + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), + } + ) T.clear(C_frag) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(E[by * block_M, k * block_K // E_factor], E_shared) @@ -271,7 +260,7 @@ def run_gemm_rs( num_stages=3, num_threads=128, ): - metadata_dtype = 'int32' if ('8' in in_dtype) else 'int16' + metadata_dtype = "int32" if ("8" in in_dtype) else "int16" program = matmul_rs( M, N, @@ -296,7 +285,8 @@ def run_gemm_rs( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0") C_sp = kernel(A_sparse, E, B) @@ -376,10 +366,10 @@ def matmul_sr( @T.prim_func def main( - A_sparse: T.Tensor(A_sparse_shape, in_dtype), - E: T.Tensor((M, K // E_factor), metadata_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // E_factor), metadata_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -387,11 +377,13 @@ def matmul_sr( E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) B_frag = T.alloc_fragment(B_frag_shape, in_dtype) C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) - T.annotate_layout({ - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), - E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), - }) + T.annotate_layout( + { + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), + } + ) T.clear(C_frag) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(E[by * block_M, k * block_K // E_factor], E_shared) @@ -425,7 +417,7 @@ def run_gemm_sr( num_stages=3, num_threads=128, ): - metadata_dtype = 'int32' if ('8' in in_dtype) else 'int16' + metadata_dtype = "int32" if ("8" in in_dtype) else "int16" program = matmul_sr( M, N, @@ -450,7 +442,8 @@ def run_gemm_sr( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0") C_sp = kernel(A_sparse, E, B) @@ -531,10 +524,10 @@ def matmul_rr( @T.prim_func def main( - A_sparse: T.Tensor(A_sparse_shape, in_dtype), - E: T.Tensor((M, K // E_factor), metadata_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // E_factor), metadata_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -543,12 +536,14 @@ def matmul_rr( A_frag = T.alloc_fragment(A_frag_shape, in_dtype) B_frag = T.alloc_fragment(B_frag_shape, in_dtype) C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) - T.annotate_layout({ - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), - E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), - }) + T.annotate_layout( + { + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), + } + ) T.clear(C_frag) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(E[by * block_M, k * block_K // E_factor], E_shared) @@ -583,7 +578,7 @@ def run_gemm_rr( num_stages=3, num_threads=128, ): - metadata_dtype = 'int32' if ('8' in in_dtype) else 'int16' + metadata_dtype = "int32" if ("8" in in_dtype) else "int16" program = matmul_rr( M, N, @@ -608,7 +603,8 @@ def run_gemm_rr( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0") C_sp = kernel(A_sparse, E, B) diff --git a/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py b/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py index 7cb1b55..d3f45c5 100644 --- a/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py +++ b/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py @@ -11,22 +11,14 @@ def _check(original, transformed): mod = tl.transform.Simplify()(mod) mod = tl.transform.LowerOpaqueBlock()(mod) mod = tl.transform.Simplify()(mod) - tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), - True) + tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), True) def test_trival_pipeline(): - @T.prim_func def before(A: T.Tensor((16, 1), "float32"), C: T.Tensor((16, 1), "float32")): for tx in T.thread_binding(0, 16, thread="threadIdx.x"): - for i in T.serial( - 0, - 1, - annotations={ - "software_pipeline_stage": [0, 1], - "software_pipeline_order": [0, 1] - }): + for i in T.serial(0, 1, annotations={"software_pipeline_stage": [0, 1], "software_pipeline_order": [0, 1]}): with T.block(): T.reads(A[tx, i]) T.writes(C[tx, i]) diff --git a/testing/python/transform/test_tilelang_transform_cluster_planning.py b/testing/python/transform/test_tilelang_transform_cluster_planning.py index 8029305..2ec6321 100644 --- a/testing/python/transform/test_tilelang_transform_cluster_planning.py +++ b/testing/python/transform/test_tilelang_transform_cluster_planning.py @@ -21,10 +21,8 @@ def _check(original, transformed): def test_cluster_planning(): - @T.prim_func - def before(A: T.Tensor((1024, 32), "float16"), B: T.Tensor((32, 1024), "float16"), C: T.Tensor( - (1024, 1024), "float16")): + def before(A: T.Tensor((1024, 32), "float16"), B: T.Tensor((32, 1024), "float16"), C: T.Tensor((1024, 1024), "float16")): with T.Kernel(8, 8, threads=128) as (bx, by): A_shared = T.alloc_shared((128, 32), "float16") B_shared = T.alloc_shared((32, 128), "float16") @@ -41,8 +39,7 @@ def test_cluster_planning(): T.copy(C_local, C[by * 128, bx * 128]) @T.prim_func - def after(A: T.Tensor((1024, 32), "float16"), B: T.Tensor((32, 1024), "float16"), C: T.Tensor( - (1024, 1024), "float16")): + def after(A: T.Tensor((1024, 32), "float16"), B: T.Tensor((32, 1024), "float16"), C: T.Tensor((1024, 1024), "float16")): T.func_attr({"clusterIdx.y": T.int32(2)}) with T.Kernel(8, 8, threads=128) as (bx, by): A_shared = T.alloc_shared((128, 32), "float16") diff --git a/testing/python/transform/test_tilelang_transform_config_index_bitwidth.py b/testing/python/transform/test_tilelang_transform_config_index_bitwidth.py index 1ef1589..339b283 100644 --- a/testing/python/transform/test_tilelang_transform_config_index_bitwidth.py +++ b/testing/python/transform/test_tilelang_transform_config_index_bitwidth.py @@ -9,7 +9,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) block_N = 64 num_stages = 0 threads = 128 - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) batch = T.int32(batch) heads = T.int32(heads) @@ -24,7 +24,6 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) block_mask_dtype = "bool" def kernel_func(block_M, block_N, num_stages, threads): - @T.macro def MMA0( K: T.Tensor(shape, dtype), @@ -36,37 +35,36 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) by: T.int32, bz: T.int32, ): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @T.macro def MMA1( - V: T.Tensor(shape, dtype), - V_shared: T.Tensor([block_M, dim], dtype), - acc_s_cast: T.Tensor([block_M, block_N], dtype), - acc_o: T.Tensor([block_M, dim], accum_dtype), - k: T.int32, - by: T.int32, - bz: T.int32, + V: T.Tensor(shape, dtype), + V_shared: T.Tensor([block_M, dim], dtype), + acc_s_cast: T.Tensor([block_M, block_N], dtype), + acc_o: T.Tensor([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @T.macro def Softmax( - acc_s: T.Tensor([block_M, block_N], accum_dtype), - acc_s_cast: T.Tensor([block_M, block_N], dtype), - scores_max: T.Tensor([block_M], accum_dtype), - scores_max_prev: T.Tensor([block_M], accum_dtype), - scores_scale: T.Tensor([block_M], accum_dtype), - scores_sum: T.Tensor([block_M], accum_dtype), - logsum: T.Tensor([block_M], accum_dtype), + acc_s: T.Tensor([block_M, block_N], accum_dtype), + acc_s_cast: T.Tensor([block_M, block_N], dtype), + scores_max: T.Tensor([block_M], accum_dtype), + scores_max_prev: T.Tensor([block_M], accum_dtype), + scores_scale: T.Tensor([block_M], accum_dtype), + scores_sum: T.Tensor([block_M], accum_dtype), + logsum: T.Tensor([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -92,22 +90,21 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) @T.macro def Rescale( - acc_o: T.Tensor([block_M, dim], accum_dtype), - scores_scale: T.Tensor([block_M], accum_dtype), + acc_o: T.Tensor([block_M, dim], accum_dtype), + scores_scale: T.Tensor([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def main( - Q: T.Tensor(shape, dtype), - K: T.Tensor(shape, dtype), - V: T.Tensor(shape, dtype), - BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), - Output: T.Tensor(shape, dtype), + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), + Output: T.Tensor(shape, dtype), ): - with T.Kernel( - T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) @@ -122,7 +119,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) logsum = T.alloc_fragment([block_M], accum_dtype) block_mask = T.alloc_local([downsample_len], block_mask_dtype) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -131,19 +128,18 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) block_mask[vj] = BlockSparseMask[bz, by, bx, vj] loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, - scores_sum, logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main diff --git a/testing/python/transform/test_tilelang_transform_inject_fence_proxy.py b/testing/python/transform/test_tilelang_transform_inject_fence_proxy.py index 2859821..854a261 100644 --- a/testing/python/transform/test_tilelang_transform_inject_fence_proxy.py +++ b/testing/python/transform/test_tilelang_transform_inject_fence_proxy.py @@ -22,7 +22,6 @@ def _check(original, transformed): def test_lower_fence_proxy(): - @T.prim_func def before(): with T.Kernel(8): @@ -30,12 +29,15 @@ def test_lower_fence_proxy(): B_shared = T.decl_buffer((1, 4, 512), "float16", scope="shared.dyn") C_local = T.decl_buffer((32,), scope="local") for i in T.unroll(16): - C_local[i * 2:i * 2 + 2] = T.Broadcast(T.float32(0), 2) - T.call_intrin("handle", tir.op.Op.get("tl.tl_gemm"), - "tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>", - T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3)) + C_local[i * 2 : i * 2 + 2] = T.Broadcast(T.float32(0), 2) + T.call_intrin( + "handle", + tir.op.Op.get("tl.tl_gemm"), + "tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>", + T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3), + ) @T.prim_func def after(): @@ -44,19 +46,21 @@ def test_lower_fence_proxy(): B_shared = T.decl_buffer((1, 4, 512), "float16", scope="shared.dyn") C_local = T.decl_buffer((32,), scope="local") for i in T.unroll(16): - C_local[i * 2:i * 2 + 2] = T.Broadcast(T.float32(0), 2) + C_local[i * 2 : i * 2 + 2] = T.Broadcast(T.float32(0), 2) T.fence_proxy_async() - T.call_intrin("handle", tir.op.Op.get("tl.tl_gemm"), - "tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>", - T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3)) + T.call_intrin( + "handle", + tir.op.Op.get("tl.tl_gemm"), + "tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>", + T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3), + ) _check(before, after) def test_async_to_generic_no_double_fence(): - @T.prim_func def before(): with T.Kernel(8): @@ -90,7 +94,6 @@ def test_async_to_generic_no_double_fence(): def test_proxy_hint_override(): - @T.prim_func def before(): with T.Kernel(8): @@ -123,7 +126,6 @@ def test_proxy_hint_override(): def test_tma_store_sync_injection(): - @T.prim_func def before(): with T.Kernel(8): @@ -154,7 +156,6 @@ def test_tma_store_sync_injection(): def test_wgmma_marked_async(): - @T.prim_func def before(): with T.Kernel(1): @@ -164,9 +165,24 @@ def test_wgmma_marked_async(): C_local = T.decl_buffer((32,), "float16", scope="local") A_shared[0] = T.float16(0) T.warpgroup_arrive() - T.ptx_wgmma_ss("float16", "m64n64k16", T.bool(True), T.bool(True), "fp16", "fp16", - "fp16", desc_a.data, T.int32(0), desc_b.data, T.int32(0), C_local.data, - T.int32(0), T.bool(True), 1, 1) + T.ptx_wgmma_ss( + "float16", + "m64n64k16", + T.bool(True), + T.bool(True), + "fp16", + "fp16", + "fp16", + desc_a.data, + T.int32(0), + desc_b.data, + T.int32(0), + C_local.data, + T.int32(0), + T.bool(True), + 1, + 1, + ) mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) mod = tvm.tir.transform.BindTarget(auto_target)(mod) diff --git a/testing/python/transform/test_tilelang_transform_inject_set_max_nreg.py b/testing/python/transform/test_tilelang_transform_inject_set_max_nreg.py index 95cbf2d..0cc79b9 100644 --- a/testing/python/transform/test_tilelang_transform_inject_set_max_nreg.py +++ b/testing/python/transform/test_tilelang_transform_inject_set_max_nreg.py @@ -35,26 +35,25 @@ def test_inject_set_max_nreg(): T.mbarrier_wait_parity(T.get_mbarrier(k % 3 + 3), T.bitwise_xor(k // 3 % 2, 1)) if v - 128 == 0: T.tma_load( - T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, - 0, 2, 2, 0), T.get_mbarrier(k % 3), - T.tvm_access_ptr( - T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), - k * 32, by * 64) - T.evaluate( - tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3)])) + T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), + T.get_mbarrier(k % 3), + T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), + k * 32, + by * 64, + ) + T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3)])) else: # Consumer branch - should have set_max_nreg(240, 1) for k in range(16): T.mbarrier_wait_parity(T.get_mbarrier(k % 3), k // 3 % 2) T.call_extern( - "handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", - T.tvm_access_ptr( - T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1), - T.tvm_access_ptr( - T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3)) - T.evaluate( - tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)])) + "handle", + "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", + T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3), + ) + T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)])) # Apply the InjectSetMaxNReg pass func = before @@ -67,15 +66,18 @@ def test_inject_set_max_nreg(): set_max_nreg_calls = [] def collect_set_max_nreg(stmt): - if (isinstance(stmt, tvm.tir.Evaluate) and hasattr(stmt.value, 'op') and - hasattr(stmt.value.op, 'name') and stmt.value.op.name == "tl.set_max_nreg"): + if ( + isinstance(stmt, tvm.tir.Evaluate) + and hasattr(stmt.value, "op") + and hasattr(stmt.value.op, "name") + and stmt.value.op.name == "tl.set_max_nreg" + ): set_max_nreg_calls.append(stmt.value) tvm.tir.stmt_functor.post_order_visit(main_func.body, collect_set_max_nreg) # We should have at least 2 set_max_nreg calls (one for producer, one for consumer) - assert len(set_max_nreg_calls - ) >= 2, f"Expected at least 2 set_max_nreg calls, got {len(set_max_nreg_calls)}" + assert len(set_max_nreg_calls) >= 2, f"Expected at least 2 set_max_nreg calls, got {len(set_max_nreg_calls)}" print("InjectSetMaxNReg test passed!") @@ -116,16 +118,18 @@ def test_inject_set_max_nreg_no_set_max_nreg(): set_max_nreg_calls = [] def collect_set_max_nreg(stmt): - if (isinstance(stmt, tvm.tir.Evaluate) and hasattr(stmt.value, 'op') and - hasattr(stmt.value.op, 'name') and stmt.value.op.name == "tl.set_max_nreg"): + if ( + isinstance(stmt, tvm.tir.Evaluate) + and hasattr(stmt.value, "op") + and hasattr(stmt.value.op, "name") + and stmt.value.op.name == "tl.set_max_nreg" + ): set_max_nreg_calls.append(stmt.value) tvm.tir.stmt_functor.post_order_visit(main_func.body, collect_set_max_nreg) # Should have no set_max_nreg calls when no_set_max_nreg is present - assert len( - set_max_nreg_calls - ) == 0, f"Expected 0 set_max_nreg calls when no_set_max_nreg is present, got {len(set_max_nreg_calls)}" + assert len(set_max_nreg_calls) == 0, f"Expected 0 set_max_nreg calls when no_set_max_nreg is present, got {len(set_max_nreg_calls)}" print("InjectSetMaxNReg with no_set_max_nreg test passed!") diff --git a/testing/python/transform/test_tilelang_transform_layout_inference.py b/testing/python/transform/test_tilelang_transform_layout_inference.py index 66415aa..270dd31 100644 --- a/testing/python/transform/test_tilelang_transform_layout_inference.py +++ b/testing/python/transform/test_tilelang_transform_layout_inference.py @@ -8,17 +8,21 @@ import pytest auto_target = tvm.target.Target(determine_target("auto")) -@pytest.mark.parametrize("block_M, block_N, block_K, threads, vec_load_b, dtype", [ - (64, 64, 32, 128, 8, "float16"), -]) +@pytest.mark.parametrize( + "block_M, block_N, block_K, threads, vec_load_b, dtype", + [ + (64, 64, 32, 128, 8, "float16"), + ], +) def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): N = tvm.te.var("n") K = tvm.te.var("k") def before(): - @T.prim_func - def main(B: T.Tensor((K, N), dtype),): + def main( + B: T.Tensor((K, N), dtype), + ): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): B_shared = T.alloc_shared((block_K, block_N), dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") @@ -26,58 +30,62 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): t = thread_bindings for i in T.unroll(0, block_N * block_K // (threads * vec_load_b)): for vec in T.Parallel(vec_load_b): - B_shared[i * (threads * vec_load_b // block_N) + t // - (block_N // vec_load_b), t % (block_N // vec_load_b) * - (block_N // vec_load_b) + vec] = T.if_then_else( - k * block_K + i * (threads * vec_load_b // block_N) + t // - (block_N // vec_load_b) < K and bx * block_N + t % - (block_N // vec_load_b) * (block_N // vec_load_b) < N, - B[k * block_K + i * (threads * vec_load_b // block_N) + - t // (block_N // vec_load_b), bx * block_N + t % - (block_N // vec_load_b) * (block_N // vec_load_b) + vec], - T.float16(0)) + B_shared[ + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ] = T.if_then_else( + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b) < K + and bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) < N, + B[ + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ], + T.float16(0), + ) - return tvm.IRModule({'main': main}) + return tvm.IRModule({"main": main}) def after(): - @T.prim_func - def main(B: T.Tensor((K, N), dtype),): + def main( + B: T.Tensor((K, N), dtype), + ): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): B_shared = T.alloc_shared((block_K, block_N), dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): t = thread_bindings for i in T.unroll(0, block_N * block_K // (threads * vec_load_b)): - if (k * block_K + i * (threads * vec_load_b // block_N) + t // - (block_N // vec_load_b)) * N % vec_load_b == 0: + if (k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b)) * N % vec_load_b == 0: for vec in T.vectorized(vec_load_b): - B_shared[i * (threads * vec_load_b // block_N) + t // - (block_N // vec_load_b), t % (block_N // vec_load_b) * - (block_N // vec_load_b) + vec] = T.if_then_else( - k * block_K + i * - (threads * vec_load_b // block_N) + t // - (block_N // vec_load_b) < K and bx * block_N + t % - (block_N // vec_load_b) * (block_N // vec_load_b) < N, - B[k * block_K + i * (threads * vec_load_b // block_N) + - t // (block_N // vec_load_b), - bx * block_N + t % (block_N // vec_load_b) * - (block_N // vec_load_b) + vec], T.float16(0)) + B_shared[ + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ] = T.if_then_else( + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b) < K + and bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) < N, + B[ + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ], + T.float16(0), + ) else: for vec in T.serial(vec_load_b): - B_shared[i * (threads * vec_load_b // block_N) + t // - (block_N // vec_load_b), t % (block_N // vec_load_b) * - (block_N // vec_load_b) + vec] = T.if_then_else( - k * block_K + i * - (threads * vec_load_b // block_N) + t // - (block_N // vec_load_b) < K and bx * block_N + t % - (block_N // vec_load_b) * (block_N // vec_load_b) < N, - B[k * block_K + i * (threads * vec_load_b // block_N) + - t // (block_N // vec_load_b), - bx * block_N + t % (block_N // vec_load_b) * - (block_N // vec_load_b) + vec], T.float16(0)) + B_shared[ + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ] = T.if_then_else( + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b) < K + and bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) < N, + B[ + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ], + T.float16(0), + ) - return tvm.IRModule({'main': main}) + return tvm.IRModule({"main": main}) with tvm.target.Target(auto_target): mod = tvm.tir.transform.BindTarget(auto_target)(before()) diff --git a/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py b/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py index 5202ab6..35a85aa 100644 --- a/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py +++ b/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py @@ -8,7 +8,9 @@ def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_off dtype = "float32" @T.prim_func - def main(A: T.Tensor((M, N), dtype=dtype),): + def main( + A: T.Tensor((M, N), dtype=dtype), + ): with T.Kernel(1, 1, threads=M) as (bx, by): A_shared = T.alloc_shared((M, N), dtype=dtype) tid = T.get_thread_binding() @@ -16,17 +18,18 @@ def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_off A_shared[tid, j] = A[tid + M_offset, j + N_offset] @T.prim_func - def expected(A: T.Tensor((M, N), dtype=dtype),): + def expected( + A: T.Tensor((M, N), dtype=dtype), + ): with T.Kernel(1, 1, threads=M) as (bx, by): A_shared = T.alloc_shared((M, N), dtype=dtype) tid = T.get_thread_binding() - T.reads(A[tid + M_offset, N_offset:N + N_offset]) + T.reads(A[tid + M_offset, N_offset : N + N_offset]) for j in T.serial(N): A_shared[tid, j] = T.if_then_else( - j + N_offset < N, - T.if_then_else(tid + M_offset < M, A[tid + M_offset, j + N_offset], - T.float32(0)), T.float32(0)) + j + N_offset < N, T.if_then_else(tid + M_offset < M, A[tid + M_offset, j + N_offset], T.float32(0)), T.float32(0) + ) return main, expected @@ -41,13 +44,13 @@ def assert_vectorize_access(M: int = 64, N: int = 64): def issue_1013_buggy_kernel(): # NOTE: This kernel is mainly to test some corner cases in boundary check - num_tokens = T.dynamic('num_tokens') + num_tokens = T.dynamic("num_tokens") num_threads = 128 @T.prim_func def main(x: T.Tensor((num_tokens,), dtype="int64")): with T.Kernel(1, threads=num_threads) as _: - count = T.alloc_var('int') + count = T.alloc_var("int") thread_idx = T.get_thread_binding() for i in T.serial(0, T.ceildiv(num_tokens - thread_idx, num_threads)): idx = thread_idx + i * num_threads @@ -59,24 +62,22 @@ def issue_1013_buggy_kernel(): @T.prim_func def expected(x: T.Tensor((num_tokens,), dtype="int64")): with T.Kernel(1, threads=num_threads) as _: - count = T.alloc_var('int') + count = T.alloc_var("int") thread_idx = T.get_thread_binding() for i in T.serial(0, T.ceildiv(num_tokens - thread_idx, num_threads)): idx = thread_idx + i * num_threads - count += T.Cast("int32", - T.if_then_else(idx < num_tokens, x[idx], T.int64(0)) == T.int64(2)) + count += T.Cast("int32", T.if_then_else(idx < num_tokens, x[idx], T.int64(0)) == T.int64(2)) return main, expected -def vectorize_access_with_atmoic_add_legalize(M: int = 64, - N: int = 64, - M_offset: int = 2, - N_offset: int = 2): +def vectorize_access_with_atmoic_add_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: int = 2): dtype = "float32" @T.prim_func - def main(A: T.Tensor((M, N), dtype=dtype),): + def main( + A: T.Tensor((M, N), dtype=dtype), + ): with T.Kernel(1, 1, threads=M) as (bx, by): A_shared = T.alloc_shared((M, N), dtype=dtype) tid = T.get_thread_binding() @@ -85,17 +86,18 @@ def vectorize_access_with_atmoic_add_legalize(M: int = 64, T.atomic_add(A[tid + M_offset, j + N_offset], 1) @T.prim_func - def expected(A: T.Tensor((M, N), dtype=dtype),): + def expected( + A: T.Tensor((M, N), dtype=dtype), + ): with T.Kernel(1, 1, threads=M) as (bx, by): A_shared = T.alloc_shared((M, N), dtype=dtype) tid = T.get_thread_binding() - T.reads(A[tid + M_offset, N_offset:N + N_offset]) + T.reads(A[tid + M_offset, N_offset : N + N_offset]) for j in T.serial(N): A_shared[tid, j] = T.if_then_else( - j + N_offset < N, - T.if_then_else(tid + M_offset < M, A[tid + M_offset, j + N_offset], - T.float32(0)), T.float32(0)) + j + N_offset < N, T.if_then_else(tid + M_offset < M, A[tid + M_offset, j + N_offset], T.float32(0)), T.float32(0) + ) # Nest if-then-else is expected, do not flatten it to pass structural equal check if j + N_offset < N: # noqa: SIM102 if tid + M_offset < M: @@ -115,17 +117,21 @@ def oob_store_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: in dtype = "float32" @T.prim_func - def main(A: T.Tensor((M, N), dtype=dtype),): + def main( + A: T.Tensor((M, N), dtype=dtype), + ): with T.Kernel(1, 1, threads=M) as (bx, by): tid = T.get_thread_binding() for j in T.serial(N): A[tid + M_offset, j + N_offset] = 1 @T.prim_func - def expected(A: T.Tensor((M, N), dtype=dtype),): + def expected( + A: T.Tensor((M, N), dtype=dtype), + ): with T.Kernel(1, 1, threads=M) as (bx, by): tid = T.get_thread_binding() - T.writes(A[tid + M_offset, N_offset:N + N_offset]) + T.writes(A[tid + M_offset, N_offset : N + N_offset]) for j in T.serial(N): if j + N_offset < N: # noqa: SIM102 if tid + M_offset < M: diff --git a/testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py b/testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py index c95af87..ec570d4 100644 --- a/testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py +++ b/testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py @@ -9,7 +9,9 @@ def vectorize_access_legalize(M: int = 64, N: int = 64): vec_len = 8 @T.prim_func - def main(A: T.Tensor((M, N, vec_len), dtype="float32"),): + def main( + A: T.Tensor((M, N, vec_len), dtype="float32"), + ): with T.Kernel(1, 1, threads=M) as (bx, by): A_shared = T.alloc_shared((M, N, vec_len), dtype=dtype) tid = T.get_thread_binding() @@ -18,7 +20,9 @@ def vectorize_access_legalize(M: int = 64, N: int = 64): A_shared[tid, j, v] = A[tid, j, v] @T.prim_func - def expected(A: T.Tensor((M, N, vec_len), dtype="float32"),): + def expected( + A: T.Tensor((M, N, vec_len), dtype="float32"), + ): with T.Kernel(1, 1, threads=M) as (bx, by): A_shared = T.alloc_shared((M, N, vec_len), dtype=dtype) tid = T.get_thread_binding() diff --git a/testing/python/transform/test_tilelang_transform_let_inline.py b/testing/python/transform/test_tilelang_transform_let_inline.py index aa2638a..6603eca 100644 --- a/testing/python/transform/test_tilelang_transform_let_inline.py +++ b/testing/python/transform/test_tilelang_transform_let_inline.py @@ -8,12 +8,10 @@ def _check(original, transformed): func = original mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tl.transform.LetInline()(mod) - tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), - True) + tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), True) def test_let_binding(): - @T.prim_func def before(A: T.Tensor((128, 128), "float32"), B: T.Tensor((128, 128), "float32")): for i in range(128): @@ -34,7 +32,6 @@ def test_let_binding(): def test_parallel_scope(): - @T.prim_func def before(A: T.Tensor((128,), "float32")): for i in T.Parallel(128): diff --git a/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py b/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py index ca5042e..f411b3d 100644 --- a/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py +++ b/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py @@ -24,7 +24,6 @@ def _check(original, transformed): def test_lower_hopper_intrin_barrier(): - @T.prim_func def before(): with T.Kernel(8): @@ -37,18 +36,10 @@ def test_lower_hopper_intrin_barrier(): v_1 = T.launch_thread("threadIdx.x", 128) T.evaluate(tir.Call("handle", "tir.create_barriers", [4])) with T.If(v_1 == 0), T.Then(): - T.evaluate( - tir.Call("handle", "tir.ptx_init_barrier_thread_count", - [T.get_mbarrier(0), 128])) - T.evaluate( - tir.Call("handle", "tir.ptx_init_barrier_thread_count", - [T.get_mbarrier(1), 128])) - T.evaluate( - tir.Call("handle", "tir.ptx_init_barrier_thread_count", - [T.get_mbarrier(2), 128])) - T.evaluate( - tir.Call("handle", "tir.ptx_init_barrier_thread_count", - [T.get_mbarrier(3), 128])) + T.evaluate(tir.Call("handle", "tir.ptx_init_barrier_thread_count", [T.get_mbarrier(0), 128])) + T.evaluate(tir.Call("handle", "tir.ptx_init_barrier_thread_count", [T.get_mbarrier(1), 128])) + T.evaluate(tir.Call("handle", "tir.ptx_init_barrier_thread_count", [T.get_mbarrier(2), 128])) + T.evaluate(tir.Call("handle", "tir.ptx_init_barrier_thread_count", [T.get_mbarrier(3), 128])) T.evaluate(tir.Call("handle", "tir.tvm_storage_sync", ["shared"])) _check(before, after) diff --git a/testing/python/transform/test_tilelang_transform_lower_tile_op.py b/testing/python/transform/test_tilelang_transform_lower_tile_op.py index 07dbd53..ac58418 100644 --- a/testing/python/transform/test_tilelang_transform_lower_tile_op.py +++ b/testing/python/transform/test_tilelang_transform_lower_tile_op.py @@ -8,63 +8,69 @@ import pytest auto_target = tvm.target.Target(determine_target("auto")) -@pytest.mark.parametrize("block_M, block_N, block_K, threads, vec_load_b, dtype", [ - (64, 64, 32, 128, 8, "float16"), -]) +@pytest.mark.parametrize( + "block_M, block_N, block_K, threads, vec_load_b, dtype", + [ + (64, 64, 32, 128, 8, "float16"), + ], +) def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): N = tvm.te.var("n") K = tvm.te.var("k") def before(): - @T.prim_func - def main(B: T.Tensor((K, N), dtype),): + def main( + B: T.Tensor((K, N), dtype), + ): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): B_shared = T.alloc_shared((block_K, block_N), dtype) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): T.copy(B[k * block_K, bx * block_N], B_shared) - return tvm.IRModule({'main': main}) + return tvm.IRModule({"main": main}) def after(): - @T.prim_func - def main(B: T.Tensor((K, N), dtype),): + def main( + B: T.Tensor((K, N), dtype), + ): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): B_shared = T.alloc_shared((block_K, block_N), dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): t = thread_bindings for i in T.unroll(0, block_N * block_K // (threads * vec_load_b)): - if (k * block_K + i * (threads * vec_load_b // block_N) + t // - (block_N // vec_load_b)) * N % vec_load_b == 0: + if (k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b)) * N % vec_load_b == 0: for vec in T.vectorized(vec_load_b): - B_shared[i * (threads * vec_load_b // block_N) + t // - (block_N // vec_load_b), t % (block_N // vec_load_b) * - (block_N // vec_load_b) + vec] = T.if_then_else( - k * block_K + i * - (threads * vec_load_b // block_N) + t // - (block_N // vec_load_b) < K and bx * block_N + t % - (block_N // vec_load_b) * (block_N // vec_load_b) < N, - B[k * block_K + i * (threads * vec_load_b // block_N) + - t // (block_N // vec_load_b), - bx * block_N + t % (block_N // vec_load_b) * - (block_N // vec_load_b) + vec], T.float16(0)) + B_shared[ + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ] = T.if_then_else( + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b) < K + and bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) < N, + B[ + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ], + T.float16(0), + ) else: for vec in T.serial(vec_load_b): - B_shared[i * (threads * vec_load_b // block_N) + t // - (block_N // vec_load_b), t % (block_N // vec_load_b) * - (block_N // vec_load_b) + vec] = T.if_then_else( - k * block_K + i * - (threads * vec_load_b // block_N) + t // - (block_N // vec_load_b) < K and bx * block_N + t % - (block_N // vec_load_b) * (block_N // vec_load_b) < N, - B[k * block_K + i * (threads * vec_load_b // block_N) + - t // (block_N // vec_load_b), - bx * block_N + t % (block_N // vec_load_b) * - (block_N // vec_load_b) + vec], T.float16(0)) + B_shared[ + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ] = T.if_then_else( + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b) < K + and bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) < N, + B[ + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ], + T.float16(0), + ) - return tvm.IRModule({'main': main}) + return tvm.IRModule({"main": main}) with tvm.transform.PassContext(): mod = tvm.tir.transform.BindTarget(auto_target)(before()) diff --git a/testing/python/transform/test_tilelang_transform_make_packed_api.py b/testing/python/transform/test_tilelang_transform_make_packed_api.py index ff44873..2508a9d 100644 --- a/testing/python/transform/test_tilelang_transform_make_packed_api.py +++ b/testing/python/transform/test_tilelang_transform_make_packed_api.py @@ -80,7 +80,6 @@ def test_target_host_removed(): @I.ir_module class before: - @T.prim_func def main(A: T.Buffer(1, "float32")): T.func_attr({"global_symbol": "main", "target": T.target("cuda", host=host)}) @@ -102,7 +101,6 @@ def test_internal_subroutine_call(): @I.ir_module class before: - @T.prim_func def main(A: T.Buffer(1, "float32")): T.func_attr({"target": T.target("llvm", host="llvm")}) @@ -121,7 +119,8 @@ def test_internal_subroutine_call(): subroutine_call_op = compute_scope.body.value.op assert isinstance(subroutine_call_op, tvm.ir.GlobalVar), ( f"The main function's CallNode should use the subroutine's GLobalVar as the operation, " - f"but instead has an operation of type {subroutine_call_op}") + f"but instead has an operation of type {subroutine_call_op}" + ) def test_subroutine_call_to_externally_visible_subroutine(): @@ -135,7 +134,6 @@ def test_subroutine_call_to_externally_visible_subroutine(): @I.ir_module class before: - @T.prim_func def main(A: T.Buffer(1, "float32")): T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")}) @@ -154,11 +152,10 @@ def test_subroutine_call_to_externally_visible_subroutine(): assert subroutine_compute_scope is not None subroutine_call_op = main_compute_scope.body.value.op - assert ( - isinstance(subroutine_call_op, tvm.ir.Op) and - subroutine_call_op.name == "tir.tvm_call_cpacked" - ), (f"The main function's CallNode should be lowered to the builtin 'tir.tvm_call_cpacked', " - f"but instead has an operation of type {subroutine_call_op}") + assert isinstance(subroutine_call_op, tvm.ir.Op) and subroutine_call_op.name == "tir.tvm_call_cpacked", ( + f"The main function's CallNode should be lowered to the builtin 'tir.tvm_call_cpacked', " + f"but instead has an operation of type {subroutine_call_op}" + ) @tilelang.testing.requires_llvm @@ -167,10 +164,10 @@ def test_function_call_with_wrong_argument_count(): @T.prim_func def func( - A: T.Buffer([16, 16], "int32"), - B: T.Buffer([16, 16], "int32"), - C: T.Buffer([16, 16], "int32"), - D: T.Buffer([16, 16], "int32"), + A: T.Buffer([16, 16], "int32"), + B: T.Buffer([16, 16], "int32"), + C: T.Buffer([16, 16], "int32"), + D: T.Buffer([16, 16], "int32"), ): pass diff --git a/testing/python/transform/test_tilelang_transform_multi_version_buffer.py b/testing/python/transform/test_tilelang_transform_multi_version_buffer.py index ddb7f66..0d56ab1 100644 --- a/testing/python/transform/test_tilelang_transform_multi_version_buffer.py +++ b/testing/python/transform/test_tilelang_transform_multi_version_buffer.py @@ -31,7 +31,6 @@ block_K = 32 def test_multi_version_buffer(): - @T.prim_func def before(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): bx = T.launch_thread("blockIdx.x", 8) @@ -49,21 +48,27 @@ def test_multi_version_buffer(): for k in T.serial(16, annotations={"num_stages": T.int32(3)}): if v == 0: T.tma_load( - T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, - 2, 0), 0, + T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), + 0, T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 2), - k * 32, by * 64) + k * 32, + by * 64, + ) if v == 0: T.tma_load( - T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, - 2, 0), 0, + T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0), + 0, T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 2), - bx * 64, k * 32) + bx * 64, + k * 32, + ) T.call_extern( - "handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", + "handle", + "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3)) + T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3), + ) @T.prim_func def after(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): @@ -82,31 +87,32 @@ def test_multi_version_buffer(): for k in T.serial(16, annotations={"num_stages": T.int32(3)}): if v == 0: T.tma_load( - T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, - 2, 0), 0, - T.tvm_access_ptr( - T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), - k * 32, by * 64) + T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), + 0, + T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), + k * 32, + by * 64, + ) if v == 0: T.tma_load( - T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, - 2, 0), 0, - T.tvm_access_ptr( - T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2), - bx * 64, k * 32) + T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0), + 0, + T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2), + bx * 64, + k * 32, + ) T.call_extern( - "handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", - T.tvm_access_ptr( - T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1), - T.tvm_access_ptr( - T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3)) + "handle", + "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", + T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3), + ) _check(before, after) def test_multi_version_buffer_with_let(): - @T.prim_func def before(scales: T.Tensor((4,), "float32")): with T.block("root"): diff --git a/testing/python/transform/test_tilelang_transform_pipeline_planning.py b/testing/python/transform/test_tilelang_transform_pipeline_planning.py index b7448a2..f38d607 100644 --- a/testing/python/transform/test_tilelang_transform_pipeline_planning.py +++ b/testing/python/transform/test_tilelang_transform_pipeline_planning.py @@ -19,10 +19,8 @@ def _check(original, transformed): def test_simple_pipeline(): - @T.prim_func - def before(A: T.Tensor((1024, 32), "float32"), B: T.Tensor((32, 1024), "float32"), C: T.Tensor( - (1024, 1024), "float32")): + def before(A: T.Tensor((1024, 32), "float32"), B: T.Tensor((32, 1024), "float32"), C: T.Tensor((1024, 1024), "float32")): with T.Kernel(8, 8, threads=128) as (bx, by): A_shared = T.alloc_shared((128, 32), "float32") B_shared = T.alloc_shared((32, 128), "float32") @@ -39,8 +37,7 @@ def test_simple_pipeline(): T.copy(C_local, C[by * 128, bx * 128]) @T.prim_func - def after(A: T.Tensor((1024, 32), "float32"), B: T.Tensor((32, 1024), "float32"), C: T.Tensor( - (1024, 1024), "float32")): + def after(A: T.Tensor((1024, 32), "float32"), B: T.Tensor((32, 1024), "float32"), C: T.Tensor((1024, 1024), "float32")): with T.Kernel(8, 8, threads=128) as (bx, by): A_shared = T.alloc_shared((128, 32), "float32") B_shared = T.alloc_shared((32, 128), "float32") @@ -49,14 +46,13 @@ def test_simple_pipeline(): T.clear(C_local) for ko in T.serial( - 32, - annotations={ - "software_pipeline_async_stages": [T.int32(0)], - "software_pipeline_order": [T.int32(0), T.int32(1), - T.int32(2)], - "software_pipeline_stage": [T.int32(3), T.int32(3), - T.int32(3)] - }): + 32, + annotations={ + "software_pipeline_async_stages": [T.int32(0)], + "software_pipeline_order": [T.int32(0), T.int32(1), T.int32(2)], + "software_pipeline_stage": [T.int32(3), T.int32(3), T.int32(3)], + }, + ): T.copy(A[by * 128, ko * 32], A_shared) T.copy(B[ko * 32, bx * 128], B_shared) T.gemm(A_shared, B_shared, C_local) diff --git a/testing/python/transform/test_tilelang_transform_simplify.py b/testing/python/transform/test_tilelang_transform_simplify.py index e1f4f94..657a2e8 100644 --- a/testing/python/transform/test_tilelang_transform_simplify.py +++ b/testing/python/transform/test_tilelang_transform_simplify.py @@ -8,14 +8,13 @@ def modify( with_B: bool = False, with_bias: bool = False, ): - @T.prim_func def main( - A: T.Tensor((64, 64)), - B: T.Tensor((64, 64)), - C: T.Tensor((64, 64)), - D: T.Tensor((64, 64)), - bias: T.Tensor((64, 64)), + A: T.Tensor((64, 64)), + B: T.Tensor((64, 64)), + C: T.Tensor((64, 64)), + D: T.Tensor((64, 64)), + bias: T.Tensor((64, 64)), ): if with_B: if with_bias: @@ -42,7 +41,6 @@ def test_modify(with_B=False, with_bias=False): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def main( a: T.handle, @@ -76,6 +74,7 @@ def test_matmul(): kernel = tl.compile(mod["main"], out_idx=[2]) import torch + a = torch.randn(1024, 1024, dtype=torch.float16).cuda().half() b = torch.randn(1024, 1024, dtype=torch.float16).cuda().half() c = kernel(a, b) diff --git a/testing/python/transform/test_tilelang_transform_thread_sync.py b/testing/python/transform/test_tilelang_transform_thread_sync.py index c0b7055..046ed44 100644 --- a/testing/python/transform/test_tilelang_transform_thread_sync.py +++ b/testing/python/transform/test_tilelang_transform_thread_sync.py @@ -11,11 +11,7 @@ def run_passes(func: tvm.tir.PrimFunc): cuda_target = tvm.target.Target("cuda", host="llvm") - mod = tvm.tir.transform.Apply(lambda f: f.with_attr({ - "global_symbol": "test", - "target": cuda_target - }))( - mod) + mod = tvm.tir.transform.Apply(lambda f: f.with_attr({"global_symbol": "test", "target": cuda_target}))(mod) mod = tvm.tir.transform.AnnotateDeviceRegions()(mod) mod = tvm.tir.transform.SplitHostDevice()(mod) @@ -24,7 +20,6 @@ def run_passes(func: tvm.tir.PrimFunc): @tilelang.testing.requires_cuda def test_sync_if_with_same_index(): - @T.prim_func(check_well_formed=False) def func(p0_arg: T.Buffer((1, 2, 1, 1), "float32"), p1: T.Buffer(2, "float32")) -> None: threadIdx_x = T.env_thread("threadIdx.x") @@ -47,7 +42,6 @@ def test_sync_if_with_same_index(): @tilelang.testing.requires_cuda def test_sync_read_thread_id_independent_location(): - @T.prim_func def func(p0_arg: T.Buffer((1, 2, 1, 1), "float32"), p1: T.Buffer(2, "float32")) -> None: threadIdx_x = T.env_thread("threadIdx.x") @@ -71,7 +65,6 @@ def test_sync_read_thread_id_independent_location(): @tilelang.testing.requires_cuda def test_sync_shared(): - @T.prim_func(private=True) def func(A: T.Buffer((4, 4), "float32"), E: T.Buffer((4, 4), "float32")): blockIdx_x = T.launch_thread("blockIdx.x", 1) @@ -113,7 +106,6 @@ def test_sync_shared(): @tvm.testing.requires_cuda def test_sync_let_stmt(): - @T.prim_func(private=True) def func(A: T.Buffer((16 * 512), "float32")): blockIdx_x = T.launch_thread("blockIdx.x", 16) @@ -136,9 +128,9 @@ def test_sync_let_stmt(): in_thread_A_temp_1[0] = A_temp cross_thread_A_temp_1 = T.Buffer((1,), data=cross_thread_A_temp, scope="local") with T.attr( - T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), - "reduce_scope", - T.reinterpret("handle", T.uint64(0)), + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), ): T.tvm_thread_allreduce( T.uint32(1), @@ -190,16 +182,19 @@ def test_sync_let_stmt(): @tilelang.testing.requires_cuda def test_sync_shared_dyn_stmatrix_loop_hoist(): - @T.prim_func def func(): buf_dyn_shmem = T.alloc_buffer((98304,), "uint8", scope="shared.dyn") tx = T.launch_thread("threadIdx.x", 384) for i in T.unroll(8): off = ( - i // 4 * 8192 + tx // 32 * 1024 + tx % 16 * 64 + - (tx % 8 // 4 + i % 4 // 2) % 2 * 32 + (tx % 4 // 2 + i % 2) % 2 * 16 + - (tx % 32 // 16 + tx % 2) % 2 * 8) + i // 4 * 8192 + + tx // 32 * 1024 + + tx % 16 * 64 + + (tx % 8 // 4 + i % 4 // 2) % 2 * 32 + + (tx % 4 // 2 + i % 2) % 2 * 16 + + (tx % 32 // 16 + tx % 2) % 2 * 8 + ) T.evaluate( T.call_intrin( "handle", @@ -214,7 +209,8 @@ def test_sync_shared_dyn_stmatrix_loop_hoist(): 2, ), T.int32(2), - )) + ) + ) mod = tvm.IRModule({"main": func}) mod = tilelang.transform.ThreadSync("shared.dyn")(mod) diff --git a/testing/python/transform/test_tilelang_transform_warp_specialized.py b/testing/python/transform/test_tilelang_transform_warp_specialized.py index 063ae29..2e101bf 100644 --- a/testing/python/transform/test_tilelang_transform_warp_specialized.py +++ b/testing/python/transform/test_tilelang_transform_warp_specialized.py @@ -32,7 +32,6 @@ block_K = 32 def test_warp_specialized(): - @T.prim_func def before(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): bx = T.launch_thread("blockIdx.x", 8) @@ -47,25 +46,27 @@ def test_warp_specialized(): for k in T.serial(16, annotations={"num_stages": T.int32(3)}): if v == 0: T.tma_load( - T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, - 2, 0), 0, - T.tvm_access_ptr( - T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), - k * 32, by * 64) + T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), + 0, + T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), + k * 32, + by * 64, + ) if v == 0: T.tma_load( - T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, - 2, 0), 0, - T.tvm_access_ptr( - T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2), - bx * 64, k * 32) + T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0), + 0, + T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2), + bx * 64, + k * 32, + ) T.call_extern( - "handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", - T.tvm_access_ptr( - T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1), - T.tvm_access_ptr( - T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3)) + "handle", + "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", + T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3), + ) @T.prim_func def after(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): @@ -85,34 +86,35 @@ def test_warp_specialized(): T.mbarrier_expect_tx(T.get_mbarrier(k % 3), 4096) if v - 128 == 0: T.tma_load( - T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, - 2, 0), T.get_mbarrier(k % 3), - T.tvm_access_ptr( - T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), - k * 32, by * 64) + T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), + T.get_mbarrier(k % 3), + T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), + k * 32, + by * 64, + ) if v - 128 == 0: T.mbarrier_expect_tx(T.get_mbarrier(k % 3), 4096) if v - 128 == 0: T.tma_load( - T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, - 2, 0), T.get_mbarrier(k % 3), - T.tvm_access_ptr( - T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2), - bx * 64, k * 32) + T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0), + T.get_mbarrier(k % 3), + T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2), + bx * 64, + k * 32, + ) T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3)])) else: T.set_max_nreg(240, 1) for k in range(16): T.mbarrier_wait_parity(T.get_mbarrier(k % 3), k // 3 % 2) T.call_extern( - "handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", - T.tvm_access_ptr( - T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1), - T.tvm_access_ptr( - T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3)) - T.evaluate( - tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)])) + "handle", + "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", + T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3), + ) + T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)])) _check(before, after) diff --git a/testing/python/utils/test_compress_utils.py b/testing/python/utils/test_compress_utils.py index 1ec4cac..e8fc205 100644 --- a/testing/python/utils/test_compress_utils.py +++ b/testing/python/utils/test_compress_utils.py @@ -6,7 +6,7 @@ from tilelang.utils.sparse import compress_sm90, randn_semi_sparse def _test_compress_sm90(M, K, block_k, dtype): - A = randn_semi_sparse(M, K, dtype=dtype, device='cuda') + A = randn_semi_sparse(M, K, dtype=dtype, device="cuda") A_sparse, E = compress_sm90(A, block_k, False) diff --git a/testing/python/webgpu/test_webgpu_codegen.py b/testing/python/webgpu/test_webgpu_codegen.py index 0fe4f19..ed17527 100644 --- a/testing/python/webgpu/test_webgpu_codegen.py +++ b/testing/python/webgpu/test_webgpu_codegen.py @@ -5,12 +5,11 @@ import tilelang.language as T def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): diff --git a/tilelang/__init__.py b/tilelang/__init__.py index 0d8c21b..1f2a4f4 100644 --- a/tilelang/__init__.py +++ b/tilelang/__init__.py @@ -23,6 +23,7 @@ def _compute_version() -> str: if version_file.is_file(): try: from version_provider import dynamic_metadata # type: ignore + return dynamic_metadata("version") except Exception: # Fall back to the raw VERSION file if provider isn't available. @@ -33,6 +34,7 @@ def _compute_version() -> str: try: from importlib.metadata import version as _dist_version # py3.8+ + return _dist_version("tilelang") except Exception as exc: warnings.warn( diff --git a/tilelang/analysis/fragment_loop_checker.py b/tilelang/analysis/fragment_loop_checker.py index 3186b23..94900a5 100644 --- a/tilelang/analysis/fragment_loop_checker.py +++ b/tilelang/analysis/fragment_loop_checker.py @@ -1,6 +1,6 @@ from __future__ import annotations from tvm import tir -from tvm.tir import (PyStmtExprVisitor, BufferStore, For, Var, PrimFunc, BufferLoad, IntImm) +from tvm.tir import PyStmtExprVisitor, BufferStore, For, Var, PrimFunc, BufferLoad, IntImm from tvm.tir.transform import prim_func_pass from tvm.tir.stmt_functor import post_order_visit @@ -22,14 +22,14 @@ class _LoopVarUseAnalyzer(PyStmtExprVisitor): def collect_local_buffer_accesses(statement) -> list[BufferLoad | BufferStore]: """ - Collect local buffer accesses in the loop body. + Collect local buffer accesses in the loop body. - Args: - statement: The TIR statement to analyze + Args: + statement: The TIR statement to analyze - Returns: - Tuple of buffer accesses in the loop body. - """ + Returns: + Tuple of buffer accesses in the loop body. + """ buffer_accesses = [] @@ -44,7 +44,6 @@ def collect_local_buffer_accesses(statement) -> list[BufferLoad | BufferStore]: @tir.functor.visitor class _FragmentLoopCheckVisitor(PyStmtExprVisitor): - def __init__(self) -> None: super().__init__() @@ -75,7 +74,8 @@ class _FragmentLoopCheckVisitor(PyStmtExprVisitor): raise ValueError( "[Tilelang Semantic Check] " f"Loop variable {loop.loop_var} in a T.Parallel loop with symbolic range (min={loop.min}, extent={loop.extent}) is used to index " - "a local/fragment buffer, which is not allowed in Tilelang.") + "a local/fragment buffer, which is not allowed in Tilelang." + ) return diff --git a/tilelang/analysis/layout_visual.py b/tilelang/analysis/layout_visual.py index 782b912..141fb80 100644 --- a/tilelang/analysis/layout_visual.py +++ b/tilelang/analysis/layout_visual.py @@ -23,10 +23,7 @@ def print_fragment_format(layout: T.Fragment) -> str: if isinstance(layout, T.Fragment): input_shape = layout.get_input_shape() output_shape = layout.get_output_shape() - lines = [ - f" Shape: {input_shape} -> {output_shape}", f" Thread: {layout.forward_thread}", - f" Index: {layout.forward_index}" - ] + lines = [f" Shape: {input_shape} -> {output_shape}", f" Thread: {layout.forward_thread}", f" Index: {layout.forward_index}"] print("\n".join(lines)) else: raise ValueError(f"Expected T.Fragment, but got {type(layout).__name__}") @@ -82,7 +79,6 @@ class _LayoutVisualVisitor(PyStmtExprVisitor): def LayoutVisual(formats: str = ""): - def pass_fn(func: tir.PrimFunc, mod, ctx): _LayoutVisualVisitor(formats=formats).visit_stmt(func.body) return func diff --git a/tilelang/analysis/nested_loop_checker.py b/tilelang/analysis/nested_loop_checker.py index eff0fc2..51da7f4 100644 --- a/tilelang/analysis/nested_loop_checker.py +++ b/tilelang/analysis/nested_loop_checker.py @@ -11,10 +11,7 @@ from tvm.tir.transform import prim_func_pass def is_pipelined_for(op: For) -> bool: """Check if a for loop is pipelined.""" - anno_keys = [ - "num_stages", "tl_pipeline_order", "tl_pipeline_stage", "tl_pipeline_sync", - "tl_pipeline_group" - ] + anno_keys = ["num_stages", "tl_pipeline_order", "tl_pipeline_stage", "tl_pipeline_sync", "tl_pipeline_group"] return any(key in op.annotations for key in anno_keys) @@ -26,7 +23,6 @@ def is_tile_op(op: Call) -> bool: @tir.functor.visitor class _NestedLoopCheckVisitor(PyStmtExprVisitor): - def __init__(self) -> None: super().__init__() self.in_parallel_context = False @@ -42,27 +38,24 @@ class _NestedLoopCheckVisitor(PyStmtExprVisitor): # Otherwise if self.in_parallel_context: - raise ValueError("[Tilelang Semantic Check] " - "Nested parallel loops are not allowed. " - "Please check your loop structure.") + raise ValueError("[Tilelang Semantic Check] Nested parallel loops are not allowed. Please check your loop structure.") self.in_parallel_context = True super().visit_for_(op) self.in_parallel_context = False return elif is_pipelined_for(op): if self.in_parallel_context: - raise ValueError("[Tilelang Semantic Check] " - "Pipelined loop cannot be nested inside a parallel loop. " - "Please check your loop structure.") + raise ValueError( + "[Tilelang Semantic Check] Pipelined loop cannot be nested inside a parallel loop. Please check your loop structure." + ) super().visit_for_(op) def visit_call_(self, op: Call) -> None: if self.in_parallel_context and is_tile_op(op): - raise ValueError("[Tilelang Semantic Check] " - "Only elementwise operations are allowed inside a parallel loop. " \ - f"Got a tile-op \"{op.op}\"." - ) + raise ValueError( + f'[Tilelang Semantic Check] Only elementwise operations are allowed inside a parallel loop. Got a tile-op "{op.op}".' + ) def NestedLoopChecker(): diff --git a/tilelang/autotuner/capture.py b/tilelang/autotuner/capture.py index 27c24f1..428a6da 100644 --- a/tilelang/autotuner/capture.py +++ b/tilelang/autotuner/capture.py @@ -85,8 +85,7 @@ def _get_current_stack() -> CaptureStack: class AutotuneInputsCapture: - - __slots__ = ("tensors") + __slots__ = "tensors" def __init__(self, tensors: list[Any]): self.tensors = tensors diff --git a/tilelang/autotuner/param.py b/tilelang/autotuner/param.py index 4c8d9a9..69ad49c 100644 --- a/tilelang/autotuner/param.py +++ b/tilelang/autotuner/param.py @@ -1,5 +1,5 @@ -"""The auto-tune parameters. -""" +"""The auto-tune parameters.""" + from __future__ import annotations import tilelang @@ -50,7 +50,7 @@ class CompileArgs: out_idx: list[int] | int | None = None execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto" - target: Literal['auto', 'cuda', 'hip'] = 'auto' + target: Literal["auto", "cuda", "hip"] = "auto" target_host: str | Target = None verbose: bool = False pass_configs: dict[str, Any] | None = None @@ -62,24 +62,20 @@ class CompileArgs: target=self.target, target_host=self.target_host, verbose=self.verbose, - pass_configs=self.pass_configs) + pass_configs=self.pass_configs, + ) def __hash__(self): data = { - "execution_backend": - self.execution_backend, - "target": - str(self.target), - "target_host": - str(self.target_host) if self.target_host else None, - "verbose": - self.verbose, - "pass_configs": - json.dumps(self.pass_configs, sort_keys=True) if self.pass_configs else None, + "execution_backend": self.execution_backend, + "target": str(self.target), + "target_host": str(self.target_host) if self.target_host else None, + "verbose": self.verbose, + "pass_configs": json.dumps(self.pass_configs, sort_keys=True) if self.pass_configs else None, } - hash_obj = hashlib.sha256(json.dumps(data, sort_keys=True).encode('utf-8')) - return int.from_bytes(hash_obj.digest(), byteorder='big') + hash_obj = hashlib.sha256(json.dumps(data, sort_keys=True).encode("utf-8")) + return int.from_bytes(hash_obj.digest(), byteorder="big") @dataclass(frozen=True) @@ -104,6 +100,7 @@ class ProfileArgs: manual_check_prog: Callable = None cache_input_tensors: bool = True """ + warmup: int = 25 rep: int = 100 timeout: int = 30 @@ -127,8 +124,8 @@ class ProfileArgs: "atol": self.atol, "max_mismatched_ratio": self.max_mismatched_ratio, } - hash_obj = hashlib.sha256(json.dumps(data, sort_keys=True).encode('utf-8')) - return int.from_bytes(hash_obj.digest(), byteorder='big') + hash_obj = hashlib.sha256(json.dumps(data, sort_keys=True).encode("utf-8")) + return int.from_bytes(hash_obj.digest(), byteorder="big") @dataclass(frozen=True) @@ -143,6 +140,7 @@ class AutotuneResult: func: Optimized function. kernel: Compiled kernel function. """ + latency: float | None = None config: dict | None = None ref_latency: float | None = None @@ -199,8 +197,7 @@ class AutotuneResult: if verbose: logger.debug(f"Saving kernel source code to file: {device_kernel_path}") if kernel.kernel_source is not None: - self._safe_write_file(device_kernel_path, "w", - lambda f: f.write(kernel.kernel_source)) + self._safe_write_file(device_kernel_path, "w", lambda f: f.write(kernel.kernel_source)) except Exception as e: logger.error(f"Error saving kernel source code to disk: {e}") @@ -211,11 +208,9 @@ class AutotuneResult: logger.debug(f"Saving wrapped kernel source code to file: {host_kernel_path}") # Match kernel_cache behavior: use host source for tvm_ffi, otherwise wrapped kernel if kernel.execution_backend == "tvm_ffi": - self._safe_write_file(host_kernel_path, "w", - lambda f: f.write(kernel.adapter.get_host_source())) + self._safe_write_file(host_kernel_path, "w", lambda f: f.write(kernel.adapter.get_host_source())) else: - self._safe_write_file(host_kernel_path, "w", - lambda f: f.write(kernel.adapter.get_kernel_source())) + self._safe_write_file(host_kernel_path, "w", lambda f: f.write(kernel.adapter.get_kernel_source())) except Exception as e: logger.error(f"Error saving wrapped kernel source code to disk: {e}") @@ -237,12 +232,10 @@ class AutotuneResult: py_src_path = src_lib_path.replace(".cubin", ".py") if verbose: logger.debug(f"Saving kernel nvrtc python code to file: {kernel_py_path}") - self._safe_write_file(kernel_py_path, "wb", - lambda f: f.write(self._load_binary(py_src_path))) + self._safe_write_file(kernel_py_path, "wb", lambda f: f.write(self._load_binary(py_src_path))) if verbose: logger.debug(f"Saving kernel library to file: {kernel_lib_path}") - self._safe_write_file(kernel_lib_path, "wb", - lambda f: f.write(self._load_binary(src_lib_path))) + self._safe_write_file(kernel_lib_path, "wb", lambda f: f.write(self._load_binary(src_lib_path))) elif kernel.execution_backend == "tvm_ffi": executable = kernel.adapter.executable if verbose: @@ -252,8 +245,7 @@ class AutotuneResult: src_lib_path = kernel.adapter.libpath if verbose: logger.debug(f"Saving kernel library to file: {kernel_lib_path}") - self._safe_write_file(kernel_lib_path, "wb", - lambda f: f.write(self._load_binary(src_lib_path))) + self._safe_write_file(kernel_lib_path, "wb", lambda f: f.write(self._load_binary(src_lib_path))) except Exception as e: logger.error(f"Error saving kernel library to disk: {e}") @@ -370,14 +362,12 @@ class AutotuneResult: # save best config (atomic) if verbose: logger.debug(f"Saving best config to file: {path / BEST_CONFIG_PATH}") - self._safe_write_file( - str(path / BEST_CONFIG_PATH), "w", lambda f: json.dump(self.config, f)) + self._safe_write_file(str(path / BEST_CONFIG_PATH), "w", lambda f: json.dump(self.config, f)) # save function (atomic) if verbose: logger.debug(f"Saving function to file: {path / FUNCTION_PATH}") - self._safe_write_file( - str(path / FUNCTION_PATH), "wb", lambda f: cloudpickle.dump(self.func, f)) + self._safe_write_file(str(path / FUNCTION_PATH), "wb", lambda f: cloudpickle.dump(self.func, f)) # save ref latency (atomic) if verbose: @@ -385,10 +375,13 @@ class AutotuneResult: self._safe_write_file( str(path / LATENCY_PATH), "w", - lambda f: json.dump({ - "latency": self.latency, - "ref_latency": self.ref_latency, - }, f), + lambda f: json.dump( + { + "latency": self.latency, + "ref_latency": self.ref_latency, + }, + f, + ), ) # save kernel @@ -403,8 +396,8 @@ class AutotuneResult: # Normalize target and resolve execution backend for loading from tilelang.utils.target import determine_target as _determine_target from tilelang.jit.execution_backend import resolve_execution_backend - norm_target = Target(_determine_target(compile_args.target)) if isinstance( - compile_args.target, str) else compile_args.target + + norm_target = Target(_determine_target(compile_args.target)) if isinstance(compile_args.target, str) else compile_args.target requested_backend = compile_args.execution_backend resolved_backend = resolve_execution_backend(requested_backend, norm_target) # load best config diff --git a/tilelang/autotuner/tuner.py b/tilelang/autotuner/tuner.py index 9b2fca2..5bbdc48 100644 --- a/tilelang/autotuner/tuner.py +++ b/tilelang/autotuner/tuner.py @@ -3,6 +3,7 @@ This module provides functionality for auto-tuning tilelang programs, including JIT compilation and performance optimization through configuration search. """ + from __future__ import annotations from dataclasses import dataclass @@ -14,7 +15,8 @@ from tvm.tir import PrimFunc, Var from tvm.target import Target import inspect from functools import partial -from typing import (Callable, Generic, Literal, Any, TypeVar) +from typing import Callable, Generic, Literal, Any, TypeVar + # Python 3.9 compatibility for ParamSpec try: from typing import ParamSpec @@ -74,8 +76,8 @@ def _init_logger_handlers(): global _logger_handlers_initialized if _logger_handlers_initialized: return - formatter = logging.Formatter('%(asctime)s %(levelname)s:%(message)s') - file_handler = logging.FileHandler('autotuner.log', mode='w') + formatter = logging.Formatter("%(asctime)s %(levelname)s:%(message)s") + file_handler = logging.FileHandler("autotuner.log", mode="w") file_handler.setLevel(logging.DEBUG) file_handler.setFormatter(formatter) console_handler = logging.StreamHandler(sys.stdout) @@ -87,8 +89,7 @@ def _init_logger_handlers(): def get_available_cpu_count() -> int: - """Gets the number of CPU cores available to the current process. - """ + """Gets the number of CPU cores available to the current process.""" try: cpu_count = len(os.sched_getaffinity(0)) except AttributeError: @@ -107,6 +108,7 @@ class AutoTuner: fn: The function to be auto-tuned. configs: List of configurations to try during auto-tuning. """ + compile_args = CompileArgs() profile_args = ProfileArgs() @@ -137,14 +139,15 @@ class AutoTuner: """ return cls(kernel, configs) - def set_compile_args(self, - out_idx: list[int] | int | None = None, - target: Literal['auto', 'cuda', 'hip', 'metal'] = 'auto', - execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", - "torch"] = "auto", - target_host: str | Target = None, - verbose: bool = False, - pass_configs: dict[str, Any] | None = None): + def set_compile_args( + self, + out_idx: list[int] | int | None = None, + target: Literal["auto", "cuda", "hip", "metal"] = "auto", + execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto", + target_host: str | Target = None, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + ): """Set compilation arguments for the auto-tuner. Args: @@ -161,6 +164,7 @@ class AutoTuner: # Normalize target to a concrete TVM Target and resolve execution backend t = Target(determine_target(target)) from tilelang.jit.execution_backend import resolve_execution_backend + resolved_backend = resolve_execution_backend(execution_backend, t) self.compile_args = CompileArgs( @@ -169,23 +173,26 @@ class AutoTuner: execution_backend=resolved_backend, target_host=target_host, verbose=verbose, - pass_configs=pass_configs) + pass_configs=pass_configs, + ) return self - def set_profile_args(self, - warmup: int = 25, - rep: int = 100, - timeout: int = 30, - supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto, - ref_prog: Callable = None, - supply_prog: Callable = None, - rtol: float = 1e-2, - atol: float = 1e-2, - max_mismatched_ratio: float = 0.01, - skip_check: bool = False, - manual_check_prog: Callable = None, - cache_input_tensors: bool = False): + def set_profile_args( + self, + warmup: int = 25, + rep: int = 100, + timeout: int = 30, + supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto, + ref_prog: Callable = None, + supply_prog: Callable = None, + rtol: float = 1e-2, + atol: float = 1e-2, + max_mismatched_ratio: float = 0.01, + skip_check: bool = False, + manual_check_prog: Callable = None, + cache_input_tensors: bool = False, + ): """Set profiling arguments for the auto-tuner. Args: @@ -209,9 +216,7 @@ class AutoTuner: # the `supply_prog` will be ignored and the `get_autotune_inputs` will be used instead. if get_autotune_inputs() is not None: if supply_prog is not None: - logger.warning( - "`supply_prog` will be ignored as this program is under `with set_autotune_inputs` context." - ) + logger.warning("`supply_prog` will be ignored as this program is under `with set_autotune_inputs` context.") supply_prog = lambda _: get_autotune_inputs() # noqa: E731 self.profile_args = ProfileArgs( @@ -226,13 +231,13 @@ class AutoTuner: cache_input_tensors=cache_input_tensors, warmup=warmup, rep=rep, - timeout=timeout) + timeout=timeout, + ) # If a custom `supply_prog` is provided, the profiler's `supply_type` setting # becomes ineffective. The custom supply program will be used instead. if supply_prog is not None and supply_type != tilelang.TensorSupplyType.Auto: - logger.warning("Ignoring `supply_type` passed to `set_profile_args` because " - "`supply_prog` is not None.") + logger.warning("Ignoring `supply_type` passed to `set_profile_args` because `supply_prog` is not None.") return self @@ -241,10 +246,8 @@ class AutoTuner: self._kernel_parameters = k_parameters self._function_parameters = f_parameters - def generate_cache_key(self, parameters: dict[str, Any], - extra_parameters: dict[str, Any]) -> AutotuneResult | None: - """Generate a cache key for the auto-tuning process. - """ + def generate_cache_key(self, parameters: dict[str, Any], extra_parameters: dict[str, Any]) -> AutotuneResult | None: + """Generate a cache key for the auto-tuning process.""" def _normalize_param(value): if isinstance(value, Var): @@ -315,8 +318,9 @@ class AutoTuner: if var_name in parameters: continue # Cell content must be serializable - assert isinstance(cell.cell_contents, (int, float, str, bool, type(None))), \ + assert isinstance(cell.cell_contents, (int, float, str, bool, type(None))), ( f"Cell contents {cell.cell_contents} is not serializable: {type(cell.cell_contents)}" + ) extra_parameters[var_name] = cell.cell_contents if isinstance(self.configs, Callable): @@ -328,8 +332,10 @@ class AutoTuner: if env.is_cache_enabled() and not env.is_autotune_cache_disabled(): # First check in-memory cache if key in self._memory_cache: - logger.warning("Found kernel in memory cache. For better performance," \ - " consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel.") + logger.warning( + "Found kernel in memory cache. For better performance," + " consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel." + ) return self._memory_cache[key] # Then check disk cache @@ -369,7 +375,6 @@ class AutoTuner: # This encapsulates the logic of using either a custom supply program (`supply_prog`) # or the default profiler input generation (`profiler._get_inputs`). def get_input_tensors_supply(with_output: bool): - def func(): if supply_prog is not None: return supply_prog(profiler._get_params(with_output=with_output)) @@ -387,8 +392,7 @@ class AutoTuner: self.jit_input_tensors = jit_input_tensors_supply() else: # check if the cached tensors are compatible with the current configuration - assert len(params) == len( - self.jit_input_tensors), "len(params) != len(self.jit_input_tensors)" + assert len(params) == len(self.jit_input_tensors), "len(params) != len(self.jit_input_tensors)" for p, c in zip(params, self.jit_input_tensors): if not isinstance(c, torch.Tensor): # skip non-tensor inputs checking @@ -397,8 +401,8 @@ class AutoTuner: # Check tensor compatibility using generator expression def shape_equal(a, b): return all( - a_dim == b_dim or isinstance(a_dim, Var) or isinstance(b_dim, Var) - for a_dim, b_dim in zip(a.shape, b.shape)) + a_dim == b_dim or isinstance(a_dim, Var) or isinstance(b_dim, Var) for a_dim, b_dim in zip(a.shape, b.shape) + ) if p.dtype != c.dtype or not shape_equal(p, c): logger.warning( @@ -409,7 +413,8 @@ class AutoTuner: "To ensure fresh, compatible inputs are generated for every trial " "you can disable caching by setting:\n" " `cache_input_tensors=False`\n" - "within your `.set_compile_args(...)` call.\n") + "within your `.set_compile_args(...)` call.\n" + ) # otherwise, regenerate the input tensors for safety self.jit_input_tensors = jit_input_tensors_supply() break @@ -418,24 +423,16 @@ class AutoTuner: if (not skip_check) and (ref_prog is not None): if manual_check_prog is not None: - profiler.manual_assert_close( - ref_prog, - input_tensors=self.jit_input_tensors, - manual_check_prog=manual_check_prog) + profiler.manual_assert_close(ref_prog, input_tensors=self.jit_input_tensors, manual_check_prog=manual_check_prog) else: profiler.assert_allclose( - ref_prog, - input_tensors=self.jit_input_tensors, - rtol=rtol, - atol=atol, - max_mismatched_ratio=max_mismatched_ratio) - latency = profiler.do_bench( - warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors) + ref_prog, input_tensors=self.jit_input_tensors, rtol=rtol, atol=atol, max_mismatched_ratio=max_mismatched_ratio + ) + latency = profiler.do_bench(warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors) if self.ref_latency_cache is None and ref_prog is not None: self.ref_input_tensors = ref_input_tensors_supply() - self.ref_latency_cache = profiler.do_bench( - ref_prog, n_warmup=warmup, n_repeat=rep, input_tensors=self.ref_input_tensors) + self.ref_latency_cache = profiler.do_bench(ref_prog, n_warmup=warmup, n_repeat=rep, input_tensors=self.ref_input_tensors) return latency, self.ref_latency_cache @@ -469,17 +466,14 @@ class AutoTuner: # Check if all tunable arguments have been tuned by comparing config keys with key_kwargs_tuple if any(key in top_config for key, _ in key_kwargs_tuple) or any( - check_tunable_argument_value(key, self._function_parameters, key_args_tuple) - for key in tunable_arguments): + check_tunable_argument_value(key, self._function_parameters, key_args_tuple) for key in tunable_arguments + ): logger.warning( f"Tunable parameters {tunable_arguments} already provided during auto-tuning. Skipping compilation and using direct JIT" ) # compile the kernel with the provided parameters jit_kernel = self.jit_compile() - autotuner_result = AutotuneResult( - libcode=jit_kernel.get_kernel_source(), - func=jit_kernel.prim_func, - kernel=jit_kernel) + autotuner_result = AutotuneResult(libcode=jit_kernel.get_kernel_source(), func=jit_kernel.prim_func, kernel=jit_kernel) self._memory_cache[key] = autotuner_result return autotuner_result # get the cpu count @@ -489,9 +483,7 @@ class AutoTuner: max_cpu_count = int(env.TILELANG_AUTO_TUNING_MAX_CPU_COUNT) if cpu_counts > 0: num_workers = min(cpu_counts, available_cpu_count) - logger.info( - f"Auto-tuning with {cpu_counts} CPU counts, {available_cpu_count} CPUs available, {num_workers} CPUs will be used" - ) + logger.info(f"Auto-tuning with {cpu_counts} CPU counts, {available_cpu_count} CPUs available, {num_workers} CPUs will be used") else: num_workers = max(1, int(available_cpu_count * cpu_utilizations)) logger.info( @@ -509,7 +501,6 @@ class AutoTuner: future_to_index = {} def cuda_device_wrapper(func, device): - def inner(**config_arg): torch.cuda.set_device(device) return func(**config_arg) @@ -532,18 +523,14 @@ class AutoTuner: future_to_index[future] = i results_with_configs = [] - for future in tqdm( - concurrent.futures.as_completed(futures), - total=len(futures), - desc="Compiling configurations"): + for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Compiling configurations"): idx = future_to_index[future] config = config_args[idx] try: result = future.result() results_with_configs.append((result, config)) except Exception as e: - logger.debug( - f"Compilation failed for config {config} at index {idx} with error: {e}") + logger.debug(f"Compilation failed for config {config} at index {idx} with error: {e}") continue ref_latency = None @@ -556,14 +543,10 @@ class AutoTuner: # latency, ref_latency = target_fn(jit_kernel) latency, ref_latency = run_with_timeout(target_fn, timeout, jit_kernel) except TimeoutException: - logger.warning( - f"A timeout occurred while testing config {config}, checkout autotuner.log for more details" - ) + logger.warning(f"A timeout occurred while testing config {config}, checkout autotuner.log for more details") continue except Exception: - logger.warning( - f"An error occurred while testing config {config}, checkout autotuner.log for more details" - ) + logger.warning(f"An error occurred while testing config {config}, checkout autotuner.log for more details") logger.debug(f"Error: {traceback.format_exc()}") continue @@ -578,8 +561,7 @@ class AutoTuner: pool.shutdown() if best_kernel is None: - error_msg = ("Auto-tuning failed: No configuration successfully " - "compiled and passed benchmarking/validation.") + error_msg = "Auto-tuning failed: No configuration successfully compiled and passed benchmarking/validation." logger.error(error_msg) raise RuntimeError(error_msg) @@ -595,7 +577,8 @@ class AutoTuner: ref_latency=ref_latency, libcode=best_kernel.get_kernel_source(), func=best_kernel.prim_func, - kernel=best_kernel) + kernel=best_kernel, + ) if self.compile_args.execution_backend in ("torch"): logger.warning("DLPack backend does not support cache saving to disk.") @@ -617,8 +600,8 @@ class AutoTuner: return self.run() -_P = ParamSpec('_P') -_T = TypeVar('_T') +_P = ParamSpec("_P") +_T = TypeVar("_T") @dataclass @@ -643,8 +626,9 @@ class AutoTuneImpl(Generic[_P, _T]): self._tuner_cache = {} def get_tunner(self): - autotuner = AutoTuner( - self.jit_impl.func, configs=self.configs).set_profile_args( + autotuner = ( + AutoTuner(self.jit_impl.func, configs=self.configs) + .set_profile_args( supply_type=self.supply_type, ref_prog=self.ref_prog, supply_prog=self.supply_prog, @@ -654,7 +638,8 @@ class AutoTuneImpl(Generic[_P, _T]): skip_check=self.skip_check, manual_check_prog=self.manual_check_prog, cache_input_tensors=self.cache_input_tensors, - ).set_compile_args( + ) + .set_compile_args( out_idx=self.jit_impl.out_idx, execution_backend=self.jit_impl.execution_backend, target=self.jit_impl.target, @@ -662,6 +647,7 @@ class AutoTuneImpl(Generic[_P, _T]): verbose=self.jit_impl.verbose, pass_configs=self.jit_impl.pass_configs, ) + ) autotuner.run = partial(autotuner.run, self.warmup, self.rep, self.timeout) return autotuner @@ -753,16 +739,13 @@ def autotune( # This is the new public interface if callable(func): # Case 1: Used as @autotune (func_or_out_idx is the function, others are defaults) # This is a placeholder for a real auto tuner implementation - raise ValueError( - "Use tilelang.autotune to decorate func without arguments is not supported yet.") + raise ValueError("Use tilelang.autotune to decorate func without arguments is not supported yet.") elif isinstance(func, PrimFunc): raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.") else: def decorator(impl): - assert isinstance( - impl, JITImpl - ), "The @autotune decorator can only be applied to @tilelang.jit decorated instances." + assert isinstance(impl, JITImpl), "The @autotune decorator can only be applied to @tilelang.jit decorated instances." return AutoTuneImpl( jit_impl=impl, configs=configs, diff --git a/tilelang/cache/__init__.py b/tilelang/cache/__init__.py index 144c272..18ac847 100644 --- a/tilelang/cache/__init__.py +++ b/tilelang/cache/__init__.py @@ -1,4 +1,5 @@ """The cache utils with class and database persistence - Init file""" + from __future__ import annotations from typing import Literal @@ -18,8 +19,7 @@ def cached( *args, target: str | Target = "auto", target_host: str | Target = None, - execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] - | None = "auto", + execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] | None = "auto", verbose: bool | None = False, pass_configs: dict | None = None, compile_flags: list[str] | str | None = None, @@ -36,7 +36,8 @@ def cached( execution_backend=execution_backend, verbose=verbose, pass_configs=pass_configs, - compile_flags=compile_flags) + compile_flags=compile_flags, + ) def clear_cache(): @@ -47,9 +48,11 @@ def clear_cache(): RuntimeError: Always raised to warn users to clear the cache manually. """ cache_dir = env.TILELANG_CACHE_DIR - raise RuntimeError("tilelang.clear_cache() is disabled because deleting the cache directory " - "is dangerous. If you accept the risk, remove it manually with " - f"`rm -rf '{cache_dir}'`.") + raise RuntimeError( + "tilelang.clear_cache() is disabled because deleting the cache directory " + "is dangerous. If you accept the risk, remove it manually with " + f"`rm -rf '{cache_dir}'`." + ) if env.TILELANG_CLEAR_CACHE.lower() in ("1", "true", "yes", "on"): diff --git a/tilelang/cache/kernel_cache.py b/tilelang/cache/kernel_cache.py index 74ecb27..4fbe2dc 100644 --- a/tilelang/cache/kernel_cache.py +++ b/tilelang/cache/kernel_cache.py @@ -1,4 +1,5 @@ """The cache utils with class and database persistence - KernelCache Class""" + from __future__ import annotations import json @@ -97,9 +98,7 @@ class KernelCache: "version": __version__, "func": sha256(func_binary).hexdigest(), # Use SHA256 to generate hash key "out_idx": (tuple(out_idx) if isinstance(out_idx, (list, tuple)) else [out_idx]), - "args_repr": tuple( - repr(arg) for arg in args - ), # Use repr to serialize arguments, may need more robust serialization + "args_repr": tuple(repr(arg) for arg in args), # Use repr to serialize arguments, may need more robust serialization "target": str(target), "target_host": str(target_host) if target_host else None, "execution_backend": execution_backend, @@ -118,8 +117,7 @@ class KernelCache: *args, target: str | Target = "auto", target_host: str | Target = None, - execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", - "torch"] = "auto", + execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto", verbose: bool = False, pass_configs: dict = None, compile_flags: list[str] | str | None = None, @@ -140,6 +138,7 @@ class KernelCache: # Normalize target and resolve execution backend before proceeding from tilelang.utils.target import determine_target as _determine_target from tilelang.jit.execution_backend import resolve_execution_backend, allowed_backends_for_target + norm_target = Target(_determine_target(target)) if isinstance(target, str) else target requested_backend = execution_backend execution_backend = resolve_execution_backend(requested_backend, norm_target) @@ -180,21 +179,21 @@ class KernelCache: with self._lock: # First check in-memory cache if key in self._memory_cache: - self.logger.warning("Found kernel in memory cache. For better performance," \ - " consider using `@tilelang.jit` instead of direct kernel caching.") + self.logger.warning( + "Found kernel in memory cache. For better performance, consider using `@tilelang.jit` instead of direct kernel caching." + ) return self._memory_cache[key] if verbose: self.logger.debug(f"Checking disk cache for kernel {func.attrs['global_symbol']}") # Then check disk cache - kernel = self._load_kernel_from_disk(key, norm_target, target_host, out_idx, - execution_backend, pass_configs, compile_flags, - func, verbose) + kernel = self._load_kernel_from_disk( + key, norm_target, target_host, out_idx, execution_backend, pass_configs, compile_flags, func, verbose + ) if kernel is not None: if verbose: - self.logger.debug( - f"Found kernel in disk cache for {func.attrs['global_symbol']}") + self.logger.debug(f"Found kernel in disk cache for {func.attrs['global_symbol']}") # Populate memory cache with disk result self._memory_cache[key] = kernel return kernel @@ -262,11 +261,7 @@ class KernelCache: executable.export_library(temp_path) os.replace(temp_path, path) - def _save_kernel_to_disk(self, - key: str, - kernel: JITKernel, - func: Callable = None, - verbose: bool = False): + def _save_kernel_to_disk(self, key: str, kernel: JITKernel, func: Callable = None, verbose: bool = False): """ Persists a compiled kernel to disk cache. @@ -292,8 +287,7 @@ class KernelCache: if verbose: self.logger.debug(f"Saving kernel source code to file: {device_kernel_path}") if kernel.kernel_source is not None: - KernelCache._safe_write_file(device_kernel_path, "w", - lambda file: file.write(kernel.kernel_source)) + KernelCache._safe_write_file(device_kernel_path, "w", lambda file: file.write(kernel.kernel_source)) except Exception as e: self.logger.error(f"Error saving kernel source code to disk: {e}") @@ -303,13 +297,9 @@ class KernelCache: if verbose: self.logger.debug(f"Saving wrapped kernel source code to file: {host_kernel_path}") if self.execution_backend == "tvm_ffi": - KernelCache._safe_write_file( - host_kernel_path, "w", - lambda file: file.write(kernel.adapter.get_host_source())) + KernelCache._safe_write_file(host_kernel_path, "w", lambda file: file.write(kernel.adapter.get_host_source())) else: - KernelCache._safe_write_file( - host_kernel_path, "w", - lambda file: file.write(kernel.adapter.get_kernel_source())) + KernelCache._safe_write_file(host_kernel_path, "w", lambda file: file.write(kernel.adapter.get_kernel_source())) except Exception as e: self.logger.error(f"Error saving host kernel source code to disk: {e}") @@ -332,9 +322,7 @@ class KernelCache: src_lib_path = src_lib_path.replace(".cubin", ".py") if verbose: self.logger.debug(f"Saving kernel nvrtc python code to file: {kernel_py_path}") - KernelCache._safe_write_file( - kernel_py_path, "wb", - lambda file: file.write(KernelCache._load_binary(src_lib_path))) + KernelCache._safe_write_file(kernel_py_path, "wb", lambda file: file.write(KernelCache._load_binary(src_lib_path))) elif self.execution_backend == "tvm_ffi": executable = kernel.adapter.executable if verbose: @@ -344,9 +332,7 @@ class KernelCache: src_lib_path = kernel.adapter.libpath if verbose: self.logger.debug(f"Saving kernel library to file: {kernel_lib_path}") - KernelCache._safe_write_file( - kernel_lib_path, "wb", - lambda file: file.write(KernelCache._load_binary(src_lib_path))) + KernelCache._safe_write_file(kernel_lib_path, "wb", lambda file: file.write(KernelCache._load_binary(src_lib_path))) except Exception as e: self.logger.error(f"Error saving kernel library to disk: {e}") @@ -356,8 +342,7 @@ class KernelCache: params_path = os.path.join(cache_path, PARAMS_PATH) if verbose: self.logger.debug(f"Saving kernel parameters to disk: {params_path}") - KernelCache._safe_write_file(params_path, "wb", - lambda file: cloudpickle.dump(kernel.params, file)) + KernelCache._safe_write_file(params_path, "wb", lambda file: cloudpickle.dump(kernel.params, file)) except Exception as e: self.logger.error(f"Error saving kernel parameters to disk: {e}") @@ -417,8 +402,7 @@ class KernelCache: self.logger.error(f"Error loading kernel source code from disk: {e}") try: if verbose: - self.logger.debug( - f"Loading wrapped kernel source code from file: {host_kernel_path}") + self.logger.debug(f"Loading wrapped kernel source code from file: {host_kernel_path}") with open(host_kernel_path) as f: host_kernel_source = f.read() except Exception as e: diff --git a/tilelang/carver/__init__.py b/tilelang/carver/__init__.py index 4ffd436..f1dfc5b 100644 --- a/tilelang/carver/__init__.py +++ b/tilelang/carver/__init__.py @@ -1,4 +1,5 @@ """Base infra""" + from .analysis import ( BlockInfo, # noqa: F401 IterInfo, # noqa: F401 diff --git a/tilelang/carver/analysis.py b/tilelang/carver/analysis.py index 96606e7..6ca9168 100644 --- a/tilelang/carver/analysis.py +++ b/tilelang/carver/analysis.py @@ -1,4 +1,5 @@ """Analysis on TIR blocks, loops and functions.""" + from __future__ import annotations from typing_extensions import Literal @@ -144,11 +145,13 @@ def normalize_prim_func(sch: tir.Schedule) -> list[BlockInfo] | None: var=iter.var, dom=iter.dom, loop_rv=loop, - ) for loop, iter in zip(loops, iters) + ) + for loop, iter in zip(loops, iters) ], block_rv=block, reduction_block=is_reduction, - )) + ) + ) return blocks @@ -188,8 +191,7 @@ def get_max_shared_memory_per_block(target: Target) -> int: _assert_gpu_target(target) max_shared_memory_per_block = target.attrs.get("max_shared_memory_per_block", None) if max_shared_memory_per_block is None: - raise ValueError( - f"Cannot find `max_shared_memory_per_block` in {target}, please specify it manually") + raise ValueError(f"Cannot find `max_shared_memory_per_block` in {target}, please specify it manually") return int(max_shared_memory_per_block) @@ -197,13 +199,11 @@ def get_root_block(sch: Schedule, func_name: str = "main") -> BlockRV: try: block = sch.mod[func_name].body.block except Exception: - raise ValueError(f"The function body is expected to be the root block, but got:\n" - f"{sch.mod[func_name].body}") from None + raise ValueError(f"The function body is expected to be the root block, but got:\n{sch.mod[func_name].body}") from None return sch.get_block(block.name_hint) -def collect_block_iter_vars_used_in_access_region(block: tir.Block, - region: list[ir.Range]) -> set[tir.Var]: +def collect_block_iter_vars_used_in_access_region(block: tir.Block, region: list[ir.Range]) -> set[tir.Var]: """Collect the block iter variables used in the access region of a buffer region.""" tir_vars = set() for expr in region: @@ -251,15 +251,13 @@ def is_broadcast_epilogue( for buffer_region in sch.get(epilogue).reads: if buffer_region.buffer not in write_buffers: continue - tir_vars = collect_block_iter_vars_used_in_access_region( - sch.get(epilogue), buffer_region.region) + tir_vars = collect_block_iter_vars_used_in_access_region(sch.get(epilogue), buffer_region.region) if len(tir_vars) < len(epilogue_iters): return True return False -def get_reduction_blocks(sch: tir.Schedule, - blocks: list[tir.schedule.BlockRV]) -> list[tir.schedule.BlockRV]: +def get_reduction_blocks(sch: tir.Schedule, blocks: list[tir.schedule.BlockRV]) -> list[tir.schedule.BlockRV]: # Get the main computation block def is_reduction(block: BlockRV) -> bool: block_stmt = sch.get(block) diff --git a/tilelang/carver/arch/__init__.py b/tilelang/carver/arch/__init__.py index c2bc9c7..b6cb9e7 100644 --- a/tilelang/carver/arch/__init__.py +++ b/tilelang/carver/arch/__init__.py @@ -39,18 +39,18 @@ def auto_infer_current_arch() -> TileDevice: __all__ = [ - 'is_cpu_arch', - 'is_cuda_arch', - 'is_volta_arch', - 'is_ampere_arch', - 'is_ada_arch', - 'is_hopper_arch', - 'is_tensorcore_supported_precision', - 'has_mma_support', - 'is_cdna_arch', - 'is_metal_arch', - 'CUDA', - 'CDNA', - 'METAL', - 'CPU', + "is_cpu_arch", + "is_cuda_arch", + "is_volta_arch", + "is_ampere_arch", + "is_ada_arch", + "is_hopper_arch", + "is_tensorcore_supported_precision", + "has_mma_support", + "is_cdna_arch", + "is_metal_arch", + "CUDA", + "CDNA", + "METAL", + "CPU", ] diff --git a/tilelang/carver/arch/arch_base.py b/tilelang/carver/arch/arch_base.py index 4c8825e..c5e9dfa 100644 --- a/tilelang/carver/arch/arch_base.py +++ b/tilelang/carver/arch/arch_base.py @@ -7,9 +7,7 @@ class TileDevice: self.reg_cap: int = 0 # Register capacity: The amount of register memory available self.smem_cap: int = 0 # Shared memory capacity: The amount of shared memory available self.compute_max_core: int = 0 # The maximum number of computing cores - self.warp_size: int = ( - 0 # The size of a warp, a group of threads that execute instructions in lockstep - ) + self.warp_size: int = 0 # The size of a warp, a group of threads that execute instructions in lockstep self.sm_partition: int = 0 # The number of streaming multiprocessor partitions self.transaction_size: list[int] = [ 0, @@ -21,9 +19,7 @@ class TileDevice: 0, ] # Bandwidth specifications, possibly including peak and sustained rates self.platform: str = "unknown" # The platform or manufacturer of the device - self.compute_capability: str = ( - "unknown" # The compute capability, indicating the feature set and performance level - ) + self.compute_capability: str = "unknown" # The compute capability, indicating the feature set and performance level self.l2_cache_size_bytes: int = 0 # the number of transaction size in bytes self.transaction_size: list[int] = [0, 0] # in bytes diff --git a/tilelang/carver/arch/cdna.py b/tilelang/carver/arch/cdna.py index ec5aa90..5c2d4c4 100644 --- a/tilelang/carver/arch/cdna.py +++ b/tilelang/carver/arch/cdna.py @@ -9,7 +9,6 @@ def is_cdna_arch(arch: TileDevice) -> bool: class CDNA(TileDevice): - def __init__(self, target: Target | str): if isinstance(target, str): target = tvm.target.Target(target) @@ -33,6 +32,6 @@ class CDNA(TileDevice): __all__ = [ - 'is_cdna_arch', - 'CDNA', + "is_cdna_arch", + "CDNA", ] diff --git a/tilelang/carver/arch/cpu.py b/tilelang/carver/arch/cpu.py index f4643ba..fc18c6c 100644 --- a/tilelang/carver/arch/cpu.py +++ b/tilelang/carver/arch/cpu.py @@ -10,7 +10,6 @@ def is_cpu_arch(arch: TileDevice) -> bool: # For LLVM Backend, we do not provide the detailed information of the CPU # As the LLVM backend do not required tuning, just maintain the consistency class CPU(TileDevice): - def __init__(self, target: Target): self.target = target device = tvm.runtime.cpu(0) @@ -21,6 +20,6 @@ class CPU(TileDevice): __all__ = [ - 'is_cpu_arch', - 'CPU', + "is_cpu_arch", + "CPU", ] diff --git a/tilelang/carver/arch/cuda.py b/tilelang/carver/arch/cuda.py index 4c7f98d..2b79b28 100644 --- a/tilelang/carver/arch/cuda.py +++ b/tilelang/carver/arch/cuda.py @@ -78,7 +78,6 @@ hopper_tensorcore_supported = ada_tensorcore_supported # instead of assuming both a and b share the same dtype. # As the tensorcore may supports float8_e4m3 * float8_e5m2 def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: TileDevice) -> bool: - if is_volta_arch(arch): return (in_dtype, accum_dtype) in volta_tensorcore_supported elif is_ampere_arch(arch): @@ -92,7 +91,6 @@ def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: Til class TensorInstruction: - def __init__( self, name: str, @@ -104,7 +102,6 @@ class TensorInstruction: class CUDA(TileDevice): - def __init__(self, target: Target | str): if isinstance(target, str): target = tvm.target.Target(target) @@ -148,12 +145,12 @@ class CUDA(TileDevice): __all__ = [ - 'is_cuda_arch', - 'is_volta_arch', - 'is_ampere_arch', - 'is_ada_arch', - 'is_hopper_arch', - 'is_tensorcore_supported_precision', - 'has_mma_support', + "is_cuda_arch", + "is_volta_arch", + "is_ampere_arch", + "is_ada_arch", + "is_hopper_arch", + "is_tensorcore_supported_precision", + "has_mma_support", "CUDA", ] diff --git a/tilelang/carver/arch/driver/cuda_driver.py b/tilelang/carver/arch/driver/cuda_driver.py index c8cc1a3..a631276 100644 --- a/tilelang/carver/arch/driver/cuda_driver.py +++ b/tilelang/carver/arch/driver/cuda_driver.py @@ -83,8 +83,7 @@ def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes") Get the maximum dynamic shared memory size in bytes, kilobytes, or megabytes. """ assert format in ["bytes", "kb", "mb"], "Invalid format. Must be one of: bytes, kb, mb" - shared_mem = get_device_attribute( - cudaDeviceAttrNames.cudaDevAttrMaxSharedMemoryPerMultiprocessor, device_id) + shared_mem = get_device_attribute(cudaDeviceAttrNames.cudaDevAttrMaxSharedMemoryPerMultiprocessor, device_id) if format == "bytes": return shared_mem elif format == "kb": diff --git a/tilelang/carver/arch/metal.py b/tilelang/carver/arch/metal.py index 9cd1c4d..0b76849 100644 --- a/tilelang/carver/arch/metal.py +++ b/tilelang/carver/arch/metal.py @@ -8,7 +8,6 @@ def is_metal_arch(arch: TileDevice) -> bool: class METAL(TileDevice): - def __init__(self, target: Target | str): if isinstance(target, str): target = Target(target) @@ -16,6 +15,6 @@ class METAL(TileDevice): __all__ = [ - 'is_metal_arch', - 'METAL', + "is_metal_arch", + "METAL", ] diff --git a/tilelang/carver/common_schedules.py b/tilelang/carver/common_schedules.py index 199f015..4904b77 100644 --- a/tilelang/carver/common_schedules.py +++ b/tilelang/carver/common_schedules.py @@ -19,6 +19,7 @@ # Modifications Copyright (c) Microsoft. # The code below is mostly copied from apache/tvm common_schedules.py in dlight. """Common schedule strategies for TIR.""" + from typing import Callable from tvm import tir diff --git a/tilelang/carver/matmul_analysis.py b/tilelang/carver/matmul_analysis.py index 02a86cc..6d27de8 100644 --- a/tilelang/carver/matmul_analysis.py +++ b/tilelang/carver/matmul_analysis.py @@ -1,5 +1,6 @@ # pylint: disable=missing-docstring, invalid-name """A GEMM schedule rule for GPU operators.""" + from __future__ import annotations from dataclasses import dataclass from enum import Enum @@ -157,8 +158,7 @@ def find_last_producer_from_buffer(sch, main_block, buffer: tir.Buffer) -> Block return block -def find_arg_idx_from_buffer_chain(sch: tir.Schedule, main_block: tir.schedule.BlockRV, - buffer: tir.Buffer) -> int: +def find_arg_idx_from_buffer_chain(sch: tir.Schedule, main_block: tir.schedule.BlockRV, buffer: tir.Buffer) -> int: """traverse to find the arg index from the buffer""" producers = sch.get_producers(main_block) @@ -226,9 +226,7 @@ def make_iter_fusion_index_map( else: fused_iters[trait.kind] = v_i - final_indices: list[tir.PrimExpr] = [ - fused_iters.get(kind, tir.IntImm(traits[0].extent.dtype, 0)) for kind in kind_order - ] + final_indices: list[tir.PrimExpr] = [fused_iters.get(kind, tir.IntImm(traits[0].extent.dtype, 0)) for kind in kind_order] return tir.IndexMap(input_iters, final_indices, None) @@ -307,8 +305,7 @@ def detect_iter_traits(block: tir.Block) -> tuple[list[IterTrait]] | None: return A_traits, B_traits, C_traits, block_traits -def get_index_map(block: tir.Block, - layout: list[str] | None = None) -> tuple[tir.IndexMap, ...] | None: +def get_index_map(block: tir.Block, layout: list[str] | None = None) -> tuple[tir.IndexMap, ...] | None: """Get index maps for the block Parameters @@ -343,10 +340,7 @@ def get_index_map(block: tir.Block, return axes def is_common_reduce(var: Var) -> bool: - for iter_var in block.iter_vars: - if iter_var.var == var and iter_var.iter_type == IterVar.CommReduce: - return True - return False + return any(iter_var.var == var and iter_var.iter_type == IterVar.CommReduce for iter_var in block.iter_vars) def has_common_reduce(var: Var) -> bool: vars = collect_vars_from_expr(var) @@ -384,17 +378,17 @@ def get_index_map(block: tir.Block, if kind == "C": return [IterKind.kIter_S, primary_iter, secondary_iter] else: - return ([IterKind.kIter_S, spatial_iter, reduction_iter] if check_last_trait(region) - else [IterKind.kIter_S, reduction_iter, spatial_iter]) + return ( + [IterKind.kIter_S, spatial_iter, reduction_iter] + if check_last_trait(region) + else [IterKind.kIter_S, reduction_iter, spatial_iter] + ) else: raise ValueError(f"Unknown layout {layout}") - A_index_map = make_iter_fusion_index_map( - A_traits, infer_layout(layout[0], block.reads[0].region, kind="A")) - B_index_map = make_iter_fusion_index_map( - B_traits, infer_layout(layout[1], block.reads[1].region, kind="B")) - C_index_map = make_iter_fusion_index_map( - C_traits, infer_layout(layout[2], block.writes[0].region, kind="C")) + A_index_map = make_iter_fusion_index_map(A_traits, infer_layout(layout[0], block.reads[0].region, kind="A")) + B_index_map = make_iter_fusion_index_map(B_traits, infer_layout(layout[1], block.reads[1].region, kind="B")) + C_index_map = make_iter_fusion_index_map(C_traits, infer_layout(layout[2], block.writes[0].region, kind="C")) matmul_index_map = make_iter_fusion_index_map( block_traits, @@ -429,8 +423,7 @@ def get_dequantize_block(sch, blocks) -> BlockRV | None: has_uint_input = any("uint" in str(region.buffer.dtype) for region in block_stmt.reads) if not has_uint_input: return False - return not (len(block_stmt.writes) != 1 or - "float" not in str(block_stmt.writes[0].buffer.dtype)) + return not (len(block_stmt.writes) != 1 or "float" not in str(block_stmt.writes[0].buffer.dtype)) dequantize_blocks = [block for block in blocks if is_dequantize(block)] return dequantize_blocks[0] if len(dequantize_blocks) == 1 else None @@ -452,8 +445,7 @@ def is_identity_or_transpose_block(block_stmt: tir.Block) -> bool: return None axes.extend(undefined_vars(r.min)) # remove trivial axis - trivial_vars = set( - iter_var.var for iter_var in block_stmt.iter_vars if _is_one(iter_var.dom.extent)) + trivial_vars = set(iter_var.var for iter_var in block_stmt.iter_vars if _is_one(iter_var.dom.extent)) axes = [axis for axis in axes if axis not in trivial_vars] # remove duplicate axis axes = [var for i, var in enumerate(axes) if i == 0 or var != axes[i - 1]] @@ -462,8 +454,7 @@ def is_identity_or_transpose_block(block_stmt: tir.Block) -> bool: lhs_access_vars = get_access_vars(block_stmt.reads[0].region)[-2:] rhs_access_vars = get_access_vars(block_stmt.writes[0].region)[-2:] is_identity = list(lhs_access_vars) == list(rhs_access_vars) - is_transpose = list(lhs_access_vars) != list(rhs_access_vars) and set(lhs_access_vars) == set( - rhs_access_vars) + is_transpose = list(lhs_access_vars) != list(rhs_access_vars) and set(lhs_access_vars) == set(rhs_access_vars) return is_identity, is_transpose @@ -491,9 +482,7 @@ def inline_transpose_block(sch: tir.Schedule, blocks: list[tir.schedule.BlockRV] return result_blocks -def normalize_to_matmul(sch: tir.Schedule, - main_block: BlockRV, - layout: list[str] | None = None) -> tir.Schedule | None: +def normalize_to_matmul(sch: tir.Schedule, main_block: BlockRV, layout: list[str] | None = None) -> tir.Schedule | None: if layout is None: layout = ["n", "t", "n"] block_stmt = sch.get(main_block) @@ -526,7 +515,7 @@ def get_tensorized_func_and_tags( allow_gemv: bool = False, ) -> tuple[tir.PrimFunc, dict[str, list[int] | int]]: """ - transform function to matmul if necessary (e.g. transform conv2d with im2col) + transform function to matmul if necessary (e.g. transform conv2d with im2col) """ if layout is None: layout = ["a", "a", "a"] @@ -543,10 +532,7 @@ def get_tensorized_func_and_tags( conditions = [] conditions.append(len(block_stmt.reads) == 2) conditions.append(len(block_stmt.writes) == 1) - conditions.append( - len( - collect_block_iter_vars_used_in_access_region(block_stmt, - block_stmt.writes[0].region)) > 0) + conditions.append(len(collect_block_iter_vars_used_in_access_region(block_stmt, block_stmt.writes[0].region)) > 0) return all(conditions) # step2. transform function to tensorcore matmul (e.g. conv2d with im2col) @@ -592,10 +578,7 @@ def get_tensorized_func_and_tags( return axes def is_common_reduce(var: Var) -> bool: - for iter_var in block_stmt.iter_vars: - if iter_var.var == var and iter_var.iter_type == IterVar.CommReduce: - return True - return False + return any(iter_var.var == var and iter_var.iter_type == IterVar.CommReduce for iter_var in block_stmt.iter_vars) def has_common_reduce(var: Var) -> bool: vars = collect_vars_from_expr(var) @@ -626,7 +609,7 @@ def get_tensorized_func_and_tags( # When the func is a dequantize like ops, we should consider the M require_block_reduce = False # And we only support float16 for now - if (hasattr(func.attrs, "dequantize_info") and in_dtype in ["bfloat16", "float16"]): + if hasattr(func.attrs, "dequantize_info") and in_dtype in ["bfloat16", "float16"]: for arg in func.params: inp_shape = func.buffer_map[arg].shape M = inp_shape[0] @@ -645,9 +628,7 @@ def get_tensorized_func_and_tags( if target.kind.name == "cuda" and check_sm_version(target.arch) >= 70: in_dtype, out_dtype = get_in_out_dtypes(block_stmt) if not is_tensorcore_supported_precision(in_dtype, out_dtype, arch=get_arch(target)): - logger.debug( - f"The input and output dtype ({in_dtype}, {out_dtype})is not supported by tensorcore" - ) + logger.debug(f"The input and output dtype ({in_dtype}, {out_dtype})is not supported by tensorcore") return func, None # reindex and transform functions @@ -676,7 +657,7 @@ def get_tensorized_func_and_tags( else: raise ValueError(f"Unknown IterVar type {iter_type}") - if (isinstance(extent, tir.expr.IntImm) and extent.value < minimal_tensorize_threshold): + if isinstance(extent, tir.expr.IntImm) and extent.value < minimal_tensorize_threshold: return func, None tags = analysis_tensorcore_tags(sch, main_block, target) return sch.mod["main"], tags @@ -686,8 +667,10 @@ def get_tensorized_func_and_tags( def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", index_dtype="int32"): from bitblas.tl.mma_layout import ( # pylint: disable=import-outside-toplevel - ldmatrix_32x8_to_shared_16x16_layout, ldmatrix_trans_32x8_to_shared_16x16_layout, - ldmatrix_32x16_to_shared_16x32_layout_a, ldmatrix_32x16_to_shared_16x32_layout_b, + ldmatrix_32x8_to_shared_16x16_layout, + ldmatrix_trans_32x8_to_shared_16x16_layout, + ldmatrix_32x16_to_shared_16x32_layout_a, + ldmatrix_32x16_to_shared_16x32_layout_b, ) assert dtype in [ @@ -727,9 +710,7 @@ def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", inde return ldmatrix_layout(thread_id, local_id) if dtype in ["bfloat16", "float16"]: - ldmatrix_index_map = ( - ldmatrix_trans_permutation_16x16_32x8_16x16 - if trans else ldmatrix_permutation_16x16_32x8_16x16) + ldmatrix_index_map = ldmatrix_trans_permutation_16x16_32x8_16x16 if trans else ldmatrix_permutation_16x16_32x8_16x16 else: ldmatrix_index_map = ldmatrix_permutation_16x32_32x16_32x16 @@ -744,7 +725,6 @@ def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", inde # Ladder weight propagation, which can be used to avoid the ldmatrix # Instructions. def get_ladder_stage3_map(dtype="float16", index_dtype="int32"): - def shared_32x8_to_mma_32x8_layout(i, j): thread_id = (i % 8) * 4 + (j // 2) local_id = (i // 8) * 2 + (j % 2) @@ -837,8 +817,7 @@ def layout_propagate_chain( scaling_factor = 1 for i, j in zip(write.buffer.shape, read.buffer.shape): scaling_factor *= i // j - final_indices = list( - index_map.map_indices(tmp_index_map.map_indices(write_indices))) + final_indices = list(index_map.map_indices(tmp_index_map.map_indices(write_indices))) final_indices[-1] = final_indices[-1] // scaling_factor index_map = IndexMap( write_indices, diff --git a/tilelang/carver/roller/bestfit.py b/tilelang/carver/roller/bestfit.py index b66ceaa..ec78174 100644 --- a/tilelang/carver/roller/bestfit.py +++ b/tilelang/carver/roller/bestfit.py @@ -2,7 +2,6 @@ class Block: - def __init__(self, start, end, is_free): self.start = start self.end = end @@ -21,7 +20,6 @@ class Block: class BestFit: - def __init__(self, align=32): self.limit = 0 self.list = [] @@ -31,16 +29,14 @@ class BestFit: size = (size + self.align - 1) // self.align * self.align found = None for block in self.list: - if block.is_free and block.size() >= size and (not found or - found.size() > block.size()): + if block.is_free and block.size() >= size and (not found or found.size() > block.size()): found = block if found: found.is_free = False remain = found.size() - size if remain != 0: found.end -= remain - self.list.insert( - self.list.index(found) + 1, Block(found.end, found.end + remain, True)) + self.list.insert(self.list.index(found) + 1, Block(found.end, found.end + remain, True)) return found elif len(self.list) > 0 and self.list[-1].is_free: add = size - self.list[-1].size() diff --git a/tilelang/carver/roller/hint.py b/tilelang/carver/roller/hint.py index 17c69da..8fd1fb4 100644 --- a/tilelang/carver/roller/hint.py +++ b/tilelang/carver/roller/hint.py @@ -1,4 +1,5 @@ """Hint definition for schedule""" + from tvm import DataType from . import PrimFuncNode import numpy as np @@ -60,7 +61,7 @@ class Stride: strided_elem = original_shape else: assert self.ax < len(shape) - strided_elem = np.prod(shape[0:self.ax + 1]) * self.stride + strided_elem = np.prod(shape[0 : self.ax + 1]) * self.stride assert strided_elem >= original_shape return int(strided_elem) @@ -217,7 +218,7 @@ class Hint: return dic @classmethod - def from_dict(cls, dic: dict) -> 'Hint': + def from_dict(cls, dic: dict) -> "Hint": hint = cls() for k, v in dic.items(): setattr(hint, k, v) diff --git a/tilelang/carver/roller/node.py b/tilelang/carver/roller/node.py index f9e38b1..3122c7b 100644 --- a/tilelang/carver/roller/node.py +++ b/tilelang/carver/roller/node.py @@ -1,4 +1,5 @@ """PrimFunc Wrapper and Block information Analaysis""" + from __future__ import annotations import tvm @@ -31,7 +32,6 @@ def pre_order_traverse(block_analyzer, blocks, func): class BlockAnalyzer: - def __init__(self, sch) -> None: self.sch: tir.Schedule = sch self.block_infos: list[BlockInfo] = normalize_prim_func(self.sch) @@ -92,7 +92,6 @@ class Edge: class Node: - def __init__(self, tags: dict | None = None, name: str = "Node") -> None: self.name = name if tags is None: @@ -177,7 +176,6 @@ class Node: class PlaceHolderNode(Node): - def __init__(self, name=""): super().__init__(name="PlaceHolder_" + name) @@ -189,11 +187,7 @@ class PlaceHolderNode(Node): class PrimFuncNode(Node): - - def __init__(self, - prim_func: PrimFunc, - tags: dict | None = None, - name: str = "PrimFuncNode") -> None: + def __init__(self, prim_func: PrimFunc, tags: dict | None = None, name: str = "PrimFuncNode") -> None: super().__init__(tags, name=name) self.prim_func = self._specialize_func(prim_func) self.sch: tir.Schedule = tir.Schedule(self.prim_func) @@ -227,7 +221,7 @@ class PrimFuncNode(Node): for dst_id, n in enumerate(inputs): if isinstance(n, Node): n = (n, 0) - assert (len(n) == 2) + assert len(n) == 2 src_node, src_id = n[0], n[1] edge = Edge(src_node, self, src_id, dst_id) self._in_edges.append(edge) @@ -338,9 +332,8 @@ class PrimFuncNode(Node): if rstep is None: rstep = {} shape = { - self.block_analyzer.get_output_buffers(block)[0].name: [ - tvm.arith.ConstIntBound(0, val - 1) for val in tile - ] for block in self.schedule_stages + self.block_analyzer.get_output_buffers(block)[0].name: [tvm.arith.ConstIntBound(0, val - 1) for val in tile] + for block in self.schedule_stages } return self.ana.infer(shape, rstep, targets) @@ -356,10 +349,7 @@ class PrimFuncNode(Node): results.append(shapes[arg.name]) continue # should not exceed original shape - trimmed_shape = [ - self.extent_wrapper(i) - for i in list(map(min, zip(shapes[arg.name], self.input_buffers[i].shape))) - ] + trimmed_shape = [self.extent_wrapper(i) for i in list(map(min, zip(shapes[arg.name], self.input_buffers[i].shape)))] results.append(trimmed_shape) return results @@ -380,10 +370,8 @@ class PrimFuncNode(Node): propagate_shape = shapes[arg.name] buffer_shape = args[i].shape if len(buffer_shape) > len(propagate_shape): - buffer_shape = buffer_shape[-len(propagate_shape):] - trimmed_shape = [ - self.extent_wrapper(j) for j in list(map(min, zip(propagate_shape, buffer_shape))) - ] + buffer_shape = buffer_shape[-len(propagate_shape) :] + trimmed_shape = [self.extent_wrapper(j) for j in list(map(min, zip(propagate_shape, buffer_shape)))] results.append(trimmed_shape) return results @@ -412,10 +400,7 @@ class PrimFuncNode(Node): def get_reduce_inputs_dtype(self): if self.reduction_block is None: return {} - return { - b.name: tvm.DataType(b.dtype) - for b in self.block_analyzer.get_input_buffers(self.reduction_block) - } + return {b.name: tvm.DataType(b.dtype) for b in self.block_analyzer.get_input_buffers(self.reduction_block)} @functools.lru_cache def infer_tensorcore_axis(self) -> tuple[int]: @@ -425,8 +410,7 @@ class PrimFuncNode(Node): C_ax_m, C_ax_n = self.get_tag("tensorcore_config") wmma_m, wmma_n, wmma_k = [16, 16, 16] # just for testing, any number is ok - output_buffer_shape = ( - self.block_analyzer.sch.get(self.reduction_block).writes[0].buffer.shape) + output_buffer_shape = self.block_analyzer.sch.get(self.reduction_block).writes[0].buffer.shape valid_region = [] for region in output_buffer_shape: if region.value == 1: @@ -438,8 +422,7 @@ class PrimFuncNode(Node): def get_cl_shapes(c_ax_m, c_ax_n, num_nvalid_regions): spatial_dim = self.get_space_dim() - assert len(valid_region) == len( - spatial_dim), f" {valid_region} mismatch with {spatial_dim}" + assert len(valid_region) == len(spatial_dim), f" {valid_region} mismatch with {spatial_dim}" cl_shapes = [1] * len(spatial_dim) cl_shapes[c_ax_m - num_nvalid_regions] = wmma_m cl_shapes[c_ax_n - num_nvalid_regions] = wmma_n @@ -467,9 +450,11 @@ class PrimFuncNode(Node): shapes, _ = self.propagate(shape, rstep) def is_broadcast_pattern(buffer, output_buffer): - return (buffer in self.args and - len(shapes[output_buffer.name]) > len(shapes[buffer.name]) and - np.prod(shapes[output_buffer.name]) > np.prod(shapes[buffer.name])) + return ( + buffer in self.args + and len(shapes[output_buffer.name]) > len(shapes[buffer.name]) + and np.prod(shapes[output_buffer.name]) > np.prod(shapes[buffer.name]) + ) def is_after_reduce_stage(block): if not self.reduction_block: @@ -491,8 +476,8 @@ class PrimFuncNode(Node): output_buffer = self.block_analyzer.get_output_buffers(block)[0] for buffer in self.block_analyzer.get_input_buffers(block): cache = buffer.name not in cached_tensor and ( - is_broadcast_pattern(buffer, output_buffer) or - self.block_analyzer.get_block_info(block).is_reduction()) + is_broadcast_pattern(buffer, output_buffer) or self.block_analyzer.get_block_info(block).is_reduction() + ) if not cache: continue cached_tensor.append(buffer.name) @@ -500,8 +485,7 @@ class PrimFuncNode(Node): continue # cache after reduce op can often reuse buffer in reduce stage if buffer.name in stride_map: - num_elem = stride_map[buffer.name].compute_elements_from_shape( - shapes[buffer.name]) + num_elem = stride_map[buffer.name].compute_elements_from_shape(shapes[buffer.name]) else: num_elem = np.prod(shapes[buffer.name]) buffer_len = num_elem * int((tvm.DataType(buffer.dtype).bits + 7) // 8) @@ -514,7 +498,6 @@ class PrimFuncNode(Node): class OutputNode(Node): - def __init__(self, node, id=0): super().__init__(name="OutputNode") # connect node and output node @@ -549,15 +532,16 @@ def topo_order(list_of_nodes) -> list[Node]: input_ready_count[dst_node] = len(dst_node.inputs) list_of_nodes.append(dst_node) input_ready_count[dst_node] -= 1 - assert (input_ready_count[dst_node] >= 0) + assert input_ready_count[dst_node] >= 0 if input_ready_count[dst_node] == 0: ready.append(dst_node) - assert (len(list_of_nodes) == len(output_list)) + assert len(list_of_nodes) == len(output_list) return output_list def find_topo_sort_priority(output_node_list) -> list[Node]: import sys + sys.setrecursionlimit(10000) def topo_sort_get_layer(node, topo_layer): @@ -576,9 +560,7 @@ def find_topo_sort_priority(output_node_list) -> list[Node]: if node in visited: return visited.add(node) - ordered_input_nodes = sorted([edge.src_node for edge in node.inputs], - key=lambda n: topo_layer[n], - reverse=True) + ordered_input_nodes = sorted([edge.src_node for edge in node.inputs], key=lambda n: topo_layer[n], reverse=True) for n in ordered_input_nodes: topo_sort_dfs(n, visited, topo_order) topo_order.append(node) @@ -591,7 +573,6 @@ def find_topo_sort_priority(output_node_list) -> list[Node]: def find_topo_sort(output_node_list) -> list[Node]: - def topo_sort_dfs(node, visited, topo_order): if node in visited: return diff --git a/tilelang/carver/roller/policy/default.py b/tilelang/carver/roller/policy/default.py index 161df27..d09216e 100644 --- a/tilelang/carver/roller/policy/default.py +++ b/tilelang/carver/roller/policy/default.py @@ -1,4 +1,5 @@ """Policy for cuda core schedule""" + from __future__ import annotations import functools import math @@ -36,20 +37,14 @@ class DefaultPolicy: self.rasterization = NoRasterization() @classmethod - def from_prim_func(cls, - func: tvm.tir.PrimFunc, - arch: TileDevice, - tags: dict | None = None, - name: str = "PrimFuncNode"): + def from_prim_func(cls, func: tvm.tir.PrimFunc, arch: TileDevice, tags: dict | None = None, name: str = "PrimFuncNode"): return cls(arch, tags)._init_with_prim_func(func, name) @classmethod def from_output_nodes(cls, nodes: list[OutputNode], arch: TileDevice, tags: dict | None = None): return cls(arch, tags)._init_with_output_nodes(nodes) - def _init_with_prim_func(self, - func: tvm.tir.PrimFunc, - name: str = "PrimFuncNode") -> DefaultPolicy: + def _init_with_prim_func(self, func: tvm.tir.PrimFunc, name: str = "PrimFuncNode") -> DefaultPolicy: if func is not None and isinstance(func, tvm.tir.PrimFunc): self.func = func self.prim_func_node = PrimFuncNode(self.func, tags=self.tags, name=name) @@ -60,9 +55,7 @@ class DefaultPolicy: return self def _init_with_output_nodes(self, output_nodes: list[OutputNode]): - self.ordered_nodes = list( - filter(lambda n: not n.is_placeholder() and not n.is_output(), - find_topo_sort(output_nodes))) + self.ordered_nodes = list(filter(lambda n: not n.is_placeholder() and not n.is_output(), find_topo_sort(output_nodes))) for node in self.ordered_nodes: node.update_tags(self.tags) @@ -102,13 +95,14 @@ class DefaultPolicy: def dfs_smem_tile(self, init_tile, rstep_map) -> Iterable[TileDict]: _steps = [get_all_factors(n) for n in self.output_nodes[0].get_space_dim()] - steps = [step[step.index(t):] for step, t in zip(_steps, init_tile)] + steps = [step[step.index(t) :] for step, t in zip(_steps, init_tile)] for i in range(len(steps)): added = list( filter( lambda s: s < steps[i][-1] and s > steps[i][0] and s not in steps[i], [2, 4, 8, 16, 32], - )) + ) + ) steps[i].extend(added) steps[i] = sorted(steps[i]) visited_tiles = {} @@ -190,10 +184,7 @@ class DefaultPolicy: """ tile_map = {} for node in self.output_nodes: - tile_map[node] = [ - tile[i] * node.get_space_dim()[i] // self.output_nodes[0].get_space_dim()[i] - for i in range(len(tile)) - ] + tile_map[node] = [tile[i] * node.get_space_dim()[i] // self.output_nodes[0].get_space_dim()[i] for i in range(len(tile))] return tile_map def compute_workload_per_item(self, output_tile) -> float: @@ -304,8 +295,7 @@ class DefaultPolicy: score = 0 shape = node.propagate_inputs(tile, rstep=rstep) for i, input_buffer in enumerate(node.input_buffers): - read_transaction_elements = self.arch.transaction_size[1] // ( - (node.get_buffer_dtype(input_buffer).bits + 7) // 8) + read_transaction_elements = self.arch.transaction_size[1] // ((node.get_buffer_dtype(input_buffer).bits + 7) // 8) score += sim( int(coalesced_factor(shape[i], input_buffer.shape)), read_transaction_elements, @@ -380,17 +370,13 @@ class DefaultPolicy: return None return max(candidates, key=lambda x: x[1])[0] - cur_rstep_id = { - k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis - } + cur_rstep_id = {k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis} new_rstep_map = rstep_map.copy() while True: new_rstep_id = _enlarge(cur_rstep_id) if new_rstep_id is None: break - new_rstep_map[node] = { - k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] for k in node.raxis - } + new_rstep_map[node] = {k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] for k in node.raxis} old_rstep_map = td.rstep_map td.rstep_map = new_rstep_map smem_usage, _ = self._compute_shared_memory_usage(td) @@ -434,15 +420,14 @@ class DefaultPolicy: if edge.src_node.is_placeholder(): nbytes = (edge.src_node.get_dtype().bits + 7) // 8 read_transaction_elements = self.arch.transaction_size[1] // nbytes - traffic += coalesced_tensor_shape(input_shapes[i], edge.src_node.get_shape(), - read_transaction_elements) * nbytes + traffic += coalesced_tensor_shape(input_shapes[i], edge.src_node.get_shape(), read_transaction_elements) * nbytes for edge in node.outputs: if edge.dst_node.is_output(): nbytes = (edge.src_node.get_dtype().bits + 7) // 8 write_transaction_elements = self.arch.transaction_size[0] // nbytes - traffic += coalesced_tensor_shape(output_shapes[edge.src_id], - node.get_shape(edge.src_id), - write_transaction_elements) * nbytes + traffic += ( + coalesced_tensor_shape(output_shapes[edge.src_id], node.get_shape(edge.src_id), write_transaction_elements) * nbytes + ) return traffic, op_tile_map @@ -487,10 +472,7 @@ class DefaultPolicy: cached_tensors_map = {} def can_free(node, out_id): - for edge in node.outputs: - if edge.src_id == out_id and edge.dst_node not in processed: - return False - return True + return all(not (edge.src_id == out_id and edge.dst_node not in processed) for edge in node.outputs) for node in self.ordered_nodes: node_internal_bytes, cached_tensors_map[node] = self.infer_node_smem_usage(td, node) @@ -528,9 +510,7 @@ class DefaultPolicy: Tuple[Dict, Dict] A tuple of dictionaries containing the output strides and tensor strides. """ - output_strides = { - int(i + len(node.input_buffers)): Stride() for i, _ in enumerate(node.output_buffers) - } + output_strides = {int(i + len(node.input_buffers)): Stride() for i, _ in enumerate(node.output_buffers)} tensor_strides = {} return output_strides, tensor_strides @@ -551,8 +531,7 @@ class DefaultPolicy: output_strides_map = {} tensor_strides_map = {} for node in self.ordered_nodes: - output_strides_map[node], tensor_strides_map[node] = self.compute_node_stride_map( - node, td) + output_strides_map[node], tensor_strides_map[node] = self.compute_node_stride_map(node, td) td.output_strides_map, td.tensor_strides_map = output_strides_map, tensor_strides_map def compute_tile_dict(self, output_tile: list[int], rstep_map) -> TileDict: @@ -582,9 +561,7 @@ class DefaultPolicy: output_shape = self.output_nodes[0].get_space_dim() td.grid_size = int(np.prod([(y + x - 1) // x for x, y in zip(output_tile, output_shape)])) # estimated reg usage - reg_usage = int(2 * max([ - np.prod(td.get_tile(node)) * node.get_dtype().bits / 32 for node in self.ordered_nodes - ])) + reg_usage = int(2 * max([np.prod(td.get_tile(node)) * node.get_dtype().bits / 32 for node in self.ordered_nodes])) if reg_usage > self.arch.reg_cap: td.valid = False return td @@ -609,13 +586,10 @@ class DefaultPolicy: for node in self.ordered_nodes: if np.prod(td.get_tile(node)) == 0: return False - node_grid_size = np.prod([ - (y + x - 1) // x for x, y in zip(td.get_tile(node), node.get_space_dim()) - ]) + node_grid_size = np.prod([(y + x - 1) // x for x, y in zip(td.get_tile(node), node.get_space_dim())]) if node_grid_size != td.grid_size: return False - if (hasattr(node, "reduce_op") and node.reduce_op is not None and - len(node.reduce_op.axis) == len(td.output_tile)): + if hasattr(node, "reduce_op") and node.reduce_op is not None and len(node.reduce_op.axis) == len(td.output_tile): for i, tile_extent in enumerate(td.output_tile): if node.reduce_op.axis[i].dom.extent % tile_extent: return False @@ -639,23 +613,22 @@ class DefaultPolicy: node_space_sizes = [int(np.prod(td.get_tile(node))) for node in self.ordered_nodes] max_block_size = functools.reduce(math.gcd, node_space_sizes) - if max_block_size < self.arch.warp_size * self.arch.sm_partition and max_block_size == min( - node_space_sizes): - node_reduce_sizes = [ - int(np.prod(list(td.get_rstep(node).values()))) for node in self.ordered_nodes - ] + if max_block_size < self.arch.warp_size * self.arch.sm_partition and max_block_size == min(node_space_sizes): + node_reduce_sizes = [int(np.prod(list(td.get_rstep(node).values()))) for node in self.ordered_nodes] total_sizes = [x * y for x, y in zip(node_space_sizes, node_reduce_sizes)] max_possible_size = functools.reduce(math.gcd, total_sizes) possible_block_sizes = list( filter( lambda x: x % max_block_size == 0 and x <= 1024, get_all_factors(max_possible_size), - )) + ) + ) possible_block_sizes = list( filter( # either be a factor of space or cover fully cover the space lambda x: all([x % s == 0 or s % x == 0 for s in node_space_sizes]), possible_block_sizes, - )) + ) + ) factor_ordered = sorted(possible_block_sizes, key=self.score_block_size) return factor_ordered else: @@ -821,8 +794,7 @@ class DefaultPolicy: vectorize_result = {} for tensor, shape in shapes.items(): for v in vectorize_sizes: - if (is_shape_aligned(shape, block_size * v) and is_cont(shape, v) and - is_type_allowed(dtypes[tensor], v)): + if is_shape_aligned(shape, block_size * v) and is_cont(shape, v) and is_type_allowed(dtypes[tensor], v): vectorize_result[tensor] = v break return vectorize_result diff --git a/tilelang/carver/roller/policy/tensorcore.py b/tilelang/carver/roller/policy/tensorcore.py index 15bad41..86c79ea 100644 --- a/tilelang/carver/roller/policy/tensorcore.py +++ b/tilelang/carver/roller/policy/tensorcore.py @@ -1,4 +1,5 @@ """Policy for tensorcore schedule""" + from __future__ import annotations import tvm import numpy as np @@ -13,7 +14,6 @@ logger = logging.getLogger(__name__) class TensorCorePolicy(DefaultPolicy): - # this is the trick for wmma. # However, for int8 mma, the wmma_k should be 32. wmma_k: int = 16 @@ -70,9 +70,9 @@ class TensorCorePolicy(DefaultPolicy): A_high_ax = min(A_ax_m, A_ax_k) B_high_ax = min(B_ax_n, B_ax_k) C_high_ax = min(C_ax_m, C_ax_n) - A_stride = Stride(stride=np.prod(AS_shape[A_high_ax + 1:]) + offset, ax=A_high_ax) - B_stride = Stride(stride=np.prod(BS_shape[B_high_ax + 1:]) + offset, ax=B_high_ax) - C_stride = Stride(stride=np.prod(CS_shape[C_high_ax + 1:]) + offset, ax=C_high_ax) + A_stride = Stride(stride=np.prod(AS_shape[A_high_ax + 1 :]) + offset, ax=A_high_ax) + B_stride = Stride(stride=np.prod(BS_shape[B_high_ax + 1 :]) + offset, ax=B_high_ax) + C_stride = Stride(stride=np.prod(CS_shape[C_high_ax + 1 :]) + offset, ax=C_high_ax) return A_stride, B_stride, C_stride def infer_node_smem_usage(self, td: TileDict, node: PrimFuncNode): @@ -86,8 +86,7 @@ class TensorCorePolicy(DefaultPolicy): # get reduce input size target_transaction = self.arch.transaction_size[0] * 2 # 512 bytes // type bits - reduce_input_dtype = node.get_buffer_dtype( - node.block_analyzer.get_input_buffers(node.reduction_block)[0]) + reduce_input_dtype = node.get_buffer_dtype(node.block_analyzer.get_input_buffers(node.reduction_block)[0]) basic = (target_transaction * 8) // reduce_input_dtype.bits result = {} @@ -95,7 +94,7 @@ class TensorCorePolicy(DefaultPolicy): iter_name = iter_info.var.name iter_dom = iter_info.dom.extent if iter_dom % 16 > 0: - result[iter_name] = (16 if iter_dom < basic else basic) # for the case of padding + result[iter_name] = 16 if iter_dom < basic else basic # for the case of padding elif iter_dom % basic == 0: result[iter_name] = basic else: @@ -114,7 +113,6 @@ class TensorCorePolicy(DefaultPolicy): return False if _check_small_tile(td): - smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap) rstep_map = td.rstep_map.copy() @@ -127,13 +125,10 @@ class TensorCorePolicy(DefaultPolicy): return rstep def _shared_memory_usage(td: TileDict): - return node.footprint(td.output_tile, new_rstep_map, - td.tensor_strides_map[node]) + return node.footprint(td.output_tile, new_rstep_map, td.tensor_strides_map[node]) def _score(rstep_id): - rstep = { - k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis - } + rstep = {k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis} score = 0 shape = node.propagate_inputs_on_reduction(td.get_tile(node), rstep=rstep) input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block) @@ -153,18 +148,13 @@ class TensorCorePolicy(DefaultPolicy): return None return max(candidates, key=lambda x: x[1])[0] - cur_rstep_id = { - k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis - } + cur_rstep_id = {k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis} new_rstep_map = rstep_map.copy() while True: new_rstep_id = _enlarge(cur_rstep_id) if new_rstep_id is None: break - new_rstep_map = { - k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] - for k in node.raxis - } + new_rstep_map = {k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] for k in node.raxis} old_rstep_map = td.rstep_map td.rstep_map = new_rstep_map smem_usage, _ = _shared_memory_usage(td) @@ -173,9 +163,7 @@ class TensorCorePolicy(DefaultPolicy): break else: cur_rstep_id = new_rstep_id - rstep = { - k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis - } + rstep = {k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis} return rstep for node in self.ordered_nodes: @@ -206,11 +194,7 @@ class TensorCorePolicy(DefaultPolicy): return super().get_node_reduce_step_candidates(node) else: # must be a a multiple of wmma_k - return { - k.var.name: [ - x * self.wmma_k for x in get_all_factors(int(k.dom.extent) // self.wmma_k) - ] for k in node.raxis - } + return {k.var.name: [x * self.wmma_k for x in get_all_factors(int(k.dom.extent) // self.wmma_k)] for k in node.raxis} def check_tile_shape_isvalid(self, td: TileDict): for node in self.ordered_nodes: @@ -221,10 +205,7 @@ class TensorCorePolicy(DefaultPolicy): td.tile_map[node][ax_n], ) # check the tile size is valid - wmma_invalid = [ - block_m < wmma_m or block_n < wmma_n - for wmma_m, wmma_n in self.arch.get_avaliable_tensorintrin_shapes() - ] + wmma_invalid = [block_m < wmma_m or block_n < wmma_n for wmma_m, wmma_n in self.arch.get_avaliable_tensorintrin_shapes()] if all(wmma_invalid): return False if any([y % x for x, y in zip(td.tile_map[node], node.get_space_dim())]): @@ -242,13 +223,10 @@ class TensorCorePolicy(DefaultPolicy): return super().compute_node_stride_map(node, td) use_layout = self._can_implement_layout(node, td) - AS_stride, BS_stride, C_stride = self._compute_tc_strides(node, td.get_tile(node), - td.get_rstep(node)) + AS_stride, BS_stride, C_stride = self._compute_tc_strides(node, td.get_tile(node), td.get_rstep(node)) A_stride, B_stride, _ = self._compute_tc_strides(node, td.get_tile(node)) tensor_strides = {} - output_strides = { - int(i + len(node.input_buffers)): Stride() for i, _ in enumerate(node.output_buffers) - } + output_strides = {int(i + len(node.input_buffers)): Stride() for i, _ in enumerate(node.output_buffers)} tensor_strides = {} # when connected to shared input, should use full stride without rstep for i, (_, _) in enumerate(zip([AS_stride, BS_stride], [A_stride, B_stride])): @@ -347,8 +325,7 @@ class TensorCorePolicy(DefaultPolicy): overall_gmem_size_in_bytes: int = 0 for node in self.ordered_nodes: for buffer in node.input_buffers: - overall_gmem_size_in_bytes += ( - int(np.prod(buffer.shape)) * tvm.DataType(buffer.dtype).bits // 8) + overall_gmem_size_in_bytes += int(np.prod(buffer.shape)) * tvm.DataType(buffer.dtype).bits // 8 return overall_gmem_size_in_bytes < self.arch.l2_cache_size_bytes conditions.append(_check_memory_size()) diff --git a/tilelang/carver/roller/rasterization.py b/tilelang/carver/roller/rasterization.py index ebd1319..ec565a1 100644 --- a/tilelang/carver/roller/rasterization.py +++ b/tilelang/carver/roller/rasterization.py @@ -2,7 +2,6 @@ class Rasterization: - panel_width_ = None def __init__(self) -> None: @@ -18,7 +17,6 @@ class Rasterization: class NoRasterization(Rasterization): - def __init__(self) -> None: super().__init__() diff --git a/tilelang/carver/roller/shape_inference/common.py b/tilelang/carver/roller/shape_inference/common.py index c52a170..c29ae41 100644 --- a/tilelang/carver/roller/shape_inference/common.py +++ b/tilelang/carver/roller/shape_inference/common.py @@ -4,9 +4,7 @@ from tvm import arith class Statement: - - def __init__(self, output: str, dependent_region: dict, var_map: OrderedDict, - range_map: OrderedDict): + def __init__(self, output: str, dependent_region: dict, var_map: OrderedDict, range_map: OrderedDict): self.output = output self.dependent_region = dependent_region self.var_map = var_map @@ -18,7 +16,6 @@ def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound): class InputShapeInference: - def __init__(self, deps: list[Statement]): self.deps = deps diff --git a/tilelang/carver/roller/shape_inference/tir.py b/tilelang/carver/roller/shape_inference/tir.py index 618cf9b..d7b11d6 100644 --- a/tilelang/carver/roller/shape_inference/tir.py +++ b/tilelang/carver/roller/shape_inference/tir.py @@ -5,7 +5,6 @@ from tvm import arith, tir class Statement: - def __init__(self, block_analyzer, block: BlockRV): self.block_analyzer = block_analyzer self.block = block @@ -21,9 +20,7 @@ class Statement: if len(self.dependent_region[input_name]) != 1: return None indices = self.dependent_region[input_name][0] - iter_map_range = { - _iter.var: _iter.dom for _iter in self.block_analyzer.get_spatial_axis(self.block) - } + iter_map_range = {_iter.var: _iter.dom for _iter in self.block_analyzer.get_spatial_axis(self.block)} iter_map_result = arith.detect_iter_map( indices, iter_map_range, @@ -77,7 +74,6 @@ class TensorDepNode: class DependencyAnalysis: - def __init__(self, deps): self.deps = deps # issue: duplicate name when we have two same ops. @@ -112,8 +108,7 @@ class DependencyAnalysis: def traverse_dependencies(self, compute): if isinstance(compute, Statement): - node = self.get_or_create_node( - compute.block_analyzer.get_output_buffers(compute.block)[0].name) + node = self.get_or_create_node(compute.block_analyzer.get_output_buffers(compute.block)[0].name) # Loop through input tensors for input_buffer in compute.block_analyzer.get_input_buffers(compute.block): # Get the input node @@ -167,7 +162,6 @@ class DependencyAnalysis: class InputShapeInference: - def __init__(self, deps: list[Statement]): self.deps = deps self.target_mapping = {} @@ -183,16 +177,11 @@ class InputShapeInference: if targets in self.target_mapping: return self.target_mapping[targets] # should be buffer name instead of block name - name2dep = { - dep.block_analyzer.get_output_buffers(dep.block)[0].name: dep for dep in self.deps - } + name2dep = {dep.block_analyzer.get_output_buffers(dep.block)[0].name: dep for dep in self.deps} mapping = {} input_vars = [] for target in targets: - vars = [ - iter.var - for iter in name2dep[target].block_analyzer.get_spatial_axis(name2dep[target].block) - ] + vars = [iter.var for iter in name2dep[target].block_analyzer.get_spatial_axis(name2dep[target].block)] input_vars.append(vars) mapping[target] = [vars] ana = arith.Analyzer() @@ -221,13 +210,8 @@ class InputShapeInference: mapping[input_name] = [] for indices in indices_list: for region in regions: - vmap = { - k: (tir.Cast(k.dtype, v) if v.dtype != k.dtype else v) - for k, v in zip(ax_vars, indices) - } - region = [ - ana.simplify(tir.stmt_functor.substitute(ax, vmap)) for ax in region - ] + vmap = {k: (tir.Cast(k.dtype, v) if v.dtype != k.dtype else v) for k, v in zip(ax_vars, indices)} + region = [ana.simplify(tir.stmt_functor.substitute(ax, vmap)) for ax in region] if not region_exist_in_list(region, mapping[input_name]): mapping[input_name].append(region) buffers = [] @@ -241,10 +225,7 @@ class InputShapeInference: self.target_mapping[targets] = input_vars, mapping return input_vars, mapping - def infer(self, - shape: dict[str, list[arith.ConstIntBound]], - rstep: dict[str, int] = None, - targets=None): + def infer(self, shape: dict[str, list[arith.ConstIntBound]], rstep: dict[str, int] = None, targets=None): if rstep is None: rstep = {} compute_targets = tuple(shape.keys()) @@ -258,8 +239,7 @@ class InputShapeInference: for ax in self.reduce_axes: # assume the dom.min is always 0, maybe we can extend the IterInfo to include the min value. if ax.var.name in rstep: - bound = arith.ConstIntBound( - int(ax.dom.min), int(ax.dom.min + min(ax.dom.extent, rstep[ax.var.name]) - 1)) + bound = arith.ConstIntBound(int(ax.dom.min), int(ax.dom.min + min(ax.dom.extent, rstep[ax.var.name]) - 1)) else: bound = arith.ConstIntBound(int(ax.dom.min), int(ax.dom.min + ax.dom.extent - 1)) ana.update(ax.var, bound, True) @@ -312,14 +292,11 @@ class InputShapeInference: for name, regions in mapping.items(): region = regions[0] - result[name] = [ - ana.simplify(tir.stmt_functor.substitute(index, vmap)) for index in region - ] + result[name] = [ana.simplify(tir.stmt_functor.substitute(index, vmap)) for index in region] return result def region_exist_in_list(a, list) -> bool: - def expr_is_same(a, b) -> bool: if isinstance(a, tir.IntImm) and isinstance(b, tir.IntImm): return a.value == b.value diff --git a/tilelang/carver/template/base.py b/tilelang/carver/template/base.py index a119c16..4a699fb 100644 --- a/tilelang/carver/template/base.py +++ b/tilelang/carver/template/base.py @@ -2,7 +2,12 @@ from abc import ABC, abstractmethod # For defining abstract base classes from dataclasses import dataclass, field # For defining data classes from ..arch import ( # Import architecture-related utilities and classes - TileDevice, is_volta_arch, is_ampere_arch, is_cdna_arch, auto_infer_current_arch) + TileDevice, + is_volta_arch, + is_ampere_arch, + is_cdna_arch, + auto_infer_current_arch, +) from ..roller.hint import Hint # Import the Hint class from ..roller.node import OutputNode # Import the OutputNode class from tvm.tir import PrimFunc # Import PrimFunc for handling tensor IR functions @@ -41,7 +46,7 @@ class BaseTemplate(ABC): """ pass - def with_arch(self, arch: TileDevice) -> 'BaseTemplate': + def with_arch(self, arch: TileDevice) -> "BaseTemplate": """ Sets the architecture for this template and returns itself. @@ -109,7 +114,7 @@ class BaseTemplate(ABC): """ raise NotImplementedError("initialize_function is not implemented") - def set_function(self, func: PrimFunc) -> 'BaseTemplate': + def set_function(self, func: PrimFunc) -> "BaseTemplate": """ Sets the function for this template and returns itself. @@ -122,7 +127,7 @@ class BaseTemplate(ABC): self._func = func return self - def set_output_nodes(self, output_nodes: list[OutputNode]) -> 'BaseTemplate': + def set_output_nodes(self, output_nodes: list[OutputNode]) -> "BaseTemplate": """ Sets the output nodes for this template and returns itself. diff --git a/tilelang/carver/template/conv.py b/tilelang/carver/template/conv.py index 9ea8920..c339e58 100644 --- a/tilelang/carver/template/conv.py +++ b/tilelang/carver/template/conv.py @@ -28,6 +28,7 @@ class ConvTemplate(BaseTemplate): accum_dtype (str): Data type used for accumulation. with_bias (bool): Whether to add a bias term. """ + # Operation-related configuration parameters N: int # The number of input samples processed simultaneously in a batch. C: int # The number of input feature maps. @@ -69,12 +70,18 @@ class ConvTemplate(BaseTemplate): AssertionError: If N, C, H, W, F, K, S, D, P are not positive integers. """ N, C, H, W, F, K, S, D, P = self.N, self.C, self.H, self.W, self.F, self.K, self.S, self.D, self.P - assert (isinstance(N, int) and isinstance(C, int) and isinstance(H, int) and - isinstance(W, int) and isinstance(F, int) and isinstance(K, int) and - isinstance(S, int) and isinstance(D, int) and - isinstance(P, int)), "Only Support Integer Params" - assert (N > 0 and C > 0 and H > 0 and W > 0 and F > 0 and K > 0 and S > 0 and D > 0 and - P > 0), "Params should be positive" + assert ( + isinstance(N, int) + and isinstance(C, int) + and isinstance(H, int) + and isinstance(W, int) + and isinstance(F, int) + and isinstance(K, int) + and isinstance(S, int) + and isinstance(D, int) + and isinstance(P, int) + ), "Only Support Integer Params" + assert N > 0 and C > 0 and H > 0 and W > 0 and F > 0 and K > 0 and S > 0 and D > 0 and P > 0, "Params should be positive" # Load configuration parameters in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype @@ -123,8 +130,10 @@ class ConvTemplate(BaseTemplate): te.if_then_else( te.all(h_in >= 0, h_in < H, w_in >= 0, w_in < W), A[n, h_in, w_in, c].astype(accum_dtype) * B[kh, kw, c, f].astype(accum_dtype), - tir.const(0, accum_dtype)), - axis=[kh, kw, c]) + tir.const(0, accum_dtype), + ), + axis=[kh, kw, c], + ) # Compute convolution result C = te.compute( diff --git a/tilelang/carver/template/flashattention.py b/tilelang/carver/template/flashattention.py index ae1a254..933ab95 100644 --- a/tilelang/carver/template/flashattention.py +++ b/tilelang/carver/template/flashattention.py @@ -9,7 +9,6 @@ from ..utils import get_roller_hints_from_output_nodes, get_tensorized_func_and_ @dataclass class FlashAttentionTemplate(BaseTemplate): - _output_nodes: list[OutputNode] = None # Operation-related configuration parameters @@ -91,10 +90,7 @@ class FlashAttentionTemplate(BaseTemplate): """ A_indices = [b, i, k] B_indices = [b, j, k] - return te.sum( - A[tuple(A_indices)].astype(accum_dtype) * - B[tuple(B_indices)].astype(accum_dtype), - axis=k) + return te.sum(A[tuple(A_indices)].astype(accum_dtype) * B[tuple(B_indices)].astype(accum_dtype), axis=k) # Compute matrix multiplication result C = te.compute( diff --git a/tilelang/carver/template/gemv.py b/tilelang/carver/template/gemv.py index cdcc78d..e7962f6 100644 --- a/tilelang/carver/template/gemv.py +++ b/tilelang/carver/template/gemv.py @@ -50,9 +50,8 @@ class GEMVTemplate(BaseTemplate): N, K = self.N, self.K # Ensure M, N, K are valid positive integers - assert (isinstance(M, int) and isinstance(N, int) and - isinstance(K, int)), "Only Support Integer M, N, K" - assert (M > 0 and N > 0 and K > 0), "M, N, K should be positive" + assert isinstance(M, int) and isinstance(N, int) and isinstance(K, int), "Only Support Integer M, N, K" + assert M > 0 and N > 0 and K > 0, "M, N, K should be positive" # Load configuration parameters trans_B = self.trans_B @@ -86,9 +85,7 @@ class GEMVTemplate(BaseTemplate): """ A_indices = [i, k] B_indices = [k, j] if not trans_B else [j, k] - return te.sum( - A[tuple(A_indices)].astype(accum_dtype) * B[tuple(B_indices)].astype(accum_dtype), - axis=k) + return te.sum(A[tuple(A_indices)].astype(accum_dtype) * B[tuple(B_indices)].astype(accum_dtype), axis=k) # Compute matrix multiplication result C = te.compute( diff --git a/tilelang/carver/template/general_reduce.py b/tilelang/carver/template/general_reduce.py index a8da5fd..b7a5515 100644 --- a/tilelang/carver/template/general_reduce.py +++ b/tilelang/carver/template/general_reduce.py @@ -9,15 +9,13 @@ from ..utils import get_roller_hints_from_func @dataclass class GeneralReductionTemplate(BaseTemplate): - # OP Related Config structure: str | list[str] = None shape: list[int] = None dtype: str = "float16" def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]: - roller_hints = get_roller_hints_from_func( - self._func, arch=arch, topk=topk, allow_gemv=False) + roller_hints = get_roller_hints_from_func(self._func, arch=arch, topk=topk, allow_gemv=False) return roller_hints def initialize_function(self) -> None: @@ -38,9 +36,9 @@ class GeneralReductionTemplate(BaseTemplate): spatial_axes = [] reduce_axes = [] for i, axis_type in enumerate(self.structure): - if axis_type.upper() == 'S': + if axis_type.upper() == "S": spatial_axes.append((i, self.shape[i])) - elif axis_type.upper() == 'R': + elif axis_type.upper() == "R": reduce_axes.append((i, self.shape[i])) else: raise ValueError(f"Unrecognized axis type '{axis_type}', only 'S'/'R' allowed.") @@ -90,7 +88,7 @@ class GeneralReductionTemplate(BaseTemplate): # Walk through the structure in order for axis_type in self.structure: - if axis_type.upper() == 'S': + if axis_type.upper() == "S": # use the next spatial_indices item full_index.append(spatial_indices[spatial_iter]) spatial_iter += 1 diff --git a/tilelang/carver/template/matmul.py b/tilelang/carver/template/matmul.py index 653ddab..57c92be 100644 --- a/tilelang/carver/template/matmul.py +++ b/tilelang/carver/template/matmul.py @@ -65,9 +65,8 @@ class MatmulTemplate(BaseTemplate): M, N, K = self.M, self.N, self.K # Ensure M, N, K are valid positive integers - assert (isinstance(M, int) and isinstance(N, int) and - isinstance(K, int)), "Only Support Integer M, N, K" - assert (M > 0 and N > 0 and K > 0), "M, N, K should be positive" + assert isinstance(M, int) and isinstance(N, int) and isinstance(K, int), "Only Support Integer M, N, K" + assert M > 0 and N > 0 and K > 0, "M, N, K should be positive" # Load configuration parameters trans_A, trans_B = self.trans_A, self.trans_B @@ -101,9 +100,7 @@ class MatmulTemplate(BaseTemplate): """ A_indices = [i, k] if not trans_A else [k, i] # Adjust indexing if A is transposed B_indices = [k, j] if not trans_B else [j, k] # Adjust indexing if B is transposed - return te.sum( - A[tuple(A_indices)].astype(accum_dtype) * B[tuple(B_indices)].astype(accum_dtype), - axis=k) + return te.sum(A[tuple(A_indices)].astype(accum_dtype) * B[tuple(B_indices)].astype(accum_dtype), axis=k) # Compute matrix multiplication result C = te.compute( diff --git a/tilelang/carver/utils.py b/tilelang/carver/utils.py index cedb754..67db89e 100644 --- a/tilelang/carver/utils.py +++ b/tilelang/carver/utils.py @@ -26,11 +26,9 @@ def get_rasterization_code(pannel_width: int = 8) -> str: """ -def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule, - arch: TileDevice, - topk: int = 10, - tensorcore_only: bool = False, - allow_gemv: bool = False) -> list[Hint] | None: +def get_roller_hints_from_func( + func_or_module: tir.PrimFunc | IRModule, arch: TileDevice, topk: int = 10, tensorcore_only: bool = False, allow_gemv: bool = False +) -> list[Hint] | None: func = None if isinstance(func_or_module, tir.PrimFunc): func = func_or_module @@ -44,8 +42,7 @@ def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule, roller_hints = None if tensorcore_only: try: - tensorized_func, tags = get_tensorized_func_and_tags( - func, arch.target, allow_gemv=allow_gemv) + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target, allow_gemv=allow_gemv) except Exception as e_msg: logger.debug("Get tensorized func and tags failed: ", e_msg) tags = None @@ -58,8 +55,7 @@ def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule, policy = DefaultPolicy.from_prim_func(func=func, arch=arch) tensorized_func = None try: - tensorized_func, tags = get_tensorized_func_and_tags( - func, arch.target, allow_gemv=allow_gemv) + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target, allow_gemv=allow_gemv) except Exception as e_msg: logger.debug("Get tensorized func and tags failed: ", e_msg) tags = None @@ -69,10 +65,9 @@ def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule, return roller_hints -def get_roller_hints_from_output_nodes(output_nodes: list[OutputNode], - arch: TileDevice, - topk: int = 10, - extra_tags: list[str] | None = None) -> list[Hint] | None: +def get_roller_hints_from_output_nodes( + output_nodes: list[OutputNode], arch: TileDevice, topk: int = 10, extra_tags: list[str] | None = None +) -> list[Hint] | None: assert isinstance(output_nodes, list), "The input should be a list of functions." lints = [] @@ -80,8 +75,7 @@ def get_roller_hints_from_output_nodes(output_nodes: list[OutputNode], policy = TensorCorePolicy.from_output_nodes(output_nodes, arch=arch, tags=None) lints = policy.emit_config(topk) except Exception as e_msg: - logger.debug(f"Generate hints from output nodes failed: {e_msg}", - "fallback to default policy") + logger.debug(f"Generate hints from output nodes failed: {e_msg}", "fallback to default policy") if len(lints) == 0: policy = DefaultPolicy.from_output_nodes(output_nodes, arch=arch, tags=None) @@ -92,7 +86,6 @@ def get_roller_hints_from_output_nodes(output_nodes: list[OutputNode], def retrieve_func_from_module(ir_module: IRModule) -> PrimFunc: if not isinstance(ir_module, IRModule): raise ValueError("Not supported type: ", type(ir_module)) - assert len(ir_module.get_global_vars()) == 1, ( - "The optimized module should only have one global variable for default schedule.") + assert len(ir_module.get_global_vars()) == 1, "The optimized module should only have one global variable for default schedule." func = list(ir_module.functions.values())[0] return func diff --git a/tilelang/contrib/cc.py b/tilelang/contrib/cc.py index 87d943a..7dc4597 100644 --- a/tilelang/contrib/cc.py +++ b/tilelang/contrib/cc.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Util to invoke C/C++ compilers in the system.""" + import functools import os import shutil @@ -30,8 +31,7 @@ from tvm.contrib import utils as _utils def _is_linux_like(): - return (sys.platform == "darwin" or sys.platform.startswith("linux") or - sys.platform.startswith("freebsd")) + return sys.platform == "darwin" or sys.platform.startswith("linux") or sys.platform.startswith("freebsd") def _is_windows_like(): @@ -90,7 +90,7 @@ def get_cplus_compiler(): def is_darwin(): - return platform.system() == 'Darwin' + return platform.system() == "Darwin" def create_shared(output, objects, options=None, cc=None, cwd=None, ccache_env=None): @@ -287,11 +287,7 @@ create_shared.output_format = "so" if sys.platform != "win32" else "dll" create_shared.get_target_triple = get_target_by_dump_machine(os.environ.get("CXX", get_cc())) -def cross_compiler(compile_func, - options=None, - output_format=None, - get_target_triple=None, - add_files=None): +def cross_compiler(compile_func, options=None, output_format=None, get_target_triple=None, add_files=None): """Create a cross compiler function by specializing compile_func with options. This function can be used to construct compile functions that @@ -363,13 +359,7 @@ def cross_compiler(compile_func, return _fcompile -def _linux_compile(output, - objects, - options, - compile_cmd, - cwd=None, - ccache_env=None, - compile_shared=False): +def _linux_compile(output, objects, options, compile_cmd, cwd=None, ccache_env=None, compile_shared=False): cmd = [compile_cmd] if compile_cmd != "nvcc": if compile_shared or output.endswith(".so") or output.endswith(".dylib"): @@ -430,15 +420,15 @@ def _windows_compile(output, objects, options, cwd=None, ccache_env=None): raise ValueError("ccache not found") try: - proc = subprocess.Popen( - cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=cwd, env=env) + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=cwd, env=env) (out, _) = proc.communicate() except FileNotFoundError: - raise RuntimeError("Can not find the LLVM clang for Windows clang.exe)." - "Make sure it's installed" - " and the installation directory is in the %PATH% environment " - "variable. Prebuilt binaries can be found at: https://llvm.org/") \ - from None + raise RuntimeError( + "Can not find the LLVM clang for Windows clang.exe)." + "Make sure it's installed" + " and the installation directory is in the %PATH% environment " + "variable. Prebuilt binaries can be found at: https://llvm.org/" + ) from None if proc.returncode != 0: msg = "Compilation error:\n" msg += " ".join(cmd) + "\n" diff --git a/tilelang/contrib/dlpack.py b/tilelang/contrib/dlpack.py index 6772fe1..d80f0fd 100644 --- a/tilelang/contrib/dlpack.py +++ b/tilelang/contrib/dlpack.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Wrapping functions to bridge frameworks with DLPack support to TVM""" + from tvm import runtime @@ -45,12 +46,8 @@ def convert_func(tvm_func, tensor_type, to_dlpack_func): def adapt_tensor(arg): if isinstance(arg, tensor_type): - if arg.dtype in { - torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, - torch.float8_e5m2fnuz - }: - return runtime.from_dlpack(to_dlpack_func(arg.view(torch.int8)))._create_view( - arg.shape, dtype=float8_dtype_map[arg.dtype]) + if arg.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz}: + return runtime.from_dlpack(to_dlpack_func(arg.view(torch.int8)))._create_view(arg.shape, dtype=float8_dtype_map[arg.dtype]) return runtime.from_dlpack(to_dlpack_func(arg)) return arg diff --git a/tilelang/contrib/hipcc.py b/tilelang/contrib/hipcc.py index 4e3c9a5..7b7f9f9 100644 --- a/tilelang/contrib/hipcc.py +++ b/tilelang/contrib/hipcc.py @@ -16,12 +16,7 @@ from tvm.base import py_str from tvm.contrib.rocm import get_rocm_arch, find_rocm_path -def compile_hip(code, - target_format="hsaco", - arch=None, - options=None, - path_target=None, - verbose=False): +def compile_hip(code, target_format="hsaco", arch=None, options=None, path_target=None, verbose=False): """Compile HIP code with hipcc. Parameters @@ -61,7 +56,7 @@ def compile_hip(code, file_target = path_target if path_target else temp_target cmd = ["hipcc"] - cmd += ["-O3", '-c'] + cmd += ["-O3", "-c"] if isinstance(arch, str): cmd += [f"--offload-arch={arch}"] if target_format == "hsaco": diff --git a/tilelang/contrib/nvcc.py b/tilelang/contrib/nvcc.py index 0e6a19b..36df6c8 100644 --- a/tilelang/contrib/nvcc.py +++ b/tilelang/contrib/nvcc.py @@ -1,6 +1,7 @@ # pylint: disable=invalid-name # modified from apache tvm python/tvm/contrib/nvcc.py """Utility to invoke nvcc compiler in the system""" + from __future__ import annotations import os @@ -18,12 +19,7 @@ from tvm.base import py_str from tvm.contrib import utils -def compile_cuda(code, - target_format="ptx", - arch=None, - options=None, - path_target=None, - verbose=False): +def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target=None, verbose=False): """Compile cuda code with NVCC from env. Parameters @@ -67,7 +63,7 @@ def compile_cuda(code, temp_target = temp.relpath(f"{file_name}.{target_format}") pass_context = tvm.get_global_func("transform.GetCurrentPassContext")() - kernels_output_dir = (pass_context.config.get("cuda.kernels_output_dir", None)) + kernels_output_dir = pass_context.config.get("cuda.kernels_output_dir", None) if kernels_output_dir is not None: if not os.path.isdir(kernels_output_dir): os.makedirs(kernels_output_dir) @@ -114,10 +110,7 @@ def compile_cuda(code, print(py_str(out)) if proc.returncode != 0: - msg = f"{code}\n" \ - f"Compilation error:\n" \ - f"{py_str(out)}\n" \ - f"Command: {' '.join(cmd)}\n" + msg = f"{code}\nCompilation error:\n{py_str(out)}\nCommand: {' '.join(cmd)}\n" raise RuntimeError(msg) with open(file_target, "rb") as f: @@ -165,6 +158,7 @@ def default_compile_options(compile_flags: list[str] | None = None) -> list[str] # (e.g., multiple "-gencode" pairs or repeated "-Xcompiler" entries). if compile_flags: import shlex + for flag in compile_flags: # Split each string like a shell would, preserving quoted args tokens = shlex.split(flag) if isinstance(flag, str) else [str(flag)] @@ -172,9 +166,7 @@ def default_compile_options(compile_flags: list[str] | None = None) -> list[str] return options -def get_ptx_from_source(code: str, - compile_flags: list[str] | None = None, - verbose: bool = False) -> str: +def get_ptx_from_source(code: str, compile_flags: list[str] | None = None, verbose: bool = False) -> str: """ Compile CUDA C++ source to PTX using NVCC and return as text. @@ -212,9 +204,7 @@ def _find_tool(name: str) -> str | None: return None -def get_sass_from_source(code: str, - compile_flags: list[str] | None = None, - verbose: bool = False) -> str: +def get_sass_from_source(code: str, compile_flags: list[str] | None = None, verbose: bool = False) -> str: """ Compile CUDA C++ source to CUBIN and disassemble to SASS. @@ -246,9 +236,7 @@ def get_sass_from_source(code: str, cand_nvdisasm = _find_tool("nvdisasm") cand_cuobjdump = _find_tool("cuobjdump") if not cand_nvdisasm and not cand_cuobjdump: - raise RuntimeError( - "Cannot find 'nvdisasm' or 'cuobjdump'. Please ensure CUDA toolkit is installed and in PATH." - ) + raise RuntimeError("Cannot find 'nvdisasm' or 'cuobjdump'. Please ensure CUDA toolkit is installed and in PATH.") last_err: str | None = None try: # Attempt nvdisasm first @@ -268,8 +256,7 @@ def get_sass_from_source(code: str, return text last_err = f"{tool_name} rc={proc.returncode}, output:\n{text}" # If we reach here, all attempts failed - raise RuntimeError(f"SASS disassembly failed. Tried tools: " - f"{', '.join(name for name, _ in tools_to_try)}\n{last_err or ''}") + raise RuntimeError(f"SASS disassembly failed. Tried tools: {', '.join(name for name, _ in tools_to_try)}\n{last_err or ''}") finally: with contextlib.suppress(Exception): os.remove(cubin_path) @@ -438,8 +425,7 @@ def get_target_compute_version(target=None): if tvm.cuda(0).exist: return tvm.cuda(0).compute_version - raise ValueError("No CUDA architecture was specified or GPU detected." - "Try specifying it by adding '-arch=sm_xx' to your target.") + raise ValueError("No CUDA architecture was specified or GPU detected.Try specifying it by adding '-arch=sm_xx' to your target.") def parse_compute_version(compute_version) -> tuple[int, int]: @@ -524,7 +510,8 @@ def have_tensorcore(compute_version=None, target=None): warnings.warn( "Tensorcore will be disabled due to no CUDA architecture specified." "Try specifying it by adding '-arch=sm_xx' to your target.", - stacklevel=2) + stacklevel=2, + ) return False compute_version = target.attrs["arch"] # Compute version will be in the form "sm_{major}{minor}" diff --git a/tilelang/contrib/nvrtc.py b/tilelang/contrib/nvrtc.py index b691155..105c518 100644 --- a/tilelang/contrib/nvrtc.py +++ b/tilelang/contrib/nvrtc.py @@ -11,11 +11,13 @@ def get_nvrtc_version() -> tuple[int, int]: return (major, minor) -def compile_cuda(code: str, - target_format: Literal["ptx", "cubin"] = "ptx", - arch: int | None = None, - options: str | list[str] | None = None, - verbose: bool = False) -> bytearray: +def compile_cuda( + code: str, + target_format: Literal["ptx", "cubin"] = "ptx", + arch: int | None = None, + options: str | list[str] | None = None, + verbose: bool = False, +) -> bytearray: """Compile cuda code with NVRTC. Parameters @@ -43,8 +45,7 @@ def compile_cuda(code: str, if arch is None: # If None, then it will use `tvm.target.Target.current().arch`. # Target arch could be a str like "80", "90", "90a", etc. - major, minor = parse_compute_version( - get_target_compute_version(Target.current(allow_none=True))) + major, minor = parse_compute_version(get_target_compute_version(Target.current(allow_none=True))) arch = major * 10 + minor prefix = "compute" if target_format == "ptx" else "sm" suffix = "a" if arch >= 90 else "" @@ -77,8 +78,7 @@ def compile_cuda(code: str, compile_result = nvrtc.nvrtcCompileProgram(program, len(options_bytes), options_bytes)[0] if compile_result != nvrtc.nvrtcResult.NVRTC_SUCCESS: - msg = f"{code}\n" \ - f"Compilation error:\n" + msg = f"{code}\nCompilation error:\n" if verbose: result, log_size = nvrtc.nvrtcGetProgramLogSize(program) assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to get program log size: {result}" @@ -105,7 +105,6 @@ def compile_cuda(code: str, assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to get PTX: {result}" # Destroy handler - assert nvrtc.nvrtcDestroyProgram( - program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to destroy program: {result}" + assert nvrtc.nvrtcDestroyProgram(program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to destroy program: {result}" return result_bytes diff --git a/tilelang/contrib/rocm.py b/tilelang/contrib/rocm.py index 4a57c3c..f3b92e5 100644 --- a/tilelang/contrib/rocm.py +++ b/tilelang/contrib/rocm.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Utility for ROCm backend""" + # ruff: noqa import re import subprocess @@ -255,9 +256,11 @@ def get_rocm_arch(rocm_path="/opt/rocm"): gpu_arch = match.group(1) return gpu_arch except subprocess.CalledProcessError: - print(f"Unable to execute rocminfo command, \ + print( + f"Unable to execute rocminfo command, \ please ensure ROCm is installed and you have an AMD GPU on your system.\ - using default {gpu_arch}.") + using default {gpu_arch}." + ) return gpu_arch diff --git a/tilelang/engine/lower.py b/tilelang/engine/lower.py index 7abdfb9..9932d52 100644 --- a/tilelang/engine/lower.py +++ b/tilelang/engine/lower.py @@ -1,4 +1,5 @@ """The compiler for TL programs.""" + from __future__ import annotations import os @@ -28,14 +29,13 @@ def is_cpu_device_backend(target: Target): def has_device_kernel_launch(attrs) -> bool: """Check if the attributes indicate a device kernel launch.""" - return bool(attrs and "calling_conv" in attrs and - attrs["calling_conv"] == CallingConv.DEVICE_KERNEL_LAUNCH) + return bool(attrs and "calling_conv" in attrs and attrs["calling_conv"] == CallingConv.DEVICE_KERNEL_LAUNCH) def is_device_call_c_device(func: tir.PrimFunc): attrs = func.attrs calling_conv = attrs.get("calling_conv", CallingConv.DEFAULT) - is_cpacked = (calling_conv == CallingConv.C_PACKED_FUNC) + is_cpacked = calling_conv == CallingConv.C_PACKED_FUNC # Check if it's a C target if "target" in attrs and attrs["target"].kind.name == "c" and not is_cpacked: @@ -141,16 +141,16 @@ def extrac_params(func: tir.PrimFunc) -> list[KernelParam]: if var in func.buffer_map: tensor_types.append(KernelParam.from_buffer(func.buffer_map[var])) else: - if var.dtype == 'handle': + if var.dtype == "handle": raise ValueError( - f'Handle parameter {var} must be mapped to a buffer.\n' - f'Please use T.tensor({var.name}, shape=..., dtype=...) to map it to a buffer.') + f"Handle parameter {var} must be mapped to a buffer.\n" + f"Please use T.tensor({var.name}, shape=..., dtype=...) to map it to a buffer." + ) tensor_types.append(KernelParam.from_var(var)) return tensor_types def canon_target_host(target: str | Target, target_host: str | Target | None): - if not target_host: target_host = "llvm" if tvm.runtime.enabled("llvm") else "c" @@ -195,11 +195,9 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) -> device_mod = tilelang.transform.LowerIntrin()(device_mod) device_mod = tir.transform.Simplify()(device_mod) if target.kind.name == "cuda": - device_mod = tvm.ffi.get_global_func("target.build.tilelang_cuda_without_compile")( - device_mod, target) + device_mod = tvm.ffi.get_global_func("target.build.tilelang_cuda_without_compile")(device_mod, target) elif target.kind.name == "hip": - device_mod = tvm.ffi.get_global_func("target.build.tilelang_hip_without_compile")( - device_mod, target) + device_mod = tvm.ffi.get_global_func("target.build.tilelang_hip_without_compile")(device_mod, target) elif target.kind.name == "c": device_mod = tvm.ffi.get_global_func("target.build.tilelang_cpp")(device_mod, target) elif target.kind.name == "llvm": @@ -222,12 +220,12 @@ def lower( enable_host_codegen=False, enable_device_compile=False, ) -> CompiledArtifact: - ''' - enable_host_codegen: whether to enable host codegen, default is False, as we have our - own host codegen implementation in jit. - enable_device_compile: whether to enable device codegen, default is False, as we have our - own device codegen implementation in jit. - ''' + """ + enable_host_codegen: whether to enable host codegen, default is False, as we have our + own host codegen implementation in jit. + enable_device_compile: whether to enable device codegen, default is False, as we have our + own device codegen implementation in jit. + """ mod = func_or_mod params = None @@ -259,14 +257,11 @@ def lower( host_mod = tir.transform.Filter(_is_host_call)(mod) device_mod = tir.transform.Filter(_is_device_call)(mod) - codegen_mod = device_codegen( - device_mod, target) if enable_device_compile else device_codegen_without_compile( - device_mod, target) + codegen_mod = device_codegen(device_mod, target) if enable_device_compile else device_codegen_without_compile(device_mod, target) if enable_host_codegen: host_mod = host_codegen(host_mod, target_host) host_mod.import_module(codegen_mod) - return CompiledArtifact( - host_mod, device_mod, params, codegen_mod.inspect_source(), rt_mod=host_mod) + return CompiledArtifact(host_mod, device_mod, params, codegen_mod.inspect_source(), rt_mod=host_mod) return CompiledArtifact(host_mod, device_mod, params, codegen_mod.inspect_source()) diff --git a/tilelang/engine/param.py b/tilelang/engine/param.py index de3c979..1abf66a 100644 --- a/tilelang/engine/param.py +++ b/tilelang/engine/param.py @@ -1,4 +1,5 @@ """The profiler and convert to torch utils""" + from __future__ import annotations from dataclasses import dataclass @@ -14,6 +15,7 @@ class KernelParam: Represents parameters for a kernel operation, storing dtype and shape information. Used to describe tensor or scalar parameters in TVM/PyTorch interop. """ + dtype: torch.dtype # PyTorch data type of the parameter shape: list[int | Var] # List of dimensions, can be integers or TVM variables @@ -109,6 +111,7 @@ class CompiledArtifact: Represents a compiled kernel artifact containing both host and device code. Stores all necessary components for kernel execution in the TVM runtime. """ + host_mod: tvm.IRModule # Host-side TVM IR module for managing kernel execution device_mod: tvm.IRModule # Device-side TVM IR module containing the actual kernel code params: list[KernelParam] # List of parameters (tensors/scalars) used by the kernel diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index cd205a6..cef3d9a 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -6,8 +6,7 @@ from tilelang.transform import PassContext from tilelang.contrib.nvcc import have_tma, is_hopper -def allow_warp_specialized(pass_ctx: PassContext | None = None, - target: Target | None = None) -> bool: +def allow_warp_specialized(pass_ctx: PassContext | None = None, target: Target | None = None) -> bool: # avoid circular import from tilelang.jit.adapter.utils import is_cuda_target @@ -19,8 +18,7 @@ def allow_warp_specialized(pass_ctx: PassContext | None = None, return not disable_warp_specialized -def allow_tma_and_warp_specialized(pass_ctx: PassContext | None = None, - target: Target | None = None) -> bool: +def allow_tma_and_warp_specialized(pass_ctx: PassContext | None = None, target: Target | None = None) -> bool: if pass_ctx is None: pass_ctx = tilelang.transform.get_pass_context() if not have_tma(target): @@ -47,12 +45,10 @@ def allow_global_thread_synchronization(pass_ctx: PassContext | None = None) -> return enable_global_thread_sync -def should_enable_aggressive_merge(pass_ctx: PassContext | None = None, - target: Target | None = None) -> bool: +def should_enable_aggressive_merge(pass_ctx: PassContext | None = None, target: Target | None = None) -> bool: if pass_ctx is None: pass_ctx = tilelang.transform.get_pass_context() - enable_aggressive_merge = bool( - pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE, False)) + enable_aggressive_merge = bool(pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE, False)) if allow_warp_specialized(pass_ctx=pass_ctx, target=target): # This is a workaround to avoid the bug in the MergeSharedMemoryAllocations pass # when warp specialization is enabled, as different warp threads may access different @@ -88,7 +84,7 @@ def get_layout_visual_formats(pass_ctx: PassContext | None = None) -> list[str]: return ["txt", "png", "pdf", "svg"] if "," in formats_str: - formats_list = [f.strip() for f in formats_str.split(',')] + formats_list = [f.strip() for f in formats_str.split(",")] else: formats_list = [formats_str] @@ -257,9 +253,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: # MergeSharedMemoryAllocations must be applied after SplitHostDevice # because the merged allocation site is at the beginning of each device function enable_aggressive_merge = should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target) - mod = tilelang.transform.MergeSharedMemoryAllocations( - enable_aggressive_merge=enable_aggressive_merge)( - mod) + mod = tilelang.transform.MergeSharedMemoryAllocations(enable_aggressive_merge=enable_aggressive_merge)(mod) mod = tilelang.transform.ThreadSync("shared")(mod) mod = tilelang.transform.ThreadSync("shared.dyn")(mod) # Inject PTX async copy must behind the thread sync pass diff --git a/tilelang/env.py b/tilelang/env.py index ce27aba..0583cd4 100644 --- a/tilelang/env.py +++ b/tilelang/env.py @@ -10,36 +10,34 @@ from dataclasses import dataclass logger = logging.getLogger(__name__) # SETUP ENVIRONMENT VARIABLES -CUTLASS_NOT_FOUND_MESSAGE = ("CUTLASS is not installed or found in the expected path") +CUTLASS_NOT_FOUND_MESSAGE = "CUTLASS is not installed or found in the expected path" ", which may lead to compilation bugs when utilize tilelang backend." -COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE = ( - "Composable Kernel is not installed or found in the expected path") +COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE = "Composable Kernel is not installed or found in the expected path" ", which may lead to compilation bugs when utilize tilelang backend." -TL_TEMPLATE_NOT_FOUND_MESSAGE = ("TileLang is not installed or found in the expected path") +TL_TEMPLATE_NOT_FOUND_MESSAGE = "TileLang is not installed or found in the expected path" ", which may lead to compilation bugs when utilize tilelang backend." -TVM_LIBRARY_NOT_FOUND_MESSAGE = ("TVM is not installed or found in the expected path") +TVM_LIBRARY_NOT_FOUND_MESSAGE = "TVM is not installed or found in the expected path" TL_ROOT = os.path.dirname(os.path.abspath(__file__)) # Only expose the internal lib directory to sys.path to avoid shadowing # common top-level module names (e.g., utils, analysis) from user projects. -TL_LIBS = [os.path.join(TL_ROOT, 'lib')] +TL_LIBS = [os.path.join(TL_ROOT, "lib")] TL_LIBS = [i for i in TL_LIBS if os.path.exists(i)] DEV = False -THIRD_PARTY_ROOT = os.path.join(TL_ROOT, '3rdparty') +THIRD_PARTY_ROOT = os.path.join(TL_ROOT, "3rdparty") if not os.path.exists(THIRD_PARTY_ROOT): DEV = True tl_dev_root = os.path.dirname(TL_ROOT) - dev_lib_root = os.path.join(tl_dev_root, 'build') + dev_lib_root = os.path.join(tl_dev_root, "build") # In dev builds, place artifacts under build/lib and point search path there # to avoid adding the entire build root to sys.path. - TL_LIBS = [os.path.join(dev_lib_root, 'lib'), os.path.join(dev_lib_root, 'tvm')] - THIRD_PARTY_ROOT = os.path.join(tl_dev_root, '3rdparty') - logger.warning(f'Loading tilelang libs from dev root: {dev_lib_root}') + TL_LIBS = [os.path.join(dev_lib_root, "lib"), os.path.join(dev_lib_root, "tvm")] + THIRD_PARTY_ROOT = os.path.join(tl_dev_root, "3rdparty") + logger.warning(f"Loading tilelang libs from dev root: {dev_lib_root}") -assert TL_LIBS and all( - os.path.exists(i) for i in TL_LIBS), f'tilelang lib root do not exists: {TL_LIBS}' +assert TL_LIBS and all(os.path.exists(i) for i in TL_LIBS), f"tilelang lib root do not exists: {TL_LIBS}" for lib in TL_LIBS: if lib not in sys.path: @@ -52,7 +50,7 @@ def _find_cuda_home() -> str: Adapted from https://github.com/pytorch/pytorch/blob/main/torch/utils/cpp_extension.py """ # Guess #1 - cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') + cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") if cuda_home is None: # Guess #2 nvcc_path = shutil.which("nvcc") @@ -70,15 +68,15 @@ def _find_cuda_home() -> str: else: # Guess #3 - if sys.platform == 'win32': - cuda_homes = glob.glob('C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*') - cuda_home = '' if len(cuda_homes) == 0 else cuda_homes[0] + if sys.platform == "win32": + cuda_homes = glob.glob("C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*") + cuda_home = "" if len(cuda_homes) == 0 else cuda_homes[0] else: # Linux/macOS - if os.path.exists('/usr/local/cuda'): - cuda_home = '/usr/local/cuda' - elif os.path.exists('/opt/nvidia/hpc_sdk/Linux_x86_64'): - cuda_home = '/opt/nvidia/hpc_sdk/Linux_x86_64' + if os.path.exists("/usr/local/cuda"): + cuda_home = "/usr/local/cuda" + elif os.path.exists("/opt/nvidia/hpc_sdk/Linux_x86_64"): + cuda_home = "/opt/nvidia/hpc_sdk/Linux_x86_64" # Validate found path if cuda_home is None or not os.path.exists(cuda_home): @@ -89,13 +87,13 @@ def _find_cuda_home() -> str: def _find_rocm_home() -> str: """Find the ROCM install path.""" - rocm_home = os.environ.get('ROCM_PATH') or os.environ.get('ROCM_HOME') + rocm_home = os.environ.get("ROCM_PATH") or os.environ.get("ROCM_HOME") if rocm_home is None: rocmcc_path = shutil.which("hipcc") if rocmcc_path is not None: rocm_home = os.path.dirname(os.path.dirname(rocmcc_path)) else: - rocm_home = '/opt/rocm' + rocm_home = "/opt/rocm" if not os.path.exists(rocm_home): rocm_home = None return rocm_home if rocm_home is not None else "" @@ -104,6 +102,7 @@ def _find_rocm_home() -> str: # Cache control class CacheState: """Class to manage global kernel caching state.""" + _enabled = True @classmethod @@ -230,13 +229,11 @@ class Environment: TILELANG_TMP_DIR = EnvVar("TILELANG_TMP_DIR", os.path.join(TILELANG_CACHE_DIR.get(), "tmp")) # Kernel Build options - TILELANG_PRINT_ON_COMPILATION = EnvVar("TILELANG_PRINT_ON_COMPILATION", - "1") # print kernel name on compile + TILELANG_PRINT_ON_COMPILATION = EnvVar("TILELANG_PRINT_ON_COMPILATION", "1") # print kernel name on compile TILELANG_DISABLE_CACHE = EnvVar( - "TILELANG_DISABLE_CACHE", - "0") # disable kernel cache, usually for unit testing / debugging, high priority - TILELANG_CLEAR_CACHE = EnvVar("TILELANG_CLEAR_CACHE", - "0") # DEPRECATED! clear cache automatically if set + "TILELANG_DISABLE_CACHE", "0" + ) # disable kernel cache, usually for unit testing / debugging, high priority + TILELANG_CLEAR_CACHE = EnvVar("TILELANG_CLEAR_CACHE", "0") # DEPRECATED! clear cache automatically if set # Kernel selection options # Default to GEMM v2; set to "1"/"true"/"yes"/"on" to force v1 @@ -244,12 +241,9 @@ class Environment: # Auto-tuning settings TILELANG_AUTO_TUNING_DISABLE_CACHE = EnvVar("TILELANG_AUTO_TUNING_DISABLE_CACHE", "0") - TILELANG_AUTO_TUNING_CPU_UTILITIES = EnvVar("TILELANG_AUTO_TUNING_CPU_UTILITIES", - "0.9") # percent of CPUs used - TILELANG_AUTO_TUNING_CPU_COUNTS = EnvVar("TILELANG_AUTO_TUNING_CPU_COUNTS", - "-1") # -1 means auto - TILELANG_AUTO_TUNING_MAX_CPU_COUNT = EnvVar("TILELANG_AUTO_TUNING_MAX_CPU_COUNT", - "-1") # -1 means no limit + TILELANG_AUTO_TUNING_CPU_UTILITIES = EnvVar("TILELANG_AUTO_TUNING_CPU_UTILITIES", "0.9") # percent of CPUs used + TILELANG_AUTO_TUNING_CPU_COUNTS = EnvVar("TILELANG_AUTO_TUNING_CPU_COUNTS", "-1") # -1 means auto + TILELANG_AUTO_TUNING_MAX_CPU_COUNT = EnvVar("TILELANG_AUTO_TUNING_MAX_CPU_COUNT", "-1") # -1 means no limit # TVM integration SKIP_LOADING_TILELANG_SO = EnvVar("SKIP_LOADING_TILELANG_SO", "0") @@ -323,18 +317,18 @@ def prepend_pythonpath(path): if env.TVM_IMPORT_PYTHON_PATH is not None: prepend_pythonpath(env.TVM_IMPORT_PYTHON_PATH) else: - tvm_path = os.path.join(THIRD_PARTY_ROOT, 'tvm', 'python') + tvm_path = os.path.join(THIRD_PARTY_ROOT, "tvm", "python") assert os.path.exists(tvm_path), tvm_path if tvm_path not in sys.path: prepend_pythonpath(tvm_path) env.TVM_IMPORT_PYTHON_PATH = tvm_path # By default, the built TVM-related libraries are stored in TL_LIBS. if os.environ.get("TVM_LIBRARY_PATH") is None: - os.environ['TVM_LIBRARY_PATH'] = env.TVM_LIBRARY_PATH = os.pathsep.join(TL_LIBS) + os.environ["TVM_LIBRARY_PATH"] = env.TVM_LIBRARY_PATH = os.pathsep.join(TL_LIBS) # Initialize CUTLASS paths if os.environ.get("TL_CUTLASS_PATH", None) is None: - cutlass_inc_path = os.path.join(THIRD_PARTY_ROOT, 'cutlass', 'include') + cutlass_inc_path = os.path.join(THIRD_PARTY_ROOT, "cutlass", "include") if os.path.exists(cutlass_inc_path): os.environ["TL_CUTLASS_PATH"] = env.CUTLASS_INCLUDE_DIR = cutlass_inc_path else: @@ -342,7 +336,7 @@ if os.environ.get("TL_CUTLASS_PATH", None) is None: # Initialize COMPOSABLE_KERNEL paths if os.environ.get("TL_COMPOSABLE_KERNEL_PATH", None) is None: - ck_inc_path = os.path.join(THIRD_PARTY_ROOT, 'composable_kernel', 'include') + ck_inc_path = os.path.join(THIRD_PARTY_ROOT, "composable_kernel", "include") if os.path.exists(ck_inc_path): os.environ["TL_COMPOSABLE_KERNEL_PATH"] = env.COMPOSABLE_KERNEL_INCLUDE_DIR = ck_inc_path else: diff --git a/tilelang/intrinsics/mfma_layout.py b/tilelang/intrinsics/mfma_layout.py index 183ba64..3895964 100644 --- a/tilelang/intrinsics/mfma_layout.py +++ b/tilelang/intrinsics/mfma_layout.py @@ -4,7 +4,7 @@ import tilelang.language as T def shared_16x4_to_local_64x1_layout_A(i, j): - thread_id = (j * 16 + i) + thread_id = j * 16 + i return thread_id, convert(0) @@ -15,7 +15,7 @@ def thread_id_shared_access_64x1_to_16x4_layout_A(thread_id, local_id): def shared_4x16_to_local_64x1_layout_B(i, j): - thread_id = (i * 16 + j) + thread_id = i * 16 + j return thread_id, convert(0) @@ -27,7 +27,7 @@ def thread_id_shared_access_64x1_to_4x16_layout_B(thread_id, local_id): def shared_16x16_to_local_64x4_layout_C(i, j): thread_id = j + (i // 4) * 16 - local = (i % 4) + local = i % 4 return thread_id, local @@ -45,7 +45,7 @@ def thread_id_shared_access_64x4_to_16x16_layout_A(thread_id, local_id): def shared_16x16_to_local_64x4_layout_A(i, j): thread_id = i + 16 * (j // 4) - local = (j % 4) + local = j % 4 return thread_id, local @@ -57,7 +57,7 @@ def thread_id_shared_access_64x4_to_16x16_layout_B(thread_id, local_id): def shared_16x16_to_local_64x4_layout_B(i, j): thread_id = j + (i // 4) * 16 - local = (i % 4) + local = i % 4 return thread_id, local @@ -87,7 +87,7 @@ def thread_id_shared_access_64x8_to_16x32_layout_A(thread_id, local_id): def shared_16x32_to_local_64x8_layout_A(i, j): thread_id = i + 16 * (j // 8) - local = (j % 8) + local = j % 8 return thread_id, local @@ -99,7 +99,7 @@ def thread_id_shared_access_64x8_to_16x32_layout_B(thread_id, local_id): def shared_16x32_to_local_64x8_layout_B(i, j): thread_id = j + (i // 8) * 16 - local = (i % 8) + local = i % 8 return thread_id, local @@ -111,7 +111,7 @@ def thread_id_shared_access_64x16_to_16x64_layout_A(thread_id, local_id): def shared_16x64_to_local_64x16_layout_A(i, j): thread_id = i + 16 * (j // 16) - local = (j % 16) + local = j % 16 return thread_id, local @@ -123,7 +123,7 @@ def thread_id_shared_access_64x16_to_16x64_layout_B(thread_id, local_id): def shared_16x64_to_local_64x16_layout_B(i, j): thread_id = i + 16 * (j // 16) - local = (j % 16) + local = j % 16 return thread_id, local diff --git a/tilelang/intrinsics/mfma_macro_generator.py b/tilelang/intrinsics/mfma_macro_generator.py index 618a998..1e97bd0 100644 --- a/tilelang/intrinsics/mfma_macro_generator.py +++ b/tilelang/intrinsics/mfma_macro_generator.py @@ -6,7 +6,7 @@ from tvm import tir from tvm.ir import Range from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion, BufferLoad from tvm.runtime import convert -from .utils import (mfma_store_index_map) +from .utils import mfma_store_index_map from typing import Literal, Callable from tilelang.utils import is_fragment @@ -101,7 +101,7 @@ class MatrixCoreIntrinEmitter: self.warp_rows = warp_row_tiles // self.micro_size_x self.warp_cols = warp_col_tiles // self.micro_size_y self.reduce_k = reduce_k - self.threads = (self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k) + self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k self.num_elems_per_byte = num_elems_per_byte self.thread_var = thread_var @@ -132,12 +132,7 @@ class MatrixCoreIntrinEmitter: def _initialize_mfma_prefix(self, k_dim=16): in_dtype, out_dtype = self.a_dtype, self.accum_dtype M_DIM, N_DIM = self.M_DIM, self.N_DIM - out_dtype_abbrv = { - "float16": "f16", - "float32": "f32", - "int8": "i8", - "int32": "i32" - }[out_dtype] + out_dtype_abbrv = {"float16": "f16", "float32": "f32", "int8": "i8", "int32": "i32"}[out_dtype] in_dtype_abbrv = { "bfloat16": "bf16", @@ -176,7 +171,6 @@ class MatrixCoreIntrinEmitter: self.b_preshuffle = b_preshuffle def get_ldmatrix_index_map(self, is_b=False): - k_dim = self.k_dim * self.k_pack transposed = self.a_transposed if not is_b else self.b_transposed if k_dim == 4: @@ -184,28 +178,42 @@ class MatrixCoreIntrinEmitter: reverse_index_map = thread_id_shared_access_64x1_to_16x4_layout_A if is_b: index_map = shared_16x4_to_local_64x1_layout_A if transposed else shared_4x16_to_local_64x1_layout_B - reverse_index_map = thread_id_shared_access_64x1_to_16x4_layout_A if transposed else thread_id_shared_access_64x1_to_4x16_layout_B + reverse_index_map = ( + thread_id_shared_access_64x1_to_16x4_layout_A if transposed else thread_id_shared_access_64x1_to_4x16_layout_B + ) elif k_dim == 16: index_map = shared_16x16_to_local_64x4_layout_B if transposed else shared_16x16_to_local_64x4_layout_A - reverse_index_map = thread_id_shared_access_64x4_to_16x16_layout_B if transposed else thread_id_shared_access_64x4_to_16x16_layout_A + reverse_index_map = ( + thread_id_shared_access_64x4_to_16x16_layout_B if transposed else thread_id_shared_access_64x4_to_16x16_layout_A + ) if is_b: index_map = shared_16x16_to_local_64x4_layout_A if transposed else shared_16x16_to_local_64x4_layout_B - reverse_index_map = thread_id_shared_access_64x4_to_16x16_layout_A if transposed else thread_id_shared_access_64x4_to_16x16_layout_B + reverse_index_map = ( + thread_id_shared_access_64x4_to_16x16_layout_A if transposed else thread_id_shared_access_64x4_to_16x16_layout_B + ) elif k_dim == 32: index_map = shared_16x32_to_local_64x8_layout_B if transposed else shared_16x32_to_local_64x8_layout_A - reverse_index_map = thread_id_shared_access_64x8_to_16x32_layout_B if transposed else thread_id_shared_access_64x8_to_16x32_layout_A + reverse_index_map = ( + thread_id_shared_access_64x8_to_16x32_layout_B if transposed else thread_id_shared_access_64x8_to_16x32_layout_A + ) if is_b: index_map = shared_16x32_to_local_64x8_layout_A if transposed else shared_16x32_to_local_64x8_layout_B - reverse_index_map = thread_id_shared_access_64x8_to_16x32_layout_A if transposed else thread_id_shared_access_64x8_to_16x32_layout_B + reverse_index_map = ( + thread_id_shared_access_64x8_to_16x32_layout_A if transposed else thread_id_shared_access_64x8_to_16x32_layout_B + ) elif k_dim == 64: index_map = shared_16x64_to_local_64x16_layout_B if transposed else shared_16x64_to_local_64x16_layout_A - reverse_index_map = thread_id_shared_access_64x16_to_16x64_layout_B if transposed else thread_id_shared_access_64x16_to_16x64_layout_A + reverse_index_map = ( + thread_id_shared_access_64x16_to_16x64_layout_B if transposed else thread_id_shared_access_64x16_to_16x64_layout_A + ) if is_b: index_map = shared_16x64_to_local_64x16_layout_A if transposed else shared_16x64_to_local_64x16_layout_B - reverse_index_map = thread_id_shared_access_64x16_to_16x64_layout_A if transposed else thread_id_shared_access_64x16_to_16x64_layout_B + reverse_index_map = ( + thread_id_shared_access_64x16_to_16x64_layout_A if transposed else thread_id_shared_access_64x16_to_16x64_layout_B + ) else: raise ValueError("k_dim must be 4 or 16 or 32 or 64 currently") @@ -227,14 +235,12 @@ class MatrixCoreIntrinEmitter: else: return self.thread_var - def extract_thread_binding(self, - thread_id, - is_m_first=None) -> tuple[PrimExpr, PrimExpr, PrimExpr]: - ''' - is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) - which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] - Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)] - ''' + def extract_thread_binding(self, thread_id, is_m_first=None) -> tuple[PrimExpr, PrimExpr, PrimExpr]: + """ + is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) + which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] + Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)] + """ WARP_SIZE = self.WARP_SIZE block_row_warps = self.block_row_warps block_col_warps = self.block_col_warps @@ -244,16 +250,18 @@ class MatrixCoreIntrinEmitter: is_m_first = self.is_m_first if is_m_first: - lane_id, warp_n, warp_m = thread_id % WARP_SIZE, ( - thread_id // - WARP_SIZE) % block_col_warps, (thread_id // - (WARP_SIZE * block_col_warps)) % block_row_warps, + lane_id, warp_n, warp_m = ( + thread_id % WARP_SIZE, + (thread_id // WARP_SIZE) % block_col_warps, + (thread_id // (WARP_SIZE * block_col_warps)) % block_row_warps, + ) return lane_id, warp_n, warp_m else: - lane_id, warp_m, warp_n = thread_id % WARP_SIZE, ( - thread_id // - WARP_SIZE) % block_row_warps, (thread_id // - (WARP_SIZE * block_row_warps)) % block_col_warps, + lane_id, warp_m, warp_n = ( + thread_id % WARP_SIZE, + (thread_id // WARP_SIZE) % block_row_warps, + (thread_id // (WARP_SIZE * block_row_warps)) % block_col_warps, + ) return lane_id, warp_n, warp_m def ldmatrix_a(self, A_local_buf, A_shared_buf: Buffer | BufferRegion, ki, rk=0): @@ -287,18 +295,14 @@ class MatrixCoreIntrinEmitter: for i in T.serial(warp_rows): for local_id in T.vectorized(k_pack * local_size_a): row, col = T.meta_var(reverse_index_map(tx, local_id)) - l, r = (rk * chunk + ki * (k_pack * micro_size_k), - warp_m * warp_row_tiles + i * micro_size_x) - A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, - A_base1 + r + col] + l, r = (rk * chunk + ki * (k_pack * micro_size_k), warp_m * warp_row_tiles + i * micro_size_x) + A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col] else: for i in T.serial(warp_rows): for local_id in T.vectorized(k_pack * local_size_a): row, col = T.meta_var(reverse_index_map(tx, local_id)) - l, r = (warp_m * warp_row_tiles + i * micro_size_x, - rk * chunk + ki * (k_pack * micro_size_k)) - A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, - A_base1 + r + col] + l, r = (warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * (k_pack * micro_size_k)) + A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col] return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) @@ -337,8 +341,7 @@ class MatrixCoreIntrinEmitter: warp_n * warp_col_tiles + j * micro_size_y, rk * chunk + ki * (k_pack * micro_size_k), ) - B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, - B_base1 + r + col] + B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col] else: for j in T.serial(warp_cols): @@ -348,16 +351,11 @@ class MatrixCoreIntrinEmitter: rk * chunk + ki * (k_pack * micro_size_k), warp_n * warp_col_tiles + j * micro_size_y, ) - B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, - B_base1 + r + col] + B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col] return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) - def mfma(self, - A_local_buf: Buffer, - B_local_buf: Buffer, - C_local_buf: Buffer, - k_inner: PrimExpr | None = 0): + def mfma(self, A_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_inner: PrimExpr | None = 0): warp_rows = self.warp_rows warp_cols = self.warp_cols local_size_a = self.local_size_a @@ -421,14 +419,13 @@ class MatrixCoreIntrinEmitter: for local_id in T.vectorized(local_size_out): row, col = T.meta_var(mfma_store_index_map(tx, local_id)) if C_buf_dims == 2: - C_buf[(warp_m * warp_rows + i) * M_DIM + row, - (warp_n * warp_cols + j) * N_DIM + - col] = C_local_buf[i * (warp_cols * local_size_out) + - j * local_size_out + local_id] + C_buf[(warp_m * warp_rows + i) * M_DIM + row, (warp_n * warp_cols + j) * N_DIM + col] = C_local_buf[ + i * (warp_cols * local_size_out) + j * local_size_out + local_id + ] else: - C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, - col] = C_local_buf[i * warp_cols * local_size_out + - j * local_size_out + local_id] + C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, col] = C_local_buf[ + i * warp_cols * local_size_out + j * local_size_out + local_id + ] @T.macro def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding): @@ -436,18 +433,17 @@ class MatrixCoreIntrinEmitter: for i, j in T.grid(warp_rows, warp_cols): for local_id in T.vectorized(local_size_out): row, col = T.meta_var(mfma_store_index_map(tx, local_id)) - C_buf[(pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row, - (pid_n * BLOCK_N + warp_n * warp_cols + j) * N_DIM + - col] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + - local_id] - - return _warp_stmatrix_global(C_local_buf, C_buf, - thread_binding) if is_global else _warp_stmatrix_shared( - C_local_buf, C_buf, thread_binding) - - def make_mfma_load_layout(self, - local_buf: Buffer, - matrix: Literal["A", "B"] = "A") -> T.Fragment: + C_buf[ + (pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row, (pid_n * BLOCK_N + warp_n * warp_cols + j) * N_DIM + col + ] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + local_id] + + return ( + _warp_stmatrix_global(C_local_buf, C_buf, thread_binding) + if is_global + else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding) + ) + + def make_mfma_load_layout(self, local_buf: Buffer, matrix: Literal["A", "B"] = "A") -> T.Fragment: """ Create a layout function for storing MFMA results into a fragment buffer. @@ -468,6 +464,7 @@ class MatrixCoreIntrinEmitter: If `local_buf` is not detected to be a fragment buffer. """ from tilelang.utils import is_fragment + assert matrix in ["A", "B"], "matrix should be either A or B" matrix_is_a: bool = matrix == "A" matrix_is_b: bool = matrix == "B" @@ -506,11 +503,9 @@ class MatrixCoreIntrinEmitter: transform_func: Callable = None if matrix_is_a: - transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( - j, i) + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i) elif matrix_is_b: - transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b( - j, i) + transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i) else: raise ValueError(f"Unsupported matrix {matrix}") @@ -543,8 +538,7 @@ class MatrixCoreIntrinEmitter: return local_id base_fragment = T.Fragment( - [micro_size_s, micro_size_r * - self.k_pack] if is_sr_axis_order else [micro_size_r * self.k_pack, micro_size_s], + [micro_size_s, micro_size_r * self.k_pack] if is_sr_axis_order else [micro_size_r * self.k_pack, micro_size_s], forward_thread_fn=forward_thread, forward_index_fn=forward_index, ) @@ -558,31 +552,19 @@ class MatrixCoreIntrinEmitter: replicate = block_col_warps if matrix_is_a else block_row_warps if is_sr_axis_order: - warp_fragment = base_fragment.repeat([warp_s, warp_r], - repeat_on_thread=False, - lower_dim_first=False) + warp_fragment = base_fragment.repeat([warp_s, warp_r], repeat_on_thread=False, lower_dim_first=False) if matrix_is_a: - block_fragment = warp_fragment.repeat([block_s, 1], - repeat_on_thread=True, - lower_dim_first=True).replicate(replicate) + block_fragment = warp_fragment.repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True).replicate(replicate) elif matrix_is_b: - block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], - repeat_on_thread=True, - lower_dim_first=True) + block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True) else: raise ValueError(f"Unsupported matrix type {matrix}") else: - warp_fragment = base_fragment.repeat([warp_r, warp_s], - repeat_on_thread=False, - lower_dim_first=True) + warp_fragment = base_fragment.repeat([warp_r, warp_s], repeat_on_thread=False, lower_dim_first=True) if matrix_is_a: - block_fragment = warp_fragment.repeat([1, block_s], - repeat_on_thread=True, - lower_dim_first=True).replicate(replicate) + block_fragment = warp_fragment.repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True).replicate(replicate) elif matrix_is_b: - block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], - repeat_on_thread=True, - lower_dim_first=True) + block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True) else: raise ValueError(f"Unsupported matrix type {matrix}") @@ -686,7 +668,6 @@ class MatrixCoreIntrinEmitter: class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): - def __init__( self, a_dtype: str = "float16", @@ -792,20 +773,20 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): rk * (chunk // micro_size_k) + ki, warp_m * warp_rows + i, ) - A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row, - col] + A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row, col] else: print(self.a_preshuffle) for i in T.serial(warp_rows): for local_id in T.vectorized(k_pack * local_size_a): row, col = T.meta_var(reverse_index_map(tx, local_id)) l, r = (warp_m * warp_rows + i, rk * (chunk // micro_size_k) + ki) - A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row, - col] + A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row, col] - return _warp_ldmatrix_a_global(A_local_buf, A_buf, ki, thread_binding, - rk) if is_global else _warp_ldmatrix_a_shared( - A_local_buf, A_buf, ki, thread_binding, rk) + return ( + _warp_ldmatrix_a_global(A_local_buf, A_buf, ki, thread_binding, rk) + if is_global + else _warp_ldmatrix_a_shared(A_local_buf, A_buf, ki, thread_binding, rk) + ) def ldmatrix_b(self, B_local_buf, B_buf, ki, rk=0, pid_m=None, pid_n=None): warp_cols = self.warp_cols @@ -867,8 +848,7 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): warp_n * warp_cols + j, rk * (chunk // micro_size_k) + ki, ) - B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row, - col] + B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row, col] else: for j in T.serial(warp_cols): for local_id in T.vectorized(k_pack * local_size_b): @@ -877,9 +857,10 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): rk * (chunk // micro_size_k) + ki, warp_n * warp_cols + j, ) - B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row, - col] + B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row, col] - return _warp_ldmatrix_b_global(B_local_buf, B_buf, ki, thread_binding, - rk) if is_global else _warp_ldmatrix_b_shared( - B_local_buf, B_buf, ki, thread_binding, rk) + return ( + _warp_ldmatrix_b_global(B_local_buf, B_buf, ki, thread_binding, rk) + if is_global + else _warp_ldmatrix_b_shared(B_local_buf, B_buf, ki, thread_binding, rk) + ) diff --git a/tilelang/intrinsics/mma_layout.py b/tilelang/intrinsics/mma_layout.py index f49b595..2eb575f 100644 --- a/tilelang/intrinsics/mma_layout.py +++ b/tilelang/intrinsics/mma_layout.py @@ -153,14 +153,14 @@ def mma_load_a_32x16_to_shared_16x32_layout(thread_id, local_id): def mma_load_a_32x8_to_shared_16x16_layout(thread_id, local_id): """ - groupID = %laneid >> 2 - threadID_in_group = %laneid % 4 + groupID = %laneid >> 2 + threadID_in_group = %laneid % 4 - row = groupID for ai where 0 <= i < 2 || 4 <= i < 6 - groupID + 8 Otherwise + row = groupID for ai where 0 <= i < 2 || 4 <= i < 6 + groupID + 8 Otherwise - col = (threadID_in_group * 2) + (i & 0x1) for ai where i < 4 - (threadID_in_group * 2) + (i & 0x1) + 8 for ai where i >= 4 + col = (threadID_in_group * 2) + (i & 0x1) for ai where i < 4 + (threadID_in_group * 2) + (i & 0x1) + 8 for ai where i >= 4 """ row = (thread_id // 4) + 8 * (local_id % 4 // 2) col = (thread_id % 4) * 2 + (local_id % 2) + 8 * (local_id // 4) @@ -175,13 +175,13 @@ def mma_load_b_32x16_to_shared_16x32_layout(thread_id, local_id): def mma_load_b_32x8_to_shared_16x16_layout(thread_id, local_id): """ - groupID = %laneid >> 2 - threadID_in_group = %laneid % 4 + groupID = %laneid >> 2 + threadID_in_group = %laneid % 4 - row = (threadID_in_group * 2) + (i & 0x1) for bi where i < 2 - (threadID_in_group * 2) + (i & 0x1) + 8 for bi where i >= 2 + row = (threadID_in_group * 2) + (i & 0x1) for bi where i < 2 + (threadID_in_group * 2) + (i & 0x1) + 8 for bi where i >= 2 - col = groupID + col = groupID """ col = (thread_id % 4) * 2 + ((local_id % 4) % 2) + ((local_id % 4) // 2) * 8 row = (thread_id // 4) + 8 * (local_id // 4) diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py index 5811eb5..28afdb2 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/intrinsics/mma_macro_generator.py @@ -191,6 +191,7 @@ class TensorCoreIntrinEmitter: def get_store_index_map(self, inverse: bool = False) -> IndexMap: from .utils import mma_store_index_map, mma_store_index_map_fp64 + warp_size, local_size_c = self.WARP_SIZE, self.local_size_out if DataType(self.accum_dtype).bits == 64: index_map = IndexMap.from_func(mma_store_index_map_fp64, index_dtype="int32") @@ -201,10 +202,7 @@ class TensorCoreIntrinEmitter: inverse_index_map = index_map.inverse([warp_size, local_size_c]) return inverse_index_map - def extract_thread_binding( - self, - thread_id: PrimExpr, - is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]: + def extract_thread_binding(self, thread_id: PrimExpr, is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]: """ is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] @@ -233,11 +231,7 @@ class TensorCoreIntrinEmitter: ) return lane_id, warp_n, warp_m - def ldmatrix_a(self, - A_local_buf: Buffer, - A_shared_buf: Buffer | BufferRegion, - ki: PrimExpr, - rk: PrimExpr | None = 0): + def ldmatrix_a(self, A_local_buf: Buffer, A_shared_buf: Buffer | BufferRegion, ki: PrimExpr, rk: PrimExpr | None = 0): # Fast path for fp64: no ldmatrix support, do direct per-lane loads if DataType(self.a_dtype).bits == 64: warp_row_tiles = self.warp_row_tiles @@ -324,9 +318,7 @@ class TensorCoreIntrinEmitter: for i in T.serial(warp_rows): # Assign A_shared_buf_elem wi, wk = warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k - A_shared_buf_elem = A_buf[A_base0 + wk, - A_base1 + wi] if a_transposed else A_buf[A_base0 + wi, - A_base1 + wk] + A_shared_buf_elem = A_buf[A_base0 + wk, A_base1 + wi] if a_transposed else A_buf[A_base0 + wi, A_base1 + wk] if ldmatrix_available: T.ptx_ldmatrix( @@ -343,20 +335,13 @@ class TensorCoreIntrinEmitter: for j in T.serial(local_size_a): mi, mk = mma_load_layout(tx, j) if a_transposed: - A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wk + mk, - A_base1 + wi + mi] + A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wk + mk, A_base1 + wi + mi] else: - A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wi + mi, - A_base1 + wk + mk] + A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wi + mi, A_base1 + wk + mk] return _warp_ldmatrix_a(A_local_buf, A_region, ki, thread_binding, rk) - def ldmatrix_b(self, - B_local_buf: Buffer, - B_shared_buf: Buffer | BufferRegion, - ki: PrimExpr, - rk: PrimExpr | None = 0): - + def ldmatrix_b(self, B_local_buf: Buffer, B_shared_buf: Buffer | BufferRegion, ki: PrimExpr, rk: PrimExpr | None = 0): # Fast path for fp64: no ldmatrix support, do direct per-lane loads if DataType(self.b_dtype).bits == 64: warp_col_tiles = self.warp_col_tiles @@ -411,7 +396,7 @@ class TensorCoreIntrinEmitter: B_base0 = B_region.region[-2].min B_base1 = B_region.region[-1].min B_stride_last = B_buf.shape[-1] - replicate_b = (self.n_dim == 16) + replicate_b = self.n_dim == 16 # ldmatrix cannot be used for int8 + trans case. ldmatrix_available = not (DataType(b_dtype).bits != 16 and not b_transposed) @@ -448,9 +433,7 @@ class TensorCoreIntrinEmitter: ) if ldmatrix_available: - B_shared_buf_elem = B_buf[B_base0 + wi, - B_base1 + wk] if b_transposed else B_buf[B_base0 + wk, - B_base1 + wi] + B_shared_buf_elem = B_buf[B_base0 + wi, B_base1 + wk] if b_transposed else B_buf[B_base0 + wk, B_base1 + wi] T.ptx_ldmatrix( b_dtype, @@ -469,19 +452,13 @@ class TensorCoreIntrinEmitter: for j in T.serial(local_size_b): mi, mk = mma_load_layout(tx, j) if b_transposed: - B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wi + mi, - B_base1 + wk + mk] + B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wi + mi, B_base1 + wk + mk] else: - B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wk + mk, - B_base1 + wi + mi] + B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wk + mk, B_base1 + wi + mi] return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) - def mma(self, - A_local_buf: Buffer, - B_local_buf: Buffer, - C_local_buf: Buffer, - k_inner: PrimExpr | None = 0): + def mma(self, A_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_inner: PrimExpr | None = 0): warp_rows = self.warp_rows warp_cols = self.warp_cols local_size_a = self.local_size_a @@ -492,7 +469,7 @@ class TensorCoreIntrinEmitter: accum_dtype = self.accum_dtype accum_dtype_abbrv = self.accum_dtype_abbrv mma_prefix = self.mma_prefix - replicate_b = (self.n_dim == 16) + replicate_b = self.n_dim == 16 a_is_fragment = is_fragment(A_local_buf) b_is_fragment = is_fragment(B_local_buf) @@ -532,8 +509,7 @@ class TensorCoreIntrinEmitter: B_local_buf.data, b_local_stride + j * local_size_b + lift(local_size_b) // 2, C_local_buf.data, - i * warp_cols * local_size_out + j * local_size_out + - lift(local_size_out) // 2, + i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, T.bool(False), # saturate ) @@ -568,14 +544,13 @@ class TensorCoreIntrinEmitter: local_id = local_id_o * 2 + local_id_i row, col = T.meta_var(mma_store_index_map(tx, local_id)) if C_buf_dims == 2: - C_buf[(warp_m * warp_rows + i) * M_DIM + row, - (warp_n * warp_cols + j) * n_dim + - col] = C_local_buf[i * (warp_cols * local_size_out) + - j * local_size_out + local_id] + C_buf[(warp_m * warp_rows + i) * M_DIM + row, (warp_n * warp_cols + j) * n_dim + col] = C_local_buf[ + i * (warp_cols * local_size_out) + j * local_size_out + local_id + ] else: - C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, - col] = C_local_buf[i * (warp_cols * local_size_out) + - j * local_size_out + local_id] + C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, col] = C_local_buf[ + i * (warp_cols * local_size_out) + j * local_size_out + local_id + ] @T.macro def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding): @@ -588,15 +563,15 @@ class TensorCoreIntrinEmitter: C_buf[ (pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row, (pid_n * BLOCK_N + warp_n * warp_cols + j) * n_dim + col, - ] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + - local_id] + ] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + local_id] - return (_warp_stmatrix_global(C_local_buf, C_buf, thread_binding) - if is_global else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding)) + return ( + _warp_stmatrix_global(C_local_buf, C_buf, thread_binding) + if is_global + else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding) + ) - def make_mma_load_layout(self, - local_buf: Buffer, - matrix: Literal["A", "B"] = "A") -> T.Fragment: + def make_mma_load_layout(self, local_buf: Buffer, matrix: Literal["A", "B"] = "A") -> T.Fragment: """ Create a layout function for storing MMA results into a fragment buffer. This layout is used in conjunction with `inverse_mma_store_layout` to @@ -619,6 +594,7 @@ class TensorCoreIntrinEmitter: If `local_buf` is not detected to be a fragment buffer. """ from tilelang.utils import is_fragment + assert matrix in ["A", "B"], "matrix should be either A or B" matrix_is_a: bool = matrix == "A" matrix_is_b: bool = matrix == "B" @@ -655,11 +631,9 @@ class TensorCoreIntrinEmitter: # so the b matrix expected a transposed basic layout transform_func: Callable = None if matrix_is_a: - transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( - j, i) + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i) elif matrix_is_b: - transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b( - j, i) + transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i) else: raise ValueError(f"Unsupported matrix {matrix}") @@ -706,31 +680,19 @@ class TensorCoreIntrinEmitter: replicate = block_col_warps if matrix_is_a else block_row_warps if is_sr_axis_order: - warp_fragment = base_fragment.repeat([warp_s, warp_r], - repeat_on_thread=False, - lower_dim_first=False) + warp_fragment = base_fragment.repeat([warp_s, warp_r], repeat_on_thread=False, lower_dim_first=False) if matrix_is_a: - block_fragment = warp_fragment.repeat([block_s, 1], - repeat_on_thread=True, - lower_dim_first=True).replicate(replicate) + block_fragment = warp_fragment.repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True).replicate(replicate) elif matrix_is_b: - block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], - repeat_on_thread=True, - lower_dim_first=True) + block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True) else: raise ValueError(f"Unsupported matrix type {matrix}") else: - warp_fragment = base_fragment.repeat([warp_r, warp_s], - repeat_on_thread=False, - lower_dim_first=True) + warp_fragment = base_fragment.repeat([warp_r, warp_s], repeat_on_thread=False, lower_dim_first=True) if matrix_is_a: - block_fragment = warp_fragment.repeat([1, block_s], - repeat_on_thread=True, - lower_dim_first=True).replicate(replicate) + block_fragment = warp_fragment.repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True).replicate(replicate) elif matrix_is_b: - block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], - repeat_on_thread=True, - lower_dim_first=True) + block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True) else: raise ValueError(f"Unsupported matrix type {matrix}") @@ -761,8 +723,7 @@ class TensorCoreIntrinEmitter: from tilelang.utils import is_fragment shape = local_buf.shape - assert is_fragment( - local_buf), f"local_buf {local_buf} must be a fragment, but got {local_buf.scope()}" + assert is_fragment(local_buf), f"local_buf {local_buf} must be a fragment, but got {local_buf.scope()}" inverse_mma_store_layout = self.get_store_index_map(inverse=True) micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y @@ -954,10 +915,12 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter): ".b16", A_local_buf.data, i * local_size_a, - T.address_of(A_shared_buf[ - warp_m * warp_row_tiles + i * micro_size_x, - rk * chunk + ki * micro_size_k, - ]), + T.address_of( + A_shared_buf[ + warp_m * warp_row_tiles + i * micro_size_x, + rk * chunk + ki * micro_size_k, + ] + ), get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed), ) elif transform_kind_a == TransformKind.InterWarpTransform: @@ -1019,10 +982,8 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter): warp_m * warp_rows + j, rk * (chunk // micro_size_k) + ki, ) - rii, rjj = (tx * local_size_a + - local_id) // micro_size_k, (tx * local_size_a + local_id) % ( - micro_size_k) - A_local_buf[j * local_size_a + local_id] = (A_shared_buf[ri, rj, rii, rjj]) + rii, rjj = (tx * local_size_a + local_id) // micro_size_k, (tx * local_size_a + local_id) % (micro_size_k) + A_local_buf[j * local_size_a + local_id] = A_shared_buf[ri, rj, rii, rjj] else: raise ValueError("Unsupported TransformKind for Input A") @@ -1131,12 +1092,11 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter): warp_n * warp_cols + j, rk * (chunk // micro_size_k) + ki, ) - rii, rjj = (tx * local_size_dequantize + - local_id) // (micro_size_k // num_elems_per_byte), ( - tx * local_size_dequantize + local_id) % ( - micro_size_k // num_elems_per_byte) - B_local_buf[j * local_size_dequantize + local_id] = ( - B_shared_buf[ri, rj, rii, rjj]) + rii, rjj = ( + (tx * local_size_dequantize + local_id) // (micro_size_k // num_elems_per_byte), + (tx * local_size_dequantize + local_id) % (micro_size_k // num_elems_per_byte), + ) + B_local_buf[j * local_size_dequantize + local_id] = B_shared_buf[ri, rj, rii, rjj] else: raise ValueError("Unsupported TransformKind for Input B") @@ -1195,7 +1155,6 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter): class INT4TensorCoreIntrinEmitter(TensorCoreIntrinEmitter): - def mma(self, A_local_buf, B_local_buf, C_local_buf): warp_rows = self.warp_rows warp_cols = self.warp_cols @@ -1298,9 +1257,7 @@ class INT4TensorCoreIntrinEmitter(TensorCoreIntrinEmitter): class INT4TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitterWithLadderTransform): - def mma(self, A_local_buf, B_local_buf, C_local_buf): - warp_rows = self.warp_rows warp_cols = self.warp_cols local_size_a = self.local_size_a diff --git a/tilelang/intrinsics/mma_sm70_layout.py b/tilelang/intrinsics/mma_sm70_layout.py index e7a57da..8029234 100644 --- a/tilelang/intrinsics/mma_sm70_layout.py +++ b/tilelang/intrinsics/mma_sm70_layout.py @@ -17,10 +17,8 @@ def shared_16x4_to_mma_b_32x4_layout_trans(row, col, rep): def mma_32x8_to_shared_16x16_layout_fp32(thread_id, local_id): - row = (thread_id % 2) + ( - (local_id // 2 % 2) * 2) + 4 * (thread_id // 16) + (thread_id % 16 // 4) % 2 * 8 - col = (thread_id % 4 // 2) * 2 + (thread_id % 16 // 8) * 4 + (local_id % - 2) + (local_id // 4) * 8 + row = (thread_id % 2) + ((local_id // 2 % 2) * 2) + 4 * (thread_id // 16) + (thread_id % 16 // 4) % 2 * 8 + col = (thread_id % 4 // 2) * 2 + (thread_id % 16 // 8) * 4 + (local_id % 2) + (local_id // 4) * 8 return row, col @@ -31,7 +29,7 @@ def mma_32x8_to_shared_16x16_layout_fp16(thread_id, local_id): def mma_load_a_32x4_to_shared_16x4_layout(thread_id, local_id): - row = (thread_id % 4) + (4 * (((thread_id // 16 + thread_id % 16 // 4 * 2)) % 4)) + row = (thread_id % 4) + (4 * ((thread_id // 16 + thread_id % 16 // 4 * 2) % 4)) col = local_id return row, col diff --git a/tilelang/intrinsics/mma_sm70_macro_generator.py b/tilelang/intrinsics/mma_sm70_macro_generator.py index 7824808..3186adb 100644 --- a/tilelang/intrinsics/mma_sm70_macro_generator.py +++ b/tilelang/intrinsics/mma_sm70_macro_generator.py @@ -147,18 +147,15 @@ class TensorCoreIntrinEmitter: def get_store_index_map(self, inverse: bool = False) -> IndexMap: warp_size, local_size_c = self.WARP_SIZE, self.local_size_out index_map = IndexMap.from_func( - mma_32x8_to_shared_16x16_layout_fp32 - if self.accum_dtype == "float32" else mma_32x8_to_shared_16x16_layout_fp16, - index_dtype="int32") + mma_32x8_to_shared_16x16_layout_fp32 if self.accum_dtype == "float32" else mma_32x8_to_shared_16x16_layout_fp16, + index_dtype="int32", + ) if not inverse: return index_map inverse_index_map = index_map.inverse([warp_size, local_size_c]) return inverse_index_map - def extract_thread_binding( - self, - thread_id: PrimExpr, - is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]: + def extract_thread_binding(self, thread_id: PrimExpr, is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]: """ is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] @@ -187,11 +184,7 @@ class TensorCoreIntrinEmitter: ) return lane_id, warp_n, warp_m - def ldmatrix_a(self, - A_local_buf: Buffer, - A_shared_buf: Buffer | BufferRegion, - ki: PrimExpr, - rk: PrimExpr | None = 0): + def ldmatrix_a(self, A_local_buf: Buffer, A_shared_buf: Buffer | BufferRegion, ki: PrimExpr, rk: PrimExpr | None = 0): warp_row_tiles = self.warp_row_tiles warp_rows = self.warp_rows chunk = self.chunk @@ -231,11 +224,7 @@ class TensorCoreIntrinEmitter: return _warp_ldmatrix_a(A_local_buf, A_region, ki, thread_binding, rk) - def ldmatrix_b(self, - B_local_buf: Buffer, - B_shared_buf: Buffer | BufferRegion, - ki: PrimExpr, - rk: PrimExpr | None = 0): + def ldmatrix_b(self, B_local_buf: Buffer, B_shared_buf: Buffer | BufferRegion, ki: PrimExpr, rk: PrimExpr | None = 0): warp_col_tiles = self.warp_col_tiles warp_cols = self.warp_cols chunk = self.chunk @@ -274,20 +263,14 @@ class TensorCoreIntrinEmitter: for j in T.vectorized(local_size_b): if b_transposed: mi, mk = mma_load_layout(tx, j) - B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wi + mi, - B_base1 + wk + mk] + B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wi + mi, B_base1 + wk + mk] else: mk, mi = mma_load_layout(tx, j) - B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wk + mk, - B_base1 + wi + mi] + B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wk + mk, B_base1 + wi + mi] return _warp_ldmatrix_b(B_local_buf, B_region, ki, thread_binding, rk) - def mma(self, - A_local_buf: Buffer, - B_local_buf: Buffer, - C_local_buf: Buffer, - k_inner: PrimExpr | None = 0): + def mma(self, A_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_inner: PrimExpr | None = 0): warp_rows = self.warp_rows warp_cols = self.warp_cols local_size_a = self.local_size_a @@ -326,9 +309,7 @@ class TensorCoreIntrinEmitter: return _warp_mma(A_local_buf, B_local_buf, C_local_buf) - def make_mma_load_layout(self, - local_buf: Buffer, - matrix: Literal["A", "B"] = "A") -> T.Fragment: + def make_mma_load_layout(self, local_buf: Buffer, matrix: Literal["A", "B"] = "A") -> T.Fragment: """ Create a layout function for storing MMA results into a fragment buffer. This layout is used in conjunction with `inverse_mma_store_layout` to @@ -351,6 +332,7 @@ class TensorCoreIntrinEmitter: If `local_buf` is not detected to be a fragment buffer. """ from tilelang.utils import is_fragment + assert matrix in ["A", "B"], "matrix should be either A or B" matrix_is_a: bool = matrix == "A" matrix_is_b: bool = matrix == "B" @@ -383,11 +365,9 @@ class TensorCoreIntrinEmitter: # so the b matrix expected a transposed basic layout transform_func: Callable = None if matrix_is_a: - transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( - j, i) + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i) elif matrix_is_b: - transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_rs_b( - i, j) + transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_rs_b(i, j) else: raise ValueError(f"Unsupported matrix {matrix}") @@ -413,9 +393,8 @@ class TensorCoreIntrinEmitter: return lane_id, local_id base_fragment = T.Fragment( - [micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s], - forward_fn=forward, - replicate=2) + [micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s], forward_fn=forward, replicate=2 + ) warp_rows, warp_cols = self.warp_rows, self.warp_cols chunk = self.chunk @@ -426,31 +405,19 @@ class TensorCoreIntrinEmitter: replicate = block_col_warps if matrix_is_a else block_row_warps if is_sr_axis_order: - warp_fragment = base_fragment.repeat([warp_s, warp_r], - repeat_on_thread=False, - lower_dim_first=False) + warp_fragment = base_fragment.repeat([warp_s, warp_r], repeat_on_thread=False, lower_dim_first=False) if matrix_is_a: - block_fragment = warp_fragment.repeat([block_s, 1], - repeat_on_thread=True, - lower_dim_first=True).replicate(replicate) + block_fragment = warp_fragment.repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True).replicate(replicate) elif matrix_is_b: - block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], - repeat_on_thread=True, - lower_dim_first=True) + block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True) else: raise ValueError(f"Unsupported matrix type {matrix}") else: - warp_fragment = base_fragment.repeat([warp_r, warp_s], - repeat_on_thread=False, - lower_dim_first=True) + warp_fragment = base_fragment.repeat([warp_r, warp_s], repeat_on_thread=False, lower_dim_first=True) if matrix_is_a: - block_fragment = warp_fragment.repeat([1, block_s], - repeat_on_thread=True, - lower_dim_first=True).replicate(replicate) + block_fragment = warp_fragment.repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True).replicate(replicate) elif matrix_is_b: - block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], - repeat_on_thread=True, - lower_dim_first=True) + block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True) else: raise ValueError(f"Unsupported matrix type {matrix}") diff --git a/tilelang/intrinsics/mma_sp_layout.py b/tilelang/intrinsics/mma_sp_layout.py index bae86bf..58034e7 100644 --- a/tilelang/intrinsics/mma_sp_layout.py +++ b/tilelang/intrinsics/mma_sp_layout.py @@ -72,56 +72,47 @@ def get_logical_id_32bit(thread_id: int) -> int: return (thread_id // 4) * 2 + (thread_id % 4) % 2 -def metadata_8bit_load_32x4_to_shared_16x4_layout_32bit(thread_id: int, - local_id: int) -> tuple[int, int]: +def metadata_8bit_load_32x4_to_shared_16x4_layout_32bit(thread_id: int, local_id: int) -> tuple[int, int]: logical_id = get_logical_id_32bit(thread_id) row = logical_id // 4 + local_id * 8 col = logical_id % 4 return row, col -def metadata_16bit_load_32x2_to_shared_16x2_layout_32bit(thread_id: int, - local_id: int) -> tuple[int, int]: +def metadata_16bit_load_32x2_to_shared_16x2_layout_32bit(thread_id: int, local_id: int) -> tuple[int, int]: logical_id = get_logical_id_32bit(thread_id) row = logical_id // 2 + local_id * 8 col = logical_id % 2 return row, col -def metadata_8bit_load_32x4_to_shared_16x4_layout_16bit(thread_id: int, - local_id: int) -> tuple[int, int]: - return metadata_8bit_load_32x4_to_shared_16x4_layout_32bit( - thread_id, local_id) # same mapping for 16bit and 32bit +def metadata_8bit_load_32x4_to_shared_16x4_layout_16bit(thread_id: int, local_id: int) -> tuple[int, int]: + return metadata_8bit_load_32x4_to_shared_16x4_layout_32bit(thread_id, local_id) # same mapping for 16bit and 32bit -def metadata_16bit_load_32x2_to_shared_16x2_layout_16bit(thread_id: int, - local_id: int) -> tuple[int, int]: - return metadata_16bit_load_32x2_to_shared_16x2_layout_32bit( - thread_id, local_id) # same mapping for 16bit and 32bit +def metadata_16bit_load_32x2_to_shared_16x2_layout_16bit(thread_id: int, local_id: int) -> tuple[int, int]: + return metadata_16bit_load_32x2_to_shared_16x2_layout_32bit(thread_id, local_id) # same mapping for 16bit and 32bit def get_logical_id_8bit(thread_id: int) -> int: return thread_id -def metadata_8bit_load_32x4_to_shared_16x4_layout_8bit(thread_id: int, - local_id: int) -> tuple[int, int]: +def metadata_8bit_load_32x4_to_shared_16x4_layout_8bit(thread_id: int, local_id: int) -> tuple[int, int]: logical_id = get_logical_id_8bit(thread_id) row = logical_id // 2 + local_id * 8 col = (logical_id % 4) // 2 * 4 + local_id return row, col -def metadata_16bit_load_32x2_to_shared_16x4_layout_8bit(thread_id: int, - local_id: int) -> tuple[int, int]: +def metadata_16bit_load_32x2_to_shared_16x4_layout_8bit(thread_id: int, local_id: int) -> tuple[int, int]: logical_id = get_logical_id_8bit(thread_id) row = logical_id // 2 + local_id * 8 col = (logical_id % 4) // 2 * 2 + local_id return row, col -def metadata_32bit_load_32x1_to_shared_16x2_layout_8bit(thread_id: int, - local_id: int) -> tuple[int, int]: +def metadata_32bit_load_32x1_to_shared_16x2_layout_8bit(thread_id: int, local_id: int) -> tuple[int, int]: # local_id is always 0 logical_id = get_logical_id_8bit(thread_id) row = logical_id // 4 + (logical_id % 2) * 8 diff --git a/tilelang/intrinsics/mma_sp_macro_generator.py b/tilelang/intrinsics/mma_sp_macro_generator.py index 629d95d..ea7aa89 100644 --- a/tilelang/intrinsics/mma_sp_macro_generator.py +++ b/tilelang/intrinsics/mma_sp_macro_generator.py @@ -190,8 +190,7 @@ class SparseTensorCoreIntrinEmitter: def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32): self.local_size_a = (m_dim * k_dim) // warp_size // self.SPARSE_FACTOR - self.local_size_e = ( - m_dim * k_dim) // self.e_factor // warp_size * self.E_REPLICATE_FACTOR[self.a_dtype] + self.local_size_e = (m_dim * k_dim) // self.e_factor // warp_size * self.E_REPLICATE_FACTOR[self.a_dtype] self.local_size_b = (n_dim * k_dim) // warp_size self.local_size_out = (m_dim * n_dim) // warp_size @@ -257,10 +256,7 @@ class SparseTensorCoreIntrinEmitter: inverse_index_map = index_map.inverse([warp_size, local_size_c]) return inverse_index_map - def extract_thread_binding( - self, - thread_id: PrimExpr, - is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]: + def extract_thread_binding(self, thread_id: PrimExpr, is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]: """ is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] @@ -330,8 +326,7 @@ class SparseTensorCoreIntrinEmitter: for i in T.serial(warp_rows): # Assign A_shared_buf_elem - wi, wk = warp_m * warp_row_tiles + i * micro_size_x, ( - rk * warp_k + ki * micro_size_k) // self.SPARSE_FACTOR + wi, wk = warp_m * warp_row_tiles + i * micro_size_x, (rk * warp_k + ki * micro_size_k) // self.SPARSE_FACTOR A_shared_buf_elem = A_shared_buf[wk, wi] if a_transposed else A_shared_buf[wi, wk] if ldmatrix_available: @@ -348,10 +343,9 @@ class SparseTensorCoreIntrinEmitter: else: for j in T.serial(local_size_a): mi, mk = mma_load_layout(tx, j) - A_local_buf[i * local_size_a + - j] = A_shared_buf[wk + mk, wi + - mi] if a_transposed else A_shared_buf[wi + mi, - wk + mk] + A_local_buf[i * local_size_a + j] = ( + A_shared_buf[wk + mk, wi + mi] if a_transposed else A_shared_buf[wi + mi, wk + mk] + ) return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) @@ -412,14 +406,10 @@ class SparseTensorCoreIntrinEmitter: tx, _, warp_m = self.extract_thread_binding(thread_binding) for i in T.serial(warp_rows): # Assign E_shared_buf_elem - wi, wk = warp_m * warp_row_tiles + i * micro_size_x, ( - rk * warp_k + ki * micro_size_k) // self.e_factor + wi, wk = warp_m * warp_row_tiles + i * micro_size_x, (rk * warp_k + ki * micro_size_k) // self.e_factor for j in T.serial(local_size_e): mi, mk = mma_load_layout(tx, j) - E_local_buf[i * local_size_e + - j] = E_shared_buf[wk + mk, - wi + mi] if trans else E_shared_buf[wi + mi, - wk + mk] + E_local_buf[i * local_size_e + j] = E_shared_buf[wk + mk, wi + mi] if trans else E_shared_buf[wi + mi, wk + mk] return _warp_ldmatrix_e(E_local_buf, E_shared_buf, ki, thread_binding, rk) @@ -433,7 +423,7 @@ class SparseTensorCoreIntrinEmitter: b_dtype = self.b_dtype b_transposed = self.b_transposed thread_binding = self.get_thread_binding() - replicate_b = (self.n_dim == 16) + replicate_b = self.n_dim == 16 # ldmatrix cannot be used for int8 + trans case. ldmatrix_available = not (DataType(b_dtype).bits != 16 and not b_transposed) @@ -470,8 +460,7 @@ class SparseTensorCoreIntrinEmitter: ) if ldmatrix_available: - B_shared_buf_elem = B_shared_buf[wi, wk] if b_transposed else B_shared_buf[wk, - wi] + B_shared_buf_elem = B_shared_buf[wi, wk] if b_transposed else B_shared_buf[wk, wi] if replicate_b: T.ptx_ldmatrix( @@ -493,9 +482,7 @@ class SparseTensorCoreIntrinEmitter: B_local_buf.data, i * local_size_b + lift(local_size_b) // 2, T.address_of(B_shared_buf_elem), - get_ldmatrix_offset_b("B", tx, - lift(local_size_b) // 2, stride, b_dtype, - b_transposed), + get_ldmatrix_offset_b("B", tx, lift(local_size_b) // 2, stride, b_dtype, b_transposed), ) else: T.ptx_ldmatrix( @@ -514,19 +501,13 @@ class SparseTensorCoreIntrinEmitter: # must be transposed. for j in T.serial(local_size_b): mi, mk = mma_load_layout(tx, j) - B_local_buf[i * local_size_b + - j] = B_shared_buf[wi + mi, wk + - mk] if b_transposed else B_shared_buf[wk + mk, - wi + mi] + B_local_buf[i * local_size_b + j] = ( + B_shared_buf[wi + mi, wk + mk] if b_transposed else B_shared_buf[wk + mk, wi + mi] + ) return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) - def mma_sp(self, - A_local_buf: Buffer, - E_local_buf: Buffer, - B_local_buf: Buffer, - C_local_buf: Buffer, - k_inner: PrimExpr = 0): + def mma_sp(self, A_local_buf: Buffer, E_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_inner: PrimExpr = 0): warp_rows = self.warp_rows warp_cols = self.warp_cols local_size_a = self.local_size_a @@ -538,7 +519,7 @@ class SparseTensorCoreIntrinEmitter: accum_dtype = self.accum_dtype accum_dtype_abbrv = self.accum_dtype_abbrv mma_prefix = self.mma_prefix - replicate_b = (self.n_dim == 16) + replicate_b = self.n_dim == 16 a_is_fragment = is_fragment(A_local_buf) e_is_fragment = is_fragment(E_local_buf) @@ -584,8 +565,7 @@ class SparseTensorCoreIntrinEmitter: B_local_buf.data, b_local_stride + j * local_size_b + lift(local_size_b) // 2, C_local_buf.data, - i * warp_cols * local_size_out + j * local_size_out + - lift(local_size_out) // 2, + i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, E_local_buf.data, # metadata e_local_stride + i * local_size_e, # metadata offset self.SPARSE_SELECTOR, # sparse_selector @@ -623,14 +603,13 @@ class SparseTensorCoreIntrinEmitter: local_id = local_id_o * 2 + local_id_i row, col = T.meta_var(mma_store_index_map(tx, local_id)) if C_buf_dims == 2: - C_buf[(warp_m * warp_rows + i) * M_DIM + row, - (warp_n * warp_cols + j) * n_dim + - col] = C_local_buf[i * (warp_cols * local_size_out) + - j * local_size_out + local_id] + C_buf[(warp_m * warp_rows + i) * M_DIM + row, (warp_n * warp_cols + j) * n_dim + col] = C_local_buf[ + i * (warp_cols * local_size_out) + j * local_size_out + local_id + ] else: - C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, - col] = C_local_buf[i * (warp_cols * local_size_out) + - j * local_size_out + local_id] + C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, col] = C_local_buf[ + i * (warp_cols * local_size_out) + j * local_size_out + local_id + ] @T.macro def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding): @@ -643,15 +622,15 @@ class SparseTensorCoreIntrinEmitter: C_buf[ (pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row, (pid_n * BLOCK_N + warp_n * warp_cols + j) * n_dim + col, - ] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + - local_id] + ] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + local_id] - return (_warp_stmatrix_global(C_local_buf, C_buf, thread_binding) - if is_global else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding)) + return ( + _warp_stmatrix_global(C_local_buf, C_buf, thread_binding) + if is_global + else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding) + ) - def make_mma_load_layout(self, - local_buf: Buffer, - matrix: Literal["A", "B"] = "A") -> T.Fragment: + def make_mma_load_layout(self, local_buf: Buffer, matrix: Literal["A", "B"] = "A") -> T.Fragment: """ Create a layout function for storing MMA results into a fragment buffer. This layout is used in conjunction with `inverse_mma_store_layout` to @@ -674,6 +653,7 @@ class SparseTensorCoreIntrinEmitter: If `local_buf` is not detected to be a fragment buffer. """ from tilelang.utils import is_fragment + assert matrix in ["A", "B"], "matrix should be either A or B" matrix_is_a: bool = matrix == "A" matrix_is_b: bool = matrix == "B" @@ -710,11 +690,9 @@ class SparseTensorCoreIntrinEmitter: # so the b matrix expected a transposed basic layout transform_func: Callable = None if matrix_is_a: - transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( - j, i) + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i) elif matrix_is_b: - transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b( - j, i) + transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i) else: raise ValueError(f"Unsupported matrix {matrix}") @@ -747,7 +725,8 @@ class SparseTensorCoreIntrinEmitter: return local_id base_fragment = T.Fragment( - [micro_size_s, micro_size_r // 2 if matrix_is_a else micro_size_r] if is_sr_axis_order + [micro_size_s, micro_size_r // 2 if matrix_is_a else micro_size_r] + if is_sr_axis_order else [micro_size_r // 2 if matrix_is_a else micro_size_r, micro_size_s], forward_thread_fn=forward_thread, forward_index_fn=forward_index, @@ -762,31 +741,19 @@ class SparseTensorCoreIntrinEmitter: replicate = block_col_warps if matrix_is_a else block_row_warps if is_sr_axis_order: - warp_fragment = base_fragment.repeat([warp_s, warp_r], - repeat_on_thread=False, - lower_dim_first=False) + warp_fragment = base_fragment.repeat([warp_s, warp_r], repeat_on_thread=False, lower_dim_first=False) if matrix_is_a: - block_fragment = warp_fragment.repeat([block_s, 1], - repeat_on_thread=True, - lower_dim_first=True).replicate(replicate) + block_fragment = warp_fragment.repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True).replicate(replicate) elif matrix_is_b: - block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], - repeat_on_thread=True, - lower_dim_first=True) + block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True) else: raise ValueError(f"Unsupported matrix type {matrix}") else: - warp_fragment = base_fragment.repeat([warp_r, warp_s], - repeat_on_thread=False, - lower_dim_first=True) + warp_fragment = base_fragment.repeat([warp_r, warp_s], repeat_on_thread=False, lower_dim_first=True) if matrix_is_a: - block_fragment = warp_fragment.repeat([1, block_s], - repeat_on_thread=True, - lower_dim_first=True).replicate(replicate) + block_fragment = warp_fragment.repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True).replicate(replicate) elif matrix_is_b: - block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], - repeat_on_thread=True, - lower_dim_first=True) + block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True) else: raise ValueError(f"Unsupported matrix type {matrix}") diff --git a/tilelang/intrinsics/tcgen05_macro_generator.py b/tilelang/intrinsics/tcgen05_macro_generator.py index 966f4dc..26208d6 100644 --- a/tilelang/intrinsics/tcgen05_macro_generator.py +++ b/tilelang/intrinsics/tcgen05_macro_generator.py @@ -88,9 +88,22 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): is_m_first: bool = False, thread_var: Var | None = None, ): - super().__init__(a_dtype, b_dtype, accum_dtype, a_transposed, b_transposed, block_row_warps, - block_col_warps, warp_row_tiles, warp_col_tiles, chunk, reduce_k, - num_elems_per_byte, is_m_first, thread_var) + super().__init__( + a_dtype, + b_dtype, + accum_dtype, + a_transposed, + b_transposed, + block_row_warps, + block_col_warps, + warp_row_tiles, + warp_col_tiles, + chunk, + reduce_k, + num_elems_per_byte, + is_m_first, + thread_var, + ) def _assign_a_shared_layout(self, layout: Layout): self.a_shared_layout = layout @@ -137,13 +150,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): else: raise ValueError(f"Unsupported swizzle mode: {layout}") - def tcgen05mma(self, - A_buf: Buffer, - B_buf: Buffer, - C_local_buf: Buffer, - mbar, - clear_accum: PrimExpr = False): - + def tcgen05mma(self, A_buf: Buffer, B_buf: Buffer, C_local_buf: Buffer, mbar, clear_accum: PrimExpr = False): if is_tensor_memory(A_buf): return self.tcgen05mma_rs(A_buf, B_buf, C_local_buf, clear_accum) @@ -164,22 +171,20 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): elems_in_bits = DataType(self.a_dtype).bits elems_in_bytes = elems_in_bits // 8 a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes - b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none( - ) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes + b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes accum_dtype_in_bits = DataType(accum_dtype).bits meta = self.get_tcgen5_mma_meta(m_dim, n_dim, k_dim) if len(meta) != 5: raise ValueError( f"Unsupported TCGEN5MMA configuration for desc generation: M={m_dim}, N={n_dim}, " - f"K={k_dim}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}") + f"K={k_dim}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}" + ) atom_m, atom_n, atom_k, enable_ws, enable_2cta = (int(x) for x in meta) # by default, we utilize non-swizzle layout offset - a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * - elems_in_bytes) - a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 * - elems_in_bytes) + a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * elems_in_bytes) + a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 * elems_in_bytes) if not a_swizzle_mode.is_none(): # swizzle mode doesn't require LBO/SBO to be 1 @@ -202,11 +207,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): else: a_stride_byte_offset = 8 * elems_in_bytes * a_swizzle_atom_elems - b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * - elems_in_bytes) - b_stride_byte_offset = (8 * k_dim * - elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else - (8 * 8 * elems_in_bytes)) + b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * elems_in_bytes) + b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * elems_in_bytes)) if not b_swizzle_mode.is_none(): # swizzle mode doesn't require LBO/SBO to be 1 # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset @@ -312,21 +314,26 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): for ki in T.unroll(0, (k_dim // micro_size_k)): scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1)) A_elem_offset = ( - ki % ak_atom_size - ) * micro_size_k + i * atom_m * a_swizzle_atom_elems + ( - ki // ak_atom_size - ) * m_dim * a_swizzle_atom_elems if a_is_k_major else i * atom_m * k_dim + ki * a_swizzle_atom_elems * micro_size_k + (ki % ak_atom_size) * micro_size_k + + i * atom_m * a_swizzle_atom_elems + + (ki // ak_atom_size) * m_dim * a_swizzle_atom_elems + if a_is_k_major + else i * atom_m * k_dim + ki * a_swizzle_atom_elems * micro_size_k + ) - B_elem_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + ( - ki % bk_atom_size - ) * micro_size_k + j * atom_n * b_swizzle_atom_elems if b_is_k_major else ( - ki * b_swizzle_atom_elems * micro_size_k + j * atom_n * - (k_dim if n_dim // b_swizzle_atom_elems > 1 else 1)) + B_elem_offset = ( + (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + + (ki % bk_atom_size) * micro_size_k + + j * atom_n * b_swizzle_atom_elems + if b_is_k_major + else ( + ki * b_swizzle_atom_elems * micro_size_k + j * atom_n * (k_dim if n_dim // b_swizzle_atom_elems > 1 else 1) + ) + ) A_byte_offset = A_elem_offset * elems_in_bytes B_byte_offset = B_elem_offset * elems_in_bytes - C_offset = (i * n_dim + j * tmem_col_step - ) * accum_dtype_in_bits // 32 # 32 bits per tmem bank + C_offset = (i * n_dim + j * tmem_col_step) * accum_dtype_in_bits // 32 # 32 bits per tmem bank T.ptx_tcgen05_mma_ss( a_dtype_abbrv, @@ -373,8 +380,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): """ assert is_tensor_memory(tmem_buf), "tmem_buf must reside in tensor memory (shared.tmem)" if len(tmem_buf.shape) != 2: - raise ValueError( - f"TCGEN5MMA expects a 2-D tensor-memory buffer, got shape {tmem_buf.shape}") + raise ValueError(f"TCGEN5MMA expects a 2-D tensor-memory buffer, got shape {tmem_buf.shape}") m = int(tmem_buf.shape[0]) n = int(tmem_buf.shape[1]) @@ -382,14 +388,13 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): meta = self.get_tcgen5_mma_meta(m, n, k) if len(meta) != 5: - raise ValueError(f"Unsupported TCGEN5MMA configuration: M={m}, N={n}, K={k}, " - f"A dtype={self.a_dtype}, accum dtype={self.accum_dtype}") + raise ValueError( + f"Unsupported TCGEN5MMA configuration: M={m}, N={n}, K={k}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}" + ) atom_m, atom_n, _, _, _ = (int(x) for x in meta) if m % atom_m != 0 or n % atom_n != 0: - raise ValueError( - f"Invalid TCGEN5MMA store layout for shape ({m}, {n}) with atoms ({atom_m}, {atom_n})" - ) + raise ValueError(f"Invalid TCGEN5MMA store layout for shape ({m}, {n}) with atoms ({atom_m}, {atom_n})") def forward(i: PrimExpr, j: PrimExpr): atom_idx = (i // atom_m) + (j // atom_n) * (m // atom_m) @@ -422,11 +427,11 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): return Layout([m, n], forward) def get_tcgen5_mma_meta(self, m: int, n: int, k: int): - return _ffi_api.get_tcgen5_mma_meta( - int(m), int(n), int(k), DataType(self.a_dtype), DataType(self.accum_dtype)) + return _ffi_api.get_tcgen5_mma_meta(int(m), int(n), int(k), DataType(self.a_dtype), DataType(self.accum_dtype)) - def get_tcgen5_instr_desc(self, atom_m: int, atom_n: int, atom_k: int, a_is_k_major: bool, - b_is_k_major: bool, scale_in_a: int, scale_in_b: int) -> PrimExpr: + def get_tcgen5_instr_desc( + self, atom_m: int, atom_n: int, atom_k: int, a_is_k_major: bool, b_is_k_major: bool, scale_in_a: int, scale_in_b: int + ) -> PrimExpr: desc = _ffi_api.get_tcgen5_instr_desc( atom_m, atom_n, diff --git a/tilelang/intrinsics/utils.py b/tilelang/intrinsics/utils.py index 7fc9bab..fb24a4a 100644 --- a/tilelang/intrinsics/utils.py +++ b/tilelang/intrinsics/utils.py @@ -10,7 +10,7 @@ from .mma_layout import ( mma_store_32x8_to_shared_16x16_layout, mma_store_32x2_to_shared_8x8_layout_fp64, ) -from .mfma_layout import (thread_id_shared_access_64x4_to_16x16_layout_C_n_m) +from .mfma_layout import thread_id_shared_access_64x4_to_16x16_layout_C_n_m from .mma_layout import get_swizzle_layout # noqa: F401 from .mma_layout import make_mma_swizzle_layout # noqa: F401 diff --git a/tilelang/intrinsics/wgmma_macro_generator.py b/tilelang/intrinsics/wgmma_macro_generator.py index 51a90fb..483b6e7 100644 --- a/tilelang/intrinsics/wgmma_macro_generator.py +++ b/tilelang/intrinsics/wgmma_macro_generator.py @@ -15,9 +15,11 @@ from tilelang.layout import ( make_linear_layout, ) from tvm.runtime import convert -from tilelang.intrinsics.mma_layout import (shared_16x8_to_mma_32x4_layout_sr_a, - shared_16x16_to_mma_32x8_layout_sr_a, - shared_16x32_to_mma_32x16_layout_sr_a) +from tilelang.intrinsics.mma_layout import ( + shared_16x8_to_mma_32x4_layout_sr_a, + shared_16x16_to_mma_32x8_layout_sr_a, + shared_16x32_to_mma_32x16_layout_sr_a, +) lift = convert @@ -96,9 +98,22 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): is_m_first: bool | None = False, thread_var: Var | None = None, ): - super().__init__(a_dtype, b_dtype, accum_dtype, a_transposed, b_transposed, block_row_warps, - block_col_warps, warp_row_tiles, warp_col_tiles, chunk, reduce_k, - num_elems_per_byte, is_m_first, thread_var) + super().__init__( + a_dtype, + b_dtype, + accum_dtype, + a_transposed, + b_transposed, + block_row_warps, + block_col_warps, + warp_row_tiles, + warp_col_tiles, + chunk, + reduce_k, + num_elems_per_byte, + is_m_first, + thread_var, + ) self._initialize_wgmma_prefix(self.n_dim) def _assign_a_shared_layout(self, layout: Layout): @@ -112,12 +127,12 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): def _initialize_wgmma_prefix(self, n_dim: int = 16): inst_m, inst_n = 64, gcd(self.warp_col_tiles, 256) assert inst_n % 8 == 0, ( - f"inst_n must be a multiple of 8, got {inst_n} " - f"(block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})") + f"inst_n must be a multiple of 8, got {inst_n} (block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})" + ) # Validate inst_n: Hopper WGMMA supports n in [8, 256] and multiple of 8 assert 8 <= inst_n <= 256, ( - f"inst_n must be within [8, 256], got {inst_n} " - f"(block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})") + f"inst_n must be within [8, 256], got {inst_n} (block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})" + ) # 256 bits per instruction inst_k = 256 // DataType(self.a_dtype).bits self.wgmma_inst_m = inst_m @@ -160,13 +175,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): else: raise ValueError(f"Unsupported swizzle mode: {layout}") - def wgmma(self, - A_region: BufferRegion, - B_region: BufferRegion, - C_region: BufferRegion, - clear_accum: PrimExpr = False, - wg_wait: int = 0): - + def wgmma( + self, A_region: BufferRegion, B_region: BufferRegion, C_region: BufferRegion, clear_accum: PrimExpr = False, wg_wait: int = 0 + ): if is_fragment(A_region): return self.wgmma_rs(A_region, B_region, C_region, clear_accum, wg_wait) @@ -195,16 +206,13 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): elems_in_bytes = elems_in_bits // 8 a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes - b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none( - ) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes + b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes accum_bits = DataType(accum_dtype).bits accum_regs = ((m_dim // 64) * warp_cols * local_size_out * accum_bits + 31) // 32 # by default, we utilize non-swizzle layout offset - a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * - elems_in_bytes) - a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 * - elems_in_bytes) + a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * elems_in_bytes) + a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 * elems_in_bytes) if not a_swizzle_mode.is_none(): # swizzle mode doesn't require LBO/SBO to be 1 @@ -220,19 +228,15 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): if a_m_axis_atoms <= 1: a_leading_byte_offset = 0 else: - a_leading_byte_offset = 8 * a_swizzle_mode.swizzle_atom_size() * ( - a_swizzle_mode.swizzle_byte_size() // elems_in_bytes) + a_leading_byte_offset = 8 * a_swizzle_mode.swizzle_atom_size() * (a_swizzle_mode.swizzle_byte_size() // elems_in_bytes) if a_m_axis_atoms <= 1: a_stride_byte_offset = 8 * elems_in_bytes * m_dim else: a_stride_byte_offset = 8 * elems_in_bytes * a_swizzle_atom_elems - b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * - elems_in_bytes) - b_stride_byte_offset = (8 * k_dim * - elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else - (8 * 8 * elems_in_bytes)) + b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * elems_in_bytes) + b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * elems_in_bytes)) if not b_swizzle_mode.is_none(): # swizzle mode doesn't require LBO/SBO to be 1 # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset @@ -275,12 +279,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): desc_a = T.alloc_wgmma_desc() desc_b = T.alloc_wgmma_desc() - T.initialize_wgmma_descriptor(desc_a, A_ptr, a_swizzle_mode, - int(a_leading_byte_offset >> 4), - int(a_stride_byte_offset >> 4)) - T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode, - int(b_leading_byte_offset >> 4), - int(b_stride_byte_offset >> 4)) + T.initialize_wgmma_descriptor(desc_a, A_ptr, a_swizzle_mode, int(a_leading_byte_offset >> 4), int(a_stride_byte_offset >> 4)) + T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode, int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4)) T.warpgroup_fence_operand(C_buf, num_regs=accum_regs) T.warpgroup_arrive() @@ -291,21 +291,41 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): warp_i = (warp_m // 4) * num_inst_m + i warp_j = warp_n * num_inst_n + j A_offset = ( - ki % ak_atom_size - ) * micro_size_k + warp_i * 64 * a_swizzle_atom_elems + ( - ki // ak_atom_size - ) * m_dim * a_swizzle_atom_elems if a_is_k_major else warp_i * 64 * k_dim + ki * a_swizzle_atom_elems * micro_size_k - B_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + ( - ki % bk_atom_size - ) * micro_size_k + warp_j * wgmma_inst_n * b_swizzle_atom_elems if b_is_k_major else ( - ki * b_swizzle_atom_elems * micro_size_k + warp_j * wgmma_inst_n * - (k_dim if n_dim // b_swizzle_atom_elems > 1 else 1)) + (ki % ak_atom_size) * micro_size_k + + warp_i * 64 * a_swizzle_atom_elems + + (ki // ak_atom_size) * m_dim * a_swizzle_atom_elems + if a_is_k_major + else warp_i * 64 * k_dim + ki * a_swizzle_atom_elems * micro_size_k + ) + B_offset = ( + (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + + (ki % bk_atom_size) * micro_size_k + + warp_j * wgmma_inst_n * b_swizzle_atom_elems + if b_is_k_major + else ( + ki * b_swizzle_atom_elems * micro_size_k + + warp_j * wgmma_inst_n * (k_dim if n_dim // b_swizzle_atom_elems > 1 else 1) + ) + ) C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n # 4 warps as an unit - T.ptx_wgmma_ss(accum_dtype, wgmma_prefix, a_is_k_major, b_is_k_major, - a_dtype_abbrv, b_dtype_abbrv, accum_dtype_abbrv, desc_a.data, - (A_offset * elems_in_bytes) >> 4, desc_b.data, - (B_offset * elems_in_bytes) >> 4, C_buf.data, C_offset, - scale_out, scale_in_a, scale_in_b) + T.ptx_wgmma_ss( + accum_dtype, + wgmma_prefix, + a_is_k_major, + b_is_k_major, + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + desc_a.data, + (A_offset * elems_in_bytes) >> 4, + desc_b.data, + (B_offset * elems_in_bytes) >> 4, + C_buf.data, + C_offset, + scale_out, + scale_in_a, + scale_in_b, + ) T.warpgroup_commit_batch() if wg_wait >= 0: @@ -314,12 +334,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): return _warp_mma(A_ptr, B_ptr, C_buf) - def wgmma_rs(self, - A_region: BufferRegion, - B_region: BufferRegion, - C_region: BufferRegion, - clear_accum: PrimExpr = False, - wg_wait: int = 0): + def wgmma_rs( + self, A_region: BufferRegion, B_region: BufferRegion, C_region: BufferRegion, clear_accum: PrimExpr = False, wg_wait: int = 0 + ): local_size_a = self.local_size_a local_size_out = self.local_size_out a_dtype_abbrv = self.a_dtype_abbrv @@ -344,14 +361,10 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): b_is_k_major = self.b_transposed b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout) - b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none( - ) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes - - b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * - elems_in_bytes) - b_stride_byte_offset = (8 * k_dim * - elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else - (8 * 8 * elems_in_bytes)) + b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes + + b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * elems_in_bytes) + b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * elems_in_bytes)) if not b_swizzle_mode.is_none(): # swizzle mode doesn't require LBO/SBO to be 1 # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset @@ -390,9 +403,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) desc_b = T.alloc_wgmma_desc() - T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode, - int(b_leading_byte_offset >> 4), - int(b_stride_byte_offset >> 4)) + T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode, int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4)) T.warpgroup_fence_operand(A_buf, num_regs=a_regs) T.warpgroup_fence_operand(C_buf, num_regs=accum_regs) T.warpgroup_arrive() @@ -405,11 +416,15 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): A_offset = ki * warp_rows * local_size_a + i * local_size_a B_offset = ( - ki // bk_atom_size - ) * n_dim * b_swizzle_atom_elems + warp_j * wgmma_inst_n * b_swizzle_atom_elems + ( - ki % bk_atom_size) * micro_size_k if b_is_k_major else ( - ki * b_swizzle_atom_elems * micro_size_k + warp_j * wgmma_inst_n * - (k_dim if n_dim // b_swizzle_atom_elems > 1 else 1)) + (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + + warp_j * wgmma_inst_n * b_swizzle_atom_elems + + (ki % bk_atom_size) * micro_size_k + if b_is_k_major + else ( + ki * b_swizzle_atom_elems * micro_size_k + + warp_j * wgmma_inst_n * (k_dim if n_dim // b_swizzle_atom_elems > 1 else 1) + ) + ) C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n # 4 warps as an unit T.ptx_wgmma_rs( accum_dtype, @@ -460,6 +475,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): If `local_buf` is not detected to be a fragment buffer. """ from tilelang.utils import is_fragment + assert matrix in ["A"], "matrix should be A for WGMMA" dtype = self.a_dtype dtype_bits = DataType(dtype).bits @@ -488,8 +504,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): # the layout of mma.sync is row.col. # so the b matrix expected a transposed basic layout transform_func: Callable = None - transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( - j, i) + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i) assert is_fragment(local_buf), f"local_buf must be a fragment, but got {local_buf.scope()}" @@ -531,20 +546,12 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): replicate = block_col_warps if is_sr_axis_order: - warp_fragment = base_fragment.repeat([block_s, 1], - repeat_on_thread=True, - lower_dim_first=False).replicate(replicate) - block_fragment = warp_fragment.repeat([warp_s, warp_r], - repeat_on_thread=False, - lower_dim_first=False) + warp_fragment = base_fragment.repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=False).replicate(replicate) + block_fragment = warp_fragment.repeat([warp_s, warp_r], repeat_on_thread=False, lower_dim_first=False) else: # rs condition, transposed_a matrix - warp_fragment = base_fragment.repeat([1, block_s], - repeat_on_thread=True, - lower_dim_first=False).replicate(replicate) - block_fragment = warp_fragment.repeat([warp_r, warp_s], - repeat_on_thread=False, - lower_dim_first=True) + warp_fragment = base_fragment.repeat([1, block_s], repeat_on_thread=True, lower_dim_first=False).replicate(replicate) + block_fragment = warp_fragment.repeat([warp_r, warp_s], repeat_on_thread=False, lower_dim_first=True) return block_fragment diff --git a/tilelang/ir.py b/tilelang/ir.py index 08d4e96..b4a7de5 100644 --- a/tilelang/ir.py +++ b/tilelang/ir.py @@ -7,23 +7,19 @@ from tilelang import _ffi_api @tvm_ffi.register_object("tl.Fill") -class Fill(Node, Scriptable): - ... +class Fill(Node, Scriptable): ... @tvm_ffi.register_object("tl.AtomicAdd") -class AtomicAdd(Node, Scriptable): - ... +class AtomicAdd(Node, Scriptable): ... @tvm_ffi.register_object("tl.Copy") -class Copy(Node, Scriptable): - ... +class Copy(Node, Scriptable): ... @tvm_ffi.register_object("tl.Conv2DIm2Col") -class Conv2DIm2ColOp(Node, Scriptable): - ... +class Conv2DIm2ColOp(Node, Scriptable): ... @tvm_ffi.register_object("tl.GemmWarpPolicy") @@ -32,10 +28,8 @@ class GemmWarpPolicy(Node, Scriptable): m_warp: int n_warp: int - def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target, - is_wgmma: bool): - _ffi_api.GemmWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target, - is_wgmma) + def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target, is_wgmma: bool): + _ffi_api.GemmWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target, is_wgmma) return self.m_warp, self.n_warp @@ -45,48 +39,38 @@ class GemmSPWarpPolicy(Node, Scriptable): m_warp: int n_warp: int - def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target, - is_wgmma: bool, bits: int): - _ffi_api.GemmSPWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target, - is_wgmma, bits) + def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target, is_wgmma: bool, bits: int): + _ffi_api.GemmSPWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target, is_wgmma, bits) return self.m_warp, self.n_warp @tvm_ffi.register_object("tl.Gemm") -class Gemm(Node, Scriptable): - ... +class Gemm(Node, Scriptable): ... @tvm_ffi.register_object("tl.GemmSP") -class GemmSP(Node, Scriptable): - ... +class GemmSP(Node, Scriptable): ... @tvm_ffi.register_object("tl.FinalizeReducerOp") -class FinalizeReducerOp(Node, Scriptable): - ... +class FinalizeReducerOp(Node, Scriptable): ... @tvm_ffi.register_object("tl.ParallelOp") -class ParallelOp(Node, Scriptable): - ... +class ParallelOp(Node, Scriptable): ... @tvm_ffi.register_object("tl.ReduceOp") -class ReduceOp(Node, Scriptable): - ... +class ReduceOp(Node, Scriptable): ... @tvm_ffi.register_object("tl.CumSumOp") -class CumSumOp(Node, Scriptable): - ... +class CumSumOp(Node, Scriptable): ... @tvm_ffi.register_object("tl.RegionOp") -class RegionOp(Node, Scriptable): - ... +class RegionOp(Node, Scriptable): ... @tvm_ffi.register_object("tl.ReduceType") -class ReduceType(Node, Scriptable): - ... +class ReduceType(Node, Scriptable): ... diff --git a/tilelang/jit/__init__.py b/tilelang/jit/__init__.py index 09cbac5..9a5920d 100644 --- a/tilelang/jit/__init__.py +++ b/tilelang/jit/__init__.py @@ -3,6 +3,7 @@ This module provides an auto-tuning infrastructure for TileLang (tl) programs. It includes functionality to JIT-compile TileLang programs into a runnable kernel adapter using TVM. """ + from __future__ import annotations from dataclasses import dataclass @@ -39,17 +40,16 @@ from tqdm.auto import tqdm logger = getLogger(__name__) -_P = ParamSpec('_P') -_KP = ParamSpec('_KP') -_T = TypeVar('_T') -_Ret = TypeVar('_Ret') +_P = ParamSpec("_P") +_KP = ParamSpec("_KP") +_T = TypeVar("_T") +_Ret = TypeVar("_Ret") def compile( func: PrimFunc[_KP, _T] = None, out_idx: list[int] | int | None = None, - execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", - "torch"] = "auto", + execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto", target: str | Target = "auto", target_host: str | Target | None = None, verbose: bool = False, @@ -83,11 +83,9 @@ def compile( if isinstance(compile_flags, str): compile_flags = [compile_flags] - if hasattr(func, 'out_idx_override'): + if hasattr(func, "out_idx_override"): if func.out_idx_override is not None and out_idx is not None: - raise ValueError( - "Out index conflict: out_idx is specified and prim_func have returned `T.empty` tensors" - ) + raise ValueError("Out index conflict: out_idx is specified and prim_func have returned `T.empty` tensors") out_idx = func.out_idx_override or out_idx # This path is not a performance critical path, so we can afford to convert the target. @@ -96,6 +94,7 @@ def compile( # Resolve execution backend (handles aliases, auto, validation per target) requested_backend = execution_backend from tilelang.jit.execution_backend import resolve_execution_backend, allowed_backends_for_target + execution_backend = resolve_execution_backend(requested_backend, target) if verbose: allowed_now = allowed_backends_for_target(target, include_unavailable=False) @@ -119,17 +118,18 @@ def compile( ) -def par_compile(funcs: Iterable[PrimFunc[_KP, _T]], - out_idx: list[int] | int | None = None, - execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", - "torch"] = "auto", - target: str | Target = "auto", - target_host: str | Target | None = None, - verbose: bool = False, - pass_configs: dict[str, Any] | None = None, - compile_flags: list[str] | str | None = None, - num_workers: int = None, - ignore_error: bool = False) -> list[JITKernel[_KP, _T]]: +def par_compile( + funcs: Iterable[PrimFunc[_KP, _T]], + out_idx: list[int] | int | None = None, + execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto", + target: str | Target = "auto", + target_host: str | Target | None = None, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | str | None = None, + num_workers: int = None, + ignore_error: bool = False, +) -> list[JITKernel[_KP, _T]]: """ Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels. Parameters @@ -151,7 +151,7 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]], Additional keyword arguments to pass to the Compiler PassContext. Refer to `tilelang.transform.PassConfigKey` for supported options. """ - with concurrent.futures.ThreadPoolExecutor(num_workers, 'tl-par-comp') as executor: + with concurrent.futures.ThreadPoolExecutor(num_workers, "tl-par-comp") as executor: futures = [] future_map = {} for i, func in enumerate(funcs): @@ -170,9 +170,9 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]], futures.append(future) results = [... for _ in futures] for future in tqdm( - concurrent.futures.as_completed(futures), - total=len(futures), - desc="Parallel Compiling", + concurrent.futures.as_completed(futures), + total=len(futures), + desc="Parallel Compiling", ): idx = future_map[future] if ignore_error: @@ -189,7 +189,7 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]], @dataclass class JITImpl(Generic[_P, _KP, _T, _Ret]): - ''' + """ Detailed Just-In-Time wrapper for TileLang programs. This dataclass encapsulates the configuration and runtime helpers used by the @@ -256,7 +256,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]): PrimFunc and the resulting set is compiled in parallel via the module-level `par_compile` helper. Returns a list of JITKernel objects in the same order as the provided configs. - ''' + """ out_idx: list[int] | int | None execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] @@ -302,10 +302,9 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]): assert isinstance(tir, PrimFunc), f"target function must be a PrimFunc but got {type(tir)}" return tir - def par_compile(self, - configs: Iterable[dict[str, Any] | tuple[str, Any]], - num_workers: int = None, - ignore_error: bool = False) -> list[JITKernel[_KP, _T]]: + def par_compile( + self, configs: Iterable[dict[str, Any] | tuple[str, Any]], num_workers: int = None, ignore_error: bool = False + ) -> list[JITKernel[_KP, _T]]: """ Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels. Parameters @@ -328,7 +327,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]): """ configs = list(configs) funcs = [] - for cfg in tqdm(configs, desc='Elaborating'): + for cfg in tqdm(configs, desc="Elaborating"): if isinstance(cfg, tuple): funcs.append(self.get_tir(*cfg)) elif isinstance(cfg, dict): @@ -345,7 +344,8 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]): pass_configs=self.pass_configs, compile_flags=self.compile_flags, num_workers=num_workers, - ignore_error=ignore_error) + ignore_error=ignore_error, + ) def compile(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret: func = self.get_tir(*args, **kwargs) @@ -362,25 +362,25 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]): if self.debug_root_path: if isinstance(self.func, PrimFunc): - func_name = self.func.attrs['global_symbol'] + func_name = self.func.attrs["global_symbol"] else: - func_name = getattr(self.func, '__name__', 'jit_kernel') - kernel_file = f'tilelang_jit_kernel_{func_name}.c' - program_file = f'tilelang_jit_program_{func_name}.py' + func_name = getattr(self.func, "__name__", "jit_kernel") + kernel_file = f"tilelang_jit_kernel_{func_name}.c" + program_file = f"tilelang_jit_program_{func_name}.py" makedirs(self.debug_root_path, exist_ok=True) - with open(path.join(self.debug_root_path, kernel_file), 'w') as f: + with open(path.join(self.debug_root_path, kernel_file), "w") as f: print(kernel_result.get_kernel_source(), file=f) - with open(path.join(self.debug_root_path, program_file), 'w') as f: + with open(path.join(self.debug_root_path, program_file), "w") as f: print(func.script(), file=f) return kernel_result def parse_cache_key(self, *args: _P.args, **kwargs: _P.kwargs): if isinstance(self.func, PrimFuncCreater): - tune_params = kwargs.pop('__tune_params', {}) + tune_params = kwargs.pop("__tune_params", {}) return self.func.func_annot.parse_key(*args, **kwargs, **tune_params) else: - tune_params = kwargs.pop('__tune_params', {}) + tune_params = kwargs.pop("__tune_params", {}) key_args_tuple = args key_kwargs_tuple = tuple(sorted(kwargs.items())) tuned_key_kwargs_tuple = tuple(sorted(tune_params.items())) @@ -389,34 +389,31 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]): def convert_kernel_args(self, *args: _P.args, **kwargs: _P.kwargs): if isinstance(self.func, PrimFuncCreater): - tune_params = kwargs.pop('__tune_params', {}) + tune_params = kwargs.pop("__tune_params", {}) return self.func.func_annot.convert_to_kernel_args(*args, **kwargs, **tune_params) else: - raise NotImplementedError( - "convert_arg_to_kernel_args is only implemented for PrimFuncCreater.") + raise NotImplementedError("convert_arg_to_kernel_args is only implemented for PrimFuncCreater.") def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret: # Separate out the tuning parameters from the user's kwargs # Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache - return_compile_arguments = kwargs.pop('__return_compile_arguments', False) + return_compile_arguments = kwargs.pop("__return_compile_arguments", False) if return_compile_arguments: - logger.warning( - "`__return_compile_arguments` is deprecated and will be removed in future versions." - ) + logger.warning("`__return_compile_arguments` is deprecated and will be removed in future versions.") compile_args = { - 'out_idx': self.out_idx, - 'execution_backend': self.execution_backend, - 'target': self.target, - 'target_host': self.target_host, - 'verbose': self.verbose, - 'pass_configs': self.pass_configs, - 'compile_flags': self.compile_flags, + "out_idx": self.out_idx, + "execution_backend": self.execution_backend, + "target": self.target, + "target_host": self.target_host, + "verbose": self.verbose, + "pass_configs": self.pass_configs, + "compile_flags": self.compile_flags, } return compile_args key = self.parse_cache_key(*args, **kwargs) - tune_params = kwargs.pop('__tune_params', {}) + tune_params = kwargs.pop("__tune_params", {}) kernel = self._kernel_cache.get(key, None) if kernel is None: @@ -434,8 +431,7 @@ ExecutionBackend = Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvr @overload -def jit(func: Callable[_P, PrimFunc[_KP, _T]]) -> JITImpl[_P, _KP, _T, JITKernel[_KP, _T]]: - ... +def jit(func: Callable[_P, PrimFunc[_KP, _T]]) -> JITImpl[_P, _KP, _T, JITKernel[_KP, _T]]: ... @overload @@ -448,22 +444,22 @@ def jit( verbose: bool = False, pass_configs: dict[str, Any] | None = None, debug_root_path: str | None = None, - compile_flags: list[str] | str | None = None -) -> Callable[[Callable[_P, PrimFunc[_KP, _T]]], JITImpl[_P, _KP, _T, JITKernel[_KP, _T]]]: - ... + compile_flags: list[str] | str | None = None, +) -> Callable[[Callable[_P, PrimFunc[_KP, _T]]], JITImpl[_P, _KP, _T, JITKernel[_KP, _T]]]: ... def jit( # This is the new public interface - func: Callable[_P, _T] | PrimFunc | None = None, - *, # Indicates subsequent arguments are keyword-only - out_idx: Any = None, - target: str | Target = "auto", - target_host: str | Target = None, - execution_backend: ExecutionBackend = "auto", - verbose: bool = False, - pass_configs: dict[str, Any] | None = None, - debug_root_path: str | None = None, - compile_flags: list[str] | str | None = None): + func: Callable[_P, _T] | PrimFunc | None = None, + *, # Indicates subsequent arguments are keyword-only + out_idx: Any = None, + target: str | Target = "auto", + target_host: str | Target = None, + execution_backend: ExecutionBackend = "auto", + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + debug_root_path: str | None = None, + compile_flags: list[str] | str | None = None, +): """ Just-In-Time (JIT) compiler decorator for TileLang functions. @@ -516,7 +512,8 @@ def jit( # This is the new public interface compile_flags=compile_flags, func_source=inspect.getsource(orig_func), signature=inspect.signature(orig_func), - lazy_jit=False) + lazy_jit=False, + ) if func is not None: return decorator(func) @@ -525,8 +522,7 @@ def jit( # This is the new public interface @overload -def lazy_jit(func: Callable[_KP, _T]) -> JITImpl[_KP, _KP, _T, _T]: - ... +def lazy_jit(func: Callable[_KP, _T]) -> JITImpl[_KP, _KP, _T, _T]: ... @overload @@ -539,9 +535,8 @@ def lazy_jit( verbose: bool = False, pass_configs: dict[str, Any] | None = None, debug_root_path: str | None = None, - compile_flags: list[str] | str | None = None -) -> Callable[[Callable[_KP, _T]], JITImpl[_KP, _KP, _T, _T]]: - ... + compile_flags: list[str] | str | None = None, +) -> Callable[[Callable[_KP, _T]], JITImpl[_KP, _KP, _T, _T]]: ... def lazy_jit( @@ -555,7 +550,6 @@ def lazy_jit( debug_root_path: str | None = None, compile_flags: list[str] | str | None = None, ): - if isinstance(compile_flags, str): compile_flags = [compile_flags] @@ -567,7 +561,8 @@ def lazy_jit( verbose=verbose, pass_configs=pass_configs, debug_root_path=debug_root_path, - compile_flags=compile_flags) + compile_flags=compile_flags, + ) def decorator(func: Callable[_P, _T]): pf: PrimFunc[_P, _T] | PrimFuncCreater[_P, _T] = prim_func(func, generator=True) @@ -576,10 +571,7 @@ def lazy_jit( # return compile(pf, **compile_args) # else: return JITImpl( - func=pf, - **compile_args, - func_source=inspect.getsource(pf.orig_func), - signature=inspect.signature(pf.orig_func), - lazy_jit=True) + func=pf, **compile_args, func_source=inspect.getsource(pf.orig_func), signature=inspect.signature(pf.orig_func), lazy_jit=True + ) return decorator(func) if func is not None else decorator diff --git a/tilelang/jit/adapter/base.py b/tilelang/jit/adapter/base.py index 6bd69cf..3669f9e 100644 --- a/tilelang/jit/adapter/base.py +++ b/tilelang/jit/adapter/base.py @@ -1,4 +1,5 @@ """The profiler and convert to torch utils""" + from __future__ import annotations from abc import ABC, abstractmethod @@ -8,7 +9,6 @@ import torch class BaseKernelAdapter(ABC): - func: Callable | None = None def __init__(self, mod, params: list[KernelParam], result_idx: list[int]) -> None: @@ -24,18 +24,14 @@ class BaseKernelAdapter(ABC): result_idx = [] elif isinstance(result_idx, int): if result_idx > len(params) or result_idx < -len(params): - raise ValueError( - f"result_idx should be an integer between {-len(params) - 1} and {len(params) - 1}" - ) + raise ValueError(f"result_idx should be an integer between {-len(params) - 1} and {len(params) - 1}") if result_idx < 0: result_idx = len(params) + result_idx result_idx = [result_idx] elif isinstance(result_idx, list): for i, idx in enumerate(result_idx): if idx >= len(params) or idx < -len(params): - raise ValueError( - f"result_idx should be an integer between {-len(params) - 1} and {len(params) - 1}" - ) + raise ValueError(f"result_idx should be an integer between {-len(params) - 1} and {len(params) - 1}") if idx < 0: result_idx[i] = len(params) + idx else: diff --git a/tilelang/jit/adapter/ctypes/adapter.py b/tilelang/jit/adapter/ctypes/adapter.py index e267730..92af826 100644 --- a/tilelang/jit/adapter/ctypes/adapter.py +++ b/tilelang/jit/adapter/ctypes/adapter.py @@ -1,4 +1,5 @@ """The profiler and convert to torch utils""" + from __future__ import annotations import torch from ..base import BaseKernelAdapter @@ -41,18 +42,20 @@ class CtypesKernelAdapter(BaseKernelAdapter): param_dtypes: list[torch.dtype] | None = None # Cache for parameter dtypes param_shapes: list[list] | None = None # Cache for parameter shapes - def __init__(self, - params: list[TensorType], - result_idx: list[int], - target: str, - func_or_mod: tir.PrimFunc | tvm.IRModule, - host_mod: tvm.IRModule | None = None, - device_mod: tvm.IRModule | None = None, - host_kernel_source: str | None = None, - device_kernel_source: str | None = None, - verbose: bool = False, - pass_configs: dict[str, Any] | None = None, - compile_flags: list[str] | None = None): + def __init__( + self, + params: list[TensorType], + result_idx: list[int], + target: str, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_mod: tvm.IRModule | None = None, + device_mod: tvm.IRModule | None = None, + host_kernel_source: str | None = None, + device_kernel_source: str | None = None, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ): """Initialize the adapter with the given TIR function or module. Args: @@ -109,17 +112,19 @@ class CtypesKernelAdapter(BaseKernelAdapter): self._post_init() @classmethod - def from_database(cls, - params: list[TensorType], - result_idx: list[int], - target: str, - func_or_mod: tir.PrimFunc | tvm.IRModule, - host_kernel_source: str, - device_kernel_source: str, - kernel_lib_path: str, - verbose: bool = False, - pass_configs: dict[str, Any] | None = None, - compile_flags: list[str] | None = None): + def from_database( + cls, + params: list[TensorType], + result_idx: list[int], + target: str, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_kernel_source: str, + device_kernel_source: str, + kernel_lib_path: str, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ): adapter = cls.__new__(cls) adapter.params = params adapter.result_idx = adapter._legalize_result_idx(result_idx) @@ -175,15 +180,13 @@ class CtypesKernelAdapter(BaseKernelAdapter): if param in buffer_map: buffer = buffer_map[param] for j, shape in enumerate(buffer.shape): - if (isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and - (shape not in params)): + if isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and (shape not in params): dynamic_symbolic_map[shape] = (0, i, j) for i, param in enumerate(params): if param in buffer_map: buffer = buffer_map[param] for j, stride in enumerate(buffer.strides): - if (isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and - (stride not in params)): + if isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and (stride not in params): dynamic_symbolic_map[stride] = (1, i, j) return dynamic_symbolic_map @@ -192,9 +195,7 @@ class CtypesKernelAdapter(BaseKernelAdapter): Converts PyTorch tensor pointers to C void pointers for ctypes interface. """ - ctypes_args = [ - ctypes.c_void_p(arr.data_ptr()) if not isinstance(arr, int) else arr for arr in args - ] + ctypes_args = [ctypes.c_void_p(arr.data_ptr()) if not isinstance(arr, int) else arr for arr in args] ctypes_args.append(ctypes.c_void_p(stream)) self.lib.call(*ctypes_args) @@ -288,7 +289,7 @@ class CtypesKernelAdapter(BaseKernelAdapter): @property def is_dynamic(self): """Indicates whether the kernel handles dynamic shapes.""" - return (self.dynamic_symbolic_map is not None and len(self.dynamic_symbolic_map) > 0) + return self.dynamic_symbolic_map is not None and len(self.dynamic_symbolic_map) > 0 def get_kernel_source(self, kernel_only: bool = False): """Returns the source code of the compiled kernel.""" diff --git a/tilelang/jit/adapter/cython/adapter.py b/tilelang/jit/adapter/cython/adapter.py index fe8fe5b..c456e4d 100644 --- a/tilelang/jit/adapter/cython/adapter.py +++ b/tilelang/jit/adapter/cython/adapter.py @@ -1,4 +1,5 @@ """The profiler and convert to torch utils""" + from __future__ import annotations import ctypes import logging @@ -70,17 +71,19 @@ class CythonKernelAdapter(BaseKernelAdapter): # Pass configs for the compiler pass_configs: dict[str, Any] | None = None - def __init__(self, - params: list[KernelParam], - result_idx: list[int], - target: str | Target, - func_or_mod: tir.PrimFunc | tvm.IRModule, - host_mod: tvm.IRModule | None = None, - device_mod: tvm.IRModule | None = None, - device_kernel_source: str | None = None, - verbose: bool = False, - pass_configs: dict[str, Any] | None = None, - compile_flags: list[str] | None = None): + def __init__( + self, + params: list[KernelParam], + result_idx: list[int], + target: str | Target, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_mod: tvm.IRModule | None = None, + device_mod: tvm.IRModule | None = None, + device_kernel_source: str | None = None, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ): """Initialize the adapter with the given TIR function or module. Args: @@ -130,7 +133,7 @@ class CythonKernelAdapter(BaseKernelAdapter): self.lib.get_last_error.restype = ctypes.c_char_p result = self.lib.init() if result != 0: - error_msg = self.lib.get_last_error().decode('utf-8') + error_msg = self.lib.get_last_error().decode("utf-8") error_msg += f"\n{self.lib_code}" raise RuntimeError(f"Initialization failed: {error_msg}") @@ -145,17 +148,19 @@ class CythonKernelAdapter(BaseKernelAdapter): self._post_init() @classmethod - def from_database(cls, - params: list[TensorType], - result_idx: list[int], - target: str, - func_or_mod: tir.PrimFunc | tvm.IRModule, - host_kernel_source: str, - device_kernel_source: str, - kernel_lib_path: str, - verbose: bool = False, - pass_configs: dict[str, Any] | None = None, - compile_flags: list[str] | None = None): + def from_database( + cls, + params: list[TensorType], + result_idx: list[int], + target: str, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_kernel_source: str, + device_kernel_source: str, + kernel_lib_path: str, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ): adapter = cls.__new__(cls) adapter.params = params adapter.result_idx = adapter._legalize_result_idx(result_idx) @@ -190,11 +195,10 @@ class CythonKernelAdapter(BaseKernelAdapter): adapter.lib.get_last_error.restype = ctypes.c_char_p result = adapter.lib.init() if result != 0: - error_msg = adapter.lib.get_last_error().decode('utf-8') + error_msg = adapter.lib.get_last_error().decode("utf-8") raise RuntimeError(f"Initialization failed: {error_msg}") - adapter.cython_wrapper = CythonKernelWrapper(adapter.result_idx, adapter.params, - adapter.lib) + adapter.cython_wrapper = CythonKernelWrapper(adapter.result_idx, adapter.params, adapter.lib) adapter.cython_wrapper.set_dynamic_symbolic_map(adapter.dynamic_symbolic_map) adapter.cython_wrapper.set_buffer_dtype_map(adapter.buffer_dtype_map) adapter.cython_wrapper.set_static_shape_map(adapter.static_shape_map) @@ -221,15 +225,13 @@ class CythonKernelAdapter(BaseKernelAdapter): if param in buffer_map: buffer = buffer_map[param] for j, shape in enumerate(buffer.shape): - if (isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and - (shape not in params)): + if isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and (shape not in params): dynamic_symbolic_map[shape] = (0, i, j) for i, param in enumerate(params): if param in buffer_map: buffer = buffer_map[param] for j, stride in enumerate(buffer.strides): - if (isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and - (stride not in params)): + if isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and (stride not in params): dynamic_symbolic_map[stride] = (1, i, j) return dynamic_symbolic_map @@ -259,14 +261,13 @@ class CythonKernelAdapter(BaseKernelAdapter): params = func.params ptr_map = {} for i, param in enumerate(params): - if param.dtype == 'handle': + if param.dtype == "handle": ptr_map[i] = param.name return ptr_map - def _process_static_buffer_infos(self) -> \ - tuple[dict[tir.Var, tuple[int, list[tuple[int, int]]]], - dict[tir.Var, tuple[int, list[tuple[int, int]]]], - list[tuple[tir.Var]]]: + def _process_static_buffer_infos( + self, + ) -> tuple[dict[tir.Var, tuple[int, list[tuple[int, int]]]], dict[tir.Var, tuple[int, list[tuple[int, int]]]], list[tuple[tir.Var]]]: """Extract information about static shapes from the TIR function. Maps buffer variables to their corresponding static shapes. @@ -332,9 +333,7 @@ class CythonKernelAdapter(BaseKernelAdapter): Converts PyTorch tensor pointers to C void pointers for ctypes interface. """ - ctypes_args = [ - ctypes.c_void_p(arr.data_ptr()) if not isinstance(arr, int) else arr for arr in args - ] + ctypes_args = [ctypes.c_void_p(arr.data_ptr()) if not isinstance(arr, int) else arr for arr in args] ctypes_args.append(ctypes.c_void_p(stream)) self.lib.call(*ctypes_args) @@ -349,9 +348,7 @@ class CythonKernelAdapter(BaseKernelAdapter): skip_tensor_validation: Whether to skip tensor attributes validation which includes shape, dtype, device, etc. """ - return self.cython_wrapper.forward([*args], - stream=stream, - skip_tensor_validation=skip_tensor_validation) + return self.cython_wrapper.forward([*args], stream=stream, skip_tensor_validation=skip_tensor_validation) return lambda_forward diff --git a/tilelang/jit/adapter/libgen.py b/tilelang/jit/adapter/libgen.py index 208370b..d67f5b4 100644 --- a/tilelang/jit/adapter/libgen.py +++ b/tilelang/jit/adapter/libgen.py @@ -55,6 +55,7 @@ class LibraryGenerator: verbose = self.verbose if is_cuda_target(target): from tilelang.env import CUTLASS_INCLUDE_DIR + src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) # noqa: SIM115 target_arch = get_target_arch(get_target_compute_version(target)) libpath = src.name.replace(".cu", ".so") @@ -65,15 +66,12 @@ class LibraryGenerator: "TL_ENABLE_FAST_MATH", "0.1.7", ) - enable_fast_math = not self.pass_configs.get(PassConfigKey.TL_DISABLE_FAST_MATH, - True) + enable_fast_math = not self.pass_configs.get(PassConfigKey.TL_DISABLE_FAST_MATH, True) else: enable_fast_math = self.pass_configs.get(PassConfigKey.TL_ENABLE_FAST_MATH, False) - ptxas_usage_level = self.pass_configs.get(PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL, - None) - verbose_ptxas_output = self.pass_configs.get( - PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT, False) + ptxas_usage_level = self.pass_configs.get(PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL, None) + verbose_ptxas_output = self.pass_configs.get(PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT, False) command = [ get_nvcc_compiler(), @@ -102,6 +100,7 @@ class LibraryGenerator: elif is_hip_target(target): from tilelang.env import COMPOSABLE_KERNEL_INCLUDE_DIR + src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False) # noqa: SIM115 libpath = src.name.replace(".cpp", ".so") rocm_path = find_rocm_path() @@ -119,6 +118,7 @@ class LibraryGenerator: ] elif is_cpu_target(target): from tilelang.contrib.cc import get_cplus_compiler + src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False) # noqa: SIM115 libpath = src.name.replace(".cpp", ".so") @@ -134,9 +134,7 @@ class LibraryGenerator: ] if self.compile_flags: - command += [ - item for flag in self.compile_flags for item in flag.split() if item not in command - ] + command += [item for flag in self.compile_flags for item in flag.split() if item not in command] command += ["-o", libpath] @@ -151,8 +149,7 @@ class LibraryGenerator: raise RuntimeError(f"Compile kernel failed because of {e}") from e if ret.returncode != 0: - raise RuntimeError(f"Compilation Failed! {command}" - f"\n {self.lib_code}") + raise RuntimeError(f"Compilation Failed! {command}\n {self.lib_code}") self.srcpath = src.name self.libpath = libpath diff --git a/tilelang/jit/adapter/nvrtc/__init__.py b/tilelang/jit/adapter/nvrtc/__init__.py index faa08c1..c8abe8d 100644 --- a/tilelang/jit/adapter/nvrtc/__init__.py +++ b/tilelang/jit/adapter/nvrtc/__init__.py @@ -5,22 +5,22 @@ This module provides runtime compilation support using NVIDIA's NVRTC API. import logging -__all__ = [ - 'NVRTCKernelAdapter', 'TLNVRTCSourceWrapper', 'NVRTCLibraryGenerator', 'is_nvrtc_available', - 'check_nvrtc_available' -] +__all__ = ["NVRTCKernelAdapter", "TLNVRTCSourceWrapper", "NVRTCLibraryGenerator", "is_nvrtc_available", "check_nvrtc_available"] logger = logging.getLogger(__name__) # Check if cuda-python is available is_nvrtc_available = False -NVRTC_UNAVAILABLE_MESSAGE = ("cuda-python is not available, NVRTC backend cannot be used. " - "Please install cuda-python via `pip install cuda-python` " - "if you want to use the NVRTC backend.") +NVRTC_UNAVAILABLE_MESSAGE = ( + "cuda-python is not available, NVRTC backend cannot be used. " + "Please install cuda-python via `pip install cuda-python` " + "if you want to use the NVRTC backend." +) try: import cuda.bindings.driver as cuda # noqa: F401 import cuda.bindings.nvrtc as nvrtc # noqa: F401 + is_nvrtc_available = True except ImportError as e: logger.debug(f"cuda-python import failed: {e}") diff --git a/tilelang/jit/adapter/nvrtc/adapter.py b/tilelang/jit/adapter/nvrtc/adapter.py index 4a465d3..d222f33 100644 --- a/tilelang/jit/adapter/nvrtc/adapter.py +++ b/tilelang/jit/adapter/nvrtc/adapter.py @@ -27,18 +27,19 @@ class NVRTCKernelAdapter(BaseKernelAdapter): pymodule = None kernels = {} - def __init__(self, - params: list[KernelParam], - result_idx: list[int], - target: str | Target, - func_or_mod: tir.PrimFunc | tvm.IRModule, - host_mod: tvm.IRModule | None = None, - device_mod: tvm.IRModule | None = None, - device_kernel_source: str | None = None, - verbose: bool = False, - pass_configs: dict[str, Any] | None = None, - compile_flags: list[str] | None = None): - + def __init__( + self, + params: list[KernelParam], + result_idx: list[int], + target: str | Target, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_mod: tvm.IRModule | None = None, + device_mod: tvm.IRModule | None = None, + device_kernel_source: str | None = None, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ): check_nvrtc_available() self.params = params @@ -92,17 +93,19 @@ class NVRTCKernelAdapter(BaseKernelAdapter): self._post_init() @classmethod - def from_database(cls, - params: list[KernelParam], - result_idx: list[int], - target: str, - func_or_mod: tir.PrimFunc | tvm.IRModule, - host_kernel_source: str, - device_kernel_source: str, - kernel_lib_path: str, - verbose: bool = False, - pass_configs: dict[str, Any] | None = None, - compile_flags: list[str] | None = None): + def from_database( + cls, + params: list[KernelParam], + result_idx: list[int], + target: str, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_kernel_source: str, + device_kernel_source: str, + kernel_lib_path: str, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ): adapter = cls.__new__(cls) adapter.params = params adapter.result_idx = adapter._legalize_result_idx(result_idx) @@ -183,8 +186,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter): return self.host_func def _forward_from_prebuild_lib(self, *args, stream: int | None = None): - """Low-level function to call the compiled CUDA kernel. - """ + """Low-level function to call the compiled CUDA kernel.""" return self.pymodule.call(self.kernels, *args, stream=stream) def _wrap_forward_from_prebuild_lib(self, *ins: list[torch.Tensor], stream: int | None = None): diff --git a/tilelang/jit/adapter/nvrtc/libgen.py b/tilelang/jit/adapter/nvrtc/libgen.py index 50a587a..406cc44 100644 --- a/tilelang/jit/adapter/nvrtc/libgen.py +++ b/tilelang/jit/adapter/nvrtc/libgen.py @@ -13,6 +13,7 @@ Key responsibilities: - Load compiled cubin and extract kernel handles - Manage library lifecycle (load/unload) """ + from __future__ import annotations import importlib import logging @@ -56,6 +57,7 @@ class NVRTCLibraryGenerator(LibraryGenerator): culib: CUDA library handle (CUlibrary) pymodule: Imported Python module containing call() function """ + host_func: str | None = None culib: cuda.CUlibrary | None = None pymodule: ModuleType | None = None @@ -131,10 +133,10 @@ class NVRTCLibraryGenerator(LibraryGenerator): ctx = cuda.cuCtxGetCurrent()[1] if cuda.cuCtxGetApiVersion(ctx)[0] != cuda.CUresult.CUDA_SUCCESS: import torch + torch.cuda.synchronize() - result, self.culib = cuda.cuLibraryLoadFromFile( - bytes(lib_path, "utf-8"), [], [], 0, [], [], 0) + result, self.culib = cuda.cuLibraryLoadFromFile(bytes(lib_path, "utf-8"), [], [], 0, [], [], 0) if result != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError(f"Failed to load library: {lib_path}, error: {result}") @@ -164,7 +166,8 @@ class NVRTCLibraryGenerator(LibraryGenerator): target = self.target verbose = self.verbose if is_cuda_target(target): - from tilelang.env import (CUDA_HOME, CUTLASS_INCLUDE_DIR, TILELANG_TEMPLATE_PATH) + from tilelang.env import CUDA_HOME, CUTLASS_INCLUDE_DIR, TILELANG_TEMPLATE_PATH + src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) libpath = src.name.replace(".cu", ".cubin") @@ -195,13 +198,9 @@ class NVRTCLibraryGenerator(LibraryGenerator): f"-D__CUDACC_VER_MAJOR__={__CUDACC_VER_MAJOR__}", ] if self.compile_flags: - options += [ - item for flag in self.compile_flags for item in flag.split() - if item not in options - ] + options += [item for flag in self.compile_flags for item in flag.split() if item not in options] - cubin_bytes = compile_cuda( - self.lib_code, target_format="cubin", options=options, verbose=verbose) + cubin_bytes = compile_cuda(self.lib_code, target_format="cubin", options=options, verbose=verbose) with open(libpath, "wb") as f: f.write(cubin_bytes) @@ -212,8 +211,7 @@ class NVRTCLibraryGenerator(LibraryGenerator): self.libpath = libpath self.pypath = src.name.replace(".cu", ".py") if self.host_func is None: - raise RuntimeError( - "Host function is not set, please call update_host_func() first.") + raise RuntimeError("Host function is not set, please call update_host_func() first.") with open(self.pypath, "w") as f: f.write(self.host_func) else: diff --git a/tilelang/jit/adapter/nvrtc/wrapper.py b/tilelang/jit/adapter/nvrtc/wrapper.py index 7e00050..3df2b3b 100644 --- a/tilelang/jit/adapter/nvrtc/wrapper.py +++ b/tilelang/jit/adapter/nvrtc/wrapper.py @@ -12,6 +12,7 @@ Key design: - Dict-based deduplication ensures TMA descriptors created only once - Generates pure Python using cuda.bindings.driver for zero C++ dependency """ + from __future__ import annotations from typing import Any, ClassVar @@ -21,8 +22,7 @@ from tvm.tir.stmt_functor import post_order_visit from tilelang import tvm as tvm from tilelang.jit.adapter.wrapper import TLCUDASourceWrapper -from tilelang.jit.adapter.utils import (match_declare_kernel, pythonic_expr, - parse_function_call_args, parse_tma_descriptor_args) +from tilelang.jit.adapter.utils import match_declare_kernel, pythonic_expr, parse_function_call_args, parse_tma_descriptor_args PREDEF_HOST_FUNC_PY = """ from cuda.bindings.driver import ( @@ -235,13 +235,15 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): _generated_host_func: str | None = None - def __init__(self, - scheduled_ir_module: IRModule, - source: str, - target: Target, - device_mod: IRModule | None = None, - host_mod: IRModule | None = None, - pass_configs: dict[str, Any] | None = None): + def __init__( + self, + scheduled_ir_module: IRModule, + source: str, + target: Target, + device_mod: IRModule | None = None, + host_mod: IRModule | None = None, + pass_configs: dict[str, Any] | None = None, + ): """Initialize NVRTC wrapper with compiled IR modules. Args: @@ -303,15 +305,16 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): for param in self.prim_func.params: if param in self.prim_func.buffer_map: buffer = self.prim_func.buffer_map[param] - function_args.append({ - "name": buffer.data.name, - "type": "ctypes.c_void_p", - }) + function_args.append( + { + "name": buffer.data.name, + "type": "ctypes.c_void_p", + } + ) elif isinstance(param, tvm.tir.Var): function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)}) else: - raise ValueError( - f"Parameter {param} is not in the buffer map of the primary function.") + raise ValueError(f"Parameter {param} is not in the buffer map of the primary function.") # Add dynamic symbols as integer arguments for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set: if dyn_sym not in [arg["name"] for arg in function_args]: @@ -359,9 +362,9 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): return (f"{name}.data_ptr()", arg_type) return (name, arg_type) - call_args = parse_function_call_args(declaration, function_args, function_params, - desc_name_map, desc_name_var_map, - transform_nvrtc_arg) + call_args = parse_function_call_args( + declaration, function_args, function_params, desc_name_map, desc_name_var_map, transform_nvrtc_arg + ) for arg_name, arg_type in call_args: if arg_type == "ctypes.c_void_p": @@ -369,26 +372,28 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): break # Store kernel info for second pass - kernel_info_list.append({ - 'function_name': function_name, - 'block_info': block_info, - 'grid_info': grid_info, - 'dynamic_smem_buf': dynamic_smem_buf, - 'call_args': call_args, - 'device_index': device_index, - }) + kernel_info_list.append( + { + "function_name": function_name, + "block_info": block_info, + "grid_info": grid_info, + "dynamic_smem_buf": dynamic_smem_buf, + "call_args": call_args, + "device_index": device_index, + } + ) # Generate TMA descriptor initialization code once for all kernels kernel_launch_code += self.generate_tma_descriptor_args(desc_name_map, desc_name_var_map) # Second pass: generate kernel launch code for each kernel for kernel_info in kernel_info_list: - function_name = kernel_info['function_name'] - block_info = kernel_info['block_info'] - grid_info = kernel_info['grid_info'] - dynamic_smem_buf = kernel_info['dynamic_smem_buf'] - call_args = kernel_info['call_args'] - device_index = kernel_info['device_index'] + function_name = kernel_info["function_name"] + block_info = kernel_info["block_info"] + grid_info = kernel_info["grid_info"] + dynamic_smem_buf = kernel_info["dynamic_smem_buf"] + call_args = kernel_info["call_args"] + device_index = kernel_info["device_index"] arg_names = ", ".join([arg[0] for arg in call_args]) arg_types = ", ".join([arg[1] for arg in call_args]) @@ -399,23 +404,26 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): kernel_launch_code += init_l2_persistent_map # Generate kernel launch code - kernel_launch_code += KERNEL_LAUNCH_FUNC_PY.format(function_name, - self._pythonic_expr(grid_info[0]), - self._pythonic_expr(grid_info[1]), - self._pythonic_expr(grid_info[2]), - self._pythonic_expr(block_info[0]), - self._pythonic_expr(block_info[1]), - self._pythonic_expr(block_info[2]), - smem_str, arg_names, arg_types, - device_index) + kernel_launch_code += KERNEL_LAUNCH_FUNC_PY.format( + function_name, + self._pythonic_expr(grid_info[0]), + self._pythonic_expr(grid_info[1]), + self._pythonic_expr(grid_info[2]), + self._pythonic_expr(block_info[0]), + self._pythonic_expr(block_info[1]), + self._pythonic_expr(block_info[2]), + smem_str, + arg_names, + arg_types, + device_index, + ) # Reset L2 persistent map after all kernel execution if has_l2_persistent_map: kernel_launch_code += L2_PERSISTENT_MAP_RESET_HANDLE_PY # Wrap the kernel dispatch logic in an external C function - host_func = PREDEF_HOST_FUNC_PY.format( - repr(list(function_informations.keys())), def_args, kernel_launch_code) + host_func = PREDEF_HOST_FUNC_PY.format(repr(list(function_informations.keys())), def_args, kernel_launch_code) return host_func def generate_l2_persistent_map(self, function_name: str) -> str: @@ -434,23 +442,21 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): if function_name not in self.l2_persistent_map: return "" init_l2_persistent_map = "" - for buffer_name, (hit_ratio, - size_in_bytes) in self.l2_persistent_map[function_name].items(): + for buffer_name, (hit_ratio, size_in_bytes) in self.l2_persistent_map[function_name].items(): # Get persisting_l2_cache_max_size from tilelang.carver.arch.driver import get_persisting_l2_cache_max_size + persisting_l2_cache_max_size = get_persisting_l2_cache_max_size() try: num_bytes = min(size_in_bytes, persisting_l2_cache_max_size) except TypeError: # as size_in_bytes may be a symbolic expression num_bytes = persisting_l2_cache_max_size - init_l2_persistent_map += L2_PERSISTENT_MAP_INIT_FUNC_PY.format( - buffer_name, float(hit_ratio), self._pythonic_expr(num_bytes)) + init_l2_persistent_map += L2_PERSISTENT_MAP_INIT_FUNC_PY.format(buffer_name, float(hit_ratio), self._pythonic_expr(num_bytes)) return init_l2_persistent_map - def generate_tma_descriptor_args(self, desc_name_map: dict[str, str], - desc_name_var_map: dict[str, tvm.tir.Var]) -> str: + def generate_tma_descriptor_args(self, desc_name_map: dict[str, str], desc_name_var_map: dict[str, tvm.tir.Var]) -> str: """Generate Python code to initialize TMA descriptors. TMA (Tensor Memory Accelerator) descriptors are opaque CUDA objects @@ -470,28 +476,43 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): return tma_descriptor_init # Parse TMA descriptor arguments using the common utility - parsed_params = parse_tma_descriptor_args(self.tma_descriptor_args, desc_name_map, - desc_name_var_map, self._pythonic_expr) + parsed_params = parse_tma_descriptor_args(self.tma_descriptor_args, desc_name_map, desc_name_var_map, self._pythonic_expr) # Generate Python code from parsed parameters for params in parsed_params: if not params.is_img2col: tma_descriptor_init += TMA_DESC_INIT_FUNC_PY.format( - params.handle_name, params.dtype, params.tensor_rank, params.global_address, + params.handle_name, + params.dtype, + params.tensor_rank, + params.global_address, ", ".join(map(lambda x: f"cuuint64_t({x})", params.global_dim)), ", ".join(map(lambda x: f"cuuint64_t({x})", params.global_stride)), ", ".join(map(lambda x: f"cuuint32_t({x})", params.box_dim)), ", ".join(map(lambda x: f"cuuint32_t({x})", params.element_strides)), - params.interleave, params.swizzle, params.l2_promotion, params.oob_fill) + params.interleave, + params.swizzle, + params.l2_promotion, + params.oob_fill, + ) else: tma_descriptor_init += TMA_IM2COL_DESC_INIT_FUNC_PY.format( - params.handle_name, params.dtype, params.tensor_rank, params.global_address, + params.handle_name, + params.dtype, + params.tensor_rank, + params.global_address, ", ".join(map(lambda x: f"cuuint64_t({x})", params.global_dim)), ", ".join(map(lambda x: f"cuuint64_t({x})", params.global_stride)), - ", ".join(map(lambda x: f"cuuint32_t({x})", - params.element_strides)), ", ".join(params.lower_corner), - ", ".join(params.upper_corner), params.smem_box_channel, params.smem_box_pixel, - params.interleave, params.swizzle, params.l2_promotion, params.oob_fill) + ", ".join(map(lambda x: f"cuuint32_t({x})", params.element_strides)), + ", ".join(params.lower_corner), + ", ".join(params.upper_corner), + params.smem_box_channel, + params.smem_box_pixel, + params.interleave, + params.swizzle, + params.l2_promotion, + params.oob_fill, + ) return tma_descriptor_init @@ -527,17 +548,14 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): def visitor(node, fn=function_name, param_cnt=kernel_params_cnt): nonlocal function_params if isinstance(node, tvm.tir.Call): - if not (hasattr(node, "op") and - node.op == tvm.ir.Op.get("tir.tvm_call_packed")): + if not (hasattr(node, "op") and node.op == tvm.ir.Op.get("tir.tvm_call_packed")): return args = node.args if not args or args[0] != fn: return if len(args) < 1 + param_cnt: - raise AssertionError( - "tvm_call_packed should have at least 1 argument and match device function parameters" - ) - function_params = args[1:1 + param_cnt] + raise AssertionError("tvm_call_packed should have at least 1 argument and match device function parameters") + function_params = args[1 : 1 + param_cnt] post_order_visit(self.host_func.body, visitor) assert function_params is not None, "function_params should not be None" diff --git a/tilelang/jit/adapter/torch/__init__.py b/tilelang/jit/adapter/torch/__init__.py index 2390e3e..f688993 100644 --- a/tilelang/jit/adapter/torch/__init__.py +++ b/tilelang/jit/adapter/torch/__init__.py @@ -1,3 +1,3 @@ from .metal import MetalKernelAdapter -__all__ = ['MetalKernelAdapter'] +__all__ = ["MetalKernelAdapter"] diff --git a/tilelang/jit/adapter/torch/metal.py b/tilelang/jit/adapter/torch/metal.py index 0b1bc00..4690cf5 100644 --- a/tilelang/jit/adapter/torch/metal.py +++ b/tilelang/jit/adapter/torch/metal.py @@ -12,7 +12,6 @@ from tilelang.engine.param import KernelParam class MetalKernelAdapter(BaseKernelAdapter): - def __init__( self, params: list[KernelParam], @@ -28,10 +27,10 @@ class MetalKernelAdapter(BaseKernelAdapter): ): self.kernel_global_source = kernel_global_source if isinstance(func_or_mod, tir.PrimFunc): - func_name = func_or_mod.attrs['global_symbol'] + func_name = func_or_mod.attrs["global_symbol"] else: func_name = func_or_mod.__name__ - self.kernel_name = func_name + '_kernel' + self.kernel_name = func_name + "_kernel" self.verbose = verbose self.block_info = [1, 1, 1] @@ -39,7 +38,7 @@ class MetalKernelAdapter(BaseKernelAdapter): for var, func in device_mod.functions.items(): assert var.name_hint == self.kernel_name - thread_extent = func.attrs['thread_extent'] + thread_extent = func.attrs["thread_extent"] for tag, extent in thread_extent.items(): if "threadIdx" in tag: self.block_info["xyz".index(tag[-1])] = extent @@ -47,7 +46,7 @@ class MetalKernelAdapter(BaseKernelAdapter): self.grid_info["xyz".index(tag[-1])] = extent break else: - raise AssertionError(f'no kernel with name {func_name}') + raise AssertionError(f"no kernel with name {func_name}") # print(self.block_info, self.grid_info) super().__init__(func_or_mod, result_idx=result_idx, params=params) @@ -55,15 +54,12 @@ class MetalKernelAdapter(BaseKernelAdapter): _kernel = None def _convert_torch_func(self) -> Callable: - if self._kernel is None: - _kernel = getattr(torch.mps.compile_shader(self.kernel_global_source), self.kernel_name) _threads = [x * y for (x, y) in zip(self.block_info, self.grid_info)] @wraps(_kernel) def launcher(*args: torch.Tensor): - return _kernel( *args, threads=_threads, diff --git a/tilelang/jit/adapter/tvm_ffi.py b/tilelang/jit/adapter/tvm_ffi.py index 96b4c85..8b86864 100644 --- a/tilelang/jit/adapter/tvm_ffi.py +++ b/tilelang/jit/adapter/tvm_ffi.py @@ -5,6 +5,7 @@ via light-weight callables so that, when the wrapped function is invoked, the execution observes the same stream context as the active Torch code. On non-CUDA builds, the stream/device fall back to 0/CPU semantics. """ + from __future__ import annotations from typing import Callable, Any @@ -31,6 +32,7 @@ class TVMFFIKernelAdapter(BaseKernelAdapter): - The stream pointer returned is a raw CUDA stream handle compatible with TVM's device API; on CPU or when CUDA is unavailable, we return 0. """ + # Class attributes to store compiled kernel information target: str | Target = "cuda" ir_module: tvm.IRModule | None = None @@ -51,19 +53,21 @@ class TVMFFIKernelAdapter(BaseKernelAdapter): dynamic_symbolic_map: dict[tir.Var, tuple[int, int, int]] | None = None # Stream/device functors are inherited from BaseKernelAdapter - def __init__(self, - params: list[KernelParam], - result_idx: list[int], - target: str | Target, - func_or_mod: tir.PrimFunc | tvm.IRModule, - host_mod: tvm.IRModule | None = None, - device_mod: tvm.IRModule | None = None, - rt_mod: tvm.runtime.Module | None = None, - host_kernel_source: str | None = None, - device_kernel_source: str | None = None, - verbose: bool = False, - pass_configs: dict[str, Any] | None = None, - compile_flags: list[str] | None = None): + def __init__( + self, + params: list[KernelParam], + result_idx: list[int], + target: str | Target, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_mod: tvm.IRModule | None = None, + device_mod: tvm.IRModule | None = None, + rt_mod: tvm.runtime.Module | None = None, + host_kernel_source: str | None = None, + device_kernel_source: str | None = None, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ): """Initialize the adapter with the given TIR function or module. Args: @@ -113,15 +117,13 @@ class TVMFFIKernelAdapter(BaseKernelAdapter): if param in buffer_map: buffer = buffer_map[param] for j, shape in enumerate(buffer.shape): - if (isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and - (shape not in params)): + if isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and (shape not in params): dynamic_symbolic_map[shape] = (0, i, j) for i, param in enumerate(params): if param in buffer_map: buffer = buffer_map[param] for j, stride in enumerate(buffer.strides): - if (isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and - (stride not in params)): + if isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and (stride not in params): dynamic_symbolic_map[stride] = (1, i, j) return dynamic_symbolic_map @@ -197,8 +199,7 @@ class TVMFFIKernelAdapter(BaseKernelAdapter): # Validate input count strictly expected_inputs = len(self.params) - len(self.result_idx) if len(inputs) != expected_inputs: - raise ValueError( - f"Kernel expected {expected_inputs} inputs, but {len(inputs)} are provided.") + raise ValueError(f"Kernel expected {expected_inputs} inputs, but {len(inputs)} are provided.") # Resolve the device used for outputs. Prefer the first tensor input's device # if available, otherwise use PyTorch's current device. @@ -217,17 +218,14 @@ class TVMFFIKernelAdapter(BaseKernelAdapter): for s in param_shapes[i]: if isinstance(s, tir.Var): for key in dynamic_symbolic_map: - if (str(s) == str(key)): - ref_id, ref_tensor_idx, ref_shape_idx = dynamic_symbolic_map[ - key] + if str(s) == str(key): + ref_id, ref_tensor_idx, ref_shape_idx = dynamic_symbolic_map[key] if ref_id == 2: shape.append(inputs[ref_tensor_idx]) elif ref_id == 0: - shape.append( - tensor_list[ref_tensor_idx].shape[ref_shape_idx]) + shape.append(tensor_list[ref_tensor_idx].shape[ref_shape_idx]) elif ref_id == 1: - shape.append( - tensor_list[ref_tensor_idx].stride()[ref_shape_idx]) + shape.append(tensor_list[ref_tensor_idx].stride()[ref_shape_idx]) else: # Already converted to Python int during initialization shape.append(s) @@ -235,11 +233,11 @@ class TVMFFIKernelAdapter(BaseKernelAdapter): out_device = current_device_functor() if len(shape) == 0: - param_name = self.params[i].name if hasattr(self.params[i], - 'name') else f'parameter_{i}' + param_name = self.params[i].name if hasattr(self.params[i], "name") else f"parameter_{i}" raise ValueError( f"Cannot create output tensor (name={param_name}) - 0-dimensional tensors are not supported. " - f"Expected shape: {shape}") + f"Expected shape: {shape}" + ) tensor = torch.empty(*shape, dtype=dtype, device=out_device) else: tensor = inputs[ins_idx] @@ -256,17 +254,19 @@ class TVMFFIKernelAdapter(BaseKernelAdapter): return func @classmethod - def from_database(cls, - params: list[TensorType], - result_idx: list[int], - target: str, - func_or_mod: tir.PrimFunc | tvm.IRModule, - host_kernel_source: str, - device_kernel_source: str, - kernel_lib_path: str, - verbose: bool = False, - pass_configs: dict[str, Any] | None = None, - compile_flags: list[str] | None = None): + def from_database( + cls, + params: list[TensorType], + result_idx: list[int], + target: str, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_kernel_source: str, + device_kernel_source: str, + kernel_lib_path: str, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ): adapter = cls.__new__(cls) adapter.params = params adapter.result_idx = adapter._legalize_result_idx(result_idx) diff --git a/tilelang/jit/adapter/utils.py b/tilelang/jit/adapter/utils.py index 94e590d..15801ff 100644 --- a/tilelang/jit/adapter/utils.py +++ b/tilelang/jit/adapter/utils.py @@ -70,7 +70,6 @@ def get_annotated_mod( target_host: str | Target | None = None, model_type: Literal["device", "host", "all"] = "all", ) -> IRModule | tuple[IRModule, IRModule]: - # Validate model_type early if model_type not in {"device", "host", "all"}: raise ValueError(f"Invalid model type: {model_type}") @@ -95,21 +94,15 @@ def get_annotated_mod( # Define dispatch dictionary for different model types dispatch = { - "device": - lambda m: tir.transform.Filter(_is_device_call)(m), - "host": - lambda m: tir.transform.Filter(_is_host_call)(m), - "all": - lambda m: (tir.transform.Filter(_is_device_call)(m), tir.transform.Filter(_is_host_call) - (m)), + "device": lambda m: tir.transform.Filter(_is_device_call)(m), + "host": lambda m: tir.transform.Filter(_is_host_call)(m), + "all": lambda m: (tir.transform.Filter(_is_device_call)(m), tir.transform.Filter(_is_host_call)(m)), } return dispatch[model_type](mod) -def pythonic_expr(expr: tvm.tir.PrimExpr, - dtype_map: dict[str, str] | None = None, - ignore_cast: bool = False) -> str: +def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = None, ignore_cast: bool = False) -> str: """ Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence. @@ -168,9 +161,23 @@ def pythonic_expr(expr: tvm.tir.PrimExpr, s = f"({type_str}){value_str}" p = PRECEDENCE.get(type(node), ATOMIC_PRECEDENCE) elif isinstance( - node, - (tvm.tir.Mul, tvm.tir.FloorDiv, tvm.tir.Add, tvm.tir.Sub, tvm.tir.FloorMod, tvm.tir.LT, - tvm.tir.LE, tvm.tir.GT, tvm.tir.GE, tvm.tir.EQ, tvm.tir.NE, tvm.tir.And, tvm.tir.Or)): + node, + ( + tvm.tir.Mul, + tvm.tir.FloorDiv, + tvm.tir.Add, + tvm.tir.Sub, + tvm.tir.FloorMod, + tvm.tir.LT, + tvm.tir.LE, + tvm.tir.GT, + tvm.tir.GE, + tvm.tir.EQ, + tvm.tir.NE, + tvm.tir.And, + tvm.tir.Or, + ), + ): op_map = { tvm.tir.Mul: "*", tvm.tir.FloorDiv: "/", @@ -222,10 +229,7 @@ def pythonic_expr(expr: tvm.tir.PrimExpr, return next(iter(node_to_result_map[expr]), "") -def maybe_desc_name(name: str, - matches: list[str], - i: int, - desc_name_map: dict[str, str] | None = None) -> bool: +def maybe_desc_name(name: str, matches: list[str], i: int, desc_name_map: dict[str, str] | None = None) -> bool: """ Check if a parameter name corresponds to a TMA descriptor. @@ -290,8 +294,7 @@ def parse_function_call_args( else: call_args.append(match) if desc_name_var_map is not None and function_params is not None: - assert len(call_args) <= len(function_params), \ - f"Too many arguments: {len(call_args)} > {len(function_params)}" + assert len(call_args) <= len(function_params), f"Too many arguments: {len(call_args)} > {len(function_params)}" desc_name_var_map[match] = function_params[len(call_args) - 1] return call_args @@ -300,12 +303,7 @@ def parse_function_call_args( class TMADescriptorParams: """Parsed TMA descriptor parameters.""" - def __init__(self, - handle_name: str, - dtype: str, - tensor_rank: int, - global_address: Any, - is_img2col: bool = False): + def __init__(self, handle_name: str, dtype: str, tensor_rank: int, global_address: Any, is_img2col: bool = False): self.handle_name = handle_name self.dtype = dtype self.tensor_rank = tensor_rank @@ -355,22 +353,19 @@ def parse_tma_descriptor_args( results = [] for handle_name, _ in desc_name_map.items(): - assert handle_name in desc_name_var_map, \ - f"Handle name {handle_name} not found in desc_name_var_map" + assert handle_name in desc_name_var_map, f"Handle name {handle_name} not found in desc_name_var_map" desc_var = desc_name_var_map[handle_name] - assert desc_var in tma_descriptor_args, \ - f"TMA descriptor {desc_var} not found in {tma_descriptor_args}" + assert desc_var in tma_descriptor_args, f"TMA descriptor {desc_var} not found in {tma_descriptor_args}" args = tma_descriptor_args[desc_var] # Skip __tvm_tensormap_create_tiled and second element (like CUDA version) if len(args) < 3: - raise ValueError( - f"TMA descriptor args too short: {len(args)} elements, expected at least 3") + raise ValueError(f"TMA descriptor args too short: {len(args)} elements, expected at least 3") tma_create_str, _, dtype, tensor_rank, global_address, *remaining_args = args - is_img2col = (tma_create_str.value == "__tvm_tensormap_create_im2col") + is_img2col = tma_create_str.value == "__tvm_tensormap_create_im2col" # Convert basic fields dtype = pythonic_expr_func(dtype) @@ -386,60 +381,45 @@ def parse_tma_descriptor_args( # Tiled mode expected_args_len = 4 * tensor_rank + 4 if len(remaining_args) < expected_args_len: - raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, " - f"expected {expected_args_len} for tensor_rank {tensor_rank}") + raise ValueError( + f"Insufficient remaining args: got {len(remaining_args)}, expected {expected_args_len} for tensor_rank {tensor_rank}" + ) # Extract dimensions and strides params.global_dim = [pythonic_expr_func(i) for i in remaining_args[:tensor_rank]] - params.global_stride = [ - pythonic_expr_func(i) for i in remaining_args[tensor_rank:2 * tensor_rank] - ] - params.box_dim = [ - pythonic_expr_func(i) for i in remaining_args[2 * tensor_rank:3 * tensor_rank] - ] - params.element_strides = [ - pythonic_expr_func(i) for i in remaining_args[3 * tensor_rank:4 * tensor_rank] - ] + params.global_stride = [pythonic_expr_func(i) for i in remaining_args[tensor_rank : 2 * tensor_rank]] + params.box_dim = [pythonic_expr_func(i) for i in remaining_args[2 * tensor_rank : 3 * tensor_rank]] + params.element_strides = [pythonic_expr_func(i) for i in remaining_args[3 * tensor_rank : 4 * tensor_rank]] # Extract remaining parameters try: - interleave, swizzle, l2_promotion, oob_fill = remaining_args[4 * tensor_rank:4 * - tensor_rank + 4] + interleave, swizzle, l2_promotion, oob_fill = remaining_args[4 * tensor_rank : 4 * tensor_rank + 4] params.interleave = pythonic_expr_func(interleave) params.swizzle = pythonic_expr_func(swizzle) params.l2_promotion = pythonic_expr_func(l2_promotion) params.oob_fill = pythonic_expr_func(oob_fill) except ValueError as e: - raise ValueError( - "Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)" - ) from e + raise ValueError("Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)") from e else: # Im2col mode expected_args_len = 5 * tensor_rank + 2 if len(remaining_args) < expected_args_len: - raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, " - f"expected {expected_args_len} for tensor_rank {tensor_rank}") + raise ValueError( + f"Insufficient remaining args: got {len(remaining_args)}, expected {expected_args_len} for tensor_rank {tensor_rank}" + ) # Extract dimensions and strides params.global_dim = [pythonic_expr_func(i) for i in remaining_args[:tensor_rank]] - params.global_stride = [ - pythonic_expr_func(i) for i in remaining_args[tensor_rank:2 * tensor_rank] - ] - params.element_strides = [ - pythonic_expr_func(i) for i in remaining_args[2 * tensor_rank:3 * tensor_rank] - ] - params.lower_corner = [ - pythonic_expr_func(i) for i in remaining_args[3 * tensor_rank:4 * tensor_rank - 2] - ] - params.upper_corner = [ - pythonic_expr_func(i) - for i in remaining_args[4 * tensor_rank - 2:5 * tensor_rank - 4] - ] + params.global_stride = [pythonic_expr_func(i) for i in remaining_args[tensor_rank : 2 * tensor_rank]] + params.element_strides = [pythonic_expr_func(i) for i in remaining_args[2 * tensor_rank : 3 * tensor_rank]] + params.lower_corner = [pythonic_expr_func(i) for i in remaining_args[3 * tensor_rank : 4 * tensor_rank - 2]] + params.upper_corner = [pythonic_expr_func(i) for i in remaining_args[4 * tensor_rank - 2 : 5 * tensor_rank - 4]] # Extract remaining parameters try: - smem_box_pixel, smem_box_channel, interleave, swizzle, l2_promotion, oob_fill = \ - remaining_args[5 * tensor_rank - 4:5 * tensor_rank + 2] + smem_box_pixel, smem_box_channel, interleave, swizzle, l2_promotion, oob_fill = remaining_args[ + 5 * tensor_rank - 4 : 5 * tensor_rank + 2 + ] params.smem_box_pixel = pythonic_expr_func(smem_box_pixel) params.smem_box_channel = pythonic_expr_func(smem_box_channel) params.interleave = pythonic_expr_func(interleave) diff --git a/tilelang/jit/adapter/wrapper.py b/tilelang/jit/adapter/wrapper.py index 7560797..c028a58 100644 --- a/tilelang/jit/adapter/wrapper.py +++ b/tilelang/jit/adapter/wrapper.py @@ -4,9 +4,18 @@ from tilelang import tvm as tvm from typing import Any from tvm import IRModule from tvm.target import Target -from .utils import (is_metal_target, match_declare_kernel, match_declare_kernel_cpu, is_cuda_target, - is_hip_target, is_cpu_target, get_annotated_mod, pythonic_expr, - parse_function_call_args, parse_tma_descriptor_args) +from .utils import ( + is_metal_target, + match_declare_kernel, + match_declare_kernel_cpu, + is_cuda_target, + is_hip_target, + is_cpu_target, + get_annotated_mod, + pythonic_expr, + parse_function_call_args, + parse_tma_descriptor_args, +) import re import logging import textwrap @@ -129,7 +138,6 @@ TMA_IM2COL_DESC_INIT_FUNC = """ class BaseWrapper(ABC): - @abstractmethod def wrap(self, *args, **kwargs): raise NotImplementedError @@ -163,13 +171,15 @@ class TLCUDASourceWrapper: host_mod: IRModule | None = None pass_configs: dict[str, Any] | None = None - def __init__(self, - scheduled_ir_module: IRModule, - source: str, - target: Target, - device_mod: IRModule | None = None, - host_mod: IRModule | None = None, - pass_configs: dict[str, Any] | None = None): + def __init__( + self, + scheduled_ir_module: IRModule, + source: str, + target: Target, + device_mod: IRModule | None = None, + host_mod: IRModule | None = None, + pass_configs: dict[str, Any] | None = None, + ): self.mod = scheduled_ir_module self.target = target self.source = source @@ -211,15 +221,16 @@ class TLCUDASourceWrapper: for param in self.prim_func.params: if param in self.prim_func.buffer_map: buffer = self.prim_func.buffer_map[param] - function_args.append({ - "name": buffer.data.name, - "type": self._lookup_type(buffer.dtype) + "* __restrict__", - }) + function_args.append( + { + "name": buffer.data.name, + "type": self._lookup_type(buffer.dtype) + "* __restrict__", + } + ) elif isinstance(param, tvm.tir.Var): function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)}) else: - raise ValueError( - f"Parameter {param} is not in the buffer map of the primary function.") + raise ValueError(f"Parameter {param} is not in the buffer map of the primary function.") # Add dynamic symbols as integer arguments for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set: if dyn_sym not in [arg["name"] for arg in function_args]: @@ -256,38 +267,40 @@ class TLCUDASourceWrapper: # Identify the start of the function body to insert arguments index = code.index("{", index) - block_str = f"dim3({self._pythonic_expr(block_info[0])}, {self._pythonic_expr(block_info[1])}, {self._pythonic_expr(block_info[2])})" - grid_str = f"dim3({self._pythonic_expr(grid_info[0])}, {self._pythonic_expr(grid_info[1])}, {self._pythonic_expr(grid_info[2])})" + block_str = ( + f"dim3({self._pythonic_expr(block_info[0])}, {self._pythonic_expr(block_info[1])}, {self._pythonic_expr(block_info[2])})" + ) + grid_str = ( + f"dim3({self._pythonic_expr(grid_info[0])}, {self._pythonic_expr(grid_info[1])}, {self._pythonic_expr(grid_info[2])})" + ) smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf init_l2_persistent_map = self.generate_l2_persistent_map(function_name) kernel_launch_code += init_l2_persistent_map if self.use_cooperative_groups[function_name]: - args_list = parse_function_call_args(declaration, function_args, function_params, - desc_name_map, desc_name_var_map) - assert len(function_params) == len( - args_list - ), f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments" + args_list = parse_function_call_args(declaration, function_args, function_params, desc_name_map, desc_name_var_map) + assert len(function_params) == len(args_list), ( + f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments" + ) args_array = [f"(void*)&{arg}" for arg in args_list] call_args = f"\tvoid* {function_name}_args[] = {{{', '.join(args_array)}}};\n" kernel_launch_code += call_args # Using cudaLaunchCooperativeKernel to launch the kernel kernel_launch_code += "\tTILELANG_CHECK(cudaLaunchCooperativeKernel((void*){}, {}, {}, {}, {}, stream));\n".format( - function_name, grid_str, block_str, function_name + "_args", smem_str) + function_name, grid_str, block_str, function_name + "_args", smem_str + ) else: - args_list = parse_function_call_args(declaration, function_args, function_params, - desc_name_map, desc_name_var_map) - assert len(function_params) == len( - args_list - ), f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments" + args_list = parse_function_call_args(declaration, function_args, function_params, desc_name_map, desc_name_var_map) + assert len(function_params) == len(args_list), ( + f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments" + ) call_args = ", ".join(args_list) kernel_launch_code += f"\t{function_name}<<<{grid_str}, {block_str}, {smem_str}, stream>>>({call_args});\n" - kernel_launch_code += f"\tTILELANG_CHECK_LAST_ERROR(\"{function_name}\");\n" + kernel_launch_code += f'\tTILELANG_CHECK_LAST_ERROR("{function_name}");\n' if has_l2_persistent_map: kernel_launch_code += L2_PERSISTENT_MAP_RESET_HANDLE - init_tma_descriptor_args = self.generate_tma_descriptor_args(desc_name_map, - desc_name_var_map) + init_tma_descriptor_args = self.generate_tma_descriptor_args(desc_name_map, desc_name_var_map) kernel_launch_code = init_tma_descriptor_args + kernel_launch_code # Wrap the kernel dispatch logic in an external C function @@ -298,46 +311,63 @@ class TLCUDASourceWrapper: if function_name not in self.l2_persistent_map: return "" init_l2_persistent_map = "" - for buffer_name, (hit_ratio, - size_in_bytes) in self.l2_persistent_map[function_name].items(): + for buffer_name, (hit_ratio, size_in_bytes) in self.l2_persistent_map[function_name].items(): # get persisting_l2_cache_max_size from tilelang.carver.arch.driver import get_persisting_l2_cache_max_size + persisting_l2_cache_max_size = get_persisting_l2_cache_max_size() try: num_bytes = min(size_in_bytes, persisting_l2_cache_max_size) except Exception: # as size_in_bytes maybe a symbolic expression num_bytes = persisting_l2_cache_max_size - init_l2_persistent_map += L2_PERSISTENT_MAP_INIT_FUNC.format( - buffer_name, float(hit_ratio), self._pythonic_expr(num_bytes)) + init_l2_persistent_map += L2_PERSISTENT_MAP_INIT_FUNC.format(buffer_name, float(hit_ratio), self._pythonic_expr(num_bytes)) return init_l2_persistent_map - def generate_tma_descriptor_args(self, desc_name_map: dict[str, str], - desc_name_var_map: dict[str, tvm.tir.Var]) -> str: + def generate_tma_descriptor_args(self, desc_name_map: dict[str, str], desc_name_var_map: dict[str, tvm.tir.Var]) -> str: tma_descripter_init = "" if self.tma_descriptor_args is None: return tma_descripter_init # Parse TMA descriptor arguments using the common utility - parsed_params = parse_tma_descriptor_args(self.tma_descriptor_args, desc_name_map, - desc_name_var_map, self._pythonic_expr) + parsed_params = parse_tma_descriptor_args(self.tma_descriptor_args, desc_name_map, desc_name_var_map, self._pythonic_expr) # Generate C++ code from parsed parameters for params in parsed_params: if not params.is_img2col: tma_descripter_init += TMA_DESC_INIT_FUNC.format( - params.handle_name, params.dtype, params.tensor_rank, params.global_address, - ",".join(params.global_dim), ",".join(params.global_stride), - ",".join(params.box_dim), ",".join(params.element_strides), params.interleave, - params.swizzle, params.l2_promotion, params.oob_fill) + params.handle_name, + params.dtype, + params.tensor_rank, + params.global_address, + ",".join(params.global_dim), + ",".join(params.global_stride), + ",".join(params.box_dim), + ",".join(params.element_strides), + params.interleave, + params.swizzle, + params.l2_promotion, + params.oob_fill, + ) else: tma_descripter_init += TMA_IM2COL_DESC_INIT_FUNC.format( - params.handle_name, params.dtype, params.tensor_rank, params.global_address, - ",".join(params.global_dim), ",".join(params.global_stride), - ",".join(params.element_strides), ",".join(params.lower_corner), - ",".join(params.upper_corner), params.smem_box_channel, params.smem_box_pixel, - params.interleave, params.swizzle, params.l2_promotion, params.oob_fill) + params.handle_name, + params.dtype, + params.tensor_rank, + params.global_address, + ",".join(params.global_dim), + ",".join(params.global_stride), + ",".join(params.element_strides), + ",".join(params.lower_corner), + ",".join(params.upper_corner), + params.smem_box_channel, + params.smem_box_pixel, + params.interleave, + params.swizzle, + params.l2_promotion, + params.oob_fill, + ) return tma_descripter_init @@ -347,9 +377,8 @@ class TLCUDASourceWrapper: device_mod, host_mod = get_annotated_mod(self.mod, self.target) self.device_mod = device_mod self.host_mod = host_mod - assert (len(self.device_mod.functions) - >= 1), "Device module should have at least one function." - assert (len(self.host_mod.functions) == 1), "Only support one function in host module." + assert len(self.device_mod.functions) >= 1, "Device module should have at least one function." + assert len(self.host_mod.functions) == 1, "Only support one function in host module." block_info_map = {} grid_info_map = {} @@ -438,8 +467,7 @@ class TLCUDASourceWrapper: for function_name, dynamic_smem_buf in self.dynamic_smem_buf.items(): if dynamic_smem_buf is not None: # Format the cudaFuncSetAttribute call for dynamic shared memory - call_str += PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY.format( - function_name, dynamic_smem_buf) + call_str += PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY.format(function_name, dynamic_smem_buf) # Format the initialization function using the call_str init_funcs = PREDEF_INIT_FUNC.format(call_str) return init_funcs @@ -466,17 +494,14 @@ class TLCUDASourceWrapper: def visitor(node, fn=function_name, param_cnt=kernel_params_cnt): nonlocal function_params if isinstance(node, tvm.tir.Call): - if not (hasattr(node, "op") and - node.op == tvm.ir.Op.get("tir.tvm_call_packed")): + if not (hasattr(node, "op") and node.op == tvm.ir.Op.get("tir.tvm_call_packed")): return args = node.args if not args or args[0] != fn: return if len(args) < 1 + param_cnt: - raise AssertionError( - "tvm_call_packed should have at least 1 argument and match device function parameters" - ) - function_params = args[1:1 + param_cnt] + raise AssertionError("tvm_call_packed should have at least 1 argument and match device function parameters") + function_params = args[1 : 1 + param_cnt] post_order_visit(self.host_func.body, visitor) assert function_params is not None, "function_params should not be None" @@ -564,13 +589,15 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper): "uchar": "uint8_t", } - def __init__(self, - scheduled_ir_module: IRModule, - source: str, - target: Target, - device_mod: IRModule | None = None, - host_mod: IRModule | None = None, - pass_configs: dict[str, Any] | None = None): + def __init__( + self, + scheduled_ir_module: IRModule, + source: str, + target: Target, + device_mod: IRModule | None = None, + host_mod: IRModule | None = None, + pass_configs: dict[str, Any] | None = None, + ): super().__init__(scheduled_ir_module, source, target, device_mod, host_mod, pass_configs) def get_init_func(self): @@ -580,8 +607,7 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper): for function_name, dynamic_smem_buf in self.dynamic_smem_buf.items(): if dynamic_smem_buf is not None: # Format the cudaFuncSetAttribute call for dynamic shared memory - call_str += PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY_HIP.format( - function_name, dynamic_smem_buf) + call_str += PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY_HIP.format(function_name, dynamic_smem_buf) # Format the initialization function using the call_str init_funcs = PREDEF_INIT_FUNC.format(call_str) return init_funcs @@ -623,13 +649,15 @@ class TLCPUSourceWrapper: host_mod: IRModule | None = None pass_configs: dict[str, Any] | None = None - def __init__(self, - scheduled_ir_module: IRModule, - source: str, - target: Target, - device_mod: IRModule | None = None, - host_mod: IRModule | None = None, - pass_configs: dict[str, Any] | None = None): + def __init__( + self, + scheduled_ir_module: IRModule, + source: str, + target: Target, + device_mod: IRModule | None = None, + host_mod: IRModule | None = None, + pass_configs: dict[str, Any] | None = None, + ): self.mod = scheduled_ir_module self.target = target self.source = source @@ -658,15 +686,16 @@ class TLCPUSourceWrapper: for param in self.prim_func.params: if param in self.prim_func.buffer_map: buffer = self.prim_func.buffer_map[param] - function_args.append({ - "name": buffer.name, - "type": self._lookup_type(buffer.dtype) + "*", - }) + function_args.append( + { + "name": buffer.name, + "type": self._lookup_type(buffer.dtype) + "*", + } + ) elif isinstance(param, tvm.tir.Var): function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)}) else: - raise ValueError( - f"Parameter {param} is not in the buffer map of the primary function.") + raise ValueError(f"Parameter {param} is not in the buffer map of the primary function.") # Add dynamic symbols as integer arguments for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set: function_args.append({"name": dyn_sym, "type": self._lookup_type(dyn_sym_dtype)}) @@ -686,7 +715,6 @@ class TLCPUSourceWrapper: _call_str = """""" for function_name, _ in function_informations.items(): - # Find the location of the global kernel function in the code index = match_declare_kernel_cpu(code, function_name + "(") @@ -706,8 +734,8 @@ class TLCPUSourceWrapper: def parse_source_information(self): with tvm.transform.PassContext(opt_level=3, config=self.pass_configs): device_mod, host_mod = get_annotated_mod(self.mod, self.target) - assert (len(device_mod.functions) >= 1), "Device module should have at least one function." - assert (len(host_mod.functions) == 1), "Only support one function in host module." + assert len(device_mod.functions) >= 1, "Device module should have at least one function." + assert len(host_mod.functions) == 1, "Only support one function in host module." function_names = [] for g_var, _ in device_mod.functions.items(): @@ -767,14 +795,15 @@ class TLCPUSourceWrapper: class TLMetalSourceWrapper: - - def __init__(self, - scheduled_ir_module: IRModule, - source: str, - target: Target, - device_mod: IRModule | None = None, - host_mod: IRModule | None = None, - pass_configs: dict[str, Any] | None = None): + def __init__( + self, + scheduled_ir_module: IRModule, + source: str, + target: Target, + device_mod: IRModule | None = None, + host_mod: IRModule | None = None, + pass_configs: dict[str, Any] | None = None, + ): self.mod = scheduled_ir_module self.target = target self.source = source @@ -792,6 +821,7 @@ class TLWrapper(BaseWrapper): """ A wrapper class for the TileLang backend. """ + device_mod: IRModule | None = None host_mod: IRModule | None = None pass_configs: dict[str, Any] | None = None @@ -836,12 +866,12 @@ class TLWrapper(BaseWrapper): target=self.target, device_mod=self.device_mod, host_mod=self.host_mod, - pass_configs=self.pass_configs) + pass_configs=self.pass_configs, + ) return wrapper.lib_code class TLPyWrapper(TLWrapper): - def __init__(self, target: Target): super().__init__(target) @@ -849,6 +879,7 @@ class TLPyWrapper(TLWrapper): # assert self.scheduled_ir_module is not None, "Please assign optimized module first." if is_cuda_target(self.target): from tilelang.jit.adapter.nvrtc import TLNVRTCSourceWrapper + wrapper_class = TLNVRTCSourceWrapper else: raise ValueError(f"Unsupported target for NVRTC backend: {self.target}") @@ -858,5 +889,6 @@ class TLPyWrapper(TLWrapper): target=self.target, device_mod=self.device_mod, host_mod=self.host_mod, - pass_configs=self.pass_configs) + pass_configs=self.pass_configs, + ) return wrapper.host_func, wrapper.function_names diff --git a/tilelang/jit/execution_backend.py b/tilelang/jit/execution_backend.py index fe60000..492e8cb 100644 --- a/tilelang/jit/execution_backend.py +++ b/tilelang/jit/execution_backend.py @@ -46,6 +46,7 @@ def allowed_backends_for_target(target: Target, *, include_unavailable: bool = T # Drop NVRTC if not importable try: from tilelang.jit.adapter.nvrtc import is_nvrtc_available # lazy + if not is_nvrtc_available and "nvrtc" in allowed: allowed = [b for b in allowed if b != "nvrtc"] except Exception: @@ -89,12 +90,14 @@ def resolve_execution_backend(requested: str | None, target: Target) -> str: if req not in allowed_all: raise ValueError( f"Invalid execution backend '{requested}' for target '{_target_kind(target)}'. " - f"Allowed: {_format_options(allowed_all)}. Tip: use execution_backend='auto'.") + f"Allowed: {_format_options(allowed_all)}. Tip: use execution_backend='auto'." + ) # Promote to availability-aware set for nicer errors (e.g., nvrtc not installed) if req not in allowed_avail: raise ValueError( f"Execution backend '{requested}' requires extra dependencies and is not available now. " - f"Try one of: {_format_options(allowed_avail)}.") + f"Try one of: {_format_options(allowed_avail)}." + ) return req diff --git a/tilelang/jit/kernel.py b/tilelang/jit/kernel.py index 22cecf9..c05ef9e 100644 --- a/tilelang/jit/kernel.py +++ b/tilelang/jit/kernel.py @@ -1,5 +1,6 @@ from __future__ import annotations from typing import Any, Callable, Generic, Literal, TypeVar + # Python 3.9 compatibility for ParamSpec try: from typing import ParamSpec @@ -14,8 +15,7 @@ import tilelang from tilelang import tvm from tilelang import env from tilelang.engine.param import CompiledArtifact, KernelParam -from tilelang.jit.adapter import (BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter, - TVMFFIKernelAdapter, MetalKernelAdapter) +from tilelang.jit.adapter import BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter, TVMFFIKernelAdapter, MetalKernelAdapter from tilelang.profiler import Profiler, TensorSupplyType from tilelang.utils.target import determine_target from tilelang.contrib import nvcc as tl_nvcc @@ -24,8 +24,8 @@ import os logger = logging.getLogger(__name__) -_P = ParamSpec('_P') -_T = TypeVar('_T') +_P = ParamSpec("_P") +_T = TypeVar("_T") class JITKernel(Generic[_P, _T]): @@ -41,6 +41,7 @@ class JITKernel(Generic[_P, _T]): torch_function : Callable The compiled function that can be invoked as a PyTorch-compatible function. """ + prim_func: PrimFunc = None artifact: CompiledArtifact = None adapter: BaseKernelAdapter = None @@ -111,9 +112,7 @@ class JITKernel(Generic[_P, _T]): if execution_backend == "cython": from tilelang.contrib.cc import get_cplus_compiler - assert ( - get_cplus_compiler() is not None - ), "Cython backend requires a C++ compiler, please install or use other backends." + assert get_cplus_compiler() is not None, "Cython backend requires a C++ compiler, please install or use other backends." if from_database: return @@ -200,8 +199,7 @@ class JITKernel(Generic[_P, _T]): """ return self.torch_function(*args, **kwds) - def _compile_and_create_adapter(self, tilelang_func: PrimFunc, - out_idx: list[int]) -> BaseKernelAdapter: + def _compile_and_create_adapter(self, tilelang_func: PrimFunc, out_idx: list[int]) -> BaseKernelAdapter: """ Compiles the given TileLang PrimFunc using TVM and creates a kernel adapter. @@ -233,7 +231,8 @@ class JITKernel(Generic[_P, _T]): target=target, target_host=target_host, enable_host_codegen=enable_host_codegen, - enable_device_compile=enable_device_compile) + enable_device_compile=enable_device_compile, + ) self.artifact = artifact @@ -241,7 +240,7 @@ class JITKernel(Generic[_P, _T]): if execution_backend == "tvm_ffi": # Use TVMFFIKernelAdapter for interoperability with PyTorch via DLPack. # But we need to ensure that the runtime is enabled and the runtime module is not None. - assert (artifact.rt_mod is not None), "tvm_ffi backend requires a runtime module." + assert artifact.rt_mod is not None, "tvm_ffi backend requires a runtime module." adapter = TVMFFIKernelAdapter( params=artifact.params, result_idx=out_idx, @@ -283,6 +282,7 @@ class JITKernel(Generic[_P, _T]): ) elif execution_backend == "nvrtc": from tilelang.jit.adapter import NVRTCKernelAdapter + adapter = NVRTCKernelAdapter( params=artifact.params, result_idx=out_idx, @@ -315,16 +315,18 @@ class JITKernel(Generic[_P, _T]): return adapter - def _create_adapter_from_database(self, - params: list[KernelParam], - result_idx: list[int] | int, - target: str | Target, - func_or_mod: PrimFunc | tvm.runtime.Module, - host_kernel_source: str, - device_kernel_source: str, - kernel_lib_path: str, - pass_configs: dict[str, Any] | None = None, - compile_flags: list[str] | None = None) -> BaseKernelAdapter: + def _create_adapter_from_database( + self, + params: list[KernelParam], + result_idx: list[int] | int, + target: str | Target, + func_or_mod: PrimFunc | tvm.runtime.Module, + host_kernel_source: str, + device_kernel_source: str, + kernel_lib_path: str, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ) -> BaseKernelAdapter: target = self.target execution_backend = self.execution_backend @@ -366,6 +368,7 @@ class JITKernel(Generic[_P, _T]): ) elif execution_backend == "nvrtc": from tilelang.jit.adapter import NVRTCKernelAdapter + adapter = NVRTCKernelAdapter.from_database( params=params, result_idx=result_idx, @@ -402,8 +405,7 @@ class JITKernel(Generic[_P, _T]): """ return cls(func=tilelang_func, **kwargs) - def get_profiler(self, - tensor_supply_type: TensorSupplyType = TensorSupplyType.Auto) -> Profiler: + def get_profiler(self, tensor_supply_type: TensorSupplyType = TensorSupplyType.Auto) -> Profiler: """ Creates a profiler to benchmark the compiled runtime module. @@ -417,8 +419,7 @@ class JITKernel(Generic[_P, _T]): Profiler A Profiler instance for benchmarking the runtime module. """ - return Profiler(self.params, self.out_idx, - tensor_supply_type).with_default_adapter(self.adapter) + return Profiler(self.params, self.out_idx, tensor_supply_type).with_default_adapter(self.adapter) def get_kernel_source(self, kernel_only: bool = True) -> str: """ @@ -507,21 +508,19 @@ class JITKernel(Generic[_P, _T]): dir_path = os.path.dirname(kernel_path) if dir_path: os.makedirs(dir_path, exist_ok=True) - with open(kernel_path, 'w') as f: + with open(kernel_path, "w") as f: f.write(self.get_kernel_source()) if host_path is not None: dir_path = os.path.dirname(host_path) if dir_path: os.makedirs(dir_path, exist_ok=True) - with open(host_path, 'w') as f: + with open(host_path, "w") as f: f.write(self.get_host_source()) except Exception as e: logger.error(f"Failed to export sources: {e}") # Backward compatibility alias (deprecated) - def print_source_code(self, - which: Literal["kernel", "host", "both"] = "kernel", - file: str | None = None) -> None: + def print_source_code(self, which: Literal["kernel", "host", "both"] = "kernel", file: str | None = None) -> None: """ Deprecated: use show_source() or export_sources() instead. @@ -541,16 +540,14 @@ class JITKernel(Generic[_P, _T]): >>> # Old API (still works but deprecated) >>> jit_kernel.print_source_code(file="/tmp/kernel.cu") """ - logger.warning( - "print_source_code is deprecated; use show_source() or export_sources() instead.") + logger.warning("print_source_code is deprecated; use show_source() or export_sources() instead.") if file is not None: # Historical behavior wrote only kernel source when file provided self.export_sources(kernel_path=file) else: self.show_source(which=which) - def update_tuner_result(self, latency: float, config: dict[str, Any], - ref_latency: float) -> JITKernel: + def update_tuner_result(self, latency: float, config: dict[str, Any], ref_latency: float) -> JITKernel: """ Updates the tuning results for this kernel. @@ -651,8 +648,7 @@ class JITKernel(Generic[_P, _T]): verbose = self.verbose # Ensure target is set so nvcc picks correct arch via Target.current() with self.target: - return tl_nvcc.get_ptx_from_source( - code, compile_flags=self.compile_flags, verbose=verbose) + return tl_nvcc.get_ptx_from_source(code, compile_flags=self.compile_flags, verbose=verbose) def show_ptx(self) -> None: """ @@ -714,8 +710,7 @@ class JITKernel(Generic[_P, _T]): if verbose is None: verbose = self.verbose with self.target: - return tl_nvcc.get_sass_from_source( - code, compile_flags=self.compile_flags, verbose=verbose) + return tl_nvcc.get_sass_from_source(code, compile_flags=self.compile_flags, verbose=verbose) def show_sass(self) -> None: """ diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index c91ac3c..0f3d5fb 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -1,4 +1,5 @@ """The language interface for tl programs.""" + from __future__ import annotations # from .parser import * @@ -102,7 +103,10 @@ from .utils import index_to_coordinates # noqa: F401 from .symbolics import dynamic, symbolic # noqa: F401 from .annotations import ( # noqa: F401 - use_swizzle, annotate_layout, annotate_safe_value, annotate_l2_hit_ratio, + use_swizzle, + annotate_layout, + annotate_safe_value, + annotate_l2_hit_ratio, ) diff --git a/tilelang/language/allocate.py b/tilelang/language/allocate.py index 7337782..b26f0b8 100644 --- a/tilelang/language/allocate.py +++ b/tilelang/language/allocate.py @@ -13,8 +13,10 @@ Available allocation functions: Each function takes shape and dtype parameters and returns a TVM buffer object with the appropriate memory scope. """ + from __future__ import annotations from typing import TypeVar, overload, Literal, Callable + # Python 3.9 compatibility for advanced typing features (PEP 646) try: from typing import TypeVarTuple, Unpack # type: ignore[attr-defined] @@ -30,13 +32,11 @@ from .v2.dtypes import dtype as tl_dtype from .v2.builder import OutTensor from .v2.annot import Tensor, SharedBuffer, LocalBuffer, FragmentBuffer -_Shapes = TypeVarTuple('_Shapes') -_DType = TypeVar('_DType') +_Shapes = TypeVarTuple("_Shapes") +_DType = TypeVar("_DType") -def alloc_shared(shape: tuple[Unpack[_Shapes]], - dtype: _DType, - scope="shared.dyn") -> SharedBuffer[Callable[[Unpack[_Shapes]]], _DType]: +def alloc_shared(shape: tuple[Unpack[_Shapes]], dtype: _DType, scope="shared.dyn") -> SharedBuffer[Callable[[Unpack[_Shapes]]], _DType]: """Allocate a shared memory buffer for inter-thread communication. Args: @@ -54,9 +54,7 @@ def alloc_shared(shape: tuple[Unpack[_Shapes]], return T.alloc_buffer(shape, dtype, scope=scope) -def alloc_local(shape: tuple[Unpack[_Shapes]], - dtype: _DType, - scope="local") -> LocalBuffer[Callable[[Unpack[_Shapes]]], _DType]: +def alloc_local(shape: tuple[Unpack[_Shapes]], dtype: _DType, scope="local") -> LocalBuffer[Callable[[Unpack[_Shapes]]], _DType]: """Allocate a local memory buffer for thread-private storage. Args: @@ -70,9 +68,9 @@ def alloc_local(shape: tuple[Unpack[_Shapes]], return T.alloc_buffer(shape, dtype, scope=scope) -def alloc_fragment(shape: tuple[Unpack[_Shapes]], - dtype: _DType, - scope="local.fragment") -> FragmentBuffer[Callable[[Unpack[_Shapes]]], _DType]: +def alloc_fragment( + shape: tuple[Unpack[_Shapes]], dtype: _DType, scope="local.fragment" +) -> FragmentBuffer[Callable[[Unpack[_Shapes]]], _DType]: """Allocate a fragment memory buffer for specialized operations. Args: @@ -87,16 +85,11 @@ def alloc_fragment(shape: tuple[Unpack[_Shapes]], @overload -def alloc_var(dtype: str, init: PrimExpr | int | float, scope: str = 'local.var') -> Buffer: - ... +def alloc_var(dtype: str, init: PrimExpr | int | float, scope: str = "local.var") -> Buffer: ... @overload -def alloc_var(dtype: str, - scope: str = 'local.var', - *, - init: PrimExpr | int | float | None = None) -> Buffer: - ... +def alloc_var(dtype: str, scope: str = "local.var", *, init: PrimExpr | int | float | None = None) -> Buffer: ... def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None): @@ -142,8 +135,7 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None): raise TypeError("Scope must be provided as a string in alloc_var.") parsed_scope = parsed_scope_arg elif len(args) > 2: - raise TypeError( - f"alloc_var expected at most 3 positional arguments but got {len(args) + 1}.") + raise TypeError(f"alloc_var expected at most 3 positional arguments but got {len(args) + 1}.") if not isinstance(parsed_scope, str): raise TypeError("Scope must be a string in alloc_var.") @@ -274,13 +266,10 @@ def alloc_tcgen05_instr_desc(dtype: str = "uint32"): @overload -def empty(shape: tuple[Unpack[_Shapes]], - dtype: str = 'float32') -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: - ... +def empty(shape: tuple[Unpack[_Shapes]], dtype: str = "float32") -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: ... -def empty(*shape: Unpack[_Shapes], - dtype: str = 'float32') -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: +def empty(*shape: Unpack[_Shapes], dtype: str = "float32") -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: if len(shape) == 1 and isinstance(shape[0], (tuple, list)): return OutTensor(shape[0], dtype) elif len(shape) == 2 and isinstance(shape[0], (tuple, list)) and isinstance(shape[1], str): @@ -288,4 +277,4 @@ def empty(*shape: Unpack[_Shapes], elif all([isinstance(x, (int, PrimExpr)) for x in shape]): return OutTensor(shape, dtype) else: - raise RuntimeError(f'Invalid shape {shape}') + raise RuntimeError(f"Invalid shape {shape}") diff --git a/tilelang/language/annotations.py b/tilelang/language/annotations.py index 2ce71cb..09cfa58 100644 --- a/tilelang/language/annotations.py +++ b/tilelang/language/annotations.py @@ -1,4 +1,5 @@ """Annotation helpers exposed on the TileLang language surface.""" + from typing import Callable from tilelang.layout import Layout diff --git a/tilelang/language/ast/__init__.py b/tilelang/language/ast/__init__.py index 9d77454..6ab6249 100644 --- a/tilelang/language/ast/__init__.py +++ b/tilelang/language/ast/__init__.py @@ -17,6 +17,7 @@ # This file is modified from the original version, # which is part of the TVM project (https://tvm.apache.org/). """Package tvm.script.ir_builder.tir""" + from .ir import * # noqa: F401 from .ir import boolean as bool # noqa: F401 from .ir import buffer as Buffer # noqa: F401 diff --git a/tilelang/language/ast/_ffi_api.py b/tilelang/language/ast/_ffi_api.py index 518d57e..5cc7476 100644 --- a/tilelang/language/ast/_ffi_api.py +++ b/tilelang/language/ast/_ffi_api.py @@ -17,6 +17,7 @@ # This file is modified from the original version, # which is part of the TVM project (https://tvm.apache.org/). """FFI APIs""" + import tvm.ffi tvm.ffi._init_api("script.ir_builder.tir", __name__) # pylint: disable=protected-access diff --git a/tilelang/language/ast/ir.py b/tilelang/language/ast/ir.py index 41b658d..0352514 100644 --- a/tilelang/language/ast/ir.py +++ b/tilelang/language/ast/ir.py @@ -558,7 +558,8 @@ class axis: # pylint: disable=invalid-name The iteration variable. """ return _ffi_api.AxisSpatial( # type: ignore[attr-defined] # pylint: disable=no-member - _as_range(dom), binding, dtype) + _as_range(dom), binding, dtype + ) @staticmethod def reduce( @@ -585,7 +586,8 @@ class axis: # pylint: disable=invalid-name The iteration variable. """ return _ffi_api.AxisReduce( # type: ignore[attr-defined] # pylint: disable=no-member - _as_range(dom), binding, dtype) + _as_range(dom), binding, dtype + ) @staticmethod def scan( @@ -612,7 +614,8 @@ class axis: # pylint: disable=invalid-name The iteration variable. """ return _ffi_api.AxisScan( # type: ignore[attr-defined] # pylint: disable=no-member - _as_range(dom), binding, dtype) + _as_range(dom), binding, dtype + ) @staticmethod def opaque( @@ -639,7 +642,8 @@ class axis: # pylint: disable=invalid-name The iteration variable. """ return _ffi_api.AxisOpaque( # type: ignore[attr-defined] # pylint: disable=no-member - _as_range(dom), binding, dtype) + _as_range(dom), binding, dtype + ) @staticmethod def remap(kinds: str, bindings: List[PrimExpr], dtype: str = "int32") -> Union[List[Var], Var]: @@ -662,17 +666,15 @@ class axis: # pylint: disable=invalid-name The iteration variables. """ iter_vars = _ffi_api.AxisRemap( # type: ignore[attr-defined] # pylint: disable=no-member - kinds, bindings, dtype) + kinds, bindings, dtype + ) return iter_vars[0] if len(iter_vars) == 1 else iter_vars S = spatial # pylint: disable=invalid-name R = reduce # pylint: disable=invalid-name -def serial(start: PrimExpr, - stop: PrimExpr = None, - *, - annotations: Dict[str, Any] = None) -> frame.ForFrame: +def serial(start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None) -> frame.ForFrame: """The serial For statement. Parameters @@ -700,10 +702,7 @@ def serial(start: PrimExpr, return _ffi_api.Serial(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member -def parallel(start: PrimExpr, - stop: PrimExpr = None, - *, - annotations: Dict[str, Any] = None) -> frame.ForFrame: +def parallel(start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None) -> frame.ForFrame: """The parallel For statement. Parameters @@ -731,10 +730,7 @@ def parallel(start: PrimExpr, return _ffi_api.Parallel(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member -def vectorized(start: PrimExpr, - stop: PrimExpr = None, - *, - annotations: Dict[str, Any] = None) -> frame.ForFrame: +def vectorized(start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None) -> frame.ForFrame: """The vectorized For statement. Parameters @@ -762,10 +758,7 @@ def vectorized(start: PrimExpr, return _ffi_api.Vectorized(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member -def unroll(start: PrimExpr, - stop: PrimExpr = None, - *, - annotations: Dict[str, Any] = None) -> frame.ForFrame: +def unroll(start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None) -> frame.ForFrame: """The unrolled For statement. Parameters @@ -837,7 +830,8 @@ def thread_binding( else: start = 0 return _ffi_api.ThreadBinding( # type: ignore[attr-defined] # pylint: disable=no-member - start, stop, thread, annotations) + start, stop, thread, annotations + ) def grid(*extents: PrimExpr) -> frame.ForFrame: @@ -878,10 +872,10 @@ def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame: # pylint: d def LetStmt( # pylint: disable=invalid-name - value: PrimExpr, - type_annotation: Optional[Type] = None, # pylint: disable=redefined-outer-name - *, - var: Optional[Var] = None, # pylint: disable=redefined-outer-name + value: PrimExpr, + type_annotation: Optional[Type] = None, # pylint: disable=redefined-outer-name + *, + var: Optional[Var] = None, # pylint: disable=redefined-outer-name ) -> frame.LetFrame: """Create a LetStmt binding @@ -909,8 +903,8 @@ def LetStmt( # pylint: disable=invalid-name def Let( # pylint: disable=invalid-name - expr: PrimExpr, - where: Dict[Var, PrimExpr], # pylint: disable=redefined-outer-name + expr: PrimExpr, + where: Dict[Var, PrimExpr], # pylint: disable=redefined-outer-name ) -> PrimExpr: """Create a Let expression binding""" assert len(where) == 1, "T.Let only allows `where` to have exactly one element" @@ -980,7 +974,8 @@ def realize( The result RealizeFrame. """ return _ffi_api.Realize( # type: ignore[attr-defined] # pylint: disable=no-member - buffer_slice, storage_scope, condition) + buffer_slice, storage_scope, condition + ) def allocate( @@ -1012,7 +1007,8 @@ def allocate( if isinstance(condition, bool): condition = IntImm("bool", condition) return _ffi_api.Allocate( # type: ignore[attr-defined] # pylint: disable=no-member - extents, dtype, scope, condition, annotations) + extents, dtype, scope, condition, annotations + ) def allocate_const( @@ -1048,7 +1044,8 @@ def allocate_const( np_data = np_data.reshape(extents) return _ffi_api.AllocateConst( # type: ignore[attr-defined] # pylint: disable=no-member - ndarray.array(np_data), dtype, extents, annotations) + ndarray.array(np_data), dtype, extents, annotations + ) def attr(node: Any, attr_key: str, value: Union[PrimExpr, str]) -> frame.AttrFrame: @@ -1297,7 +1294,8 @@ def buffer_store( if isinstance(value, bool) and buffer.dtype == "bool": value = IntImm("bool", value) return _ffi_api.BufferStore( # type: ignore[attr-defined] # pylint: disable=no-member - buffer, value, expr_indices) + buffer, value, expr_indices + ) def prefetch( @@ -1464,10 +1462,7 @@ def boolean(expr: Optional[PrimExpr] = None, is_size_var: bool = False) -> PrimE return _ffi_api.Boolean(expr, is_size_var) # type: ignore[attr-defined] # pylint: disable=no-member -def handle(dtype: Optional[str] = None, - storage_scope: str = "global", - *, - is_size_var: bool = False) -> Var: +def handle(dtype: Optional[str] = None, storage_scope: str = "global", *, is_size_var: bool = False) -> Var: """Create a TIR var that represents a pointer. Parameters @@ -1667,7 +1662,7 @@ def comm_reducer(combiner: Callable, identity: List[PrimExpr]) -> CommReducer: res = combiner(*args) if not isinstance(res, tuple): res = (res,) - return CommReducer(args[:num_args // 2], args[num_args // 2:], res, identity) + return CommReducer(args[: num_args // 2], args[num_args // 2 :], res, identity) def index_map( @@ -1700,16 +1695,15 @@ def target( The target. """ if not isinstance(target_config, (str, dict)): - raise ValueError( - f"T.target expected a config dict or string, but got {type(target_config)}") + raise ValueError(f"T.target expected a config dict or string, but got {type(target_config)}") if host is not None and not isinstance(host, (str, dict, Target)): - raise ValueError("T.target expected the host to be " - "a config dict, string, or T.target, " - f"but got {type(host)}") + raise ValueError(f"T.target expected the host to be a config dict, string, or T.target, but got {type(host)}") if isinstance(target_config, dict) and "host" in target_config and host is not None: - raise ValueError("T.target expects to either receive the host " - "as part of the target's config dictionary, " - "or as a separate argument, but not both.") + raise ValueError( + "T.target expects to either receive the host " + "as part of the target's config dictionary, " + "or as a separate argument, but not both." + ) return Target(target_config, host) @@ -1742,7 +1736,6 @@ class meta_var: # pylint: disable=invalid-name self.value = value def __iter__(self): - def f(): for i in self.value: yield meta_var(i) @@ -1754,7 +1747,6 @@ class meta_var: # pylint: disable=invalid-name def _op_wrapper(func): - @functools.wraps(func) def wrapped(*args, **kwargs): if "dtype" in kwargs: @@ -1874,7 +1866,6 @@ vscale = _op_wrapper(_tir_op.vscale) def _dtype_forward(func): - @functools.wraps(func) def wrapped(*args, **kwargs): if "dtype" in kwargs: diff --git a/tilelang/language/atomic.py b/tilelang/language/atomic.py index 07e45bb..89a3af2 100644 --- a/tilelang/language/atomic.py +++ b/tilelang/language/atomic.py @@ -1,6 +1,7 @@ # Copyright (c) Tile-AI Corporation. # Licensed under the MIT License. """Atomic operations for tilelang.""" + from __future__ import annotations import tilelang.language as T @@ -18,10 +19,7 @@ _MEMORY_ORDER_ID_MAP = { } -def atomic_max(dst: Buffer, - value: PrimExpr, - memory_order: str | None = None, - return_prev: bool = False) -> PrimExpr: +def atomic_max(dst: Buffer, value: PrimExpr, memory_order: str | None = None, return_prev: bool = False) -> PrimExpr: """ Perform an atomic maximum on the value stored at dst with an optional memory-order. @@ -64,10 +62,7 @@ def atomic_max(dst: Buffer, return T.call_extern(return_type, func_name, dst, value, _MEMORY_ORDER_ID_MAP[memory_order]) -def atomic_min(dst: Buffer, - value: PrimExpr, - memory_order: str | None = None, - return_prev: bool = False) -> PrimExpr: +def atomic_min(dst: Buffer, value: PrimExpr, memory_order: str | None = None, return_prev: bool = False) -> PrimExpr: """ Atomically update the value at dst to the minimum of its current value and value. @@ -112,11 +107,7 @@ def atomic_min(dst: Buffer, return T.call_extern(return_type, func_name, dst, value, _MEMORY_ORDER_ID_MAP[memory_order]) -def atomic_add(dst: Buffer, - value: PrimExpr, - memory_order: str | None = None, - return_prev: bool = False, - use_tma: bool = False) -> PrimExpr: +def atomic_add(dst: Buffer, value: PrimExpr, memory_order: str | None = None, return_prev: bool = False, use_tma: bool = False) -> PrimExpr: """ Atomically add `value` into `dst`, returning a handle to the operation. @@ -191,8 +182,7 @@ def atomic_add(dst: Buffer, if memory_order is None: return T.call_extern(return_type, func_name, dst, value) else: - return T.call_extern(return_type, func_name, dst, value, - _MEMORY_ORDER_ID_MAP[memory_order]) + return T.call_extern(return_type, func_name, dst, value, _MEMORY_ORDER_ID_MAP[memory_order]) if isinstance(dst, Buffer) and isinstance(value, Buffer): ir.assert_structural_equal(dst.shape, value.shape) @@ -208,14 +198,12 @@ def atomic_add(dst: Buffer, # Note: tile-region-based atomic operations don't support return_prev yet # This would need to be implemented in the tile runtime if return_prev: - raise NotImplementedError( - "return_prev is not supported for tile-region-based atomic operations") + raise NotImplementedError("return_prev is not supported for tile-region-based atomic operations") if memory_order is None: return T.call_intrin("handle", op.Op.get("tl.tileop.atomicadd"), value, dst, use_tma, 0) else: - return T.call_intrin("handle", op.Op.get("tl.tileop.atomicadd"), value, dst, use_tma, - _MEMORY_ORDER_ID_MAP[memory_order]) + return T.call_intrin("handle", op.Op.get("tl.tileop.atomicadd"), value, dst, use_tma, _MEMORY_ORDER_ID_MAP[memory_order]) def atomic_addx2(dst: Buffer, value: PrimExpr, return_prev: bool = False) -> PrimExpr: diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index 0bc12fc..60739e6 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -1,4 +1,5 @@ """The language interface for tl programs.""" + from __future__ import annotations from tilelang import tvm as tvm @@ -179,38 +180,32 @@ def set_max_nreg(reg_count: int, is_inc: int): def inc_max_nreg(reg_count: int): - """Increment the maximum number of registers to use. - """ + """Increment the maximum number of registers to use.""" return set_max_nreg(reg_count, 1) def dec_max_nreg(reg_count: int): - """Decrement the maximum number of registers to use. - """ + """Decrement the maximum number of registers to use.""" return set_max_nreg(reg_count, 0) def annotate_producer_reg_dealloc(reg_count: int = 24): - """Annotate the producer reg dealloc. - """ + """Annotate the producer reg dealloc.""" return dec_max_nreg(reg_count) def annotate_consumer_reg_alloc(reg_count: int = 240): - """Annotate the consumer reg alloc. - """ + """Annotate the consumer reg alloc.""" return inc_max_nreg(reg_count) def no_set_max_nreg(): - """Disable the maximum register limit setting. - """ + """Disable the maximum register limit setting.""" return tir.call_intrin("handle", tir.op.Op.get("tl.no_set_max_nreg")) def disable_warp_group_reg_alloc(): - """Disable the warp group reg alloc. - """ + """Disable the warp group reg alloc.""" return no_set_max_nreg() @@ -325,7 +320,9 @@ def warpgroup_wait(num_mma: int): return tir.call_intrin("handle", tir.op.Op.get("tl.warpgroup_wait"), num_mma) -def get_lane_idx(warp_size: int | PrimExpr | None = None,) -> PrimExpr: +def get_lane_idx( + warp_size: int | PrimExpr | None = None, +) -> PrimExpr: """Return the logical lane index of the calling thread within a warp. Parameters @@ -350,7 +347,9 @@ def get_lane_idx(warp_size: int | PrimExpr | None = None,) -> PrimExpr: return tir.call_intrin("int32", tir.op.Op.get("tl.get_lane_idx"), warp_size_expr) -def get_warp_idx_sync(warp_size: int | PrimExpr | None = None,) -> PrimExpr: +def get_warp_idx_sync( + warp_size: int | PrimExpr | None = None, +) -> PrimExpr: """Return the canonical warp index, assuming the warp's threads are converged. Parameters @@ -374,7 +373,9 @@ def get_warp_idx_sync(warp_size: int | PrimExpr | None = None,) -> PrimExpr: return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_idx_sync"), warp_size_expr) -def get_warp_idx(warp_size: int | PrimExpr | None = None,) -> PrimExpr: +def get_warp_idx( + warp_size: int | PrimExpr | None = None, +) -> PrimExpr: """Return the canonical warp index without synchronizing the warp. Parameters @@ -429,8 +430,7 @@ def get_warp_group_idx( args.append(warp_size_expr) if warps_per_group_expr is not None: if warp_size_expr is None: - raise ValueError("get_warp_group_idx expects `warp_size` when specifying " - "`warps_per_group`.") + raise ValueError("get_warp_group_idx expects `warp_size` when specifying `warps_per_group`.") args.append(warps_per_group_expr) return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_group_idx"), *args) @@ -459,10 +459,9 @@ def shuffle_elect(thread_extent: int) -> PrimExpr: return tir.call_intrin("bool", tir.op.Op.get("tl.tl_shuffle_elect"), thread_extent) -def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr, - offset: int | PrimExpr = 0, - num_regs: int | PrimExpr | None = None, - dtype: str | None = None): +def warpgroup_fence_operand( + buffer_or_ptr: tir.Buffer | PrimExpr, offset: int | PrimExpr = 0, num_regs: int | PrimExpr | None = None, dtype: str | None = None +): """Insert a warpgroup fence for the destination accumulator registers. This prevents NVCC from sinking uses of accumulator fragments past the corresponding @@ -517,7 +516,8 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr, data_ptr, convert(offset), convert(num_regs), - )) + ) + ) if isinstance(buffer_or_ptr, tir.Buffer): data_ptr = buffer_or_ptr.data @@ -531,8 +531,7 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr, if isinstance(dim, tir.IntImm): total_elems *= int(dim) else: - raise ValueError( - "warpgroup_fence_operand requires num_regs when buffer shape is symbolic.") + raise ValueError("warpgroup_fence_operand requires num_regs when buffer shape is symbolic.") bits_per_elem = DataType(dtype).bits num_regs = (total_elems * bits_per_elem + 31) // 32 elif isinstance(buffer_or_ptr, BufferRegion): @@ -569,9 +568,7 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr, bits_per_elem = DataType(dtype).bits num_regs = (total_elems * bits_per_elem + 31) // 32 else: - raise ValueError( - "warpgroup_fence_operand requires num_regs when BufferRegion extent is symbolic." - ) + raise ValueError("warpgroup_fence_operand requires num_regs when BufferRegion extent is symbolic.") return evaluate( tir.call_intrin( "handle", @@ -580,7 +577,8 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr, data_ptr, convert(offset), convert(num_regs), - )) + ) + ) else: data_ptr = buffer_or_ptr # Try to infer dtype from common pointer expressions when not provided @@ -603,9 +601,7 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr, except Exception: inferred = None if inferred is None: - raise ValueError( - "dtype must be provided when passing a pointer expression and cannot be inferred." - ) + raise ValueError("dtype must be provided when passing a pointer expression and cannot be inferred.") dtype = inferred if num_regs is None: raise ValueError("num_regs must be provided when passing a pointer expression.") @@ -618,7 +614,8 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr, data_ptr, convert(offset), convert(num_regs), - )) + ) + ) def wait_wgmma(id: int): @@ -673,7 +670,7 @@ def shfl_xor(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call if _IS_HIP_AVAILABLE: return tir.call_extern(value.dtype, "__shfl_xor", value, offset) else: - return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset) + return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xFFFFFFFF, value, offset) def shfl_down(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call): @@ -686,7 +683,7 @@ def shfl_down(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Cal if _IS_HIP_AVAILABLE: return tir.call_extern(value.dtype, "__shfl_down", value, offset) else: - return tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset) + return tir.call_extern(value.dtype, "__shfl_down_sync", 0xFFFFFFFF, value, offset) def shfl_up(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call): @@ -699,12 +696,11 @@ def shfl_up(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call) if _IS_HIP_AVAILABLE: return tir.call_extern(value.dtype, "__shfl_up", value, offset) else: - return tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset) + return tir.call_extern(value.dtype, "__shfl_up_sync", 0xFFFFFFFF, value, offset) def sync_threads(barrier_id: int = None, arrive_count: int = None): - """Synchronize all threads in a block. - """ + """Synchronize all threads in a block.""" args = [] if barrier_id is not None: args.append(barrier_id) @@ -714,8 +710,7 @@ def sync_threads(barrier_id: int = None, arrive_count: int = None): def sync_global(): - """Synchronize all threads in the entire grid. - """ + """Synchronize all threads in the entire grid.""" tx, ty, tz = get_thread_bindings() ex, ey, ez = get_block_extents() print(tx, ty, tz, ex, ey, ez) @@ -724,8 +719,7 @@ def sync_global(): def sync_grid(): - """Synchronize all threads in a grid. - """ + """Synchronize all threads in a grid.""" return tir.call_intrin("handle", tir.op.Op.get("tl.sync_grid")) @@ -741,12 +735,10 @@ def initialize_wgmma_descriptor( if not isinstance(descriptor, (BufferLoad, tir.Buffer)): raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.") - if isinstance(descriptor, tir.Buffer) and (len(descriptor.shape) != 1 or - descriptor.shape[0] != 1): + if isinstance(descriptor, tir.Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1): raise ValueError("Descriptor must be a 1D buffer of size 1.") - descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad( - descriptor, [0]) + descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(descriptor, [0]) return evaluate( tir.call_intrin( @@ -757,7 +749,8 @@ def initialize_wgmma_descriptor( layout_type_, int(leading_byte_offset), int(stride_byte_offset), - )) + ) + ) def initialize_tcgen05_descriptor( @@ -774,12 +767,10 @@ def initialize_tcgen05_descriptor( if not isinstance(descriptor, (BufferLoad, tir.Buffer)): raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.") - if isinstance(descriptor, tir.Buffer) and (len(descriptor.shape) != 1 or - descriptor.shape[0] != 1): + if isinstance(descriptor, tir.Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1): raise ValueError("Descriptor must be a 1D buffer of size 1.") - descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad( - descriptor, [0]) + descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(descriptor, [0]) return evaluate( tir.call_intrin( @@ -792,7 +783,8 @@ def initialize_tcgen05_descriptor( int(base_offset), tir.IntImm("int32", 1 if leading_is_absolute else 0), int(swizzle_mode), - )) + ) + ) def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimExpr: @@ -809,27 +801,21 @@ def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimEx if not isinstance(descriptor, (BufferLoad, tir.Buffer)): raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.") - if isinstance(descriptor, tir.Buffer) and len( - descriptor.shape) != 1 or descriptor.shape[0] != 1: + if isinstance(descriptor, tir.Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1: raise ValueError("Descriptor must be a 1D buffer of size 1.") - descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad( - descriptor, [0]) + descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(descriptor, [0]) - return evaluate( - tir.call_intrin("handle", tir.op.Op.get("tl.increase_descriptor_offset"), descriptor, - offset)) + return evaluate(tir.call_intrin("handle", tir.op.Op.get("tl.increase_descriptor_offset"), descriptor, offset)) def loop_break(): - """Break out of the innermost loop. - """ + """Break out of the innermost loop.""" return tir.call_intrin("handle", tir.op.Op.get("tl.loop_break")) def cp_async_barrier_noinc(barrier_id: int | PrimExpr | tir.Call): - """Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc. - """ + """Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.""" return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id) diff --git a/tilelang/language/copy.py b/tilelang/language/copy.py index cabc4a3..b80a24e 100644 --- a/tilelang/language/copy.py +++ b/tilelang/language/copy.py @@ -1,4 +1,5 @@ """The language interface for tl programs.""" + from __future__ import annotations from typing import Literal from tilelang import language as T @@ -10,11 +11,13 @@ from tilelang.utils.language import ( from tvm import ir, tir -def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, - dst: tir.Buffer | tir.BufferLoad, - coalesced_width: int | None = None, - disable_tma: bool = False, - eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None): +def copy( + src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, + dst: tir.Buffer | tir.BufferLoad, + coalesced_width: int | None = None, + disable_tma: bool = False, + eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None, +): """Copy data between memory regions. Args: @@ -65,8 +68,7 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, src_extent = get_extent(src) dst_extent = get_extent(dst) # Combine the nested if statements into a single if statement as suggested by SIM102 - if (src_extent is None and dst_extent is None and isinstance(src, tir.BufferLoad) and - isinstance(dst, tir.BufferLoad)): + if src_extent is None and dst_extent is None and isinstance(src, tir.BufferLoad) and isinstance(dst, tir.BufferLoad): # check if the case is like this: # copy(buffer_a[i], buffer_b[i]) where both are BufferLoad nodes # In this case, lower it to a simple BufferStore: buffer_b[i] = buffer_a[i] @@ -90,19 +92,20 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, eviction_policy = 0 else: eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy] - return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.copy"), src, dst, coalesced_width, - disable_tma, eviction_policy) - - -def c2d_im2col(img: tir.Buffer, - col: tir.Buffer, - nhw_step: tir.PrimExpr, - c_step: tir.PrimExpr, - kernel: int, - stride: int, - dilation: int, - pad: int, - eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None): + return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.copy"), src, dst, coalesced_width, disable_tma, eviction_policy) + + +def c2d_im2col( + img: tir.Buffer, + col: tir.Buffer, + nhw_step: tir.PrimExpr, + c_step: tir.PrimExpr, + kernel: int, + stride: int, + dilation: int, + pad: int, + eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None, +): """Perform im2col transformation for 2D convolution. Args: @@ -124,5 +127,16 @@ def c2d_im2col(img: tir.Buffer, eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy] img_region = to_buffer_region(img, access_type="r") col_region = to_buffer_region(col, access_type="w") - return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.c2d_im2col"), img_region, col_region, - nhw_step, c_step, kernel, stride, dilation, pad, eviction_policy) + return tir.call_intrin( + "handle", + tir.op.Op.get("tl.tileop.c2d_im2col"), + img_region, + col_region, + nhw_step, + c_step, + kernel, + stride, + dilation, + pad, + eviction_policy, + ) diff --git a/tilelang/language/customize.py b/tilelang/language/customize.py index 720c9e9..e2f4b1c 100644 --- a/tilelang/language/customize.py +++ b/tilelang/language/customize.py @@ -1,8 +1,9 @@ """The language interface for tl programs.""" + from __future__ import annotations import tilelang.language as T from tvm.tir import PrimExpr, Buffer, op -from tilelang.utils.language import (bits_product, prim_expr_equal) +from tilelang.utils.language import bits_product, prim_expr_equal from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store # noqa: F401 @@ -46,9 +47,9 @@ def reshape(src: Buffer, shape: list[PrimExpr]) -> Buffer: Returns: Buffer: A new buffer view with the specified shape """ - assert prim_expr_equal( - bits_product(shape, src.dtype), bits_product(src.shape, src.dtype) - ), f"T.reshape/view shape check failed. src {src} src.shape: {src.shape}, src.dtype: {src.dtype}, target shape: {shape}, target dtype: {src.dtype}" + assert prim_expr_equal(bits_product(shape, src.dtype), bits_product(src.shape, src.dtype)), ( + f"T.reshape/view shape check failed. src {src} src.shape: {src.shape}, src.dtype: {src.dtype}, target shape: {shape}, target dtype: {src.dtype}" + ) return T.Tensor(shape, src.dtype, src.data) @@ -61,8 +62,7 @@ def view(src: Buffer, shape: list[PrimExpr] | None = None, dtype: str | None = N shape = src.shape if dtype is None: dtype = src.dtype - assert prim_expr_equal(bits_product(shape, dtype), - bits_product(src.shape, src.dtype)), "T.reshape/view shape check failed." + assert prim_expr_equal(bits_product(shape, dtype), bits_product(src.shape, src.dtype)), "T.reshape/view shape check failed." return T.Tensor(shape, dtype, src.data) diff --git a/tilelang/language/experimental/gemm_sp.py b/tilelang/language/experimental/gemm_sp.py index 4a20f3f..5adac92 100644 --- a/tilelang/language/experimental/gemm_sp.py +++ b/tilelang/language/experimental/gemm_sp.py @@ -1,4 +1,5 @@ """The language interface for tl programs.""" + from __future__ import annotations from tilelang.primitives.gemm.base import GemmWarpPolicy import tilelang.language as T @@ -11,7 +12,8 @@ from tilelang.utils.language import ( prim_expr_equal, ) from tilelang.language.utils import ( - buffer_region_to_tile_region,) + buffer_region_to_tile_region, +) def gemm_sp( @@ -169,18 +171,19 @@ def gemm_sp_v2( assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor" if len(A_shape) > 2: for i in range(len(A_shape) - 2): - assert A_shape[i] == 1, \ + assert A_shape[i] == 1, ( "current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions" + ) if len(B_shape) > 2: for i in range(len(B_shape) - 2): - assert B_shape[i] == 1, \ + assert B_shape[i] == 1, ( "current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions" + ) M, N = C_shape K = 2 * (A_shape[-2] if transpose_A else A_shape[-1]) K_B = B_shape[-1] if transpose_B else B_shape[-2] - assert prim_expr_equal( - K, K_B), f"T.gemm_sp K shape check failed: K_A (wo sparse) = {K}, K_B = {K_B}" + assert prim_expr_equal(K, K_B), f"T.gemm_sp K shape check failed: K_A (wo sparse) = {K}, K_B = {K_B}" stride_a = A_stride[-2] stride_b = B_stride[-2] diff --git a/tilelang/language/fill.py b/tilelang/language/fill.py index b237333..af301c2 100644 --- a/tilelang/language/fill.py +++ b/tilelang/language/fill.py @@ -1,4 +1,5 @@ """The language interface for tl programs.""" + from __future__ import annotations from tvm import tir from tilelang.language import has_let_value, get_let_value @@ -32,8 +33,7 @@ def fill(buffer: tir.Buffer | tir.BufferRegion | tir.BufferLoad, value: tir.Prim extents = [tir.IntImm("int32", 1) for _ in buffer.indices] else: extents = [] - return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.fill"), - to_buffer_region(buffer, access_type="w", extents=extents), value) + return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.fill"), to_buffer_region(buffer, access_type="w", extents=extents), value) def clear(buffer: tir.Buffer | tir.Var): @@ -55,8 +55,7 @@ def clear(buffer: tir.Buffer | tir.Var): elif isinstance(buffer_region, tir.BufferLoad): region = get_buffer_region_from_load(buffer_region) if region is None: - raise ValueError( - f"Invalid buffer region: {buffer_region}, type: {type(buffer_region)}") + raise ValueError(f"Invalid buffer region: {buffer_region}, type: {type(buffer_region)}") return fill(region, 0) else: raise ValueError(f"Invalid buffer region: {buffer_region}, type: {type(buffer_region)}") diff --git a/tilelang/language/frame.py b/tilelang/language/frame.py index db64995..7e60f46 100644 --- a/tilelang/language/frame.py +++ b/tilelang/language/frame.py @@ -1,4 +1,5 @@ """Override the LetFrame to print a message when entering the frame.""" + from __future__ import annotations from tvm.ffi import register_object as _register_object from tvm.tir import Var, PrimExpr, BufferLoad, BufferRegion @@ -29,7 +30,7 @@ class FrameStack: item: The frame object to push onto the stack """ self._stack.append(item) - if hasattr(item, 'var') and hasattr(item, 'value'): + if hasattr(item, "var") and hasattr(item, "value"): self._var_value_map[item.var] = item.value def pop(self): @@ -43,7 +44,7 @@ class FrameStack: """ if self._stack: item = self._stack.pop() - if hasattr(item, 'var'): + if hasattr(item, "var"): self._var_value_map.pop(item.var, None) return item raise IndexError(f"{self.__class__.__name__} is empty") @@ -129,8 +130,7 @@ class LetFrame(TIRFrame): is_block_load = True break if is_block_load: - self.value = BufferRegion(self.value.buffer, - [Range(x.base, x.lanes) for x in indices]) + self.value = BufferRegion(self.value.buffer, [Range(x.base, x.lanes) for x in indices]) _get_let_stack().push(self) return self.var diff --git a/tilelang/language/gemm.py b/tilelang/language/gemm.py index db8e04a..56f6805 100644 --- a/tilelang/language/gemm.py +++ b/tilelang/language/gemm.py @@ -1,4 +1,5 @@ """The language interface for tl programs.""" + from __future__ import annotations from tilelang.primitives.gemm.base import GemmWarpPolicy import tilelang.language as T @@ -11,7 +12,8 @@ from tilelang.utils.language import ( prim_expr_equal, ) from tilelang.language.utils import ( - buffer_region_to_tile_region,) + buffer_region_to_tile_region, +) from tilelang.env import env as _env @@ -68,12 +70,14 @@ def _gemm_impl( assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor" if len(A_shape) > 2: for i in range(len(A_shape) - 2): - assert A_shape[i] == 1, \ + assert A_shape[i] == 1, ( "current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions" + ) if len(B_shape) > 2: for i in range(len(B_shape) - 2): - assert B_shape[i] == 1, \ + assert B_shape[i] == 1, ( "current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions" + ) M, N = C_shape K = A_shape[-2] if transpose_A else A_shape[-1] @@ -96,9 +100,29 @@ def _gemm_impl( A_arg = buffer_region_to_tile_region(A_region, "r", [r for r in A_shape]) B_arg = buffer_region_to_tile_region(B_region, "r", [r for r in B_shape]) C_arg = buffer_region_to_tile_region(C_region, "rw", [r for r in C_shape]) - return tir.call_intrin("handle", tir.op.Op.get(op_key), A_arg, B_arg, C_arg, transpose_A, - transpose_B, M, N, K, policy, clear_accum, stride_a, stride_b, offset_a, - offset_b, k_pack, wg_wait, mbar, C_coords[0], C_coords[1]) + return tir.call_intrin( + "handle", + tir.op.Op.get(op_key), + A_arg, + B_arg, + C_arg, + transpose_A, + transpose_B, + M, + N, + K, + policy, + clear_accum, + stride_a, + stride_b, + offset_a, + offset_b, + k_pack, + wg_wait, + mbar, + C_coords[0], + C_coords[1], + ) # Public wrappers diff --git a/tilelang/language/kernel.py b/tilelang/language/kernel.py index 5e819da..625531b 100644 --- a/tilelang/language/kernel.py +++ b/tilelang/language/kernel.py @@ -1,4 +1,5 @@ """The language interface for tl programs.""" + from __future__ import annotations from collections import deque from tvm import tir @@ -107,8 +108,7 @@ class KernelLaunchFrame(TIRFrame): _get_current_stack().push(self) last_block_frame = self.frames[-1] - assert isinstance(last_block_frame, - BlockFrame), f"Last frame must be a block frame, got {last_block_frame}" + assert isinstance(last_block_frame, BlockFrame), f"Last frame must be a block frame, got {last_block_frame}" maybe_cpu = last_block_frame.annotations.get("tilelang.is_cpu_kernel_frame", False) @@ -303,56 +303,48 @@ def Kernel( def get_thread_binding(dim: int = 0) -> Var: - """Returns the thread binding for the given dimension. - """ + """Returns the thread binding for the given dimension.""" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" return KernelLaunchFrame.Current().get_thread_binding(dim) def get_thread_bindings() -> list[Var]: - """Returns all three thread bindings. - """ + """Returns all three thread bindings.""" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" return KernelLaunchFrame.Current().get_thread_bindings() def get_block_binding(dim: int = 0) -> Var: - """Returns the block binding for the given dimension. - """ + """Returns the block binding for the given dimension.""" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" return KernelLaunchFrame.Current().get_block_binding(dim) def get_block_bindings() -> list[Var]: - """Returns all three block bindings. - """ + """Returns all three block bindings.""" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" return KernelLaunchFrame.Current().get_block_bindings() def get_thread_extent(dim: int = 0) -> int: - """Returns the thread extent for the given dimension. - """ + """Returns the thread extent for the given dimension.""" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" return KernelLaunchFrame.Current().get_thread_extent(dim) def get_thread_extents() -> list[int]: - """Returns all three thread extents. - """ + """Returns all three thread extents.""" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" return KernelLaunchFrame.Current().get_thread_extents() def get_block_extent(dim: int = 0) -> int: - """Returns the block extent for the given dimension. - """ + """Returns the block extent for the given dimension.""" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" return KernelLaunchFrame.Current().get_block_extent(dim) def get_block_extents() -> list[int]: - """Returns all three block extents. - """ + """Returns all three block extents.""" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" return KernelLaunchFrame.Current().get_block_extents() diff --git a/tilelang/language/logical.py b/tilelang/language/logical.py index a09088e..fb4b88a 100644 --- a/tilelang/language/logical.py +++ b/tilelang/language/logical.py @@ -1,4 +1,5 @@ """The language interface for tl programs.""" + from __future__ import annotations from tilelang import language as T @@ -36,8 +37,7 @@ def any_of(buffer: T.Tensor | BufferRegion): ) new_region.append(r.min) buffer_load = BufferLoad(buffer, new_region) - return T.call_intrin(return_type, tir.op.Op.get("tl.any_of"), T.address_of(buffer_load), - extent) + return T.call_intrin(return_type, tir.op.Op.get("tl.any_of"), T.address_of(buffer_load), extent) else: raise ValueError(f"Invalid buffer type: {type(buffer)}") @@ -71,7 +71,6 @@ def all_of(buffer: T.Tensor | BufferRegion): ) new_region.append(r.min) buffer_load = BufferLoad(buffer, new_region) - return T.call_intrin(return_type, tir.op.Op.get("tl.all_of"), T.address_of(buffer_load), - extent) + return T.call_intrin(return_type, tir.op.Op.get("tl.all_of"), T.address_of(buffer_load), extent) else: raise ValueError(f"Invalid buffer type: {type(buffer)}") diff --git a/tilelang/language/loop.py b/tilelang/language/loop.py index 3478b6c..f28f097 100644 --- a/tilelang/language/loop.py +++ b/tilelang/language/loop.py @@ -1,4 +1,5 @@ """The language interface for tl programs.""" + from __future__ import annotations from typing import Any from tvm import tir @@ -94,11 +95,9 @@ def Pipelined( return _ffi_api.Pipelined(start, stop, num_stages, order, stage, sync, group) -def serial(start: tir.PrimExpr, - stop: tir.PrimExpr | None = None, - step: tir.PrimExpr | None = None, - *, - annotations: dict[str, Any] | None = None) -> frame.ForFrame: +def serial( + start: tir.PrimExpr, stop: tir.PrimExpr | None = None, step: tir.PrimExpr | None = None, *, annotations: dict[str, Any] | None = None +) -> frame.ForFrame: step_is_one = False step_is_one |= isinstance(step, int) and step == 1 step_is_one |= isinstance(step, IntImm) and step.value == 1 @@ -111,13 +110,15 @@ def serial(start: tir.PrimExpr, return SerialForWithStep(start, stop, step, annotations=annotations) -def unroll(start: tir.PrimExpr, - stop: tir.PrimExpr | None = None, - step: tir.PrimExpr | None = None, - *, - explicit: bool = False, - unroll_factor: int | None = None, - annotations: dict[str, Any] | None = None) -> frame.ForFrame: +def unroll( + start: tir.PrimExpr, + stop: tir.PrimExpr | None = None, + step: tir.PrimExpr | None = None, + *, + explicit: bool = False, + unroll_factor: int | None = None, + annotations: dict[str, Any] | None = None, +) -> frame.ForFrame: """The unrolled For statement. Parameters diff --git a/tilelang/language/math_intrinsics.py b/tilelang/language/math_intrinsics.py index 39cab27..7a6104c 100644 --- a/tilelang/language/math_intrinsics.py +++ b/tilelang/language/math_intrinsics.py @@ -3,7 +3,7 @@ from tvm import tir def _validate_rounding_mode(rounding_mode): """Validate that the rounding mode is one of the supported IEEE modes""" - valid_modes = {'rn', 'rz', 'ru', 'rd'} + valid_modes = {"rn", "rz", "ru", "rd"} if isinstance(rounding_mode, str) and rounding_mode in valid_modes: return raise ValueError(f"Invalid rounding mode '{rounding_mode}'. Must be one of: {valid_modes}") diff --git a/tilelang/language/overrides/parser.py b/tilelang/language/overrides/parser.py index af42098..0b2fcc4 100644 --- a/tilelang/language/overrides/parser.py +++ b/tilelang/language/overrides/parser.py @@ -1,4 +1,5 @@ """TVMScript parser overrides tailored for TileLang.""" + from functools import partial from tvm.script.ir_builder import tir as T @@ -58,8 +59,12 @@ def tilelang_visit_assign(self, node: doc.Assign) -> None: # pylint: disable=un lhs.ctx = load_ctx lhs_value = self.eval_expr(lhs) lhs.ctx = store_ctx - if (isinstance(lhs_value, BufferLoad) and lhs_value.buffer.scope() == "local.var" and - len(lhs_value.indices) == 1 and lhs_value.indices[0] == 0): + if ( + isinstance(lhs_value, BufferLoad) + and lhs_value.buffer.scope() == "local.var" + and len(lhs_value.indices) == 1 + and lhs_value.indices[0] == 0 + ): T.buffer_store(lhs_value.buffer, rhs, indices=[0]) continue @@ -106,8 +111,12 @@ def tilelang_visit_aug_assign(self, node: doc.AugAssign) -> None: # pylint: dis lhs.ctx = load_ctx lhs_value = self.eval_expr(lhs) lhs.ctx = store_ctx - if (isinstance(lhs_value, BufferLoad) and lhs_value.buffer.scope() == "local.var" and - len(lhs_value.indices) == 1 and lhs_value.indices[0] == 0): + if ( + isinstance(lhs_value, BufferLoad) + and lhs_value.buffer.scope() == "local.var" + and len(lhs_value.indices) == 1 + and lhs_value.indices[0] == 0 + ): T.buffer_store(lhs_value.buffer, rhs, indices=[0]) return @@ -131,8 +140,12 @@ def tilelang_visit_ann_assign(self, node: doc.AnnAssign) -> None: # pylint: dis lhs.ctx = load_ctx lhs_value = self.eval_expr(lhs) lhs.ctx = store_ctx - if (isinstance(lhs_value, BufferLoad) and lhs_value.buffer.scope() == "local.var" and - len(lhs_value.indices) == 1 and lhs_value.indices[0] == 0): + if ( + isinstance(lhs_value, BufferLoad) + and lhs_value.buffer.scope() == "local.var" + and len(lhs_value.indices) == 1 + and lhs_value.indices[0] == 0 + ): T.buffer_store(lhs_value.buffer, rhs, indices=[0]) return diff --git a/tilelang/language/parser/entry.py b/tilelang/language/parser/entry.py index aa98cf5..5f2aaab 100644 --- a/tilelang/language/parser/entry.py +++ b/tilelang/language/parser/entry.py @@ -18,6 +18,7 @@ # which is part of the TVM project (https://tvm.apache.org/). # ruff: noqa """The entry point of TVM parser for tir.""" + import inspect from typing import Callable, Optional, Union @@ -29,9 +30,7 @@ from tvm.script.parser._core import parse, scan_macro, utils from tvm.script.parser.core.parser import Parser, ScriptMacro -def prim_func(func: Optional[Callable] = None, - private: bool = False, - check_well_formed=True) -> Union[PrimFunc, Callable]: +def prim_func(func: Optional[Callable] = None, private: bool = False, check_well_formed=True) -> Union[PrimFunc, Callable]: """The parsing method for tir prim func, by using `@prim_func` as decorator. Parameters @@ -149,8 +148,7 @@ def macro(*args, hygienic: bool = True) -> Callable: if len(args) == 1 and inspect.isfunction(args[0]): return _decorator(args[0]) - raise ValueError( - "Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])") + raise ValueError("Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])") class BufferProxy: diff --git a/tilelang/language/parser/operation.py b/tilelang/language/parser/operation.py index b2138ac..473da43 100644 --- a/tilelang/language/parser/operation.py +++ b/tilelang/language/parser/operation.py @@ -17,6 +17,7 @@ # This file is modified from the original version, # which is part of the TVM project (https://tvm.apache.org/). """The tir expression operation registration""" + from tvm import tir from tvm.ffi.runtime_ctypes import DataType, DataTypeCode from tvm.tir import IntImm @@ -55,11 +56,9 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name return dtype[0:index] def _auto_broadcast(a, b, op): - if isinstance(a, int): if hasattr(b, "dtype"): - if (DataType(b.dtype).type_code == DataTypeCode.INT or - DataType(b.dtype).type_code == DataTypeCode.UINT): + if DataType(b.dtype).type_code == DataTypeCode.INT or DataType(b.dtype).type_code == DataTypeCode.UINT: a = IntImm(_get_type_str(b.dtype), a) elif DataType(b.dtype).type_code == DataTypeCode.FLOAT: a = FloatImm(_get_type_str(b.dtype), a) @@ -75,8 +74,7 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name assert isinstance(a, tir.PrimExpr), "Operand should be a PrimExpr." if isinstance(b, int): - if (DataType(a.dtype).type_code == DataTypeCode.INT or - DataType(a.dtype).type_code == DataTypeCode.UINT): + if DataType(a.dtype).type_code == DataTypeCode.INT or DataType(a.dtype).type_code == DataTypeCode.UINT: b = IntImm(_get_type_str(a.dtype), b) elif DataType(a.dtype).type_code == DataTypeCode.FLOAT: b = FloatImm(_get_type_str(a.dtype), b) @@ -85,10 +83,10 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name if DataType(a.dtype).lanes == DataType(b.dtype).lanes: return op(a, b) - elif (DataType(a.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes): + elif DataType(a.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes: broadcast_a = tir.Broadcast(a, DataType(b.dtype).lanes) return op(broadcast_a, b) - elif (DataType(b.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes): + elif DataType(b.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes: broadcast_b = tir.Broadcast(b, DataType(a.dtype).lanes) return op(a, broadcast_b) else: diff --git a/tilelang/language/parser/parser.py b/tilelang/language/parser/parser.py index 3aa720d..4cac0ad 100644 --- a/tilelang/language/parser/parser.py +++ b/tilelang/language/parser/parser.py @@ -146,8 +146,7 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) - res = value.__enter__() IRBuilder.name(var_name, res) return res - elif isinstance(value, (Buffer, IterVar)) or (isinstance(value, Var) and - not self.var_table.exist(value)): + elif isinstance(value, (Buffer, IterVar)) or (isinstance(value, Var) and not self.var_table.exist(value)): IRBuilder.name(var_name, value) return value else: @@ -191,8 +190,7 @@ def visit_for(self: Parser, node: doc.For) -> None: if not isinstance(for_frame, T.frame.ForFrame): self.report_error( node.iter, - "Expect the for loop to be one of the following: " - "range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding", + "Expect the for loop to be one of the following: range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding", ) with self.var_table.with_frame(): with for_frame as iters: @@ -361,8 +359,7 @@ def visit_with(self: Parser, node: doc.With) -> None: for item in node.items: frame = self.eval_expr(item.context_expr) if not isinstance(frame, Frame): - self.report_error(item.context_expr, - "Invalid context expression in the with-statement.") + self.report_error(item.context_expr, "Invalid context expression in the with-statement.") rhs = stack.enter_context(frame) if item.optional_vars is not None: self.eval_assign(target=item.optional_vars, source=rhs, bind_value=bind_with_value) @@ -505,8 +502,7 @@ def visit_if(self: Parser, node: doc.If) -> None: with self.var_table.with_frame(): self.visit_body(node.orelse) else: - self.report_error(node.test, - f"If condition must be a boolean expression, but got {predicate}") + self.report_error(node.test, f"If condition must be a boolean expression, but got {predicate}") @dispatch.register(token="tir", type_name="Assert") diff --git a/tilelang/language/print.py b/tilelang/language/print.py index 08e18f4..bbaa119 100644 --- a/tilelang/language/print.py +++ b/tilelang/language/print.py @@ -26,9 +26,7 @@ def print_var(var: tir.PrimExpr, msg: str = "") -> tir.PrimExpr: @macro -def print_var_with_condition(condition: tir.PrimExpr, - var: tir.PrimExpr, - msg: str = "") -> tir.PrimExpr: +def print_var_with_condition(condition: tir.PrimExpr, var: tir.PrimExpr, msg: str = "") -> tir.PrimExpr: """ Conditionally prints a TIR primitive expression (PrimExpr) if a given condition is True. @@ -44,10 +42,7 @@ def print_var_with_condition(condition: tir.PrimExpr, @macro -def print_global_buffer_with_condition(condition: tir.PrimExpr, - buffer: tir.Buffer, - elems: int, - msg: str = "") -> tir.PrimExpr: +def print_global_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffer, elems: int, msg: str = "") -> tir.PrimExpr: """ Conditionally prints the values of a flattened TIR buffer if the condition is True. """ @@ -55,17 +50,13 @@ def print_global_buffer_with_condition(condition: tir.PrimExpr, # Iterate through the buffer elements and print each one. for i in serial(elems): coords = index_to_coordinates(i, buffer.shape) - tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, - buffer[coords]) + tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords]) else: tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords]) @macro -def print_shared_buffer_with_condition(condition: tir.PrimExpr, - buffer: tir.Buffer, - elems: int, - msg: str = "") -> tir.PrimExpr: +def print_shared_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffer, elems: int, msg: str = "") -> tir.PrimExpr: """ Conditionally prints the values of a flattened TIR buffer if the condition is True. @@ -81,15 +72,11 @@ def print_shared_buffer_with_condition(condition: tir.PrimExpr, # Iterate through the buffer elements and print each one. for i in serial(elems): coords = index_to_coordinates(i, buffer.shape) - tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, - buffer[coords]) + tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords]) @macro -def print_fragment_buffer_with_condition(condition: tir.PrimExpr, - buffer: tir.Buffer, - elems: int, - msg: str = "") -> tir.PrimExpr: +def print_fragment_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffer, elems: int, msg: str = "") -> tir.PrimExpr: """ Conditionally prints the values of a flattened TIR buffer if the condition is True. @@ -111,10 +98,7 @@ def print_fragment_buffer_with_condition(condition: tir.PrimExpr, @macro -def print_local_buffer_with_condition(condition: tir.PrimExpr, - buffer: tir.Buffer, - elems: int, - msg: str = "") -> tir.PrimExpr: +def print_local_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffer, elems: int, msg: str = "") -> tir.PrimExpr: """ Conditionally prints the values of a flattened TIR buffer if the condition is True. @@ -130,8 +114,7 @@ def print_local_buffer_with_condition(condition: tir.PrimExpr, # Iterate through the buffer elements and print each one. for i in serial(elems): coords = index_to_coordinates(i, buffer.shape) - tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, - buffer[coords]) + tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords]) from tilelang.utils.target import check_cuda_availability @@ -201,7 +184,7 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> elems *= dim # Ensure only the first thread (tx=0, ty=0, tz=0) executes the print. - condition = (tx == main_lane and ty == 0 and tz == 0) + condition = tx == main_lane and ty == 0 and tz == 0 if not msg: msg = f"buffer<{buffer.name}, {buffer.dtype}>" return print_fragment_buffer_with_condition(condition, buffer, elems, msg) @@ -212,7 +195,7 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> elems *= dim # Ensure only the first thread (tx=0, ty=0, tz=0) executes the print. - condition = (tx == main_lane and ty == 0 and tz == 0) + condition = tx == main_lane and ty == 0 and tz == 0 if not msg: msg = f"buffer<{buffer.name}, {buffer.dtype}>" return print_shared_buffer_with_condition(condition, buffer, elems, msg) @@ -234,5 +217,4 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> else: # Unsupported object type. - raise ValueError( - f"Unexpected type: {type(obj)}. Supported types are tir.Buffer and tir.PrimExpr.") + raise ValueError(f"Unexpected type: {type(obj)}. Supported types are tir.Buffer and tir.PrimExpr.") diff --git a/tilelang/language/proxy.py b/tilelang/language/proxy.py index 9e209a1..7807a46 100644 --- a/tilelang/language/proxy.py +++ b/tilelang/language/proxy.py @@ -1,4 +1,5 @@ """The language interface for tl programs.""" + from __future__ import annotations from typing import Any, SupportsIndex, TYPE_CHECKING, Generic, TypeVar @@ -51,11 +52,9 @@ class BufferProxy: return self(keys) return self(*keys) # type: ignore[attr-defined] # pylint: disable=no-member - def from_ptr(self, - pointer_var: Var, - shape: tuple[PrimExpr, ...], - dtype: str = "float32", - strides: tuple[PrimExpr, ...] = None) -> Buffer: + def from_ptr( + self, pointer_var: Var, shape: tuple[PrimExpr, ...], dtype: str = "float32", strides: tuple[PrimExpr, ...] = None + ) -> Buffer: """Create a buffer from a pointer, shape, and data type. Args: @@ -76,6 +75,7 @@ class BaseTensorProxy: customizable default values for scope, alignment, and offset factors. It implements the core functionality for creating TIR buffers with specific memory configurations. """ + default_scope = "global" default_align = 0 default_offset_factor = 0 @@ -118,11 +118,9 @@ class BaseTensorProxy: keys = (keys,) return self(*keys) - def from_ptr(self, - pointer_var: Var, - shape: tuple[PrimExpr, ...], - dtype: str = "float32", - strides: tuple[PrimExpr, ...] = None) -> tir.Buffer: + def from_ptr( + self, pointer_var: Var, shape: tuple[PrimExpr, ...], dtype: str = "float32", strides: tuple[PrimExpr, ...] = None + ) -> tir.Buffer: """Create a buffer from a pointer, shape, and data type. Args: @@ -151,19 +149,10 @@ class TensorProxy(BaseTensorProxy): strides.append(s) return tuple(reversed(strides)) - def __call__(self, - shape: tuple[Any] | PrimExpr | int, - dtype: str = "float32", - data=None, - scope=None) -> tir.Buffer: + def __call__(self, shape: tuple[Any] | PrimExpr | int, dtype: str = "float32", data=None, scope=None) -> tir.Buffer: if isinstance(shape, (int, PrimExpr)): shape = (shape,) - return super().__call__( - shape, - dtype=dtype, - strides=TensorProxy._construct_strides(shape), - data=data, - scope=scope) + return super().__call__(shape, dtype=dtype, strides=TensorProxy._construct_strides(shape), data=data, scope=scope) class StridedTensorProxy(BaseTensorProxy): @@ -172,11 +161,7 @@ class StridedTensorProxy(BaseTensorProxy): This class implements the default tensor proxy with global memory scope, with the stride information required. """ - def __call__(self, - shape: tuple[Any], - strides: tuple[Any], - dtype: str = "float32", - scope=None) -> tir.Buffer: + def __call__(self, shape: tuple[Any], strides: tuple[Any], dtype: str = "float32", scope=None) -> tir.Buffer: if len(shape) != len(strides): raise ValueError("Invalid shape/strides' dimensions") return super().__call__(shape, dtype=dtype, strides=strides, scope=scope) @@ -188,6 +173,7 @@ class FragmentBufferProxy(BaseTensorProxy): This class represents tensor proxies specifically for local fragment memory, typically used in GPU tensor core operations. """ + default_scope = "local.fragment" @@ -197,6 +183,7 @@ class SharedBufferProxy(BaseTensorProxy): This class represents tensor proxies for dynamic shared memory, commonly used in GPU shared memory operations. """ + default_scope = "shared.dyn" @@ -206,6 +193,7 @@ class LocalBufferProxy(BaseTensorProxy): This class represents tensor proxies for local memory scope, typically used for temporary computations in GPU kernels. """ + default_scope = "local" @@ -216,15 +204,12 @@ Buffer = BufferProxy() # pylint: disable=invalid-name if TYPE_CHECKING: class BaseTensor: - def __class_getitem__(cls, key): return cls - def __getitem__(self, key) -> Any: - ... + def __getitem__(self, key) -> Any: ... - def __setitem__(self, key, value) -> None: - ... + def __setitem__(self, key, value) -> None: ... def __init__( self, @@ -238,36 +223,26 @@ if TYPE_CHECKING: offset_factor=None, buffer_type="", axis_separators=None, - ): - ... + ): ... @classmethod - def from_ptr(cls, - pointer_var: Var, - shape: Sequence[PrimExpr, ...], - dtype: str = "float32", - strides: tuple[PrimExpr, ...] = None) -> Self: - ... + def from_ptr( + cls, pointer_var: Var, shape: Sequence[PrimExpr, ...], dtype: str = "float32", strides: tuple[PrimExpr, ...] = None + ) -> Self: ... - class Tensor(BaseTensor): - ... + class Tensor(BaseTensor): ... - class StridedTensor(BaseTensor): - ... + class StridedTensor(BaseTensor): ... - class FragmentBuffer(BaseTensor): - ... + class FragmentBuffer(BaseTensor): ... - class SharedBuffer(BaseTensor): - ... + class SharedBuffer(BaseTensor): ... - class LocalBuffer(BaseTensor): - ... + class LocalBuffer(BaseTensor): ... - _T = TypeVar('_T') + _T = TypeVar("_T") - class Ref(Generic[_T], tir.Var): - ... + class Ref(Generic[_T], tir.Var): ... else: Tensor = TensorProxy() # pylint: disable=invalid-name StridedTensor = StridedTensorProxy() # pylint: disable=invalid-name @@ -275,14 +250,10 @@ else: SharedBuffer = SharedBufferProxy() # pylint: disable=invalid-name LocalBuffer = LocalBufferProxy() # pylint: disable=invalid-name - class Ref: - ... + class Ref: ... -def ptr(dtype: str | None = None, - storage_scope: str = "global", - *, - is_size_var: bool = False) -> Var: +def ptr(dtype: str | None = None, storage_scope: str = "global", *, is_size_var: bool = False) -> Var: """Create a TIR var that represents a pointer. Parameters @@ -304,8 +275,5 @@ def ptr(dtype: str | None = None, return handle(dtype=dtype, storage_scope=storage_scope, is_size_var=is_size_var) -def make_tensor(ptr: Var, - shape: tuple[PrimExpr, ...], - dtype: str = "float32", - strides: tuple[PrimExpr, ...] = None) -> tir.Buffer: +def make_tensor(ptr: Var, shape: tuple[PrimExpr, ...], dtype: str = "float32", strides: tuple[PrimExpr, ...] = None) -> tir.Buffer: return Tensor.from_ptr(ptr, shape, dtype, strides) diff --git a/tilelang/language/reduce.py b/tilelang/language/reduce.py index fb84b6d..9bb3b17 100644 --- a/tilelang/language/reduce.py +++ b/tilelang/language/reduce.py @@ -1,4 +1,5 @@ """The language interface for tl programs.""" + from __future__ import annotations from tvm import tir from tilelang.language import copy, macro, alloc_shared, alloc_fragment @@ -30,15 +31,13 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea tir.Call: Handle to the reduction operation """ # input shape: [X, d, Y], expected output shape: [X, Y] or [X, 1, Y] - expected_shapes = [ - buffer.shape[:dim] + buffer.shape[dim + 1:], - buffer.shape[:dim] + [1] + buffer.shape[dim + 1:] - ] + expected_shapes = [buffer.shape[:dim] + buffer.shape[dim + 1 :], buffer.shape[:dim] + [1] + buffer.shape[dim + 1 :]] if list(out.shape) not in expected_shapes: - expected_shapes_str = ' or '.join(map(str, expected_shapes)) + expected_shapes_str = " or ".join(map(str, expected_shapes)) raise ValueError( f"Invalid reduce output shape, buffer shape is {buffer.shape}, dim is {dim}, " - f"output shape is {out.shape}, expected shapes are {expected_shapes_str}") + f"output shape is {out.shape}, expected shapes are {expected_shapes_str}" + ) @macro def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool): diff --git a/tilelang/language/tir/entry.py b/tilelang/language/tir/entry.py index 22702ae..82ae7d7 100644 --- a/tilelang/language/tir/entry.py +++ b/tilelang/language/tir/entry.py @@ -7,9 +7,7 @@ from tvm.tir.function import PrimFunc from tvm.script.parser._core import parse, scan_macro, utils -def prim_func(func: Callable | None = None, - private: bool = False, - check_well_formed: bool = False) -> PrimFunc | Callable: +def prim_func(func: Callable | None = None, private: bool = False, check_well_formed: bool = False) -> PrimFunc | Callable: """The parsing method for tir prim func, by using `@prim_func` as decorator. Parameters @@ -113,8 +111,7 @@ def macro(*args, hygienic: bool = True) -> Callable: if len(args) == 1 and inspect.isfunction(args[0]): return _decorator(args[0]) - raise ValueError( - "Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])") + raise ValueError("Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])") setattr(macro, "dispatch_token", "tir") # noqa: B010 diff --git a/tilelang/language/tir/ir.py b/tilelang/language/tir/ir.py index 74cb32f..a836793 100644 --- a/tilelang/language/tir/ir.py +++ b/tilelang/language/tir/ir.py @@ -6,10 +6,7 @@ import tilelang.language.tir.op as _tir_op import functools -def serial(start: PrimExpr, - stop: PrimExpr = None, - *, - annotations: dict[str, Any] = None) -> frame.ForFrame: +def serial(start: PrimExpr, stop: PrimExpr = None, *, annotations: dict[str, Any] = None) -> frame.ForFrame: """The serial For statement. Parameters @@ -31,10 +28,7 @@ def serial(start: PrimExpr, return _ir.serial(start=start, stop=stop, annotations=annotations) -def parallel(start: PrimExpr, - stop: PrimExpr = None, - *, - annotations: dict[str, Any] = None) -> frame.ForFrame: +def parallel(start: PrimExpr, stop: PrimExpr = None, *, annotations: dict[str, Any] = None) -> frame.ForFrame: """The parallel For statement. Parameters @@ -56,10 +50,7 @@ def parallel(start: PrimExpr, return _ir.parallel(start=start, stop=stop, annotations=annotations) -def vectorized(start: PrimExpr, - stop: PrimExpr = None, - *, - annotations: dict[str, Any] = None) -> frame.ForFrame: +def vectorized(start: PrimExpr, stop: PrimExpr = None, *, annotations: dict[str, Any] = None) -> frame.ForFrame: """The vectorized For statement. Parameters @@ -81,10 +72,7 @@ def vectorized(start: PrimExpr, return _ir.vectorized(start=start, stop=stop, annotations=annotations) -def unroll(start: PrimExpr, - stop: PrimExpr = None, - *, - annotations: dict[str, Any] = None) -> frame.ForFrame: +def unroll(start: PrimExpr, stop: PrimExpr = None, *, annotations: dict[str, Any] = None) -> frame.ForFrame: """The unrolled For statement. Parameters @@ -161,7 +149,6 @@ def grid(*extents: PrimExpr) -> frame.ForFrame: def _dtype_forward(func): - @functools.wraps(func) def wrapped(*args, **kwargs): if "dtype" in kwargs: @@ -172,7 +159,6 @@ def _dtype_forward(func): def _op_wrapper(func): - @functools.wraps(func) def wrapped(*args, **kwargs): if "dtype" in kwargs: diff --git a/tilelang/language/tir/ir.pyi b/tilelang/language/tir/ir.pyi index fe25b58..7723f13 100644 --- a/tilelang/language/tir/ir.pyi +++ b/tilelang/language/tir/ir.pyi @@ -1,22 +1,22 @@ from typing import TypeVar, Literal from tvm.tir.expr import Span, PrimExpr, BufferLoad, Var, IntImm -_T = TypeVar('_T') +_T = TypeVar("_T") -def abs(x: _T, span: Span | None=None) -> _T: ... +def abs(x: _T, span: Span | None = None) -> _T: ... def acos(x: _T) -> _T: ... def acosh(x: _T) -> _T: ... -def address_of(buffer_load: BufferLoad, span: Span | None=None) -> PrimExpr: ... +def address_of(buffer_load: BufferLoad, span: Span | None = None) -> PrimExpr: ... def asin(x: _T) -> _T: ... def asinh(x: _T) -> _T: ... def atan(x: _T) -> _T: ... def atan2(x1: _T, x2: _T) -> _T: ... def atanh(x: _T) -> _T: ... -def bitwise_and(x: _T, y: _T, span: Span | None=None) -> _T: ... -def bitwise_not(x: _T, span: Span | None=None) -> _T: ... -def bitwise_or(x: _T, y: _T, span: Span | None=None) -> _T: ... -def bitwise_xor(x: _T, y: _T, span: Span | None=None) -> _T: ... -def ceil(x: _T, span: Span | None=None) -> _T: ... +def bitwise_and(x: _T, y: _T, span: Span | None = None) -> _T: ... +def bitwise_not(x: _T, span: Span | None = None) -> _T: ... +def bitwise_or(x: _T, y: _T, span: Span | None = None) -> _T: ... +def bitwise_xor(x: _T, y: _T, span: Span | None = None) -> _T: ... +def ceil(x: _T, span: Span | None = None) -> _T: ... def clz(x: _T) -> _T: ... def copysign(x1: _T, x2: _T) -> _T: ... def cos(x: _T) -> _T: ... @@ -25,35 +25,37 @@ def erf(x: _T) -> _T: ... def exp(x: _T) -> _T: ... def exp2(x: _T) -> _T: ... def exp10(x: _T) -> _T: ... -def floor(x: _T, span: Span | None=None) -> _T: ... -def ceildiv(lhs: _T, rhs: _T, span: Span | None=None) -> _T: ... -def floordiv(a: _T, b: _T, span: Span | None=None) -> _T: ... -def floormod(a: _T, b: _T, span: Span | None=None) -> _T: ... +def floor(x: _T, span: Span | None = None) -> _T: ... +def ceildiv(lhs: _T, rhs: _T, span: Span | None = None) -> _T: ... +def floordiv(a: _T, b: _T, span: Span | None = None) -> _T: ... +def floormod(a: _T, b: _T, span: Span | None = None) -> _T: ... def fmod(x: _T, y: _T) -> _T: ... def hypot(x1: _T, x2: _T) -> _T: ... -def if_then_else(cond: PrimExpr, t: _T, f: _T, span: Span | None=None) -> _T: ... -def infinity(dtype: _T, span: Span | None=None) -> _T: ... -def isfinite(x: _T, span: Span | None=None) -> _T: ... -def isinf(x: _T, span: Span | None=None) -> _T: ... -def isnan(x: _T, span: Span | None=None) -> _T: ... -def isnullptr(x: _T, span: Span | None=None) -> _T: ... +def if_then_else(cond: PrimExpr, t: _T, f: _T, span: Span | None = None) -> _T: ... +def infinity(dtype: _T, span: Span | None = None) -> _T: ... +def isfinite(x: _T, span: Span | None = None) -> _T: ... +def isinf(x: _T, span: Span | None = None) -> _T: ... +def isnan(x: _T, span: Span | None = None) -> _T: ... +def isnullptr(x: _T, span: Span | None = None) -> _T: ... def ldexp(x1: _T, x2: _T) -> _T: ... -def likely(cond: _T, span: Span | None=None) -> _T: ... +def likely(cond: _T, span: Span | None = None) -> _T: ... def log(x: _T) -> _T: ... def log1p(x: _T) -> _T: ... def log2(x: _T) -> _T: ... def log10(x: _T) -> _T: ... -def lookup_param(param_name: str, span: Span | None=None) -> PrimExpr: ... -def max_value(dtype: str, span: Span | None=None) -> PrimExpr: ... -def min_value(dtype: str, span: Span | None=None) -> PrimExpr: ... -def nearbyint(x: _T, span: Span | None=None) -> _T: ... +def lookup_param(param_name: str, span: Span | None = None) -> PrimExpr: ... +def max_value(dtype: str, span: Span | None = None) -> PrimExpr: ... +def min_value(dtype: str, span: Span | None = None) -> PrimExpr: ... +def nearbyint(x: _T, span: Span | None = None) -> _T: ... def nextafter(x1: _T, x2: _T) -> _T: ... def popcount(x: _T) -> _T: ... -def pow(x: _T, y: _T, span: Span | None=None) -> _T: ... +def pow(x: _T, y: _T, span: Span | None = None) -> _T: ... def q_multiply_shift(x: _T, y: _T, q: _T, s: _T) -> _T: ... -def q_multiply_shift_per_axis(x: _T, y: _T, ls: _T, rs: _T, q: IntImm, is_lshift_required: IntImm, is_rshift_required: IntImm) -> PrimExpr: ... +def q_multiply_shift_per_axis( + x: _T, y: _T, ls: _T, rs: _T, q: IntImm, is_lshift_required: IntImm, is_rshift_required: IntImm +) -> PrimExpr: ... def ret(val: _T) -> _T: ... -def round(x: _T, span: Span | None=None) -> _T: ... +def round(x: _T, span: Span | None = None) -> _T: ... def rsqrt(x: _T) -> _T: ... def shift_left(x: _T, y: _T, span=None) -> _T: ... def shift_right(x: _T, y: _T, span=None) -> _T: ... @@ -63,14 +65,16 @@ def sinh(x: _T) -> _T: ... def sqrt(x: _T) -> _T: ... def tan(x: _T) -> _T: ... def tanh(x: _T) -> _T: ... -def trunc(x: _T, span: Span | None=None) -> _T: ... -def truncdiv(a: _T, b: _T, span: Span | None=None) -> _T: ... -def truncmod(a: _T, b: _T, span: Span | None=None) -> _T: ... +def trunc(x: _T, span: Span | None = None) -> _T: ... +def truncdiv(a: _T, b: _T, span: Span | None = None) -> _T: ... +def truncmod(a: _T, b: _T, span: Span | None = None) -> _T: ... def tvm_access_ptr(ptype: PrimExpr, data, offset: int, extent: int, rw_mask: int) -> PrimExpr: ... def tvm_throw_last_error() -> _T: ... def tvm_stack_alloca(dtype_str: str, num: int) -> PrimExpr: ... def tvm_stack_make_shape(*args) -> _T: ... -def tvm_stack_make_array(data: PrimExpr, shape: PrimExpr, strides: PrimExpr, ndim: PrimExpr, arr_dtype: PrimExpr, elem_offset) -> PrimExpr: ... +def tvm_stack_make_array( + data: PrimExpr, shape: PrimExpr, strides: PrimExpr, ndim: PrimExpr, arr_dtype: PrimExpr, elem_offset +) -> PrimExpr: ... def tvm_check_return(expected: int, return_unexpected: int, nested_call: PrimExpr) -> PrimExpr: ... def call_packed(*args, span=None) -> _T: ... def call_cpacked(*args, span=None) -> _T: ... @@ -80,11 +84,47 @@ def tvm_tuple(*value) -> _T: ... def tvm_struct_set(arr, index: int, field: int, value: PrimExpr) -> PrimExpr: ... def tvm_thread_invariant(cond: _T) -> _T: ... def tvm_thread_allreduce(*freduce_args) -> _T: ... -def tvm_load_matrix_sync(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, buffer_ptr: PrimExpr, stride: PrimExpr, layout: Literal['row_major', 'column_major']) -> PrimExpr: ... -def tvm_mma_sync(fragment_d: Var, index_d: PrimExpr, fragment_a: Var, index_a: PrimExpr, fragment_b: Var, index_b: PrimExpr, fragment_c: Var, index_c: PrimExpr) -> PrimExpr: ... -def tvm_bmma_sync(fragment_d: Var, index_d: PrimExpr, fragment_a: Var, index_a: PrimExpr, fragment_b: Var, index_b: PrimExpr, fragment_c: Var, index_c: PrimExpr) -> PrimExpr: ... +def tvm_load_matrix_sync( + fragment: Var, + m: IntImm, + n: IntImm, + k: IntImm, + index: PrimExpr, + buffer_ptr: PrimExpr, + stride: PrimExpr, + layout: Literal["row_major", "column_major"], +) -> PrimExpr: ... +def tvm_mma_sync( + fragment_d: Var, + index_d: PrimExpr, + fragment_a: Var, + index_a: PrimExpr, + fragment_b: Var, + index_b: PrimExpr, + fragment_c: Var, + index_c: PrimExpr, +) -> PrimExpr: ... +def tvm_bmma_sync( + fragment_d: Var, + index_d: PrimExpr, + fragment_a: Var, + index_a: PrimExpr, + fragment_b: Var, + index_b: PrimExpr, + fragment_c: Var, + index_c: PrimExpr, +) -> PrimExpr: ... def tvm_fill_fragment(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, value: PrimExpr) -> PrimExpr: ... -def tvm_store_matrix_sync(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, buffer_ptr: PrimExpr, stride: PrimExpr, layout: Literal['row_major', 'column_major']) -> PrimExpr: ... +def tvm_store_matrix_sync( + fragment: Var, + m: IntImm, + n: IntImm, + k: IntImm, + index: PrimExpr, + buffer_ptr: PrimExpr, + stride: PrimExpr, + layout: Literal["row_major", "column_major"], +) -> PrimExpr: ... def ptx_wait_group(num: int) -> PrimExpr: ... def ptx_commit_group() -> _T: ... def ptx_cp_async_barrier(barrier_id: int) -> PrimExpr: ... @@ -93,7 +133,7 @@ def ptx_arrive_barrier(barrier_id: int) -> PrimExpr: ... def ptx_arrive_barrier_expect_tx(barrier_id: int, byte_count: int) -> PrimExpr: ... def ptx_wait_barrier(barrier_id: int) -> PrimExpr: ... def create_barriers(barrier_count: int) -> PrimExpr: ... -def assume(cond: _T=None) -> _T: ... +def assume(cond: _T = None) -> _T: ... def undef() -> _T: ... def TVMBackendAllocWorkspace(device_type: int, device_id: int, nbytes: int, dtype_code_hint: int, dtype_bits_hint: int) -> PrimExpr: ... def TVMBackendFreeWorkspace(device_type: int, device_id: int, ptr: Var) -> PrimExpr: ... diff --git a/tilelang/language/tir/op.py b/tilelang/language/tir/op.py index a9ce6a5..6cf7841 100644 --- a/tilelang/language/tir/op.py +++ b/tilelang/language/tir/op.py @@ -724,8 +724,7 @@ def tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout): return _tvm_op.tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout) -def tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, - index_c): +def tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c): """TVM intrinsic for tensor core mma_sync operators Parameters @@ -759,12 +758,10 @@ def tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, call : PrimExpr The call expression. """ - return _tvm_op.tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, - fragment_c, index_c) + return _tvm_op.tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c) -def tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, - index_c): +def tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c): """TVM intrinsic for tensor core bmma_sync operators Parameters @@ -798,8 +795,7 @@ def tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, call : PrimExpr The call expression. """ - return _tvm_op.tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, - fragment_c, index_c) + return _tvm_op.tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c) def tvm_fill_fragment(fragment, m, n, k, index, value): @@ -1121,7 +1117,6 @@ def ptx_wgmma_rs( scale_in_a, scale_in_b, ): - return call_intrin( dtype, _tvm_op.Op.get("tl.ptx_wgmma_rs"), @@ -1345,8 +1340,7 @@ def ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, sme call : PrimExpr The call expression. """ - return _tvm_op.ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, - smem_offset) + return _tvm_op.ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset) def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes): @@ -1381,8 +1375,7 @@ def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, by return _tvm_op.ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes) -def ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, - barrier_id): +def ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_id): """TVM intrinsic for ptx async copy from global to shared memory using cp.async.bulk https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk @@ -1414,8 +1407,7 @@ def ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offse call : PrimExpr The call expression. """ - return _tvm_op.ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset, - bytes, barrier_id) + return _tvm_op.ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_id) def ptx_commit_group(): @@ -2951,8 +2943,7 @@ def q_multiply_shift_per_axis( z : PrimExpr The result. """ - return _tvm_op.q_multiply_shift_per_axis(x, y, ls, rs, q, is_lshift_required, - is_rshift_required) + return _tvm_op.q_multiply_shift_per_axis(x, y, ls, rs, q, is_lshift_required, is_rshift_required) def shift_left(x, y, span=None): @@ -3302,8 +3293,7 @@ def TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dt call : PrimExpr The call expression. """ - return _tvm_op.TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, - dtype_bits_hint) + return _tvm_op.TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dtype_bits_hint) def TVMBackendFreeWorkspace(device_type, device_id, ptr): diff --git a/tilelang/language/utils.py b/tilelang/language/utils.py index 136bc0b..7d68294 100644 --- a/tilelang/language/utils.py +++ b/tilelang/language/utils.py @@ -14,23 +14,18 @@ def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: list """Convert a BufferLoad to a tl.region call with explicit extents.""" indices = list(load.indices) if len(indices) > len(extents): - extents = [tir.IntImm("int32", 1) for _ in range(len(indices) - len(extents)) - ] + list(extents) + extents = [tir.IntImm("int32", 1) for _ in range(len(indices) - len(extents))] + list(extents) assert len(indices) == len(extents), f"indices = {indices}, extents = {extents}" return region(load, access_type, *extents) -def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: str, - extents: list[tir.PrimExpr]): +def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: str, extents: list[tir.PrimExpr]): """Clamp extents and return a tl.region call.""" mins = [r.min for r in buffer_region.region] region_extents = [r.extent for r in buffer_region.region] - assert len(region_extents) >= len(extents), ( - f"region_extents must be >= extents, region_extents = {region_extents}, extents = {extents}" - ) + assert len(region_extents) >= len(extents), f"region_extents must be >= extents, region_extents = {region_extents}, extents = {extents}" clamped_extents = [ - tir.min(region_extents[i], extents[i]) if i < len(extents) else region_extents[i] - for i in range(len(region_extents)) + tir.min(region_extents[i], extents[i]) if i < len(extents) else region_extents[i] for i in range(len(region_extents)) ] return region(tir.BufferLoad(buffer_region.buffer, mins), access_type, *clamped_extents) diff --git a/tilelang/language/v2/annot.py b/tilelang/language/v2/annot.py index b61d9d1..bac9214 100644 --- a/tilelang/language/v2/annot.py +++ b/tilelang/language/v2/annot.py @@ -5,6 +5,7 @@ from tvm import tir from tvm.ir.expr import PrimExpr from tvm.script.ir_builder.tir import buffer from typing import Any, Callable, Literal, TypeVar, Generic, TYPE_CHECKING + # Python 3.9 compatibility for advanced typing features try: from typing import ParamSpec, TypeVarTuple, Unpack, Self # type: ignore[attr-defined] @@ -37,16 +38,16 @@ from tvm.script.ir_builder import IRBuilder import torch import inspect -_Shapes = TypeVarTuple('_Shapes') -_Shape = ParamSpec('_Shape') -_Stride = ParamSpec('_Stride') -_DType = TypeVar('_DType') +_Shapes = TypeVarTuple("_Shapes") +_Shape = ParamSpec("_Shape") +_Stride = ParamSpec("_Stride") +_DType = TypeVar("_DType") -Scope = Literal['global', 'shared.dyn', 'local', 'local.fragment'] +Scope = Literal["global", "shared.dyn", "local", "local.fragment"] class Annot(ABC): - ''' + """ Base class for tilelang kernel annotations Tilelang kernel annotations are used to specify how to interpret each argument of the jit kernel @@ -54,12 +55,12 @@ class Annot(ABC): 1. determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) 2. parse the argument value into a hash key for jit caching 3. convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation - ''' + """ def is_kernel_arg(self) -> bool: - ''' + """ Determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) - ''' + """ return False @abstractmethod @@ -68,29 +69,29 @@ class Annot(ABC): @abstractmethod def get_key_parser(self) -> Callable[[str, Any], tuple[Any, ...]]: - ''' + """ Return a parser function that converts the argument value into a hash key for jit caching - ''' + """ @abstractmethod def create_prim_func_arg(self, name: str, value: Any, vt: ArgVarTable) -> tir.Var | tir.Buffer: - ''' + """ Convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation - ''' + """ def promote(self) -> TIRAnnot | None: - ''' + """ Try to promote the annotation into a FixedAnnot if possible Return None if not promotable - ''' + """ return None @dataclass class ArgVarTable: - ''' + """ ArgVarTable is used to manage the mapping from argument names to tir.Var objects - ''' + """ var_tab: dict[str, tir.Var] = field(default_factory=dict) tmp_name_idx: int = 0 @@ -103,50 +104,49 @@ class ArgVarTable: return self.var_tab[name] def create_tmp_name(self) -> str: - name = f'varg_{self.tmp_name_idx}' + name = f"varg_{self.tmp_name_idx}" self.tmp_name_idx += 1 return name @dataclass class Value(Annot): - kind: Literal['static', 'dynamic'] = 'dynamic' + kind: Literal["static", "dynamic"] = "dynamic" name: str | None = None dtype: dt.dtype | None = dt.int32 value: int | tir.Var | None = None creator: Callable[[], Any] | None = None def is_kernel_arg(self) -> bool: - return self.kind == 'dynamic' + return self.kind == "dynamic" @classmethod def from_value(cls, value: Any, prefer_name: str = None) -> Value: if isinstance(value, int): # handle A: T.Tensor[[1024, 1024], ...] - return Value(kind='static', name=prefer_name, dtype=dt.int32, value=value) + return Value(kind="static", name=prefer_name, dtype=dt.int32, value=value) elif isinstance(value, float): - return Value(kind='static', name=prefer_name, dtype=dt.float32, value=value) + return Value(kind="static", name=prefer_name, dtype=dt.float32, value=value) elif isinstance(value, dt.dtype): # handle A: T.float32 - return Value(kind='dynamic', name=prefer_name, dtype=value, value=None) + return Value(kind="dynamic", name=prefer_name, dtype=value, value=None) elif isinstance(value, Value): # handle A: T.dyn return value elif isinstance(value, TypeVar): - return Value(kind='static', name=value.__name__, value=None) + return Value(kind="static", name=value.__name__, value=None) elif isinstance(value, (tir.Var, PrimExpr)): # handle A: T.Tensor[[M, N, K], ...] # or primexpr annotation like A: T.Tensor[[M, N * 4 +1]] name = value.name if isinstance(value, tir.Var) else prefer_name - return Value(kind='dynamic', name=name, dtype=value.dtype, value=value) - elif value is Any or value is None or value is dt.dtype or isinstance( - value, (type,) + _GenericAliasTypes): + return Value(kind="dynamic", name=name, dtype=value.dtype, value=value) + elif value is Any or value is None or value is dt.dtype or isinstance(value, (type,) + _GenericAliasTypes): # A # no annotation # A: Any # A: _T # A: dt.dtype # A: tuple[...] - return Value(kind='static', name=prefer_name, value=None) + return Value(kind="static", name=prefer_name, value=None) else: raise TypeError(f"Unsupported Value annotation: {value!r}, type: {type(value)}") @@ -154,7 +154,7 @@ class Value(Annot): return Value(kind=self.kind, name=self.name or name, dtype=self.dtype, value=self.value) def get_key_parser(self): - if self.kind == 'static': + if self.kind == "static": if self.value is not None: expected_value = self.value @@ -172,7 +172,7 @@ class Value(Annot): return self.get_key_parser()(target) def create_prim_func_arg(self, name: str, value: Any, vt: ArgVarTable, create_arg: bool = True): - if self.kind == 'static': + if self.kind == "static": if self.value: assert self.value == value, f"static value mismatch for {name}: expected {self.value}, got {value}" return value @@ -187,18 +187,18 @@ class Value(Annot): return tb_tir.arg(name, arg) if create_arg else arg def __repr__(self): - if self.kind == 'static': + if self.kind == "static": if self.value is not None: return repr(self.value) else: - return (str(self.name) or '$unnamed') + '$' + return (str(self.name) or "$unnamed") + "$" else: if self.value is not None: return repr(self.value) elif self.creator is not None: return repr(self.creator()) else: - return (str(self.name) or '$unnamed') + '$dyn' + return (str(self.name) or "$unnamed") + "$dyn" def _canonicalize_dtype(val: Any) -> dt.dtype | None: @@ -226,7 +226,7 @@ def _shape_with_name(shape: Sequence[Value], base_name: str) -> list[Value]: return None res = [] for i, dim in enumerate(shape): - dim = dim.with_name(f'{base_name}_{i}') + dim = dim.with_name(f"{base_name}_{i}") res.append(dim) return res @@ -236,7 +236,7 @@ def _try_convert_static_shape(shape: Sequence[Value]): return None res = [] for s in shape: - if s.kind == 'static' and s.value is not None or s.kind == 'dynamic' and s.value is not None: + if s.kind == "static" and s.value is not None or s.kind == "dynamic" and s.value is not None: res.append(s.value) if len(res) == len(shape): return res @@ -253,7 +253,7 @@ class BufferAnnot(Annot): @property def scope(self): - return 'global' + return "global" def __call__( self, @@ -290,8 +290,8 @@ class BufferAnnot(Annot): return self.__class__(shape, strides=self.strides, dtype=dtype) def with_name(self, name: str): - shape = _shape_with_name(self.shape, base_name=f'{name}_shape') - strides = _shape_with_name(self.strides, base_name=f'{name}_stride') + shape = _shape_with_name(self.shape, base_name=f"{name}_shape") + strides = _shape_with_name(self.strides, base_name=f"{name}_stride") return self.__class__(shape, strides, self.dtype) def get_key_parser(self): @@ -299,14 +299,14 @@ class BufferAnnot(Annot): if self.shape is not None: raw_shapes = False shape_len = len(self.shape) - static_shape_idx = [i for i, dim in enumerate(self.shape) if dim.kind == 'static'] + static_shape_idx = [i for i, dim in enumerate(self.shape) if dim.kind == "static"] # static_fixed_shape_idx = [i for i, dim in enumerate(self.shape) if dim.kind == 'static' and dim.value is not None] # static_fixed_shape_values = [dim.value for dim in self.shape if dim.kind == 'static' and dim.value is not None] raw_strides = True if self.strides is not None: raw_strides = False strides_len = len(self.strides) - strides_shape_idx = [i for i, dim in enumerate(self.strides) if dim.kind == 'static'] + strides_shape_idx = [i for i, dim in enumerate(self.strides) if dim.kind == "static"] # static_fixed_strides_idx = [i for i, dim in enumerate(self.strides) if dim.kind == 'static' and dim.value is not None] # static_fixed_strides_values = [dim.value for dim in self.strides if dim.kind == 'static' and dim.value is not None] raw_dtype = True @@ -340,9 +340,7 @@ class BufferAnnot(Annot): if not raw_dtype: dtype = dt.dtype(dtype) if dtype != expected_dtype: - raise TypeError( - f"Tensor dtype mismatch for argument `{name}`, expected {expected_dtype}, got {dtype}" - ) + raise TypeError(f"Tensor dtype mismatch for argument `{name}`, expected {expected_dtype}, got {dtype}") return shape, strides, dtype return key_parser @@ -384,7 +382,6 @@ class BufferAnnot(Annot): class TensorAnnot(BufferAnnot): - @staticmethod def _construct_strides(shape: tuple[Any]): s, strides = 1, [1] @@ -419,7 +416,8 @@ class TensorAnnot(BufferAnnot): align=align, offset_factor=offset_factor, buffer_type=buffer_type, - axis_separators=axis_separators) + axis_separators=axis_separators, + ) def promote(self): shape = _try_convert_static_shape(self.shape) @@ -430,7 +428,6 @@ class TensorAnnot(BufferAnnot): class StridedTensorAnnot(BufferAnnot): - def __call__( self, shape, @@ -466,30 +463,27 @@ class StridedTensorAnnot(BufferAnnot): class FragmentBufferAnnot(BufferAnnot): - @property def scope(self): - return 'local.fragment' + return "local.fragment" class SharedBufferAnnot(BufferAnnot): - @property def scope(self): - return 'shared.dyn' + return "shared.dyn" class LocalBufferAnnot(BufferAnnot): - @property def scope(self): - return 'local' + return "local" class DynAnnot(Value): - ''' + """ Dynamic variable annotation represents a tvm tir.Var argument - ''' + """ def __call__(self, dtype: AnyDType = dt.float32, name: str | None = None) -> DynAnnot: return tir.Var(name, dtype) @@ -499,16 +493,16 @@ class DynAnnot(Value): params = (params,) dtype = None if len(params) == 1: - name, = params + (name,) = params if len(params) == 2: dtype, name = params dtype = _canonicalize_dtype(dtype) or dt.int32 - return DynAnnot(kind='dynamic', dtype=dtype, name=name) + return DynAnnot(kind="dynamic", dtype=dtype, name=name) @dataclass class DTypeAnnot(Annot): - ''' + """ Data type annotation ensures automatically conversion from AnyDType to dtype >>> def foo(A: T.dtype): print(A) >>> foo(torch.float32) @@ -517,7 +511,8 @@ class DTypeAnnot(Annot): dtype('float32') >>> foo('float32') dtype('float32') - ''' + """ + name: str | None = None def is_kernel_arg(self) -> bool: @@ -533,15 +528,16 @@ class DTypeAnnot(Annot): return dt.dtype(value) def __repr__(self): - return self.name + '$dtype' + return self.name + "$dtype" @dataclass class TIRAnnot(Annot): - ''' + """ TIR annotation is used to directly pass tir.Buffer or tir.Var as kernel arguments >>> def foo(A: T.Buffer((128,), T.float32)): ... - ''' + """ + data: tir.Buffer | tir.Var def is_kernel_arg(self) -> bool: @@ -564,7 +560,6 @@ class TIRAnnot(Annot): if TYPE_CHECKING: class Buffer(Generic[_Shape, _DType]): - def __init__( shape: tuple[Unpack[_Shapes]], dtype: _DType = "float32", @@ -576,26 +571,20 @@ if TYPE_CHECKING: offset_factor=0, buffer_type="", axis_separators=None, - ) -> Buffer[Callable[[Unpack[_Shapes]]], _DType]: - ... + ) -> Buffer[Callable[[Unpack[_Shapes]]], _DType]: ... @property - def shape(self: Buffer[Callable[[Unpack[_Shapes]]], _DType]) -> tuple[Unpack[_Shapes]]: - ... + def shape(self: Buffer[Callable[[Unpack[_Shapes]]], _DType]) -> tuple[Unpack[_Shapes]]: ... @property - def dtype(self: Buffer[Callable[[Unpack[_Shapes]]], _DType]) -> dt.dtype[_DType]: - ... + def dtype(self: Buffer[Callable[[Unpack[_Shapes]]], _DType]) -> dt.dtype[_DType]: ... @property - def strides(self) -> tuple[tir.PrimExpr]: - ... + def strides(self) -> tuple[tir.PrimExpr]: ... - def scope(self) -> Scope: - ... + def scope(self) -> Scope: ... class Tensor(Generic[_Shape, _DType], Buffer[_Shape, _DType]): - def __new__( shape: tuple[Unpack[_Shapes]], dtype: _DType = "float32", @@ -607,11 +596,9 @@ if TYPE_CHECKING: offset_factor=0, buffer_type="", axis_separators=None, - ) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: - ... + ) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: ... class StridedTensor(Generic[_Shape, _Stride, _DType], Buffer[_Shape, _DType]): - def __new__( shape: tuple[Unpack[_Shapes]], strides=None, @@ -623,8 +610,7 @@ if TYPE_CHECKING: offset_factor=0, buffer_type="", axis_separators=None, - ) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: - ... + ) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: ... class FragmentBuffer(Generic[_Shape, _DType], Buffer[_Shape, _DType]): pass @@ -636,16 +622,12 @@ if TYPE_CHECKING: pass class dyn(tir.Var): - - def __new__(cls, dtype: _DType = "float32", name: str | None = None) -> dyn[_DType]: - ... + def __new__(cls, dtype: _DType = "float32", name: str | None = None) -> dyn[_DType]: ... @property - def dtype(self: dyn[_DType]) -> dt.dtype[_DType]: - ... + def dtype(self: dyn[_DType]) -> dt.dtype[_DType]: ... else: - Buffer = BufferAnnot() Tensor = TensorAnnot() StridedTensor = StridedTensorAnnot() @@ -670,7 +652,7 @@ class FuncAnnot: ker_arg_names = [] for param in sig.parameters.values(): name = param.name - annot = func_annots.get(name, Value('static', name)) + annot = func_annots.get(name, Value("static", name)) if not isinstance(annot, Annot): if not isinstance(annot, type) and callable(annot): annot = annot() @@ -679,7 +661,7 @@ class FuncAnnot: elif isinstance(annot, (tir.Buffer, tir.Var)): annot = TIRAnnot(data=annot) else: - annot = Value(kind='static', name=name) + annot = Value(kind="static", name=name) annot = annot.promote() or annot annots[name] = annot.with_name(name) if annot.is_kernel_arg(): @@ -689,9 +671,9 @@ class FuncAnnot: return FuncAnnot(sig, arg_names, annots, arg_parser, ker_arg_names) def parse_key(self, *args, **kws): - ''' + """ Parse arguments and generates the cache key for jit caching - ''' + """ args = {name: arg for name, arg in zip(self.arg_names, args)} arg_dict = dict(**args, **kws) parsed = [] @@ -706,15 +688,15 @@ class FuncAnnot: return [arg_dict[name] for name in self.ker_arg_names] def create_argument(self, name: str, value: Any, vt: ArgVarTable): - ''' + """ Convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation - ''' + """ return self.annots[name].create_prim_func_arg(name, value, vt) def is_all_static(self): - ''' + """ Check if all arguments are static (i.e., can be fully determined at compile time) - ''' + """ return all(isinstance(annot, TIRAnnot) for annot in self.annots.values()) def get_all_static_args(self): diff --git a/tilelang/language/v2/ast.py b/tilelang/language/v2/ast.py index c6dfecf..26c1851 100644 --- a/tilelang/language/v2/ast.py +++ b/tilelang/language/v2/ast.py @@ -4,16 +4,18 @@ from dataclasses import dataclass from typing import Callable, Generic, Any, Literal, TypeVar from contextlib import AbstractContextManager from collections.abc import Iterable + # Python 3.9 compatibility for ParamSpec try: from typing import ParamSpec except ImportError: # Python < 3.10 from typing_extensions import ParamSpec import inspect + # from .utils import get_ast, get_compiled_object from . import utils -_span_attrs = ['lineno', 'col_offset', 'end_lineno', 'end_col_offset'] +_span_attrs = ["lineno", "col_offset", "end_lineno", "end_col_offset"] def ast_has_span(ast: ast.AST) -> bool: @@ -34,7 +36,6 @@ def ast_set_span(ast: ast.AST, span: tuple[int, int, int, int]): class QuoteVisitor(ast.NodeTransformer): - def __init__(self, names: dict[str, ast.AST], passes: list[Any] | None = None, span=None): self.names = names self.passes = passes or [] @@ -76,9 +77,8 @@ def quote_expr(expr: str, **kws) -> ast.expr: return res.value -Operator = Literal['Add', 'Sub', 'Mult', 'MatMult', 'Div', 'Mod', 'Pow', 'LShift', 'RShift', - 'BitOr', 'BitXor', 'BitAnd', 'FloorDiv'] -BoolOp = Literal['And', 'Or', 'Not'] +Operator = Literal["Add", "Sub", "Mult", "MatMult", "Div", "Mod", "Pow", "LShift", "RShift", "BitOr", "BitXor", "BitAnd", "FloorDiv"] +BoolOp = Literal["And", "Or", "Not"] def get_operator_name(operator: ast.operator) -> Operator: @@ -89,84 +89,83 @@ def get_boolop_name(boolop: ast.boolop) -> BoolOp: return boolop.__class__.__name__ -_T = TypeVar('_T') +_T = TypeVar("_T") def eval_op(op: Operator, left: Any, right: Any) -> Any: - if op == 'Add': + if op == "Add": return left + right - if op == 'Sub': + if op == "Sub": return left - right - if op == 'Mult': + if op == "Mult": return left * right - if op == 'MatMult': + if op == "MatMult": return left @ right - if op == 'Div': + if op == "Div": return left / right - if op == 'Mod': + if op == "Mod": return left % right - if op == 'Pow': + if op == "Pow": return left**right - if op == 'LShift': + if op == "LShift": return left << right - if op == 'RShift': + if op == "RShift": return left >> right - if op == 'BitOr': + if op == "BitOr": return left | right - if op == 'BitXor': + if op == "BitXor": return left ^ right - if op == 'BitAnd': + if op == "BitAnd": return left & right - if op == 'FloorDiv': + if op == "FloorDiv": return left // right - raise ValueError(f'Unknown operator: {op}') + raise ValueError(f"Unknown operator: {op}") def eval_aug_assign(op: Operator, left: Any, sl: slice, right: Any) -> Any: - if op == 'Add': + if op == "Add": left[sl] += right return left - if op == 'Sub': + if op == "Sub": left[sl] -= right return left - if op == 'Mult': + if op == "Mult": left[sl] *= right return left - if op == 'MatMult': + if op == "MatMult": left[sl] @= right return left - if op == 'Div': + if op == "Div": left[sl] /= right return left - if op == 'Mod': + if op == "Mod": left[sl] %= right return left - if op == 'Pow': + if op == "Pow": left[sl] **= right return left - if op == 'LShift': + if op == "LShift": left[sl] <<= right return left - if op == 'RShift': + if op == "RShift": left[sl] >>= right return left - if op == 'BitOr': + if op == "BitOr": left[sl] |= right return left - if op == 'BitXor': + if op == "BitXor": left[sl] ^= right return left - if op == 'BitAnd': + if op == "BitAnd": left[sl] &= right return left - if op == 'FloorDiv': + if op == "FloorDiv": left[sl] //= right return left - raise ValueError(f'Unknown operator: {op}') + raise ValueError(f"Unknown operator: {op}") -class _empty: - ... +class _empty: ... class BaseBuilder: @@ -218,13 +217,13 @@ class BaseBuilder: eval_aug_assign(op, target, sl, aug_value) def boolop(self, op: BoolOp, left: Any, right: Callable[[], Any] | None = None) -> Any: - if op == 'And': + if op == "And": return left and right() - if op == 'Or': + if op == "Or": return left or right() - if op == 'Not': + if op == "Not": return not left - raise ValueError(f'Unknown boolop: {op}') + raise ValueError(f"Unknown boolop: {op}") def ifexp(self, cond: Any, then: Callable[[], Any], otherwise: Callable[[], Any]) -> Any: return then() if cond else otherwise() @@ -249,7 +248,6 @@ class BaseBuilder: class DSLMutator(ast.NodeTransformer): - def __init__(self, closure_names: list[str]): self.tmp_counter = 0 self.closure_names = closure_names @@ -264,19 +262,13 @@ class DSLMutator(ast.NodeTransformer): br = self.get_tmp() if len(node.orelse) == 0: return quote( - f"for {br} in __tb.ctx_if(cond):\n" - f" for _ in __tb.ctx_then({br}):\n" - " pass\n", + f"for {br} in __tb.ctx_if(cond):\n for _ in __tb.ctx_then({br}):\n pass\n", cond=node.test, passes=[node.body], span=node, ) return quote( - f"for {br} in __tb.ctx_if(cond):\n" - f" for _ in __tb.ctx_then({br}):\n" - f" pass\n" - f" for _ in __tb.ctx_else({br}):\n" - f" pass\n", + f"for {br} in __tb.ctx_if(cond):\n for _ in __tb.ctx_then({br}):\n pass\n for _ in __tb.ctx_else({br}):\n pass\n", cond=node.test, passes=[node.body, node.orelse], span=node, @@ -290,7 +282,7 @@ class DSLMutator(ast.NodeTransformer): if isinstance(target, ast.Name): return f"'{target.id}'" elif isinstance(target, ast.Tuple): - return ("(" + ",".join([self._parse_names(elt) for elt in target.elts]) + ",)") + return "(" + ",".join([self._parse_names(elt) for elt in target.elts]) + ",)" else: s = ast.unparse(target) raise NotImplementedError(f"Unsupported for target `{s}`") @@ -303,8 +295,7 @@ class DSLMutator(ast.NodeTransformer): ast_set_span(var, ast_get_span(node.target)) stmts = self._emit_assign_target(node.target, var) return quote( - f"for {tmp} in __tb.ctx_for(range):\n" - " pass\n", + f"for {tmp} in __tb.ctx_for(range):\n pass\n", target=node.target, range=node.iter, passes=[stmts + node.body], @@ -319,24 +310,15 @@ class DSLMutator(ast.NodeTransformer): node = self.generic_visit(node) return quote("if __tb.ctx_break(): break", span=node) - def _emit_assign_target(self, - target: ast.expr, - rval: ast.expr, - annot: ast.expr = None) -> list[ast.AST]: + def _emit_assign_target(self, target: ast.expr, rval: ast.expr, annot: ast.expr = None) -> list[ast.AST]: if isinstance(target, ast.Name): if annot is None: - return quote( - f"name = __tb.bind('{target.id}', value)", name=target, value=rval, span=target) + return quote(f"name = __tb.bind('{target.id}', value)", name=target, value=rval, span=target) else: - return quote( - f'name = __tb.bind("{target.id}", value, annot)', - name=target, - value=rval, - annot=annot, - span=target) + return quote(f'name = __tb.bind("{target.id}", value, annot)', name=target, value=rval, annot=annot, span=target) elif isinstance(target, ast.Attribute): s = ast.unparse(target) - raise NotImplementedError(f'Attribute assignment not supported yet, `{s}`') + raise NotImplementedError(f"Attribute assignment not supported yet, `{s}`") elif isinstance(target, ast.Subscript): if annot is None: return quote( @@ -356,7 +338,6 @@ class DSLMutator(ast.NodeTransformer): span=target, ) else: - # flatten nested tuple into a list of (tmp_name, target) unpacked = [] @@ -374,11 +355,9 @@ class DSLMutator(ast.NodeTransformer): return res else: s = ast.unparse(target) - raise NotImplementedError(f'Attribute assignment not supported yet, `{s}`') + raise NotImplementedError(f"Attribute assignment not supported yet, `{s}`") - unpack_stmt = ast.Assign( - targets=[_visit_target(target)], - value=quote_expr('__tb.unwrap_value(rval)', rval=rval, span=rval)) + unpack_stmt = ast.Assign(targets=[_visit_target(target)], value=quote_expr("__tb.unwrap_value(rval)", rval=rval, span=rval)) ast_set_span(unpack_stmt, ast_get_span(target)) stmts = [unpack_stmt] bind_lvals = [] @@ -386,8 +365,7 @@ class DSLMutator(ast.NodeTransformer): def flush_binds(): if bind_lvals: - stmts.append( - quote1(f'{", ".join(bind_lvals)}, = {", ".join(bind_rvals)},', span=target)) + stmts.append(quote1(f"{', '.join(bind_lvals)}, = {', '.join(bind_rvals)},", span=target)) bind_lvals.clear() bind_rvals.clear() @@ -417,15 +395,10 @@ class DSLMutator(ast.NodeTransformer): bind_rvals.append(f'__tb.bind("{target.id}", {tmp})') elif isinstance(target, ast.Subscript): flush_binds() - stmts.append( - quote1( - f'__tb.assign_slice(lval, slice, {tmp})', - lval=target.value, - slice=target.slice, - span=target)) + stmts.append(quote1(f"__tb.assign_slice(lval, slice, {tmp})", lval=target.value, slice=target.slice, span=target)) else: s = ast.unparse(target) - raise NotImplementedError(f'Unsupported target: {s}') + raise NotImplementedError(f"Unsupported target: {s}") flush_binds() return stmts @@ -450,11 +423,7 @@ class DSLMutator(ast.NodeTransformer): target, rval = node.target, node.value op = get_operator_name(node.op) if isinstance(target, ast.Name): - return quote( - f"name = __tb.aug_assign('{op}', {target.id}, value)", - name=target, - value=rval, - span=node) + return quote(f"name = __tb.aug_assign('{op}', {target.id}, value)", name=target, value=rval, span=node) elif isinstance(target, ast.Subscript): return quote( f"__tb.aug_assign_slice('{op}', lval, slice, value)", @@ -468,16 +437,12 @@ class DSLMutator(ast.NodeTransformer): def visit_AnnAssign(self, node: ast.AnnAssign): node = self.generic_visit(node) - rval = node.value or quote_expr('__tb.empty', span=node, annot=node) + rval = node.value or quote_expr("__tb.empty", span=node, annot=node) return self._emit_assign_target(node.target, rval, annot=node.annotation) def visit_While(self, node): node = self.generic_visit(node) - return quote1( - "for _ in __tb.ctx_while(lambda: cond):\n pass", - cond=node.test, - passes=[node.body], - span=node) + return quote1("for _ in __tb.ctx_while(lambda: cond):\n pass", cond=node.test, passes=[node.body], span=node) def visit_FunctionDef(self, node: ast.FunctionDef): node = self.generic_visit(node) @@ -536,18 +501,14 @@ class DSLMutator(ast.NodeTransformer): left = comp last = split[-1] for i in reversed(range(len(split) - 1)): - last = quote_expr( - "__tb.boolop('And', left, lambda: right)", left=split[i], right=last, span=node) + last = quote_expr("__tb.boolop('And', left, lambda: right)", left=split[i], right=last, span=node) return last def visit_IfExp(self, node: ast.IfExp) -> ast.Expr: node = self.generic_visit(node) return quote_expr( - '__tb.ifexp(cond, lambda: then, lambda: otherwise)', - cond=node.test, - then=node.body, - otherwise=node.orelse, - span=node) + "__tb.ifexp(cond, lambda: then, lambda: otherwise)", cond=node.test, then=node.body, otherwise=node.orelse, span=node + ) def visit_Return(self, node: ast.Return): node = self.generic_visit(node) @@ -569,7 +530,7 @@ class DSLMutator(ast.NodeTransformer): return node -_P = ParamSpec('_P') +_P = ParamSpec("_P") @dataclass @@ -626,7 +587,7 @@ def mutate(func: Callable[_P, _T]) -> IRGenerator[_P, _T]: make_closure = utils.get_compiled_object( tree, - 'make_closure', + "make_closure", filename, func.__globals__, # use the original globalns ) diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py index 436756d..645a1ad 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/v2/builder.py @@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, Callable, Any, Generic, TypeVar, ForwardRef, U from collections.abc import Sequence from .annot import FuncAnnot, ArgVarTable, Annot import pprint + # Python 3.9 compatibility for ParamSpec and Self try: from typing import ParamSpec, Self @@ -32,9 +33,9 @@ logger = logging.getLogger(__name__) def unwrap_expr(expr) -> PrimExpr | int | float: - ''' + """ unwrap expr and convert it into PrimExpr like - ''' + """ if isinstance(expr, tir.meta_var): expr = expr.value elif isinstance(expr, Ref): @@ -47,9 +48,9 @@ def unwrap_expr(expr) -> PrimExpr | int | float: def unwrap_cond(expr): - ''' + """ unwrap expr and convert to bool condition - ''' + """ expr = unwrap_expr(expr) if isinstance(expr, (IntImm, FloatImm, StringImm)): return bool(expr.value) @@ -61,10 +62,10 @@ def unwrap_cond(expr): return bool(expr) else: logger.warning( - f"Python expression `{expr}` is used as condition in TileLang, \n" - "this is treated as a constant expression. ", + f"Python expression `{expr}` is used as condition in TileLang, \nthis is treated as a constant expression. ", stack_info=True, - stacklevel=3) + stacklevel=3, + ) return bool(expr) @@ -72,44 +73,35 @@ thread_local_storage = threading.local() class Frame: - ''' + """ Frame are virtual context managers used in frontend only They do not have any runtime representation in the generated TIR. - ''' + """ - def __enter__(self): - ... + def __enter__(self): ... - def __exit__(self, exc_type, exc_value, traceback): - ... + def __exit__(self, exc_type, exc_value, traceback): ... -class MacroFrame(Frame): - ... +class MacroFrame(Frame): ... -class ExitedMacroFrame(Frame): - ... +class ExitedMacroFrame(Frame): ... -class BoolOpFrame(Frame): - ... +class BoolOpFrame(Frame): ... -class ConstIfFrame(Frame): - ... +class ConstIfFrame(Frame): ... -class BlockFrame(Frame): - ... +class BlockFrame(Frame): ... -class ContinueFrame(Frame): - ... +class ContinueFrame(Frame): ... -class BreakFrame(Frame): - ... +class BreakFrame(Frame): ... @dataclass @@ -145,8 +137,7 @@ class Ref: return self.bufload -class UnrollForWithStep(SerialForWithStep): - ... +class UnrollForWithStep(SerialForWithStep): ... # Python 3.9 compatibility: avoid PEP 604 unions at runtime @@ -172,11 +163,10 @@ TIR_VAR_SCOPE_FRAME = ( def is_var(v: Any) -> bool: - return isinstance(v, Buffer) and v.scope() == 'local.var' + return isinstance(v, Buffer) and v.scope() == "local.var" class Builder(BaseBuilder): - def __init__(self, func_annot: FuncAnnot = None): self.frames: list[AnyFrame] = [] self.ir_builder = IRBuilder() @@ -189,7 +179,7 @@ class Builder(BaseBuilder): @classmethod def current(cls) -> Self: - builder = getattr(thread_local_storage, 'builder', None) + builder = getattr(thread_local_storage, "builder", None) return builder @contextmanager @@ -199,14 +189,15 @@ class Builder(BaseBuilder): tir.func_name(name) yield if len(self.out_idx) != self.out_tensor_cnt: - raise RuntimeError('Not all tensor allocated from `T.empty` are returned') + raise RuntimeError("Not all tensor allocated from `T.empty` are returned") @contextmanager def macro(self, name=None, annotations=None): if self.find_frame_idx(BoolOpFrame) is not None: raise RuntimeError( f"Macro `{name}` is used inside boolean expressions, " - "please use `if` to replace `M and M`, `M or M`, `M if xxx else M` constructs") + "please use `if` to replace `M and M`, `M or M`, `M if xxx else M` constructs" + ) save = self.name_inside_frame, self.macro_arg_annot self.name_inside_frame = {} self.macro_arg_annot = annotations or {} @@ -244,10 +235,7 @@ class Builder(BaseBuilder): def check_continue_break(self): idx = self.find_frame_idx(ContinueOrBreak) if idx is not None: - logger.warning( - 'Writing code after continue/break may cause undefined behavior in tilelang.', - stack_info=True, - stacklevel=3) + logger.warning("Writing code after continue/break may cause undefined behavior in tilelang.", stack_info=True, stacklevel=3) @contextmanager def with_frame(self, frame: AbstractContextManager[Any] | None): @@ -256,8 +244,7 @@ class Builder(BaseBuilder): while len(self.frames) > pop_idx: self.frames.pop().__exit__(None, None, None) - class _has_if_frame: - ... + class _has_if_frame: ... def ctx_if(self, cond): self.check_continue_break() @@ -294,7 +281,7 @@ class Builder(BaseBuilder): elif isinstance(val, tir.frame.IRBuilderFrame): if isinstance(val, tir.frame.ForFrame): logger.warning( - 'Evaluating a for frame may cause undefined behavior in tilelang.', + "Evaluating a for frame may cause undefined behavior in tilelang.", stack_info=True, stacklevel=1, ) @@ -310,8 +297,7 @@ class Builder(BaseBuilder): elif isinstance(val, (Buffer, Var)): pass else: - logger.warning( - f"Unused return value: {val}({type(val)})", stack_info=True, stacklevel=2) + logger.warning(f"Unused return value: {val}({type(val)})", stack_info=True, stacklevel=2) def ctx_for(self, it): self.check_continue_break() @@ -321,15 +307,13 @@ class Builder(BaseBuilder): if isinstance(it.step, (int, IntImm)): step_value = it.step if isinstance(it.step, int) else it.step.value if step_value == 0: - raise ValueError('Invalid stepped serial: step must be non-zero') + raise ValueError("Invalid stepped serial: step must be non-zero") if step_value > 0: real_stop = tir.ceildiv(it.stop - it.start, step_value) else: real_stop = tir.ceildiv(it.start - it.stop, -step_value) else: - logger.warning( - f'Using a non-constant step `{it.step}` in stepped serial may lead to undefined behavior in tilelang' - ) + logger.warning(f"Using a non-constant step `{it.step}` in stepped serial may lead to undefined behavior in tilelang") real_stop = tir.ceildiv(it.stop - it.start, it.step) if isinstance(it, UnrollForWithStep): real_frame = tir.unroll(real_stop, annotations=it.annotations) @@ -338,15 +322,17 @@ class Builder(BaseBuilder): else: raise TypeError( f"Invalid for loop, got {it}({type(it)}), expect one of the following: " - "range, T.serial, T.unroll, T.grid, T.parallel, T.vectorized, T.thread_binding") + "range, T.serial, T.unroll, T.grid, T.parallel, T.vectorized, T.thread_binding" + ) with self.with_frame(real_frame) as v: - IRBuilder.name('_tmp', v) + IRBuilder.name("_tmp", v) yield it.start + v * it.step else: if not isinstance(it, tir.frame.ForFrame): raise TypeError( f"Invalid for loop, got {it}({type(it)}), expect one of the following: " - "range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding") + "range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding" + ) with self.with_frame(it) as v: yield v @@ -369,15 +355,16 @@ class Builder(BaseBuilder): if not isinstance(cond_v_unwrap, PrimExpr): if cond_v_unwrap: raise RuntimeError( - f'Infinite while loop detected in TileLang\n' - f'Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n' + f"Infinite while loop detected in TileLang\n" + f"Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n" ) else: logger.warning( - 'While loop with constant false condition detected in Tilelang, the loop body will never be executed.\n', - f'Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n', + "While loop with constant false condition detected in Tilelang, the loop body will never be executed.\n", + f"Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n", stack_info=True, - stacklevel=2) + stacklevel=2, + ) with self.with_frame(tir.While(cond_v_unwrap)): yield None @@ -406,14 +393,14 @@ class Builder(BaseBuilder): # 2. Quick return for trivil types if isinstance(value, (tuple, list, tvm.ffi.Array, int, float, str)): return value - if isinstance(value, tir.IntImm) and value.dtype == 'int32': + if isinstance(value, tir.IntImm) and value.dtype == "int32": return value.value if isinstance(value, (Var, Buffer)): # Bind TVM Var/Buffer names and also record scope so reusing the same # Python name (e.g., loop vars like `i`) across different for-frames # works without triggering out-of-scope errors. IRBuilder.name(name, value) - if name != '_': + if name != "_": frame = self.find_frame_idx(TIR_VAR_SCOPE_FRAME) assert frame is not None, f"Variable `{name}` is not defined inside any control flow." self.name_inside_frame[name] = self.frames[frame] @@ -423,12 +410,12 @@ class Builder(BaseBuilder): res = self.bind_immutable(name, value) # 4. Check variable scope and shadowing - if name != '_': + if name != "_": frame = self.find_frame_idx(TIR_VAR_SCOPE_FRAME) assert frame is not None, f"Variable `{name}` is not defined inside any control flow." if name in self.name_inside_frame and self.name_inside_frame[name] in self.frames: logger.warning( - f'Variable `{name}` is declared twice, are you looking for a T.alloc_var?', + f"Variable `{name}` is declared twice, are you looking for a T.alloc_var?", stack_info=True, stacklevel=2, ) @@ -436,9 +423,9 @@ class Builder(BaseBuilder): return res def unwrap_value(self, value): - ''' + """ Unwrap some tilelang objects to get their inner value - ''' + """ value = unwrap_expr(value) # handle bx, by = tl.Kernel(128, 128), rval is frame if isinstance(value, tir.frame.IRBuilderFrame): @@ -447,11 +434,11 @@ class Builder(BaseBuilder): return value def bind_immutable(self, name, value): - ''' + """ Bind an immutable tilelang objects. The immutability means the result is usually not changed or re-assigned in a python block. - ''' - if name == '_': + """ + if name == "_": # use _tmp to make the generated tir more readable name = "_tmp" if isinstance(value, tir.meta_var): @@ -459,18 +446,20 @@ class Builder(BaseBuilder): elif isinstance(value, tir.frame.IRBuilderFrame): if isinstance(value, tir.frame.ForFrame): logger.warning( - 'Binding a for frame to variable may cause undefined behavior in tilelang.', + "Binding a for frame to variable may cause undefined behavior in tilelang.", stack_info=True, stacklevel=2, ) return self.enter_frame(value) elif isinstance(value, OutTensor): - arg = tir.arg(name, - tir.buffer( - shape=value.shape, - dtype=value.dtype, - strides=value.strides, - )) + arg = tir.arg( + name, + tir.buffer( + shape=value.shape, + dtype=value.dtype, + strides=value.strides, + ), + ) arg._out_idx = self.out_tensor_cnt self.out_tensor_cnt += 1 return arg @@ -490,8 +479,7 @@ class Builder(BaseBuilder): def assign_slice(self, lval: Any, sl: slice, value: Any, annot=BaseBuilder.empty): self.check_continue_break() if annot is not self.empty: - logger.warning( - "Type annotation in slice assignment has no effect", stack_info=True, stacklevel=2) + logger.warning("Type annotation in slice assignment has no effect", stack_info=True, stacklevel=2) if isinstance(lval, Buffer): tir.buffer_store(lval, value, sl) else: @@ -521,11 +509,11 @@ class Builder(BaseBuilder): left = unwrap_cond(left) if isinstance(left, PrimExpr): with self.with_frame(BoolOpFrame()): - if op == 'And': + if op == "And": return tir.And(left, right()) - if op == 'Or': + if op == "Or": return tir.Or(left, right()) - if op == 'Not': + if op == "Not": return tir.Not(left) raise RuntimeError(f"Unsupported boolean operator: {op}") else: @@ -557,7 +545,7 @@ class Builder(BaseBuilder): "You should allocate a var before the control flow, assign value inside the blocks, \n" "and return the var after the control flow. i.e.\n" "```\n" - "@T.macro\n" \ + "@T.macro\n" "def my_macro(cond):\n" " a = T.alloc_var(T.float16)\n" " if cond:\n" @@ -570,14 +558,12 @@ class Builder(BaseBuilder): if not isinstance(value, tuple): value = (value,) for v in value: - if not isinstance(v, Buffer) or not hasattr(v, '_out_idx'): - raise RuntimeError( - f'Only tensor allocated from `T.empty` can be returned in a prim_func, got {v}({type(v)})' - ) + if not isinstance(v, Buffer) or not hasattr(v, "_out_idx"): + raise RuntimeError(f"Only tensor allocated from `T.empty` can be returned in a prim_func, got {v}({type(v)})") # convert 0, 1, 2 => -3, -2, -1 as the out tensor index self.out_idx.append(v._out_idx - self.out_tensor_cnt) if len(self.out_idx) != self.out_tensor_cnt: - raise RuntimeError(f'Not all tensor from `T.empty` are returned, only got {value}') + raise RuntimeError(f"Not all tensor from `T.empty` are returned, only got {value}") return NotImplemented def ctx_with(self, ctx): @@ -591,7 +577,7 @@ class Builder(BaseBuilder): self.check_continue_break() cond = unwrap_cond(cond) if msg is None: - msg = 'Assertion failed' + msg = "Assertion failed" if isinstance(cond, PrimExpr): self.enter_frame(tir.Assert(cond, msg)) elif not cond: @@ -611,23 +597,18 @@ class Builder(BaseBuilder): annot_value = self.macro_arg_annot.get(name, None) if annot_value is Var or annot_value is Ref: if annot_value is Var: - logger.warning('Use `T.Var` as macro annotations is deprecated, please use `T.Ref`') + logger.warning("Use `T.Var` as macro annotations is deprecated, please use `T.Ref`") if isinstance(value, BufferLoad): if is_var(value.buffer): return value.buffer - idx = [self.bind('_', idx) for idx in value.indices] + idx = [self.bind("_", idx) for idx in value.indices] # indices = self.bind(f'_', value.indices) return Ref(BufferLoad(value.buffer, indices=idx)) if isinstance(value, BufferRegion): - region = [ - Range( - self.bind('_', x.begin), - end=self.bind('_', x.end) if x.end is not None else None) - for x in value.region - ] + region = [Range(self.bind("_", x.begin), end=self.bind("_", x.end) if x.end is not None else None) for x in value.region] return BufferRegion(value.buffer, region=region) raise ValueError( - f'To pass as reference, argument `{name}` is expected to be a variable or a buffer region, but got {value}({type(value)})' + f"To pass as reference, argument `{name}` is expected to be a variable or a buffer region, but got {value}({type(value)})" ) elif isinstance(value, (PrimExpr, int, float)): return self.bind(name, value) @@ -652,13 +633,14 @@ class Builder(BaseBuilder): def override(self, name: str): from tilelang.language import serial - if name == 'range': + + if name == "range": return serial - raise ValueError(f'Unknown override: {name}') + raise ValueError(f"Unknown override: {name}") -_P = ParamSpec('_P') -_T = TypeVar('_T') +_P = ParamSpec("_P") +_T = TypeVar("_T") @dataclass @@ -683,14 +665,8 @@ class PrimFuncCreater(Generic[_P, _T]): return res def __repr__(self): - fmt = pprint.pformat( - { - 'annot': self.func_annot.annots, - 'ir_gen': self.ir_gen, - 'orig_func': self.orig_func - }, - indent=2) - return f'{self.__class__.__name__}(\n{fmt}\n)' + fmt = pprint.pformat({"annot": self.func_annot.annots, "ir_gen": self.ir_gen, "orig_func": self.orig_func}, indent=2) + return f"{self.__class__.__name__}(\n{fmt}\n)" if TYPE_CHECKING: @@ -769,8 +745,7 @@ def macro(func: Callable[_P, _T] = None) -> Macro[_P, _T]: def impl(func: Callable[_P, _T]) -> Macro[_P, _T]: annotations = get_type_hints(func) - return Macro( - name=func.__name__, orig_func=func, ir_gen=mutate(func), annotations=annotations) + return Macro(name=func.__name__, orig_func=func, ir_gen=mutate(func), annotations=annotations) return impl(func) if func is not None else impl @@ -779,9 +754,9 @@ from typing import _eval_type def get_type_hints(func): - annot = getattr(func, '__annotations__', None) + annot = getattr(func, "__annotations__", None) if annot is None: - raise TypeError(f'Failed to get function type hints, {func} is not a function') + raise TypeError(f"Failed to get function type hints, {func} is not a function") hints = {} # Build eval namespaces from function globals plus captured closure variables # This lets annotations reference symbols like `n`, `h`, or dtype vars @@ -808,7 +783,7 @@ def get_type_hints(func): # ... # empty function, do not use `n` localns = utils.get_func_nonlocals(func) for name, value in annot.items(): - if name == 'return': + if name == "return": continue if isinstance(value, tvm.DataType): hints[name] = value @@ -821,7 +796,7 @@ def get_type_hints(func): # typing see: T.float32 is str('float32'), and there is no object named `flaot32` and give a NameError # here we manually interpret it to return T.float32 object try: - _, v = value.split('.', maxsplit=1) + _, v = value.split(".", maxsplit=1) except ValueError: v = value if v in dt._all_dtypes: @@ -837,9 +812,7 @@ def get_type_hints(func): return hints -def prim_func(func: Callable[_P, _T] = None, - *, - generator: bool = False) -> PrimFunc[_P, _T] | PrimFuncCreater[_P, _T]: +def prim_func(func: Callable[_P, _T] = None, *, generator: bool = False) -> PrimFunc[_P, _T] | PrimFuncCreater[_P, _T]: """ Decorator to create a primitive function (PrimFunc) for TileLang IR generation. This decorator transforms a Python function into a TileLang primitive function by analyzing @@ -903,7 +876,8 @@ def prim_func(func: Callable[_P, _T] = None, raise ValueError( f"Cannot create PrimFunc for `{func.__name__}`, some arguments are not compile-time known, \n" f"Annotations:\n{func_annot.annots}" - f"Unknown Args: {unknown_args}") + f"Unknown Args: {unknown_args}" + ) return prim_func_generator return impl(func) if func is not None else impl diff --git a/tilelang/language/v2/dtypes.py b/tilelang/language/v2/dtypes.py index 3b21587..6ed56b4 100644 --- a/tilelang/language/v2/dtypes.py +++ b/tilelang/language/v2/dtypes.py @@ -6,14 +6,12 @@ from tvm import tir import tvm.script.ir_builder.tir._ffi_api as tb_ffi import numpy as np -_T = TypeVar('_T') +_T = TypeVar("_T") if TYPE_CHECKING: class dtype(Generic[_T]): - - def torch(self) -> torch.dtype: - ... + def torch(self) -> torch.dtype: ... else: dtype = tvm.DataType @@ -21,53 +19,53 @@ else: AnyDType = Union[ir.Type, str, type, torch.dtype, dtype] _PYTHON_DTYPE_TO_STR = { - bool: 'bool', - int: 'int32', - float: 'float32', + bool: "bool", + int: "int32", + float: "float32", } _NUMPY_DTYPE_TO_STR = { - np.bool_: 'bool', - np.short: 'int16', - np.int_: 'int64', - np.longlong: 'int64', - np.half: 'float16', - np.double: 'float64', - np.int8: 'int8', - np.int16: 'int16', - np.int32: 'int32', - np.int64: 'int64', - np.uint8: 'uint8', - np.uint16: 'uint16', - np.uint32: 'uint32', - np.uint64: 'uint64', - np.float16: 'float16', - np.float32: 'float32', - np.float64: 'float64', + np.bool_: "bool", + np.short: "int16", + np.int_: "int64", + np.longlong: "int64", + np.half: "float16", + np.double: "float64", + np.int8: "int8", + np.int16: "int16", + np.int32: "int32", + np.int64: "int64", + np.uint8: "uint8", + np.uint16: "uint16", + np.uint32: "uint32", + np.uint64: "uint64", + np.float16: "float16", + np.float32: "float32", + np.float64: "float64", } _NUMPY_DTYPE_TO_STR.update({np.dtype(k): v for k, v in _NUMPY_DTYPE_TO_STR.items()}) _TORCH_DTYPE_TO_STR = { - torch.bool: 'bool', - torch.short: 'int16', - torch.int: 'int32', - torch.long: 'int64', - torch.half: 'float16', - torch.float: 'float32', - torch.double: 'float64', - torch.int8: 'int8', - torch.int16: 'int16', - torch.int32: 'int32', - torch.int64: 'int64', - torch.uint8: 'uint8', - torch.uint16: 'uint16', - torch.uint32: 'uint32', - torch.uint64: 'uint64', - torch.float16: 'float16', - torch.float32: 'float32', - torch.float64: 'float64', - torch.bfloat16: 'bfloat16', + torch.bool: "bool", + torch.short: "int16", + torch.int: "int32", + torch.long: "int64", + torch.half: "float16", + torch.float: "float32", + torch.double: "float64", + torch.int8: "int8", + torch.int16: "int16", + torch.int32: "int32", + torch.int64: "int64", + torch.uint8: "uint8", + torch.uint16: "uint16", + torch.uint32: "uint32", + torch.uint64: "uint64", + torch.float16: "float16", + torch.float32: "float32", + torch.float64: "float64", + torch.bfloat16: "bfloat16", } # _STR_TO_TORCH_DTYPE = {v: k for k, v in _TORCH_DTYPE_TO_STR.items()} @@ -77,24 +75,24 @@ _TORCH_DTYPE_TO_STR = { _DTYPE_TO_STR = {**_PYTHON_DTYPE_TO_STR, **_NUMPY_DTYPE_TO_STR, **_TORCH_DTYPE_TO_STR} _STR_TO_TVM_DTYPE_CALL = { - 'bool': 'Boolean', - 'int8': 'Int8', - 'int32': 'Int32', - 'int64': 'Int64', - 'uint8': 'UInt8', - 'uint16': 'UInt16', - 'uint32': 'UInt32', - 'uint64': 'UInt64', - 'float16': 'Float16', - 'float32': 'Float32', - 'float64': 'Float64', - 'bfloat16': 'BFloat16', - 'float8_e4m3': 'Float8E4M3', - 'float8_e4m3fn': 'Float8E4M3FN', - 'float8_e4m3fnuz': 'Float8E4M3FNUZ', - 'float8_e5m2': 'Float8E5M2', - 'float8_e5m2fnuz': 'Float8E5M2FNUZ', - 'float8_e8m0fnu': 'Float8E8M0FNU' + "bool": "Boolean", + "int8": "Int8", + "int32": "Int32", + "int64": "Int64", + "uint8": "UInt8", + "uint16": "UInt16", + "uint32": "UInt32", + "uint64": "UInt64", + "float16": "Float16", + "float32": "Float32", + "float64": "Float64", + "bfloat16": "BFloat16", + "float8_e4m3": "Float8E4M3", + "float8_e4m3fn": "Float8E4M3FN", + "float8_e4m3fnuz": "Float8E4M3FNUZ", + "float8_e5m2": "Float8E5M2", + "float8_e5m2fnuz": "Float8E5M2FNUZ", + "float8_e8m0fnu": "Float8E8M0FNU", } int_ = int @@ -108,23 +106,24 @@ def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var call = getattr(tb_ffi, attr, None) return call(expr, is_size_var) # try to construct the ffi call - if self.startswith('uint'): - val = 'UInt' + self[4:] - elif self.startswith('int'): - val = 'Int' + self[3:] - elif self.startswith('float'): - val = 'Float' + self[5:] - elif self.startswith('bfloat'): - val = 'BFloat' + self[6:] + if self.startswith("uint"): + val = "UInt" + self[4:] + elif self.startswith("int"): + val = "Int" + self[3:] + elif self.startswith("float"): + val = "Float" + self[5:] + elif self.startswith("bfloat"): + val = "BFloat" + self[6:] else: - raise TypeError(f'Invalid type {self}') - if '_' in val: - first, second = val.split('_', maxsplit=1) + raise TypeError(f"Invalid type {self}") + if "_" in val: + first, second = val.split("_", maxsplit=1) val = first + second.upper() call = getattr(tb_ffi, val, None) if call is None: - raise TypeError(f"Convert to datatype `{self}` is not supported by tvm\n" - f"calling failed on `tvm.script.ir_builder.tir._ffi_api.{val}`") + raise TypeError( + f"Convert to datatype `{self}` is not supported by tvm\ncalling failed on `tvm.script.ir_builder.tir._ffi_api.{val}`" + ) return call(expr, is_size_var) @@ -152,7 +151,6 @@ def get_tvm_dtype(value: AnyDType) -> dtype: if TYPE_CHECKING: - # yapf: disable class bool(dtype): ... class short(dtype): ... @@ -319,336 +317,336 @@ if TYPE_CHECKING: # yapf: enable else: - bool = dtype('bool') - short = dtype('int16') - int = dtype('int32') - long = dtype('int64') - half = dtype('float16') - float = dtype('float32') - double = dtype('float64') - int8 = dtype('int8') - int16 = dtype('int16') - int32 = dtype('int32') - int64 = dtype('int64') - int8x2 = dtype('int8x2') - int16x2 = dtype('int16x2') - int32x2 = dtype('int32x2') - int64x2 = dtype('int64x2') - int8x4 = dtype('int8x4') - int16x4 = dtype('int16x4') - int32x4 = dtype('int32x4') - int64x4 = dtype('int64x4') - int8x8 = dtype('int8x8') - int16x8 = dtype('int16x8') - int32x8 = dtype('int32x8') - int64x8 = dtype('int64x8') - int8x16 = dtype('int8x16') - int16x16 = dtype('int16x16') - int32x16 = dtype('int32x16') - int64x16 = dtype('int64x16') - int8x32 = dtype('int8x32') - int16x32 = dtype('int16x32') - int32x32 = dtype('int32x32') - int64x32 = dtype('int64x32') - int8x64 = dtype('int8x64') - int16x64 = dtype('int16x64') - int32x64 = dtype('int32x64') - int64x64 = dtype('int64x64') - uint8 = dtype('uint8') - uint16 = dtype('uint16') - uint32 = dtype('uint32') - uint64 = dtype('uint64') - uint8x2 = dtype('uint8x2') - uint16x2 = dtype('uint16x2') - uint32x2 = dtype('uint32x2') - uint64x2 = dtype('uint64x2') - uint8x4 = dtype('uint8x4') - uint16x4 = dtype('uint16x4') - uint32x4 = dtype('uint32x4') - uint64x4 = dtype('uint64x4') - uint8x8 = dtype('uint8x8') - uint16x8 = dtype('uint16x8') - uint32x8 = dtype('uint32x8') - uint64x8 = dtype('uint64x8') - uint8x16 = dtype('uint8x16') - uint16x16 = dtype('uint16x16') - uint32x16 = dtype('uint32x16') - uint64x16 = dtype('uint64x16') - uint8x32 = dtype('uint8x32') - uint16x32 = dtype('uint16x32') - uint32x32 = dtype('uint32x32') - uint64x32 = dtype('uint64x32') - uint8x64 = dtype('uint8x64') - uint16x64 = dtype('uint16x64') - uint32x64 = dtype('uint32x64') - uint64x64 = dtype('uint64x64') - float16 = dtype('float16') - float32 = dtype('float32') - float64 = dtype('float64') - float16x2 = dtype('float16x2') - float32x2 = dtype('float32x2') - float64x2 = dtype('float64x2') - float16x4 = dtype('float16x4') - float32x4 = dtype('float32x4') - float64x4 = dtype('float64x4') - float16x8 = dtype('float16x8') - float32x8 = dtype('float32x8') - float64x8 = dtype('float64x8') - float16x16 = dtype('float16x16') - float32x16 = dtype('float32x16') - float64x16 = dtype('float64x16') - float16x32 = dtype('float16x32') - float32x32 = dtype('float32x32') - float64x32 = dtype('float64x32') - float16x64 = dtype('float16x64') - float32x64 = dtype('float32x64') - float64x64 = dtype('float64x64') - float8_e3m4 = dtype('float8_e3m4') - float8_e3m4x2 = dtype('float8_e3m4x2') - float8_e3m4x4 = dtype('float8_e3m4x4') - float8_e3m4x8 = dtype('float8_e3m4x8') - float8_e3m4x16 = dtype('float8_e3m4x16') - float8_e3m4x32 = dtype('float8_e3m4x32') - float8_e3m4x64 = dtype('float8_e3m4x64') - float8_e4m3 = dtype('float8_e4m3') - float8_e4m3x2 = dtype('float8_e4m3x2') - float8_e4m3x4 = dtype('float8_e4m3x4') - float8_e4m3x8 = dtype('float8_e4m3x8') - float8_e4m3x16 = dtype('float8_e4m3x16') - float8_e4m3x32 = dtype('float8_e4m3x32') - float8_e4m3x64 = dtype('float8_e4m3x64') - float8_e4m3b11fnuz = dtype('float8_e4m3b11fnuz') - float8_e4m3b11fnuzx2 = dtype('float8_e4m3b11fnuzx2') - float8_e4m3b11fnuzx4 = dtype('float8_e4m3b11fnuzx4') - float8_e4m3b11fnuzx8 = dtype('float8_e4m3b11fnuzx8') - float8_e4m3b11fnuzx16 = dtype('float8_e4m3b11fnuzx16') - float8_e4m3b11fnuzx32 = dtype('float8_e4m3b11fnuzx32') - float8_e4m3b11fnuzx64 = dtype('float8_e4m3b11fnuzx64') - float8_e4m3fn = dtype('float8_e4m3fn') - float8_e4m3fnx2 = dtype('float8_e4m3fnx2') - float8_e4m3fnx4 = dtype('float8_e4m3fnx4') - float8_e4m3fnx8 = dtype('float8_e4m3fnx8') - float8_e4m3fnx16 = dtype('float8_e4m3fnx16') - float8_e4m3fnx32 = dtype('float8_e4m3fnx32') - float8_e4m3fnx64 = dtype('float8_e4m3fnx64') - float8_e4m3fnuz = dtype('float8_e4m3fnuz') - float8_e4m3fnuzx2 = dtype('float8_e4m3fnuzx2') - float8_e4m3fnuzx4 = dtype('float8_e4m3fnuzx4') - float8_e4m3fnuzx8 = dtype('float8_e4m3fnuzx8') - float8_e4m3fnuzx16 = dtype('float8_e4m3fnuzx16') - float8_e4m3fnuzx32 = dtype('float8_e4m3fnuzx32') - float8_e4m3fnuzx64 = dtype('float8_e4m3fnuzx64') - float8_e5m2 = dtype('float8_e5m2') - float8_e5m2x2 = dtype('float8_e5m2x2') - float8_e5m2x4 = dtype('float8_e5m2x4') - float8_e5m2x8 = dtype('float8_e5m2x8') - float8_e5m2x16 = dtype('float8_e5m2x16') - float8_e5m2x32 = dtype('float8_e5m2x32') - float8_e5m2x64 = dtype('float8_e5m2x64') - float8_e5m2fnuz = dtype('float8_e5m2fnuz') - float8_e5m2fnuzx2 = dtype('float8_e5m2fnuzx2') - float8_e5m2fnuzx4 = dtype('float8_e5m2fnuzx4') - float8_e5m2fnuzx8 = dtype('float8_e5m2fnuzx8') - float8_e5m2fnuzx16 = dtype('float8_e5m2fnuzx16') - float8_e5m2fnuzx32 = dtype('float8_e5m2fnuzx32') - float8_e5m2fnuzx64 = dtype('float8_e5m2fnuzx64') - float8_e8m0fnu = dtype('float8_e8m0fnu') - float8_e8m0fnux2 = dtype('float8_e8m0fnux2') - float8_e8m0fnux4 = dtype('float8_e8m0fnux4') - float8_e8m0fnux8 = dtype('float8_e8m0fnux8') - float8_e8m0fnux16 = dtype('float8_e8m0fnux16') - float8_e8m0fnux32 = dtype('float8_e8m0fnux32') - float8_e8m0fnux64 = dtype('float8_e8m0fnux64') - float6_e2m3fn = dtype('float6_e2m3fn') - float6_e2m3fnx2 = dtype('float6_e2m3fnx2') - float6_e2m3fnx4 = dtype('float6_e2m3fnx4') - float6_e2m3fnx8 = dtype('float6_e2m3fnx8') - float6_e2m3fnx16 = dtype('float6_e2m3fnx16') - float6_e2m3fnx32 = dtype('float6_e2m3fnx32') - float6_e2m3fnx64 = dtype('float6_e2m3fnx64') - float6_e3m2fn = dtype('float6_e3m2fn') - float6_e3m2fnx2 = dtype('float6_e3m2fnx2') - float6_e3m2fnx4 = dtype('float6_e3m2fnx4') - float6_e3m2fnx8 = dtype('float6_e3m2fnx8') - float6_e3m2fnx16 = dtype('float6_e3m2fnx16') - float6_e3m2fnx32 = dtype('float6_e3m2fnx32') - float6_e3m2fnx64 = dtype('float6_e3m2fnx64') - float4_e2m1fn = dtype('float4_e2m1fn') - float4_e2m1fnx2 = dtype('float4_e2m1fnx2') - float4_e2m1fnx4 = dtype('float4_e2m1fnx4') - float4_e2m1fnx8 = dtype('float4_e2m1fnx8') - float4_e2m1fnx16 = dtype('float4_e2m1fnx16') - float4_e2m1fnx32 = dtype('float4_e2m1fnx32') - float4_e2m1fnx64 = dtype('float4_e2m1fnx64') - bfloat16 = dtype('bfloat16') + bool = dtype("bool") + short = dtype("int16") + int = dtype("int32") + long = dtype("int64") + half = dtype("float16") + float = dtype("float32") + double = dtype("float64") + int8 = dtype("int8") + int16 = dtype("int16") + int32 = dtype("int32") + int64 = dtype("int64") + int8x2 = dtype("int8x2") + int16x2 = dtype("int16x2") + int32x2 = dtype("int32x2") + int64x2 = dtype("int64x2") + int8x4 = dtype("int8x4") + int16x4 = dtype("int16x4") + int32x4 = dtype("int32x4") + int64x4 = dtype("int64x4") + int8x8 = dtype("int8x8") + int16x8 = dtype("int16x8") + int32x8 = dtype("int32x8") + int64x8 = dtype("int64x8") + int8x16 = dtype("int8x16") + int16x16 = dtype("int16x16") + int32x16 = dtype("int32x16") + int64x16 = dtype("int64x16") + int8x32 = dtype("int8x32") + int16x32 = dtype("int16x32") + int32x32 = dtype("int32x32") + int64x32 = dtype("int64x32") + int8x64 = dtype("int8x64") + int16x64 = dtype("int16x64") + int32x64 = dtype("int32x64") + int64x64 = dtype("int64x64") + uint8 = dtype("uint8") + uint16 = dtype("uint16") + uint32 = dtype("uint32") + uint64 = dtype("uint64") + uint8x2 = dtype("uint8x2") + uint16x2 = dtype("uint16x2") + uint32x2 = dtype("uint32x2") + uint64x2 = dtype("uint64x2") + uint8x4 = dtype("uint8x4") + uint16x4 = dtype("uint16x4") + uint32x4 = dtype("uint32x4") + uint64x4 = dtype("uint64x4") + uint8x8 = dtype("uint8x8") + uint16x8 = dtype("uint16x8") + uint32x8 = dtype("uint32x8") + uint64x8 = dtype("uint64x8") + uint8x16 = dtype("uint8x16") + uint16x16 = dtype("uint16x16") + uint32x16 = dtype("uint32x16") + uint64x16 = dtype("uint64x16") + uint8x32 = dtype("uint8x32") + uint16x32 = dtype("uint16x32") + uint32x32 = dtype("uint32x32") + uint64x32 = dtype("uint64x32") + uint8x64 = dtype("uint8x64") + uint16x64 = dtype("uint16x64") + uint32x64 = dtype("uint32x64") + uint64x64 = dtype("uint64x64") + float16 = dtype("float16") + float32 = dtype("float32") + float64 = dtype("float64") + float16x2 = dtype("float16x2") + float32x2 = dtype("float32x2") + float64x2 = dtype("float64x2") + float16x4 = dtype("float16x4") + float32x4 = dtype("float32x4") + float64x4 = dtype("float64x4") + float16x8 = dtype("float16x8") + float32x8 = dtype("float32x8") + float64x8 = dtype("float64x8") + float16x16 = dtype("float16x16") + float32x16 = dtype("float32x16") + float64x16 = dtype("float64x16") + float16x32 = dtype("float16x32") + float32x32 = dtype("float32x32") + float64x32 = dtype("float64x32") + float16x64 = dtype("float16x64") + float32x64 = dtype("float32x64") + float64x64 = dtype("float64x64") + float8_e3m4 = dtype("float8_e3m4") + float8_e3m4x2 = dtype("float8_e3m4x2") + float8_e3m4x4 = dtype("float8_e3m4x4") + float8_e3m4x8 = dtype("float8_e3m4x8") + float8_e3m4x16 = dtype("float8_e3m4x16") + float8_e3m4x32 = dtype("float8_e3m4x32") + float8_e3m4x64 = dtype("float8_e3m4x64") + float8_e4m3 = dtype("float8_e4m3") + float8_e4m3x2 = dtype("float8_e4m3x2") + float8_e4m3x4 = dtype("float8_e4m3x4") + float8_e4m3x8 = dtype("float8_e4m3x8") + float8_e4m3x16 = dtype("float8_e4m3x16") + float8_e4m3x32 = dtype("float8_e4m3x32") + float8_e4m3x64 = dtype("float8_e4m3x64") + float8_e4m3b11fnuz = dtype("float8_e4m3b11fnuz") + float8_e4m3b11fnuzx2 = dtype("float8_e4m3b11fnuzx2") + float8_e4m3b11fnuzx4 = dtype("float8_e4m3b11fnuzx4") + float8_e4m3b11fnuzx8 = dtype("float8_e4m3b11fnuzx8") + float8_e4m3b11fnuzx16 = dtype("float8_e4m3b11fnuzx16") + float8_e4m3b11fnuzx32 = dtype("float8_e4m3b11fnuzx32") + float8_e4m3b11fnuzx64 = dtype("float8_e4m3b11fnuzx64") + float8_e4m3fn = dtype("float8_e4m3fn") + float8_e4m3fnx2 = dtype("float8_e4m3fnx2") + float8_e4m3fnx4 = dtype("float8_e4m3fnx4") + float8_e4m3fnx8 = dtype("float8_e4m3fnx8") + float8_e4m3fnx16 = dtype("float8_e4m3fnx16") + float8_e4m3fnx32 = dtype("float8_e4m3fnx32") + float8_e4m3fnx64 = dtype("float8_e4m3fnx64") + float8_e4m3fnuz = dtype("float8_e4m3fnuz") + float8_e4m3fnuzx2 = dtype("float8_e4m3fnuzx2") + float8_e4m3fnuzx4 = dtype("float8_e4m3fnuzx4") + float8_e4m3fnuzx8 = dtype("float8_e4m3fnuzx8") + float8_e4m3fnuzx16 = dtype("float8_e4m3fnuzx16") + float8_e4m3fnuzx32 = dtype("float8_e4m3fnuzx32") + float8_e4m3fnuzx64 = dtype("float8_e4m3fnuzx64") + float8_e5m2 = dtype("float8_e5m2") + float8_e5m2x2 = dtype("float8_e5m2x2") + float8_e5m2x4 = dtype("float8_e5m2x4") + float8_e5m2x8 = dtype("float8_e5m2x8") + float8_e5m2x16 = dtype("float8_e5m2x16") + float8_e5m2x32 = dtype("float8_e5m2x32") + float8_e5m2x64 = dtype("float8_e5m2x64") + float8_e5m2fnuz = dtype("float8_e5m2fnuz") + float8_e5m2fnuzx2 = dtype("float8_e5m2fnuzx2") + float8_e5m2fnuzx4 = dtype("float8_e5m2fnuzx4") + float8_e5m2fnuzx8 = dtype("float8_e5m2fnuzx8") + float8_e5m2fnuzx16 = dtype("float8_e5m2fnuzx16") + float8_e5m2fnuzx32 = dtype("float8_e5m2fnuzx32") + float8_e5m2fnuzx64 = dtype("float8_e5m2fnuzx64") + float8_e8m0fnu = dtype("float8_e8m0fnu") + float8_e8m0fnux2 = dtype("float8_e8m0fnux2") + float8_e8m0fnux4 = dtype("float8_e8m0fnux4") + float8_e8m0fnux8 = dtype("float8_e8m0fnux8") + float8_e8m0fnux16 = dtype("float8_e8m0fnux16") + float8_e8m0fnux32 = dtype("float8_e8m0fnux32") + float8_e8m0fnux64 = dtype("float8_e8m0fnux64") + float6_e2m3fn = dtype("float6_e2m3fn") + float6_e2m3fnx2 = dtype("float6_e2m3fnx2") + float6_e2m3fnx4 = dtype("float6_e2m3fnx4") + float6_e2m3fnx8 = dtype("float6_e2m3fnx8") + float6_e2m3fnx16 = dtype("float6_e2m3fnx16") + float6_e2m3fnx32 = dtype("float6_e2m3fnx32") + float6_e2m3fnx64 = dtype("float6_e2m3fnx64") + float6_e3m2fn = dtype("float6_e3m2fn") + float6_e3m2fnx2 = dtype("float6_e3m2fnx2") + float6_e3m2fnx4 = dtype("float6_e3m2fnx4") + float6_e3m2fnx8 = dtype("float6_e3m2fnx8") + float6_e3m2fnx16 = dtype("float6_e3m2fnx16") + float6_e3m2fnx32 = dtype("float6_e3m2fnx32") + float6_e3m2fnx64 = dtype("float6_e3m2fnx64") + float4_e2m1fn = dtype("float4_e2m1fn") + float4_e2m1fnx2 = dtype("float4_e2m1fnx2") + float4_e2m1fnx4 = dtype("float4_e2m1fnx4") + float4_e2m1fnx8 = dtype("float4_e2m1fnx8") + float4_e2m1fnx16 = dtype("float4_e2m1fnx16") + float4_e2m1fnx32 = dtype("float4_e2m1fnx32") + float4_e2m1fnx64 = dtype("float4_e2m1fnx64") + bfloat16 = dtype("bfloat16") _all_dtypes = { - 'bool', - 'short', - 'int', - 'long', - 'half', - 'float', - 'double', - 'int8', - 'int16', - 'int32', - 'int64', - 'int8x2', - 'int16x2', - 'int32x2', - 'int64x2', - 'int8x4', - 'int16x4', - 'int32x4', - 'int64x4', - 'int8x8', - 'int16x8', - 'int32x8', - 'int64x8', - 'int8x16', - 'int16x16', - 'int32x16', - 'int64x16', - 'int8x32', - 'int16x32', - 'int32x32', - 'int64x32', - 'int8x64', - 'int16x64', - 'int32x64', - 'int64x64', - 'uint8', - 'uint16', - 'uint32', - 'uint64', - 'uint8x2', - 'uint16x2', - 'uint32x2', - 'uint64x2', - 'uint8x4', - 'uint16x4', - 'uint32x4', - 'uint64x4', - 'uint8x8', - 'uint16x8', - 'uint32x8', - 'uint64x8', - 'uint8x16', - 'uint16x16', - 'uint32x16', - 'uint64x16', - 'uint8x32', - 'uint16x32', - 'uint32x32', - 'uint64x32', - 'uint8x64', - 'uint16x64', - 'uint32x64', - 'uint64x64', - 'float16', - 'float32', - 'float64', - 'float16x2', - 'float32x2', - 'float64x2', - 'float16x4', - 'float32x4', - 'float64x4', - 'float16x8', - 'float32x8', - 'float64x8', - 'float16x16', - 'float32x16', - 'float64x16', - 'float16x32', - 'float32x32', - 'float64x32', - 'float16x64', - 'float32x64', - 'float64x64', - 'float8_e3m4', - 'float8_e3m4x2', - 'float8_e3m4x4', - 'float8_e3m4x8', - 'float8_e3m4x16', - 'float8_e3m4x32', - 'float8_e3m4x64', - 'float8_e4m3', - 'float8_e4m3x2', - 'float8_e4m3x4', - 'float8_e4m3x8', - 'float8_e4m3x16', - 'float8_e4m3x32', - 'float8_e4m3x64', - 'float8_e4m3b11fnuz', - 'float8_e4m3b11fnuzx2', - 'float8_e4m3b11fnuzx4', - 'float8_e4m3b11fnuzx8', - 'float8_e4m3b11fnuzx16', - 'float8_e4m3b11fnuzx32', - 'float8_e4m3b11fnuzx64', - 'float8_e4m3fn', - 'float8_e4m3fnx2', - 'float8_e4m3fnx4', - 'float8_e4m3fnx8', - 'float8_e4m3fnx16', - 'float8_e4m3fnx32', - 'float8_e4m3fnx64', - 'float8_e4m3fnuz', - 'float8_e4m3fnuzx2', - 'float8_e4m3fnuzx4', - 'float8_e4m3fnuzx8', - 'float8_e4m3fnuzx16', - 'float8_e4m3fnuzx32', - 'float8_e4m3fnuzx64', - 'float8_e5m2', - 'float8_e5m2x2', - 'float8_e5m2x4', - 'float8_e5m2x8', - 'float8_e5m2x16', - 'float8_e5m2x32', - 'float8_e5m2x64', - 'float8_e5m2fnuz', - 'float8_e5m2fnuzx2', - 'float8_e5m2fnuzx4', - 'float8_e5m2fnuzx8', - 'float8_e5m2fnuzx16', - 'float8_e5m2fnuzx32', - 'float8_e5m2fnuzx64', - 'float8_e8m0fnu', - 'float8_e8m0fnux2', - 'float8_e8m0fnux4', - 'float8_e8m0fnux8', - 'float8_e8m0fnux16', - 'float8_e8m0fnux32', - 'float8_e8m0fnux64', - 'float6_e2m3fn', - 'float6_e2m3fnx2', - 'float6_e2m3fnx4', - 'float6_e2m3fnx8', - 'float6_e2m3fnx16', - 'float6_e2m3fnx32', - 'float6_e2m3fnx64', - 'float6_e3m2fn', - 'float6_e3m2fnx2', - 'float6_e3m2fnx4', - 'float6_e3m2fnx8', - 'float6_e3m2fnx16', - 'float6_e3m2fnx32', - 'float6_e3m2fnx64', - 'float4_e2m1fn', - 'float4_e2m1fnx2', - 'float4_e2m1fnx4', - 'float4_e2m1fnx8', - 'float4_e2m1fnx16', - 'float4_e2m1fnx32', - 'float4_e2m1fnx64', - 'bfloat16', + "bool", + "short", + "int", + "long", + "half", + "float", + "double", + "int8", + "int16", + "int32", + "int64", + "int8x2", + "int16x2", + "int32x2", + "int64x2", + "int8x4", + "int16x4", + "int32x4", + "int64x4", + "int8x8", + "int16x8", + "int32x8", + "int64x8", + "int8x16", + "int16x16", + "int32x16", + "int64x16", + "int8x32", + "int16x32", + "int32x32", + "int64x32", + "int8x64", + "int16x64", + "int32x64", + "int64x64", + "uint8", + "uint16", + "uint32", + "uint64", + "uint8x2", + "uint16x2", + "uint32x2", + "uint64x2", + "uint8x4", + "uint16x4", + "uint32x4", + "uint64x4", + "uint8x8", + "uint16x8", + "uint32x8", + "uint64x8", + "uint8x16", + "uint16x16", + "uint32x16", + "uint64x16", + "uint8x32", + "uint16x32", + "uint32x32", + "uint64x32", + "uint8x64", + "uint16x64", + "uint32x64", + "uint64x64", + "float16", + "float32", + "float64", + "float16x2", + "float32x2", + "float64x2", + "float16x4", + "float32x4", + "float64x4", + "float16x8", + "float32x8", + "float64x8", + "float16x16", + "float32x16", + "float64x16", + "float16x32", + "float32x32", + "float64x32", + "float16x64", + "float32x64", + "float64x64", + "float8_e3m4", + "float8_e3m4x2", + "float8_e3m4x4", + "float8_e3m4x8", + "float8_e3m4x16", + "float8_e3m4x32", + "float8_e3m4x64", + "float8_e4m3", + "float8_e4m3x2", + "float8_e4m3x4", + "float8_e4m3x8", + "float8_e4m3x16", + "float8_e4m3x32", + "float8_e4m3x64", + "float8_e4m3b11fnuz", + "float8_e4m3b11fnuzx2", + "float8_e4m3b11fnuzx4", + "float8_e4m3b11fnuzx8", + "float8_e4m3b11fnuzx16", + "float8_e4m3b11fnuzx32", + "float8_e4m3b11fnuzx64", + "float8_e4m3fn", + "float8_e4m3fnx2", + "float8_e4m3fnx4", + "float8_e4m3fnx8", + "float8_e4m3fnx16", + "float8_e4m3fnx32", + "float8_e4m3fnx64", + "float8_e4m3fnuz", + "float8_e4m3fnuzx2", + "float8_e4m3fnuzx4", + "float8_e4m3fnuzx8", + "float8_e4m3fnuzx16", + "float8_e4m3fnuzx32", + "float8_e4m3fnuzx64", + "float8_e5m2", + "float8_e5m2x2", + "float8_e5m2x4", + "float8_e5m2x8", + "float8_e5m2x16", + "float8_e5m2x32", + "float8_e5m2x64", + "float8_e5m2fnuz", + "float8_e5m2fnuzx2", + "float8_e5m2fnuzx4", + "float8_e5m2fnuzx8", + "float8_e5m2fnuzx16", + "float8_e5m2fnuzx32", + "float8_e5m2fnuzx64", + "float8_e8m0fnu", + "float8_e8m0fnux2", + "float8_e8m0fnux4", + "float8_e8m0fnux8", + "float8_e8m0fnux16", + "float8_e8m0fnux32", + "float8_e8m0fnux64", + "float6_e2m3fn", + "float6_e2m3fnx2", + "float6_e2m3fnx4", + "float6_e2m3fnx8", + "float6_e2m3fnx16", + "float6_e2m3fnx32", + "float6_e2m3fnx64", + "float6_e3m2fn", + "float6_e3m2fnx2", + "float6_e3m2fnx4", + "float6_e3m2fnx8", + "float6_e3m2fnx16", + "float6_e3m2fnx32", + "float6_e3m2fnx64", + "float4_e2m1fn", + "float4_e2m1fnx2", + "float4_e2m1fnx4", + "float4_e2m1fnx8", + "float4_e2m1fnx16", + "float4_e2m1fnx32", + "float4_e2m1fnx64", + "bfloat16", } __all__ = list(_all_dtypes) + [ - 'dtype', - 'AnyDType', - 'get_tvm_dtype', + "dtype", + "AnyDType", + "get_tvm_dtype", ] diff --git a/tilelang/language/v2/utils.py b/tilelang/language/v2/utils.py index 022402d..207bd92 100644 --- a/tilelang/language/v2/utils.py +++ b/tilelang/language/v2/utils.py @@ -12,11 +12,12 @@ def disk_compile(source, name): cache_dir = env.TILELANG_CACHE_DIR if cache_dir is not None: import os + save_dir = os.path.join(cache_dir, "py-cache") os.makedirs(save_dir, exist_ok=True) - hash_sfx = sha256(source.encode('utf-8')).hexdigest()[:8] + hash_sfx = sha256(source.encode("utf-8")).hexdigest()[:8] path = os.path.join(save_dir, f"{name}.{hash_sfx}.py") - with open(path, 'w') as f: + with open(path, "w") as f: f.write(source) linecache.cache[path] = (len(source), None, source.splitlines(), path) return compile(source, path, "exec") @@ -59,29 +60,26 @@ def get_ast(func: Callable): filename = inspect.getsourcefile(func) or inspect.getfile(func) source = inspect.getsource(func) source = _remove_leading_ident(source) - source = '\n' * (start - 1) + source + source = "\n" * (start - 1) + source tree = ast.parse(source, filename=filename) return tree -CompileMethod = Literal['direct', 'disk'] +CompileMethod = Literal["direct", "disk"] -def get_compiled_object(source: str | ast.AST, - name: str, - filename: str = None, - globals: dict[str, Any] = None): +def get_compiled_object(source: str | ast.AST, name: str, filename: str = None, globals: dict[str, Any] = None): if isinstance(source, ast.AST): assert filename is not None, "filename must be provided when source is an AST" try: if isinstance(source, ast.AST): ast.fix_missing_locations(source) - compiled = compile(source, filename, 'exec') + compiled = compile(source, filename, "exec") else: compiled = disk_compile(source, name) except Exception as e: source_str = source if isinstance(source, str) else ast.unparse(source) - raise RuntimeError(f'Failed to compile source for {name}, Error: {e}:\n{source_str}') from e + raise RuntimeError(f"Failed to compile source for {name}, Error: {e}:\n{source_str}") from e locs = {} exec(compiled, globals, locs) return locs[name] @@ -95,7 +93,6 @@ def construct_strides(shape: tuple[Any, ...], allow_prim_expr: bool = True) -> t strides.append(stride) stride *= s if not allow_prim_expr and isinstance(stride, tir.PrimExpr): - raise ValueError( - "Cannot construct strides with PrimExpr when allow_prim_expr is False.") + raise ValueError("Cannot construct strides with PrimExpr when allow_prim_expr is False.") strides = tuple(reversed(strides)) return strides diff --git a/tilelang/language/warpgroup.py b/tilelang/language/warpgroup.py index bec7680..77cf692 100644 --- a/tilelang/language/warpgroup.py +++ b/tilelang/language/warpgroup.py @@ -1,4 +1,5 @@ """The language interface for tl programs.""" + from tvm.script.ir_builder.tir.frame import TIRFrame from tvm.ffi import register_object from tilelang import _ffi_api diff --git a/tilelang/layout/fragment.py b/tilelang/layout/fragment.py index ff45f6d..256a7d5 100644 --- a/tilelang/layout/fragment.py +++ b/tilelang/layout/fragment.py @@ -1,4 +1,5 @@ """Wrapping Layouts.""" + # pylint: disable=invalid-name, unsupported-binary-operation import tvm import tvm_ffi @@ -20,12 +21,7 @@ class Fragment(Layout): # Disable the linter warning about not calling super().__init__() # because this object is created via TVM's FFI constructor mechanism. # pylint: disable=super-init-not-called - def __init__(self, - shape, - forward_fn=None, - forward_thread_fn=None, - replicate=1, - forward_index_fn=None): + def __init__(self, shape, forward_fn=None, forward_thread_fn=None, replicate=1, forward_index_fn=None): """ Initialize the Fragment with iteration variables and optional thread replication. @@ -119,10 +115,7 @@ class Fragment(Layout): """ return _ffi_api.Fragment_thread_size(self) - def repeat(self, - repeats, - repeat_on_thread: bool = False, - lower_dim_first: bool = True) -> 'Fragment': + def repeat(self, repeats, repeat_on_thread: bool = False, lower_dim_first: bool = True) -> "Fragment": """ Returns a new Fragment that repeats the iteration space a given number of times. @@ -142,7 +135,7 @@ class Fragment(Layout): """ return _ffi_api.Fragment_repeat(self, repeats, repeat_on_thread, lower_dim_first) - def replicate(self, replicate: int) -> 'Fragment': + def replicate(self, replicate: int) -> "Fragment": """ Replicate the Fragment across a new thread dimension. @@ -158,7 +151,7 @@ class Fragment(Layout): """ return _ffi_api.Fragment_replicate(self, replicate) - def condense_rep_var(self) -> 'Fragment': + def condense_rep_var(self) -> "Fragment": """ Condense or fold the replicate variable into the existing iteration space. This operation may be used to reduce dimensionality if the replicate variable @@ -190,8 +183,7 @@ class Fragment(Layout): # The thread dimension (IterVar) is accessed via the `thread` property forward_thread = self.thread # Construct an IndexMap to map the provided args into the final thread index - index_map = IndexMap( - initial_indices=forward_vars, final_indices=[forward_thread], inverse_index_map=None) + index_map = IndexMap(initial_indices=forward_vars, final_indices=[forward_thread], inverse_index_map=None) return index_map.map_indices(indices) def __repr__(self): @@ -206,7 +198,7 @@ class Fragment(Layout): return self._DebugOutput() # return f"Fragment<{self.get_input_shape()}->{self.get_output_shape()}, thread={self.thread}, index={self.index}>" - def is_equal(self, other: 'Fragment') -> bool: + def is_equal(self, other: "Fragment") -> bool: """ Check if the current fragment is equal to another fragment. """ diff --git a/tilelang/layout/gemm_sp.py b/tilelang/layout/gemm_sp.py index e5d1902..e68c116 100644 --- a/tilelang/layout/gemm_sp.py +++ b/tilelang/layout/gemm_sp.py @@ -1,4 +1,5 @@ """Wrapping Layouts.""" + # pylint: disable=invalid-name, unsupported-binary-operation from __future__ import annotations import tvm @@ -114,8 +115,7 @@ def make_cutlass_metadata_layout_sm8x(buffer: tvm.tir.Buffer, mma_dtype: str): if mma_dtype in ["float16", "bfloat16"] and buffer.dtype not in ["uint16", "int16"]: raise ValueError(f"metadata should be 16 bit, got {buffer.dtype}") - if mma_dtype in ["float8_e4m3", "float8_e5m2", "int8", "uint8" - ] and buffer.dtype not in ["uint32", "int32"]: + if mma_dtype in ["float8_e4m3", "float8_e5m2", "int8", "uint8"] and buffer.dtype not in ["uint32", "int32"]: raise ValueError(f"metadata should be 32 bit, got {buffer.dtype}") m, k = buffer.shape @@ -134,10 +134,7 @@ def make_cutlass_metadata_layout_sm8x(buffer: tvm.tir.Buffer, mma_dtype: str): return T.Layout(buffer.shape, ColumnMajorInterleaved) -def make_cutlass_metadata_layout(buffer: tvm.tir.Buffer, - mma_dtype: str = "float16", - arch: str | None = None, - **extra_args): +def make_cutlass_metadata_layout(buffer: tvm.tir.Buffer, mma_dtype: str = "float16", arch: str | None = None, **extra_args): if arch is None: arch = nvcc.get_target_compute_version() diff --git a/tilelang/layout/layout.py b/tilelang/layout/layout.py index 87d2ee4..fbd39e8 100644 --- a/tilelang/layout/layout.py +++ b/tilelang/layout/layout.py @@ -1,4 +1,5 @@ """Wrapping Layouts.""" + # pylint: disable=invalid-name, unsupported-binary-operation import tvm_ffi from tvm.ir import Node, Range @@ -9,7 +10,6 @@ from tilelang import _ffi_api # Register the Layout class as a TVM object under the name "tl.Layout" @tvm_ffi.register_object("tl.Layout") class Layout(Node): - def __init__(self, shape, forward_fn): """ Initialize a Layout object. @@ -114,13 +114,13 @@ class Layout(Node): index_map = IndexMap( initial_indices=forward_vars, # The original iteration variables final_indices=forward_indexes, # The computed forward indices - inverse_index_map=None # No inverse mapping provided at this stage + inverse_index_map=None, # No inverse mapping provided at this stage ) # Map the provided indices using the constructed index mapping return index_map.map_indices(indices) - def inverse(self) -> 'Layout': + def inverse(self) -> "Layout": """ Compute the inverse of the current layout transformation. @@ -131,7 +131,7 @@ class Layout(Node): """ return _ffi_api.Layout_inverse(self) - def is_equal(self, other: 'Layout') -> bool: + def is_equal(self, other: "Layout") -> bool: """ Check if the current layout is equal to another layout. diff --git a/tilelang/layout/swizzle.py b/tilelang/layout/swizzle.py index 3a219c6..e083d75 100644 --- a/tilelang/layout/swizzle.py +++ b/tilelang/layout/swizzle.py @@ -1,4 +1,5 @@ """Wrapping Layouts.""" + # pylint: disable=invalid-name, unsupported-binary-operation from __future__ import annotations @@ -7,9 +8,7 @@ from tvm.tir import Buffer, BufferLoad, BufferRegion from tilelang import _ffi_api -def _get_buffer_info( - buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion -) -> tuple[Buffer, list[int], str]: +def _get_buffer_info(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> tuple[Buffer, list[int], str]: """ Extract buffer, shape, and dtype from Buffer, BufferLoad, or BufferRegion. @@ -25,12 +24,10 @@ def _get_buffer_info( buf = buffer_or_load_or_region.buffer return buf, buf.shape, buf.dtype else: - raise TypeError( - f"Expected Buffer, BufferLoad, or BufferRegion, got {type(buffer_or_load_or_region)}") + raise TypeError(f"Expected Buffer, BufferLoad, or BufferRegion, got {type(buffer_or_load_or_region)}") -def _get_stride_continuous( - buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> tuple[int, int]: +def _get_stride_continuous(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> tuple[int, int]: """ Get stride (last 2nd dimension) and continuous (last dimension) from Buffer, BufferLoad, or BufferRegion. @@ -62,9 +59,7 @@ def _get_element_size(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegi # Use a stable swizzled layout to ensure consistent memory access patterns. # Swizzling should be enabled or disabled based on whether TMA (Tensor Memory Access) is applied. -def make_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, - k_major: bool = True, - allow_pad: bool = True): +def make_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, k_major: bool = True, allow_pad: bool = True): stride, continuous = _get_stride_continuous(buffer) element_size = _get_element_size(buffer) return _ffi_api.make_swizzled_layout( @@ -77,9 +72,7 @@ def make_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, # for Volta Intrinsics -def make_volta_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, - is_a: bool = True, - k_inner: bool = True): +def make_volta_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, is_a: bool = True, k_inner: bool = True): stride, continuous = _get_stride_continuous(buffer) return _ffi_api.make_volta_swizzled_layout( stride, @@ -90,9 +83,7 @@ def make_volta_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, # for WGMMA Intrinsics -def make_wgmma_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, - continuity: int = None, - k_major: bool = True): +def make_wgmma_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, continuity: int = None, k_major: bool = True): stride, continuous = _get_stride_continuous(buffer) element_size = _get_element_size(buffer) if continuity is None: @@ -107,9 +98,7 @@ def make_wgmma_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, # for TCGEN05MMA Intrinsics -def make_tcgen05mma_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, - continuity: int = None, - k_major: bool = True): +def make_tcgen05mma_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, continuity: int = None, k_major: bool = True): stride, continuous = _get_stride_continuous(buffer) element_size = _get_element_size(buffer) if continuity is None: diff --git a/tilelang/libinfo.py b/tilelang/libinfo.py index 5af8c84..d82986b 100644 --- a/tilelang/libinfo.py +++ b/tilelang/libinfo.py @@ -31,6 +31,5 @@ def find_lib_path(name: str, py_ext=False): if os.path.exists(lib_dll_path) and os.path.isfile(lib_dll_path): return lib_dll_path else: - message = (f"Cannot find libraries: {lib_name}\n" + "List of candidates:\n" + - "\n".join(TL_LIBS)) + message = f"Cannot find libraries: {lib_name}\n" + "List of candidates:\n" + "\n".join(TL_LIBS) raise RuntimeError(message) diff --git a/tilelang/primitives/__init__.py b/tilelang/primitives/__init__.py index 8eccc3e..9d2a739 100644 --- a/tilelang/primitives/__init__.py +++ b/tilelang/primitives/__init__.py @@ -1,3 +1,3 @@ -""" bootstrap the primitives module via tile language """ +"""bootstrap the primitives module via tile language""" from .gemm import gemm # noqa: F401 diff --git a/tilelang/primitives/gemm/__init__.py b/tilelang/primitives/gemm/__init__.py index 2484374..7664a7b 100644 --- a/tilelang/primitives/gemm/__init__.py +++ b/tilelang/primitives/gemm/__init__.py @@ -3,7 +3,8 @@ from tvm import tir from tilelang.utils import is_local, is_fragment, is_shared from tilelang.primitives.gemm.base import GemmWarpPolicy from tilelang.primitives.gemm.gemm_mma import ( - GemmPrimitiveMMA,) + GemmPrimitiveMMA, +) def gemm( @@ -20,12 +21,9 @@ def gemm( policy: GemmWarpPolicy = GemmWarpPolicy.Square, k_pack: int = 1, ): - assert is_local(A) or is_fragment(A) or is_shared(A), ( - f"Expected A to be a local, fragment, or shared buffer, but got {A.scope()}") - assert is_local(B) or is_fragment(B) or is_shared(B), ( - f"Expected B to be a local, fragment, or shared buffer, but got {B.scope()}") - assert is_local(C) or is_fragment(C), ( - f"Expected C to be a local, fragment, but got {C.scope()}") + assert is_local(A) or is_fragment(A) or is_shared(A), f"Expected A to be a local, fragment, or shared buffer, but got {A.scope()}" + assert is_local(B) or is_fragment(B) or is_shared(B), f"Expected B to be a local, fragment, or shared buffer, but got {B.scope()}" + assert is_local(C) or is_fragment(C), f"Expected C to be a local, fragment, but got {C.scope()}" # TODO(lei): Now we only support Nvidia GPUs # Must enhance the design to implement runtime lowering # for different targets (hip mfma for example) diff --git a/tilelang/primitives/gemm/base.py b/tilelang/primitives/gemm/base.py index 827ff78..b7fcdca 100644 --- a/tilelang/primitives/gemm/base.py +++ b/tilelang/primitives/gemm/base.py @@ -131,7 +131,7 @@ class GemmWarpPolicy(IntEnum): # Try to find the best balanced partition best_m = 1 best_n = 1 - best_balance = float('inf') + best_balance = float("inf") # Try all possible combinations that satisfy the constraints for m in range(1, min(max_m_warps, num_warps) + 1): @@ -202,7 +202,7 @@ class GemmBaseParams: warp_row_tiles: int | None = None warp_col_tiles: int | None = None chunk: int | None = None - policy: GemmWarpPolicy = GemmWarpPolicy.Square, + policy: GemmWarpPolicy = (GemmWarpPolicy.Square,) k_pack: int = 1 def get_warp_size(self) -> int: @@ -267,17 +267,17 @@ class GemmBaseParams: # Determine whether block partition parameters need to be inferred require_infer = ( - block_row_warps is None or block_col_warps is None or warp_row_tiles is None or - warp_col_tiles is None or chunk is None) + block_row_warps is None or block_col_warps is None or warp_row_tiles is None or warp_col_tiles is None or chunk is None + ) A_shape, B_shape = A.shape, B.shape if require_infer: - assert (threads is not None), "threads must be provided for auto inference" + assert threads is not None, "threads must be provided for auto inference" # Auto-inference only supports 2D matrix multiplication - assert ( - len(A_shape) == 2 and len(B_shape) == 2 - ), f"Only support 2D matrix multiplication, got {len(A_shape)}D and {len(B_shape)}D" + assert len(A_shape) == 2 and len(B_shape) == 2, ( + f"Only support 2D matrix multiplication, got {len(A_shape)}D and {len(B_shape)}D" + ) # Analyze A/B shapes AM = A_shape[1] if transpose_A else A_shape[0] # M dimension @@ -291,8 +291,7 @@ class GemmBaseParams: num_warps = threads // warp_size # Infer block partition using a user-specified policy - block_row_warps, block_col_warps = policy.compute_warp_partition( - block_M, block_N, num_warps) + block_row_warps, block_col_warps = policy.compute_warp_partition(block_M, block_N, num_warps) warp_row_tiles = block_M // block_row_warps warp_col_tiles = block_N // block_col_warps chunk = int(AK) diff --git a/tilelang/primitives/gemm/gemm_mma.py b/tilelang/primitives/gemm/gemm_mma.py index 11e1683..7ca3208 100644 --- a/tilelang/primitives/gemm/gemm_mma.py +++ b/tilelang/primitives/gemm/gemm_mma.py @@ -31,7 +31,6 @@ class GemmPrimitiveMMA(GemmBaseParams): C: tir.Buffer, mma_emitter: TensorCoreIntrinEmitter, ) -> tir.PrimExpr: - in_dtype = self.in_dtype warp_cols = mma_emitter.warp_cols local_size_b = mma_emitter.local_size_b @@ -53,21 +52,24 @@ class GemmPrimitiveMMA(GemmBaseParams): if a_is_fragment: # Annotate layout for A_local if it is a fragment. - T.annotate_layout({ - A_local: mma_emitter.make_mma_load_layout(A_local, "A"), - }) + T.annotate_layout( + { + A_local: mma_emitter.make_mma_load_layout(A_local, "A"), + } + ) if c_is_fragment: # Annotate layout for C_local if it is a fragment. - T.annotate_layout({ - C_local: mma_emitter.make_mma_store_layout(C_local), - }) + T.annotate_layout( + { + C_local: mma_emitter.make_mma_store_layout(C_local), + } + ) # Make default swizzle layout for shared memory # T.annotate_layout({ # B_shared: make_mma_swizzle_layout(B_shared), # }) for ki in T.serial(0, (block_K // micro_size_k)): - # Load B into fragment mma_emitter.ldmatrix_b( B_local, @@ -146,9 +148,11 @@ class GemmPrimitiveMMA(GemmBaseParams): if c_is_fragment: # Annotate layout for C_local if it is a fragment. - T.annotate_layout({ - C_local: mma_emitter.make_mma_store_layout(C_local), - }) + T.annotate_layout( + { + C_local: mma_emitter.make_mma_store_layout(C_local), + } + ) for ki in T.serial(0, (block_K // micro_size_k)): # Load A into fragment diff --git a/tilelang/profiler/__init__.py b/tilelang/profiler/__init__.py index 4750fa7..94d3501 100644 --- a/tilelang/profiler/__init__.py +++ b/tilelang/profiler/__init__.py @@ -1,4 +1,5 @@ """The profiler and convert to torch utils""" + from __future__ import annotations from typing import Callable, Any, Literal from functools import partial @@ -45,8 +46,7 @@ class Profiler: result_idx = [] elif isinstance(result_idx, int): if result_idx > len(params) or result_idx < -len(params): - raise ValueError( - f"result_idx should be an integer between {-len(params)} and {len(params) - 1}") + raise ValueError(f"result_idx should be an integer between {-len(params)} and {len(params) - 1}") if result_idx < 0: result_idx = len(params) + result_idx result_idx = [result_idx] @@ -113,8 +113,7 @@ class Profiler: ref_tensors = ins + ref_outs lib_tensors = ins + lib_outs - assert len(lib_tensors) == len( - ref_tensors), "len(lib_tensors) not equals to len(ref_tensors) !" + assert len(lib_tensors) == len(ref_tensors), "len(lib_tensors) not equals to len(ref_tensors) !" # torch.set_printoptions(edgeitems=torch.inf) for lhs, rhs in zip(lib_tensors, ref_tensors): # close_mask = torch.isclose(lhs, rhs, rtol=rtol, atol=atol) @@ -252,10 +251,9 @@ class Profiler: ) elif profiler == "tvm": assert func is not None, "func should not be None" - assert isinstance( - func, tvm.runtime.Module), f"func should be a TVM module, but got {type(func)}" + assert isinstance(func, tvm.runtime.Module), f"func should be a TVM module, but got {type(func)}" - ins = (self._get_inputs(with_output=True) if input_tensors is None else input_tensors) + ins = self._get_inputs(with_output=True) if input_tensors is None else input_tensors target = "cuda" with suppress(Exception): @@ -264,8 +262,7 @@ class Profiler: assert target in ["cuda", "hip"], f"Unknown target: {target}" device = tvm.cuda(0) if target == "cuda" else tvm.rocm(0) - time_evaluator = self.mod.time_evaluator( - self.mod.entry_name, device, number=rep, repeat=n_repeat) + time_evaluator = self.mod.time_evaluator(self.mod.entry_name, device, number=rep, repeat=n_repeat) # Transform Latency to ms return time_evaluator(*ins).mean * 1e3 else: diff --git a/tilelang/profiler/bench.py b/tilelang/profiler/bench.py index a851ceb..bfcb504 100644 --- a/tilelang/profiler/bench.py +++ b/tilelang/profiler/bench.py @@ -1,4 +1,5 @@ """Profiler and benchmarking utilities for PyTorch functions.""" + from __future__ import annotations import os @@ -16,8 +17,8 @@ class suppress_stdout_stderr: def __enter__(self): # Open null device files - self.outnull_file = open(os.devnull, 'w') - self.errnull_file = open(os.devnull, 'w') + self.outnull_file = open(os.devnull, "w") + self.errnull_file = open(os.devnull, "w") # Save original file descriptors self.old_stdout_fileno_undup = sys.stdout.fileno() @@ -56,7 +57,7 @@ class suppress_stdout_stderr: IS_CUDA = torch.cuda.is_available() -device = 'cuda:0' if IS_CUDA else 'mps:0' +device = "cuda:0" if IS_CUDA else "mps:0" Event = torch.cuda.Event if IS_CUDA else torch.mps.Event @@ -93,8 +94,7 @@ def do_bench( Returns: Runtime in milliseconds (float) or list of quantile values if quantiles specified """ - assert return_mode in ["min", "max", "mean", "median"], \ - f"Invalid return_mode: {return_mode}" + assert return_mode in ["min", "max", "mean", "median"], f"Invalid return_mode: {return_mode}" # Initial function call and synchronization fn() diff --git a/tilelang/quantize/lop3.py b/tilelang/quantize/lop3.py index e4e7f7e..e0788da 100644 --- a/tilelang/quantize/lop3.py +++ b/tilelang/quantize/lop3.py @@ -1130,16 +1130,13 @@ def get_lop3_intrin_group( Dict[str, str] A dictionary mapping the names of the intrinsics to their corresponding implementations. """ - assert out_dtype in [ - "float16", "int8", "int4" - ], (f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'int8' or 'int4' .") + assert out_dtype in ["float16", "int8", "int4"], f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'int8' or 'int4' ." dtype_mapping = {"float16": "f16", "int4": "i4", "int8": "i8", "int32": "i32"} target_dtype = dtype_mapping[out_dtype] if source_format not in ["int", "uint"]: - raise ValueError( - f"Invalid source_format. Expected 'int' or 'uint', but got {source_format}.") + raise ValueError(f"Invalid source_format. Expected 'int' or 'uint', but got {source_format}.") if with_zeros and source_format == "int": raise ValueError(f"Zeros are not supported for signed integers, but got {source_format}") diff --git a/tilelang/quantize/mxfp.py b/tilelang/quantize/mxfp.py index 80f3e06..e5c472c 100644 --- a/tilelang/quantize/mxfp.py +++ b/tilelang/quantize/mxfp.py @@ -80,13 +80,9 @@ def get_mxfp_intrin_group( AssertionError: if out_dtype, source_format, or storage_dtype are not supported. KeyError: if the constructed key does not match any available C source implementation. """ - assert out_dtype in ["float16", "bfloat16" - ], f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'bfloat16'." - assert source_format in ["int", "uint" - ], f"Invalid source_format: {source_format}. Expected 'int' or 'uint'." - assert storage_dtype in [ - "int32", "int8", "uint8" - ], f"Invalid storage_dtype: {storage_dtype}. Expected 'int32' or 'int8' or 'uint8'." + assert out_dtype in ["float16", "bfloat16"], f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'bfloat16'." + assert source_format in ["int", "uint"], f"Invalid source_format: {source_format}. Expected 'int' or 'uint'." + assert storage_dtype in ["int32", "int8", "uint8"], f"Invalid storage_dtype: {storage_dtype}. Expected 'int32' or 'int8' or 'uint8'." dtype_map = {"float16": "f16", "bfloat16": "bf16"} key = f"fp{source_bit}_to_{dtype_map[out_dtype]}" diff --git a/tilelang/quantize/utils.py b/tilelang/quantize/utils.py index 2447ca1..2d092a0 100644 --- a/tilelang/quantize/utils.py +++ b/tilelang/quantize/utils.py @@ -1,6 +1,7 @@ def gen_quant4(k, n, groupsize=-1): import torch import torch.nn as nn + maxq = 2**4 w = torch.randn((k, n), dtype=torch.half, device="cpu") @@ -48,6 +49,7 @@ def gen_quant4(k, n, groupsize=-1): def general_compress(lowprecision_weight, source_bits=4, storage_dtype=None): import torch + if storage_dtype is None: storage_dtype = torch.int8 elems_per_byte = 8 // source_bits @@ -56,11 +58,11 @@ def general_compress(lowprecision_weight, source_bits=4, storage_dtype=None): int8_weight = torch.zeros( (*lowprecision_weight.shape[:-1], lowprecision_weight.shape[-1] // elems_per_byte), dtype=torch.int8, - device=lowprecision_weight.device) + device=lowprecision_weight.device, + ) for j in range(lowprecision_weight.shape[-1] // elems_per_byte): for k in range(elems_per_byte): - int8_weight[..., j] |= (lowprecision_weight[..., j * elems_per_byte + k] << - (source_bits * k)).to(torch.int8) + int8_weight[..., j] |= (lowprecision_weight[..., j * elems_per_byte + k] << (source_bits * k)).to(torch.int8) return int8_weight.to(storage_dtype) @@ -82,6 +84,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): interleave_weight(qweight, 4, "float16") """ import torch + assert target_dtype in ["float16", "int8"] # reinterpret the data type of qweight to int32 qweight = qweight.view(torch.int32) diff --git a/tilelang/testing/__init__.py b/tilelang/testing/__init__.py index 6a20314..635fad3 100644 --- a/tilelang/testing/__init__.py +++ b/tilelang/testing/__init__.py @@ -5,20 +5,19 @@ import random import torch import numpy as np from tilelang.contrib import nvcc -from tvm.testing.utils import (requires_cuda, requires_package, requires_llvm, requires_metal, - requires_rocm, _compose) +from tvm.testing.utils import requires_cuda, requires_package, requires_llvm, requires_metal, requires_rocm, _compose from tilelang.utils.tensor import torch_assert_close as torch_assert_close __all__ = [ - 'requires_package', - 'requires_cuda', - 'requires_metal', - 'requires_rocm', - 'requires_llvm', - 'main', - 'requires_cuda_compute_version', -] + [f'requires_cuda_compute_version_{op}' for op in ('ge', 'gt', 'le', 'lt', 'eq')] + "requires_package", + "requires_cuda", + "requires_metal", + "requires_rocm", + "requires_llvm", + "main", + "requires_cuda_compute_version", +] + [f"requires_cuda_compute_version_{op}" for op in ("ge", "gt", "le", "lt", "eq")] # pytest.main() wrapper to allow running single test file diff --git a/tilelang/tileop/gemm/__init__.py b/tilelang/tileop/gemm/__init__.py index 9096090..4d2caf8 100644 --- a/tilelang/tileop/gemm/__init__.py +++ b/tilelang/tileop/gemm/__init__.py @@ -23,8 +23,7 @@ def gemm_py_infer_layout(gemm_py: GemmMMA, target: Target, thread_bounds: Range) @tvm_ffi.register_global_func("tl.gemm_py.lower") -def gemm_py_lower(gemm_py: GemmMMA, layout_map, target: Target, thread_bounds: Range, - thread_var: tir.Var): +def gemm_py_lower(gemm_py: GemmMMA, layout_map, target: Target, thread_bounds: Range, thread_var: tir.Var): thread_nums = thread_bounds.extent stmt = gemm_py.lower(layout_map, target, thread_nums, thread_var) return stmt diff --git a/tilelang/tileop/gemm/gemm_mfma.py b/tilelang/tileop/gemm/gemm_mfma.py index 862ec72..d827d8a 100644 --- a/tilelang/tileop/gemm/gemm_mfma.py +++ b/tilelang/tileop/gemm/gemm_mfma.py @@ -1,7 +1,8 @@ from .gemm_base import GemmBase from tilelang.layout import make_swizzled_layout from tilelang.intrinsics.mfma_macro_generator import ( - MatrixCoreIntrinEmitter,) + MatrixCoreIntrinEmitter, +) from tilelang.utils.language import is_shared, is_fragment, is_full_region from tilelang import tvm as tvm from tvm.target import Target @@ -11,10 +12,8 @@ from tilelang.transform.simplify import _Simplify class GemmMFMA(GemmBase): - def infer_layout(self, target: Target, thread_nums: int): - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, - False) + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) mfma_emitter = MatrixCoreIntrinEmitter( @@ -56,12 +55,10 @@ class GemmMFMA(GemmBase): self.C: mfma_emitter.make_mfma_store_layout(self.C), } else: - raise ValueError( - f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, - False) + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) mfma_emitter = MatrixCoreIntrinEmitter( @@ -153,7 +150,6 @@ class GemmMFMA(GemmBase): T.clear(C_buf) for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))): - # Load A into fragment mfma_emitter.ldmatrix_a( A_local, @@ -183,7 +179,6 @@ class GemmMFMA(GemmBase): if clear_accum: T.clear(C_buf) for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))): - # Load B into fragment mfma_emitter.ldmatrix_b( B_local, @@ -217,8 +212,7 @@ class GemmMFMA(GemmBase): # Must inline let statements to simplify the analysis return _Simplify(_gemm_rsr, inline_let=True) else: - raise ValueError( - f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") def is_gemm_ss(self) -> bool: return is_shared(self.A) and is_shared(self.B) diff --git a/tilelang/tileop/gemm/gemm_mma.py b/tilelang/tileop/gemm/gemm_mma.py index ce27409..b151734 100644 --- a/tilelang/tileop/gemm/gemm_mma.py +++ b/tilelang/tileop/gemm/gemm_mma.py @@ -1,7 +1,8 @@ from .gemm_base import GemmBase from tilelang.layout import make_swizzled_layout from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitter,) + TensorCoreIntrinEmitter, +) from tilelang.utils.language import is_shared, is_fragment, is_full_region from tilelang import tvm as tvm from tvm.target import Target @@ -11,10 +12,8 @@ from tilelang.transform.simplify import _Simplify class GemmMMA(GemmBase): - def infer_layout(self, target: Target, thread_nums: int): - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, - False) + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) mma_emitter = TensorCoreIntrinEmitter( @@ -54,12 +53,10 @@ class GemmMMA(GemmBase): self.C: mma_emitter.make_mma_store_layout(self.C), } else: - raise ValueError( - f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, - False) + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) mma_emitter = TensorCoreIntrinEmitter( @@ -177,7 +174,6 @@ class GemmMMA(GemmBase): if clear_accum: T.clear(C_buf) for ki in T.serial(0, (block_K // micro_size_k)): - # Load B into fragment mma_emitter.ldmatrix_b( B_local, @@ -211,8 +207,7 @@ class GemmMMA(GemmBase): # Must inline let statements to simplify the analysis return _Simplify(_gemm_rrr, inline_let=True) else: - raise ValueError( - f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") def is_gemm_ss(self) -> bool: return is_shared(self.A) and is_shared(self.B) diff --git a/tilelang/tileop/gemm/gemm_mma_sm70.py b/tilelang/tileop/gemm/gemm_mma_sm70.py index 12b729c..52a4bf3 100644 --- a/tilelang/tileop/gemm/gemm_mma_sm70.py +++ b/tilelang/tileop/gemm/gemm_mma_sm70.py @@ -2,7 +2,8 @@ from .gemm_base import GemmBase from tilelang.layout import make_volta_swizzled_layout from tilelang.intrinsics.mma_sm70_macro_generator import ( - TensorCoreIntrinEmitter,) + TensorCoreIntrinEmitter, +) from tilelang.utils.language import is_shared, is_fragment, is_full_region from tilelang import tvm as tvm from tvm.target import Target @@ -12,10 +13,8 @@ from tilelang.transform.simplify import _Simplify class GemmMMASm70(GemmBase): - def infer_layout(self, target: Target, thread_nums: int): - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, - False) + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) mma_emitter = TensorCoreIntrinEmitter( @@ -45,12 +44,10 @@ class GemmMMASm70(GemmBase): self.C: mma_emitter.make_mma_store_layout(self.C), } else: - raise ValueError( - f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, - False) + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) mma_emitter = TensorCoreIntrinEmitter( @@ -140,7 +137,6 @@ class GemmMMASm70(GemmBase): T.clear(C_buf) for ki in T.serial(0, (block_K // micro_size_k)): - # Load B into fragment mma_emitter.ldmatrix_b( B_local, @@ -155,8 +151,7 @@ class GemmMMASm70(GemmBase): # Must inline let statements to simplify the analysis return _Simplify(_gemm_rsr, inline_let=True) else: - raise ValueError( - f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") def is_gemm_ss(self) -> bool: return is_shared(self.A) and is_shared(self.B) diff --git a/tilelang/tileop/gemm/gemm_tcgen05.py b/tilelang/tileop/gemm/gemm_tcgen05.py index 76f919e..f93a403 100644 --- a/tilelang/tileop/gemm/gemm_tcgen05.py +++ b/tilelang/tileop/gemm/gemm_tcgen05.py @@ -1,7 +1,8 @@ from .gemm_base import GemmBase from tilelang.layout import make_tcgen05mma_swizzled_layout from tilelang.intrinsics.tcgen05_macro_generator import ( - TensorCoreIntrinEmitter,) + TensorCoreIntrinEmitter, +) from tilelang import language as T from tilelang.transform.simplify import _Simplify from tvm import tir @@ -18,10 +19,8 @@ _FLOAT8_DTYPES = { class GemmTCGEN5(GemmBase): - def infer_layout(self, target: Target, thread_nums: int): - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, - True) + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, True) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) mma_emitter = TensorCoreIntrinEmitter( @@ -40,27 +39,20 @@ class GemmTCGEN5(GemmBase): b_is_k_major = self.trans_B if self.is_gemm_ss(): - a_continuity = self.M if a_is_k_major else 4 * self.K // m_warp b_continuity = self.K if b_is_k_major else self.N // n_warp return { # WGMMA does not support padding - self.A: - make_tcgen05mma_swizzled_layout( - self.A, continuity=a_continuity, k_major=a_is_k_major), - self.B: - make_tcgen05mma_swizzled_layout( - self.B, continuity=b_continuity, k_major=b_is_k_major), - self.C: - mma_emitter.make_mma_store_layout(self.C), + self.A: make_tcgen05mma_swizzled_layout(self.A, continuity=a_continuity, k_major=a_is_k_major), + self.B: make_tcgen05mma_swizzled_layout(self.B, continuity=b_continuity, k_major=b_is_k_major), + self.C: mma_emitter.make_mma_store_layout(self.C), } # No special swizzle requirement; rely on existing layout. return {} def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, - True) + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, True) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) mma_emitter = TensorCoreIntrinEmitter( @@ -82,11 +74,9 @@ class GemmTCGEN5(GemmBase): mma_emitter._assign_b_shared_layout(layout_map[self.B]) if not self.is_gemm_ss(): - raise ValueError(f"TCGEN5MMA currently only supports gemm_ss, got " - f"A scope {self.A.scope()}, B scope {self.B.scope()}") + raise ValueError(f"TCGEN5MMA currently only supports gemm_ss, got A scope {self.A.scope()}, B scope {self.B.scope()}") - atom_m, atom_n, atom_k, enable_ws, enable_2cta = mma_emitter.get_tcgen5_mma_meta( - self.M, self.N, self.K) + atom_m, atom_n, atom_k, enable_ws, enable_2cta = mma_emitter.get_tcgen5_mma_meta(self.M, self.N, self.K) if self.A.scope() not in {"shared", "shared.dyn", "shared.tmem"}: raise ValueError(f"Unsupported A scope for TCGEN5MMA: {self.A.scope()}") @@ -108,7 +98,7 @@ class GemmTCGEN5(GemmBase): raise ValueError("TCGEN5MMA expects 2D coordinates for C buffer access") accum_dtype = str(self.C.dtype) - if accum_dtype not in ["float32", 'float16']: + if accum_dtype not in ["float32", "float16"]: raise ValueError(f"Unsupported accumulator dtype for TCGEN5MMA: {accum_dtype}") A_shared = self.ARegion diff --git a/tilelang/tileop/gemm/gemm_wgmma.py b/tilelang/tileop/gemm/gemm_wgmma.py index 2325f45..038aa2c 100644 --- a/tilelang/tileop/gemm/gemm_wgmma.py +++ b/tilelang/tileop/gemm/gemm_wgmma.py @@ -1,7 +1,8 @@ from .gemm_base import GemmBase from tilelang.layout import make_wgmma_swizzled_layout from tilelang.intrinsics.wgmma_macro_generator import ( - TensorCoreIntrinEmitter,) + TensorCoreIntrinEmitter, +) from tilelang.utils.language import is_shared, is_fragment from tilelang import tvm as tvm from tvm.target import Target @@ -11,10 +12,8 @@ from tilelang.transform.simplify import _Simplify class GemmWGMMA(GemmBase): - def infer_layout(self, target: Target, thread_nums: int): - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, - True) + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, True) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) mma_emitter = TensorCoreIntrinEmitter( @@ -38,33 +37,22 @@ class GemmWGMMA(GemmBase): return { # WGMMA does not support padding - self.A: - make_wgmma_swizzled_layout( - self.A, continuity=a_continuity, k_major=a_is_k_major), - self.B: - make_wgmma_swizzled_layout( - self.B, continuity=b_continuity, k_major=b_is_k_major), - self.C: - mma_emitter.make_mma_store_layout(self.C), + self.A: make_wgmma_swizzled_layout(self.A, continuity=a_continuity, k_major=a_is_k_major), + self.B: make_wgmma_swizzled_layout(self.B, continuity=b_continuity, k_major=b_is_k_major), + self.C: mma_emitter.make_mma_store_layout(self.C), } elif self.is_gemm_rs(): b_continuity = self.N if b_is_k_major else 4 * self.K // n_warp return { - self.A: - mma_emitter.make_mma_load_layout(self.A, matrix="A"), - self.B: - make_wgmma_swizzled_layout( - self.B, continuity=b_continuity, k_major=b_is_k_major), - self.C: - mma_emitter.make_mma_store_layout(self.C), + self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"), + self.B: make_wgmma_swizzled_layout(self.B, continuity=b_continuity, k_major=b_is_k_major), + self.C: mma_emitter.make_mma_store_layout(self.C), } else: - raise ValueError( - f"Unsupported gemm combination for wgmma, A: {self.A.scope()}, B: {self.B.scope()}") + raise ValueError(f"Unsupported gemm combination for wgmma, A: {self.A.scope()}, B: {self.B.scope()}") def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, - True) + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, True) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) @@ -133,8 +121,7 @@ class GemmWGMMA(GemmBase): # Simplify to optimize the index computing # Must inline let statements to simplify the analysis return _Simplify(_gemm_rsr, inline_let=True) - raise ValueError( - f"Unsupported gemm combination for wgmma, A: {self.A.scope()}, B: {self.B.scope()}") + raise ValueError(f"Unsupported gemm combination for wgmma, A: {self.A.scope()}, B: {self.B.scope()}") def is_gemm_ss(self) -> bool: return is_shared(self.A) and is_shared(self.B) diff --git a/tilelang/tileop/gemm_sp/__init__.py b/tilelang/tileop/gemm_sp/__init__.py index fdac694..c22bca8 100644 --- a/tilelang/tileop/gemm_sp/__init__.py +++ b/tilelang/tileop/gemm_sp/__init__.py @@ -1,7 +1,8 @@ from tilelang import tvm as tvm from tvm import tir from tilelang.utils.target import ( - target_is_cuda,) + target_is_cuda, +) from tvm.target import Target from tvm.ir.base import Node from tvm.ir import Range @@ -18,8 +19,7 @@ def gemm_sp_py_infer_layout(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds @tvm_ffi.register_global_func("tl.gemm_sp_py.lower") -def gemm_sp_py_lower(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range, - thread_var: tir.Var): +def gemm_sp_py_lower(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range, thread_var: tir.Var): thread_nums = thread_bounds.extent stmt = gemm_sp_py.lower(target, thread_nums, thread_var) return stmt diff --git a/tilelang/tileop/gemm_sp/gemm_sp_mma.py b/tilelang/tileop/gemm_sp/gemm_sp_mma.py index 50a40bb..76a0d4a 100644 --- a/tilelang/tileop/gemm_sp/gemm_sp_mma.py +++ b/tilelang/tileop/gemm_sp/gemm_sp_mma.py @@ -10,10 +10,8 @@ from tilelang.transform.simplify import _Simplify class GemmSPMMA(GemmSPBase): - def infer_layout(self, target: Target, thread_nums: int): - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, - False) + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) mma_emitter = SparseTensorCoreIntrinEmitter( @@ -55,12 +53,10 @@ class GemmSPMMA(GemmSPBase): self.C: mma_emitter.make_mma_store_layout(self.C), } else: - raise ValueError( - f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") def lower(self, target: Target, thread_nums: int, thread_var: tir.Var): - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, - False) + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) mma_emitter = SparseTensorCoreIntrinEmitter( @@ -146,7 +142,6 @@ class GemmSPMMA(GemmSPBase): E_local = T.alloc_local((warp_rows * local_size_e), self.e_dtype) for ki in T.serial(0, (self.K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a( A_local, @@ -231,8 +226,7 @@ class GemmSPMMA(GemmSPBase): # Must inline let statements to simplify the analysis return _Simplify(_gemm_rrr, inline_let=True) else: - raise ValueError( - f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") def is_gemm_ss(self) -> bool: return is_shared(self.A) and is_shared(self.B) diff --git a/tilelang/tools/Analyzer.py b/tilelang/tools/Analyzer.py index 205c647..3af5222 100644 --- a/tilelang/tools/Analyzer.py +++ b/tilelang/tools/Analyzer.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from tilelang import tvm from tvm.tir.stmt_functor import ir_transform import logging + # Configuration for different hardware architectures. # Each entry contains: (cores per SM, default clock (GHz), FLOPs per cycle, max SM count) ARCH_CONFIGS = {"80": (128, 1.41, 2, 108), "86": (128, 1.70, 2, 84), "89": (128, 2.52, 2, 128)} @@ -23,6 +24,7 @@ class AnalysisResult: tflops: Achieved TFLOPS (trillions of FLOPs per second). bandwidth_GBps: Achieved memory bandwidth in GB/s. """ + total_flops: int total_global_bytes: int estimated_time: float @@ -81,7 +83,7 @@ class Analyzer: # Account for loop and block dimensions loop_product = 1 for extent in self.loop_stack: - loop_product *= extent.value if hasattr(extent, 'value') else extent + loop_product *= extent.value if hasattr(extent, "value") else extent total_blocks = self.block_counts["blockIdx.x"] * self.block_counts["blockIdx.y"] total_bytes = bytes_transferred * loop_product * total_blocks self.total_global_bytes += total_bytes @@ -100,7 +102,7 @@ class Analyzer: # Account for loop and block dimensions loop_product = 1 for extent in self.loop_stack: - loop_product *= extent.value if hasattr(extent, 'value') else extent + loop_product *= extent.value if hasattr(extent, "value") else extent total_blocks = self.block_counts["blockIdx.x"] * self.block_counts["blockIdx.y"] self.total_flops += flops_per_call * loop_product * total_blocks @@ -127,8 +129,7 @@ class Analyzer: iter_var = stmt.node thread_tag = iter_var.thread_tag if thread_tag in self.block_counts: - extent = stmt.value.value if hasattr(stmt.value, - 'value') else stmt.value + extent = stmt.value.value if hasattr(stmt.value, "value") else stmt.value self.block_counts[thread_tag] = extent elif isinstance(stmt, tvm.tir.For): # Push loop extent onto the stack @@ -178,9 +179,7 @@ class Analyzer: """ arch_key = device.compute_capability[:2] if arch_key not in ARCH_CONFIGS: - logger.info( - f"Unsupported compute capability: {device.compute_capability}, theoretical peak tflops will be None" - ) + logger.info(f"Unsupported compute capability: {device.compute_capability}, theoretical peak tflops will be None") return None cores_per_sm, default_clock, flops_per_cycle, compute_max_core = ARCH_CONFIGS[arch_key] @@ -203,7 +202,8 @@ class Analyzer: total_global_bytes=self.total_global_bytes, estimated_time=estimated_time, expected_tflops=peak_tflops, - expected_bandwidth_GBps=bandwidth_GBps) + expected_bandwidth_GBps=bandwidth_GBps, + ) @classmethod def analysis(cls, fn, device): diff --git a/tilelang/tools/plot_layout.py b/tilelang/tools/plot_layout.py index 06e01f4..299c3e8 100644 --- a/tilelang/tools/plot_layout.py +++ b/tilelang/tools/plot_layout.py @@ -2,12 +2,14 @@ from __future__ import annotations import tilelang.language as T -def plot_layout(layout: T.Fragment, - save_directory="./tmp", - name: str = "layout", - colormap: str = "RdPu", - verbose: bool = False, - formats: str | list[str] = "png") -> None: +def plot_layout( + layout: T.Fragment, + save_directory="./tmp", + name: str = "layout", + colormap: str = "RdPu", + verbose: bool = False, + formats: str | list[str] = "png", +) -> None: """ Plot the layout of a buffer. @@ -90,11 +92,13 @@ def plot_layout(layout: T.Fragment, # Warn if the number of threads is less than the warp size if num_threads < warp_size: import warnings + warnings.warn( f"Layout visualization has {num_threads} threads, which is less than the warp size ({warp_size}). " f"For the best viewing experience, it is recommended to have at least {warp_size} threads.", UserWarning, - stacklevel=2) + stacklevel=2, + ) spectral_camp = plt.get_cmap("hsv", warp_size * 6) for i in range(min(warp_size, num_threads)): @@ -118,12 +122,7 @@ def plot_layout(layout: T.Fragment, color = colors[thread_ids[0]] # Select color based on thread ID # Create a rectangle patch for visualization - rect = patches.Rectangle((j, i), - 1, - 1, - linewidth=0.5, - edgecolor='black', - facecolor=color) + rect = patches.Rectangle((j, i), 1, 1, linewidth=0.5, edgecolor="black", facecolor=color) ax.add_patch(rect) # Add the rectangle to the plot # Add text annotations inside the rectangles @@ -139,41 +138,19 @@ def plot_layout(layout: T.Fragment, thread_fontsize = min(font_size, font_size * (4 / len(thread_str))) # Add thread ID text with adjusted font size - ax.text( - j + 0.5, - i + 0.3, - thread_str, - ha='center', - va='center', - color='black', - fontsize=thread_fontsize) + ax.text(j + 0.5, i + 0.3, thread_str, ha="center", va="center", color="black", fontsize=thread_fontsize) # Add local ID text with original font size - ax.text( - j + 0.5, - i + 0.7, - f"L{local_id}", - ha='center', - va='center', - color='black', - fontsize=font_size) + ax.text(j + 0.5, i + 0.7, f"L{local_id}", ha="center", va="center", color="black", fontsize=font_size) # Add row labels to the left side of the plot for i in range(nrows): text = f"row {i}" - ax.text(-0.75, i + 0.5, text, ha='center', va='center', color='black', fontsize=font_size) + ax.text(-0.75, i + 0.5, text, ha="center", va="center", color="black", fontsize=font_size) # Add column labels at the top of the plot for j in range(ncols): text = f"col {j}" - ax.text( - j + 0.5, - -0.5, - text, - ha='center', - va='center', - color='black', - fontsize=font_size, - rotation=45) + ax.text(j + 0.5, -0.5, text, ha="center", va="center", color="black", fontsize=font_size, rotation=45) # Set the plot limits ax.set_xlim(0, ncols) @@ -189,17 +166,15 @@ def plot_layout(layout: T.Fragment, legend_x = 1.0 + (0.5 / fig_width) # Adjust x position based on figure width legend_y = 1.0 + (1.7 / fig_height) # Adjust y position based on figure height - legend_patches = [ - patches.Patch(color='black', label="T: Thread ID"), - patches.Patch(color='black', label="L: Local ID") - ] + legend_patches = [patches.Patch(color="black", label="T: Thread ID"), patches.Patch(color="black", label="L: Local ID")] ax.legend( handles=legend_patches, loc="upper right", fontsize=font_size - 4, frameon=False, bbox_to_anchor=(legend_x, legend_y), # Dynamic position - ncols=2) + ncols=2, + ) # Create the output directory if it does not exist tmp_directory = pathlib.Path(save_directory) @@ -211,28 +186,29 @@ def plot_layout(layout: T.Fragment, if isinstance(formats, str): formats_str = formats.strip().lower() - if formats_str == 'all': - formats_list = ['pdf', 'png', 'svg'] + if formats_str == "all": + formats_list = ["pdf", "png", "svg"] elif "," in formats_str: - formats_list = [f.strip() for f in formats_str.split(',')] + formats_list = [f.strip() for f in formats_str.split(",")] else: formats_list = [formats_str] else: - raise TypeError(f"Expected str, but got {type(formats).__name__}. " - f"Please pass a string like 'png', 'pdf', 'svg', 'all', or 'png,pdf'.") + raise TypeError( + f"Expected str, but got {type(formats).__name__}. Please pass a string like 'png', 'pdf', 'svg', 'all', or 'png,pdf'." + ) # Save the figure - if 'pdf' in formats_list: + if "pdf" in formats_list: pdf_path = tmp_directory / f"{name}.pdf" plt.savefig(pdf_path, bbox_inches="tight") print(f"Saved pdf format into {pdf_path}") - if 'png' in formats_list: + if "png" in formats_list: png_path = tmp_directory / f"{name}.png" plt.savefig(png_path, bbox_inches="tight", transparent=False, dpi=255) print(f"Saved png format into {png_path}") - if 'svg' in formats_list: + if "svg" in formats_list: svg_path = tmp_directory / f"{name}.svg" plt.savefig(svg_path, bbox_inches="tight", format="svg") print(f"Saved svg format into {svg_path}") diff --git a/tilelang/transform/__init__.py b/tilelang/transform/__init__.py index a86ffe2..bb9202a 100644 --- a/tilelang/transform/__init__.py +++ b/tilelang/transform/__init__.py @@ -110,8 +110,7 @@ def LowerHopperIntrin(): fpass : tvm.transform.Pass The result pass """ - return (_ffi_api.LowerHopperIntrin() if hasattr(_ffi_api, "LowerHopperIntrin") else lambda f: f - ) # type: ignore + return _ffi_api.LowerHopperIntrin() if hasattr(_ffi_api, "LowerHopperIntrin") else lambda f: f # type: ignore def WarpSpecializedPipeline(): @@ -365,8 +364,7 @@ def FlattenBuffer(): def EliminateStorageSyncForMBarrier(): - """EliminateStorageSyncForMBarrier - """ + """EliminateStorageSyncForMBarrier""" return _ffi_api.EliminateStorageSyncForMBarrier() # type: ignore @@ -378,19 +376,16 @@ def MergeSharedMemoryAllocations(enable_aggressive_merge: bool = False, align_by fpass : tvm.transform.Pass The result pass """ - return _ffi_api.MergeSharedMemoryAllocations(enable_aggressive_merge, - align_bytes) # type: ignore + return _ffi_api.MergeSharedMemoryAllocations(enable_aggressive_merge, align_bytes) # type: ignore def LowerL2Persistent(): - """LowerL2Persistent - """ + """LowerL2Persistent""" return _ffi_api.LowerL2Persistent() # type: ignore def PersistThreadblock(): - """PersistThreadblock - """ + """PersistThreadblock""" return _ffi_api.PersistThreadblock() # type: ignore @@ -409,8 +404,7 @@ def AlignDynamicSharedMemoryAllocations(align_bytes: int = 16): def LowerSharedBarrier(): - """LowerSharedBarrier - """ + """LowerSharedBarrier""" return _ffi_api.LowerSharedBarrier() # type: ignore @@ -437,20 +431,17 @@ def StorageRewrite(): def LowerOpaqueBlock(): - """LowerOpaqueBlock - """ + """LowerOpaqueBlock""" return _ffi_api.LowerOpaqueBlock() # type: ignore def LowerThreadAllreduce(): - """LowerThreadAllreduce - """ + """LowerThreadAllreduce""" return _ffi_api.LowerThreadAllreduce() # type: ignore def LowerIntrin(): - """LowerIntrin - """ + """LowerIntrin""" return _ffi_api.LowerIntrin() # type: ignore @@ -468,8 +459,7 @@ def LowerDeviceKernelLaunch(): def LowerSharedTmem(): - """LowerSharedTmem - """ + """LowerSharedTmem""" return _ffi_api.LowerSharedTmem() # type: ignore diff --git a/tilelang/transform/add_bufstore_wrapper.py b/tilelang/transform/add_bufstore_wrapper.py index d8457f9..c1dd41e 100644 --- a/tilelang/transform/add_bufstore_wrapper.py +++ b/tilelang/transform/add_bufstore_wrapper.py @@ -1,4 +1,4 @@ -from tvm.tir import (BufferStore, For, AttrStmt, ForKind, Var, PrimFunc, BufferLoad, Buffer, IntImm) +from tvm.tir import BufferStore, For, AttrStmt, ForKind, Var, PrimFunc, BufferLoad, Buffer, IntImm from tvm.tir.stmt_functor import ir_transform, post_order_visit from tvm.tir.transform import prim_func_pass @@ -97,7 +97,7 @@ def AddWrapperForSingleBufStore(): Returns: True if the loop is a tile operation (parallel or has num_stages annotation) """ - return loop.kind == ForKind.PARALLEL or 'num_stages' in loop.annotations + return loop.kind == ForKind.PARALLEL or "num_stages" in loop.annotations def pre_visit(statement): """ @@ -105,7 +105,7 @@ def AddWrapperForSingleBufStore(): """ nonlocal tile_operation_depth - if isinstance(statement, AttrStmt) and statement.attr_key == 'thread_extent': + if isinstance(statement, AttrStmt) and statement.attr_key == "thread_extent": thread_binding_vars.add(statement.node.var) elif isinstance(statement, For) and is_tile_operation_loop(statement): tile_operation_depth += 1 @@ -139,7 +139,8 @@ def AddWrapperForSingleBufStore(): if isinstance(index, IntImm) and index != 0: raise ValueError( f"Fragment buffer access with non-zero index [{index}] is not supported. " - "Only fragment[0] access is allowed.") + "Only fragment[0] access is allowed." + ) # Wrap fragment[0] access with T.Parallel loop return For(Var("_", "int32"), 0, 1, ForKind.PARALLEL, statement) diff --git a/tilelang/transform/pass_config.py b/tilelang/transform/pass_config.py index 92adcb4..92a7313 100644 --- a/tilelang/transform/pass_config.py +++ b/tilelang/transform/pass_config.py @@ -5,6 +5,7 @@ from enum import Enum class PassConfigKey(str, Enum): """Pass configuration keys for TileLang compiler.""" + # TileLang specific configs TL_SIMPLIFY = "tl.Simplify" """Enable/disable TileLang simplification passes. Default: True""" diff --git a/tilelang/transform/simplify.py b/tilelang/transform/simplify.py index 7e0c506..c5e577d 100644 --- a/tilelang/transform/simplify.py +++ b/tilelang/transform/simplify.py @@ -51,7 +51,6 @@ def _Simplify(stmt: PrimFunc | IRModule, inline_let: bool = False) -> PrimFunc | # Decorator to simplify the output of a function def simplify_prim_func(func: Callable) -> Callable: - def wrapper(*args, **kwargs): stmt: PrimFunc | IRModule = (func)(*args, **kwargs) return _Simplify(stmt) diff --git a/tilelang/utils/deprecated.py b/tilelang/utils/deprecated.py index 2aff08b..2944f29 100644 --- a/tilelang/utils/deprecated.py +++ b/tilelang/utils/deprecated.py @@ -1,11 +1,10 @@ def deprecated_warning(method_name: str, new_method_name: str, phaseout_version: str = None): - """A function to indicate that a method is deprecated - """ + """A function to indicate that a method is deprecated""" import warnings # pylint: disable=import-outside-toplevel, import-error warnings.warn( - f"{method_name} is deprecated, use {new_method_name} instead" + - (f" and will be removed in {phaseout_version}" if phaseout_version else ""), + f"{method_name} is deprecated, use {new_method_name} instead" + + (f" and will be removed in {phaseout_version}" if phaseout_version else ""), DeprecationWarning, stacklevel=2, ) @@ -30,7 +29,6 @@ def deprecated( import functools # pylint: disable=import-outside-toplevel def _deprecate(func): - @functools.wraps(func) def _wrapper(*args, **kwargs): deprecated_warning(method_name, new_method_name, phaseout_version) diff --git a/tilelang/utils/language.py b/tilelang/utils/language.py index 41da8ab..584e999 100644 --- a/tilelang/utils/language.py +++ b/tilelang/utils/language.py @@ -24,8 +24,7 @@ def _get_buffer(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> elif isinstance(buffer_or_load_or_region, (tir.BufferLoad, tir.BufferRegion)): return buffer_or_load_or_region.buffer else: - raise TypeError( - f"Expected Buffer, BufferLoad, or BufferRegion, got {type(buffer_or_load_or_region)}") + raise TypeError(f"Expected Buffer, BufferLoad, or BufferRegion, got {type(buffer_or_load_or_region)}") def is_global(buffer: Buffer | BufferLoad | BufferRegion) -> bool: @@ -153,14 +152,12 @@ def retrieve_func_from_module(ir_module: IRModule) -> PrimFunc: """ if not isinstance(ir_module, IRModule): raise ValueError("Not supported type: ", type(ir_module)) - assert len(ir_module.get_global_vars()) == 1, ( - "The optimized module should only have one global variable for default schedule.") + assert len(ir_module.get_global_vars()) == 1, "The optimized module should only have one global variable for default schedule." func = list(ir_module.functions.values())[0] return func -def get_buffer_region_from_load(buffer_load: tir.BufferLoad, - extents: list[PrimExpr] | None = None) -> tir.BufferRegion | None: +def get_buffer_region_from_load(buffer_load: tir.BufferLoad, extents: list[PrimExpr] | None = None) -> tir.BufferRegion | None: """ Get the buffer region from a buffer load. @@ -193,9 +190,9 @@ def get_buffer_region_from_load(buffer_load: tir.BufferLoad, return None -def to_buffer_region(obj: Buffer | BufferLoad | BufferRegion | tir.Var, - access_type: str = "rw", - extents: list[PrimExpr] | None = None) -> PrimExpr | BufferRegion: +def to_buffer_region( + obj: Buffer | BufferLoad | BufferRegion | tir.Var, access_type: str = "rw", extents: list[PrimExpr] | None = None +) -> PrimExpr | BufferRegion: """ Convert to/from the tl.region representation. @@ -203,6 +200,7 @@ def to_buffer_region(obj: Buffer | BufferLoad | BufferRegion | tir.Var, - tl.region Call -> returns the decoded BufferRegion for analysis """ from tilelang.language.frame import has_let_value, get_let_value + if isinstance(obj, tir.Var) and has_let_value(obj): obj = get_let_value(obj) # Encode into tl.region call (when extents is provided), otherwise return BufferRegion for analysis @@ -279,8 +277,7 @@ def retrieve_stride(obj: Buffer | BufferRegion | BufferLoad) -> list: return strides -def retrive_ptr_from_buffer_region(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion, - access_type: str = "r") -> PrimExpr: +def retrive_ptr_from_buffer_region(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion, access_type: str = "r") -> PrimExpr: if isinstance(buffer_or_load_or_region, Buffer): return buffer_or_load_or_region.access_ptr(access_type) elif isinstance(buffer_or_load_or_region, BufferLoad): diff --git a/tilelang/utils/sparse.py b/tilelang/utils/sparse.py index a7b17ad..26a8e34 100644 --- a/tilelang/utils/sparse.py +++ b/tilelang/utils/sparse.py @@ -15,7 +15,7 @@ os.makedirs(_CACHE_DIR, exist_ok=True) def _get_cached_lib(): - name = 'compress_lib' + name = "compress_lib" if os.path.exists(os.path.join(_CACHE_DIR, f"{name}.so")): try: @@ -32,24 +32,22 @@ def _get_cached_lib(): name=name, sources=[compress_util], extra_cuda_cflags=[ - '-O2', - '-std=c++17', - '-lineinfo', - f'-I{env.CUTLASS_INCLUDE_DIR}', - f'-I{env.CUTLASS_INCLUDE_DIR}/../tools/util/include', - '-arch=sm_90', + "-O2", + "-std=c++17", + "-lineinfo", + f"-I{env.CUTLASS_INCLUDE_DIR}", + f"-I{env.CUTLASS_INCLUDE_DIR}/../tools/util/include", + "-arch=sm_90", ], build_directory=_CACHE_DIR, ) -def compress_sm90(A: torch.Tensor, block_k: int, - transposed: bool) -> tuple[torch.Tensor, torch.Tensor]: +def compress_sm90(A: torch.Tensor, block_k: int, transposed: bool) -> tuple[torch.Tensor, torch.Tensor]: if block_k > 128: block_k = 128 # Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146 - warnings.warn( - f"block_k {block_k} is too large, set to 128 for sm90 compression.", stacklevel=2) + warnings.warn(f"block_k {block_k} is too large, set to 128 for sm90 compression.", stacklevel=2) # Load the library (will use cache if available) compress_lib = _get_cached_lib() @@ -60,8 +58,9 @@ def compress_sm80(A: torch.Tensor, transposed: bool) -> tuple[torch.Tensor, torc try: from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor except ImportError as err: - raise ImportError("SparseSemiStructuredTensor is not available in this version of PyTorch. " - "Please install a compatible version.") from err + raise ImportError( + "SparseSemiStructuredTensor is not available in this version of PyTorch. Please install a compatible version." + ) from err orig_val = SparseSemiStructuredTensor._FORCE_CUTLASS try: SparseSemiStructuredTensor._FORCE_CUTLASS = True @@ -73,10 +72,7 @@ def compress_sm80(A: torch.Tensor, transposed: bool) -> tuple[torch.Tensor, torc SparseSemiStructuredTensor._FORCE_CUTLASS = orig_val -def compress(A: torch.Tensor, - transposed: bool, - arch: str | None = None, - **kwargs) -> tuple[torch.Tensor, torch.Tensor]: +def compress(A: torch.Tensor, transposed: bool, arch: str | None = None, **kwargs) -> tuple[torch.Tensor, torch.Tensor]: """ Compress a tensor using the appropriate method based on the CUDA architecture. """ @@ -101,11 +97,10 @@ def compress(A: torch.Tensor, A_sp = A_sp.t().contiguous() return A_sp, E else: - raise ValueError(f"Unsupported CUDA compute version: {compute_version}. " - "Supported versions are sm_80 and sm_90.") + raise ValueError(f"Unsupported CUDA compute version: {compute_version}. Supported versions are sm_80 and sm_90.") -def randn_semi_sparse(M: int, K: int, dtype=torch.float16, device='cuda', transposed: bool = False): +def randn_semi_sparse(M: int, K: int, dtype=torch.float16, device="cuda", transposed: bool = False): """ Generate a random semi-sparse tensor. The generated tensor will have 2:4 sparsity along the K dimension. Args: @@ -127,13 +122,7 @@ def randn_semi_sparse(M: int, K: int, dtype=torch.float16, device='cuda', transp return tensor.to(dtype) # dtype like float8 might not have randn kernel -def randint_semi_sparse(M: int, - K: int, - low: int, - high: int, - dtype=torch.int32, - device='cuda', - transposed: bool = False): +def randint_semi_sparse(M: int, K: int, low: int, high: int, dtype=torch.int32, device="cuda", transposed: bool = False): """ Generate a random semi-sparse integer tensor. The generated tensor will have 2:4 sparsity along the K dimension. Args: @@ -157,11 +146,7 @@ def randint_semi_sparse(M: int, return tensor -def arange_semi_sparse(M: int, - K: int, - dtype=torch.float16, - device='cuda', - transposed: bool = False): +def arange_semi_sparse(M: int, K: int, dtype=torch.float16, device="cuda", transposed: bool = False): """ Generate a semi-sparse tensor with values from 0 to M*K-1. The generated tensor will have 2:4 sparsity along the K dimension. Args: diff --git a/tilelang/utils/target.py b/tilelang/utils/target.py index 094c099..4ead7ef 100644 --- a/tilelang/utils/target.py +++ b/tilelang/utils/target.py @@ -56,11 +56,10 @@ def check_metal_availability() -> bool: if not mac_release: return False # todo: check torch version? - return arch == 'arm64' + return arch == "arm64" -def determine_target(target: str | Target | Literal["auto"] = "auto", - return_object: bool = False) -> str | Target: +def determine_target(target: str | Target | Literal["auto"] = "auto", return_object: bool = False) -> str | Target: """ Determine the appropriate target for compilation (CUDA, HIP, or manual selection). diff --git a/tilelang/utils/tensor.py b/tilelang/utils/tensor.py index b2905fb..f1d4fc7 100644 --- a/tilelang/utils/tensor.py +++ b/tilelang/utils/tensor.py @@ -1,4 +1,5 @@ """The profiler and convert to torch utils""" + from enum import Enum import torch from tvm import tir @@ -17,7 +18,7 @@ def is_float8_dtype(dtype: torch.dtype) -> bool: def fp8_remove_negative_zeros_(tensor: torch.Tensor): assert is_float8_dtype(tensor.dtype), "Input tensor must be of float8 dtype" bits = tensor.view(torch.uint8) - zeros_mask = (tensor == 0) + zeros_mask = tensor == 0 bits[zeros_mask] = 0x00 @@ -33,26 +34,21 @@ class TensorSupplyType(Enum): def map_torch_type(intype: str) -> torch.dtype: if intype == "float8_e4m3": - assert hasattr(torch, "float8_e4m3fn"), \ - "torch.float8_e4m3fn is not supported in this version of torch" \ - "Please upgrade torch >= 2.1.0" + assert hasattr(torch, "float8_e4m3fn"), "torch.float8_e4m3fn is not supported in this version of torchPlease upgrade torch >= 2.1.0" return torch.float8_e4m3fn elif intype == "float8_e5m2": - assert hasattr(torch, "float8_e5m2"), \ - "torch.float8_e5m2 is not supported in this version of torch" \ - "Please upgrade torch >= 2.1.0" + assert hasattr(torch, "float8_e5m2"), "torch.float8_e5m2 is not supported in this version of torchPlease upgrade torch >= 2.1.0" return torch.float8_e5m2 elif intype == "e4m3fnuz_float8": - assert hasattr(torch, "float8_e4m3fnuz"), \ - "torch.float8_e4m3fnuz is not supported in this version of torch" \ - "Please upgrade torch >= 2.2.0" + assert hasattr(torch, "float8_e4m3fnuz"), ( + "torch.float8_e4m3fnuz is not supported in this version of torchPlease upgrade torch >= 2.2.0" + ) return torch.float8_e4m3fnuz else: return getattr(torch, intype) def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer): - from tilelang.engine.param import KernelParam from .device import get_current_device @@ -63,7 +59,8 @@ def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer): if hasattr(param, "shape") and not param.shape: raise ValueError( f"TensorType must have a shape, but got {type(param)}, " - "likely you are trying to generate a random tensor with a dynamic symbolic shape.") + "likely you are trying to generate a random tensor with a dynamic symbolic shape." + ) # Check if with dynamic symbolic shape for shape in param.shape: @@ -81,8 +78,7 @@ def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer): if is_unsigned: return torch.randint(low=0, high=3, size=shape, device=device, dtype=dtype) elif is_float8: - return torch.randint( - low=-128, high=128, size=shape, device=device, dtype=torch.int8).to(dtype) + return torch.randint(low=-128, high=128, size=shape, device=device, dtype=torch.int8).to(dtype) elif is_boolean: return torch.randint(low=0, high=2, size=shape, device=device, dtype=dtype) elif dtype in {torch.float16, torch.float32, torch.bfloat16}: @@ -91,8 +87,8 @@ def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer): return torch.randint(low=-2, high=3, size=shape, device=device, dtype=dtype) if dtype == torch.int8 and supply_type in [ - TensorSupplyType.Uniform, - TensorSupplyType.Normal, + TensorSupplyType.Uniform, + TensorSupplyType.Normal, ]: return torch.ones(*shape, device=device, dtype=dtype) @@ -103,18 +99,15 @@ def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer): if is_unsigned: return torch.randint(low=0, high=3, size=shape, device=device, dtype=dtype) elif is_float8: - return torch.randint( - low=-128, high=128, size=shape, device=device, dtype=torch.int8).to(dtype) + return torch.randint(low=-128, high=128, size=shape, device=device, dtype=torch.int8).to(dtype) elif is_boolean: return torch.randint(low=0, high=2, size=shape, device=device, dtype=dtype) else: return torch.randint(low=-2, high=3, size=shape, device=device, dtype=dtype) elif supply_type == TensorSupplyType.Uniform: - return torch.empty( - *shape, device=device, dtype=torch.float32).uniform_(-1.0, 1.0).to(dtype) + return torch.empty(*shape, device=device, dtype=torch.float32).uniform_(-1.0, 1.0).to(dtype) elif supply_type == TensorSupplyType.Normal: - return torch.empty( - *shape, device=device, dtype=torch.float32).normal_(-1.0, 1.0).to(dtype) + return torch.empty(*shape, device=device, dtype=torch.float32).normal_(-1.0, 1.0).to(dtype) elif supply_type == TensorSupplyType.Randn: return torch.randn(*shape, device=device).to(dtype) elif supply_type == TensorSupplyType.Zero: @@ -150,9 +143,7 @@ def _compare_attributes( """ def raise_mismatch_error(attribute_name: str, actual_value, expected_value): - raise AssertionError( - f"The values for attribute '{attribute_name}' do not match: {actual_value} != {expected_value}." - ) + raise AssertionError(f"The values for attribute '{attribute_name}' do not match: {actual_value} != {expected_value}.") if actual.shape != expected.shape: raise_mismatch_error("shape", actual.shape, expected.shape) @@ -163,7 +154,7 @@ def _compare_attributes( if actual.layout != expected.layout: if check_layout: raise_mismatch_error("layout", actual.layout, expected.layout) - elif (actual.layout == torch.strided and check_stride and actual.stride() != expected.stride()): + elif actual.layout == torch.strided and check_stride and actual.stride() != expected.stride(): raise_mismatch_error("stride()", actual.stride(), expected.stride()) if check_device and actual.device != expected.device: raise_mismatch_error("device", actual.device, expected.device) @@ -171,8 +162,7 @@ def _compare_attributes( raise_mismatch_error("dtype", actual.dtype, expected.dtype) -def _equalize_attributes(actual: torch.Tensor, - expected: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +def _equalize_attributes(actual: torch.Tensor, expected: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Equalizes some attributes of two tensors for value comparison. If ``actual`` and ``expected`` are ... - ... not on the same :attr:`~torch.Tensor.device`, they are moved CPU memory. @@ -210,7 +200,7 @@ def _equalize_attributes(actual: torch.Tensor, if actual.layout != expected.layout: # These checks are needed, since Tensor.to_dense() fails on tensors that are already strided actual = actual.to_dense() if actual.layout != torch.strided else actual - expected = (expected.to_dense() if expected.layout != torch.strided else expected) + expected = expected.to_dense() if expected.layout != torch.strided else expected return actual, expected @@ -254,12 +244,8 @@ def torch_assert_close( """ _compare_attributes( - tensor_a, - tensor_b, - check_device=check_device, - check_dtype=check_dtype, - check_layout=check_layout, - check_stride=check_stride) + tensor_a, tensor_b, check_device=check_device, check_dtype=check_dtype, check_layout=check_layout, check_stride=check_stride + ) tensor_a, tensor_b = _equalize_attributes(tensor_a, tensor_b) mismatched = ~torch.isclose(tensor_a, tensor_b, rtol=rtol, atol=atol, equal_nan=equal_nan) @@ -276,8 +262,7 @@ def torch_assert_close( # Print debug information about the mismatch if verbose: - print(f"Number of mismatched elements: {num_mismatched} / {total_elements} " - f"(allowed: {max_allowed_mismatched})") + print(f"Number of mismatched elements: {num_mismatched} / {total_elements} (allowed: {max_allowed_mismatched})") # If there are mismatched elements, print the first mismatch if num_mismatched > 0: @@ -289,9 +274,9 @@ def torch_assert_close( b_val = tensor_b.reshape(-1)[flat_idx].item() abs_diff = abs(a_val - b_val) rel_diff = abs_diff / (abs(b_val) + 1e-12) - mismatch_info = (f"\nFirst mismatch at index {idx}: " - f"lhs={a_val:.6f}, rhs={b_val:.6f}, " - f"abs_diff={abs_diff:.6f}, rel_diff={rel_diff:.6f}") + mismatch_info = ( + f"\nFirst mismatch at index {idx}: lhs={a_val:.6f}, rhs={b_val:.6f}, abs_diff={abs_diff:.6f}, rel_diff={rel_diff:.6f}" + ) else: mismatch_info = "" @@ -304,6 +289,7 @@ def torch_assert_close( f"\nGreatest absolute difference: {diff.max().item()}, " f"Greatest relative difference: {(diff / (torch.abs(tensor_b) + 1e-12)).max().item()}" f"\n{base_name}: {tensor_a}" - f"\n{ref_name}: {tensor_b}") + f"\n{ref_name}: {tensor_b}" + ) else: return True diff --git a/version_provider.py b/version_provider.py index 3eb45aa..c2ca929 100644 --- a/version_provider.py +++ b/version_provider.py @@ -8,29 +8,26 @@ from functools import lru_cache ROOT = Path(__file__).parent -base_version = (ROOT / 'VERSION').read_text().strip() +base_version = (ROOT / "VERSION").read_text().strip() # When installing a sdist, # the installed version needs to match the sdist version, # so pip will complain when we install `tilelang-0.1.6.post2+gitxxxx.tar.gz`. # To workaround that, when building sdist, # we do not add version label and use a file to store the git hash instead. -git_pin = ROOT / '.git_commit.txt' +git_pin = ROOT / ".git_commit.txt" def _read_cmake_bool(i: str | None, default=False): if i is None: return default - return i.lower() not in ('0', 'false', 'off', 'no', 'n', '') + return i.lower() not in ("0", "false", "off", "no", "n", "") @lru_cache(maxsize=1) def get_git_commit_id() -> str | None: """Get the current git commit hash by running git in the current file's directory.""" - r = subprocess.run(['git', 'rev-parse', 'HEAD'], - cwd=ROOT, - capture_output=True, - encoding='utf-8') + r = subprocess.run(["git", "rev-parse", "HEAD"], cwd=ROOT, capture_output=True, encoding="utf-8") if r.returncode == 0: _git = r.stdout.strip() git_pin.write_text(_git) @@ -41,51 +38,48 @@ def get_git_commit_id() -> str | None: return None -def dynamic_metadata( - field: str, - settings: dict[str, object] | None = None, -) -> str: - assert field == 'version' +def dynamic_metadata(field: str, settings: dict[str, object] | None = None) -> str: + assert field == "version" version = base_version # generate git version for sdist get_git_commit_id() - if not _read_cmake_bool(os.environ.get('NO_VERSION_LABEL')): + if not _read_cmake_bool(os.environ.get("NO_VERSION_LABEL")): exts = [] backend = None - if _read_cmake_bool(os.environ.get('NO_TOOLCHAIN_VERSION')): + if _read_cmake_bool(os.environ.get("NO_TOOLCHAIN_VERSION")): pass - elif platform.system() == 'Darwin': + elif platform.system() == "Darwin": # only on macosx_11_0_arm64, not necessary # backend = 'metal' pass - elif _read_cmake_bool(os.environ.get('USE_ROCM', '')): - backend = 'rocm' - elif 'USE_CUDA' in os.environ and not _read_cmake_bool(os.environ.get('USE_CUDA')): - backend = 'cpu' + elif _read_cmake_bool(os.environ.get("USE_ROCM", "")): + backend = "rocm" + elif "USE_CUDA" in os.environ and not _read_cmake_bool(os.environ.get("USE_CUDA")): + backend = "cpu" else: # cuda # Read nvcc version from env. # This is not exactly how it should be, # but works for now if building in a nvidia/cuda image. - if cuda_version := os.environ.get('CUDA_VERSION'): - major, minor, *_ = cuda_version.split('.') - backend = f'cu{major}{minor}' + if cuda_version := os.environ.get("CUDA_VERSION"): + major, minor, *_ = cuda_version.split(".") + backend = f"cu{major}{minor}" else: - backend = 'cuda' + backend = "cuda" if backend: exts.append(backend) - if _read_cmake_bool(os.environ.get('NO_GIT_VERSION')): + if _read_cmake_bool(os.environ.get("NO_GIT_VERSION")): pass elif git_hash := get_git_commit_id(): - exts.append(f'git{git_hash[:8]}') + exts.append(f"git{git_hash[:8]}") else: - exts.append('gitunknown') + exts.append("gitunknown") if exts: - version += '+' + '.'.join(exts) + version += "+" + ".".join(exts) return version -- GitLab