#include "misc_kernels_impl.cuh" #include "misc_kernels.h" #include "dispatch_utils.h" Tensor add(Tensor a, Tensor b) { assert(a.shape.dataExtent == b.shape.dataExtent); assert(a.dtype() == b.dtype()); assert(a.is_contiguous()); assert(b.is_contiguous()); int threadsPerBlock = 1024; int blocksPerGrid = (a.numel() + threadsPerBlock - 1) / threadsPerBlock; auto stream = getCurrentCUDAStream(); Tensor out = Tensor::empty_like(a); dispatch(out.scalar_type(), [&]() { add_kernel<<>>( a.data_ptr(), b.data_ptr(), out.data_ptr(), out.numel()); }); return out; } void mul_add(Tensor x, Tensor scale, Tensor bias) { // assert(scale.shape.data == bias.shape.data); // FIXME FIXME assert(x.numel() % scale.numel() == 0); assert(x.numel() % bias.numel() == 0); assert(x.dtype() == scale.dtype()); assert(x.dtype() == bias.dtype()); constexpr int unroll = 8; assert((uintptr_t)x.data_ptr() % (x.scalar_size() * unroll) == 0); assert((uintptr_t)scale.data_ptr() % (x.scalar_size() * unroll) == 0); assert((uintptr_t)bias.data_ptr() % (x.scalar_size() * unroll) == 0); assert(x.numel() % unroll == 0); assert(scale.numel() % unroll == 0); assert(bias.numel() % unroll == 0); int threadsPerBlock = 1024; int blocksPerGrid = (x.numel() + threadsPerBlock * unroll - 1) / (threadsPerBlock * unroll); auto stream = getCurrentCUDAStream(); dispatch(x.scalar_type(), [&]() { mul_add_kernel<<>>( x.data_ptr(), scale.data_ptr(), bias.data_ptr(), x.numel(), scale.numel(), bias.numel()); }); } Tensor embedding(Tensor input_id, Tensor lookup) { assert(input_id.dtype() == Tensor::INT32); assert(lookup.ndims() == 2); auto shapeOut = input_id.shape; shapeOut.dataExtent.push_back(lookup.shape[-1]); auto stream = getCurrentCUDAStream(); Tensor out = Tensor::empty(shapeOut, lookup.scalar_type(), input_id.device()); dispatch(out.scalar_type(), [&]() { EmbeddingKernel<<>>( input_id.data_ptr(), out.data_ptr(), lookup.data_ptr(), lookup.shape[-1]); }); return out; } Tensor argmax_sample(Tensor logits) { assert(logits.ndims() == 2); auto stream = getCurrentCUDAStream(); Tensor out = Tensor::empty({logits.shape[0]}, Tensor::INT32, logits.device()); dispatch(logits.scalar_type(), [&]() { argmax_sample_kernel<<>>( logits.data_ptr(), out.data_ptr(), logits.shape[1] ); }); return out; } void splitqkv(Tensor qkv, Tensor q, Tensor k, Tensor v) { // FIXME FIXME // assert(qkv.shape[0] == q.shape[0]); // assert(qkv.shape[0] == k.shape[0]); // assert(qkv.shape[0] == v.shape[0]); auto stream = getCurrentCUDAStream(); int dim_q = q.shape[-1] * q.shape[-2]; int dim_k = k.shape[-1] * k.shape[-2]; int dim_v = v.shape[-1] * v.shape[-2]; assert(dim_k == dim_v); assert(dim_q + dim_k + dim_v == qkv.shape[-1]); int num_tokens = qkv.numel() / qkv.shape[-1]; dispatch(qkv.scalar_type(), [&]() { splitqkv_kernel<<>>( qkv.data_ptr(), q.data_ptr(), k.data_ptr(), v.data_ptr(), dim_q, dim_k ); }); } template std::array split_mod(Tensor input) { assert(input.shape[-1] % N == 0); int threadsPerBlock = 1024; int blocksPerGrid = (input.numel() + threadsPerBlock - 1) / threadsPerBlock; auto stream = getCurrentCUDAStream(); auto shapeOut = input.shape; shapeOut[-1] /= N; std::array out; for (int k = 0; k < N; k++) { out[k] = Tensor::empty(shapeOut, input.scalar_type(), input.device()); } dispatch(input.scalar_type(), [&]() { std::array outPtr; for (int k = 0; k < N; k++) { outPtr[k] = out[k].template data_ptr(); } split_mod_kernel<<>>( input.data_ptr(), outPtr, input.numel()); }); return out; } Tensor quant_static(Tensor x, float scale) { Tensor out = Tensor::empty(x.shape, Tensor::INT8, x.device()); constexpr int unroll = 8; assert((uintptr_t)x.data_ptr() % (x.scalar_size() * unroll) == 0); int threadsPerBlock = 1024; int blocksPerGrid = (x.numel() + threadsPerBlock * unroll - 1) / (threadsPerBlock * unroll); auto stream = getCurrentCUDAStream(); dispatch(x.scalar_type(), [&]() { quant_kernel_static<<>>( x.data_ptr(), out.data_ptr(), (scalar_t)scale, x.numel()); }); return out; } Tensor quant_static_fuse_gelu(Tensor x, float scale) { Tensor out = Tensor::empty(x.shape, Tensor::INT8, x.device()); constexpr int unroll = 8; assert((uintptr_t)x.data_ptr() % (x.scalar_size() * unroll) == 0); int threadsPerBlock = 1024; int blocksPerGrid = (x.numel() + threadsPerBlock * unroll - 1) / (threadsPerBlock * unroll); auto stream = getCurrentCUDAStream(); dispatch(x.scalar_type(), [&]() { quant_kernel_static_fuse_gelu<<>>( x.data_ptr(), out.data_ptr(), (scalar_t)scale, x.numel()); }); return out; } void cast(Tensor input, Tensor output) { assert(input.is_contiguous()); assert(output.is_contiguous()); assert(input.shape.dataExtent == output.shape.dataExtent); auto stream = getCurrentCUDAStream(); dispatch(input.scalar_type(), [&]() { dispatch(output.scalar_type(), [&]() { constexpr int unroll = 16 / std::max(sizeof(input_t), sizeof(output_t)); int threadsPerBlock = 1024; int blocksPerGrid = (int)ceilDiv(input.numel(), threadsPerBlock * unroll); cast_kernel<<>>( input.data_ptr(), output.data_ptr(), input.numel()); checkCUDA(cudaGetLastError()); }); }); } Tensor topk(Tensor x, int k) { constexpr int MAXK = 64 + 4; const int N = x.shape[-1]; const int batch = x.numel() / N; assert(k <= N); assert(k <= MAXK); auto outShape = x.shape; outShape[-1] = k; outShape.dataStride.clear(); Tensor out = Tensor::empty(outShape, Tensor::INT32, x.device()); auto stream = getCurrentCUDAStream(); dispatchVal(k, std::make_integer_sequence(), [&]() { if constexpr (K == 0) { assert(false); return; } if constexpr (K > 0) { dispatch(x.scalar_type(), [&]() { topk_kernel<<>>( x.data_ptr(), out.data_ptr(), N, x.stride(-2), batch ); checkCUDA(cudaGetLastError()); }); } }); return out; } template std::array split_mod<2>(Tensor input); template std::array split_mod<3>(Tensor input); template std::array split_mod<4>(Tensor input); template std::array split_mod<5>(Tensor input); template std::array split_mod<6>(Tensor input);