Commit bb9beb36 authored by Tri Dao's avatar Tri Dao
Browse files

Remove some unused headers

parent 08c295c0
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
* Copyright (c) 2023, Tri Dao. * Copyright (c) 2023, Tri Dao.
******************************************************************************/ ******************************************************************************/
#include <torch/extension.h> // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
#include <torch/python.h>
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
......
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
#include <ATen/cuda/CUDAGeneratorImpl.h> #include <ATen/cuda/CUDAGeneratorImpl.h>
#endif #endif
#include <ATen/cuda/CUDAGraphsUtils.cuh> #include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
constexpr int TOTAL_DIM = 0; constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1; constexpr int H_DIM = 1;
......
...@@ -5,18 +5,15 @@ ...@@ -5,18 +5,15 @@
#pragma once #pragma once
#include <cute/algorithm/copy.hpp> #include <cute/algorithm/copy.hpp>
#include <cute/algorithm/gemm.hpp>
#include <cutlass/cutlass.h> #include <cutlass/cutlass.h>
#include <cutlass/array.h> #include <cutlass/array.h>
#include <cutlass/numeric_types.h> #include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include "block_info.h" #include "block_info.h"
#include "kernel_traits.h" #include "kernel_traits.h"
#include "utils.h" #include "utils.h"
#include "softmax.h" #include "softmax.h"
#include "philox.cuh"
namespace flash { namespace flash {
......
...@@ -4,20 +4,16 @@ ...@@ -4,20 +4,16 @@
#pragma once #pragma once
#include <cmath>
#include <cute/algorithm/copy.hpp> #include <cute/algorithm/copy.hpp>
#include <cute/algorithm/gemm.hpp>
#include <cutlass/cutlass.h> #include <cutlass/cutlass.h>
#include <cutlass/array.h> #include <cutlass/array.h>
#include <cutlass/numeric_types.h> #include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include "block_info.h" #include "block_info.h"
#include "kernel_traits.h" #include "kernel_traits.h"
#include "utils.h" #include "utils.h"
#include "softmax.h" #include "softmax.h"
#include "philox.cuh"
namespace flash { namespace flash {
...@@ -25,49 +21,6 @@ using namespace cute; ...@@ -25,49 +21,6 @@ using namespace cute;
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template <int MMA_M,
class... Args,
class TiledMMA>
CUTE_HOST_DEVICE
auto
make_tiled_copy_A_warpcontiguousM(Copy_Atom<Args...> const& copy_atom,
TiledMMA const& tiled_mma) {
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value;
constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M;
constexpr int MMAStride_M = MMA_M * AtomShape_M;
auto t = make_tile(Layout<Shape<Int<AtomShape_M>, Int<kNWarps>>,
Stride<_1, Int<MMAStride_M>> >{},
make_layout(size<2>(TileShape_MNK{})));
// if (cute::thread0()) {printf("make_tiled_copy_A_warpcontiguousM "); print(t); printf("\n"); }
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int MMA_M,
class... Args,
class TiledMMA>
CUTE_HOST_DEVICE
auto
make_tiled_copy_C_warpcontiguousM(Copy_Atom<Args...> const& copy_atom,
TiledMMA const& tiled_mma) {
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value;
constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M;
constexpr int MMAStride_M = MMA_M * AtomShape_M;
auto t = make_tile(Layout<Shape<Int<AtomShape_M>, Int<kNWarps>>,
Stride<_1, Int<MMAStride_M>> >{},
// TODO: Shouldn't this be size<1>?
make_layout(size<2>(TileShape_MNK{})));
// if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); }
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1, typename Tensor2> template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1, typename Tensor2>
inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum, inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum,
Tensor2 &acc_o, float softmax_scale_log2) { Tensor2 &acc_o, float softmax_scale_log2) {
...@@ -256,7 +209,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -256,7 +209,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
// auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
// if (cute::thread0()) {smem_thr_copy_Q.print_all();} // if (cute::thread0()) {smem_thr_copy_Q.print_all();}
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
...@@ -558,7 +510,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -558,7 +510,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Partition sO to match the accumulator partitioning // Partition sO to match the accumulator partitioning
auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
// auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
......
...@@ -76,7 +76,6 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -76,7 +76,6 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
BOOL_SWITCH(params.num_splits > 1, Split, [&] { BOOL_SWITCH(params.num_splits > 1, Split, [&] {
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr);
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, IsEvenMNConst && !Append_KV, IsEvenKConst, Split, Append_KV>; auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, IsEvenMNConst && !Append_KV, IsEvenKConst, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>; // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>; // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
......
...@@ -8,8 +8,7 @@ ...@@ -8,8 +8,7 @@
#include <cute/tensor.hpp> #include <cute/tensor.hpp>
#include <cutlass/cutlass.h> #include <cutlass/numeric_types.h>
#include <cutlass/array.h>
#include "philox.cuh" #include "philox.cuh"
#include "utils.h" #include "utils.h"
......
...@@ -189,7 +189,7 @@ if not SKIP_CUDA_BUILD: ...@@ -189,7 +189,7 @@ if not SKIP_CUDA_BUILD:
"--expt-relaxed-constexpr", "--expt-relaxed-constexpr",
"--expt-extended-lambda", "--expt-extended-lambda",
"--use_fast_math", "--use_fast_math",
"--ptxas-options=-v", # "--ptxas-options=-v",
# "--ptxas-options=-O2", # "--ptxas-options=-O2",
"-lineinfo" "-lineinfo"
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment