Unverified Commit a73bf6de authored by Weile's avatar Weile Committed by GitHub
Browse files

reduce compile time - move se type to runtime var (#206)

parent 4f4cad7e
......@@ -3015,18 +3015,18 @@ static bool IsConfiguredGid(union ibv_gid const& gid)
}
// Kernel for GFX execution
template <typename PACKED_FLOAT, int BLOCKSIZE, int UNROLL, int TEMPORAL_MODE, int SE_TYPE>
template <typename PACKED_FLOAT, int BLOCKSIZE, int UNROLL, int TEMPORAL_MODE>
__global__ void __launch_bounds__(BLOCKSIZE)
GpuReduceKernel(SubExecParam* params, int waveOrder, int numSubIterations)
GpuReduceKernel(SubExecParam* params, int seType, int warpSize, int waveOrder, int numSubIterations)
{
int64_t startCycle;
// For warp-level, each warp's first thread records timing; for threadblock-level, only first thread of block
bool shouldRecordTiming = (SE_TYPE == 1) ? (threadIdx.x % warpSize == 0) : (threadIdx.x == 0);
bool shouldRecordTiming = (seType == 1) ? (threadIdx.x % warpSize == 0) : (threadIdx.x == 0);
if (shouldRecordTiming) startCycle = GetTimestamp();
// SE_TYPE: 0=threadblock, 1=warp
// seType: 0=threadblock, 1=warp
int subExecIdx;
if (SE_TYPE == 0) {
if (seType == 0) {
// Threadblock-level: each threadblock is a subexecutor
subExecIdx = blockIdx.y;
} else {
......@@ -3039,7 +3039,7 @@ static bool IsConfiguredGid(union ibv_gid const& gid)
SubExecParam& p = params[subExecIdx];
// For warp-level dispatch, inactive warps should return early
if (SE_TYPE == 1 && p.N == 0) return;
if (seType == 1 && p.N == 0) return;
// Filter by XCC
#if !defined(__NVCC__)
......@@ -3060,7 +3060,7 @@ static bool IsConfiguredGid(union ibv_gid const& gid)
int32_t const nTeams = p.teamSize; // Number of threadblocks working together on this subarray
int32_t const teamIdx = p.teamIdx; // Index of this threadblock within the team
int32_t nWaves, waveIdx;
if (SE_TYPE == 0) {
if (seType == 0) {
// Threadblock-level: all wavefronts in block work together
nWaves = BLOCKSIZE / warpSize; // Number of wavefronts within this threadblock
waveIdx = threadIdx.x / warpSize; // Index of this wavefront within the threadblock
......@@ -3171,7 +3171,7 @@ static bool IsConfiguredGid(union ibv_gid const& gid)
}
// Wait for all threads to finish
if (SE_TYPE == 1) {
if (seType == 1) {
// For warp-level, sync within warp only
__syncwarp();
} else {
......@@ -3188,15 +3188,11 @@ static bool IsConfiguredGid(union ibv_gid const& gid)
}
}
#define GPU_KERNEL_SE_TYPE_DECL(BLOCKSIZE, UNROLL, DWORD, TEMPORAL) \
{GpuReduceKernel<DWORD, BLOCKSIZE, UNROLL, TEMPORAL, 0>, \
GpuReduceKernel<DWORD, BLOCKSIZE, UNROLL, TEMPORAL, 1>}
#define GPU_KERNEL_TEMPORAL_DECL(BLOCKSIZE, UNROLL, DWORD) \
{GPU_KERNEL_SE_TYPE_DECL(BLOCKSIZE, UNROLL, DWORD, TEMPORAL_NONE), \
GPU_KERNEL_SE_TYPE_DECL(BLOCKSIZE, UNROLL, DWORD, TEMPORAL_LOAD), \
GPU_KERNEL_SE_TYPE_DECL(BLOCKSIZE, UNROLL, DWORD, TEMPORAL_STORE),\
GPU_KERNEL_SE_TYPE_DECL(BLOCKSIZE, UNROLL, DWORD, TEMPORAL_BOTH)}
{GpuReduceKernel<DWORD, BLOCKSIZE, UNROLL, TEMPORAL_NONE>, \
GpuReduceKernel<DWORD, BLOCKSIZE, UNROLL, TEMPORAL_LOAD>, \
GpuReduceKernel<DWORD, BLOCKSIZE, UNROLL, TEMPORAL_STORE>, \
GpuReduceKernel<DWORD, BLOCKSIZE, UNROLL, TEMPORAL_BOTH>}
#define GPU_KERNEL_DWORD_DECL(BLOCKSIZE, UNROLL) \
{GPU_KERNEL_TEMPORAL_DECL(BLOCKSIZE, UNROLL, float), \
......@@ -3213,9 +3209,9 @@ static bool IsConfiguredGid(union ibv_gid const& gid)
GPU_KERNEL_DWORD_DECL(BLOCKSIZE, 7), \
GPU_KERNEL_DWORD_DECL(BLOCKSIZE, 8)}
// Table of all GPU Reduction kernel functions (templated blocksize / unroll / dword size / temporal / se_type)
typedef void (*GpuKernelFuncPtr)(SubExecParam*, int, int);
GpuKernelFuncPtr GpuKernelTable[MAX_WAVEGROUPS][MAX_UNROLL][3][4][2] =
// Table of all GPU Reduction kernel functions (templated blocksize / unroll / dword size / temporal)
typedef void (*GpuKernelFuncPtr)(SubExecParam*, int, int, int, int);
GpuKernelFuncPtr GpuKernelTable[MAX_WAVEGROUPS][MAX_UNROLL][3][4] =
{
GPU_KERNEL_UNROLL_DECL(64),
GPU_KERNEL_UNROLL_DECL(128),
......@@ -3258,17 +3254,18 @@ static bool IsConfiguredGid(union ibv_gid const& gid)
int wordSizeIdx = cfg.gfx.wordSize == 1 ? 0 :
cfg.gfx.wordSize == 2 ? 1 :
2;
auto gpuKernel = GpuKernelTable[cfg.gfx.blockSize/64 - 1][cfg.gfx.unrollFactor - 1][wordSizeIdx][cfg.gfx.temporalMode][cfg.gfx.seType];
auto gpuKernel = GpuKernelTable[cfg.gfx.blockSize/64 - 1][cfg.gfx.unrollFactor - 1][wordSizeIdx][cfg.gfx.temporalMode];
int warpSize = GetWarpSize();
#if defined(__NVCC__)
if (startEvent != NULL)
ERR_CHECK(hipEventRecord(startEvent, stream));
gpuKernel<<<gridSize, blockSize, 0, stream>>>(rss.subExecParamGpuPtr, cfg.gfx.waveOrder, cfg.general.numSubIterations);
gpuKernel<<<gridSize, blockSize, 0, stream>>>(rss.subExecParamGpuPtr, cfg.gfx.seType, warpSize, cfg.gfx.waveOrder, cfg.general.numSubIterations);
if (stopEvent != NULL)
ERR_CHECK(hipEventRecord(stopEvent, stream));
#else
hipExtLaunchKernelGGL(gpuKernel, gridSize, blockSize, 0, stream, startEvent, stopEvent,
0, rss.subExecParamGpuPtr, cfg.gfx.waveOrder, cfg.general.numSubIterations);
0, rss.subExecParamGpuPtr, cfg.gfx.seType, warpSize, cfg.gfx.waveOrder, cfg.general.numSubIterations);
#endif
ERR_CHECK(hipStreamSynchronize(stream));
......@@ -3335,19 +3332,20 @@ static bool IsConfiguredGid(union ibv_gid const& gid)
int wordSizeIdx = cfg.gfx.wordSize == 1 ? 0 :
cfg.gfx.wordSize == 2 ? 1 :
2;
auto gpuKernel = GpuKernelTable[cfg.gfx.blockSize/64 - 1][cfg.gfx.unrollFactor - 1][wordSizeIdx][cfg.gfx.temporalMode][cfg.gfx.seType];
auto gpuKernel = GpuKernelTable[cfg.gfx.blockSize/64 - 1][cfg.gfx.unrollFactor - 1][wordSizeIdx][cfg.gfx.temporalMode];
int warpSize = GetWarpSize();
#if defined(__NVCC__)
if (cfg.gfx.useHipEvents)
ERR_CHECK(hipEventRecord(exeInfo.startEvents[0], stream));
gpuKernel<<<gridSize, blockSize, 0 , stream>>>(exeInfo.subExecParamGpu, cfg.gfx.waveOrder, cfg.general.numSubIterations);
gpuKernel<<<gridSize, blockSize, 0 , stream>>>(exeInfo.subExecParamGpu, cfg.gfx.seType, warpSize, cfg.gfx.waveOrder, cfg.general.numSubIterations);
if (cfg.gfx.useHipEvents)
ERR_CHECK(hipEventRecord(exeInfo.stopEvents[0], stream));
#else
hipExtLaunchKernelGGL(gpuKernel, gridSize, blockSize, 0, stream,
cfg.gfx.useHipEvents ? exeInfo.startEvents[0] : NULL,
cfg.gfx.useHipEvents ? exeInfo.stopEvents[0] : NULL, 0,
exeInfo.subExecParamGpu, cfg.gfx.waveOrder, cfg.general.numSubIterations);
exeInfo.subExecParamGpu, cfg.gfx.seType, warpSize, cfg.gfx.waveOrder, cfg.general.numSubIterations);
#endif
ERR_CHECK(hipStreamSynchronize(stream));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment