Commit 869b7e83 authored by Tim Dettmers's avatar Tim Dettmers
Browse files

Warp multi-specialization 240.

parent 77f15fdc
...@@ -3058,8 +3058,8 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, ...@@ -3058,8 +3058,8 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
const int half_warp_lane = threadIdx.x % 16; const int half_warp_lane = threadIdx.x % 16;
const int batch_size_warps = (WARPS-1)*2; const int batch_size_warps = (WARPS-1)*2;
T local_A[1]; T local_A[2];
T local_B[32]; T local_B[64];
const int a_tile_offset = 16; const int a_tile_offset = 16;
const int b_tile_offset = (16*32 + 16); const int b_tile_offset = (16*32 + 16);
...@@ -3075,14 +3075,32 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, ...@@ -3075,14 +3075,32 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
int ticktock = 0; int ticktock = 0;
int idx = 0 + threadIdx.x; int idx = 0 + threadIdx.x;
int loaded_values = 0;
// prefetch // prefetch
if(idx < K && warp_id < (WARPS-1)) if(idx < K && warp_id < (WARPS-1))
{ {
local_A[0] = A[idx]; if(loaded_values == 0)
{
local_A[0] = A[idx];
local_A[1] = A[idx+blockDim.x-32];
#pragma unroll 32 #pragma unroll 32
for(int col = 0; col < 32; col++) for(int col = 0; col < 32; col++)
local_B[col] = B[(col_offset+col)*ldb+idx]; {
local_B[col] = B[(col_offset+col)*ldb+idx];
local_B[col+32] = B[(col_offset+col)*ldb+idx+blockDim.x-32];
}
loaded_values = 1;
}
else
{
local_A[0] = local_A[1];
loaded_values--;
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+32];
}
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
...@@ -3113,11 +3131,35 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, ...@@ -3113,11 +3131,35 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
__syncthreads(); __syncthreads();
if(idx < K && warp_id < (WARPS-1)) if(idx < K && warp_id < (WARPS-1))
{ {
local_A[0] = A[idx]; //local_A[0] = A[idx];
#pragma unroll 32 //#pragma unroll 32
for(int col = 0; col < 32; col++) //for(int col = 0; col < 32; col++)
local_B[col] = B[(col_offset+col)*ldb+idx]; // local_B[col] = B[(col_offset+col)*ldb+idx];
if(loaded_values == 0)
{
local_A[0] = A[idx];
local_A[1] = A[idx+blockDim.x-32];
#pragma unroll 32
for(int col = 0; col < 32; col++)
{
local_B[col] = B[(col_offset+col)*ldb+idx];
local_B[col+32] = B[(col_offset+col)*ldb+idx+blockDim.x-32];
}
loaded_values = 1;
}
else
{
local_A[0] = local_A[1];
loaded_values--;
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+32];
}
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
......
...@@ -2376,8 +2376,8 @@ def test_cutlass3_gemm(dtype): ...@@ -2376,8 +2376,8 @@ def test_cutlass3_gemm(dtype):
#print('') #print('')
#print(A) #print(A)
#print(B.t()) #print(B.t())
#A[:, :-3] = 0 #A[:, :-1] = 0
#B[:, :-3] = 0 #B[:, :-1] = 0
C1 = torch.matmul(A, B.t()) C1 = torch.matmul(A, B.t())
...@@ -2399,7 +2399,7 @@ def test_cutlass3_gemm(dtype): ...@@ -2399,7 +2399,7 @@ def test_cutlass3_gemm(dtype):
#if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5: #if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
# print('') # print('')
# print(i, err, mag.item(), relerr.item()) # print(i, err, relerr)
# print(A.flatten()[-6:]) # print(A.flatten()[-6:])
# print(B.flatten()[-6:]) # print(B.flatten()[-6:])
# out = A.flatten()[-6:]*B.flatten()[-6:] # out = A.flatten()[-6:]*B.flatten()[-6:]
...@@ -2412,7 +2412,7 @@ def test_cutlass3_gemm(dtype): ...@@ -2412,7 +2412,7 @@ def test_cutlass3_gemm(dtype):
c = int(C1.numel()*0.0014*(dim/256))+1 c = int(C1.numel()*0.0014*(dim/256))+1
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False) c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=True)
#print(c/math.sqrt(dim)) #print(c/math.sqrt(dim))
print('') print('')
print(dim, sum(errs)/len(errs)/math.sqrt(dim)) print(dim, sum(errs)/len(errs)/math.sqrt(dim))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment