#pragma once #include #include #include #include #include #include #include #include #include "util.h" template using IntegerBits = typename std::conditional::type>::type>::type>::type; template struct SoftmaxParameters { static_assert(LogElements <= 11, ""); static constexpr int Elements = 1 << LogElements; static constexpr int WarpBatch = Elements <= 128 ? 2 : 1; static constexpr int WarpIterations = Elements <= 32 ? 1 : Elements / 32; using MaskType = IntegerBits; static constexpr int WarpSize = Elements <= 32 ? Elements : 32; static constexpr int MaskStride = WarpSize; }; inline int log2_ceil(int value) { int log2_value = 0; while ((1 << log2_value) < value) ++log2_value; return log2_value; } inline at::ScalarType softmax_mask_dtype(int elements) { if (elements > 1024) { return torch::kInt64; } else if (elements > 512) { return torch::kInt32; } else if (elements > 256) { return torch::kInt16; } return torch::kInt8; } inline int softmax_mask_size(int batch_size, int elements) { int log2_elements = log2_ceil(elements); int e = 1 << log2_elements; int warp_size = e < 32 ? e : 32; return batch_size * warp_size; } inline int softmax_rng_delta_offset(int elements) { int log2_elements = log2_ceil(elements); int e = 1 << log2_elements; int warp_iterations = e <= 32 ? 1 : e / 32; int warp_batch = e <= 128 ? 2 : 1; return warp_iterations * warp_batch; } inline cudaError_t GetNumBlocks(int64_t block_size, int64_t max_blocks, int64_t waves, int *num_blocks) { int dev; { cudaError_t err = cudaGetDevice(&dev); if (err != cudaSuccess) { return err; } } int sm_count; { cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev); if (err != cudaSuccess) { return err; } } int tpm; { cudaError_t err = cudaDeviceGetAttribute(&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev); if (err != cudaSuccess) { return err; } } *num_blocks = std::max(1, std::min(max_blocks, sm_count * tpm / block_size * waves)); return cudaSuccess; } template struct SumOp { __device__ __forceinline__ T operator()(const T &a, const T &b) const { return a + b; } }; template struct MaxOp { __device__ __forceinline__ T operator()(const T &a, const T &b) const { return max(a, b); } }; template