Unverified Commit cb516f85 authored by Cameron Shinn's avatar Cameron Shinn Committed by GitHub
Browse files

Remove torchlib dependency from cpp files (#1083)

parent 5f1ae4a3
...@@ -7,14 +7,6 @@ ...@@ -7,14 +7,6 @@
#include <cuda.h> #include <cuda.h>
#include <vector> #include <vector>
#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
#include "cutlass/fast_math.h" // For cutlass::FastDivmod #include "cutlass/fast_math.h" // For cutlass::FastDivmod
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -118,9 +110,6 @@ struct Flash_fwd_params : public Qkv_params { ...@@ -118,9 +110,6 @@ struct Flash_fwd_params : public Qkv_params {
// Local window size // Local window size
int window_size_left, window_size_right; int window_size_left, window_size_right;
// Random state.
at::PhiloxCudaState philox_args;
// Pointer to the RNG seed (idx 0) and offset (idx 1). // Pointer to the RNG seed (idx 0) and offset (idx 1).
uint64_t * rng_state; uint64_t * rng_state;
......
...@@ -4,8 +4,6 @@ ...@@ -4,8 +4,6 @@
#pragma once #pragma once
#include <ATen/cuda/CUDAContext.h>
#include "cute/tensor.hpp" #include "cute/tensor.hpp"
#include "cutlass/cluster_launch.hpp" #include "cutlass/cluster_launch.hpp"
...@@ -15,6 +13,7 @@ ...@@ -15,6 +13,7 @@
#include "flash_bwd_preprocess_kernel.h" #include "flash_bwd_preprocess_kernel.h"
#include "flash_bwd_kernel.h" #include "flash_bwd_kernel.h"
#include "kernel_traits.h" #include "kernel_traits.h"
#include "utils.h"
template<bool Clear_dQaccum=true, typename Kernel_traits> template<bool Clear_dQaccum=true, typename Kernel_traits>
__global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) { __global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) {
...@@ -38,7 +37,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) { ...@@ -38,7 +37,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreadsNonWS, 0, stream>>>(params); flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreadsNonWS, 0, stream>>>(params);
// If we use both TMA_STORE (for n_block=0) and TMA_REDUCE_ADD (for n_block>0), we don't need to clear dQaccum // If we use both TMA_STORE (for n_block=0) and TMA_REDUCE_ADD (for n_block>0), we don't need to clear dQaccum
// flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreadsNonWS, 0, stream>>>(params); // flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreadsNonWS, 0, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK(); CHECK_CUDA_KERNEL_LAUNCH();
using Element = typename Kernel_traits::Element; using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum; using ElementAccum = typename Kernel_traits::ElementAccum;
...@@ -157,7 +156,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) { ...@@ -157,7 +156,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
// printf("smem_size = %d, q = %d, do = %d, k = %d, v = %d, p = %d, ds = %d\n", smem_size, smem_size_q, smem_size_do, smem_size_k, smem_size_v, smem_size_p, smem_size_ds); // printf("smem_size = %d, q = %d, do = %d, k = %d, v = %d, p = %d, ds = %d\n", smem_size, smem_size_q, smem_size_do, smem_size_k, smem_size_v, smem_size_p, smem_size_ds);
// printf("smem_size = %d, q = %d, do = %d, k = %d, v = %d, ds = %d\n", smem_size, smem_size_q, smem_size_do, smem_size_k, smem_size_v, smem_size_ds); // printf("smem_size = %d, q = %d, do = %d, k = %d, v = %d, ds = %d\n", smem_size, smem_size_q, smem_size_do, smem_size_k, smem_size_v, smem_size_ds);
if (smem_size >= 48 * 1024) { if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
} }
static constexpr int ctaSize = Kernel_traits::kNWarps * 32; static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
...@@ -179,7 +178,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) { ...@@ -179,7 +178,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
} }
// cutlass::launch_kernel_on_cluster(launch_params, kernel, params, tma_load_Q, tma_load_dO, // cutlass::launch_kernel_on_cluster(launch_params, kernel, params, tma_load_Q, tma_load_dO,
// tma_load_K, tma_load_V, tma_store_dQaccum, tma_store_dK, tma_store_dV); // tma_load_K, tma_load_V, tma_store_dQaccum, tma_store_dK, tma_store_dV);
C10_CUDA_KERNEL_LAUNCH_CHECK(); CHECK_CUDA_KERNEL_LAUNCH();
auto tma_load_dQaccum = make_tma_copy( auto tma_load_dQaccum = make_tma_copy(
typename cute::SM90_TMA_LOAD{}, typename cute::SM90_TMA_LOAD{},
...@@ -190,20 +189,20 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) { ...@@ -190,20 +189,20 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
// auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>; // auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
auto kernel_dq = &flash::convert_dQ<Kernel_traits, decltype(tma_load_dQaccum)>; auto kernel_dq = &flash::convert_dQ<Kernel_traits, decltype(tma_load_dQaccum)>;
if (Kernel_traits::kSmemdQSize * 2 + 8 >= 48 * 1024) { if (Kernel_traits::kSmemdQSize * 2 + 8 >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute( CHECK_CUDA(cudaFuncSetAttribute(
kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize * 2 + 8)); kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize * 2 + 8));
} }
kernel_dq<<<grid_m, Kernel_traits::kNThreadsdQ, Kernel_traits::kSmemdQSize * 2 + 8, stream>>>(params, tma_load_dQaccum); kernel_dq<<<grid_m, Kernel_traits::kNThreadsdQ, Kernel_traits::kSmemdQSize * 2 + 8, stream>>>(params, tma_load_dQaccum);
C10_CUDA_KERNEL_LAUNCH_CHECK(); CHECK_CUDA_KERNEL_LAUNCH();
// auto kernel_dkv = &flash_bwd_convert_dkv_kernel<Kernel_traits>; // auto kernel_dkv = &flash_bwd_convert_dkv_kernel<Kernel_traits>;
// if (Kernel_traits::kSmemdKVSize >= 48 * 1024) { // if (Kernel_traits::kSmemdKVSize >= 48 * 1024) {
// C10_CUDA_CHECK(cudaFuncSetAttribute( // CHECK_CUDA(cudaFuncSetAttribute(
// kernel_dkv, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdKVSize)); // kernel_dkv, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdKVSize));
// } // }
// int num_n_block = cute::ceil_div(params.seqlen_k, Kernel_traits::kBlockN); // int num_n_block = cute::ceil_div(params.seqlen_k, Kernel_traits::kBlockN);
// dim3 grid_n(num_n_block, params.b, params.h); // dim3 grid_n(num_n_block, params.b, params.h);
// kernel_dkv<<<grid_n, Kernel_traits::kNThreads, Kernel_traits::kSmemdKVSize, stream>>>(params); // kernel_dkv<<<grid_n, Kernel_traits::kNThreads, Kernel_traits::kSmemdKVSize, stream>>>(params);
// C10_CUDA_KERNEL_LAUNCH_CHECK(); // CHECK_CUDA_KERNEL_LAUNCH();
} }
......
...@@ -4,8 +4,6 @@ ...@@ -4,8 +4,6 @@
#pragma once #pragma once
#include <ATen/cuda/CUDAContext.h>
#include "cute/tensor.hpp" #include "cute/tensor.hpp"
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
...@@ -16,6 +14,7 @@ ...@@ -16,6 +14,7 @@
#include "tile_scheduler.hpp" #include "tile_scheduler.hpp"
#include "flash_fwd_kernel.h" #include "flash_fwd_kernel.h"
#include "kernel_traits.h" #include "kernel_traits.h"
#include "utils.h"
template<typename Kernel_traits, bool Is_causal> template<typename Kernel_traits, bool Is_causal>
...@@ -66,7 +65,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -66,7 +65,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v)); // int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v));
// printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v); // printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v);
if (smem_size >= 48 * 1024) { if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
} }
int device; int device;
...@@ -75,7 +74,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -75,7 +74,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
cudaError status_ = cudaDeviceGetAttribute( cudaError status_ = cudaDeviceGetAttribute(
&multiprocessor_count, cudaDevAttrMultiProcessorCount, device); &multiprocessor_count, cudaDevAttrMultiProcessorCount, device);
if (status_ != cudaSuccess) { if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_); CHECK_CUDA(status_);
} }
dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count); dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count);
static constexpr int ctaSize = Kernel_traits::kNWarps * 32; static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
...@@ -83,7 +82,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -83,7 +82,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{})); dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream}; cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params, epilogue_params, scheduler_params); cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params, epilogue_params, scheduler_params);
C10_CUDA_KERNEL_LAUNCH_CHECK(); CHECK_CUDA_KERNEL_LAUNCH();
} }
template<typename T> template<typename T>
......
...@@ -21,6 +21,18 @@ ...@@ -21,6 +21,18 @@
#include <cutlass/numeric_conversion.h> #include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h> #include <cutlass/numeric_types.h>
#define CHECK_CUDA(call) \
do { \
cudaError_t status_ = call; \
if (status_ != cudaSuccess) { \
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
exit(1); \
} \
} while(0)
#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError())
namespace flash { namespace flash {
using namespace cute; using namespace cute;
...@@ -62,7 +74,7 @@ struct Allreduce { ...@@ -62,7 +74,7 @@ struct Allreduce {
template<> template<>
struct Allreduce<2> { struct Allreduce<2> {
template<typename T, typename Operator> template<typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator &op) { static __device__ __forceinline__ T run(T x, Operator &op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x; return x;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment