#include #include #include #include #define CHECK_CUDA(call) \ do { \ cudaError_t err = call; \ if (err != cudaSuccess) { \ std::cerr << "CUDA error at " << __FILE__ << ":" << __LINE__ << " - " \ << cudaGetErrorString(err) << std::endl; \ exit(EXIT_FAILURE); \ } \ } while (0) constexpr int kMmaM = 16; constexpr int kMmaN = 16; constexpr int kMmaK = 16; constexpr int kWarpM = 64; constexpr int kWarpN = 64; constexpr int kWarpK = 32; constexpr int kBlockM = 128; constexpr int kBlockN = 128; constexpr int kBlockK = 64; constexpr int kWarpNumM = kBlockM / kWarpM; constexpr int kWarpNumN = kBlockN / kWarpN; __global__ void TiledGemmKernel( int M, int N, int K, float alpha, const float* __restrict__ A, const float* __restrict__ B, float beta, float* __restrict__ C) { const int lda = M; const int ldb = K; const int ldc = M; __shared__ float smemA[kBlockM][kBlockK]; __shared__ float smemB[kBlockK][kBlockN]; const int warpId = threadIdx.x / 32; const int laneId = threadIdx.x % 32; const int warpRow = warpId / kWarpNumN; const int warpCol = warpId % kWarpNumN; // 每个线程负责4x4的碎片计算 const int threadRowInWarp = laneId / 4; const int threadColInWarp = laneId % 4; const int blockRow = blockIdx.y * kBlockM; const int blockCol = blockIdx.x * kBlockN; // 每个线程负责4x4的结果,所以每个warp负责64x64 float acc[4][4] = {0}; const int numTiles = (K + kBlockK - 1) / kBlockK; for (int tileIdx = 0; tileIdx < numTiles; ++tileIdx) { // 加载A到共享内存 (M维度分块) for (int i = threadIdx.x; i < kBlockM * kBlockK; i += blockDim.x) { int row = i / kBlockK; int col = i % kBlockK; int globalRow = blockRow + row; int globalCol = tileIdx * kBlockK + col; if (globalRow < M && globalCol < K) { smemA[row][col] = A[globalRow + globalCol * lda]; } else { smemA[row][col] = 0.0f; } } // 加载B到共享内存 (N维度分块) for (int i = threadIdx.x; i < kBlockK * kBlockN; i += blockDim.x) { int row = i / kBlockN; int col = i % kBlockN; int globalRow = tileIdx * kBlockK + row; int globalCol = blockCol + col; if (globalRow < K && globalCol < N) { smemB[row][col] = B[globalRow + globalCol * ldb]; } else { smemB[row][col] = 0.0f; } } __syncthreads(); // 计算当前tile const int warpStartRow = warpRow * kWarpM; const int warpStartCol = warpCol * kWarpN; for (int k = 0; k < kBlockK; k += kMmaK) { // 每个线程加载4个A的元素 float aFrag[4]; #pragma unroll for (int i = 0; i < 4; ++i) { int row = warpStartRow + threadRowInWarp + i * 4; int col = k + threadColInWarp; aFrag[i] = smemA[row][col]; } // 每个线程加载4个B的元素 float bFrag[4]; #pragma unroll for (int j = 0; j < 4; ++j) { int row = k + threadRowInWarp; int col = warpStartCol + threadColInWarp + j * 4; bFrag[j] = smemB[row][col]; } // 计算外积并累加 #pragma unroll for (int i = 0; i < 4; ++i) { #pragma unroll for (int j = 0; j < 4; ++j) { acc[i][j] += aFrag[i] * bFrag[j]; } } } __syncthreads(); } // 写回结果 const int warpStartRow = blockRow + warpRow * kWarpM; const int warpStartCol = blockCol + warpCol * kWarpN; for (int i = 0; i < 4; ++i) { int row = warpStartRow + threadRowInWarp + i * 4; if (row >= M) continue; for (int j = 0; j < 4; ++j) { int col = warpStartCol + threadColInWarp + j * 4; if (col >= N) continue; int idx = row + col * ldc; C[idx] = alpha * acc[i][j] + beta * C[idx]; } } } void TiledGemm( int M, int N, int K, float alpha, const float* A, const float* B, float beta, float* C) { dim3 block(256); dim3 grid( (N + kBlockN - 1) / kBlockN, (M + kBlockM - 1) / kBlockM ); TiledGemmKernel<<>>(M, N, K, alpha, A, B, beta, C); CHECK_CUDA(cudaDeviceSynchronize()); } void ReferenceGemm( int M, int N, int K, float alpha, const float* A, const float* B, float beta, float* C) { for (int i = 0; i < M; ++i) { for (int j = 0; j < N; ++j) { float sum = 0; for (int k = 0; k < K; ++k) { sum += A[i + k * M] * B[k + j * K]; } C[i + j * M] = alpha * sum + beta * C[i + j * M]; } } } void RandomInit(float* data, int size) { for (int i = 0; i < size; ++i) { data[i] = (float(rand()) / RAND_MAX) * 2.0f - 1.0f; } } bool Verify(const float* C1, const float* C2, int M, int N, float tolerance = 1e-3f) { for (int i = 0; i < M; ++i) { for (int j = 0; j < N; ++j) { float diff = fabsf(C1[i + j * M] - C2[i + j * M]); if (diff > tolerance) { std::cerr << "Mismatch at C[" << i << "," << j << "]: " << C1[i + j * M] << " vs " << C2[i + j * M] << " (diff=" << diff << ")" << std::endl; return false; } } } return true; } int main(int argc, char** argv) { int M = 512; int N = 512; int K = 512; float alpha = 1.0f; float beta = 0.0f; if (argc >= 4) { M = atoi(argv[1]); N = atoi(argv[2]); K = atoi(argv[3]); } std::cout << "GEMM: M=" << M << ", N=" << N << ", K=" << K << std::endl; float *h_A, *h_B, *h_C_tiled, *h_C_ref; float *d_A, *d_B, *d_C; h_A = new float[M * K]; h_B = new float[K * N]; h_C_tiled = new float[M * N]; h_C_ref = new float[M * N]; RandomInit(h_A, M * K); RandomInit(h_B, K * N); CHECK_CUDA(cudaMalloc(&d_A, M * K * sizeof(float))); CHECK_CUDA(cudaMalloc(&d_B, K * N * sizeof(float))); CHECK_CUDA(cudaMalloc(&d_C, M * N * sizeof(float))); CHECK_CUDA(cudaMemcpy(d_A, h_A, M * K * sizeof(float), cudaMemcpyHostToDevice)); CHECK_CUDA(cudaMemcpy(d_B, h_B, K * N * sizeof(float), cudaMemcpyHostToDevice)); CHECK_CUDA(cudaMemset(d_C, 0, M * N * sizeof(float))); cudaEvent_t start, stop; CHECK_CUDA(cudaEventCreate(&start)); CHECK_CUDA(cudaEventCreate(&stop)); CHECK_CUDA(cudaEventRecord(start)); TiledGemm(M, N, K, alpha, d_A, d_B, beta, d_C); CHECK_CUDA(cudaEventRecord(stop)); CHECK_CUDA(cudaEventSynchronize(stop)); float milliseconds = 0; CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop)); CHECK_CUDA(cudaMemcpy(h_C_tiled, d_C, M * N * sizeof(float), cudaMemcpyDeviceToHost)); ReferenceGemm(M, N, K, alpha, h_A, h_B, beta, h_C_ref); bool passed = Verify(h_C_tiled, h_C_ref, M, N); float tflops = (2.0f * M * N * K) / (milliseconds * 1e-3f) / 1e12f; std::cout << "Tiled GEMM: " << milliseconds << " ms" << std::endl; std::cout << "Performance: " << tflops << " TFLOPS" << std::endl; std::cout << "Result: " << (passed ? "PASSED" : "FAILED") << std::endl; CHECK_CUDA(cudaEventDestroy(start)); CHECK_CUDA(cudaEventDestroy(stop)); CHECK_CUDA(cudaFree(d_A)); CHECK_CUDA(cudaFree(d_B)); CHECK_CUDA(cudaFree(d_C)); delete[] h_A; delete[] h_B; delete[] h_C_tiled; delete[] h_C_ref; return passed ? 0 : 1; }