// Copyright (c) Facebook, Inc. and its affiliates. // // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. #include "kernels.cuh" #include "common.cuh" #include #include #include #include #include #include #include #include #include #include #define HLF_MAX 65504 #define TH 1024 #define NUM 4 #define NUM_BLOCK 4096 __device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0}; // source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda __device__ float atomicMax(float* address, float val) { int* address_as_i = reinterpret_cast(address); int old = *address_as_i, assumed; do { assumed = old; old = atomicCAS( reinterpret_cast(address), assumed, __float_as_int(fmaxf(val, __int_as_float(assumed)))); } while (assumed != old); return __int_as_float(old); } __device__ float dDequantizeFP4Tree(unsigned char val, float absmax) { float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; if((val & 0b0100) == 4) // 0 if((val & 0b0010) == 2) //01 if((val & 0b0001) == 1) // 111 return 0.25000000f*absmax*sign; // 1111 else return 0.16666667f*absmax*sign; // 1110 else if((val & 0b0001) == 1) // 110 return 0.50000000f*absmax*sign; // 1101 else return 0.33333333f*absmax*sign; // 1100 else if((val & 0b0010) == 2) //10 if((val & 0b0001) == 1) // 101 return 1.00000000f*absmax*sign; // 1011 else return 0.66666667f*absmax*sign; // 1010 else if((val & 0b0001) == 1) // 100 return 5.208333333e-03f*absmax*sign; // 1001 else return 0.00000000f*absmax*sign; // 1000 } __device__ unsigned char dQuantizeFP4(float x) { // FP4 with bias of 3 // first bit is a sign // subnormals // 0b000 = 0 // 0b001 = 0.0625 // 0b110 = 2 // 0b111 = 3 // 0b100 = 4 // 0b101 = 6 // 0b010 = 8 // 0b011 = 12 // we do a binary search // the pivots are divided by 12 (the FP4 absmax) // since we assume input data is in [-1.0, 1.0] // !be careful here, its easy to make a mistake // that is difficult to notice if you add an extra // zero somewhere! int sign = x < 0 ? 0b1000 : 0b0000; x = fabsf(x); if(x > 0.29166667f) if( x > 0.583333f) if( x > 0.8333333f) return 0b0011+sign; else return 0b0010+sign; else if(x > 0.4166667f) return 0b101+sign; else return 0b100+sign; else if(x > 0.0859375f) if(x > 0.20833333f) return 0b0111+sign; else return 0b0110+sign; else if(x > 0.00260417f) return 0b0001+sign; else return 0b0000+sign; } __device__ __forceinline__ float dDequantizeNF4(unsigned char val) { // the values for this tree was generated by test_normal_map_tree // in the file tests/test_functional.py if((val & 0b1000) == 8) if((val & 0b0100) == 4) // 1 if((val & 0b0010) == 2) // 11 if((val & 0b0001) == 1) // 111 return 1.0f; else return 0.7229568362236023f; else if((val & 0b0001) == 1) // 110 return 0.5626170039176941f; else return 0.44070982933044434f; else if((val & 0b0010) == 2) //10 if((val & 0b0001) == 1) // 101 return 0.33791524171829224f; else return 0.24611230194568634f; else if((val & 0b0001) == 1) // 100 return 0.16093020141124725f; else return 0.07958029955625534f; else if((val & 0b0100) == 4) // 0 if((val & 0b0010) == 2) //01 if((val & 0b0001) == 1) // 011 return 0.0f; else return -0.09105003625154495f; else if((val & 0b0001) == 1) // 010 return -0.18477343022823334f; else return -0.28444138169288635f; else if((val & 0b0010) == 2) //00 if((val & 0b0001) == 1) // 001 return -0.39491748809814453f; else return -0.5250730514526367f; else if((val & 0b0001) == 1) // 000 return -0.6961928009986877f; else return -1.0f; } __device__ unsigned char dQuantizeNF4(float x) { // the values for this tree was generated by test_normal_map_tree // in the file tests/test_functional.py if(x > 0.03979014977812767f) if(x > 0.3893125355243683f) // 1 if(x > 0.6427869200706482f) // 11 if(x > 0.8614784181118011f) // 111 return 0b1111; else return 0b1110; else if(x > 0.5016634166240692f) // 110 return 0b1101; else return 0b1100; else if(x > 0.2035212516784668f) // 10 if(x > 0.2920137718319893f) // 101 return 0b1011; else return 0b1010; else if(x > 0.1202552504837513f) // 100 return 0b1001; else return 0b1000; else if(x > -0.33967943489551544f) // 0 if(x > -0.13791173323988914f) // 01 if(x > -0.045525018125772476f) // 011 return 0b0111; else return 0b0110; else if(x > -0.23460740596055984f) // 010 return 0b0101; else return 0b0100; else if(x > -0.6106329262256622f) // 00 if(x > -0.4599952697753906f) // 001 return 0b0011; else return 0b0010; else if(x > -0.8480964004993439f) // 000 return 0b0001; else return 0b0000; } // sign function for lion // taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA template __device__ int sgn(T val) { return (T(0) < val) - (val < T(0)); } template __device__ unsigned char dQuantize(float* smem_code, const float rand, float x) { int pivot = 127; int upper_pivot = 255; int lower_pivot = 0; float lower = -1.0f; float upper = 1.0f; float val = smem_code[pivot]; // i>>=1 = {32, 16, 8, 4, 2, 1} for(int i = 64; i > 0; i>>=1) { if(x > val) { lower_pivot = pivot; lower = val; pivot+=i; } else { upper_pivot = pivot; upper = val; pivot-=i; } val = smem_code[pivot]; } if(upper_pivot == 255) upper = smem_code[upper_pivot]; if(lower_pivot == 0) lower = smem_code[lower_pivot]; if(!STOCHASTIC) { if(x > val) { float midpoint = (upper+val)*0.5f; if(x > midpoint) { return upper_pivot; } else return pivot; } else { float midpoint = (lower+val)*0.5f; if(x < midpoint) return lower_pivot; else return pivot; } } else { if(x > val) { float dist_to_upper = fabsf(upper-x); float dist_full = upper-val; if(rand >= dist_to_upper/dist_full) return upper_pivot; else return pivot; } else { float dist_to_lower = fabsf(lower-x); float dist_full = val-lower; if(rand >= dist_to_lower/dist_full) return lower_pivot; else return pivot; } } } template __device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadrants, float *__restrict__ const smem_code, float x) { int pivot = 127; int upper_pivot = 255; int lower_pivot = 0; float lower = SIGNED ? -1.0f : 0.0f; float upper = 1.0f; float midpoint; float val = quadrants[1]; int local_pivot = 1; int offset = 1; // i>>=1 = {32, 16, 8, 4, 2, 1} for(int i = 64; i > 0; i>>=1) { if(x > val) { lower_pivot = pivot; lower = val; pivot+=i; //val = i == 64 ? quadrants[2] : smem_code[pivot]; local_pivot += offset; } else { upper_pivot = pivot; upper = val; pivot-=i; //val = i == 64 ? quadrants[0] : smem_code[pivot]; local_pivot -= offset; } val = i >= 64 ? quadrants[local_pivot] : smem_code[pivot]; offset -= 1; } if(x > val) { midpoint = (upper+val)*0.5f; if(x > midpoint) return upper_pivot; else return pivot; } else { midpoint = (lower+val)*0.5f; if(x < midpoint) return lower_pivot; else return pivot; } } __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n) { const int tid = threadIdx.x + (blockDim.x*blockIdx.x); const int numThreads = blockDim.x*gridDim.x; for(int i = tid; i < n; i+=numThreads) { int idx = (index1[i]*maxidx1) + index2[i]; atomicAdd(&histogram[idx], src[i]); } } #define THREADS_ESTIMATE 512 #define NUM_ESTIMATE 8 #define BLOCK_ESTIMATE 4096 template __launch_bounds__(THREADS_ESTIMATE, 1) __global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n) { const int n_full = (BLOCK_ESTIMATE*(n/BLOCK_ESTIMATE)) + (n % BLOCK_ESTIMATE == 0 ? 0 : BLOCK_ESTIMATE); int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*BLOCK_ESTIMATE) : BLOCK_ESTIMATE; const int base_idx = (blockIdx.x * BLOCK_ESTIMATE); const float reciprocal_num_blocks = 1.0f/(n < 4096 ? 1.0f : (n/BLOCK_ESTIMATE)); T vals[NUM_ESTIMATE]; typedef cub::BlockRadixSort BlockRadixSort; typedef cub::BlockLoad LoadFloat; __shared__ union { typename LoadFloat::TempStorage loadf; typename BlockRadixSort::TempStorage sort; int smem_qidx[BLOCK_ESTIMATE]; } temp_storage; for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_ESTIMATE) { valid_items = n - i > BLOCK_ESTIMATE ? BLOCK_ESTIMATE : n - i; // do not process half-blocks if(valid_items < BLOCK_ESTIMATE && n > BLOCK_ESTIMATE){ continue; } #pragma unroll 4 for(int j = 0; j < NUM_ESTIMATE; j++) vals[j] = max_val; __syncthreads(); LoadFloat(temp_storage.loadf).Load(&(A[i]), vals, valid_items); #pragma unroll 4 for(int j = 0; j < NUM_ESTIMATE; j++) vals[j] = ((float)vals[j]) * reciprocal_num_blocks; __syncthreads(); // sort into striped pattern to mitigate bank conflicts // striped pattern index for thread 0 [0, 1024, 2048, 3096] // striped pattern index for thread 1 [1, 1025, 2049, 3097] BlockRadixSort(temp_storage.sort).SortBlockedToStriped(vals); __syncthreads(); for(int j = threadIdx.x; j < BLOCK_ESTIMATE; j+=blockDim.x) temp_storage.smem_qidx[j] = -1; __syncthreads(); if(threadIdx.x < 256) { float q_interval = (1.0f-(2.0f*offset))/255.0f; int local_idx = round(((offset+(threadIdx.x*q_interval))*(valid_items-1))); temp_storage.smem_qidx[local_idx] = threadIdx.x; } __syncthreads(); for(int i = threadIdx.x; i < BLOCK_ESTIMATE; i+=blockDim.x) { if(temp_storage.smem_qidx[i] != -1) atomicAdd(&code[temp_storage.smem_qidx[i]], vals[i/THREADS_ESTIMATE]); } } } __launch_bounds__(TH, 4) __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n) { const int n_full = (NUM_BLOCK*(n/NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK); int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*NUM_BLOCK) : NUM_BLOCK; const int base_idx = (blockIdx.x * NUM_BLOCK); float vals[NUM]; unsigned char qvals[NUM]; //const int lane_id = threadIdx.x % 2; typedef cub::BlockLoad LoadFloat; typedef cub::BlockStore StoreChar; __shared__ typename LoadFloat::TempStorage loadf; __shared__ typename StoreChar::TempStorage storec; __shared__ float smem_code[256]; //__shared__ float smem_code[2][257]; if(threadIdx.x < 256) { smem_code[threadIdx.x] = code[threadIdx.x]; //smem_code[0][threadIdx.x] = code[threadIdx.x]; //smem_code[1][threadIdx.x] = smem_code[0][threadIdx.x]; } for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_BLOCK) { // number of values already processed in blocks + // number of values already processed in this block + // rand_offset % mod value valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i; __syncthreads(); LoadFloat(loadf).Load(&(A[i]), vals, valid_items); #pragma unroll 4 for(int j = 0; j < NUM; j++) qvals[j] = dQuantize<0>(smem_code, 0.0f, vals[j]); __syncthreads(); StoreChar(storec).Store(&(out[i]), qvals, valid_items); } } template //__launch_bounds__(TH, 4) __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n) { const int n_full = gridDim.x * BLOCK_SIZE; int valid_items = 0; const int base_idx = (blockIdx.x * BLOCK_SIZE); T vals[NUM_PER_TH]; float rand_vals[NUM_PER_TH]; unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH]; //float local_abs_max = -FLT_MAX; float local_abs_max = 0.0f; int local_rand_idx = 0; typedef cub::BlockLoad LoadT; typedef cub::BlockStore 0) ? NUM_PER_TH/2 : NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar; typedef cub::BlockReduce BlockReduce; typedef cub::BlockLoad LoadFloat; __shared__ typename LoadT::TempStorage loadt; __shared__ typename LoadFloat::TempStorage loadf; __shared__ typename StoreChar::TempStorage storec; __shared__ typename BlockReduce::TempStorage reduce; __shared__ float smem_code[256]; __shared__ float smem_absmax_value[1]; if(DATA_TYPE == General8bit) for(int i = threadIdx.x; i < 256; i+=blockDim.x) smem_code[i] = code[i]; for (int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) { valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; local_abs_max = -FLT_MAX; __syncthreads(); LoadT(loadt).Load(&(A[i]), vals, valid_items, (T)0.0f); // 1. compute local max // 2. broadcast local max // 3. normalize inputs and quantize #pragma unroll NUM_PER_TH for(int j = 0; j < NUM_PER_TH; j++) local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cub::Max(), valid_items); if (threadIdx.x == 0) { smem_absmax_value[0] = 1.0f / local_abs_max; absmax[i / BLOCK_SIZE] = local_abs_max; } __syncthreads(); local_abs_max = smem_absmax_value[0]; if(STOCHASTIC) { local_rand_idx = ((blockIdx.x*NUM_BLOCK) + (threadIdx.x*NUM) + rand_offset) % (1024-4); LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); } unsigned char packed_4bit = 0; switch(DATA_TYPE) { case General8bit: #pragma unroll NUM_PER_TH for(int j = 0; j < NUM_PER_TH; j++) { if(!STOCHASTIC) qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max); else qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max); } break; case FP4: #pragma unroll NUM_PER_TH for(int j = 0; j < NUM_PER_TH/2; j++) { packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4; packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max); qvals[j] = packed_4bit; } break; case NF4: #pragma unroll NUM_PER_TH for(int j = 0; j < NUM_PER_TH/2; j++) { packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4; packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max); qvals[j] = packed_4bit; } break; } __syncthreads(); StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items); } } template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n) { const int n_load = (gridDim.x * TILE_SIZE); int valid_items_load = 0; int valid_items_store = 0; const int base_idx = (blockIdx.x * TILE_SIZE); T vals[NUM_PER_TH*((DATA_TYPE > 0) ? 2 : 1)]; unsigned char qvals[NUM_PER_TH]; float local_abs_max = -FLT_MAX; typedef cub::BlockLoad LoadChar; typedef cub::BlockStore 0) ? 2 : 1), cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT; __shared__ typename LoadChar::TempStorage loadchar; __shared__ typename StoreT::TempStorage storet; for (int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE) { if (DATA_TYPE > 0) { valid_items_load = min(TILE_SIZE, (n + 1) / 2 - i); valid_items_store = min(TILE_SIZE * 2, n - i * 2); } else { valid_items_load = min(TILE_SIZE, n - i); valid_items_store = valid_items_load; } // Since blocksize will always be a power-of-2, we avoid more expensive // division by the blocksize and instead use a shift operation. // This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize. local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH) >> (31 - __clz(blocksize))]); __syncthreads(); LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); switch (DATA_TYPE) { case General8bit: // load code through read-only cache via __ldg #pragma unroll NUM_PER_TH for(int j = 0; j < NUM_PER_TH; j++) vals[j] = __ldg(&code[qvals[j]])*local_abs_max; break; case FP4: #pragma unroll NUM_PER_TH for(int j = 0; j < NUM_PER_TH; j++) { vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max); vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max); } break; case NF4: #pragma unroll NUM_PER_TH for(int j = 0; j < NUM_PER_TH; j++) { vals[j*2] = dDequantizeNF4(qvals[j] >> 4)* local_abs_max; vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F)* local_abs_max; } break; } __syncthreads(); StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i*2 : i]), vals, valid_items_store); } } __global__ void kDequantize(float *code, unsigned char *A, float *out, const int n) { const unsigned int numThreads = blockDim.x * gridDim.x; const int idx = (blockIdx.x * blockDim.x) + threadIdx.x; __shared__ float smem_code[256]; if(threadIdx.x < 256) { smem_code[threadIdx.x] = code[threadIdx.x]; } __syncthreads(); for (int i = idx;i < n; i += numThreads) { out[i] = smem_code[A[i]]; } } template __launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) __global__ void kPreconditionOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n) { const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS); int valid_items = 0; T g_vals[NUM_VALS]; float s1_vals[NUM_VALS]; float s2_vals[NUM_VALS]; const float correction1 = 1.0f/(1.0f - powf(beta1, step)); const float correction2 = 1.0f/(1.0f - powf(beta2, step)); typedef cub::BlockLoad Load; typedef cub::BlockLoad LoadFloat; typedef cub::BlockReduce BlockReduce; __shared__ union { typename Load::TempStorage load; typename LoadFloat::TempStorage loadf; typename BlockReduce::TempStorage reduce; } temp_storage; for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) { valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; __syncthreads(); Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); __syncthreads(); LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); __syncthreads(); LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items, 0.0f); # pragma unroll NUM_VALS for(unsigned int j = 0; j < NUM_VALS; j++) g_vals[j] = gnorm_scale*((float)g_vals[j]); # pragma unroll NUM_VALS for(unsigned int j = 0; j < NUM_VALS; j++) { switch(OPTIMIZER) { case ADAM: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); s1_vals[j] *= correction1; s2_vals[j] *= correction2; s1_vals[j] = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update s1_vals[j] *= s1_vals[j]; // update l2 norm (update*update) break; } } # pragma unroll NUM_VALS-1 for(unsigned int j = 1; j < NUM_VALS; j++) s1_vals[0] += s1_vals[j]; __syncthreads(); s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0]); if(threadIdx.x == 0) atomicAdd(&unorm[0], s1_vals[0]); __syncwarp(); } } #define NUM_PER_THREAD 4 template __launch_bounds__(TH, 1) __global__ void kOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) { const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); int valid_items = 0; float update_scale = 0.0f; T g_vals[NUM_PER_THREAD]; T p_vals[NUM_PER_THREAD]; float s1_vals[NUM_PER_THREAD]; float s2_vals[NUM_PER_THREAD]; // AdEMAMix has an additional state buffer, which we packed // into state1. We need thread-local storage here for these. // TODO: Mark with [[maybe_unused]] after upgrade to min compiler. float s3_vals[NUM_PER_THREAD]; const float correction1 = 1.0f - powf(beta1, step); const float correction2 = sqrtf(1.0f - powf(beta2, step)); const float step_size = -lr*correction2/correction1; if(max_unorm > 0.0f) { update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } else{ update_scale = 1.0f; } } else{ update_scale = 1.0f; } typedef cub::BlockLoad Load; typedef cub::BlockStore Store; typedef cub::BlockLoad LoadFloat; typedef cub::BlockStore StoreFloat; __shared__ union { typename Load::TempStorage load; typename Store::TempStorage store; typename LoadFloat::TempStorage loadf; typename StoreFloat::TempStorage storef; } temp_storage; for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD) { valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; __syncthreads(); Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); __syncthreads(); LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); __syncthreads(); LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items); __syncthreads(); Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); // Load additional state1 data for AdEMAMix // TODO: Make constexpr after updating min compiler if (OPTIMIZER == ADEMAMIX) { __syncthreads(); LoadFloat(temp_storage.loadf).Load(&(state1[n + i]), s3_vals, valid_items); } # pragma unroll 4 for(unsigned int j = 0; j < NUM_PER_THREAD; j++) g_vals[j] = gnorm_scale*((float)g_vals[j]); # pragma unroll 4 for(unsigned int j = 0; j < NUM_PER_THREAD; j++) { switch(OPTIMIZER) { case ADEMAMIX: // m1 update: m1 = beta1 * m1 + (1-beta1) * g s1_vals[j] = (s1_vals[j] * beta1) + ((1.0f - beta1) * (float)g_vals[j]); // m2 update: m2 = m2 * beta3 + (1-beta3) * g s3_vals[j] = (s3_vals[j] * beta3) + ((1.0f - beta3) * (float)g_vals[j]); // nu update: nu = beta2 * nu + (1-beta2) * g^2 s2_vals[j] = (s2_vals[j] * beta2) + ((1.0f - beta2) * (float)g_vals[j] * (float)g_vals[j]); p_vals[j] = (float)p_vals[j] - lr * ( ((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / ( (sqrtf(s2_vals[j]) / correction2) + eps ) ); if (weight_decay > 0.0f) p_vals[j] = ((float)p_vals[j]) * (1.0f - (lr * weight_decay)); break; case ADAM: if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) { s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2)))); if(weight_decay > 0.0f) p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); } break; } } __syncthreads(); Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); __syncthreads(); StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); __syncthreads(); StoreFloat(temp_storage.storef).Store(&(state2[i]), s2_vals, valid_items); if (OPTIMIZER == ADEMAMIX) { __syncthreads(); StoreFloat(temp_storage.storef).Store(&(state1[n + i]), s3_vals, valid_items); } } } template __launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, float* state1, float *unorm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n) { const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS); int valid_items = 0; T g_vals[NUM_VALS]; float s1_vals[NUM_VALS]; typedef cub::BlockLoad Load; typedef cub::BlockLoad LoadFloat; typedef cub::BlockReduce BlockReduce; __shared__ union { typename Load::TempStorage load; typename LoadFloat::TempStorage loadf; typename BlockReduce::TempStorage reduce; } temp_storage; for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) { valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; __syncthreads(); Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); __syncthreads(); LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); # pragma unroll NUM_VALS for(unsigned int j = 0; j < NUM_VALS; j++) g_vals[j] = gnorm_scale*((float)g_vals[j]); # pragma unroll NUM_VALS for(unsigned int j = 0; j < NUM_VALS; j++) { switch(OPTIMIZER) { case MOMENTUM: if(step == 1) s1_vals[j] = (float)g_vals[j]; // state update else s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm break; case LION: s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*(float)g_vals[j]); // state update break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm break; case ADAGRAD: s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); // state update s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm break; } } # pragma unroll for(unsigned int j = 1; j < NUM_VALS; j++) s1_vals[0] += s1_vals[j]; __syncthreads(); s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0], valid_items); if(threadIdx.x == 0) atomicAdd(&unorm[0], s1_vals[0]); __syncwarp(); } } template __launch_bounds__(TH, 1) __global__ void kOptimizer32bit1State(T *g, T *p, float *state1, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) { const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); int valid_items = 0; float update_scale = 0.0f; if(max_unorm > 0.0f) { update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; if(update_scale > max_unorm*param_norm+eps){ update_scale = (max_unorm*param_norm+eps)/update_scale; } else{ update_scale = 1.0f; } } else{ update_scale = 1.0f; } T g_vals[NUM_PER_THREAD]; T p_vals[NUM_PER_THREAD]; float s1_vals[NUM_PER_THREAD]; typedef cub::BlockLoad Load; typedef cub::BlockStore Store; typedef cub::BlockLoad LoadFloat; typedef cub::BlockStore StoreFloat; __shared__ union { typename Load::TempStorage load; typename Store::TempStorage store; typename LoadFloat::TempStorage loadf; typename StoreFloat::TempStorage storef; } temp_storage; for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD) { valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; __syncthreads(); Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); __syncthreads(); LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); __syncthreads(); Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); # pragma unroll 4 for(unsigned int j = 0; j < NUM_PER_THREAD; j++) { g_vals[j] = gnorm_scale*((float)g_vals[j]); if(weight_decay > 0.0f) g_vals[j] = (float)g_vals[j] + (((float)p_vals[j])*weight_decay); } # pragma unroll 4 for(unsigned int j = 0; j < NUM_PER_THREAD; j++) { if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) { switch(OPTIMIZER) { case MOMENTUM: if(step == 1) s1_vals[j] = (float)g_vals[j]; else s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j])); break; case LION: p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j])))); s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j])); break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps)); break; case ADAGRAD: s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); p_vals[j] = ((float)p_vals[j]) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps); break; } } } __syncthreads(); Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); __syncthreads(); StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); } } #define NUM8BIT 16 #define NUM_THREADS 256 #define NUM_PER_BLOCK 4096 template __global__ void __launch_bounds__(NUM_THREADS, 2) kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, float *unorm, const float beta1, const float beta2, const float eps, const int step, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, const float gnorm_scale, const int n) { const int n_full = gridDim.x * NUM_PER_BLOCK; const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK); float g_val = 0.0f; float local_max_s1 = -FLT_MAX; float local_max_s2 = -FLT_MAX; float local_unorm = 0.0f; float s2_vals[NUM8BIT]; float s1_vals[NUM8BIT]; T g_vals[NUM8BIT]; unsigned char m_c1[NUM8BIT]; unsigned char r_c2[NUM8BIT]; typedef cub::BlockLoad LoadT; typedef cub::BlockLoad LoadUInt8; typedef cub::BlockReduce BlockReduce; __shared__ union { typename LoadT::TempStorage loadh; typename LoadUInt8::TempStorage loadc; typename BlockReduce::TempStorage reduce; } temp_storage; __shared__ float smem_quantiles1[256]; __shared__ float smem_quantiles2[256]; if(threadIdx.x < 256) { smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; smem_quantiles2[threadIdx.x] = quantiles2[threadIdx.x]; } __syncthreads(); for (unsigned int i = base_idx; i < n_full; i += NUM_THREADS*gridDim.x*NUM8BIT) { valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); __syncthreads(); LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); __syncthreads(); LoadUInt8(temp_storage.loadc).Load(&(state2[i]), r_c2, valid_items, 128); __syncthreads(); #pragma unroll 16 for(int j = 0; j < NUM8BIT; j++) { g_val = g_vals[j]; g_val *= gnorm_scale; s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]*beta1; s1_vals[j] += (1.0f-beta1)*g_val; local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j])); } #pragma unroll 16 for(int j = 0; j < NUM8BIT; j++) { g_val = g_vals[j]; g_val *= gnorm_scale; s2_vals[j] = smem_quantiles2[r_c2[j]]*max2[0]*beta2; s2_vals[j] += (1.0f-beta2)*g_val*g_val; local_max_s2 = fmaxf(local_max_s2, fabsf(s2_vals[j])); } if(unorm != NULL) { #pragma unroll 16 for(int j = 0; j < NUM8BIT; j++) { float correction1 = __fdividef(1.0f, 1.0f - powf(beta1, step)); float correction2 = __fdividef(1.0f, 1.0f - powf(beta2, step)); s1_vals[j] *= correction1; s2_vals[j] *= correction2; float update_val = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update local_unorm += update_val*update_val; } } } __syncthreads(); local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items); __syncthreads(); local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, cub::Max(), valid_items); if(unorm != NULL) { __syncthreads(); local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items); } if(threadIdx.x == 0) { atomicMax(&new_max1[0], local_max_s1); atomicMax(&new_max2[0], local_max_s2); if(unorm != NULL){ atomicAdd(&unorm[0], local_unorm); } } } #define NUM_PER_THREAD2 4 #define NUM_THREADS2 1024 #define NUM_PER_BLOCK2 4096 template __global__ void __launch_bounds__(NUM_THREADS2, 1) kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, const float *unorm, const float max_unorm, const float param_norm, \ const float beta1, const float beta2, const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, const int n) { const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2; const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2); int valid_items = 0; float g_val = 0.0f; float s1_vals[NUM_PER_THREAD2]; float s2_vals[NUM_PER_THREAD2]; const float correction1 = 1.0f - powf(beta1, step); const float correction2 = sqrtf(1.0f - powf(beta2, step)); const float step_size = -lr*correction2/correction1; //const float step_size = -lr*correction2/correction1; float new_max_val1 = 1.0f/new_max1[0]; float new_max_val2 = 1.0f/new_max2[0]; float update_scale = 1.0f; if(max_unorm > 0.0f) { update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } else{ update_scale = 1.0f; } } else{ update_scale = 1.0f; } unsigned char c1s[NUM_PER_THREAD2]; unsigned char c2s[NUM_PER_THREAD2]; T p_vals[NUM_PER_THREAD2]; T g_vals[NUM_PER_THREAD2]; typedef cub::BlockLoad LoadT; typedef cub::BlockLoad LoadChar; typedef cub::BlockStore StoreChar; typedef cub::BlockStore StoreT; __shared__ float smem_quantiles1[256]; __shared__ float smem_quantiles2[256]; __shared__ union { typename LoadT::TempStorage loadh; typename LoadChar::TempStorage loadc; typename StoreChar::TempStorage storec; typename StoreT::TempStorage storeh; } temp_storage; if(threadIdx.x < 512) { if(threadIdx.x < 256) smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; else smem_quantiles2[threadIdx.x-256] = quantiles2[threadIdx.x-256]; } __syncthreads(); for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2) { valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); __syncthreads(); LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); __syncthreads(); LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); __syncthreads(); LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } # pragma unroll 4 for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) { g_val = float(g_vals[j]); g_val *= gnorm_scale; s1_vals[j] = smem_quantiles1[c1s[j]]; s1_vals[j] = s1_vals[j]*max1[0]; s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1); // make sure state1 term has still the same sign after quantization // (not needed for state2 term which has only positive values) if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) { if(s1_vals[j] > 0.0f) c1s[j] += 1; else c1s[j] -= 1; } s2_vals[j] = smem_quantiles2[c2s[j]]; s2_vals[j] = s2_vals[j]*max2[0]; s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); c2s[j] = dQuantize<0>(smem_quantiles2, 0.0f, s2_vals[j]*new_max_val2); } # pragma unroll 4 for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) { p_vals[j] = (T)(((float)p_vals[j]) + ((update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(correction2*eps)))))); if(weight_decay > 0.0f) p_vals[j] = update_scale*((float)p_vals[j])*(1.0f-(lr*weight_decay)); } StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); __syncthreads(); StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); __syncthreads(); StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); __syncthreads(); } } template __global__ void __launch_bounds__(NUM_THREADS, 2) kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, float *unorm, const float beta1, const float beta2, const float eps, const int step, float* __restrict__ const quantiles1, float* max1, float* new_max1, const float weight_decay, const float gnorm_scale, const int n) { const int n_full = gridDim.x * NUM_PER_BLOCK; const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK); float g_val = 0.0f; float local_max_s1 = -FLT_MAX; float local_unorm = 0.0f; float s1_vals[NUM8BIT]; T g_vals[NUM8BIT]; unsigned char m_c1[NUM8BIT]; typedef cub::BlockLoad LoadT; typedef cub::BlockLoad LoadUInt8; typedef cub::BlockReduce BlockReduce; __shared__ union { typename LoadT::TempStorage loadh; typename LoadUInt8::TempStorage loadc; typename BlockReduce::TempStorage reduce; } temp_storage; __shared__ float smem_quantiles1[256]; if(threadIdx.x < 256) smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; __syncthreads(); for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS*NUM8BIT) { valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; __syncthreads(); LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); __syncthreads(); LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); #pragma unroll 16 for(int j = 0; j < NUM8BIT; j++) { g_val = g_vals[j]; g_val *= gnorm_scale; s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]; switch(OPTIMIZER) { case ADAGRAD: case MOMENTUM: if(step == 1) s1_vals[j] = (float)g_vals[j]; else s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); if(unorm != NULL) local_unorm += s1_vals[j]*s1_vals[j]; break; case LION: s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); break; } local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j])); } } __syncthreads(); local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items); if(threadIdx.x == 0){ atomicMax(&new_max1[0], local_max_s1); } if(unorm != NULL) { __syncthreads(); local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items); if(threadIdx.x == 0){ atomicAdd(&unorm[0], local_unorm); } } } template __global__ void __launch_bounds__(1024, 1) kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, const float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* max1, float* new_max1, float weight_decay, const float gnorm_scale, const int n) { const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2; const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2); int valid_items = 0; float g_val = 0.0f; float s1_vals[NUM_PER_THREAD2]; float new_max_val1 = 1.0f/new_max1[0]; float update_scale = 1.0f; if(max_unorm > 0.0f) { update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } else{ update_scale = 1.0f; } } else{ update_scale = 1.0f; } unsigned char c1s[NUM_PER_THREAD2]; T p_vals[NUM_PER_THREAD2]; T g_vals[NUM_PER_THREAD2]; typedef cub::BlockLoad LoadT; typedef cub::BlockLoad LoadChar; typedef cub::BlockStore StoreChar; typedef cub::BlockStore StoreT; __shared__ float smem_quantiles1[256]; __shared__ union { typename LoadT::TempStorage loadh; typename LoadChar::TempStorage loadc; typename StoreChar::TempStorage storec; typename StoreT::TempStorage storeh; } temp_storage; if(threadIdx.x < 256) smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; __syncthreads(); for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2) { valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); __syncthreads(); LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); __syncthreads(); LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } # pragma unroll 4 for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) { g_val = float(g_vals[j]); g_val *= gnorm_scale; if(weight_decay > 0.0f) { switch(OPTIMIZER) { case ADAGRAD: case MOMENTUM: case RMSPROP: g_val += ((float)p_vals[j])*weight_decay; break; case LION: p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); break; } } s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0]; switch(OPTIMIZER){ case ADAGRAD: case MOMENTUM: if(step == 1) s1_vals[j] = g_vals[j]; else s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j])); break; case LION: p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val)))); s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps)); break; } c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1); // make sure state1 term has still the same sign after quantization if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) { if(s1_vals[j] > 0.0f) c1s[j] += 1; else c1s[j] -= 1; } } StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); __syncthreads(); StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); __syncthreads(); } } template __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n) { const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); int valid_items = 0; typedef cub::BlockReduce BlockReduce; typedef cub::BlockLoad LoadT; __shared__ typename BlockReduce::TempStorage reduce; __shared__ typename LoadT::TempStorage loadT; T vals[NUM_VALS]; float local_sum = 0.0f; for (unsigned int i = (blockIdx.x * BLOCK_SIZE); i < n_full; i += gridDim.x*BLOCK_SIZE) { valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; local_sum = 0.0f; __syncthreads(); LoadT(loadT).Load(&(g[i]), vals, valid_items, (T)0.0f); #pragma unroll NUM_VALS for(int j = 0; j < NUM_VALS; j++) local_sum += ((float)vals[j])*((float)vals[j]); local_sum = BlockReduce(reduce).Sum(local_sum, valid_items); if(threadIdx.x == 0) { if(step == 1) { // initialize with the same norm for all positions //#pragma unroll 10 for(int j = 0; j < 100; j++) atomicAdd(&gnorm_vec[j], local_sum); } else atomicAdd(&gnorm_vec[step % 100], local_sum); } } } #define LANES 2 #define QUAD 3 template __launch_bounds__(256, 3) __global__ void kOptimizerStatic8bit2StateBlockwise( T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n ) { //const int n_full = n + (n%BLOCK_SIZE); const int n_full = gridDim.x * BLOCK_SIZE; const int base_idx = (blockIdx.x * BLOCK_SIZE); int valid_items = 0; float g_val = 0.0f; float s1_vals[N_PER_TH]; float s2_vals[N_PER_TH]; float s3_vals[N_PER_TH]; // 2-5% const float correction1 = 1.0f - __powf(beta1, step); const float correction2 = sqrtf(1.0f -__powf(beta2, step)); const float step_size = __fdividef(-lr*correction2,correction1); const int lane_id = threadIdx.x % LANES; float new_local_abs_max1 = -FLT_MAX; float new_local_abs_max2 = -FLT_MAX; float new_local_abs_max3 = -FLT_MAX; float quadrants1[QUAD]; float quadrants2[QUAD]; unsigned char c1s[N_PER_TH]; unsigned char c2s[N_PER_TH]; unsigned char c3s[N_PER_TH]; T g_vals[N_PER_TH]; T p_vals[N_PER_TH]; typedef cub::BlockLoad LoadT; typedef cub::BlockLoad LoadChar; typedef cub::BlockStore StoreChar; typedef cub::BlockStore StoreT; __shared__ float smem_quantiles1[LANES][257]; __shared__ float smem_quantiles2[LANES][257]; typedef cub::BlockReduce BlockReduce1; typedef cub::BlockReduce BlockReduce2; typedef cub::BlockReduce BlockReduce3; __shared__ typename BlockReduce1::TempStorage reduce1; __shared__ typename BlockReduce2::TempStorage reduce2; __shared__ typename BlockReduce2::TempStorage reduce3; __shared__ float smem_exchange1[1]; __shared__ float smem_exchange2[1]; __shared__ float smem_exchange3[1]; // [[maybe_unused]] __shared__ union { typename LoadT::TempStorage loadh; typename LoadChar::TempStorage loadc; typename StoreChar::TempStorage storec; typename StoreT::TempStorage storeh; } temp_storage; // init: 0.2 -> 0.23 // 0.23 -> 0.23 smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x]; smem_quantiles2[0][threadIdx.x] = quantiles2[threadIdx.x]; # pragma unroll for(unsigned int j = 1; j < LANES; j++) { smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x]; smem_quantiles2[j][threadIdx.x] = smem_quantiles2[0][threadIdx.x]; } __syncthreads(); #pragma unroll for(int k = 0; k < QUAD; k++) { quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; quadrants2[k] = smem_quantiles2[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; } for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) { // loads: 0.23 -> 0.85/1.44 valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; __syncthreads(); LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); __syncthreads(); LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); __syncthreads(); LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); // AdEMAMix has an additional state packed into state1. if (OPTIMIZER == ADEMAMIX) { __syncthreads(); LoadChar(temp_storage.loadc).Load(&(state1[n + i]), c3s, valid_items, 128); } new_local_abs_max1 = -FLT_MAX; new_local_abs_max2 = -FLT_MAX; new_local_abs_max3 = -FLT_MAX; // update: 2.48/1.57 -> 2.51/1.60 # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) { if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) { s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE]; g_val = g_vals[j]; //float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps); //g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val; g_val *= gnorm_scale; s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); if (OPTIMIZER == ADEMAMIX) { // The absmax for the third state is appended to absmax1 s3_vals[j] = smem_quantiles1[lane_id][c3s[j]] * absmax1[(n + i)/BLOCK_SIZE]; s3_vals[j] = (s3_vals[j] * beta3) + (((1.0f - beta3) * g_val)); } } else { s1_vals[j] = 0.0f; s2_vals[j] = 0.0f; if (OPTIMIZER == ADEMAMIX) { s3_vals[j] = 0.0f; } } new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); new_local_abs_max2 = fmaxf(new_local_abs_max2, fabsf(s2_vals[j])); if (OPTIMIZER == ADEMAMIX) { new_local_abs_max3 = fmaxf(new_local_abs_max3, fabsf(s3_vals[j])); } } // reduce: 2.51/1.60 -> 2.67/1.69 new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max()); new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, cub::Max()); if (OPTIMIZER == ADEMAMIX) { new_local_abs_max3 = BlockReduce3(reduce3).Reduce(new_local_abs_max3, cub::Max()); } if(threadIdx.x == 0) { smem_exchange1[0] = new_local_abs_max1; smem_exchange2[0] = new_local_abs_max2; if (OPTIMIZER == ADEMAMIX) { smem_exchange3[0] = new_local_abs_max3; } } __syncthreads(); if(threadIdx.x == 0) { absmax1[i/BLOCK_SIZE] = new_local_abs_max1; absmax2[i/BLOCK_SIZE] = new_local_abs_max2; if (OPTIMIZER == ADEMAMIX) { absmax1[(n + i)/BLOCK_SIZE] = new_local_abs_max3; } } else { new_local_abs_max1 = smem_exchange1[0]; new_local_abs_max2 = smem_exchange2[0]; if (OPTIMIZER == ADEMAMIX) { new_local_abs_max3 = smem_exchange3[0]; } } __syncthreads(); LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); // reduce: 2.67/1.69 -> 2.67/1.70 # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) { //if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) { if (OPTIMIZER == ADEMAMIX) { p_vals[j] = T((float)p_vals[j] - lr * ( ((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / ( (sqrtf(s2_vals[j]) / correction2) + eps ) )); } else { p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); } if(weight_decay > 0.0f) p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); } } // store: 0.85/1.44 -> 2.48/1.57 __syncthreads(); StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); // quantizaztion: 2.67/1.70 -> 3.4/3.3 # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) { c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1)); c2s[j] = quantize_2D<0>(quadrants2, smem_quantiles2[lane_id], __fdividef(s2_vals[j],new_local_abs_max2)); // make sure state1 term has still the same sign after quantization // (not needed for state2 term which has only positive values) if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) { if(s1_vals[j] > 0.0f) c1s[j] += 1; else c1s[j] -= 1; } if (OPTIMIZER == ADEMAMIX) { c3s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s3_vals[j],new_local_abs_max3)); if (signbit(smem_quantiles1[lane_id][c3s[j]]) != signbit(s3_vals[j])) { c3s[j] += (s3_vals[j] > 0.0f) ? 1 : -1; } } } __syncthreads(); StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); __syncthreads(); StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); if (OPTIMIZER == ADEMAMIX) { __syncthreads(); StoreChar(temp_storage.storec).Store(&(state1[n + i]), c3s, valid_items); } } } #define LANES 2 #define QUAD 3 template __launch_bounds__(256, 3) __global__ void kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, const float beta1, const float beta2, const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* absmax1, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n) { //const int n_full = n + (n%BLOCK_SIZE); const int n_full = gridDim.x * BLOCK_SIZE; const int base_idx = (blockIdx.x * BLOCK_SIZE); int valid_items = 0; float g_val = 0.0f; float s1_vals[N_PER_TH]; // 2-5% const int lane_id = threadIdx.x % LANES; float new_local_abs_max1 = -FLT_MAX; float quadrants1[QUAD]; unsigned char c1s[N_PER_TH]; T g_vals[N_PER_TH]; T p_vals[N_PER_TH]; typedef cub::BlockLoad LoadT; typedef cub::BlockLoad LoadChar; typedef cub::BlockStore StoreChar; typedef cub::BlockStore StoreT; __shared__ float smem_quantiles1[LANES][257]; typedef cub::BlockReduce BlockReduce1; __shared__ typename BlockReduce1::TempStorage reduce1; __shared__ float smem_exchange1[1]; __shared__ union { typename LoadT::TempStorage loadh; typename LoadChar::TempStorage loadc; typename StoreChar::TempStorage storec; typename StoreT::TempStorage storeh; } temp_storage; // init: 0.2 -> 0.23 // 0.23 -> 0.23 smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x]; # pragma unroll for(unsigned int j = 1; j < LANES; j++) smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x]; __syncthreads(); #pragma unroll for(int k = 0; k < QUAD; k++) quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) { // loads: 0.23 -> 0.85/1.44 valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; __syncthreads(); LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); __syncthreads(); LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); __syncthreads(); LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); new_local_abs_max1 = -FLT_MAX; // update: 2.48/1.57 -> 2.51/1.60 # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) { g_val = float(g_vals[j]); g_val *= gnorm_scale; if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) { if(weight_decay > 0.0f) { switch(OPTIMIZER) { case MOMENTUM: case ADAGRAD: case RMSPROP: g_val += ((float)p_vals[j])*weight_decay; break; case LION: p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); break; } } s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; switch(OPTIMIZER) { case MOMENTUM: if(step == 1) s1_vals[j] = g_val; else s1_vals[j] = (s1_vals[j]*beta1) + g_val; break; case LION: // here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update, before the momentum is updated by beta2 g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*g_val)); s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); break; case ADAGRAD: s1_vals[j] = s1_vals[j] + (g_val*g_val); break; } } new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); } // reduce: 2.51/1.60 -> 2.67/1.69 new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max()); if(threadIdx.x == 0) smem_exchange1[0] = new_local_abs_max1; __syncthreads(); if(threadIdx.x == 0) absmax1[i/BLOCK_SIZE] = new_local_abs_max1; else new_local_abs_max1 = smem_exchange1[0]; // reduce: 2.67/1.69 -> 2.67/1.70 # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) { if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) { switch(OPTIMIZER) { case MOMENTUM: p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); break; case LION: p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]); break; case RMSPROP: g_val = g_vals[j]; p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); break; case ADAGRAD: g_val = g_vals[j]; p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); break; } } } // store: 0.85/1.44 -> 2.48/1.57 __syncthreads(); StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); // quantizaztion: 2.67/1.70 -> 3.4/3.3 # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) { c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1)); // make sure state1 term has still the same sign after quantization // (not needed for state2 term which has only positive values) if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) { if(s1_vals[j] > 0.0f) c1s[j] += 1; else c1s[j] -= 1; } } __syncthreads(); StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); } } // Inputs: // A [rows, cols] // Outputs: // rowStats [rows] // out [rows, cols] template __launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) { // For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32. // Otherwise `T` is `fp16`. This can be removed when Maxwell is dropped. #if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR >= 2) || BNB_FP16_AVAILABLE using TReduction = T; #else using TReduction = float; #endif using BlockReduceT = cub::BlockReduce; // One block per row. // Threads load column values in a striped arrangement. // e.g. t0 reads row[0], row[0+nthreads], .. // and t1 reads row[1], row[1+nthreads], .. // Each thread will determine its local absmax. // We then do a blockwise reduction to determine the row's absmax. __shared__ typename BlockReduceT::TempStorage temp_storage; __shared__ TReduction smem_row_absmax; const int row_id = blockIdx.x; const T* row_data = A + (row_id * cols); // Threads will read the row values in a striped access pattern and find a local absmax. TReduction row_local_absmax = -FLT_MIN; for (int i = threadIdx.x; i < cols; i += THREADS) { const TReduction absval = fabsf(__ldcs(&(row_data[i]))); // For sparse decomposition, values outside of the threshold are not to be // included when calculating the row's absmax. if constexpr (SPARSE_DECOMP) { row_local_absmax = fmaxf(row_local_absmax, absval < TReduction(threshold) ? absval : row_local_absmax); } else { row_local_absmax = fmaxf(row_local_absmax, absval); } } // Reduce thread-local absmax across the block. const TReduction row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols); if (threadIdx.x == 0) { // Save our block's absmax to shared memory for the quantization step. rowStats[row_id] = smem_row_absmax = row_absmax; } __syncthreads(); // Quantize row-wise. const float scale = __fdividef(127.0f, smem_row_absmax); for (int i = threadIdx.x; i < cols; i += THREADS) { float val = row_data[i]; if constexpr (SPARSE_DECOMP) { // For sparse decomposition, we do not want to quantize the outliers. // Instead they're zeroed out. out[row_id * cols + i] = fabs(val) < threshold ? __float2int_rn(val * scale) : 0; } else { out[row_id * cols + i] = __float2int_rn(val * scale); } } } template __launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols) { using BlockReduceT = cub::BlockReduce; // One block per row. // Threads load column values in a striped arrangement. // e.g. t0 reads row[0], row[0+nthreads], .. // and t1 reads row[1], row[1+nthreads], .. // Each thread will determine its local absmax. // We then do a blockwise reduction to determine the row's absmax. __shared__ typename BlockReduceT::TempStorage temp_storage; const int row_id = blockIdx.x; const T* __restrict__ row_data = A + (row_id * cols); // Threads will read the row values in a striped access pattern and find a local absmax. float row_local_absmax = -FLT_MIN; for (int i = threadIdx.x; i < cols; i += THREADS) { const float absval = fabsf(row_data[i]); // For sparse decomposition, values outside of the threshold are not to be // included when calculating the row's absmax. if constexpr (SPARSE_DECOMP) { row_local_absmax = fmaxf(row_local_absmax, absval < threshold ? absval : row_local_absmax); } else { row_local_absmax = fmaxf(row_local_absmax, absval); } } // Reduce thread-local absmax across the block. // TODO: Consider algorithm BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols); if (threadIdx.x == 0) { // Save our block's absmax to shared memory for the quantization step. rowStats[row_id] = row_absmax; } } template __global__ void kgetRowStats(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols); template __global__ void kgetRowStats(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols); template __global__ void kInt8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); template __global__ void kInt8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); #define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f) template __global__ void kdequant_mm_int32_fp16( int* __restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half *__restrict__ const bias, const int numRows, const int numCols, const int n ) { const int n_out = numRows * numCols; int block_offset = blockIdx.x * THREADS * ITEMS_PER_THREAD; int thread_offset = threadIdx.x * ITEMS_PER_THREAD; int local_values[ITEMS_PER_THREAD]; half local_output[ITEMS_PER_THREAD]; float local_rowStats[ITEMS_PER_THREAD]; float local_colStats[ITEMS_PER_THREAD]; float local_biasValue[ITEMS_PER_THREAD]; typedef cub::BlockLoad LoadInt32; __shared__ typename LoadInt32::TempStorage loadint32; int row_idx, col_idx; #pragma unroll ITEMS_PER_THREAD for (int j = 0; j < ITEMS_PER_THREAD; ++j) { row_idx = (block_offset + thread_offset + j) / numCols; col_idx = (block_offset + thread_offset + j) % numCols; local_colStats[j] = col_idx >= numCols ? 0.0f : __ldg(&colStats[col_idx]); local_rowStats[j] = row_idx >= numRows ? 0.0f : __ldg(&rowStats[row_idx]); local_biasValue[j] = ((bias == nullptr) || col_idx >= numCols) ? 0.0f : __half2float(bias[col_idx]); } // Each block loads THREADS * ITEMS_PER_THREAD values from A int valid_items = block_offset + THREADS * ITEMS_PER_THREAD < n_out ? THREADS * ITEMS_PER_THREAD : n_out - block_offset; LoadInt32(loadint32).Load(&(A[block_offset]), local_values, valid_items, 0); #pragma unroll ITEMS_PER_THREAD for (int j = 0; j < ITEMS_PER_THREAD; ++j) { local_output[j] = __float2half( fmaf(local_values[j] * local_rowStats[j] * local_colStats[j], MM_DEQUANT_CONST, local_biasValue[j]) ); } #pragma unroll ITEMS_PER_THREAD for (int j = 0; j < ITEMS_PER_THREAD; j++) { int outIdx = block_offset + thread_offset + j; if (outIdx < n_out) { out[outIdx] = local_output[j]; } } } template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols) { // 0. Load data into 32*32 shared memory tiles // 1. transpose / reorder in shared memory // 2. store // COL32 FORMAT: // rows*32 tiles // TURING FORMAT: // 8*32 tiles with 4*4 subtiles // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements) // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column // index increases by 32 // AMPERE FORMAT: // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows: // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32] // To have efficient loads and stores if we transpose we need 128 consequitive bytes which at 1 byte are 128 values // As such we need: // at least 32*4 shared memory tiles for col32; preferably 32*32 // at least 32*6 shared memory tiles for col32_ampere: preferably 32*32 // at least 32*8 shared memory tiles for col4_turing: preferably 32*32 // for efficient loading of row major we need to load 128 elements and repeat this 32 items // this would imply a 32x128 shared memory tile -> 4kb // It is more efficient to have more than 1 warp, so with 64 threads we need 32x128 -> 8 kb // we have 64k sharded mem per SM in Turing which is 8 blocks per SM which is 2*8 = 32 warps = 100% occupancy // for turing and 50% for A100 and 75% for RTX 30s / A40 which is probably good enough // register pressure should be low with: 8 registers from local memoryh per block and 64 registers per SM // // to make the shared memory work with that occupancy we might need to union the block loads/stores // each block loads TILE_COLs columns and TILE_ROW rows // after reading a tile the row counter increase by TILE_ROWS // the col counter reset after reading TILE_COL elements const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS; // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached const int base_col = (blockIdx.x*TILE_COLS) % tiledCols; const int base_idx = (base_row*cols) + base_col; // we load 128 bytes per warp with // 32 rows for transposes that fill col32 types // so that we can have contiguous stores __shared__ char smem_data[32*33*ITEMS_PER_THREAD]; char local_data[ITEMS_PER_THREAD]; typedef cub::BlockExchange BlockExchange; // we load row after row from the base_position // Load data row by row int warps = blockDim.x/32; int warp_id = threadIdx.x/32; int warp_lane = threadIdx.x % 32; int offset = 0; int smem_row = 0; // each warp loads one row of 128 bytes for(int row = warp_id; row < TILE_ROWS; row+=warps) { int i = base_idx + (row*cols); // we load up to 128 bytes/items per load int valid_items = cols - base_col > 32*ITEMS_PER_THREAD ? 32*ITEMS_PER_THREAD : cols - base_col; // 0. Load data into 32*32 shared memory tiles if(base_row + row < rows) { #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) { int col_idx = warp_lane+(j*32); if(col_idx < valid_items) local_data[j] = A[i+col_idx]; else local_data[j] = 0; } } else { #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) local_data[j] = 0; } if(TRANSPOSE) { #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) { int local_col = (32*j)+warp_lane; //int local_row = row; // store as 256x32 smem_data[(local_col*33) + row] = local_data[j]; } } else { // treat smem as 32x256, that is 32 rows and 256 columns #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) smem_data[row*32*ITEMS_PER_THREAD + (warp_lane) + (j*32)] = local_data[j]; } smem_row += warps; // 1. transpose / reorder in shared memory if(smem_row % 32 == 0) { smem_row = 0; __syncthreads(); for(int subrow = warp_id; subrow < 32; subrow+=warps) { for(int j = 0; j < ITEMS_PER_THREAD; j++) { switch(FORMAT) { case COL32: if(TRANSPOSE) { // data lies in shared memory in the following way: // row0 [col0 col1 ... col31] // row1 [col0 col1 ... col31] // ... // // As such we read consecutive entries with 256 threads (8rows x 32 columns) // as j increase, the row increase by a factor of 8 // We load 8 rows per subrow loop, and subrow increase by 8 per loop // so we have an offset of 8 rows every loop or (subrow/warps)*8 = (subrow/8)*8 const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) //const int local_row = warp_id; // each warp_id is one row //const int block_row = base_col; // block offset for row //const int local_col = warp_lane //const int global_col = base_row; // block offset for col if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) { // each row has 32 columns and is offset by 1 to prevent bank conflict during storage into smem char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; // each 32 columns we have new tile // each tile has size outRows*32 and base_row is done in increments of 32 offset = base_row*outRows; out[offset + (base_col + jrow + subrow_loop_row)*32 + threadIdx.x] = data; } } else { if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) { offset = (base_col/32)*(32*rows); char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; out[offset+(base_row+subrow)*32 + ((j)*rows*32)+warp_lane] = data; } } break; case COL_TURING: // TURING FORMAT: // 8*32 tiles with 4*4 subtiles // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements) // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column // index increases by 32 // // [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...] if(TRANSPOSE) { const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) //const int local_row = warp_id; // each warp_id is one row //const int block_row = base_col; // block offset for row //const int local_col = warp_lane //const int global_col = base_row; // block offset for col if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) { // each row has 32 columns and is offset by 1 to prevent bank conflict during storage into smem char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; // each 32 columns we have new tile // each tile has size 8*32 = 256 elements offset // for each row offset of 8 we increaes the tile first // after all rows are exhausted, we increase the col int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/8)*256; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows // we increase by row_tile_column every 32 columns // base_row increase in increments of 32 //int row_tile_column = 256*outRows/8; // there are outRows/8 row tiles, and each tile is 256 elements //int col_offset = (base_row/32)*row_tile_column; // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8 // 256*outRows/8*base_row/32 = outRows*base_row int col_offset = outRows*base_row; offset = row_offset+col_offset; // since we process even number of rows with each j (8) and with each subrow (8j) we can determine // odd or even rows with the warp_id (each warp processes one row) // the col is warp_lane (max 32 columns per row) and the row warp_id if(warp_id % 2 == 1) // odd offset += 128 + (warp_lane/4)*16 + (warp_lane%4) + (((warp_id%8)-1)*2); else // even offset += 0 + (warp_lane/4)*16 + (warp_lane%4) + ((warp_id%8)*2); out[offset] = data; } } else { if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) { char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; // set offset designates the tile offset among the 8*32 tiles // we first increase rows and then columns. Since we load 128 columns at once // we increase the offset by outRows*32 every 32 columns // additionally, we increase the offset by 8*32=256 every 8 rows offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/8)*256); // global offset (8x32 tile) // first 4 rows are reserved for even rows, [0, 2, 4, 6], the next 4 for odd // each of these has 32 values in total for 32*4 = 128 as offset if odd // every set of 4 columns increases the total offset by 16 // each even row increase the offset by 4, for example row 2 is offset by 4, 4 by 6 etc so: subrow/2*4 = subrow*2 // this happens every 8 rows anew (subrow % 8) // one writes 4 columns at once that is (col % 4) for the particular index in the subtile int subcol = warp_lane; // add local offset (4x4 sub-tile) if(subrow % 2 == 1) // odd offset += 128 + (subcol/4)*16 + (subcol%4) + (((subrow%8)-1)*2); else // even offset += 0 + (subcol/4)*16 + (subcol%4) + ((subrow%8)*2); out[offset] = data; } } break; case COL_AMPERE: // AMPERE FORMAT: // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows: // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32] if(TRANSPOSE) { const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) //const int local_row = warp_id; // each warp_id is one row //const int block_row = base_col; // block offset for row //const int local_col = warp_lane //const int global_col = base_row; // block offset for col if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) { // each row has 32 columns and is offset by 1 to prevent bank conflict during storage into smem char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; // each 32 columns we have new tile // each tile has size 32*32 = 1024 elements offset // for each row offset of 32 we increaes the tile first // after all rows are exhausted, we increase the col int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/32)*1024; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows // we increase by row_tile_column every 32 columns // base_row increase in increments of 32 //int row_tile_column = 1024*outRows/32; // there are outRows/32 row tiles, and each tile is 1024 elements //int col_offset = (base_row/32)*row_tile_column; // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8 // 1024*outRows/32*base_row/32 = outRows*base_row int col_offset = outRows*base_row; offset = row_offset+col_offset; // same as in the non-transpose case (see below) // the difference is that now rows = cols // in this case warp_id = subrow // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row // every 2 rows, the offset increases by two [0, 1, 8, 9...] // every 2 rows, the row index increase by 8 [0, 1, 8, 9...] int local_row = (jrow + warp_id) % 32; // offset for row > 32 is already calculated into row_offset int ampere_row = ((local_row % 8)/2)*8 + (local_row/8)*2 + (local_row % 2); // global offset + row with 32 cols each + 32 cols per j + col_idx=warp_lane out[offset + (ampere_row*32) + warp_lane] = data; } } else { if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) { char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; // set offset designates the tile offset among the 32*32 tiles // we first increase rows and then columns. Since we load 128 columns at once // we increase the offset by outRows*32 every 32 columns // additionally, we increase the offset by 32*32=1024 every 32 rows offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/32)*1024); // global offset (32x32 tile) // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row // every 2 rows, the offset increases by two [0, 1, 8, 9...] // every 2 rows, the row index increase by 8 [0, 1, 8, 9...] int local_row = ((subrow % 8)/2)*8 + (subrow/8)*2 + (subrow % 2); // global offset + row with 32 cols each + 32 cols per j + col_idx out[offset + (local_row*32) + warp_lane] = data; } } break; } } } } } } #define DENORM 1.0f/127.0f #define MAX_SPARSE_COUNT 32 #define SMEM_SIZE 8*256 template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB) { // 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block // If a block finishes, the next one is scheduled. Since the last blocks like have fewer // elements they finish faster "fillin up" the gaps left by larger blocks // without tensor cores // 1. use rowidx_length to find what to load (as many blocks as there are rows) // 2. Load A into registers // 3. each warp loads all required rows of B but each warp is offset by k // 4. Do mma operations that accumulate into registers // 5. Each warp stores its output row into matrix C const int count = max_count[blockIdx.x]; const int local_max_idx = max_idx[blockIdx.x]; const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1]; const int local_row_idx = rowidx[offset]; const int warp_id = threadIdx.x / 32; const int warp_idx = threadIdx.x % 32; const int warp_offset = (warp_id*32)*SPMM_ITEMS; const int num_items = BITS == 8 ? 8 : 8; int idx_col_B = warp_offset; int local_idx_col_B_offset = 0; half local_valA[MAX_SPARSE_COUNT]; int local_colidxA[MAX_SPARSE_COUNT]; half local_valC[SPMM_ITEMS]; T local_valsB[num_items]; half local_valOut[num_items]; // 128 byte loads per warp == 4 bytes per thread // 2. Load A into registers for(int j = 0; j < MAX_SPARSE_COUNT; j++) { local_valA[j] = j < count ? values[offset+j] : __float2half(0.0f); local_colidxA[j] = j < count ? colidx[offset+j] : 0; } // each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192 // we expect each warp to be SPMM_ITEMS*32 apart // we have a total of 128 bytes for the bank with a bank size of 4 bytes // added 3 bytes = 6 values between warps should reduce bank conflicts __shared__ half smem_dequant_stats[SMEM_SIZE]; while(idx_col_B < colsB) { if(dequant_stats != NULL) { for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x) if((idx_col_B+i-local_idx_col_B_offset) < colsB) smem_dequant_stats[i] = dequant_stats[idx_col_B+i-local_idx_col_B_offset]; __syncthreads(); } #pragma unroll SPMM_ITEMS for(int j = 0; j < SPMM_ITEMS; j++) local_valC[j] = 0.0f; #pragma unroll for(int i = 0; i < count; i++) { // 3. each warp loads all required rows of B but each warp is offset by k int row_offset = colsB*local_colidxA[i]; #pragma unroll SPMM_ITEMS for(int j = 0; j < SPMM_ITEMS; j+=num_items) { // 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached int idx = idx_col_B + (warp_idx*SPMM_ITEMS) + j; if(idx >= colsB){ break; } if((idx+num_items < colsB)) { if(BITS == 8) reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset+ idx)/num_items]; else reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset+ idx)/num_items]; } else { #pragma unroll num_items for(int k = 0; k < num_items; k++) if(idx+k < colsB) local_valsB[k] = B[row_offset+idx+k]; else local_valsB[k] = 0.0f; } #pragma unroll num_items for(int k = 0; k < num_items; k++) { if(BITS == 8 && dequant_stats != NULL) // we do texture cache reads (__ldg) on dequant_stats which should be super fast { float valB = local_valsB[k]; float valA = local_valA[i]; if(valB != 0.0 && valA != 0.0) local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*DENORM*valB*valA; } else local_valC[j+k] = (float)local_valC[j+k] + (float)local_valsB[k]*(float)local_valA[i]; } } } int idx_row_C = (colsB*local_row_idx); #pragma unroll SPMM_ITEMS for(int j = 0; j < SPMM_ITEMS; j+=num_items) { //int idx_col_C = idx_col_B + (32*j) + warp_idx; int idx_col_C = idx_col_B + warp_idx*SPMM_ITEMS + j; int idx_val = idx_col_C + idx_row_C; if(idx_col_C +num_items < colsB) { // load outputs to do inplace addition reinterpret_cast(local_valOut)[0] = reinterpret_cast(out)[idx_val/num_items]; #pragma unroll num_items for(int k = 0; k < num_items; k++) local_valC[(j/num_items) + k] = (float)local_valC[(j/num_items) + k] + (float)local_valOut[k]; reinterpret_cast(out)[idx_val/num_items] = reinterpret_cast(local_valC)[j/num_items]; } else { #pragma unroll num_items for(int k = 0; k < num_items; k++) if(idx_col_C + k < colsB) out[idx_val+k] = (float)out[idx_val+k]+(float)local_valC[j+k]; } } idx_col_B += blockDim.x*SPMM_ITEMS; local_idx_col_B_offset += blockDim.x*SPMM_ITEMS; } } template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA) { int local_colidx = idx[blockIdx.x]; if(FORMAT==COL_TURING) { // TURING FORMAT: // 8*32 tiles with 4*4 subtiles // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*8 = 128 elements) // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column // index increases by 32 // columns are grouped in increments of 4, meaning that one has the following rows and columns // rows: [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...] // cols: [0 1 2 3, 0 1 2 4, 0 1 2 3, 0 1 2 3, 4 5 6 7 ...] // each thread reads 1 element = 1 row for(int row = threadIdx.x; row < rowsA; row+= blockDim.x) { int offset_per_col_tile = ((rowsA+7)/8)*32*8; int tile_offset_rows = (row/8)*32*8; int tile_offset_cols = (local_colidx/32)*offset_per_col_tile; int offset = 0; int subtile_col_idx = local_colidx%32; int subtile_row_idx = row % 8; if(row % 2 == 1) offset += 128 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + ((subtile_row_idx-1)*2); else // even offset += 0 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + (subtile_row_idx*2); offset += tile_offset_rows + tile_offset_cols; char val = A[offset]; int out_idx = (row*idx_size) + blockIdx.x; out[out_idx] = val; } } else if(FORMAT == COL_AMPERE) { for(int row = threadIdx.x; row < rowsA; row+= blockDim.x) { // we got 32x32 tiles and we use the magic equation from the cublasLt doc to get the element // within each tile. int offset_per_col_tile = ((rowsA+31)/32)*32*32; int tile_offset_rows = (row/32)*32*32; int tile_offset_cols = (local_colidx/32)*offset_per_col_tile; int subtile_col_idx = local_colidx%32; int subtile_row_idx = row % 32; // this magic is taken from the cublasLt doc (search for COL32) int offset = (((subtile_row_idx%8)/2*4+subtile_row_idx/8)*2+subtile_row_idx%2)*32+subtile_col_idx; offset += tile_offset_cols + tile_offset_rows; char val = A[offset]; int out_idx = (row*idx_size) + blockIdx.x; out[out_idx] = val; } } } #define WARPS 3 template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) { #if __CUDA_ARCH__ >= 750 using namespace nvcuda; int col_offset = blockIdx.x *32; const int warp_id = threadIdx.x / 32; const int half_warp_id = threadIdx.x / 16; const int half_warp_lane = threadIdx.x % 16; const int batch_size_warps = (WARPS-1)*2; const int val_per_iter = blockDim.x-32; T local_A[4]; T local_B[128]; const int a_tile_offset = 16; const int b_tile_offset = (16*32 + 16); __shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))]; __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; //__shared__ T smem_C[8*32]; wmma::fragment a_frag; wmma::fragment b_frag; wmma::fragment c_frag; wmma::fill_fragment(c_frag, 0.0f); int ticktock = 0; int idx = 0 + threadIdx.x; int loaded_values = 0; // prefetch if(idx < K && warp_id < (WARPS-1)) { if(loaded_values == 0) { local_A[0] = A[idx]; local_A[1] = A[idx+(1*val_per_iter)]; local_A[2] = A[idx+(2*val_per_iter)]; local_A[3] = A[idx+(3*val_per_iter)]; #pragma unroll 32 for(int col = 0; col < 32; col++) { local_B[col] = B[(col_offset+col)*ldb+idx]; local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; } loaded_values = 3; } else { if(loaded_values == 3) { local_A[0] = local_A[1]; #pragma unroll 32 for(int col = 0; col < 32; col++) local_B[col] = local_B[col+(32)]; } else if(loaded_values == 2) { local_A[0] = local_A[2]; #pragma unroll 32 for(int col = 0; col < 32; col++) local_B[col] = local_B[col+(64)]; } else { local_A[0] = local_A[3]; #pragma unroll 32 for(int col = 0; col < 32; col++) local_B[col] = local_B[col+(96)]; } loaded_values--; } smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; #pragma unroll 32 for(int col = 0; col < 32; col++) smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; } else if(warp_id < (WARPS-1)) { local_A[0] = T(0.0); smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; #pragma unroll 32 for(int col = 0; col < 32; col++) local_B[col] = 0.0f; #pragma unroll 32 for(int col = 0; col < 32; col++) smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; } ticktock = ticktock == 0 ? 1 : 0; //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) { idx = base_idx + threadIdx.x; __syncthreads(); if(idx < K && warp_id < (WARPS-1)) { //local_A[0] = A[idx]; //#pragma unroll 32 //for(int col = 0; col < 32; col++) // local_B[col] = B[(col_offset+col)*ldb+idx]; if(loaded_values == 0) { local_A[0] = A[idx]; local_A[1] = A[idx+(1*val_per_iter)]; local_A[2] = A[idx+(2*val_per_iter)]; local_A[3] = A[idx+(3*val_per_iter)]; #pragma unroll 32 for(int col = 0; col < 32; col++) { local_B[col] = B[(col_offset+col)*ldb+idx]; local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; } loaded_values = 3; } else { if(loaded_values == 3) { local_A[0] = local_A[1]; #pragma unroll 32 for(int col = 0; col < 32; col++) local_B[col] = local_B[col+(32)]; } else if(loaded_values == 2) { local_A[0] = local_A[2]; #pragma unroll 32 for(int col = 0; col < 32; col++) local_B[col] = local_B[col+(64)]; } else { local_A[0] = local_A[3]; #pragma unroll 32 for(int col = 0; col < 32; col++) local_B[col] = local_B[col+(96)]; } loaded_values--; } smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; #pragma unroll 32 for(int col = 0; col < 32; col++) smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; } else if(warp_id < (WARPS-1)) { local_A[0] = T(0.0); smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; #pragma unroll 32 for(int col = 0; col < 32; col++) local_B[col] = 0.0f; #pragma unroll 32 for(int col = 0; col < 32; col++) smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; } ticktock = ticktock == 0 ? 1 : 0; if(warp_id == (WARPS-1)) for(int k = 0; k < batch_size_warps; k++) { wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); } } __syncthreads(); if(warp_id != (WARPS-1)){ return; } // only warp_id == (WARPS-1) from here int warp_lane = threadIdx.x % 32; ticktock = ticktock == 0 ? 1 : 0; for(int k = 0; k < batch_size_warps; k++) { wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); } // 129 mu if(warp_id == (WARPS-1)) wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); if(col_offset + warp_lane < M) out[col_offset + warp_lane] = smem_A[warp_lane]; #endif } template __device__ void printnonzero(T *A, int num_values, const char * strval) { for(int i = 0; i < num_values; i++) if((float)A[i] != 0.0) printf("%s %i %f\n", strval, i, (float)A[i]); } template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) { //// element-wise kernel //// 1. Load batch x k into registers //// 2. Load k x k into registers //// 3. dequantize and store in second pair of k x k //// 4. matmul //// 5. sum with cub //// 6. store outputs //// TC kernel //// use k warps per thread block //// 1. threadblock use read-only cache to read in register tile for A into shared memory //// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments //// 3. each warp reads a segment of values 16x32 from B //// 4. do dequantization from register of B into second pair of registers //// 5. store (4) into fragment //// 6. matmul aggregate into fragment C //// 7. aggregate files of C into shared memory block C //// 8. sum (7) //// 9. write outputs to matmul output matrix #if __CUDA_ARCH__ >= 750 using namespace nvcuda; int col_offset = blockIdx.x *32; const int warp_id = threadIdx.x / 32; const int warp_idx = threadIdx.x % 32; const int half_warp_id = threadIdx.x / 16; const int half_warp_lane = threadIdx.x % 16; const int batch_size_warps = (WARPS-1)*2; T quant_map[16]; #pragma unroll 16 for(int i = 0; i < 16; i++) quant_map[i] = nf4_data[i]; //__shared__ T quant_map[16*160]; T local_A[2]; T local_B[64]; unsigned char local_B_4bit[32]; const int a_tile_offset = 16; const int b_tile_offset = (16*32 + 16); __shared__ T smem_A[8*16 + (16*(batch_size_warps-1))]; __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; __shared__ T smem_C[8*32]; wmma::fragment a_frag; wmma::fragment b_frag; wmma::fragment c_frag; wmma::fill_fragment(c_frag, 0.0f); for(int i = threadIdx.x; i < (8*32); i+=blockDim.x) smem_C[i] = 0.0f; __syncthreads(); int ticktock = 0; int idx = 0 + threadIdx.x; int loaded_values = 0; // prefetch if(idx < K && warp_id < (WARPS-1)) { if(loaded_values == 0) { local_A[0] = A[idx]; local_A[1] = A[idx+blockDim.x-32]; #pragma unroll 32 for(int col = 0; col < 32; col++) local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; loaded_values = 1; } else { local_A[0] = local_A[1]; loaded_values--; #pragma unroll 64 for(int col = 0; col < 64; col+=2) { //local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f); //local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f); //local_B[col] = d2DequantizeFP4(local_B_4bit[col/2] >> 4)*(float)(17.0); //local_B[col+1] = d2DequantizeFP4(local_B_4bit[col/2] & 0x0F)*(float)(17.0); //local_B[col] = 127*(local_B_4bit[col/2] >> 4)*(float)(17.0); //local_B[col+1] = 127*(local_B_4bit[col/2] & 0x0F)*(float)(17.0); //local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(17.0); //local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(17.0); local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(17.0); local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(17.0); } } smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; #pragma unroll 32 for(int col = 0; col < 32; col++) smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; } else if(warp_id < (WARPS-1)) { local_A[0] = T(0.0); smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; #pragma unroll 32 for(int col = 0; col < 32; col++) local_B[col] = 0.0f; #pragma unroll 32 for(int col = 0; col < 32; col++) smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; } ticktock = ticktock == 0 ? 1 : 0; //if(threadIdx.x == 0) //printf("aa %i %i\n", idx, loaded_values); //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) { idx = base_idx + threadIdx.x; //if(threadIdx.x == 0) //printf("%i %i\n", idx, loaded_values); //__syncthreads(); if(idx < K && warp_id < (WARPS-1)) { if(loaded_values == 0) { local_A[0] = A[idx]; local_A[1] = A[idx+blockDim.x-32]; #pragma unroll 32 for(int col = 0; col < 32; col++) { local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; local_B_4bit[col+16] = B[(col_offset+col)*ldb+idx]; } loaded_values = 1; } else { local_A[0] = local_A[1]; loaded_values--; int absidx = (idx + col_offset)/blocksize; half local_absmax = __ldg(&(absmax[absidx])); #pragma unroll 64 for(int col = 0; col < 64; col+=2) { //local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx); //local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx); //local_B[col] = T(127)*T(local_B_4bit[col/2] >> 4)*T(absidx); //local_B[col+1] = T(127)*T(local_B_4bit[col/2] & 0x0F)*T(absidx); //local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(local_absmax); //local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax); local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx); local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx); } //printnonzero(local_B, 128, ""); } smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; #pragma unroll 32 for(int col = 0; col < 32; col++) smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; } else if(warp_id < (WARPS-1)) { local_A[0] = T(0.0); smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; #pragma unroll 32 for(int col = 0; col < 32; col++) local_B[col] = 0.0f; #pragma unroll 32 for(int col = 0; col < 32; col++) smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; } ticktock = ticktock == 0 ? 1 : 0; if(warp_id == (WARPS-1)) for(int k = 0; k < batch_size_warps; k++) { wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); } } __syncthreads(); //if(threadIdx.x == 0) //{ // printnonzero(smem_A, 8*16 + (2*16*(batch_size_warps-1)), "A: "); // printnonzero(smem_B, 2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1)), "B: "); //} if(warp_id != (WARPS-1)){ return; } // only warp_id == (WARPS-1) from here int warp_lane = threadIdx.x % 32; ticktock = ticktock == 0 ? 1 : 0; for(int k = 0; k < batch_size_warps; k++) { //if(warp_lane == 0) //printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x); wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); } // 129 mu if(warp_id == (WARPS-1)) wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major); //printnonzero(smem_C, 32, ""); if(col_offset + warp_lane < M) out[col_offset + warp_lane] = smem_C[warp_lane]; #endif } #define num_values_4bit 32 template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) { // per threadblock: // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps] // 4 warps -> 4 loads per iter // 1x32 * 32x4 -> 1x4 outputs per thread block typedef cub::WarpReduce WarpReduce; __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/32]; const int warp_idx = threadIdx.x / 32; const int warp_lane = threadIdx.x % 32; const int row_B = (THREADS/32)*blockIdx.x + warp_idx; const int offset_B = ldb*row_B; const int num_values_8bit = num_values_4bit/2; float local_C = 0.0f; unsigned char local_B_4bit[num_values_8bit]; T local_B[num_values_4bit/4]; T local_A[num_values_4bit/4]; __shared__ T quant_map[16]; T local_absmax = T(0.0f); if (threadIdx.x < 16) quant_map[threadIdx.x] = T(__ldg(&datatype[threadIdx.x])); //for(int i = threadIdx.x; i < 16; i++) //quant_map[i] = T(__ldg(&datatype[i])); __syncthreads(); // A: [1, K] // B: [N, K] for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += 32*num_values_4bit) { const int inner_idx_halved = inner_idx/2; // Since blocksize will always be a power-of-2, we avoid more expensive // division by the blocksize and instead use a shift operation. // This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize. const int absidx = ((2*offset_B)+inner_idx) >> (31 - __clz(blocksize)); local_absmax = __ldg(&(absmax[absidx])); if(row_B < M) { if((inner_idx_halved + num_values_8bit) < (K/2)) { // this is the most important for performance considerations reinterpret_cast(local_B_4bit)[0] = reinterpret_cast(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)]; } else { #pragma unroll for(int j = 0; j < (num_values_8bit); j++) if((inner_idx_halved) + j < (K/2)) local_B_4bit[j] = B[offset_B+inner_idx_halved + j]; else local_B_4bit[j] = 0b01110111; } } else { #pragma unroll for(int j = 0; j < (num_values_8bit); j++) local_B_4bit[j] = 0b01110111; } for(int i = 0; i < 4; i++) { #pragma unroll for(int k = 0; k < num_values_8bit/4; k++) { #if BNB_BF16_AVAILABLE local_B[k*2] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*local_absmax; local_B[k*2 + 1] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*local_absmax; #else // bf16 multipliation not supported local_B[k*2] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*(float)local_absmax); local_B[k*2 + 1] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*(float)local_absmax); #endif } if(inner_idx+(num_values_4bit/4) + (i*num_values_4bit/4) < K) { // this is also relatively important for performance if(BITS==16) { reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_4bit/4) + i]; } else { reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + (2*i) + 0]; reinterpret_cast(local_A)[1] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + (2*i) + 1]; } } else #pragma unroll for(int k = 0; k < num_values_4bit/4; k++) if(inner_idx + (i*num_values_4bit/4) + k < K) local_A[k] = A[inner_idx + k + (i*num_values_4bit/4)]; else local_A[k] = T(0.0f); // accumulate in float; small performance hit for Ampere, but lower error for outputs #pragma unroll for(int k = 0; k < num_values_4bit/4; k++) { #if BNB_BF16_AVAILABLE local_C += (float)(local_A[k]*local_B[k]); #else // bf16 multipliation not supported local_C += ((float)local_A[k]*(float)local_B[k]); #endif } } } local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C); if(row_B < M && warp_lane == 0) out[row_B] = T(local_C); } template __global__ void kfunc(T *A, T *B, T value, long n) { for(long i = (blockDim.x*blockIdx.x) + threadIdx.x; i < n; i+=(blockDim.x*gridDim.x)) { switch(FUNC) { case FILL: A[i] = (T)value; break; case ARANGE: A[i] = (T)i; break; case _MUL: A[i] = A[i]*B[i]; break; } } } //============================================================== // TEMPLATE DEFINITIONS //============================================================== template __global__ void kfunc(float *A, float *B, float value, long n); template __global__ void kfunc(unsigned char *A, unsigned char *B, unsigned char value, long n); template __global__ void kfunc(float *A, float *B, float value, long n); template __global__ void kfunc(float *A, float *B, float value, long n); // these are not used and make no sense, but the compiler needs them //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); // these are not used and make no sense, but the compiler needs them //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128, 16>(int M, int N, int K, __nv_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); template __global__ void kdequant_mm_int32_fp16<4, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n); template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x); template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x); template __global__ void kEstimateQuantiles(float *__restrict__ const A, float *code, const float offset, const float max_val, const int n); template __global__ void kEstimateQuantiles(half *__restrict__ const A, float *code, const float offset, const half max_val, const int n); #define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ template __global__ void kPreconditionOptimizer32bit1State(gtype* g, gtype* p, \ float* state1, float *unorm, \ const float beta1, const float beta2, const float eps, const float weight_decay, \ const int step, const float lr, const float gnorm_scale, const int n); \ MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half) MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) MAKE_PreconditionOptimizer32bit1State(MOMENTUM, __nv_bfloat16) MAKE_PreconditionOptimizer32bit1State(RMSPROP, half) MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) MAKE_PreconditionOptimizer32bit1State(RMSPROP, __nv_bfloat16) MAKE_PreconditionOptimizer32bit1State(LION, half) MAKE_PreconditionOptimizer32bit1State(LION, float) MAKE_PreconditionOptimizer32bit1State(LION, __nv_bfloat16) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, __nv_bfloat16) #define MAKE_Optimizer32bit1State(oname, gtype) \ template __global__ void kOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \ MAKE_Optimizer32bit1State(MOMENTUM, half) MAKE_Optimizer32bit1State(MOMENTUM, float) MAKE_Optimizer32bit1State(MOMENTUM, __nv_bfloat16) MAKE_Optimizer32bit1State(RMSPROP, half) MAKE_Optimizer32bit1State(RMSPROP, float) MAKE_Optimizer32bit1State(RMSPROP, __nv_bfloat16) MAKE_Optimizer32bit1State(LION, half) MAKE_Optimizer32bit1State(LION, float) MAKE_Optimizer32bit1State(LION, __nv_bfloat16) MAKE_Optimizer32bit1State(ADAGRAD, half) MAKE_Optimizer32bit1State(ADAGRAD, float) MAKE_Optimizer32bit1State(ADAGRAD, __nv_bfloat16) #define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \ template __global__ void kPreconditionOptimizer32bit2State(gtype* g, gtype* p, \ float* state1, float* state2, float *unorm, \ const float beta1, const float beta2, const float eps, const float weight_decay, \ const int step, const float lr, const float gnorm_scale, const int n); \ MAKE_PreconditionOptimizer32bit2State(ADAM, float) MAKE_PreconditionOptimizer32bit2State(ADAM, half) MAKE_PreconditionOptimizer32bit2State(ADAM, __nv_bfloat16) MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, float) MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, half) MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, __nv_bfloat16) template __global__ void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); template __global__ void kOptimizer32bit2State(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); template __global__ void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); template __global__ void kOptimizer32bit2State(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADEMAMIX>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); #define MAKE_PreconditionStatic8bit1State(oname, gtype) \ template __global__ void kPreconditionOptimizerStatic8bit1State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ float *unorm, \ const float beta1, \ const float beta2, \ const float eps, const int step, \ float* __restrict__ const quantiles1, \ float* max1, float* new_max1, \ const float weight_decay, \ const float gnorm_scale, \ const int n); \ MAKE_PreconditionStatic8bit1State(MOMENTUM, half) MAKE_PreconditionStatic8bit1State(MOMENTUM, float) MAKE_PreconditionStatic8bit1State(RMSPROP, half) MAKE_PreconditionStatic8bit1State(RMSPROP, float) MAKE_PreconditionStatic8bit1State(LION, half) MAKE_PreconditionStatic8bit1State(LION, float) MAKE_PreconditionStatic8bit1State(ADAGRAD, half) MAKE_PreconditionStatic8bit1State(ADAGRAD, float) #define MAKE_optimizerStatic8bit1State(oname, gtype) \ template __global__ void kOptimizerStatic8bit1State(gtype* p, gtype* const g, unsigned char* state1, \ const float *unorm, const float max_unorm, const float param_norm, \ const float beta1, \ const float beta2, \ const float eps, const int step, const float lr, \ float* __restrict__ const quantiles1, \ float* max1, float* new_max1, \ float weight_decay, \ const float gnorm_scale, \ const int n); \ MAKE_optimizerStatic8bit1State(MOMENTUM, half) MAKE_optimizerStatic8bit1State(MOMENTUM, float) MAKE_optimizerStatic8bit1State(RMSPROP, half) MAKE_optimizerStatic8bit1State(RMSPROP, float) MAKE_optimizerStatic8bit1State(LION, half) MAKE_optimizerStatic8bit1State(LION, float) MAKE_optimizerStatic8bit1State(ADAGRAD, half) MAKE_optimizerStatic8bit1State(ADAGRAD, float) #define MAKE_PreconditionStatic8bit2State(oname, gtype) \ template __global__ void kPreconditionOptimizerStatic8bit2State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \ float *unorm, \ const float beta1, const float beta2, \ const float eps, const int step, \ float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ float* max1, float* max2, float* new_max1, float* new_max2, \ const float gnorm_scale, \ const int n); \ MAKE_PreconditionStatic8bit2State(ADAM, half) MAKE_PreconditionStatic8bit2State(ADAM, float) #define MAKE_optimizerStatic8bit2State(oname, gtype) \ template __global__ void kOptimizerStatic8bit2State(gtype* p, gtype* const g, unsigned char* state1, unsigned char* state2, \ const float *unorm, const float max_unorm, const float param_norm, \ const float beta1, const float beta2, \ const float eps, const int step, const float lr, \ float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ float* max1, float* max2, float* new_max1, float* new_max2, \ float weight_decay, \ const float gnorm_scale, \ const int n); \ MAKE_optimizerStatic8bit2State(ADAM, half) MAKE_optimizerStatic8bit2State(ADAM, float) template __global__ void kPercentileClipping(float * __restrict__ g, float *gnorm_vec, int step, const int n); template __global__ void kPercentileClipping(half * __restrict__ g, float *gnorm_vec, int step, const int n); #define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \ template __global__ void kQuantizeBlockwise(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \ MAKE_kQuantizeBlockwise(half, 4096, 4, 0, General8bit) MAKE_kQuantizeBlockwise(half, 4096, 4, 1, General8bit) MAKE_kQuantizeBlockwise(half, 2048, 4, 0, General8bit) MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit) MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit) MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit) MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit) MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit) MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4) MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4) MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4) MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4) MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4) MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4) MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4) MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4) MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4) MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4) MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4) MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4) MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4) MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4) MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit) MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit) MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit) MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit) MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit) MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit) MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit) MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit) MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4) MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4) MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4) MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4) MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4) MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4) MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4) MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4) MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4) MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4) MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4) MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4) MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4) MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 0, General8bit) MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 1, General8bit) MAKE_kQuantizeBlockwise(__nv_bfloat16, 2048, 4, 0, General8bit) MAKE_kQuantizeBlockwise(__nv_bfloat16, 1024, 4, 0, General8bit) MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, General8bit) MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, General8bit) MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, General8bit) MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, General8bit) MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 0, FP4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 2048, 4, 0, FP4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 1024, 4, 0, FP4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, FP4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, FP4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, FP4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, FP4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 0, NF4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 2048, 4, 0, NF4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 1024, 4, 0, NF4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, NF4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, NF4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, NF4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, NF4) template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n); template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n); template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n); #define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ template __global__ void kOptimizerStatic8bit2StateBlockwise(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \ const float beta1, const float beta2, const float beta3, const float alpha, \ const float eps, const int step, const float lr, \ float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ float* absmax1, float* absmax2, \ float weight_decay, \ const float gnorm_scale, const bool skip_zeros, const int n); \ MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 256, 1) MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 256, 1) MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, __nv_bfloat16, 256, 1) MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, float, 256, 1) MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, half, 256, 1) MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, __nv_bfloat16, 256, 1) #define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \ template __global__ void kOptimizerStatic8bit1StateBlockwise( \ gtype* p, gtype* __restrict__ const g, unsigned char* state1, \ const float beta1, const float beta2, \ const float eps, const int step, const float lr, \ float* __restrict__ const quantiles1, \ float* absmax1, \ float weight_decay, \ const float gnorm_scale, const bool skip_zeros, const int n); \ MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, __nv_bfloat16, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, __nv_bfloat16, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, __nv_bfloat16, 256, 1) template __device__ void printnonzero(float *A, int num_values, const char*strval); template __device__ void printnonzero(half *A, int num_values, const char*strval);