Unverified Commit cb516f85 authored by Cameron Shinn's avatar Cameron Shinn Committed by GitHub
Browse files

Remove torchlib dependency from cpp files (#1083)

parent 5f1ae4a3
......@@ -7,14 +7,6 @@
#include <cuda.h>
#include <vector>
#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
#include "cutlass/fast_math.h" // For cutlass::FastDivmod
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -118,9 +110,6 @@ struct Flash_fwd_params : public Qkv_params {
// Local window size
int window_size_left, window_size_right;
// Random state.
at::PhiloxCudaState philox_args;
// Pointer to the RNG seed (idx 0) and offset (idx 1).
uint64_t * rng_state;
......
......@@ -4,8 +4,6 @@
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include "cute/tensor.hpp"
#include "cutlass/cluster_launch.hpp"
......@@ -15,6 +13,7 @@
#include "flash_bwd_preprocess_kernel.h"
#include "flash_bwd_kernel.h"
#include "kernel_traits.h"
#include "utils.h"
template<bool Clear_dQaccum=true, typename Kernel_traits>
__global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) {
......@@ -38,7 +37,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreadsNonWS, 0, stream>>>(params);
// If we use both TMA_STORE (for n_block=0) and TMA_REDUCE_ADD (for n_block>0), we don't need to clear dQaccum
// flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreadsNonWS, 0, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
CHECK_CUDA_KERNEL_LAUNCH();
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
......@@ -157,7 +156,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
// printf("smem_size = %d, q = %d, do = %d, k = %d, v = %d, p = %d, ds = %d\n", smem_size, smem_size_q, smem_size_do, smem_size_k, smem_size_v, smem_size_p, smem_size_ds);
// printf("smem_size = %d, q = %d, do = %d, k = %d, v = %d, ds = %d\n", smem_size, smem_size_q, smem_size_do, smem_size_k, smem_size_v, smem_size_ds);
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
......@@ -179,7 +178,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
}
// cutlass::launch_kernel_on_cluster(launch_params, kernel, params, tma_load_Q, tma_load_dO,
// tma_load_K, tma_load_V, tma_store_dQaccum, tma_store_dK, tma_store_dV);
C10_CUDA_KERNEL_LAUNCH_CHECK();
CHECK_CUDA_KERNEL_LAUNCH();
auto tma_load_dQaccum = make_tma_copy(
typename cute::SM90_TMA_LOAD{},
......@@ -190,20 +189,20 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
// auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
auto kernel_dq = &flash::convert_dQ<Kernel_traits, decltype(tma_load_dQaccum)>;
if (Kernel_traits::kSmemdQSize * 2 + 8 >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
CHECK_CUDA(cudaFuncSetAttribute(
kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize * 2 + 8));
}
kernel_dq<<<grid_m, Kernel_traits::kNThreadsdQ, Kernel_traits::kSmemdQSize * 2 + 8, stream>>>(params, tma_load_dQaccum);
C10_CUDA_KERNEL_LAUNCH_CHECK();
CHECK_CUDA_KERNEL_LAUNCH();
// auto kernel_dkv = &flash_bwd_convert_dkv_kernel<Kernel_traits>;
// if (Kernel_traits::kSmemdKVSize >= 48 * 1024) {
// C10_CUDA_CHECK(cudaFuncSetAttribute(
// CHECK_CUDA(cudaFuncSetAttribute(
// kernel_dkv, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdKVSize));
// }
// int num_n_block = cute::ceil_div(params.seqlen_k, Kernel_traits::kBlockN);
// dim3 grid_n(num_n_block, params.b, params.h);
// kernel_dkv<<<grid_n, Kernel_traits::kNThreads, Kernel_traits::kSmemdKVSize, stream>>>(params);
// C10_CUDA_KERNEL_LAUNCH_CHECK();
// CHECK_CUDA_KERNEL_LAUNCH();
}
......
......@@ -4,8 +4,6 @@
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
......@@ -16,6 +14,7 @@
#include "tile_scheduler.hpp"
#include "flash_fwd_kernel.h"
#include "kernel_traits.h"
#include "utils.h"
template<typename Kernel_traits, bool Is_causal>
......@@ -66,7 +65,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v));
// printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v);
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
int device;
......@@ -75,7 +74,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
cudaError status_ = cudaDeviceGetAttribute(
&multiprocessor_count, cudaDevAttrMultiProcessorCount, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
CHECK_CUDA(status_);
}
dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count);
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
......@@ -83,7 +82,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params, epilogue_params, scheduler_params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
CHECK_CUDA_KERNEL_LAUNCH();
}
template<typename T>
......
......@@ -21,6 +21,18 @@
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
#define CHECK_CUDA(call) \
do { \
cudaError_t status_ = call; \
if (status_ != cudaSuccess) { \
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
exit(1); \
} \
} while(0)
#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError())
namespace flash {
using namespace cute;
......@@ -62,7 +74,7 @@ struct Allreduce {
template<>
struct Allreduce<2> {
template<typename T, typename Operator>
template<typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator &op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment