demo.cu 7.08 KB
Newer Older
liuys's avatar
liuys committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
#include <cuda_runtime.h>
#include <iostream>
#include <cmath>
#include <cstdlib>

#define CHECK_CUDA(call) \
  do { \
    cudaError_t err = call; \
    if (err != cudaSuccess) { \
      std::cerr << "CUDA error at " << __FILE__ << ":" << __LINE__ << " - " \
                << cudaGetErrorString(err) << std::endl; \
      exit(EXIT_FAILURE); \
    } \
  } while (0)

constexpr int kMmaM = 16;
constexpr int kMmaN = 16;
constexpr int kMmaK = 16;

constexpr int kWarpM = 64;
constexpr int kWarpN = 64;
constexpr int kWarpK = 32;

constexpr int kBlockM = 128;
constexpr int kBlockN = 128;
constexpr int kBlockK = 64;

constexpr int kWarpNumM = kBlockM / kWarpM;
constexpr int kWarpNumN = kBlockN / kWarpN;

__global__ void TiledGemmKernel(
  int M, int N, int K,
  float alpha,
  const float* __restrict__ A,
  const float* __restrict__ B,
  float beta,
  float* __restrict__ C) {

  const int lda = M;
  const int ldb = K;
  const int ldc = M;

  __shared__ float smemA[kBlockM][kBlockK];
  __shared__ float smemB[kBlockK][kBlockN];

  const int warpId = threadIdx.x / 32;
  const int laneId = threadIdx.x % 32;

  const int warpRow = warpId / kWarpNumN;
  const int warpCol = warpId % kWarpNumN;

  // 每个线程负责4x4的碎片计算
  const int threadRowInWarp = laneId / 4;
  const int threadColInWarp = laneId % 4;

  const int blockRow = blockIdx.y * kBlockM;
  const int blockCol = blockIdx.x * kBlockN;

  // 每个线程负责4x4的结果,所以每个warp负责64x64
  float acc[4][4] = {0};

  const int numTiles = (K + kBlockK - 1) / kBlockK;

  for (int tileIdx = 0; tileIdx < numTiles; ++tileIdx) {
    // 加载A到共享内存 (M维度分块)
    for (int i = threadIdx.x; i < kBlockM * kBlockK; i += blockDim.x) {
      int row = i / kBlockK;
      int col = i % kBlockK;
      int globalRow = blockRow + row;
      int globalCol = tileIdx * kBlockK + col;
      if (globalRow < M && globalCol < K) {
        smemA[row][col] = A[globalRow + globalCol * lda];
      } else {
        smemA[row][col] = 0.0f;
      }
    }

    // 加载B到共享内存 (N维度分块)
    for (int i = threadIdx.x; i < kBlockK * kBlockN; i += blockDim.x) {
      int row = i / kBlockN;
      int col = i % kBlockN;
      int globalRow = tileIdx * kBlockK + row;
      int globalCol = blockCol + col;
      if (globalRow < K && globalCol < N) {
        smemB[row][col] = B[globalRow + globalCol * ldb];
      } else {
        smemB[row][col] = 0.0f;
      }
    }

    __syncthreads();

    // 计算当前tile
    const int warpStartRow = warpRow * kWarpM;
    const int warpStartCol = warpCol * kWarpN;

    for (int k = 0; k < kBlockK; k += kMmaK) {
      // 每个线程加载4个A的元素
      float aFrag[4];
      #pragma unroll
      for (int i = 0; i < 4; ++i) {
        int row = warpStartRow + threadRowInWarp + i * 4;
        int col = k + threadColInWarp;
        aFrag[i] = smemA[row][col];
      }

      // 每个线程加载4个B的元素
      float bFrag[4];
      #pragma unroll
      for (int j = 0; j < 4; ++j) {
        int row = k + threadRowInWarp;
        int col = warpStartCol + threadColInWarp + j * 4;
        bFrag[j] = smemB[row][col];
      }

      // 计算外积并累加
      #pragma unroll
      for (int i = 0; i < 4; ++i) {
        #pragma unroll
        for (int j = 0; j < 4; ++j) {
          acc[i][j] += aFrag[i] * bFrag[j];
        }
      }
    }

    __syncthreads();
  }

  // 写回结果
  const int warpStartRow = blockRow + warpRow * kWarpM;
  const int warpStartCol = blockCol + warpCol * kWarpN;

  for (int i = 0; i < 4; ++i) {
    int row = warpStartRow + threadRowInWarp + i * 4;
    if (row >= M) continue;
    
    for (int j = 0; j < 4; ++j) {
      int col = warpStartCol + threadColInWarp + j * 4;
      if (col >= N) continue;
      
      int idx = row + col * ldc;
      C[idx] = alpha * acc[i][j] + beta * C[idx];
    }
  }
}

void TiledGemm(
  int M, int N, int K,
  float alpha,
  const float* A,
  const float* B,
  float beta,
  float* C) {

  dim3 block(256);
  dim3 grid(
    (N + kBlockN - 1) / kBlockN,
    (M + kBlockM - 1) / kBlockM
  );

  TiledGemmKernel<<<grid, block>>>(M, N, K, alpha, A, B, beta, C);
  CHECK_CUDA(cudaDeviceSynchronize());
}

void ReferenceGemm(
  int M, int N, int K,
  float alpha,
  const float* A,
  const float* B,
  float beta,
  float* C) {

  for (int i = 0; i < M; ++i) {
    for (int j = 0; j < N; ++j) {
      float sum = 0;
      for (int k = 0; k < K; ++k) {
        sum += A[i + k * M] * B[k + j * K];
      }
      C[i + j * M] = alpha * sum + beta * C[i + j * M];
    }
  }
}

void RandomInit(float* data, int size) {
  for (int i = 0; i < size; ++i) {
    data[i] = (float(rand()) / RAND_MAX) * 2.0f - 1.0f;
  }
}

bool Verify(const float* C1, const float* C2, int M, int N, float tolerance = 1e-3f) {
  for (int i = 0; i < M; ++i) {
    for (int j = 0; j < N; ++j) {
      float diff = fabsf(C1[i + j * M] - C2[i + j * M]);
      if (diff > tolerance) {
        std::cerr << "Mismatch at C[" << i << "," << j << "]: "
                  << C1[i + j * M] << " vs " << C2[i + j * M]
                  << " (diff=" << diff << ")" << std::endl;
        return false;
      }
    }
  }
  return true;
}

int main(int argc, char** argv) {
  int M = 512;
  int N = 512;
  int K = 512;
  float alpha = 1.0f;
  float beta = 0.0f;

  if (argc >= 4) {
    M = atoi(argv[1]);
    N = atoi(argv[2]);
    K = atoi(argv[3]);
  }

  std::cout << "GEMM: M=" << M << ", N=" << N << ", K=" << K << std::endl;

  float *h_A, *h_B, *h_C_tiled, *h_C_ref;
  float *d_A, *d_B, *d_C;

  h_A = new float[M * K];
  h_B = new float[K * N];
  h_C_tiled = new float[M * N];
  h_C_ref = new float[M * N];

  RandomInit(h_A, M * K);
  RandomInit(h_B, K * N);

  CHECK_CUDA(cudaMalloc(&d_A, M * K * sizeof(float)));
  CHECK_CUDA(cudaMalloc(&d_B, K * N * sizeof(float)));
  CHECK_CUDA(cudaMalloc(&d_C, M * N * sizeof(float)));

  CHECK_CUDA(cudaMemcpy(d_A, h_A, M * K * sizeof(float), cudaMemcpyHostToDevice));
  CHECK_CUDA(cudaMemcpy(d_B, h_B, K * N * sizeof(float), cudaMemcpyHostToDevice));
  CHECK_CUDA(cudaMemset(d_C, 0, M * N * sizeof(float)));

  cudaEvent_t start, stop;
  CHECK_CUDA(cudaEventCreate(&start));
  CHECK_CUDA(cudaEventCreate(&stop));

  CHECK_CUDA(cudaEventRecord(start));
  TiledGemm(M, N, K, alpha, d_A, d_B, beta, d_C);
  CHECK_CUDA(cudaEventRecord(stop));
  CHECK_CUDA(cudaEventSynchronize(stop));

  float milliseconds = 0;
  CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));

  CHECK_CUDA(cudaMemcpy(h_C_tiled, d_C, M * N * sizeof(float), cudaMemcpyDeviceToHost));

  ReferenceGemm(M, N, K, alpha, h_A, h_B, beta, h_C_ref);

  bool passed = Verify(h_C_tiled, h_C_ref, M, N);

  float tflops = (2.0f * M * N * K) / (milliseconds * 1e-3f) / 1e12f;
  std::cout << "Tiled GEMM: " << milliseconds << " ms" << std::endl;
  std::cout << "Performance: " << tflops << " TFLOPS" << std::endl;
  std::cout << "Result: " << (passed ? "PASSED" : "FAILED") << std::endl;

  CHECK_CUDA(cudaEventDestroy(start));
  CHECK_CUDA(cudaEventDestroy(stop));
  CHECK_CUDA(cudaFree(d_A));
  CHECK_CUDA(cudaFree(d_B));
  CHECK_CUDA(cudaFree(d_C));

  delete[] h_A;
  delete[] h_B;
  delete[] h_C_tiled;
  delete[] h_C_ref;

  return passed ? 0 : 1;
}